use std::sync::Mutex;
use std::time::Duration;
use serde::Deserialize;
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct Atom {
pub text: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CuratorResponse {
pub atoms: Vec<Atom>,
}
#[derive(Debug)]
pub enum CuratorError {
LlmUnavailable(String),
MalformedResponse(String),
}
impl std::fmt::Display for CuratorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::LlmUnavailable(m) => write!(f, "curator LLM unavailable: {m}"),
Self::MalformedResponse(m) => write!(f, "curator response malformed: {m}"),
}
}
}
impl std::error::Error for CuratorError {}
pub trait Curator: Send + Sync {
fn decompose(
&self,
body: &str,
max_atom_tokens: u32,
max_retries: u32,
) -> Result<Vec<Atom>, CuratorError>;
}
pub const CURATOR_SYSTEM_PROMPT: &str =
"You are decomposing a long memory into atomic propositions.
Each atom must:
(1) Be self-contained — readable without the original context
(2) Be at most {max_atom_tokens} tokens
(3) Contain exactly one fact, decision, observation, or relation
(4) Preserve original meaning — no editorial additions
Return JSON: { atoms: [{ text: string }] } with 2 to 10 atoms.
If the input cannot be decomposed (already atomic, all-or-nothing),
return { atoms: [] }.";
#[must_use]
pub fn render_system_prompt(max_atom_tokens: u32) -> String {
CURATOR_SYSTEM_PROMPT.replace("{max_atom_tokens}", &max_atom_tokens.to_string())
}
pub fn parse_response(body: &str) -> Result<CuratorResponse, String> {
if let Ok(resp) = serde_json::from_str::<CuratorResponse>(body) {
return Ok(resp);
}
let stripped = strip_code_fence(body);
if let Ok(resp) = serde_json::from_str::<CuratorResponse>(&stripped) {
return Ok(resp);
}
if let Some(extracted) = extract_first_json_object(&stripped) {
if let Ok(resp) = serde_json::from_str::<CuratorResponse>(&extracted) {
return Ok(resp);
}
}
let err = serde_json::from_str::<CuratorResponse>(body)
.err()
.map_or_else(|| "unknown parse failure".to_string(), |e| e.to_string());
Err(err)
}
fn strip_code_fence(s: &str) -> String {
let trimmed = s.trim();
let stripped = trimmed
.strip_prefix("```json")
.or_else(|| trimmed.strip_prefix("```JSON"))
.or_else(|| trimmed.strip_prefix("```"))
.unwrap_or(trimmed);
let stripped = stripped.trim_start_matches('\n');
stripped
.strip_suffix("```")
.unwrap_or(stripped)
.trim()
.to_string()
}
fn extract_first_json_object(s: &str) -> Option<String> {
let bytes = s.as_bytes();
let mut depth = 0i32;
let mut start: Option<usize> = None;
let mut in_string = false;
let mut prev_backslash = false;
for (i, &b) in bytes.iter().enumerate() {
if in_string {
if b == b'"' && !prev_backslash {
in_string = false;
}
prev_backslash = b == b'\\' && !prev_backslash;
continue;
}
prev_backslash = false;
match b {
b'"' => in_string = true,
b'{' => {
if depth == 0 {
start = Some(i);
}
depth += 1;
}
b'}' => {
depth -= 1;
if depth == 0 {
if let Some(s0) = start {
return Some(s[s0..=i].to_string());
}
}
}
_ => {}
}
}
None
}
#[must_use]
pub fn enforce_token_budget(atoms: Vec<Atom>, max_atom_tokens: u32) -> (Vec<Atom>, usize) {
let hard_cap = max_atom_tokens.saturating_add(max_atom_tokens / 4);
let mut kept = Vec::with_capacity(atoms.len());
let mut dropped = 0usize;
for atom in atoms {
let count = crate::storage::count_tokens_cl100k(&atom.text);
let count_u32 = u32::try_from(count).unwrap_or(u32::MAX);
if count_u32 <= max_atom_tokens {
kept.push(atom);
} else if count_u32 <= hard_cap {
tracing::warn!(
target: "atomisation::curator",
atom_tokens = count_u32,
budget = max_atom_tokens,
"atom slightly over token budget — accepting (fail-soft)"
);
kept.push(atom);
} else {
tracing::warn!(
target: "atomisation::curator",
atom_tokens = count_u32,
hard_cap,
"atom grossly over token budget — dropping"
);
dropped += 1;
}
}
(kept, dropped)
}
#[must_use]
pub fn backoff_for_attempt(attempt: u32) -> Duration {
const SCHEDULE_MS: &[u64] = &[100, 500, 2500];
let idx = (attempt as usize).min(SCHEDULE_MS.len() - 1);
Duration::from_millis(SCHEDULE_MS[idx])
}
pub struct LlmCurator<L: LlmGenerate + Send + Sync> {
llm: L,
sleep: Mutex<Box<dyn FnMut(Duration) + Send + Sync>>,
}
pub trait LlmGenerate {
fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String, CuratorError>;
}
impl LlmGenerate for crate::llm::OllamaClient {
fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String, CuratorError> {
Self::generate(self, prompt, system)
.map_err(|e| CuratorError::LlmUnavailable(e.to_string()))
}
}
impl LlmGenerate for std::sync::Arc<crate::llm::OllamaClient> {
fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String, CuratorError> {
crate::llm::OllamaClient::generate(self.as_ref(), prompt, system)
.map_err(|e| CuratorError::LlmUnavailable(e.to_string()))
}
}
impl<L: LlmGenerate + Send + Sync> LlmCurator<L> {
pub fn new(llm: L) -> Self {
Self {
llm,
sleep: Mutex::new(Box::new(std::thread::sleep)),
}
}
#[cfg(test)]
pub fn with_sleep<F>(llm: L, sleep: F) -> Self
where
F: FnMut(Duration) + Send + Sync + 'static,
{
Self {
llm,
sleep: Mutex::new(Box::new(sleep)),
}
}
}
impl<L: LlmGenerate + Send + Sync> Curator for LlmCurator<L> {
fn decompose(
&self,
body: &str,
max_atom_tokens: u32,
max_retries: u32,
) -> Result<Vec<Atom>, CuratorError> {
let system = render_system_prompt(max_atom_tokens);
let mut last_err = String::from("no attempts made");
for attempt in 0..=max_retries {
let resp = self.llm.generate(body, Some(&system))?;
match parse_response(&resp) {
Ok(parsed) => {
let (kept, _dropped) = enforce_token_budget(parsed.atoms, max_atom_tokens);
return Ok(kept);
}
Err(e) => {
last_err = e;
if attempt < max_retries {
let backoff = backoff_for_attempt(attempt);
if let Ok(mut s) = self.sleep.lock() {
(s)(backoff);
}
}
}
}
}
Err(CuratorError::MalformedResponse(last_err))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
pub(crate) struct MockLlm {
responses: Mutex<Vec<Result<String, CuratorError>>>,
calls: Mutex<usize>,
}
impl MockLlm {
pub fn new(responses: Vec<Result<String, CuratorError>>) -> Self {
Self {
responses: Mutex::new(responses),
calls: Mutex::new(0),
}
}
pub fn call_count(&self) -> usize {
*self.calls.lock().unwrap()
}
}
impl LlmGenerate for Arc<MockLlm> {
fn generate(&self, _prompt: &str, _system: Option<&str>) -> Result<String, CuratorError> {
let mut calls = self.calls.lock().unwrap();
*calls += 1;
let mut rs = self.responses.lock().unwrap();
if rs.is_empty() {
return Err(CuratorError::LlmUnavailable(
"mock: no responses left".into(),
));
}
rs.remove(0)
}
}
#[test]
fn render_prompt_substitutes_max_atom_tokens() {
let p = render_system_prompt(200);
assert!(p.contains("at most 200 tokens"));
assert!(!p.contains("{max_atom_tokens}"));
}
#[test]
fn parse_response_accepts_direct_json() {
let body = r#"{"atoms":[{"text":"alpha"},{"text":"beta"}]}"#;
let r = parse_response(body).unwrap();
assert_eq!(r.atoms.len(), 2);
assert_eq!(r.atoms[0].text, "alpha");
}
#[test]
fn parse_response_strips_markdown_fence() {
let body = "```json\n{\"atoms\":[{\"text\":\"alpha\"}]}\n```";
let r = parse_response(body).unwrap();
assert_eq!(r.atoms.len(), 1);
}
#[test]
fn parse_response_extracts_object_with_preamble() {
let body = "Sure, here's the JSON:\n{\"atoms\":[{\"text\":\"alpha\"}]}\nThanks!";
let r = parse_response(body).unwrap();
assert_eq!(r.atoms.len(), 1);
}
#[test]
fn parse_response_empty_atoms_is_valid() {
let body = r#"{"atoms":[]}"#;
let r = parse_response(body).unwrap();
assert_eq!(r.atoms.len(), 0);
}
#[test]
fn parse_response_rejects_garbage() {
assert!(parse_response("nope nope nope").is_err());
assert!(parse_response("").is_err());
assert!(parse_response(r#"{"wrong":"shape"}"#).is_err());
}
#[test]
fn enforce_token_budget_keeps_in_budget() {
let atoms = vec![
Atom {
text: "small atom".to_string(),
},
Atom {
text: "another small atom".to_string(),
},
];
let (kept, dropped) = enforce_token_budget(atoms, 200);
assert_eq!(kept.len(), 2);
assert_eq!(dropped, 0);
}
#[test]
fn enforce_token_budget_drops_grossly_over() {
let huge: String = "word ".repeat(500);
let atoms = vec![
Atom {
text: "fine".to_string(),
},
Atom { text: huge },
];
let (kept, dropped) = enforce_token_budget(atoms, 10);
assert_eq!(kept.len(), 1);
assert_eq!(dropped, 1);
}
#[test]
fn backoff_schedule_is_monotonic_and_bounded() {
assert_eq!(backoff_for_attempt(0), Duration::from_millis(100));
assert_eq!(backoff_for_attempt(1), Duration::from_millis(500));
assert_eq!(backoff_for_attempt(2), Duration::from_millis(2500));
assert_eq!(backoff_for_attempt(99), Duration::from_millis(2500));
}
#[test]
fn curator_succeeds_on_first_attempt() {
let mock = Arc::new(MockLlm::new(vec![Ok(
r#"{"atoms":[{"text":"alpha"},{"text":"beta"}]}"#.to_string(),
)]));
let curator = LlmCurator::with_sleep(mock.clone(), |_| {});
let atoms = curator.decompose("input", 200, 3).unwrap();
assert_eq!(atoms.len(), 2);
assert_eq!(mock.call_count(), 1);
}
#[test]
fn curator_retries_on_malformed_then_succeeds() {
let mock = Arc::new(MockLlm::new(vec![
Ok("garbage".to_string()),
Ok("still garbage".to_string()),
Ok(r#"{"atoms":[{"text":"alpha"}]}"#.to_string()),
]));
let curator = LlmCurator::with_sleep(mock.clone(), |_| {});
let atoms = curator.decompose("input", 200, 3).unwrap();
assert_eq!(atoms.len(), 1);
assert_eq!(mock.call_count(), 3);
}
#[test]
fn curator_fails_after_max_retries() {
let mock = Arc::new(MockLlm::new(vec![
Ok("garbage 1".to_string()),
Ok("garbage 2".to_string()),
Ok("garbage 3".to_string()),
Ok("garbage 4".to_string()),
]));
let curator = LlmCurator::with_sleep(mock.clone(), |_| {});
let err = curator.decompose("input", 200, 3).unwrap_err();
assert!(matches!(err, CuratorError::MalformedResponse(_)));
assert_eq!(mock.call_count(), 4);
}
#[test]
fn curator_propagates_llm_unavailable() {
let mock = Arc::new(MockLlm::new(vec![Err(CuratorError::LlmUnavailable(
"connection refused".into(),
))]));
let curator = LlmCurator::with_sleep(mock, |_| {});
let err = curator.decompose("input", 200, 3).unwrap_err();
assert!(matches!(err, CuratorError::LlmUnavailable(_)));
}
#[test]
fn llm_curator_new_with_real_ollama_client_maps_unavailable() {
let client = crate::llm::OllamaClient::new_with_url_no_health_check(
"http://127.0.0.1:1",
"test-model",
)
.expect("build no-health-check client");
let curator = LlmCurator::new(client);
let err = curator.decompose("body", 200, 0).unwrap_err();
assert!(
matches!(err, CuratorError::LlmUnavailable(_)),
"expected LlmUnavailable, got {err:?}"
);
}
#[test]
fn llm_curator_arc_ollama_passthrough_maps_unavailable() {
let client = Arc::new(
crate::llm::OllamaClient::new_with_url_no_health_check(
"http://127.0.0.1:1",
"test-model",
)
.expect("build no-health-check client"),
);
let curator = LlmCurator::with_sleep(client, |_| {});
let err = curator.decompose("body", 200, 0).unwrap_err();
assert!(matches!(err, CuratorError::LlmUnavailable(_)));
}
#[test]
fn curator_error_display_arms() {
assert_eq!(
CuratorError::LlmUnavailable("x".into()).to_string(),
"curator LLM unavailable: x"
);
assert_eq!(
CuratorError::MalformedResponse("y".into()).to_string(),
"curator response malformed: y"
);
let e = CuratorError::MalformedResponse("z".into());
let _: &dyn std::error::Error = &e;
}
#[test]
fn extract_first_json_object_handles_braces_in_strings() {
let s = r#"prefix {"atoms":[{"text":"contains } brace"}]} suffix"#;
let extracted = extract_first_json_object(s).unwrap();
let parsed: CuratorResponse = serde_json::from_str(&extracted).unwrap();
assert_eq!(parsed.atoms.len(), 1);
assert_eq!(parsed.atoms[0].text, "contains } brace");
}
}