Skip to main content

oxios_memory/memory/
proactive.rs

1//! Proactive recall — automatically inject relevant memories into context.
2//!
3//! Implements 3-step selective recall:
4//! 1. ROOT index triage (O(1) topic lookup)
5//! 2. Manifest-based selection (keyword matching)
6//! 3. HNSW vector search (semantic similarity)
7//!
8//! Only triggered at session start and topic transitions to avoid context bloat.
9
10use std::collections::HashSet;
11
12use anyhow::Result;
13
14use crate::memory::decay::DecayEngine;
15use crate::memory::manager::MemoryManager;
16use crate::memory::types::{MemoryEntry, MemoryTier};
17
18// ---------------------------------------------------------------------------
19// RecallTiming
20// ---------------------------------------------------------------------------
21
22/// Tracks when proactive recall should be triggered.
23#[derive(Debug, Clone, Default)]
24pub struct RecallTiming {
25    /// Last topic that triggered a recall.
26    pub last_recall_topic: Option<String>,
27    /// Messages since last recall.
28    pub message_count_since_recall: usize,
29}
30
31impl RecallTiming {
32    /// Create a new timing tracker.
33    pub fn new() -> Self {
34        Self {
35            last_recall_topic: None,
36            message_count_since_recall: 0,
37        }
38    }
39
40    /// Check if proactive recall should fire for the given query.
41    ///
42    /// Triggers on:
43    /// - Session first message (count == 0)
44    /// - Topic change (after at least 3 messages)
45    /// - Periodic (every 10 messages)
46    pub fn should_recall(&mut self, query: &str) -> bool {
47        let topic_changed = self
48            .last_recall_topic
49            .as_ref()
50            .is_none_or(|prev| !topics_similar(prev, query));
51
52        let should = self.message_count_since_recall == 0 // First message
53            || (topic_changed && self.message_count_since_recall >= 3) // Topic change
54            || self.message_count_since_recall >= 10; // Periodic
55
56        if should {
57            self.last_recall_topic = Some(query.to_string());
58            self.message_count_since_recall = 0;
59        } else {
60            self.message_count_since_recall += 1;
61        }
62        should
63    }
64}
65
66/// Simple topic similarity check (keyword overlap).
67fn topics_similar(a: &str, b: &str) -> bool {
68    let a_words: HashSet<String> = a
69        .to_lowercase()
70        .split_whitespace()
71        .filter(|w| w.len() > 3)
72        .map(|w| w.to_string())
73        .collect();
74    let b_words: HashSet<String> = b
75        .to_lowercase()
76        .split_whitespace()
77        .filter(|w| w.len() > 3)
78        .map(|w| w.to_string())
79        .collect();
80
81    if a_words.is_empty() || b_words.is_empty() {
82        return false;
83    }
84
85    let overlap = a_words.intersection(&b_words).count();
86    let union = a_words.union(&b_words).count();
87    // Jaccard similarity > 0.3
88    overlap as f32 / union as f32 > 0.3
89}
90
91// ---------------------------------------------------------------------------
92// ProactiveRecall
93// ---------------------------------------------------------------------------
94
95/// Proactive recall engine.
96///
97/// Combines ROOT index triage, manifest-based selection, and HNSW
98/// semantic search to find relevant memories for the current context.
99pub struct ProactiveRecall {
100    /// Maximum results to return.
101    pub limit: usize,
102    /// Minimum effective importance threshold.
103    pub threshold: f32,
104}
105
106impl ProactiveRecall {
107    /// Create with the given limit and threshold.
108    pub fn new(limit: usize, threshold: f32) -> Self {
109        Self { limit, threshold }
110    }
111
112    /// Execute 3-step proactive recall.
113    pub async fn recall(
114        &self,
115        mgr: &MemoryManager,
116        query: &str,
117        current_context: &[MemoryEntry],
118    ) -> Result<Vec<MemoryEntry>> {
119        let mut results = Vec::new();
120        let mut seen_ids: HashSet<String> = current_context.iter().map(|e| e.id.clone()).collect();
121
122        // Step 1: HOT tier memories (always included)
123        if let Ok(hot_entries) = mgr.list_by_tier(MemoryTier::Hot, self.limit).await {
124            for entry in hot_entries {
125                if !seen_ids.contains(&entry.id) {
126                    seen_ids.insert(entry.id.clone());
127                    results.push(entry);
128                }
129            }
130        }
131
132        // Step 2: Semantic + BM25 search
133        if results.len() < self.limit {
134            let remaining = self.limit - results.len();
135            let search_results = mgr
136                .search(query, None, remaining * 2)
137                .await
138                .unwrap_or_default();
139            for entry in search_results {
140                if !seen_ids.contains(&entry.id) {
141                    seen_ids.insert(entry.id.clone());
142                    results.push(entry);
143                }
144                if results.len() >= self.limit {
145                    break;
146                }
147            }
148        }
149
150        // Filter by importance threshold
151        results.retain(|e| DecayEngine::effective_importance(e) >= self.threshold);
152
153        Ok(results)
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Tests
159// ---------------------------------------------------------------------------
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_recall_timing_first_message() {
167        let mut timing = RecallTiming::new();
168        assert!(timing.should_recall("hello"));
169    }
170
171    #[test]
172    fn test_recall_timing_topic_change() {
173        let mut timing = RecallTiming::new();
174        timing.should_recall("rust programming");
175        timing.message_count_since_recall = 5;
176        assert!(timing.should_recall("python deployment"));
177    }
178
179    #[test]
180    fn test_recall_timing_same_topic() {
181        let mut timing = RecallTiming::new();
182        timing.should_recall("rust async runtime");
183        timing.message_count_since_recall = 1;
184        assert!(!timing.should_recall("rust async tokio"));
185    }
186
187    #[test]
188    fn test_recall_timing_periodic() {
189        let mut timing = RecallTiming::new();
190        timing.should_recall("rust");
191        timing.message_count_since_recall = 10;
192        assert!(timing.should_recall("rust continued"));
193    }
194
195    #[test]
196    fn test_topics_similar_same() {
197        assert!(topics_similar("rust async runtime", "rust async runtime"));
198    }
199
200    #[test]
201    fn test_topics_similar_overlap() {
202        assert!(topics_similar(
203            "rust async runtime tokio",
204            "rust async runtime futures"
205        ));
206    }
207
208    #[test]
209    fn test_topics_different() {
210        assert!(!topics_similar("rust async runtime", "python data science"));
211    }
212}