Skip to main content

localgpt_core/memory/
active_recall.rs

1//! Active Memory Recall — automatically search memory before generating replies
2//!
3//! When enabled, the agent searches its memory using the user's message as a query
4//! before the LLM call, injecting any relevant recalled context into the conversation.
5//! This ensures user preferences, past decisions, and important facts are surfaced
6//! without requiring the agent to explicitly call memory_search.
7
8use super::MemoryChunk;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::{Duration, Instant};
12
13/// Configuration for active memory recall
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(default)]
16pub struct ActiveMemoryConfig {
17    /// Enable active memory recall before replies. Default: false.
18    pub enabled: bool,
19    /// How to build the search query: "message" (user message only) or "recent" (include recent turns)
20    pub query_mode: QueryMode,
21    /// Number of recent turns to include when query_mode is "recent". Default: 4.
22    pub max_recent_turns: usize,
23    /// Maximum number of memory chunks to recall. Default: 3.
24    pub max_results: usize,
25    /// Maximum total characters of recalled context to inject. Default: 500.
26    pub max_chars: usize,
27    /// Minimum relevance score to include a result (0.0-1.0). Default: 0.1.
28    pub min_score: f64,
29    /// Cache TTL in milliseconds. Prevents redundant searches for repeated queries. Default: 15000.
30    pub cache_ttl_ms: u64,
31}
32
33impl Default for ActiveMemoryConfig {
34    fn default() -> Self {
35        Self {
36            enabled: false,
37            query_mode: QueryMode::Message,
38            max_recent_turns: 4,
39            max_results: 3,
40            max_chars: 500,
41            min_score: 0.1,
42            cache_ttl_ms: 15_000,
43        }
44    }
45}
46
47/// How to construct the memory search query
48#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
49#[serde(rename_all = "lowercase")]
50pub enum QueryMode {
51    /// Use only the current user message as the query
52    #[default]
53    Message,
54    /// Include recent conversation turns for better context
55    Recent,
56}
57
58/// Result of an active memory recall attempt
59#[derive(Debug)]
60pub enum RecallResult {
61    /// Relevant context was recalled
62    Recalled(String),
63    /// No relevant results found
64    Empty,
65    /// Feature is disabled
66    Disabled,
67    /// Cache hit (returning cached result)
68    CacheHit(String),
69}
70
71/// In-memory cache for recall results to avoid redundant searches
72pub struct RecallCache {
73    entries: HashMap<u64, CacheEntry>,
74    ttl: Duration,
75}
76
77struct CacheEntry {
78    result: Option<String>,
79    created_at: Instant,
80}
81
82impl RecallCache {
83    pub fn new(ttl_ms: u64) -> Self {
84        Self {
85            entries: HashMap::new(),
86            ttl: Duration::from_millis(ttl_ms),
87        }
88    }
89
90    pub fn get(&self, query_hash: u64) -> Option<&Option<String>> {
91        self.entries.get(&query_hash).and_then(|entry| {
92            if entry.created_at.elapsed() < self.ttl {
93                Some(&entry.result)
94            } else {
95                None
96            }
97        })
98    }
99
100    pub fn put(&mut self, query_hash: u64, result: Option<String>) {
101        // Evict expired entries periodically
102        if self.entries.len() > 100 {
103            self.entries
104                .retain(|_, entry| entry.created_at.elapsed() < self.ttl);
105        }
106        self.entries.insert(
107            query_hash,
108            CacheEntry {
109                result,
110                created_at: Instant::now(),
111            },
112        );
113    }
114}
115
116/// Build a search query from the user's message and optionally recent conversation turns
117pub fn build_query(
118    user_message: &str,
119    recent_messages: &[(String, String)], // (role, content) pairs
120    config: &ActiveMemoryConfig,
121) -> String {
122    match config.query_mode {
123        QueryMode::Message => user_message.to_string(),
124        QueryMode::Recent => {
125            let mut parts = Vec::new();
126            let start = recent_messages
127                .len()
128                .saturating_sub(config.max_recent_turns);
129            for (role, content) in &recent_messages[start..] {
130                // Truncate each turn to avoid excessive query length
131                let truncated = if content.len() > 200 {
132                    &content[..200]
133                } else {
134                    content.as_str()
135                };
136                parts.push(format!("{}: {}", role, truncated));
137            }
138            parts.push(format!("user: {}", user_message));
139            parts.join("\n")
140        }
141    }
142}
143
144/// Format recalled memory chunks into a context string for injection
145pub fn format_recalled_context(chunks: &[MemoryChunk], max_chars: usize) -> Option<String> {
146    if chunks.is_empty() {
147        return None;
148    }
149
150    let mut parts = Vec::new();
151    let mut total_chars = 0;
152
153    for chunk in chunks {
154        let entry = chunk.content.trim();
155        if entry.is_empty() {
156            continue;
157        }
158
159        if total_chars + entry.len() > max_chars {
160            // Include partial last entry if we have room
161            let remaining = max_chars.saturating_sub(total_chars);
162            if remaining > 50 {
163                parts.push(format!("- {}...", &entry[..remaining.min(entry.len())]));
164            }
165            break;
166        }
167
168        parts.push(format!("- {}", entry));
169        total_chars += entry.len();
170    }
171
172    if parts.is_empty() {
173        return None;
174    }
175
176    Some(format!(
177        "<recalled_context>\nThe following was automatically recalled from memory and may be relevant:\n{}\n</recalled_context>",
178        parts.join("\n")
179    ))
180}
181
182/// Compute a simple hash for cache keying
183pub fn query_hash(query: &str) -> u64 {
184    use std::hash::{Hash, Hasher};
185    let mut hasher = std::collections::hash_map::DefaultHasher::new();
186    query.hash(&mut hasher);
187    hasher.finish()
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_build_query_message_mode() {
196        let config = ActiveMemoryConfig {
197            query_mode: QueryMode::Message,
198            ..Default::default()
199        };
200        let query = build_query("What color do I prefer?", &[], &config);
201        assert_eq!(query, "What color do I prefer?");
202    }
203
204    #[test]
205    fn test_build_query_recent_mode() {
206        let config = ActiveMemoryConfig {
207            query_mode: QueryMode::Recent,
208            max_recent_turns: 2,
209            ..Default::default()
210        };
211        let recent = vec![
212            ("user".to_string(), "Hello".to_string()),
213            ("assistant".to_string(), "Hi there!".to_string()),
214            ("user".to_string(), "Tell me about colors".to_string()),
215            ("assistant".to_string(), "Sure, what colors?".to_string()),
216        ];
217        let query = build_query("What color do I prefer?", &recent, &config);
218        assert!(query.contains("Tell me about colors"));
219        assert!(query.contains("Sure, what colors?"));
220        assert!(query.contains("What color do I prefer?"));
221        // Should NOT contain the first turn (only last 2)
222        assert!(!query.contains("Hello"));
223    }
224
225    #[test]
226    fn test_format_recalled_context_empty() {
227        assert!(format_recalled_context(&[], 500).is_none());
228    }
229
230    #[test]
231    fn test_format_recalled_context_basic() {
232        let chunks = vec![
233            MemoryChunk {
234                file: "test.md".to_string(),
235                line_start: 1,
236                line_end: 1,
237                content: "User prefers dark mode".to_string(),
238                score: 0.9,
239                updated_at: 0,
240            },
241            MemoryChunk {
242                file: "test.md".to_string(),
243                line_start: 2,
244                line_end: 2,
245                content: "User works at Acme Corp".to_string(),
246                score: 0.7,
247                updated_at: 0,
248            },
249        ];
250
251        let result = format_recalled_context(&chunks, 500).unwrap();
252        assert!(result.contains("<recalled_context>"));
253        assert!(result.contains("User prefers dark mode"));
254        assert!(result.contains("User works at Acme Corp"));
255    }
256
257    #[test]
258    fn test_format_recalled_context_truncation() {
259        let chunks = vec![MemoryChunk {
260            file: "test.md".to_string(),
261            line_start: 1,
262            line_end: 1,
263            content: "A".repeat(600),
264            score: 0.9,
265            updated_at: 0,
266        }];
267
268        let result = format_recalled_context(&chunks, 100).unwrap();
269        // Should be truncated
270        assert!(result.len() < 600);
271        assert!(result.contains("..."));
272    }
273
274    #[test]
275    fn test_recall_cache() {
276        let mut cache = RecallCache::new(60_000); // 60s TTL
277
278        // Miss
279        assert!(cache.get(123).is_none());
280
281        // Put and hit
282        cache.put(123, Some("recalled text".to_string()));
283        let hit = cache.get(123).unwrap();
284        assert_eq!(hit.as_deref(), Some("recalled text"));
285
286        // Empty result cache
287        cache.put(456, None);
288        let hit = cache.get(456).unwrap();
289        assert!(hit.is_none());
290    }
291
292    #[test]
293    fn test_recall_cache_expired() {
294        let mut cache = RecallCache::new(0); // 0ms TTL = immediate expiry
295        cache.put(123, Some("text".to_string()));
296
297        // Should be expired
298        std::thread::sleep(std::time::Duration::from_millis(1));
299        assert!(cache.get(123).is_none());
300    }
301
302    #[test]
303    fn test_query_hash_deterministic() {
304        let h1 = query_hash("test query");
305        let h2 = query_hash("test query");
306        let h3 = query_hash("different query");
307
308        assert_eq!(h1, h2);
309        assert_ne!(h1, h3);
310    }
311
312    #[test]
313    fn test_default_config_disabled() {
314        let config = ActiveMemoryConfig::default();
315        assert!(!config.enabled);
316    }
317}