oxios_memory/memory/
proactive.rs1use std::collections::HashSet;
11
12use anyhow::Result;
13
14use crate::memory::decay::DecayEngine;
15use crate::memory::manager::MemoryManager;
16use crate::memory::types::{MemoryEntry, MemoryTier};
17
18#[derive(Debug, Clone, Default)]
24pub struct RecallTiming {
25 pub last_recall_topic: Option<String>,
27 pub message_count_since_recall: usize,
29}
30
31impl RecallTiming {
32 pub fn new() -> Self {
34 Self {
35 last_recall_topic: None,
36 message_count_since_recall: 0,
37 }
38 }
39
40 pub fn should_recall(&mut self, query: &str) -> bool {
47 let topic_changed = self
48 .last_recall_topic
49 .as_ref()
50 .is_none_or(|prev| !topics_similar(prev, query));
51
52 let should = self.message_count_since_recall == 0 || (topic_changed && self.message_count_since_recall >= 3) || self.message_count_since_recall >= 10; if should {
57 self.last_recall_topic = Some(query.to_string());
58 self.message_count_since_recall = 0;
59 } else {
60 self.message_count_since_recall += 1;
61 }
62 should
63 }
64}
65
66fn topics_similar(a: &str, b: &str) -> bool {
68 let a_words: HashSet<String> = a
69 .to_lowercase()
70 .split_whitespace()
71 .filter(|w| w.len() > 3)
72 .map(|w| w.to_string())
73 .collect();
74 let b_words: HashSet<String> = b
75 .to_lowercase()
76 .split_whitespace()
77 .filter(|w| w.len() > 3)
78 .map(|w| w.to_string())
79 .collect();
80
81 if a_words.is_empty() || b_words.is_empty() {
82 return false;
83 }
84
85 let overlap = a_words.intersection(&b_words).count();
86 let union = a_words.union(&b_words).count();
87 overlap as f32 / union as f32 > 0.3
89}
90
91pub struct ProactiveRecall {
100 pub limit: usize,
102 pub threshold: f32,
104}
105
106impl ProactiveRecall {
107 pub fn new(limit: usize, threshold: f32) -> Self {
109 Self { limit, threshold }
110 }
111
112 pub async fn recall(
114 &self,
115 mgr: &MemoryManager,
116 query: &str,
117 current_context: &[MemoryEntry],
118 ) -> Result<Vec<MemoryEntry>> {
119 let mut results = Vec::new();
120 let mut seen_ids: HashSet<String> = current_context.iter().map(|e| e.id.clone()).collect();
121
122 if let Ok(hot_entries) = mgr.list_by_tier(MemoryTier::Hot, self.limit).await {
124 for entry in hot_entries {
125 if !seen_ids.contains(&entry.id) {
126 seen_ids.insert(entry.id.clone());
127 results.push(entry);
128 }
129 }
130 }
131
132 if results.len() < self.limit {
134 let remaining = self.limit - results.len();
135 let search_results = mgr
136 .search(query, None, remaining * 2)
137 .await
138 .unwrap_or_default();
139 for entry in search_results {
140 if !seen_ids.contains(&entry.id) {
141 seen_ids.insert(entry.id.clone());
142 results.push(entry);
143 }
144 if results.len() >= self.limit {
145 break;
146 }
147 }
148 }
149
150 results.retain(|e| DecayEngine::effective_importance(e) >= self.threshold);
152
153 Ok(results)
154 }
155}
156
157#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn test_recall_timing_first_message() {
167 let mut timing = RecallTiming::new();
168 assert!(timing.should_recall("hello"));
169 }
170
171 #[test]
172 fn test_recall_timing_topic_change() {
173 let mut timing = RecallTiming::new();
174 timing.should_recall("rust programming");
175 timing.message_count_since_recall = 5;
176 assert!(timing.should_recall("python deployment"));
177 }
178
179 #[test]
180 fn test_recall_timing_same_topic() {
181 let mut timing = RecallTiming::new();
182 timing.should_recall("rust async runtime");
183 timing.message_count_since_recall = 1;
184 assert!(!timing.should_recall("rust async tokio"));
185 }
186
187 #[test]
188 fn test_recall_timing_periodic() {
189 let mut timing = RecallTiming::new();
190 timing.should_recall("rust");
191 timing.message_count_since_recall = 10;
192 assert!(timing.should_recall("rust continued"));
193 }
194
195 #[test]
196 fn test_topics_similar_same() {
197 assert!(topics_similar("rust async runtime", "rust async runtime"));
198 }
199
200 #[test]
201 fn test_topics_similar_overlap() {
202 assert!(topics_similar(
203 "rust async runtime tokio",
204 "rust async runtime futures"
205 ));
206 }
207
208 #[test]
209 fn test_topics_different() {
210 assert!(!topics_similar("rust async runtime", "python data science"));
211 }
212}