1pub mod gc;
36
37use anyhow::Result;
38use rusqlite::{Connection, OptionalExtension, params};
39
40#[derive(Debug, Clone)]
47pub struct Candidate<'a> {
48 pub memory_id: &'a str,
49 pub retriever: &'a str,
50 pub rank: i64,
51 pub score: f64,
52}
53
54pub fn record_recall(
66 conn: &Connection,
67 recall_id: &str,
68 candidates: &[Candidate<'_>],
69) -> Result<usize> {
70 if candidates.is_empty() {
71 return Ok(0);
72 }
73 let mut stmt = conn.prepare_cached(
74 "INSERT OR IGNORE INTO recall_observations \
75 (recall_id, memory_id, retriever, rank, score) \
76 VALUES (?1, ?2, ?3, ?4, ?5)",
77 )?;
78 let mut written = 0_usize;
79 for c in candidates {
80 let n = stmt.execute(params![
81 recall_id,
82 c.memory_id,
83 c.retriever,
84 c.rank,
85 c.score
86 ])?;
87 written += n;
88 }
89 Ok(written)
90}
91
92pub fn mark_consumed(
109 conn: &Connection,
110 recall_id: &str,
111 cited_memory_ids: &[&str],
112 consumed_by: &str,
113) -> Result<usize> {
114 if cited_memory_ids.is_empty() {
115 return Ok(0);
116 }
117 let now = chrono::Utc::now().to_rfc3339();
118 let mut stmt = conn.prepare_cached(
119 "UPDATE recall_observations \
120 SET consumed = 1, \
121 consumed_at = ?1, \
122 consumed_by_memory_id = ?2 \
123 WHERE recall_id = ?3 \
124 AND memory_id = ?4 \
125 AND consumed = 0",
126 )?;
127 let mut flipped = 0_usize;
128 for mid in cited_memory_ids {
129 let n = stmt.execute(params![now, consumed_by, recall_id, mid])?;
130 flipped += n;
131 }
132 Ok(flipped)
133}
134
135#[derive(Debug, Clone, serde::Serialize)]
140pub struct Observation {
141 pub recall_id: String,
142 pub memory_id: String,
143 pub retriever: String,
144 pub rank: i64,
145 pub score: f64,
146 pub consumed: bool,
147 pub observed_at: String,
148 #[serde(skip_serializing_if = "Option::is_none")]
149 pub consumed_at: Option<String>,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 pub consumed_by_memory_id: Option<String>,
152}
153
154pub fn list_observations(
163 conn: &Connection,
164 recall_id: Option<&str>,
165 consumed: Option<bool>,
166 since: Option<&str>,
167 until: Option<&str>,
168 limit: usize,
169) -> Result<Vec<Observation>> {
170 let mut sql = String::from(
171 "SELECT recall_id, memory_id, retriever, rank, score, consumed, \
172 observed_at, consumed_at, consumed_by_memory_id \
173 FROM recall_observations \
174 WHERE 1=1",
175 );
176 let mut binds: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
177 if let Some(rid) = recall_id {
178 sql.push_str(" AND recall_id = ?");
179 binds.push(Box::new(rid.to_string()));
180 }
181 if let Some(c) = consumed {
182 sql.push_str(" AND consumed = ?");
183 binds.push(Box::new(i64::from(c)));
184 }
185 if let Some(s) = since {
186 sql.push_str(" AND observed_at >= ?");
187 binds.push(Box::new(s.to_string()));
188 }
189 if let Some(u) = until {
190 sql.push_str(" AND observed_at <= ?");
191 binds.push(Box::new(u.to_string()));
192 }
193 sql.push_str(" ORDER BY observed_at DESC LIMIT ?");
194 let lim_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
195 binds.push(Box::new(lim_i64));
196
197 let mut stmt = conn.prepare(&sql)?;
198 let rows = stmt
199 .query_map(rusqlite::params_from_iter(binds.iter()), |row| {
200 Ok(Observation {
201 recall_id: row.get(0)?,
202 memory_id: row.get(1)?,
203 retriever: row.get(2)?,
204 rank: row.get(3)?,
205 score: row.get(4)?,
206 consumed: row.get::<_, i64>(5)? != 0,
207 observed_at: row.get(6)?,
208 consumed_at: row.get(7).ok(),
209 consumed_by_memory_id: row.get(8).ok(),
210 })
211 })?
212 .collect::<rusqlite::Result<Vec<_>>>()?;
213 Ok(rows)
214}
215
216#[must_use]
230pub fn parse_cite_batch(params: &serde_json::Value) -> Option<(String, Vec<String>)> {
231 let recall_id = params
232 .get("recall_id")
233 .or_else(|| params.get("consumed_from_recall_id"))
234 .and_then(serde_json::Value::as_str)
235 .map(str::trim)
236 .filter(|s| !s.is_empty())?
237 .to_string();
238 let ids_raw = params.get("cited_memory_ids").and_then(|v| v.as_array())?;
239 let mut out: Vec<String> = Vec::new();
240 for v in ids_raw {
241 if let Some(s) = v.as_str() {
242 let s = s.trim();
243 if !s.is_empty() && !out.iter().any(|x| x == s) {
244 out.push(s.to_string());
245 }
246 }
247 }
248 if out.is_empty() {
249 return None;
250 }
251 Some((recall_id, out))
252}
253
254pub fn try_mark_consumed_from_params(
260 conn: &Connection,
261 params: &serde_json::Value,
262 consumed_by: &str,
263) {
264 let Some((recall_id, ids)) = parse_cite_batch(params) else {
265 return;
266 };
267 let refs: Vec<&str> = ids.iter().map(String::as_str).collect();
268 if let Err(e) = mark_consumed(conn, &recall_id, &refs, consumed_by) {
269 tracing::warn!(
270 target: "observations",
271 recall_id = %recall_id,
272 consumed_by,
273 "mark_consumed failed (non-fatal): {e}"
274 );
275 }
276}
277
278#[must_use]
286pub fn table_exists(conn: &Connection) -> bool {
287 conn.query_row(
288 "SELECT name FROM sqlite_master WHERE type='table' AND name='recall_observations'",
289 [],
290 |row| row.get::<_, String>(0),
291 )
292 .optional()
293 .ok()
294 .flatten()
295 .is_some()
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use rusqlite::Connection;
302
303 fn fresh() -> Connection {
304 crate::storage::open(std::path::Path::new(":memory:")).expect("open in-memory db")
309 }
310
311 fn seed_memory(conn: &Connection, id: &str) {
312 conn.execute(
313 "INSERT INTO memories \
314 (id, tier, namespace, title, content, created_at, updated_at) \
315 VALUES (?1, 'long', 'test', ?2, 'content', '2025-01-01T00:00:00Z', '2025-01-01T00:00:00Z')",
316 params![id, format!("title-{id}")],
317 )
318 .expect("seed memory");
319 }
320
321 #[test]
322 fn record_recall_writes_one_row_per_candidate() {
323 let conn = fresh();
324 seed_memory(&conn, "m1");
325 seed_memory(&conn, "m2");
326 let candidates = vec![
327 Candidate {
328 memory_id: "m1",
329 retriever: "hybrid",
330 rank: 1,
331 score: 0.9,
332 },
333 Candidate {
334 memory_id: "m2",
335 retriever: "hybrid",
336 rank: 2,
337 score: 0.8,
338 },
339 ];
340 let n = record_recall(&conn, "r1", &candidates).expect("record");
341 assert_eq!(n, 2);
342
343 let obs = list_observations(&conn, Some("r1"), None, None, None, 10).expect("list");
344 assert_eq!(obs.len(), 2);
345 assert!(obs.iter().any(|o| o.memory_id == "m1"));
346 assert!(obs.iter().any(|o| o.memory_id == "m2"));
347 assert!(obs.iter().all(|o| !o.consumed));
348 }
349
350 #[test]
351 fn record_recall_is_idempotent_under_replay() {
352 let conn = fresh();
353 seed_memory(&conn, "m1");
354 let candidates = vec![Candidate {
355 memory_id: "m1",
356 retriever: "fts5",
357 rank: 1,
358 score: 0.5,
359 }];
360 record_recall(&conn, "r1", &candidates).expect("first");
361 let n = record_recall(&conn, "r1", &candidates).expect("replay");
362 assert_eq!(n, 0);
365 }
366
367 #[test]
368 fn mark_consumed_flips_only_matching_rows() {
369 let conn = fresh();
370 seed_memory(&conn, "m1");
371 seed_memory(&conn, "m2");
372 seed_memory(&conn, "m3");
373 seed_memory(&conn, "consumer");
374 record_recall(
375 &conn,
376 "r1",
377 &[
378 Candidate {
379 memory_id: "m1",
380 retriever: "hybrid",
381 rank: 1,
382 score: 0.9,
383 },
384 Candidate {
385 memory_id: "m2",
386 retriever: "hybrid",
387 rank: 2,
388 score: 0.8,
389 },
390 Candidate {
391 memory_id: "m3",
392 retriever: "hybrid",
393 rank: 3,
394 score: 0.7,
395 },
396 ],
397 )
398 .expect("record");
399
400 let flipped = mark_consumed(&conn, "r1", &["m1", "m3"], "consumer").expect("mark");
401 assert_eq!(flipped, 2);
402
403 let obs = list_observations(&conn, Some("r1"), None, None, None, 10).expect("list");
404 let m1 = obs.iter().find(|o| o.memory_id == "m1").unwrap();
405 let m2 = obs.iter().find(|o| o.memory_id == "m2").unwrap();
406 let m3 = obs.iter().find(|o| o.memory_id == "m3").unwrap();
407 assert!(m1.consumed && m1.consumed_at.is_some());
408 assert!(!m2.consumed && m2.consumed_at.is_none());
409 assert!(m3.consumed);
410 assert_eq!(m1.consumed_by_memory_id.as_deref(), Some("consumer"));
411 }
412
413 #[test]
414 fn mark_consumed_idempotent_no_replay_flip() {
415 let conn = fresh();
416 seed_memory(&conn, "m1");
417 seed_memory(&conn, "consumer");
418 record_recall(
419 &conn,
420 "r1",
421 &[Candidate {
422 memory_id: "m1",
423 retriever: "hybrid",
424 rank: 1,
425 score: 0.9,
426 }],
427 )
428 .unwrap();
429 assert_eq!(mark_consumed(&conn, "r1", &["m1"], "consumer").unwrap(), 1);
430 assert_eq!(
431 mark_consumed(&conn, "r1", &["m1"], "consumer").unwrap(),
432 0,
433 "second call must be a no-op because consumed=1 already"
434 );
435 }
436
437 #[test]
438 fn parse_cite_batch_accepts_both_field_names() {
439 let v1 = serde_json::json!({
440 "recall_id": "r1",
441 "cited_memory_ids": ["m1", "m2"],
442 });
443 let v2 = serde_json::json!({
444 "consumed_from_recall_id": "r1",
445 "cited_memory_ids": ["m1", "m2"],
446 });
447 let (rid, ids) = parse_cite_batch(&v1).expect("v1");
448 assert_eq!(rid, "r1");
449 assert_eq!(ids, vec!["m1".to_string(), "m2".to_string()]);
450 let (rid2, ids2) = parse_cite_batch(&v2).expect("v2");
451 assert_eq!(rid2, "r1");
452 assert_eq!(ids2, ids);
453 }
454
455 #[test]
456 fn parse_cite_batch_returns_none_on_missing_fields() {
457 assert!(parse_cite_batch(&serde_json::json!({})).is_none());
458 assert!(
459 parse_cite_batch(&serde_json::json!({"recall_id": "r1"})).is_none(),
460 "missing cited_memory_ids"
461 );
462 assert!(
463 parse_cite_batch(&serde_json::json!({"cited_memory_ids": ["m1"]})).is_none(),
464 "missing recall_id"
465 );
466 assert!(
467 parse_cite_batch(&serde_json::json!({"recall_id": " ", "cited_memory_ids": ["m1"]}))
468 .is_none(),
469 "blank recall_id"
470 );
471 }
472
473 #[test]
474 fn parse_cite_batch_dedupes_and_skips_blank_ids() {
475 let v = serde_json::json!({
476 "recall_id": "r1",
477 "cited_memory_ids": ["m1", "m1", "", " ", "m2"],
478 });
479 let (_rid, ids) = parse_cite_batch(&v).unwrap();
480 assert_eq!(ids, vec!["m1".to_string(), "m2".to_string()]);
481 }
482
483 #[test]
484 fn list_observations_filters_compose() {
485 let conn = fresh();
486 for id in &["m1", "m2", "m3", "consumer"] {
487 seed_memory(&conn, id);
488 }
489 record_recall(
490 &conn,
491 "r1",
492 &[
493 Candidate {
494 memory_id: "m1",
495 retriever: "hybrid",
496 rank: 1,
497 score: 0.9,
498 },
499 Candidate {
500 memory_id: "m2",
501 retriever: "hybrid",
502 rank: 2,
503 score: 0.8,
504 },
505 ],
506 )
507 .unwrap();
508 record_recall(
509 &conn,
510 "r2",
511 &[Candidate {
512 memory_id: "m3",
513 retriever: "fts5",
514 rank: 1,
515 score: 0.4,
516 }],
517 )
518 .unwrap();
519 mark_consumed(&conn, "r1", &["m1"], "consumer").unwrap();
520
521 assert_eq!(
523 list_observations(&conn, Some("r1"), None, None, None, 10)
524 .unwrap()
525 .len(),
526 2
527 );
528 let only_consumed = list_observations(&conn, None, Some(true), None, None, 10).unwrap();
530 assert_eq!(only_consumed.len(), 1);
531 assert_eq!(only_consumed[0].memory_id, "m1");
532 let only_pending = list_observations(&conn, None, Some(false), None, None, 10).unwrap();
534 assert_eq!(only_pending.len(), 2);
535 }
536}