Skip to main content

roboticus_agent/
retrieval.rs

1use roboticus_core::config::MemoryConfig;
2use roboticus_db::Database;
3use serde::Serialize;
4use std::collections::HashSet;
5
6use crate::context::{ComplexityLevel, token_budget};
7use crate::memory::MemoryBudgetManager;
8
9/// Metrics captured during memory retrieval for observability persistence.
10#[derive(Debug, Clone, Default, Serialize)]
11pub struct RetrievalMetrics {
12    /// Number of memories retrieved across all tiers.
13    pub retrieval_count: usize,
14    /// Whether any memories were retrieved (retrieval_count > 0).
15    pub retrieval_hit: bool,
16    /// Average similarity score across vector-search results (0.0 if none).
17    pub avg_similarity: f64,
18    /// Fraction of total context budget consumed by memory tokens.
19    pub budget_utilization: f64,
20    /// Per-tier breakdown of retrieved memory counts.
21    pub tiers: MemoryTierBreakdown,
22}
23
24/// Per-tier counts of memories retrieved.
25#[derive(Debug, Clone, Default, Serialize)]
26pub struct MemoryTierBreakdown {
27    pub working: usize,
28    pub episodic: usize,
29    pub semantic: usize,
30    pub procedural: usize,
31    pub relationship: usize,
32}
33
34/// Output of a retrieval call: formatted text + metrics.
35pub struct RetrievalOutput {
36    /// Formatted memory text for injection into the LLM prompt.
37    pub text: String,
38    /// Retrieval metrics for observability.
39    pub metrics: RetrievalMetrics,
40}
41
42/// Retrieves and formats memories from all five tiers for injection into the LLM prompt.
43pub struct MemoryRetriever {
44    budget_manager: MemoryBudgetManager,
45    hybrid_weight: f64,
46    similarity_threshold: f64,
47    /// Half-life (in days) for episodic memory decay during retrieval re-ranking.
48    /// Older episodic results have their similarity score discounted so that
49    /// recent memories surface above stale ones with similar cosine proximity.
50    decay_half_life_days: f64,
51}
52
53impl MemoryRetriever {
54    pub fn new(config: MemoryConfig) -> Self {
55        let hybrid_weight = config.hybrid_weight;
56        let similarity_threshold = config.similarity_threshold;
57        let decay_half_life_days = config.decay_half_life_days;
58        Self {
59            budget_manager: MemoryBudgetManager::new(config),
60            hybrid_weight,
61            similarity_threshold,
62            decay_half_life_days,
63        }
64    }
65
66    /// Override the episodic decay half-life (in days) used during retrieval re-ranking.
67    pub fn with_decay_half_life(mut self, days: f64) -> Self {
68        self.decay_half_life_days = days;
69        self
70    }
71
72    /// Retrieve memories from all tiers and format them into a single string
73    /// for context injection. Token budgets are respected per-tier.
74    pub fn retrieve(
75        &self,
76        db: &Database,
77        session_id: &str,
78        query: &str,
79        query_embedding: Option<&[f32]>,
80        complexity: ComplexityLevel,
81    ) -> String {
82        self.retrieve_with_ann(db, session_id, query, query_embedding, complexity, None)
83    }
84
85    /// Like `retrieve`, but optionally uses an ANN index for O(log n) nearest-neighbor
86    /// search instead of brute-force cosine scan.
87    pub fn retrieve_with_ann(
88        &self,
89        db: &Database,
90        session_id: &str,
91        query: &str,
92        query_embedding: Option<&[f32]>,
93        complexity: ComplexityLevel,
94        ann_index: Option<&roboticus_db::ann::AnnIndex>,
95    ) -> String {
96        self.retrieve_with_metrics(
97            db,
98            session_id,
99            query,
100            query_embedding,
101            complexity,
102            ann_index,
103        )
104        .text
105    }
106
107    /// Retrieve memories with full observability metrics.
108    ///
109    /// Returns both the formatted memory text and a [`RetrievalMetrics`] struct
110    /// containing tier breakdowns, retrieval counts, and similarity scores for
111    /// persistence into `context_snapshots`.
112    pub fn retrieve_with_metrics(
113        &self,
114        db: &Database,
115        session_id: &str,
116        query: &str,
117        query_embedding: Option<&[f32]>,
118        complexity: ComplexityLevel,
119        ann_index: Option<&roboticus_db::ann::AnnIndex>,
120    ) -> RetrievalOutput {
121        let total_budget = token_budget(complexity);
122        let budgets = self.budget_manager.allocate_budgets(total_budget);
123
124        let mut sections = Vec::new();
125        let mut tiers = MemoryTierBreakdown::default();
126
127        let working_count = if let Some(s) = self.retrieve_working(db, session_id, budgets.working)
128        {
129            // Count lines starting with "- " as individual memory entries
130            let count = s.lines().filter(|l| l.starts_with("- ")).count();
131            sections.push(s);
132            count
133        } else {
134            0
135        };
136        tiers.working = working_count;
137
138        // Try ANN index first for relevant memories; fall back to brute-force hybrid search
139        let relevant = if let (Some(ann), Some(emb)) = (ann_index, query_embedding) {
140            ann.search(emb, 10).map(|results| {
141                results
142                    .into_iter()
143                    .map(|r| roboticus_db::embeddings::SearchResult {
144                        source_table: r.source_table,
145                        source_id: r.source_id,
146                        content_preview: r.content_preview,
147                        similarity: r.similarity,
148                    })
149                    .collect::<Vec<_>>()
150            })
151        } else {
152            None
153        };
154        let mut relevant = relevant.unwrap_or_else(|| {
155            roboticus_db::embeddings::hybrid_search(
156                db,
157                query,
158                query_embedding,
159                10,
160                self.hybrid_weight,
161            )
162            .unwrap_or_default()
163        });
164
165        if self.similarity_threshold > 0.0 {
166            relevant.retain(|r| r.similarity >= self.similarity_threshold);
167        }
168
169        if !query_requests_inactive_memories(query) {
170            self.filter_inactive_memories(db, &mut relevant);
171        }
172
173        // Decay re-ranking: discount episodic results by age so recent memories
174        // surface above stale ones with similar cosine proximity.
175        if self.decay_half_life_days > 0.0 {
176            self.rerank_episodic_by_decay(db, &mut relevant);
177        }
178
179        // Compute similarity stats from vector-search results before formatting
180        // (formatting may drop some results due to budget constraints).
181        let avg_similarity = if relevant.is_empty() {
182            0.0
183        } else {
184            let sum: f64 = relevant.iter().map(|r| r.similarity).sum();
185            sum / relevant.len() as f64
186        };
187
188        // Count per-tier from relevant results
189        for r in &relevant {
190            match r.source_table.as_str() {
191                "episodic_memory" => tiers.episodic += 1,
192                "semantic_memory" => tiers.semantic += 1,
193                _ => {} // other tables map to episodic/semantic bucket
194            }
195        }
196
197        if let Some(s) = self.format_relevant(&relevant, budgets.episodic + budgets.semantic) {
198            sections.push(s);
199        }
200
201        let procedural_count = if let Some(s) = self.retrieve_procedural(db, budgets.procedural) {
202            let count = s.lines().filter(|l| l.starts_with("- ")).count();
203            sections.push(s);
204            count
205        } else {
206            0
207        };
208        tiers.procedural = procedural_count;
209
210        let relationship_count =
211            if let Some(s) = self.retrieve_relationships(db, query, budgets.relationship) {
212                let count = s.lines().filter(|l| l.starts_with("- ")).count();
213                sections.push(s);
214                count
215            } else {
216                0
217            };
218        tiers.relationship = relationship_count;
219
220        let text = if sections.is_empty() {
221            String::new()
222        } else {
223            format!("[Active Memory]\n{}", sections.join("\n\n"))
224        };
225
226        let memory_tokens = estimate_tokens(&text);
227        let retrieval_count =
228            tiers.working + tiers.episodic + tiers.semantic + tiers.procedural + tiers.relationship;
229
230        let metrics = RetrievalMetrics {
231            retrieval_count,
232            retrieval_hit: retrieval_count > 0,
233            avg_similarity,
234            budget_utilization: if total_budget > 0 {
235                memory_tokens as f64 / total_budget as f64
236            } else {
237                0.0
238            },
239            tiers,
240        };
241
242        RetrievalOutput { text, metrics }
243    }
244
245    fn retrieve_working(
246        &self,
247        db: &Database,
248        session_id: &str,
249        budget_tokens: usize,
250    ) -> Option<String> {
251        if budget_tokens == 0 {
252            return None;
253        }
254
255        let entries = roboticus_db::memory::retrieve_working(db, session_id)
256            .inspect_err(
257                |e| tracing::warn!(error = %e, session_id, "working memory retrieval failed"),
258            )
259            .ok()?;
260        if entries.is_empty() {
261            return None;
262        }
263
264        let mut text = String::from("[Working Memory]\n");
265        let mut used = estimate_tokens(&text);
266
267        for entry in &entries {
268            // `turn_summary` mirrors prior assistant output and can cause
269            // repetitive self-priming when injected into subsequent prompts.
270            if entry.entry_type.eq_ignore_ascii_case("turn_summary") {
271                continue;
272            }
273            let line = format!("- [{}] {}\n", entry.entry_type, entry.content);
274            let line_tokens = estimate_tokens(&line);
275            if used + line_tokens > budget_tokens {
276                break;
277            }
278            text.push_str(&line);
279            used += line_tokens;
280        }
281
282        if text.len() > "[Working Memory]\n".len() {
283            Some(text)
284        } else {
285            None
286        }
287    }
288
289    fn format_relevant(
290        &self,
291        results: &[roboticus_db::embeddings::SearchResult],
292        budget_tokens: usize,
293    ) -> Option<String> {
294        if budget_tokens == 0 || results.is_empty() {
295            return None;
296        }
297
298        let mut text = String::from("[Relevant Memories]\n");
299        let mut used = estimate_tokens(&text);
300
301        for result in results {
302            let line = format!(
303                "- [{} | sim={:.2}] {}\n",
304                result.source_table, result.similarity, result.content_preview,
305            );
306            let line_tokens = estimate_tokens(&line);
307            if used + line_tokens > budget_tokens {
308                break;
309            }
310            text.push_str(&line);
311            used += line_tokens;
312        }
313
314        if text.len() > "[Relevant Memories]\n".len() {
315            Some(text)
316        } else {
317            None
318        }
319    }
320
321    /// Re-rank search results by applying time-decay to episodic entries.
322    ///
323    /// For results from the `episodic_memory` table, look up their `created_at`
324    /// timestamp and scale the similarity score by an exponential decay factor.
325    /// Non-episodic results are left untouched.  The result list is re-sorted
326    /// by the adjusted similarity in descending order.
327    fn rerank_episodic_by_decay(
328        &self,
329        db: &Database,
330        results: &mut [roboticus_db::embeddings::SearchResult],
331    ) {
332        let now = chrono::Utc::now();
333
334        // Batch-query: collect all episodic IDs, look them up in one pass,
335        // then apply decay.  This avoids N separate queries holding the DB
336        // connection open in a loop.
337        let episodic_ids: Vec<&str> = results
338            .iter()
339            .filter(|r| r.source_table == "episodic_memory")
340            .map(|r| r.source_id.as_str())
341            .collect();
342
343        if episodic_ids.is_empty() {
344            return;
345        }
346
347        // Build a HashMap<id, age_days> from a single DB access
348        let age_map: std::collections::HashMap<String, f64> = {
349            let conn = db.conn();
350            let placeholders: Vec<String> =
351                (1..=episodic_ids.len()).map(|i| format!("?{i}")).collect();
352            let sql = format!(
353                "SELECT id, created_at FROM episodic_memory WHERE id IN ({})",
354                placeholders.join(", ")
355            );
356            let mut stmt = match conn.prepare(&sql) {
357                Ok(s) => s,
358                Err(_) => return,
359            };
360            let rows = match stmt
361                .query_map(roboticus_db::params_from_iter(episodic_ids.iter()), |row| {
362                    Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
363                }) {
364                Ok(r) => r,
365                Err(_) => return,
366            };
367            rows.filter_map(|r| {
368                r.inspect_err(|e| tracing::warn!("skipping corrupted episodic row: {e}"))
369                    .ok()
370            })
371            .filter_map(|(id, ts)| {
372                chrono::DateTime::parse_from_rfc3339(&ts)
373                    .ok()
374                    .map(|created| {
375                        // Age in days. Future timestamps (clock skew) yield a
376                        // negative chrono::Duration whose .to_std() returns Err,
377                        // mapping to age=0 (fresh). This is correct: only the
378                        // agent writes to episodic_memory so future-dated entries
379                        // are clock-skew artifacts, not attacker-injectable.
380                        let age = (now - created.with_timezone(&chrono::Utc))
381                            .to_std()
382                            .map(|d| d.as_secs_f64() / 86_400.0)
383                            .unwrap_or(0.0);
384                        (id, age)
385                    })
386            })
387            .collect()
388        }; // conn dropped here — DB connection released before mutation loop
389
390        for result in results.iter_mut() {
391            if result.source_table != "episodic_memory" {
392                continue;
393            }
394            if result.source_id.is_empty() {
395                // FTS-only results have no source_id and can't be looked up
396                // in episodic_memory. Apply a conservative default penalty so
397                // they don't bypass decay and outrank properly-aged results.
398                result.similarity *= 0.5;
399                continue;
400            }
401            if let Some(&age) = age_map.get(&result.source_id) {
402                let decay_factor = (0.5_f64).powf(age / self.decay_half_life_days);
403                // Floor at 0.05 so very old memories remain findable — they
404                // rank lower but never become completely invisible.
405                let clamped = decay_factor.max(0.05);
406                result.similarity *= clamped;
407            }
408        }
409
410        // Re-sort by adjusted similarity, descending
411        results.sort_by(|a, b| {
412            b.similarity
413                .partial_cmp(&a.similarity)
414                .unwrap_or(std::cmp::Ordering::Equal)
415        });
416    }
417
418    fn filter_inactive_memories(
419        &self,
420        db: &Database,
421        results: &mut Vec<roboticus_db::embeddings::SearchResult>,
422    ) {
423        let episodic_ids: Vec<&str> = results
424            .iter()
425            .filter(|r| r.source_table == "episodic_memory" && !r.source_id.is_empty())
426            .map(|r| r.source_id.as_str())
427            .collect();
428        let semantic_ids: Vec<&str> = results
429            .iter()
430            .filter(|r| r.source_table == "semantic_memory" && !r.source_id.is_empty())
431            .map(|r| r.source_id.as_str())
432            .collect();
433
434        let episodic_inactive = self.load_inactive_ids(db, "episodic_memory", &episodic_ids);
435        let semantic_inactive = self.load_inactive_ids(db, "semantic_memory", &semantic_ids);
436
437        results.retain(|r| match r.source_table.as_str() {
438            "episodic_memory" => !episodic_inactive.contains(r.source_id.as_str()),
439            "semantic_memory" => !semantic_inactive.contains(r.source_id.as_str()),
440            _ => true,
441        });
442    }
443
444    fn load_inactive_ids(&self, db: &Database, table: &str, ids: &[&str]) -> HashSet<String> {
445        if ids.is_empty() {
446            return HashSet::new();
447        }
448
449        let conn = db.conn();
450        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{i}")).collect();
451        let sql = format!(
452            "SELECT id, memory_state FROM {table} WHERE id IN ({})",
453            placeholders.join(", ")
454        );
455        let mut stmt = match conn.prepare(&sql) {
456            Ok(stmt) => stmt,
457            Err(e) => {
458                tracing::warn!(error = %e, table, "failed to prepare inactive-memory query");
459                return HashSet::new();
460            }
461        };
462        let rows = match stmt.query_map(roboticus_db::params_from_iter(ids.iter()), |row| {
463            Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
464        }) {
465            Ok(rows) => rows,
466            Err(e) => {
467                tracing::warn!(error = %e, table, "failed to query inactive memories");
468                return HashSet::new();
469            }
470        };
471
472        let mut inactive = HashSet::new();
473        for row in rows {
474            match row {
475                Ok((id, state)) if !state.eq_ignore_ascii_case("active") => {
476                    inactive.insert(id);
477                }
478                Ok(_) => {}
479                Err(e) => tracing::warn!(error = %e, table, "skipping invalid memory-state row"),
480            }
481        }
482        inactive
483    }
484
485    fn retrieve_procedural(&self, db: &Database, budget_tokens: usize) -> Option<String> {
486        if budget_tokens == 0 {
487            return None;
488        }
489
490        // Retrieve all procedural entries and present those with meaningful history
491        let conn = db.conn();
492        let mut stmt = conn
493            .prepare(
494                "SELECT name, steps, success_count, failure_count FROM procedural_memory \
495                 WHERE success_count > 0 OR failure_count > 0 \
496                 ORDER BY success_count + failure_count DESC LIMIT 5",
497            )
498            .ok()?;
499
500        let rows: Vec<(String, String, i64, i64)> = stmt
501            .query_map([], |row| {
502                Ok((
503                    row.get::<_, String>(0)?,
504                    row.get::<_, String>(1)?,
505                    row.get::<_, i64>(2)?,
506                    row.get::<_, i64>(3)?,
507                ))
508            })
509            .inspect_err(|e| tracing::warn!("failed to query tool experience: {e}"))
510            .ok()?
511            .filter_map(|r| {
512                r.inspect_err(|e| tracing::warn!("skipping corrupted tool experience row: {e}"))
513                    .ok()
514            })
515            .collect();
516
517        if rows.is_empty() {
518            return None;
519        }
520
521        let mut text = String::from("[Tool Experience]\n");
522        let mut used = estimate_tokens(&text);
523
524        for (name, _steps, successes, failures) in &rows {
525            let total = *successes + *failures;
526            let rate = if total > 0 {
527                (*successes as f64 / total as f64 * 100.0) as u32
528            } else {
529                0
530            };
531            let line = format!("- {name}: {successes}/{total} success ({rate}%)\n");
532            let line_tokens = estimate_tokens(&line);
533            if used + line_tokens > budget_tokens {
534                break;
535            }
536            text.push_str(&line);
537            used += line_tokens;
538        }
539
540        if text.len() > "[Tool Experience]\n".len() {
541            Some(text)
542        } else {
543            None
544        }
545    }
546
547    fn retrieve_relationships(
548        &self,
549        db: &Database,
550        query: &str,
551        budget_tokens: usize,
552    ) -> Option<String> {
553        if budget_tokens == 0 {
554            return None;
555        }
556
557        let conn = db.conn();
558        let mut stmt = conn
559            .prepare(
560                "SELECT entity_id, entity_name, trust_score, interaction_count \
561                 FROM relationship_memory ORDER BY interaction_count DESC LIMIT 5",
562            )
563            .ok()?;
564
565        let rows: Vec<(String, Option<String>, f64, i64)> = stmt
566            .query_map([], |row| {
567                Ok((
568                    row.get::<_, String>(0)?,
569                    row.get::<_, Option<String>>(1)?,
570                    row.get::<_, f64>(2)?,
571                    row.get::<_, i64>(3)?,
572                ))
573            })
574            .inspect_err(|e| tracing::warn!("failed to query relationship memory: {e}"))
575            .ok()?
576            .filter_map(|r| {
577                r.inspect_err(|e| tracing::warn!("skipping corrupted relationship row: {e}"))
578                    .ok()
579            })
580            .collect();
581
582        if rows.is_empty() {
583            return None;
584        }
585
586        // Only include entities that might be relevant: name appears in query, or high interaction count
587        let query_lower = query.to_lowercase();
588        let relevant: Vec<_> = rows
589            .into_iter()
590            .filter(|(id, name, _, count)| {
591                *count > 2
592                    || query_lower.contains(&id.to_lowercase())
593                    || name
594                        .as_ref()
595                        .is_some_and(|n| query_lower.contains(&n.to_lowercase()))
596            })
597            .collect();
598
599        if relevant.is_empty() {
600            return None;
601        }
602
603        let mut text = String::from("[Known Entities]\n");
604        let mut used = estimate_tokens(&text);
605
606        for (entity_id, name, trust, count) in &relevant {
607            let display = name.as_deref().unwrap_or(entity_id);
608            let line = format!("- {display}: trust={trust:.1}, interactions={count}\n");
609            let line_tokens = estimate_tokens(&line);
610            if used + line_tokens > budget_tokens {
611                break;
612            }
613            text.push_str(&line);
614            used += line_tokens;
615        }
616
617        if text.len() > "[Known Entities]\n".len() {
618            Some(text)
619        } else {
620            None
621        }
622    }
623}
624
625fn query_requests_inactive_memories(query: &str) -> bool {
626    let lower = query.to_ascii_lowercase();
627    [
628        "history",
629        "historical",
630        "previous",
631        "previously",
632        "earlier",
633        "before",
634        "past",
635        "old",
636        "resolved",
637        "stale",
638        "archive",
639        "archived",
640    ]
641    .iter()
642    .any(|term| lower.contains(term))
643}
644
645fn estimate_tokens(text: &str) -> usize {
646    text.len().div_ceil(4)
647}
648
649// ── Content chunking ────────────────────────────────────────────
650
651pub struct ChunkConfig {
652    pub max_tokens: usize,
653    pub overlap_tokens: usize,
654}
655
656impl Default for ChunkConfig {
657    fn default() -> Self {
658        Self {
659            max_tokens: 512,
660            overlap_tokens: 64,
661        }
662    }
663}
664
665pub struct Chunk {
666    pub text: String,
667    pub index: usize,
668    pub start_char: usize,
669    pub end_char: usize,
670}
671
672/// Snap a byte offset to the nearest char boundary at or before `pos`.
673fn floor_char_boundary(text: &str, pos: usize) -> usize {
674    if pos >= text.len() {
675        return text.len();
676    }
677    let mut p = pos;
678    while p > 0 && !text.is_char_boundary(p) {
679        p -= 1;
680    }
681    p
682}
683
684/// Split text into overlapping chunks for embedding.
685pub fn chunk_text(text: &str, config: &ChunkConfig) -> Vec<Chunk> {
686    if text.is_empty() || config.max_tokens == 0 {
687        return Vec::new();
688    }
689
690    let max_bytes = config.max_tokens * 4;
691    let overlap_bytes = config.overlap_tokens * 4;
692
693    if text.len() <= max_bytes {
694        return vec![Chunk {
695            text: text.to_string(),
696            index: 0,
697            start_char: 0,
698            end_char: text.len(),
699        }];
700    }
701
702    let step = max_bytes.saturating_sub(overlap_bytes).max(1);
703    let mut chunks = Vec::new();
704    let mut start = 0;
705
706    while start < text.len() {
707        let raw_end = floor_char_boundary(text, (start + max_bytes).min(text.len()));
708
709        let end = find_break_point(text, start, raw_end);
710
711        chunks.push(Chunk {
712            text: text[start..end].to_string(),
713            index: chunks.len(),
714            start_char: start,
715            end_char: end,
716        });
717
718        if end >= text.len() {
719            break;
720        }
721
722        let advance = step.min(end - start).max(1);
723        start = floor_char_boundary(text, start + advance);
724    }
725
726    chunks
727}
728
729fn find_break_point(text: &str, start: usize, raw_end: usize) -> usize {
730    if raw_end >= text.len() {
731        return text.len();
732    }
733
734    let search_start = floor_char_boundary(text, start + (raw_end - start) / 2);
735    let window = &text[search_start..raw_end];
736
737    if let Some(pos) = window.rfind("\n\n") {
738        return search_start + pos + 2;
739    }
740    for delim in [". ", ".\n", "? ", "! "] {
741        if let Some(pos) = window.rfind(delim) {
742            return search_start + pos + delim.len();
743        }
744    }
745    if let Some(pos) = window.rfind(' ') {
746        return search_start + pos + 1;
747    }
748
749    raw_end
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755
756    fn test_db() -> Database {
757        Database::new(":memory:").unwrap()
758    }
759
760    fn default_config() -> MemoryConfig {
761        MemoryConfig::default()
762    }
763
764    #[test]
765    fn retriever_empty_db_returns_empty() {
766        let db = test_db();
767        let retriever = MemoryRetriever::new(default_config());
768        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
769        let result = retriever.retrieve(&db, &session_id, "hello", None, ComplexityLevel::L1);
770        assert!(result.is_empty());
771    }
772
773    #[test]
774    fn retriever_returns_working_memory() {
775        let db = test_db();
776        let retriever = MemoryRetriever::new(default_config());
777        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
778
779        roboticus_db::memory::store_working(&db, &session_id, "goal", "find documentation", 8)
780            .unwrap();
781
782        let result = retriever.retrieve(&db, &session_id, "hello", None, ComplexityLevel::L2);
783        assert!(result.contains("Working Memory"));
784        assert!(result.contains("find documentation"));
785    }
786
787    #[test]
788    fn retriever_skips_turn_summary_working_entries() {
789        let db = test_db();
790        let retriever = MemoryRetriever::new(default_config());
791        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
792
793        roboticus_db::memory::store_working(
794            &db,
795            &session_id,
796            "turn_summary",
797            "Good to be back on familiar ground.",
798            9,
799        )
800        .unwrap();
801        roboticus_db::memory::store_working(&db, &session_id, "goal", "fix Telegram loop", 8)
802            .unwrap();
803
804        let result = retriever.retrieve(&db, &session_id, "telegram", None, ComplexityLevel::L2);
805        assert!(result.contains("Working Memory"));
806        assert!(result.contains("fix Telegram loop"));
807        assert!(!result.contains("Good to be back on familiar ground."));
808    }
809
810    #[test]
811    fn retriever_returns_relevant_memories() {
812        let db = test_db();
813        let retriever = MemoryRetriever::new(default_config());
814        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
815
816        roboticus_db::memory::store_semantic(&db, "facts", "sky", "the sky is blue", 0.9).unwrap();
817
818        let result = retriever.retrieve(&db, &session_id, "sky", None, ComplexityLevel::L2);
819        assert!(result.contains("Active Memory"));
820    }
821
822    #[test]
823    fn retriever_returns_procedural_experience() {
824        let db = test_db();
825        let retriever = MemoryRetriever::new(default_config());
826        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
827
828        roboticus_db::memory::store_procedural(&db, "web_search", "search the web").unwrap();
829        roboticus_db::memory::record_procedural_success(&db, "web_search").unwrap();
830        roboticus_db::memory::record_procedural_success(&db, "web_search").unwrap();
831
832        let result = retriever.retrieve(&db, &session_id, "search", None, ComplexityLevel::L2);
833        assert!(result.contains("Tool Experience"));
834        assert!(result.contains("web_search"));
835    }
836
837    #[test]
838    fn retriever_returns_relationships() {
839        let db = test_db();
840        let retriever = MemoryRetriever::new(default_config());
841        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
842
843        roboticus_db::memory::store_relationship(&db, "user-1", "Jon", 0.9).unwrap();
844        // Need > 2 interactions or name in query
845        let result = retriever.retrieve(&db, &session_id, "Jon", None, ComplexityLevel::L2);
846        assert!(result.contains("Known Entities") || result.contains("Jon"));
847    }
848
849    #[test]
850    fn retriever_respects_zero_budget() {
851        let config = MemoryConfig {
852            working_budget_pct: 0.0,
853            episodic_budget_pct: 0.0,
854            semantic_budget_pct: 0.0,
855            procedural_budget_pct: 0.0,
856            relationship_budget_pct: 100.0,
857            ..default_config()
858        };
859        let db = test_db();
860        let retriever = MemoryRetriever::new(config);
861        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
862
863        roboticus_db::memory::store_working(&db, &session_id, "goal", "test", 5).unwrap();
864
865        let result = retriever.retrieve(&db, &session_id, "test", None, ComplexityLevel::L0);
866        assert!(!result.contains("Working Memory"));
867    }
868
869    #[test]
870    fn retriever_similarity_threshold_filters_low_similarity_results() {
871        let config = MemoryConfig {
872            similarity_threshold: 0.4,
873            ..default_config()
874        };
875        let db = test_db();
876        let retriever = MemoryRetriever::new(config);
877        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
878
879        let active_id = roboticus_db::memory::store_semantic(
880            &db,
881            "facts",
882            "high-match",
883            "deployment rollback stabilizes the incident",
884            0.9,
885        )
886        .unwrap();
887        let low_id = roboticus_db::memory::store_semantic(
888            &db,
889            "facts",
890            "low-match",
891            "botanical orchids in alpine valleys",
892            0.9,
893        )
894        .unwrap();
895
896        roboticus_db::embeddings::store_embedding(
897            &db,
898            "emb-high",
899            "semantic_memory",
900            &active_id,
901            "deployment rollback stabilizes the incident",
902            &[1.0, 0.0],
903        )
904        .unwrap();
905        roboticus_db::embeddings::store_embedding(
906            &db,
907            "emb-low",
908            "semantic_memory",
909            &low_id,
910            "botanical orchids in alpine valleys",
911            &[-1.0, 0.0],
912        )
913        .unwrap();
914
915        let output = retriever.retrieve_with_metrics(
916            &db,
917            &session_id,
918            "deployment rollback stabilizes the incident",
919            Some(&[1.0, 0.0]),
920            ComplexityLevel::L2,
921            None,
922        );
923        assert!(output.text.contains("deployment rollback"));
924        assert!(
925            !output.text.contains("botanical orchids"),
926            "low-similarity results should be removed by the threshold"
927        );
928        assert!(output.metrics.avg_similarity >= 0.4);
929    }
930
931    // ── Chunker tests ───────────────────────────────────────────
932
933    #[test]
934    fn chunk_empty_text() {
935        let chunks = chunk_text("", &ChunkConfig::default());
936        assert!(chunks.is_empty());
937    }
938
939    #[test]
940    fn chunk_short_text() {
941        let text = "This is a short sentence.";
942        let chunks = chunk_text(text, &ChunkConfig::default());
943        assert_eq!(chunks.len(), 1);
944        assert_eq!(chunks[0].text, text);
945        assert_eq!(chunks[0].index, 0);
946    }
947
948    #[test]
949    fn chunk_long_text_produces_overlapping_chunks() {
950        let text = "word ".repeat(1000);
951        let config = ChunkConfig {
952            max_tokens: 50,
953            overlap_tokens: 10,
954        };
955        let chunks = chunk_text(&text, &config);
956        assert!(chunks.len() > 1);
957
958        for (i, chunk) in chunks.iter().enumerate() {
959            assert_eq!(chunk.index, i);
960            assert!(!chunk.text.is_empty());
961        }
962
963        // Verify continuity: each chunk's start is before the previous chunk's end
964        for i in 1..chunks.len() {
965            assert!(chunks[i].start_char < chunks[i - 1].end_char);
966        }
967    }
968
969    #[test]
970    fn chunk_respects_sentence_boundaries() {
971        let text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence. \
972                    Sixth sentence. Seventh sentence. Eighth sentence. Ninth sentence. Tenth sentence.";
973        let config = ChunkConfig {
974            max_tokens: 20,
975            overlap_tokens: 5,
976        };
977        let chunks = chunk_text(text, &config);
978        // Chunks should end at sentence boundaries when possible
979        for chunk in &chunks {
980            if chunk.end_char < text.len() {
981                let ends_at_boundary = chunk.text.ends_with(". ")
982                    || chunk.text.ends_with('.')
983                    || chunk.text.ends_with(' ');
984                assert!(
985                    ends_at_boundary,
986                    "chunk should end at a boundary: {:?}",
987                    &chunk.text[chunk.text.len().saturating_sub(10)..]
988                );
989            }
990        }
991    }
992
993    #[test]
994    fn chunk_covers_full_text() {
995        let text = "a ".repeat(500);
996        let config = ChunkConfig {
997            max_tokens: 25,
998            overlap_tokens: 5,
999        };
1000        let chunks = chunk_text(&text, &config);
1001
1002        assert_eq!(chunks.first().unwrap().start_char, 0);
1003        assert_eq!(chunks.last().unwrap().end_char, text.len());
1004    }
1005
1006    #[test]
1007    fn chunk_zero_max_tokens() {
1008        let chunks = chunk_text(
1009            "some text",
1010            &ChunkConfig {
1011                max_tokens: 0,
1012                overlap_tokens: 0,
1013            },
1014        );
1015        assert!(chunks.is_empty());
1016    }
1017
1018    #[test]
1019    fn estimate_tokens_basic() {
1020        assert_eq!(estimate_tokens(""), 0);
1021        assert_eq!(estimate_tokens("abcd"), 1);
1022        assert_eq!(estimate_tokens("hello world!"), 3);
1023    }
1024
1025    #[test]
1026    fn chunk_multibyte_does_not_panic() {
1027        let text = "Hello \u{1F600} world. ".repeat(200);
1028        let config = ChunkConfig {
1029            max_tokens: 20,
1030            overlap_tokens: 5,
1031        };
1032        let chunks = chunk_text(&text, &config);
1033        assert!(chunks.len() > 1);
1034        for chunk in &chunks {
1035            assert!(!chunk.text.is_empty());
1036            // Verify each chunk is valid UTF-8 (would panic on slice if not)
1037            let _ = chunk.text.as_bytes();
1038        }
1039    }
1040
1041    #[test]
1042    fn chunk_cjk_text() {
1043        let text = "\u{4F60}\u{597D}\u{4E16}\u{754C} ".repeat(300);
1044        let config = ChunkConfig {
1045            max_tokens: 15,
1046            overlap_tokens: 3,
1047        };
1048        let chunks = chunk_text(&text, &config);
1049        assert!(chunks.len() > 1);
1050        assert_eq!(chunks.first().unwrap().start_char, 0);
1051        assert_eq!(chunks.last().unwrap().end_char, text.len());
1052    }
1053
1054    #[test]
1055    fn floor_char_boundary_ascii() {
1056        let text = "hello world";
1057        assert_eq!(floor_char_boundary(text, 5), 5);
1058        assert_eq!(floor_char_boundary(text, 0), 0);
1059        assert_eq!(floor_char_boundary(text, 100), text.len());
1060    }
1061
1062    #[test]
1063    fn floor_char_boundary_multibyte() {
1064        // "café" = c(1) a(1) f(1) é(2) = 5 bytes total
1065        let text = "caf\u{00E9}";
1066        assert_eq!(text.len(), 5);
1067        // Position 4 is inside the 2-byte é, should snap back to 3
1068        assert_eq!(floor_char_boundary(text, 4), 3);
1069        // Position 3 is a valid boundary (start of é)
1070        assert_eq!(floor_char_boundary(text, 3), 3);
1071        // Position 5 >= len, returns len
1072        assert_eq!(floor_char_boundary(text, 5), 5);
1073    }
1074
1075    #[test]
1076    fn floor_char_boundary_emoji() {
1077        let text = "a\u{1F600}b"; // a(1) + emoji(4) + b(1) = 6 bytes
1078        assert_eq!(text.len(), 6);
1079        // Position 2 is inside the emoji
1080        assert_eq!(floor_char_boundary(text, 2), 1);
1081        // Position 5 is the start of 'b'
1082        assert_eq!(floor_char_boundary(text, 5), 5);
1083    }
1084
1085    #[test]
1086    fn estimate_tokens_rounding() {
1087        // div_ceil(1, 4) = 1
1088        assert_eq!(estimate_tokens("a"), 1);
1089        // div_ceil(5, 4) = 2
1090        assert_eq!(estimate_tokens("abcde"), 2);
1091        // div_ceil(8, 4) = 2
1092        assert_eq!(estimate_tokens("abcdefgh"), 2);
1093    }
1094
1095    #[test]
1096    fn retriever_with_procedural_no_history() {
1097        // Procedural with no success/failure counts should return None
1098        let db = test_db();
1099        let retriever = MemoryRetriever::new(default_config());
1100        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
1101
1102        roboticus_db::memory::store_procedural(&db, "unused_tool", "a tool").unwrap();
1103
1104        let result = retriever.retrieve(&db, &session_id, "test", None, ComplexityLevel::L2);
1105        assert!(
1106            !result.contains("Tool Experience"),
1107            "tools with no success/failure should not appear"
1108        );
1109    }
1110
1111    #[test]
1112    fn chunk_with_paragraph_breaks() {
1113        let text = "Paragraph one content.\n\nParagraph two content.\n\nParagraph three content.\n\n\
1114                    Paragraph four content.\n\nParagraph five content.";
1115        let config = ChunkConfig {
1116            max_tokens: 15,
1117            overlap_tokens: 3,
1118        };
1119        let chunks = chunk_text(text, &config);
1120        // Should prefer breaking at paragraph boundaries
1121        for chunk in &chunks {
1122            if chunk.end_char < text.len() {
1123                // Many chunks should end at paragraph breaks
1124                let last_few = &chunk.text[chunk.text.len().saturating_sub(5)..];
1125                let has_good_break =
1126                    last_few.contains('\n') || last_few.contains(". ") || last_few.ends_with(' ');
1127                assert!(has_good_break, "chunk should end at a reasonable boundary");
1128            }
1129        }
1130    }
1131
1132    #[test]
1133    fn chunk_config_default() {
1134        let config = ChunkConfig::default();
1135        assert_eq!(config.max_tokens, 512);
1136        assert_eq!(config.overlap_tokens, 64);
1137    }
1138
1139    #[test]
1140    fn find_break_point_at_end_of_text() {
1141        let text = "Hello world.";
1142        assert_eq!(find_break_point(text, 0, text.len()), text.len());
1143    }
1144
1145    #[test]
1146    fn retriever_relationships_high_interaction_count() {
1147        let db = test_db();
1148        let retriever = MemoryRetriever::new(default_config());
1149        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
1150
1151        // store_relationship uses ON CONFLICT to increment interaction_count
1152        // Calling it 4 times gives interaction_count > 2
1153        for _ in 0..4 {
1154            roboticus_db::memory::store_relationship(&db, "alice", "Alice Smith", 0.8).unwrap();
1155        }
1156
1157        // Query that doesn't contain "alice" but high interaction count should still include it
1158        let result = retriever.retrieve(
1159            &db,
1160            &session_id,
1161            "some random query",
1162            None,
1163            ComplexityLevel::L2,
1164        );
1165        assert!(
1166            result.contains("Known Entities") && result.contains("Alice Smith"),
1167            "high interaction count entity should appear in results"
1168        );
1169    }
1170
1171    #[test]
1172    fn retriever_suppresses_stale_digests_by_default() {
1173        let db = test_db();
1174        let retriever = MemoryRetriever::new(default_config());
1175        let session_id = roboticus_db::sessions::find_or_create(&db, "agent-1", None).unwrap();
1176        let stale_id = roboticus_db::memory::store_episodic_with_meta(
1177            &db,
1178            "digest",
1179            "[Session Digest] alpha rollout incident resolved",
1180            9,
1181            Some("agent-1"),
1182            "active",
1183            None,
1184        )
1185        .unwrap();
1186        roboticus_db::memory::mark_episodic_digests_stale_for_owner(
1187            &db,
1188            "agent-1",
1189            "newer-digest",
1190            "superseded",
1191        )
1192        .unwrap();
1193        let conn = db.conn();
1194        conn.execute(
1195            "UPDATE episodic_memory SET memory_state = 'stale' WHERE id = ?1",
1196            [stale_id],
1197        )
1198        .unwrap();
1199        drop(conn);
1200        roboticus_db::memory::store_episodic_with_meta(
1201            &db,
1202            "digest",
1203            "[Session Digest] beta stabilization plan active",
1204            9,
1205            Some("agent-1"),
1206            "active",
1207            None,
1208        )
1209        .unwrap();
1210
1211        let result = retriever.retrieve(
1212            &db,
1213            &session_id,
1214            "alpha beta digest",
1215            None,
1216            ComplexityLevel::L2,
1217        );
1218        assert!(result.contains("beta stabilization plan"));
1219        assert!(
1220            !result.contains("alpha rollout incident resolved"),
1221            "stale digests should be suppressed unless the query explicitly asks for history"
1222        );
1223    }
1224
1225    #[test]
1226    fn retriever_includes_stale_digests_when_history_requested() {
1227        let db = test_db();
1228        let retriever = MemoryRetriever::new(default_config());
1229        let session_id = roboticus_db::sessions::find_or_create(&db, "agent-1", None).unwrap();
1230        let stale_id = roboticus_db::memory::store_episodic_with_meta(
1231            &db,
1232            "digest",
1233            "[Session Digest] alpha rollout incident resolved",
1234            9,
1235            Some("agent-1"),
1236            "stale",
1237            Some("superseded"),
1238        )
1239        .unwrap();
1240        assert!(!stale_id.is_empty());
1241        roboticus_db::memory::store_episodic_with_meta(
1242            &db,
1243            "digest",
1244            "[Session Digest] beta stabilization plan active",
1245            9,
1246            Some("agent-1"),
1247            "active",
1248            None,
1249        )
1250        .unwrap();
1251
1252        let result = retriever.retrieve(
1253            &db,
1254            &session_id,
1255            "show previous history for the alpha beta digest",
1256            None,
1257            ComplexityLevel::L2,
1258        );
1259        assert!(result.contains("alpha rollout incident resolved"));
1260        assert!(result.contains("beta stabilization plan active"));
1261    }
1262
1263    #[test]
1264    fn retriever_suppresses_stale_semantic_summaries_by_default() {
1265        let db = test_db();
1266        let retriever = MemoryRetriever::new(default_config());
1267        let session_id = roboticus_db::sessions::find_or_create(&db, "agent-1", None).unwrap();
1268
1269        roboticus_db::memory::store_semantic(
1270            &db,
1271            "learned",
1272            "session:agent-1:alpha",
1273            "alpha policy was retired after the incident",
1274            0.8,
1275        )
1276        .unwrap();
1277        let active_id = roboticus_db::memory::store_semantic(
1278            &db,
1279            "learned",
1280            "session:agent-1:beta",
1281            "beta policy is active with the latest safeguards",
1282            0.9,
1283        )
1284        .unwrap();
1285        roboticus_db::memory::mark_semantic_stale_by_category_and_key_prefix(
1286            &db,
1287            "learned",
1288            "session:agent-1:",
1289            &active_id,
1290            "superseded_by_newer_session_summary",
1291        )
1292        .unwrap();
1293
1294        let result = retriever.retrieve(
1295            &db,
1296            &session_id,
1297            "alpha beta policy safeguards",
1298            None,
1299            ComplexityLevel::L2,
1300        );
1301        assert!(result.contains("beta policy is active"));
1302        assert!(
1303            !result.contains("alpha policy was retired"),
1304            "stale semantic summaries should be suppressed unless history is requested"
1305        );
1306    }
1307
1308    #[test]
1309    fn retriever_includes_stale_semantic_summaries_when_history_requested() {
1310        let db = test_db();
1311        let retriever = MemoryRetriever::new(default_config());
1312        let session_id = roboticus_db::sessions::find_or_create(&db, "agent-1", None).unwrap();
1313
1314        roboticus_db::memory::store_semantic(
1315            &db,
1316            "learned",
1317            "session:agent-1:alpha",
1318            "alpha policy was retired after the incident",
1319            0.8,
1320        )
1321        .unwrap();
1322        let active_id = roboticus_db::memory::store_semantic(
1323            &db,
1324            "learned",
1325            "session:agent-1:beta",
1326            "beta policy is active with the latest safeguards",
1327            0.9,
1328        )
1329        .unwrap();
1330        roboticus_db::memory::mark_semantic_stale_by_category_and_key_prefix(
1331            &db,
1332            "learned",
1333            "session:agent-1:",
1334            &active_id,
1335            "superseded_by_newer_session_summary",
1336        )
1337        .unwrap();
1338
1339        let result = retriever.retrieve(
1340            &db,
1341            &session_id,
1342            "show history of the alpha beta policy change",
1343            None,
1344            ComplexityLevel::L2,
1345        );
1346        assert!(result.contains("beta policy is active"));
1347        assert!(result.contains("alpha policy was retired"));
1348    }
1349
1350    #[test]
1351    fn retrieve_with_metrics_empty_db() {
1352        let db = test_db();
1353        let retriever = MemoryRetriever::new(default_config());
1354        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
1355        let output = retriever.retrieve_with_metrics(
1356            &db,
1357            &session_id,
1358            "hello",
1359            None,
1360            ComplexityLevel::L1,
1361            None,
1362        );
1363        assert!(output.text.is_empty());
1364        assert!(!output.metrics.retrieval_hit);
1365        assert_eq!(output.metrics.retrieval_count, 0);
1366        assert_eq!(output.metrics.avg_similarity, 0.0);
1367        assert_eq!(output.metrics.budget_utilization, 0.0);
1368    }
1369
1370    #[test]
1371    fn retrieve_with_metrics_working_memory_counted() {
1372        let db = test_db();
1373        let retriever = MemoryRetriever::new(default_config());
1374        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
1375
1376        roboticus_db::memory::store_working(&db, &session_id, "goal", "fix the pipeline", 8)
1377            .unwrap();
1378        roboticus_db::memory::store_working(&db, &session_id, "note", "version 0.11", 7).unwrap();
1379
1380        let output = retriever.retrieve_with_metrics(
1381            &db,
1382            &session_id,
1383            "hello",
1384            None,
1385            ComplexityLevel::L2,
1386            None,
1387        );
1388        assert!(output.metrics.retrieval_hit);
1389        assert!(
1390            output.metrics.tiers.working >= 2,
1391            "working tier count should reflect stored entries"
1392        );
1393        assert!(output.metrics.retrieval_count >= 2);
1394        assert!(output.metrics.budget_utilization > 0.0);
1395
1396        // Serialization should produce valid JSON
1397        let json = serde_json::to_string(&output.metrics.tiers).unwrap();
1398        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1399        assert!(parsed["working"].as_u64().unwrap() >= 2);
1400    }
1401
1402    #[test]
1403    fn retrieve_with_metrics_procedural_counted() {
1404        let db = test_db();
1405        let retriever = MemoryRetriever::new(default_config());
1406        let session_id = roboticus_db::sessions::find_or_create(&db, "test-agent", None).unwrap();
1407
1408        roboticus_db::memory::store_procedural(&db, "web_search", "search the web").unwrap();
1409        roboticus_db::memory::record_procedural_success(&db, "web_search").unwrap();
1410
1411        let output = retriever.retrieve_with_metrics(
1412            &db,
1413            &session_id,
1414            "search",
1415            None,
1416            ComplexityLevel::L2,
1417            None,
1418        );
1419        assert!(output.metrics.tiers.procedural >= 1);
1420    }
1421}