Skip to main content

ares/memory/
mod.rs

1//! Memory management module for conversation context and user memory.
2//!
3//! This module provides utilities for:
4//! - Building agent context with memory
5//! - Formatting memory for LLM prompts
6//! - Managing conversation history windows
7//!
8//! User memory facts and preferences are stored in the database (PostgresClient).
9//! This module provides utilities for working with that stored memory.
10
11use crate::types::{AgentContext, MemoryFact, Message, Preference, UserMemory};
12
13/// Default number of recent messages to include in context.
14pub const DEFAULT_HISTORY_WINDOW: usize = 10;
15
16/// Maximum number of facts to include in a prompt to avoid token overflow.
17pub const MAX_FACTS_IN_PROMPT: usize = 20;
18
19/// Maximum number of preferences to include in a prompt.
20pub const MAX_PREFERENCES_IN_PROMPT: usize = 10;
21
22/// Formats user memory into a string suitable for inclusion in system prompts.
23///
24/// # Arguments
25/// * `memory` - The user memory to format
26///
27/// # Returns
28/// A formatted string containing preferences and facts, or an empty string if memory is empty.
29///
30/// # Example
31/// ```ignore
32/// let memory = UserMemory { user_id: "123".into(), preferences: vec![...], facts: vec![...] };
33/// let context = format_memory_for_prompt(&memory);
34/// // context: "User Preferences:\n- communication: concise\n\nKnown Facts:\n- work: engineer"
35/// ```
36pub fn format_memory_for_prompt(memory: &UserMemory) -> String {
37    let mut parts = Vec::new();
38
39    // Format preferences (limited to avoid token overflow)
40    if !memory.preferences.is_empty() {
41        let prefs: Vec<String> = memory
42            .preferences
43            .iter()
44            .take(MAX_PREFERENCES_IN_PROMPT)
45            .filter(|p| p.confidence >= 0.5) // Only include confident preferences
46            .map(|p| format!("- {}/{}: {}", p.category, p.key, p.value))
47            .collect();
48
49        if !prefs.is_empty() {
50            parts.push(format!("User Preferences:\n{}", prefs.join("\n")));
51        }
52    }
53
54    // Format facts (limited and filtered by confidence)
55    if !memory.facts.is_empty() {
56        let facts: Vec<String> = memory
57            .facts
58            .iter()
59            .take(MAX_FACTS_IN_PROMPT)
60            .filter(|f| f.confidence >= 0.5) // Only include confident facts
61            .map(|f| format!("- {}/{}: {}", f.category, f.fact_key, f.fact_value))
62            .collect();
63
64        if !facts.is_empty() {
65            parts.push(format!("Known Facts about User:\n{}", facts.join("\n")));
66        }
67    }
68
69    parts.join("\n\n")
70}
71
72/// Formats user preferences into a compact string for prompt inclusion.
73///
74/// This is a lighter-weight alternative to `format_memory_for_prompt` when
75/// only preferences are needed (e.g., for routing decisions).
76pub fn format_preferences_compact(preferences: &[Preference]) -> String {
77    preferences
78        .iter()
79        .filter(|p| p.confidence >= 0.5)
80        .take(MAX_PREFERENCES_IN_PROMPT)
81        .map(|p| format!("{}: {}", p.key, p.value))
82        .collect::<Vec<_>>()
83        .join(", ")
84}
85
86/// Truncates conversation history to a window of recent messages.
87///
88/// # Arguments
89/// * `history` - Full conversation history
90/// * `window_size` - Maximum number of messages to keep
91///
92/// # Returns
93/// A new vector containing only the most recent messages.
94pub fn truncate_history(history: &[Message], window_size: usize) -> Vec<Message> {
95    if history.len() <= window_size {
96        history.to_vec()
97    } else {
98        history[history.len() - window_size..].to_vec()
99    }
100}
101
102/// Estimates token count for a message.
103///
104/// Uses an improved heuristic combining word count and character count:
105/// - ~1.3 tokens per word for typical English text
106/// - ~4 characters per token as a fallback floor
107///
108/// This provides a safer (higher) estimate for billing purposes.
109/// Actual token counts vary by tokenizer (GPT-3/4, Claude, etc.).
110pub fn estimate_tokens(text: &str) -> usize {
111    let words = text.split_whitespace().count();
112    let chars = text.len();
113    // Heuristic: ~1.3 tokens per word for English, with floor from char count
114    let word_estimate = (words as f64 * 1.3) as usize;
115    let char_estimate = chars.div_ceil(4);
116    // Use the higher estimate for safety (billing should overcount not undercount)
117    word_estimate.max(char_estimate).max(1)
118}
119
120/// Truncates history to fit within a token budget.
121///
122/// Removes oldest messages until the total estimated tokens is under the budget.
123///
124/// # Arguments
125/// * `history` - Full conversation history
126/// * `token_budget` - Maximum tokens to allow
127///
128/// # Returns
129/// A truncated history that fits within the token budget.
130pub fn truncate_history_to_tokens(history: &[Message], token_budget: usize) -> Vec<Message> {
131    let mut result: Vec<Message> = Vec::new();
132    let mut total_tokens = 0;
133
134    // Work backwards from most recent messages
135    for msg in history.iter().rev() {
136        let msg_tokens = estimate_tokens(&msg.content);
137        if total_tokens + msg_tokens > token_budget {
138            break;
139        }
140        result.push(msg.clone());
141        total_tokens += msg_tokens;
142    }
143
144    // Reverse to restore chronological order
145    result.reverse();
146    result
147}
148
149/// Builds an agent context from components.
150///
151/// This is a convenience function for constructing AgentContext with
152/// appropriate defaults and optional memory/history truncation.
153///
154/// # Arguments
155/// * `user_id` - User identifier
156/// * `session_id` - Session/conversation identifier
157/// * `history` - Full conversation history (will be truncated)
158/// * `memory` - Optional user memory
159/// * `history_window` - Maximum messages to include (defaults to DEFAULT_HISTORY_WINDOW)
160pub fn build_context(
161    user_id: String,
162    session_id: String,
163    history: Vec<Message>,
164    memory: Option<UserMemory>,
165    history_window: Option<usize>,
166) -> AgentContext {
167    let window = history_window.unwrap_or(DEFAULT_HISTORY_WINDOW);
168    let truncated_history = truncate_history(&history, window);
169
170    AgentContext {
171        user_id,
172        session_id,
173        conversation_history: truncated_history,
174        user_memory: memory,
175    }
176}
177
178/// Filters memory facts by category.
179///
180/// Useful for retrieving only relevant facts for specific agent types.
181pub fn filter_facts_by_category(facts: &[MemoryFact], category: &str) -> Vec<MemoryFact> {
182    facts
183        .iter()
184        .filter(|f| f.category == category)
185        .cloned()
186        .collect()
187}
188
189/// Filters preferences by category.
190pub fn filter_preferences_by_category(
191    preferences: &[Preference],
192    category: &str,
193) -> Vec<Preference> {
194    preferences
195        .iter()
196        .filter(|p| p.category == category)
197        .cloned()
198        .collect()
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::types::MessageRole;
205    use chrono::Utc;
206
207    #[test]
208    fn test_format_memory_for_prompt_empty() {
209        let memory = UserMemory {
210            user_id: "test".to_string(),
211            preferences: vec![],
212            facts: vec![],
213        };
214        assert_eq!(format_memory_for_prompt(&memory), "");
215    }
216
217    #[test]
218    fn test_format_memory_for_prompt_with_preferences() {
219        let memory = UserMemory {
220            user_id: "test".to_string(),
221            preferences: vec![Preference {
222                category: "communication".to_string(),
223                key: "style".to_string(),
224                value: "concise".to_string(),
225                confidence: 0.9,
226            }],
227            facts: vec![],
228        };
229        let result = format_memory_for_prompt(&memory);
230        assert!(result.contains("User Preferences:"));
231        assert!(result.contains("communication/style: concise"));
232    }
233
234    #[test]
235    fn test_format_memory_filters_low_confidence() {
236        let memory = UserMemory {
237            user_id: "test".to_string(),
238            preferences: vec![
239                Preference {
240                    category: "test".to_string(),
241                    key: "high".to_string(),
242                    value: "yes".to_string(),
243                    confidence: 0.8,
244                },
245                Preference {
246                    category: "test".to_string(),
247                    key: "low".to_string(),
248                    value: "no".to_string(),
249                    confidence: 0.3, // Below threshold
250                },
251            ],
252            facts: vec![],
253        };
254        let result = format_memory_for_prompt(&memory);
255        assert!(result.contains("high"));
256        assert!(!result.contains("low"));
257    }
258
259    #[test]
260    fn test_truncate_history() {
261        let history: Vec<Message> = (0..10)
262            .map(|i| Message {
263                role: MessageRole::User,
264                content: format!("Message {}", i),
265                timestamp: Utc::now(),
266            })
267            .collect();
268
269        let truncated = truncate_history(&history, 3);
270        assert_eq!(truncated.len(), 3);
271        assert!(truncated[0].content.contains("7"));
272        assert!(truncated[2].content.contains("9"));
273    }
274
275    #[test]
276    fn test_estimate_tokens() {
277        assert_eq!(estimate_tokens(""), 1); // floors at 1 for billing safety
278        assert_eq!(estimate_tokens("test"), 1);
279        assert_eq!(estimate_tokens("this is a longer test string"), 7);
280    }
281
282    #[test]
283    fn test_format_preferences_compact() {
284        let prefs = vec![
285            Preference {
286                category: "output".to_string(),
287                key: "format".to_string(),
288                value: "markdown".to_string(),
289                confidence: 0.9,
290            },
291            Preference {
292                category: "output".to_string(),
293                key: "length".to_string(),
294                value: "brief".to_string(),
295                confidence: 0.8,
296            },
297        ];
298        let result = format_preferences_compact(&prefs);
299        assert_eq!(result, "format: markdown, length: brief");
300    }
301
302    #[test]
303    fn test_build_context() {
304        let history: Vec<Message> = (0..20)
305            .map(|i| Message {
306                role: MessageRole::User,
307                content: format!("Message {}", i),
308                timestamp: Utc::now(),
309            })
310            .collect();
311
312        let context = build_context(
313            "user1".to_string(),
314            "session1".to_string(),
315            history,
316            None,
317            Some(5),
318        );
319
320        assert_eq!(context.user_id, "user1");
321        assert_eq!(context.session_id, "session1");
322        assert_eq!(context.conversation_history.len(), 5);
323        assert!(context.user_memory.is_none());
324    }
325
326    #[test]
327    fn test_filter_facts_by_category() {
328        let facts = vec![
329            MemoryFact {
330                id: "1".to_string(),
331                user_id: "test".to_string(),
332                category: "work".to_string(),
333                fact_key: "role".to_string(),
334                fact_value: "engineer".to_string(),
335                confidence: 0.9,
336                created_at: Utc::now(),
337                updated_at: Utc::now(),
338            },
339            MemoryFact {
340                id: "2".to_string(),
341                user_id: "test".to_string(),
342                category: "personal".to_string(),
343                fact_key: "hobby".to_string(),
344                fact_value: "reading".to_string(),
345                confidence: 0.8,
346                created_at: Utc::now(),
347                updated_at: Utc::now(),
348            },
349        ];
350
351        let work_facts = filter_facts_by_category(&facts, "work");
352        assert_eq!(work_facts.len(), 1);
353        assert_eq!(work_facts[0].fact_key, "role");
354    }
355}