1use super::MemoryChunk;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(default)]
16pub struct ActiveMemoryConfig {
17 pub enabled: bool,
19 pub query_mode: QueryMode,
21 pub max_recent_turns: usize,
23 pub max_results: usize,
25 pub max_chars: usize,
27 pub min_score: f64,
29 pub cache_ttl_ms: u64,
31}
32
33impl Default for ActiveMemoryConfig {
34 fn default() -> Self {
35 Self {
36 enabled: false,
37 query_mode: QueryMode::Message,
38 max_recent_turns: 4,
39 max_results: 3,
40 max_chars: 500,
41 min_score: 0.1,
42 cache_ttl_ms: 15_000,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
49#[serde(rename_all = "lowercase")]
50pub enum QueryMode {
51 #[default]
53 Message,
54 Recent,
56}
57
58#[derive(Debug)]
60pub enum RecallResult {
61 Recalled(String),
63 Empty,
65 Disabled,
67 CacheHit(String),
69}
70
71pub struct RecallCache {
73 entries: HashMap<u64, CacheEntry>,
74 ttl: Duration,
75}
76
77struct CacheEntry {
78 result: Option<String>,
79 created_at: Instant,
80}
81
82impl RecallCache {
83 pub fn new(ttl_ms: u64) -> Self {
84 Self {
85 entries: HashMap::new(),
86 ttl: Duration::from_millis(ttl_ms),
87 }
88 }
89
90 pub fn get(&self, query_hash: u64) -> Option<&Option<String>> {
91 self.entries.get(&query_hash).and_then(|entry| {
92 if entry.created_at.elapsed() < self.ttl {
93 Some(&entry.result)
94 } else {
95 None
96 }
97 })
98 }
99
100 pub fn put(&mut self, query_hash: u64, result: Option<String>) {
101 if self.entries.len() > 100 {
103 self.entries
104 .retain(|_, entry| entry.created_at.elapsed() < self.ttl);
105 }
106 self.entries.insert(
107 query_hash,
108 CacheEntry {
109 result,
110 created_at: Instant::now(),
111 },
112 );
113 }
114}
115
116pub fn build_query(
118 user_message: &str,
119 recent_messages: &[(String, String)], config: &ActiveMemoryConfig,
121) -> String {
122 match config.query_mode {
123 QueryMode::Message => user_message.to_string(),
124 QueryMode::Recent => {
125 let mut parts = Vec::new();
126 let start = recent_messages
127 .len()
128 .saturating_sub(config.max_recent_turns);
129 for (role, content) in &recent_messages[start..] {
130 let truncated = if content.len() > 200 {
132 &content[..200]
133 } else {
134 content.as_str()
135 };
136 parts.push(format!("{}: {}", role, truncated));
137 }
138 parts.push(format!("user: {}", user_message));
139 parts.join("\n")
140 }
141 }
142}
143
144pub fn format_recalled_context(chunks: &[MemoryChunk], max_chars: usize) -> Option<String> {
146 if chunks.is_empty() {
147 return None;
148 }
149
150 let mut parts = Vec::new();
151 let mut total_chars = 0;
152
153 for chunk in chunks {
154 let entry = chunk.content.trim();
155 if entry.is_empty() {
156 continue;
157 }
158
159 if total_chars + entry.len() > max_chars {
160 let remaining = max_chars.saturating_sub(total_chars);
162 if remaining > 50 {
163 parts.push(format!("- {}...", &entry[..remaining.min(entry.len())]));
164 }
165 break;
166 }
167
168 parts.push(format!("- {}", entry));
169 total_chars += entry.len();
170 }
171
172 if parts.is_empty() {
173 return None;
174 }
175
176 Some(format!(
177 "<recalled_context>\nThe following was automatically recalled from memory and may be relevant:\n{}\n</recalled_context>",
178 parts.join("\n")
179 ))
180}
181
182pub fn query_hash(query: &str) -> u64 {
184 use std::hash::{Hash, Hasher};
185 let mut hasher = std::collections::hash_map::DefaultHasher::new();
186 query.hash(&mut hasher);
187 hasher.finish()
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_build_query_message_mode() {
196 let config = ActiveMemoryConfig {
197 query_mode: QueryMode::Message,
198 ..Default::default()
199 };
200 let query = build_query("What color do I prefer?", &[], &config);
201 assert_eq!(query, "What color do I prefer?");
202 }
203
204 #[test]
205 fn test_build_query_recent_mode() {
206 let config = ActiveMemoryConfig {
207 query_mode: QueryMode::Recent,
208 max_recent_turns: 2,
209 ..Default::default()
210 };
211 let recent = vec![
212 ("user".to_string(), "Hello".to_string()),
213 ("assistant".to_string(), "Hi there!".to_string()),
214 ("user".to_string(), "Tell me about colors".to_string()),
215 ("assistant".to_string(), "Sure, what colors?".to_string()),
216 ];
217 let query = build_query("What color do I prefer?", &recent, &config);
218 assert!(query.contains("Tell me about colors"));
219 assert!(query.contains("Sure, what colors?"));
220 assert!(query.contains("What color do I prefer?"));
221 assert!(!query.contains("Hello"));
223 }
224
225 #[test]
226 fn test_format_recalled_context_empty() {
227 assert!(format_recalled_context(&[], 500).is_none());
228 }
229
230 #[test]
231 fn test_format_recalled_context_basic() {
232 let chunks = vec![
233 MemoryChunk {
234 file: "test.md".to_string(),
235 line_start: 1,
236 line_end: 1,
237 content: "User prefers dark mode".to_string(),
238 score: 0.9,
239 updated_at: 0,
240 },
241 MemoryChunk {
242 file: "test.md".to_string(),
243 line_start: 2,
244 line_end: 2,
245 content: "User works at Acme Corp".to_string(),
246 score: 0.7,
247 updated_at: 0,
248 },
249 ];
250
251 let result = format_recalled_context(&chunks, 500).unwrap();
252 assert!(result.contains("<recalled_context>"));
253 assert!(result.contains("User prefers dark mode"));
254 assert!(result.contains("User works at Acme Corp"));
255 }
256
257 #[test]
258 fn test_format_recalled_context_truncation() {
259 let chunks = vec![MemoryChunk {
260 file: "test.md".to_string(),
261 line_start: 1,
262 line_end: 1,
263 content: "A".repeat(600),
264 score: 0.9,
265 updated_at: 0,
266 }];
267
268 let result = format_recalled_context(&chunks, 100).unwrap();
269 assert!(result.len() < 600);
271 assert!(result.contains("..."));
272 }
273
274 #[test]
275 fn test_recall_cache() {
276 let mut cache = RecallCache::new(60_000); assert!(cache.get(123).is_none());
280
281 cache.put(123, Some("recalled text".to_string()));
283 let hit = cache.get(123).unwrap();
284 assert_eq!(hit.as_deref(), Some("recalled text"));
285
286 cache.put(456, None);
288 let hit = cache.get(456).unwrap();
289 assert!(hit.is_none());
290 }
291
292 #[test]
293 fn test_recall_cache_expired() {
294 let mut cache = RecallCache::new(0); cache.put(123, Some("text".to_string()));
296
297 std::thread::sleep(std::time::Duration::from_millis(1));
299 assert!(cache.get(123).is_none());
300 }
301
302 #[test]
303 fn test_query_hash_deterministic() {
304 let h1 = query_hash("test query");
305 let h2 = query_hash("test query");
306 let h3 = query_hash("different query");
307
308 assert_eq!(h1, h2);
309 assert_ne!(h1, h3);
310 }
311
312 #[test]
313 fn test_default_config_disabled() {
314 let config = ActiveMemoryConfig::default();
315 assert!(!config.enabled);
316 }
317}