Skip to main content

ares/memory/
mod.rs

1//! Memory management module for conversation context and user memory.
2//!
3//! This module provides utilities for:
4//! - Building agent context with memory
5//! - Formatting memory for LLM prompts
6//! - Managing conversation history windows
7//!
8//! User memory facts and preferences are stored in the database (TursoClient).
9//! This module provides utilities for working with that stored memory.
10
11use crate::types::{AgentContext, MemoryFact, Message, Preference, UserMemory};
12
13/// Default number of recent messages to include in context.
14pub const DEFAULT_HISTORY_WINDOW: usize = 10;
15
16/// Maximum number of facts to include in a prompt to avoid token overflow.
17pub const MAX_FACTS_IN_PROMPT: usize = 20;
18
19/// Maximum number of preferences to include in a prompt.
20pub const MAX_PREFERENCES_IN_PROMPT: usize = 10;
21
22/// Formats user memory into a string suitable for inclusion in system prompts.
23///
24/// # Arguments
25/// * `memory` - The user memory to format
26///
27/// # Returns
28/// A formatted string containing preferences and facts, or an empty string if memory is empty.
29///
30/// # Example
31/// ```ignore
32/// let memory = UserMemory { user_id: "123".into(), preferences: vec![...], facts: vec![...] };
33/// let context = format_memory_for_prompt(&memory);
34/// // context: "User Preferences:\n- communication: concise\n\nKnown Facts:\n- work: engineer"
35/// ```
36pub fn format_memory_for_prompt(memory: &UserMemory) -> String {
37    let mut parts = Vec::new();
38
39    // Format preferences (limited to avoid token overflow)
40    if !memory.preferences.is_empty() {
41        let prefs: Vec<String> = memory
42            .preferences
43            .iter()
44            .take(MAX_PREFERENCES_IN_PROMPT)
45            .filter(|p| p.confidence >= 0.5) // Only include confident preferences
46            .map(|p| format!("- {}/{}: {}", p.category, p.key, p.value))
47            .collect();
48
49        if !prefs.is_empty() {
50            parts.push(format!("User Preferences:\n{}", prefs.join("\n")));
51        }
52    }
53
54    // Format facts (limited and filtered by confidence)
55    if !memory.facts.is_empty() {
56        let facts: Vec<String> = memory
57            .facts
58            .iter()
59            .take(MAX_FACTS_IN_PROMPT)
60            .filter(|f| f.confidence >= 0.5) // Only include confident facts
61            .map(|f| format!("- {}/{}: {}", f.category, f.fact_key, f.fact_value))
62            .collect();
63
64        if !facts.is_empty() {
65            parts.push(format!("Known Facts about User:\n{}", facts.join("\n")));
66        }
67    }
68
69    parts.join("\n\n")
70}
71
72/// Formats user preferences into a compact string for prompt inclusion.
73///
74/// This is a lighter-weight alternative to `format_memory_for_prompt` when
75/// only preferences are needed (e.g., for routing decisions).
76pub fn format_preferences_compact(preferences: &[Preference]) -> String {
77    preferences
78        .iter()
79        .filter(|p| p.confidence >= 0.5)
80        .take(MAX_PREFERENCES_IN_PROMPT)
81        .map(|p| format!("{}: {}", p.key, p.value))
82        .collect::<Vec<_>>()
83        .join(", ")
84}
85
86/// Truncates conversation history to a window of recent messages.
87///
88/// # Arguments
89/// * `history` - Full conversation history
90/// * `window_size` - Maximum number of messages to keep
91///
92/// # Returns
93/// A new vector containing only the most recent messages.
94pub fn truncate_history(history: &[Message], window_size: usize) -> Vec<Message> {
95    if history.len() <= window_size {
96        history.to_vec()
97    } else {
98        history[history.len() - window_size..].to_vec()
99    }
100}
101
102/// Estimates token count for a message (rough approximation).
103///
104/// Uses a simple heuristic of ~4 characters per token for English text.
105/// This is an approximation and may vary by tokenizer.
106pub fn estimate_tokens(text: &str) -> usize {
107    // Rough approximation: ~4 chars per token for English
108    text.len().div_ceil(4)
109}
110
111/// Truncates history to fit within a token budget.
112///
113/// Removes oldest messages until the total estimated tokens is under the budget.
114///
115/// # Arguments
116/// * `history` - Full conversation history
117/// * `token_budget` - Maximum tokens to allow
118///
119/// # Returns
120/// A truncated history that fits within the token budget.
121pub fn truncate_history_to_tokens(history: &[Message], token_budget: usize) -> Vec<Message> {
122    let mut result: Vec<Message> = Vec::new();
123    let mut total_tokens = 0;
124
125    // Work backwards from most recent messages
126    for msg in history.iter().rev() {
127        let msg_tokens = estimate_tokens(&msg.content);
128        if total_tokens + msg_tokens > token_budget {
129            break;
130        }
131        result.push(msg.clone());
132        total_tokens += msg_tokens;
133    }
134
135    // Reverse to restore chronological order
136    result.reverse();
137    result
138}
139
140/// Builds an agent context from components.
141///
142/// This is a convenience function for constructing AgentContext with
143/// appropriate defaults and optional memory/history truncation.
144///
145/// # Arguments
146/// * `user_id` - User identifier
147/// * `session_id` - Session/conversation identifier
148/// * `history` - Full conversation history (will be truncated)
149/// * `memory` - Optional user memory
150/// * `history_window` - Maximum messages to include (defaults to DEFAULT_HISTORY_WINDOW)
151pub fn build_context(
152    user_id: String,
153    session_id: String,
154    history: Vec<Message>,
155    memory: Option<UserMemory>,
156    history_window: Option<usize>,
157) -> AgentContext {
158    let window = history_window.unwrap_or(DEFAULT_HISTORY_WINDOW);
159    let truncated_history = truncate_history(&history, window);
160
161    AgentContext {
162        user_id,
163        session_id,
164        conversation_history: truncated_history,
165        user_memory: memory,
166    }
167}
168
169/// Filters memory facts by category.
170///
171/// Useful for retrieving only relevant facts for specific agent types.
172pub fn filter_facts_by_category(facts: &[MemoryFact], category: &str) -> Vec<MemoryFact> {
173    facts
174        .iter()
175        .filter(|f| f.category == category)
176        .cloned()
177        .collect()
178}
179
180/// Filters preferences by category.
181pub fn filter_preferences_by_category(
182    preferences: &[Preference],
183    category: &str,
184) -> Vec<Preference> {
185    preferences
186        .iter()
187        .filter(|p| p.category == category)
188        .cloned()
189        .collect()
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::types::MessageRole;
196    use chrono::Utc;
197
198    #[test]
199    fn test_format_memory_for_prompt_empty() {
200        let memory = UserMemory {
201            user_id: "test".to_string(),
202            preferences: vec![],
203            facts: vec![],
204        };
205        assert_eq!(format_memory_for_prompt(&memory), "");
206    }
207
208    #[test]
209    fn test_format_memory_for_prompt_with_preferences() {
210        let memory = UserMemory {
211            user_id: "test".to_string(),
212            preferences: vec![Preference {
213                category: "communication".to_string(),
214                key: "style".to_string(),
215                value: "concise".to_string(),
216                confidence: 0.9,
217            }],
218            facts: vec![],
219        };
220        let result = format_memory_for_prompt(&memory);
221        assert!(result.contains("User Preferences:"));
222        assert!(result.contains("communication/style: concise"));
223    }
224
225    #[test]
226    fn test_format_memory_filters_low_confidence() {
227        let memory = UserMemory {
228            user_id: "test".to_string(),
229            preferences: vec![
230                Preference {
231                    category: "test".to_string(),
232                    key: "high".to_string(),
233                    value: "yes".to_string(),
234                    confidence: 0.8,
235                },
236                Preference {
237                    category: "test".to_string(),
238                    key: "low".to_string(),
239                    value: "no".to_string(),
240                    confidence: 0.3, // Below threshold
241                },
242            ],
243            facts: vec![],
244        };
245        let result = format_memory_for_prompt(&memory);
246        assert!(result.contains("high"));
247        assert!(!result.contains("low"));
248    }
249
250    #[test]
251    fn test_truncate_history() {
252        let history: Vec<Message> = (0..10)
253            .map(|i| Message {
254                role: MessageRole::User,
255                content: format!("Message {}", i),
256                timestamp: Utc::now(),
257            })
258            .collect();
259
260        let truncated = truncate_history(&history, 3);
261        assert_eq!(truncated.len(), 3);
262        assert!(truncated[0].content.contains("7"));
263        assert!(truncated[2].content.contains("9"));
264    }
265
266    #[test]
267    fn test_estimate_tokens() {
268        assert_eq!(estimate_tokens(""), 0);
269        assert_eq!(estimate_tokens("test"), 1);
270        assert_eq!(estimate_tokens("this is a longer test string"), 7);
271    }
272
273    #[test]
274    fn test_format_preferences_compact() {
275        let prefs = vec![
276            Preference {
277                category: "output".to_string(),
278                key: "format".to_string(),
279                value: "markdown".to_string(),
280                confidence: 0.9,
281            },
282            Preference {
283                category: "output".to_string(),
284                key: "length".to_string(),
285                value: "brief".to_string(),
286                confidence: 0.8,
287            },
288        ];
289        let result = format_preferences_compact(&prefs);
290        assert_eq!(result, "format: markdown, length: brief");
291    }
292
293    #[test]
294    fn test_build_context() {
295        let history: Vec<Message> = (0..20)
296            .map(|i| Message {
297                role: MessageRole::User,
298                content: format!("Message {}", i),
299                timestamp: Utc::now(),
300            })
301            .collect();
302
303        let context = build_context(
304            "user1".to_string(),
305            "session1".to_string(),
306            history,
307            None,
308            Some(5),
309        );
310
311        assert_eq!(context.user_id, "user1");
312        assert_eq!(context.session_id, "session1");
313        assert_eq!(context.conversation_history.len(), 5);
314        assert!(context.user_memory.is_none());
315    }
316
317    #[test]
318    fn test_filter_facts_by_category() {
319        let facts = vec![
320            MemoryFact {
321                id: "1".to_string(),
322                user_id: "test".to_string(),
323                category: "work".to_string(),
324                fact_key: "role".to_string(),
325                fact_value: "engineer".to_string(),
326                confidence: 0.9,
327                created_at: Utc::now(),
328                updated_at: Utc::now(),
329            },
330            MemoryFact {
331                id: "2".to_string(),
332                user_id: "test".to_string(),
333                category: "personal".to_string(),
334                fact_key: "hobby".to_string(),
335                fact_value: "reading".to_string(),
336                confidence: 0.8,
337                created_at: Utc::now(),
338                updated_at: Utc::now(),
339            },
340        ];
341
342        let work_facts = filter_facts_by_category(&facts, "work");
343        assert_eq!(work_facts.len(), 1);
344        assert_eq!(work_facts[0].fact_key, "role");
345    }
346}