Skip to main content

roboticus_db/
memory.rs

1use crate::{Database, DbResultExt};
2use chrono::Utc;
3use roboticus_core::Result;
4use rusqlite::OptionalExtension;
5
6// ── Working memory ──────────────────────────────────────────────
7
8#[derive(Debug, Clone)]
9pub struct WorkingEntry {
10    pub id: String,
11    pub session_id: String,
12    pub entry_type: String,
13    pub content: String,
14    pub importance: i32,
15    pub created_at: String,
16}
17
18pub fn store_working(
19    db: &Database,
20    session_id: &str,
21    entry_type: &str,
22    content: &str,
23    importance: i32,
24) -> Result<String> {
25    let conn = db.conn();
26    let id = uuid::Uuid::new_v4().to_string();
27    let now = Utc::now().to_rfc3339();
28    let tx = conn.unchecked_transaction().db_err()?;
29    tx.execute(
30        "INSERT INTO working_memory (id, session_id, entry_type, content, importance, created_at) \
31         VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
32        rusqlite::params![id, session_id, entry_type, content, importance, now],
33    )
34    .db_err()?;
35    // Remove any existing FTS row before inserting to avoid duplicates.
36    tx.execute(
37        "DELETE FROM memory_fts WHERE source_table = 'working' AND source_id = ?1",
38        rusqlite::params![id],
39    )
40    .db_err()?;
41    tx.execute(
42        "INSERT INTO memory_fts (content, category, source_table, source_id) VALUES (?1, ?2, 'working', ?3)",
43        rusqlite::params![content, entry_type, id],
44    )
45    .db_err()?;
46    tx.commit().db_err()?;
47    Ok(id)
48}
49
50pub fn retrieve_working(db: &Database, session_id: &str) -> Result<Vec<WorkingEntry>> {
51    let conn = db.conn();
52    let mut stmt = conn
53        .prepare(
54            "SELECT id, session_id, entry_type, content, importance, created_at \
55             FROM working_memory WHERE session_id = ?1 ORDER BY importance DESC, created_at DESC",
56        )
57        .db_err()?;
58
59    let rows = stmt
60        .query_map([session_id], |row| {
61            Ok(WorkingEntry {
62                id: row.get(0)?,
63                session_id: row.get(1)?,
64                entry_type: row.get(2)?,
65                content: row.get(3)?,
66                importance: row.get(4)?,
67                created_at: row.get(5)?,
68            })
69        })
70        .db_err()?;
71
72    rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
73}
74
75pub fn retrieve_working_all(db: &Database, limit: i64) -> Result<Vec<WorkingEntry>> {
76    let conn = db.conn();
77    let mut stmt = conn
78        .prepare(
79            "SELECT id, session_id, entry_type, content, importance, created_at \
80             FROM working_memory ORDER BY importance DESC, created_at DESC LIMIT ?1",
81        )
82        .db_err()?;
83
84    let rows = stmt
85        .query_map([limit], |row| {
86            Ok(WorkingEntry {
87                id: row.get(0)?,
88                session_id: row.get(1)?,
89                entry_type: row.get(2)?,
90                content: row.get(3)?,
91                importance: row.get(4)?,
92                created_at: row.get(5)?,
93            })
94        })
95        .db_err()?;
96
97    rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
98}
99
100// ── Episodic memory ─────────────────────────────────────────────
101
102#[derive(Debug, Clone)]
103pub struct EpisodicEntry {
104    pub id: String,
105    pub classification: String,
106    pub content: String,
107    pub importance: i32,
108    pub created_at: String,
109}
110
111pub fn store_episodic(
112    db: &Database,
113    classification: &str,
114    content: &str,
115    importance: i32,
116) -> Result<String> {
117    let conn = db.conn();
118    let id = uuid::Uuid::new_v4().to_string();
119    let now = Utc::now().to_rfc3339();
120    conn.execute(
121        "INSERT INTO episodic_memory (id, classification, content, importance, created_at) \
122         VALUES (?1, ?2, ?3, ?4, ?5)",
123        rusqlite::params![id, classification, content, importance, now],
124    )
125    .db_err()?;
126
127    // FTS insert handled by episodic_ai trigger
128
129    Ok(id)
130}
131
132pub fn retrieve_episodic(db: &Database, limit: i64) -> Result<Vec<EpisodicEntry>> {
133    let conn = db.conn();
134    let mut stmt = conn
135        .prepare(
136            "SELECT id, classification, content, importance, created_at \
137             FROM episodic_memory ORDER BY importance DESC, created_at DESC LIMIT ?1",
138        )
139        .db_err()?;
140
141    let rows = stmt
142        .query_map([limit], |row| {
143            Ok(EpisodicEntry {
144                id: row.get(0)?,
145                classification: row.get(1)?,
146                content: row.get(2)?,
147                importance: row.get(3)?,
148                created_at: row.get(4)?,
149            })
150        })
151        .db_err()?;
152
153    rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
154}
155
156// ── Semantic memory ─────────────────────────────────────────────
157
158#[derive(Debug, Clone)]
159pub struct SemanticEntry {
160    pub id: String,
161    pub category: String,
162    pub key: String,
163    pub value: String,
164    pub confidence: f64,
165    pub created_at: String,
166    pub updated_at: String,
167}
168
169pub fn store_semantic(
170    db: &Database,
171    category: &str,
172    key: &str,
173    value: &str,
174    confidence: f64,
175) -> Result<String> {
176    let conn = db.conn();
177    let id = uuid::Uuid::new_v4().to_string();
178    let now = Utc::now().to_rfc3339();
179    let tx = conn.unchecked_transaction().db_err()?;
180    tx.execute(
181        "INSERT INTO semantic_memory (id, category, key, value, confidence, created_at) \
182         VALUES (?1, ?2, ?3, ?4, ?5, ?6) \
183         ON CONFLICT(category, key) DO UPDATE SET value = excluded.value, \
184         confidence = excluded.confidence, updated_at = ?6",
185        rusqlite::params![id, category, key, value, confidence, now],
186    )
187    .db_err()?;
188
189    let actual_id: String = tx
190        .query_row(
191            "SELECT id FROM semantic_memory WHERE category = ?1 AND key = ?2",
192            rusqlite::params![category, key],
193            |row| row.get(0),
194        )
195        .db_err()?;
196
197    // Remove any existing FTS row before re-inserting to avoid duplicates on upsert.
198    tx.execute(
199        "DELETE FROM memory_fts WHERE source_table = 'semantic' AND source_id = ?1",
200        rusqlite::params![actual_id],
201    )
202    .db_err()?;
203    tx.execute(
204        "INSERT INTO memory_fts (content, category, source_table, source_id) VALUES (?1, ?2, 'semantic', ?3)",
205        rusqlite::params![value, category, actual_id],
206    )
207    .db_err()?;
208    tx.commit().db_err()?;
209
210    Ok(actual_id)
211}
212
213pub fn retrieve_semantic(db: &Database, category: &str) -> Result<Vec<SemanticEntry>> {
214    let conn = db.conn();
215    let mut stmt = conn
216        .prepare(
217            "SELECT id, category, key, value, confidence, created_at, updated_at \
218             FROM semantic_memory WHERE category = ?1 ORDER BY confidence DESC",
219        )
220        .db_err()?;
221
222    let rows = stmt
223        .query_map([category], |row| {
224            Ok(SemanticEntry {
225                id: row.get(0)?,
226                category: row.get(1)?,
227                key: row.get(2)?,
228                value: row.get(3)?,
229                confidence: row.get(4)?,
230                created_at: row.get(5)?,
231                updated_at: row.get(6)?,
232            })
233        })
234        .db_err()?;
235
236    rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
237}
238
239pub fn list_semantic_categories(db: &Database) -> Result<Vec<(String, i64)>> {
240    let conn = db.conn();
241    let mut stmt = conn
242        .prepare(
243            "SELECT category, COUNT(*) as cnt FROM semantic_memory \
244             GROUP BY category ORDER BY cnt DESC",
245        )
246        .db_err()?;
247
248    let rows = stmt
249        .query_map([], |row| {
250            Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
251        })
252        .db_err()?;
253
254    rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
255}
256
257pub fn retrieve_semantic_all(db: &Database, limit: i64) -> Result<Vec<SemanticEntry>> {
258    let conn = db.conn();
259    let mut stmt = conn
260        .prepare(
261            "SELECT id, category, key, value, confidence, created_at, updated_at \
262             FROM semantic_memory ORDER BY confidence DESC, updated_at DESC LIMIT ?1",
263        )
264        .db_err()?;
265
266    let rows = stmt
267        .query_map([limit], |row| {
268            Ok(SemanticEntry {
269                id: row.get(0)?,
270                category: row.get(1)?,
271                key: row.get(2)?,
272                value: row.get(3)?,
273                confidence: row.get(4)?,
274                created_at: row.get(5)?,
275                updated_at: row.get(6)?,
276            })
277        })
278        .db_err()?;
279
280    rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
281}
282
283// ── Procedural memory ───────────────────────────────────────────
284
285#[derive(Debug, Clone)]
286pub struct ProceduralEntry {
287    pub id: String,
288    pub name: String,
289    pub steps: String,
290    pub success_count: i64,
291    pub failure_count: i64,
292    pub created_at: String,
293    pub updated_at: String,
294}
295
296pub fn store_procedural(db: &Database, name: &str, steps: &str) -> Result<String> {
297    let conn = db.conn();
298    let id = uuid::Uuid::new_v4().to_string();
299    let now = Utc::now().to_rfc3339();
300    conn.execute(
301        "INSERT INTO procedural_memory (id, name, steps, created_at) VALUES (?1, ?2, ?3, ?4) \
302         ON CONFLICT(name) DO UPDATE SET steps = excluded.steps, updated_at = ?4",
303        rusqlite::params![id, name, steps, now],
304    )
305    .db_err()?;
306    Ok(id)
307}
308
309pub fn retrieve_procedural(db: &Database, name: &str) -> Result<Option<ProceduralEntry>> {
310    let conn = db.conn();
311    conn.query_row(
312        "SELECT id, name, steps, success_count, failure_count, created_at, updated_at \
313         FROM procedural_memory WHERE name = ?1",
314        [name],
315        |row| {
316            Ok(ProceduralEntry {
317                id: row.get(0)?,
318                name: row.get(1)?,
319                steps: row.get(2)?,
320                success_count: row.get(3)?,
321                failure_count: row.get(4)?,
322                created_at: row.get(5)?,
323                updated_at: row.get(6)?,
324            })
325        },
326    )
327    .optional()
328    .db_err()
329}
330
331pub fn record_procedural_success(db: &Database, name: &str) -> Result<()> {
332    let conn = db.conn();
333    // Auto-register the tool if it hasn't been seen before, then increment.
334    // Must include `steps` (NOT NULL) — SQLite evaluates NOT NULL before the
335    // ON CONFLICT(name) upsert path, so omitting it causes a hard failure
336    // even when the row already exists.
337    conn.execute(
338        "INSERT INTO procedural_memory (id, name, steps, success_count, failure_count, created_at, updated_at) \
339         VALUES (lower(hex(randomblob(16))), ?1, '', 0, 0, datetime('now'), datetime('now')) \
340         ON CONFLICT(name) DO NOTHING",
341        [name],
342    )
343    .db_err()?;
344    conn.execute(
345        "UPDATE procedural_memory SET success_count = success_count + 1, updated_at = datetime('now') WHERE name = ?1",
346        [name],
347    )
348    .db_err()?;
349    Ok(())
350}
351
352pub fn record_procedural_failure(db: &Database, name: &str) -> Result<()> {
353    let conn = db.conn();
354    // Auto-register the tool if it hasn't been seen before, then increment.
355    // Must include `steps` (NOT NULL) — see record_procedural_success comment.
356    conn.execute(
357        "INSERT INTO procedural_memory (id, name, steps, success_count, failure_count, created_at, updated_at) \
358         VALUES (lower(hex(randomblob(16))), ?1, '', 0, 0, datetime('now'), datetime('now')) \
359         ON CONFLICT(name) DO NOTHING",
360        [name],
361    )
362    .db_err()?;
363    conn.execute(
364        "UPDATE procedural_memory SET failure_count = failure_count + 1, updated_at = datetime('now') WHERE name = ?1",
365        [name],
366    )
367    .db_err()?;
368    Ok(())
369}
370
371/// Delete procedural entries with zero activity (no successes AND no failures)
372/// that haven't been updated in at least `stale_days` days.
373///
374/// Returns the number of rows deleted.
375pub fn prune_stale_procedural(db: &Database, stale_days: u32) -> Result<usize> {
376    let conn = db.conn();
377    let deleted = conn
378        .execute(
379            "DELETE FROM procedural_memory \
380             WHERE success_count = 0 AND failure_count = 0 \
381               AND updated_at < datetime('now', ?1)",
382            [format!("-{stale_days} days")],
383        )
384        .db_err()?;
385    Ok(deleted)
386}
387
388// ── Relationship memory ─────────────────────────────────────────
389
390#[derive(Debug, Clone)]
391pub struct RelationshipEntry {
392    pub id: String,
393    pub entity_id: String,
394    pub entity_name: Option<String>,
395    pub trust_score: f64,
396    pub interaction_summary: Option<String>,
397    pub interaction_count: i64,
398    pub last_interaction: Option<String>,
399    pub created_at: String,
400}
401
402pub fn store_relationship(
403    db: &Database,
404    entity_id: &str,
405    entity_name: &str,
406    trust_score: f64,
407) -> Result<String> {
408    let conn = db.conn();
409    let id = uuid::Uuid::new_v4().to_string();
410    let now = Utc::now().to_rfc3339();
411    conn.execute(
412        "INSERT INTO relationship_memory (id, entity_id, entity_name, trust_score, created_at) \
413         VALUES (?1, ?2, ?3, ?4, ?5) \
414         ON CONFLICT(entity_id) DO UPDATE SET entity_name = excluded.entity_name, \
415         trust_score = excluded.trust_score, interaction_count = interaction_count + 1, \
416         last_interaction = ?5",
417        rusqlite::params![id, entity_id, entity_name, trust_score, now],
418    )
419    .db_err()?;
420    Ok(id)
421}
422
423pub fn retrieve_relationship(db: &Database, entity_id: &str) -> Result<Option<RelationshipEntry>> {
424    let conn = db.conn();
425    conn.query_row(
426        "SELECT id, entity_id, entity_name, trust_score, interaction_summary, \
427         interaction_count, last_interaction, created_at \
428         FROM relationship_memory WHERE entity_id = ?1",
429        [entity_id],
430        |row| {
431            Ok(RelationshipEntry {
432                id: row.get(0)?,
433                entity_id: row.get(1)?,
434                entity_name: row.get(2)?,
435                trust_score: row.get(3)?,
436                interaction_summary: row.get(4)?,
437                interaction_count: row.get(5)?,
438                last_interaction: row.get(6)?,
439                created_at: row.get(7)?,
440            })
441        },
442    )
443    .optional()
444    .db_err()
445}
446
447// ── Full-text search across memory tiers ────────────────────────
448
449// ── Search results ──────────────────────────────────────────────
450
451#[derive(Debug, Clone, serde::Serialize)]
452pub struct MemorySearchResult {
453    pub content: String,
454    pub category: String,
455    pub source: String,
456}
457
458/// Sanitize user input for FTS5: keep only alphanumeric and whitespace, wrap in double quotes
459/// (phrase query), and escape any remaining double quotes so FTS5 operators (AND, OR, NOT, etc.)
460/// cannot be injected.
461pub(crate) fn sanitize_fts_query(query: &str) -> String {
462    let stripped: String = query
463        .chars()
464        .filter(|c| c.is_alphanumeric() || c.is_whitespace())
465        .collect();
466    format!("\"{}\"", stripped.replace('"', "\"\""))
467}
468
469/// Search memory: FTS5 MATCH on memory_fts (working, episodic, semantic), LIKE fallback for others.
470/// Returns matching structured entries (content + category + source) up to `limit`, deduplicated.
471pub fn fts_search(db: &Database, query: &str, limit: i64) -> Result<Vec<MemorySearchResult>> {
472    let conn = db.conn();
473    let mut results: Vec<MemorySearchResult> = Vec::new();
474    let mut seen = std::collections::HashSet::new();
475
476    // FTS5 MATCH on memory_fts (populated from working_memory, episodic_memory, semantic_memory)
477    let fts_query = sanitize_fts_query(query);
478    match conn.prepare(
479        "SELECT content, category, source_table FROM memory_fts WHERE memory_fts MATCH ?1 LIMIT ?2",
480    ) {
481        Ok(mut stmt) => {
482            match stmt.query_map(rusqlite::params![fts_query, limit], |row| {
483                Ok((
484                    row.get::<_, String>(0)?,
485                    row.get::<_, String>(1)?,
486                    row.get::<_, String>(2)?,
487                ))
488            }) {
489                Ok(rows) => {
490                    for row in rows.flatten() {
491                        let key = format!("{}|{}", row.2, row.0);
492                        if seen.insert(key) {
493                            results.push(MemorySearchResult {
494                                content: row.0,
495                                category: row.1,
496                                source: row.2,
497                            });
498                            if results.len() as i64 >= limit {
499                                return Ok(results);
500                            }
501                        }
502                    }
503                }
504                Err(e) => tracing::warn!(error = %e, "FTS5 query_map failed"),
505            }
506        }
507        Err(e) => tracing::warn!(error = %e, "FTS5 query preparation failed"),
508    }
509
510    // LIKE fallback for tables not in FTS: procedural_memory.steps, relationship_memory.interaction_summary.
511    // Safety: table and column names below are hardcoded constants, not user input,
512    // so the string interpolation into SQL is safe from injection.
513    // Escape % and _ so they are literal, and use ESCAPE '\\'.
514    let escaped_query = query
515        .replace('\\', "\\\\")
516        .replace('%', "\\%")
517        .replace('_', "\\_");
518    let pattern = format!("%{escaped_query}%");
519    let tables_and_cols: &[(&str, &str)] = &[
520        ("procedural_memory", "steps"),
521        ("relationship_memory", "interaction_summary"),
522    ];
523
524    for &(table, col) in tables_and_cols {
525        let sql = format!("SELECT {col} FROM {table} WHERE {col} LIKE ?1 ESCAPE '\\' LIMIT ?2");
526        match conn.prepare(&sql) {
527            Ok(mut stmt) => {
528                match stmt.query_map(rusqlite::params![pattern, limit], |row| {
529                    row.get::<_, String>(0)
530                }) {
531                    Ok(rows) => {
532                        for row in rows.flatten() {
533                            let key = format!("{table}|{row}");
534                            if seen.insert(key) {
535                                results.push(MemorySearchResult {
536                                    content: row,
537                                    category: table.replace("_memory", ""),
538                                    source: table.to_string(),
539                                });
540                                if results.len() as i64 >= limit {
541                                    return Ok(results);
542                                }
543                            }
544                        }
545                    }
546                    Err(e) => {
547                        tracing::warn!(error = %e, table, col, "LIKE fallback query_map failed")
548                    }
549                }
550            }
551            Err(e) => {
552                tracing::warn!(error = %e, table, col, "LIKE fallback query preparation failed")
553            }
554        }
555    }
556
557    Ok(results)
558}
559
560// ── Episodic dead-entry cleanup ────────────────────────────────
561
562/// Delete episodic entries with `importance <= 1` that are older than
563/// `stale_days` days.  These low-signal entries accumulate over time and
564/// bloat the episodic tier without contributing useful retrieval context.
565///
566/// Returns the number of rows deleted.
567pub fn prune_dead_episodic(db: &Database, stale_days: u32) -> Result<usize> {
568    let conn = db.conn();
569    let deleted = conn
570        .execute(
571            "DELETE FROM episodic_memory \
572             WHERE importance <= 1 \
573               AND created_at < datetime('now', ?1)",
574            [format!("-{stale_days} days")],
575        )
576        .db_err()?;
577    Ok(deleted)
578}
579
580// ── Orphan cleanup ─────────────────────────────────────────────
581
582/// Delete working_memory rows whose `session_id` no longer exists in `sessions`.
583///
584/// Returns the number of orphaned rows removed.
585pub fn cleanup_orphaned_working_memory(db: &Database) -> Result<usize> {
586    let conn = db.conn();
587    let deleted = conn
588        .execute(
589            "DELETE FROM working_memory \
590             WHERE session_id NOT IN (SELECT id FROM sessions)",
591            [],
592        )
593        .db_err()?;
594    Ok(deleted)
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    fn test_db() -> Database {
602        Database::new(":memory:").unwrap()
603    }
604
605    #[test]
606    fn working_memory_roundtrip() {
607        let db = test_db();
608        store_working(&db, "sess-1", "goal", "find food", 8).unwrap();
609        store_working(&db, "sess-1", "observation", "sun is up", 3).unwrap();
610
611        let entries = retrieve_working(&db, "sess-1").unwrap();
612        assert_eq!(entries.len(), 2);
613        assert_eq!(entries[0].importance, 8, "higher importance first");
614    }
615
616    #[test]
617    fn episodic_memory_roundtrip() {
618        let db = test_db();
619        store_episodic(&db, "success", "deployed v1.0", 9).unwrap();
620        store_episodic(&db, "failure", "ran out of credits", 7).unwrap();
621
622        let entries = retrieve_episodic(&db, 10).unwrap();
623        assert_eq!(entries.len(), 2);
624        assert_eq!(entries[0].classification, "success");
625    }
626
627    #[test]
628    fn semantic_memory_upsert() {
629        let db = test_db();
630        store_semantic(&db, "facts", "sky_color", "blue", 0.9).unwrap();
631        store_semantic(&db, "facts", "sky_color", "grey", 0.7).unwrap();
632
633        let entries = retrieve_semantic(&db, "facts").unwrap();
634        assert_eq!(entries.len(), 1);
635        assert_eq!(entries[0].value, "grey");
636    }
637
638    #[test]
639    fn procedural_memory_roundtrip() {
640        let db = test_db();
641        store_procedural(&db, "deploy", r#"["build","push","verify"]"#).unwrap();
642        let entry = retrieve_procedural(&db, "deploy").unwrap().unwrap();
643        assert_eq!(entry.name, "deploy");
644    }
645
646    #[test]
647    fn relationship_memory_roundtrip() {
648        let db = test_db();
649        store_relationship(&db, "user-42", "Jon", 0.9).unwrap();
650        let entry = retrieve_relationship(&db, "user-42").unwrap().unwrap();
651        assert_eq!(entry.entity_name.as_deref(), Some("Jon"));
652        assert!((entry.trust_score - 0.9).abs() < f64::EPSILON);
653    }
654
655    #[test]
656    fn fts_search_finds_across_tiers() {
657        let db = test_db();
658        store_working(&db, "s1", "note", "the quick brown fox", 5).unwrap();
659        store_episodic(&db, "event", "a lazy dog appeared", 5).unwrap();
660        store_semantic(&db, "facts", "animal", "fox is quick", 0.8).unwrap();
661
662        let hits = fts_search(&db, "quick", 10).unwrap();
663        assert_eq!(hits.len(), 2, "should match working + semantic");
664    }
665
666    #[test]
667    fn fts_search_finds_episodic_via_trigger() {
668        let db = test_db();
669        store_episodic(&db, "discovery", "the quantum engine hummed", 9).unwrap();
670
671        let hits = fts_search(&db, "quantum", 10).unwrap();
672        assert_eq!(hits.len(), 1);
673        assert!(hits[0].content.contains("quantum"));
674    }
675
676    #[test]
677    fn fts_respects_limit() {
678        let db = test_db();
679        for i in 0..5 {
680            store_working(&db, "s1", "note", &format!("alpha item {i}"), 1).unwrap();
681        }
682        let hits = fts_search(&db, "alpha", 3).unwrap();
683        assert_eq!(hits.len(), 3);
684    }
685
686    #[test]
687    fn semantic_upsert_returns_existing_id() {
688        let db = test_db();
689        let id1 = store_semantic(&db, "prefs", "color", "blue", 0.9).unwrap();
690        let id2 = store_semantic(&db, "prefs", "color", "red", 0.8).unwrap();
691        assert_eq!(id1, id2, "upsert should return the original row id");
692    }
693
694    #[test]
695    fn procedural_failure_tracking() {
696        let db = test_db();
697        store_procedural(&db, "deploy", r#"["build","push"]"#).unwrap();
698        let entry = retrieve_procedural(&db, "deploy").unwrap().unwrap();
699        assert_eq!(entry.failure_count, 0);
700
701        record_procedural_failure(&db, "deploy").unwrap();
702        record_procedural_failure(&db, "deploy").unwrap();
703        let entry = retrieve_procedural(&db, "deploy").unwrap().unwrap();
704        assert_eq!(entry.failure_count, 2);
705    }
706
707    #[test]
708    fn store_working_writes_both_tables() {
709        let db = test_db();
710        let id = store_working(&db, "sess-1", "fact", "the sky is blue", 5).unwrap();
711
712        let conn = db.conn();
713        let count: i64 = conn
714            .query_row(
715                "SELECT COUNT(*) FROM working_memory WHERE id = ?1",
716                [&id],
717                |r| r.get(0),
718            )
719            .unwrap();
720        assert_eq!(count, 1);
721
722        let fts_count: i64 = conn
723            .query_row(
724                "SELECT COUNT(*) FROM memory_fts WHERE source_id = ?1",
725                [&id],
726                |r| r.get(0),
727            )
728            .unwrap();
729        assert_eq!(fts_count, 1);
730    }
731
732    #[test]
733    fn record_procedural_success_tracking() {
734        let db = test_db();
735        store_procedural(&db, "deploy", r#"["build","push"]"#).unwrap();
736        record_procedural_success(&db, "deploy").unwrap();
737        record_procedural_success(&db, "deploy").unwrap();
738        record_procedural_success(&db, "deploy").unwrap();
739        let entry = retrieve_procedural(&db, "deploy").unwrap().unwrap();
740        assert_eq!(entry.success_count, 3);
741    }
742
743    #[test]
744    fn retrieve_working_empty_session() {
745        let db = test_db();
746        let entries = retrieve_working(&db, "nonexistent-session").unwrap();
747        assert!(entries.is_empty());
748    }
749
750    #[test]
751    fn retrieve_working_is_session_isolated() {
752        let db = test_db();
753        store_working(&db, "sess-a", "note", "alpha", 5).unwrap();
754        store_working(&db, "sess-b", "note", "beta", 5).unwrap();
755
756        let a = retrieve_working(&db, "sess-a").unwrap();
757        let b = retrieve_working(&db, "sess-b").unwrap();
758        assert_eq!(a.len(), 1);
759        assert_eq!(b.len(), 1);
760        assert_eq!(a[0].content, "alpha");
761        assert_eq!(b[0].content, "beta");
762    }
763
764    #[test]
765    fn retrieve_episodic_limit_zero() {
766        let db = test_db();
767        store_episodic(&db, "event", "something happened", 5).unwrap();
768        let entries = retrieve_episodic(&db, 0).unwrap();
769        assert!(entries.is_empty());
770    }
771
772    #[test]
773    fn retrieve_semantic_empty_category() {
774        let db = test_db();
775        let entries = retrieve_semantic(&db, "no-such-category").unwrap();
776        assert!(entries.is_empty());
777    }
778
779    #[test]
780    fn retrieve_procedural_nonexistent() {
781        let db = test_db();
782        let entry = retrieve_procedural(&db, "nonexistent").unwrap();
783        assert!(entry.is_none());
784    }
785
786    #[test]
787    fn retrieve_relationship_nonexistent() {
788        let db = test_db();
789        let entry = retrieve_relationship(&db, "no-such-entity").unwrap();
790        assert!(entry.is_none());
791    }
792
793    #[test]
794    fn store_relationship_upsert_increments_interaction() {
795        let db = test_db();
796        store_relationship(&db, "user-1", "Alice", 0.5).unwrap();
797        store_relationship(&db, "user-1", "Alice Updated", 0.8).unwrap();
798        let entry = retrieve_relationship(&db, "user-1").unwrap().unwrap();
799        assert_eq!(entry.interaction_count, 1);
800    }
801
802    #[test]
803    fn store_procedural_upsert_updates_steps() {
804        let db = test_db();
805        store_procedural(&db, "deploy", r#"["build"]"#).unwrap();
806        store_procedural(&db, "deploy", r#"["build","push","verify"]"#).unwrap();
807        let entry = retrieve_procedural(&db, "deploy").unwrap().unwrap();
808        assert_eq!(entry.steps, r#"["build","push","verify"]"#);
809    }
810
811    #[test]
812    fn fts_search_no_matches() {
813        let db = test_db();
814        store_working(&db, "s1", "note", "hello world", 5).unwrap();
815        let hits = fts_search(&db, "zzzznotfound", 10).unwrap();
816        assert!(hits.is_empty());
817    }
818
819    #[test]
820    fn fts_search_like_fallback_procedural() {
821        let db = test_db();
822        store_procedural(&db, "backup", "step one: tar the archive and compress").unwrap();
823        let hits = fts_search(&db, "tar the archive", 10).unwrap();
824        assert!(!hits.is_empty());
825    }
826
827    // ── retrieve_working_all tests ────────────────────────────
828
829    #[test]
830    fn retrieve_working_all_returns_across_sessions() {
831        let db = test_db();
832        store_working(&db, "sess-a", "note", "alpha entry", 5).unwrap();
833        store_working(&db, "sess-b", "note", "beta entry", 8).unwrap();
834        store_working(&db, "sess-c", "note", "gamma entry", 3).unwrap();
835
836        let entries = retrieve_working_all(&db, 100).unwrap();
837        assert_eq!(entries.len(), 3);
838        // Ordered by importance DESC
839        assert_eq!(entries[0].importance, 8);
840        assert_eq!(entries[1].importance, 5);
841        assert_eq!(entries[2].importance, 3);
842    }
843
844    #[test]
845    fn retrieve_working_all_respects_limit() {
846        let db = test_db();
847        for i in 0..5 {
848            store_working(&db, "sess", "note", &format!("entry {i}"), i).unwrap();
849        }
850        let entries = retrieve_working_all(&db, 2).unwrap();
851        assert_eq!(entries.len(), 2);
852    }
853
854    #[test]
855    fn retrieve_working_all_empty_db() {
856        let db = test_db();
857        let entries = retrieve_working_all(&db, 10).unwrap();
858        assert!(entries.is_empty());
859    }
860
861    // ── list_semantic_categories tests ────────────────────────
862
863    #[test]
864    fn list_semantic_categories_returns_grouped() {
865        let db = test_db();
866        store_semantic(&db, "facts", "sky_color", "blue", 0.9).unwrap();
867        store_semantic(&db, "facts", "grass_color", "green", 0.8).unwrap();
868        store_semantic(&db, "prefs", "theme", "dark", 0.7).unwrap();
869
870        let categories = list_semantic_categories(&db).unwrap();
871        assert_eq!(categories.len(), 2);
872        // Ordered by count DESC
873        assert_eq!(categories[0].0, "facts");
874        assert_eq!(categories[0].1, 2);
875        assert_eq!(categories[1].0, "prefs");
876        assert_eq!(categories[1].1, 1);
877    }
878
879    #[test]
880    fn list_semantic_categories_empty() {
881        let db = test_db();
882        let categories = list_semantic_categories(&db).unwrap();
883        assert!(categories.is_empty());
884    }
885
886    // ── retrieve_semantic_all tests ──────────────────────────
887
888    #[test]
889    fn retrieve_semantic_all_returns_across_categories() {
890        let db = test_db();
891        store_semantic(&db, "facts", "sky", "blue", 0.9).unwrap();
892        store_semantic(&db, "prefs", "theme", "dark", 0.7).unwrap();
893        store_semantic(&db, "facts", "grass", "green", 0.8).unwrap();
894
895        let entries = retrieve_semantic_all(&db, 100).unwrap();
896        assert_eq!(entries.len(), 3);
897        // Ordered by confidence DESC
898        assert!((entries[0].confidence - 0.9).abs() < f64::EPSILON);
899    }
900
901    #[test]
902    fn retrieve_semantic_all_respects_limit() {
903        let db = test_db();
904        for i in 0..5 {
905            store_semantic(
906                &db,
907                "cat",
908                &format!("key{i}"),
909                &format!("val{i}"),
910                0.5 + i as f64 * 0.1,
911            )
912            .unwrap();
913        }
914        let entries = retrieve_semantic_all(&db, 2).unwrap();
915        assert_eq!(entries.len(), 2);
916    }
917
918    #[test]
919    fn retrieve_semantic_all_empty() {
920        let db = test_db();
921        let entries = retrieve_semantic_all(&db, 10).unwrap();
922        assert!(entries.is_empty());
923    }
924
925    // ── fts_search LIKE fallback additional paths ────────────
926
927    #[test]
928    fn fts_search_like_fallback_relationship() {
929        let db = test_db();
930        // Store a relationship with an interaction_summary that can be found via LIKE fallback
931        {
932            let conn = db.conn();
933            conn.execute(
934                "INSERT INTO relationship_memory (id, entity_id, entity_name, trust_score, interaction_summary) \
935                 VALUES ('r1', 'user-99', 'TestUser', 0.8, 'discussed the quantum physics experiment')",
936                [],
937            ).unwrap();
938        }
939
940        let hits = fts_search(&db, "quantum physics", 10).unwrap();
941        assert!(
942            !hits.is_empty(),
943            "LIKE fallback should find relationship interaction_summary"
944        );
945    }
946
947    #[test]
948    fn fts_search_limit_reached_in_fts_phase() {
949        let db = test_db();
950        // Create enough FTS entries so the limit is reached during the FTS phase
951        for i in 0..5 {
952            store_working(
953                &db,
954                "sess",
955                "note",
956                &format!("searchable keyword item {i}"),
957                5,
958            )
959            .unwrap();
960        }
961        let hits = fts_search(&db, "keyword", 2).unwrap();
962        assert_eq!(hits.len(), 2, "should stop at limit during FTS phase");
963    }
964
965    #[test]
966    fn fts_search_limit_reached_in_like_phase() {
967        let db = test_db();
968        // Store items in procedural memory (LIKE fallback) with a common pattern
969        for i in 0..5 {
970            store_procedural(
971                &db,
972                &format!("proc_{i}"),
973                &format!("step: run the xyzzy command {i}"),
974            )
975            .unwrap();
976        }
977        let hits = fts_search(&db, "xyzzy command", 2).unwrap();
978        assert_eq!(
979            hits.len(),
980            2,
981            "should stop at limit during LIKE fallback phase"
982        );
983    }
984
985    #[test]
986    fn fts_search_special_chars_in_query() {
987        let db = test_db();
988        store_working(
989            &db,
990            "sess",
991            "note",
992            "test with percent % and underscore _",
993            5,
994        )
995        .unwrap();
996        // This tests the sanitize_fts_query and the LIKE escape logic
997        let hits = fts_search(&db, "percent", 10).unwrap();
998        assert!(!hits.is_empty());
999    }
1000
1001    #[test]
1002    fn sanitize_fts_query_strips_operators() {
1003        // FTS5 operators like AND, OR, NOT should be neutralized by the sanitizer
1004        let result = sanitize_fts_query("hello AND world");
1005        // Should wrap in quotes, stripping non-alnum/space
1006        assert!(result.starts_with('"'));
1007        assert!(result.ends_with('"'));
1008    }
1009
1010    #[test]
1011    fn sanitize_fts_query_empty() {
1012        let result = sanitize_fts_query("");
1013        assert_eq!(result, "\"\"");
1014    }
1015
1016    #[test]
1017    fn sanitize_fts_query_special_chars_stripped() {
1018        let result = sanitize_fts_query("hello* OR world");
1019        // * and OR should be kept as alphanumeric/space
1020        assert!(!result.contains('*'));
1021    }
1022
1023    #[test]
1024    fn prune_stale_procedural_removes_zero_activity_entries() {
1025        let db = test_db();
1026        // Create a procedural entry via store (will have success_count=0, failure_count=0)
1027        store_procedural(&db, "stale-tool", "do something").unwrap();
1028
1029        // Backdate its updated_at to 60 days ago
1030        db.conn()
1031            .execute(
1032                "UPDATE procedural_memory SET updated_at = datetime('now', '-60 days') WHERE name = ?1",
1033                ["stale-tool"],
1034            )
1035            .unwrap();
1036
1037        // Also create one with activity — should NOT be pruned
1038        store_procedural(&db, "active-tool", "steps").unwrap();
1039        record_procedural_success(&db, "active-tool").unwrap();
1040        db.conn()
1041            .execute(
1042                "UPDATE procedural_memory SET updated_at = datetime('now', '-60 days') WHERE name = ?1",
1043                ["active-tool"],
1044            )
1045            .unwrap();
1046
1047        let pruned = prune_stale_procedural(&db, 30).unwrap();
1048        assert_eq!(pruned, 1);
1049
1050        // stale-tool gone, active-tool remains
1051        assert!(retrieve_procedural(&db, "stale-tool").unwrap().is_none());
1052        assert!(retrieve_procedural(&db, "active-tool").unwrap().is_some());
1053    }
1054
1055    #[test]
1056    fn prune_stale_procedural_ignores_recent_entries() {
1057        let db = test_db();
1058        store_procedural(&db, "fresh-tool", "steps").unwrap();
1059        // Don't backdate — should not be pruned
1060        let pruned = prune_stale_procedural(&db, 30).unwrap();
1061        assert_eq!(pruned, 0);
1062        assert!(retrieve_procedural(&db, "fresh-tool").unwrap().is_some());
1063    }
1064
1065    // ── Episodic dead-entry cleanup tests ─────────────────────
1066
1067    #[test]
1068    fn prune_dead_episodic_removes_low_importance_old() {
1069        let db = test_db();
1070        store_episodic(&db, "noise", "irrelevant chatter", 1).unwrap();
1071        store_episodic(&db, "signal", "critical event", 8).unwrap();
1072
1073        // Backdate the low-importance entry
1074        db.conn()
1075            .execute(
1076                "UPDATE episodic_memory SET created_at = datetime('now', '-60 days') \
1077                 WHERE importance <= 1",
1078                [],
1079            )
1080            .unwrap();
1081
1082        let pruned = prune_dead_episodic(&db, 30).unwrap();
1083        assert_eq!(pruned, 1);
1084
1085        let remaining = retrieve_episodic(&db, 100).unwrap();
1086        assert_eq!(remaining.len(), 1);
1087        assert_eq!(remaining[0].content, "critical event");
1088    }
1089
1090    #[test]
1091    fn prune_dead_episodic_keeps_recent_low_importance() {
1092        let db = test_db();
1093        store_episodic(&db, "recent-noise", "just happened", 1).unwrap();
1094        // Don't backdate — should not be pruned
1095        let pruned = prune_dead_episodic(&db, 30).unwrap();
1096        assert_eq!(pruned, 0);
1097    }
1098
1099    #[test]
1100    fn prune_dead_episodic_keeps_old_high_importance() {
1101        let db = test_db();
1102        store_episodic(&db, "important", "old but critical", 5).unwrap();
1103        db.conn()
1104            .execute(
1105                "UPDATE episodic_memory SET created_at = datetime('now', '-90 days')",
1106                [],
1107            )
1108            .unwrap();
1109        let pruned = prune_dead_episodic(&db, 30).unwrap();
1110        assert_eq!(pruned, 0);
1111    }
1112
1113    // ── Orphan cleanup tests ─────────────────────────────────
1114
1115    #[test]
1116    fn cleanup_orphaned_working_memory_removes_dangling() {
1117        let db = test_db();
1118        // Create a real session so its working_memory survives.
1119        let conn = db.conn();
1120        conn.execute(
1121            "INSERT INTO sessions (id, agent_id) VALUES ('live-sess', 'a')",
1122            [],
1123        )
1124        .unwrap();
1125        drop(conn);
1126
1127        store_working(&db, "live-sess", "note", "survives", 5).unwrap();
1128        store_working(&db, "dead-sess", "note", "orphaned", 5).unwrap();
1129
1130        let deleted = cleanup_orphaned_working_memory(&db).unwrap();
1131        assert_eq!(deleted, 1);
1132
1133        let remaining = retrieve_working(&db, "live-sess").unwrap();
1134        assert_eq!(remaining.len(), 1);
1135        let gone = retrieve_working(&db, "dead-sess").unwrap();
1136        assert!(gone.is_empty());
1137    }
1138
1139    #[test]
1140    fn cleanup_orphaned_working_memory_noop_when_clean() {
1141        let db = test_db();
1142        let conn = db.conn();
1143        conn.execute("INSERT INTO sessions (id, agent_id) VALUES ('s1', 'a')", [])
1144            .unwrap();
1145        drop(conn);
1146
1147        store_working(&db, "s1", "note", "ok", 5).unwrap();
1148        let deleted = cleanup_orphaned_working_memory(&db).unwrap();
1149        assert_eq!(deleted, 0);
1150    }
1151}