Skip to main content

ai_agents_memory/
context.rs

1//! Conversation context for memory management
2
3use serde::{Deserialize, Serialize};
4
5use super::token_budget::TokenAllocation;
6use ai_agents_core::{ChatMessage, Role};
7
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct ConversationContext {
10    pub summary: Option<String>,
11    pub messages: Vec<ChatMessage>,
12    pub total_messages: usize,
13    pub summarized_count: usize,
14}
15
16impl ConversationContext {
17    pub fn new() -> Self {
18        Self::default()
19    }
20
21    pub fn with_messages(messages: Vec<ChatMessage>) -> Self {
22        let total = messages.len();
23        Self {
24            summary: None,
25            messages,
26            total_messages: total,
27            summarized_count: 0,
28        }
29    }
30
31    pub fn with_summary(mut self, summary: String, summarized_count: usize) -> Self {
32        self.summary = Some(summary);
33        self.summarized_count = summarized_count;
34        self
35    }
36
37    pub fn to_llm_messages(&self) -> Vec<ChatMessage> {
38        let mut result = Vec::new();
39
40        if let Some(ref summary) = self.summary {
41            result.push(ChatMessage {
42                role: Role::System,
43                content: format!("[Previous conversation summary]\n{}", summary),
44                name: None,
45                timestamp: None,
46            });
47        }
48
49        result.extend(self.messages.clone());
50        result
51    }
52
53    /// Build LLM messages with per-component token budgets.
54    pub fn to_llm_messages_with_allocation(
55        &self,
56        allocation: &TokenAllocation,
57    ) -> Vec<ChatMessage> {
58        let mut result = Vec::new();
59
60        // Summary - capped to allocation.summary tokens
61        if let Some(ref summary) = self.summary {
62            let summary_content = format!("[Previous conversation summary]\n{}", summary);
63            let summary_tokens = estimate_tokens(&summary_content);
64
65            let final_content = if summary_tokens > allocation.summary {
66                let ratio = summary_content.len() as f64 / summary_tokens as f64;
67                let target_chars = (allocation.summary as f64 * ratio) as usize;
68                let truncated = &summary_content[..target_chars.min(summary_content.len())];
69                format!("{}...", truncated)
70            } else {
71                summary_content
72            };
73
74            result.push(ChatMessage {
75                role: Role::System,
76                content: final_content,
77                name: None,
78                timestamp: None,
79            });
80        }
81
82        // Recent messages - capped to allocation.recent_messages tokens
83        let mut used_message_tokens = 0u32;
84        let mut messages_to_add: Vec<&ChatMessage> = Vec::new();
85
86        for msg in self.messages.iter().rev() {
87            let tokens = estimate_message_tokens(msg);
88            if used_message_tokens + tokens <= allocation.recent_messages {
89                used_message_tokens += tokens;
90                messages_to_add.push(msg);
91            } else {
92                break;
93            }
94        }
95
96        messages_to_add.reverse();
97        for msg in messages_to_add {
98            result.push(msg.clone());
99        }
100
101        // TODO:
102        // Facts - reserved for 'Session Management' feature, not injected yet.
103
104        result
105    }
106
107    pub fn to_llm_messages_with_budget(&self, max_tokens: u32) -> Vec<ChatMessage> {
108        let mut result = Vec::new();
109        let mut used_tokens = 0u32;
110
111        if let Some(ref summary) = self.summary {
112            let summary_msg = ChatMessage {
113                role: Role::System,
114                content: format!("[Previous conversation summary]\n{}", summary),
115                name: None,
116                timestamp: None,
117            };
118            let tokens = estimate_message_tokens(&summary_msg);
119            if tokens <= max_tokens {
120                used_tokens = tokens;
121                result.push(summary_msg);
122            }
123        }
124
125        let mut messages_to_add: Vec<&ChatMessage> = Vec::new();
126        for msg in self.messages.iter().rev() {
127            let tokens = estimate_message_tokens(msg);
128            if used_tokens + tokens <= max_tokens {
129                used_tokens += tokens;
130                messages_to_add.push(msg);
131            } else {
132                break;
133            }
134        }
135
136        messages_to_add.reverse();
137        for msg in messages_to_add {
138            result.push(msg.clone());
139        }
140
141        result
142    }
143
144    pub fn estimated_tokens(&self) -> u32 {
145        let summary_tokens = self
146            .summary
147            .as_ref()
148            .map(|s| estimate_tokens(s))
149            .unwrap_or(0);
150
151        let message_tokens: u32 = self.messages.iter().map(estimate_message_tokens).sum();
152
153        summary_tokens + message_tokens
154    }
155
156    pub fn is_empty(&self) -> bool {
157        self.summary.is_none() && self.messages.is_empty()
158    }
159
160    pub fn message_count(&self) -> usize {
161        self.messages.len()
162    }
163}
164
165/// Language-aware token estimation for multi-language support
166pub fn estimate_tokens(text: &str) -> u32 {
167    if text.is_empty() {
168        return 0;
169    }
170
171    let ascii_chars = text.chars().filter(|c| c.is_ascii()).count();
172    let cjk_chars = text.chars().filter(|c| is_cjk(*c)).count();
173    let other_chars = text.chars().count() - ascii_chars - cjk_chars;
174
175    let estimated =
176        (ascii_chars as f64 / 4.0) + (cjk_chars as f64 * 1.5) + (other_chars as f64 * 1.0);
177
178    estimated.ceil().max(1.0) as u32
179}
180
181fn is_cjk(c: char) -> bool {
182    matches!(c,
183        '\u{4E00}'..='\u{9FFF}' |   // CJK Unified Ideographs
184        '\u{3400}'..='\u{4DBF}' |   // CJK Extension A
185        '\u{AC00}'..='\u{D7AF}' |   // Korean Hangul
186        '\u{3040}'..='\u{30FF}' |   // Japanese Hiragana/Katakana
187        '\u{31F0}'..='\u{31FF}'     // Katakana Extensions
188    )
189}
190
191pub fn estimate_message_tokens(message: &ChatMessage) -> u32 {
192    let role_tokens = 4u32;
193    let content_tokens = estimate_tokens(&message.content);
194    let name_tokens = message
195        .name
196        .as_ref()
197        .map(|n| estimate_tokens(n))
198        .unwrap_or(0);
199    role_tokens + content_tokens + name_tokens
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub enum CompressResult {
204    NotNeeded,
205    Compressed {
206        messages_summarized: usize,
207        new_summary_length: usize,
208        tokens_saved: u32,
209    },
210    AlreadyCompressed,
211    Failed {
212        error: String,
213    },
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    fn make_message(role: Role, content: &str) -> ChatMessage {
221        ChatMessage {
222            role,
223            content: content.to_string(),
224            name: None,
225            timestamp: None,
226        }
227    }
228
229    #[test]
230    fn test_conversation_context_new() {
231        let ctx = ConversationContext::new();
232        assert!(ctx.is_empty());
233        assert_eq!(ctx.message_count(), 0);
234        assert!(ctx.summary.is_none());
235    }
236
237    #[test]
238    fn test_conversation_context_with_messages() {
239        let messages = vec![
240            make_message(Role::User, "Hello"),
241            make_message(Role::Assistant, "Hi there!"),
242        ];
243        let ctx = ConversationContext::with_messages(messages);
244        assert_eq!(ctx.message_count(), 2);
245        assert_eq!(ctx.total_messages, 2);
246        assert!(!ctx.is_empty());
247    }
248
249    #[test]
250    fn test_conversation_context_with_summary() {
251        let messages = vec![make_message(Role::User, "Current message")];
252        let ctx = ConversationContext::with_messages(messages)
253            .with_summary("Previous discussion about weather".to_string(), 5);
254
255        assert!(ctx.summary.is_some());
256        assert_eq!(ctx.summarized_count, 5);
257
258        let llm_messages = ctx.to_llm_messages();
259        assert_eq!(llm_messages.len(), 2);
260        assert!(
261            llm_messages[0]
262                .content
263                .contains("Previous conversation summary")
264        );
265    }
266
267    #[test]
268    fn test_to_llm_messages_without_summary() {
269        let messages = vec![
270            make_message(Role::User, "Hello"),
271            make_message(Role::Assistant, "Hi!"),
272        ];
273        let ctx = ConversationContext::with_messages(messages);
274
275        let llm_messages = ctx.to_llm_messages();
276        assert_eq!(llm_messages.len(), 2);
277        assert_eq!(llm_messages[0].role, Role::User);
278    }
279
280    #[test]
281    fn test_estimated_tokens() {
282        let ctx = ConversationContext::with_messages(vec![
283            make_message(Role::User, "Hello world"),
284            make_message(Role::Assistant, "Hi there"),
285        ]);
286
287        let tokens = ctx.estimated_tokens();
288        assert!(tokens > 0);
289    }
290
291    #[test]
292    fn test_to_llm_messages_with_budget() {
293        let messages: Vec<ChatMessage> = (0..10)
294            .map(|i| make_message(Role::User, &format!("Message number {}", i)))
295            .collect();
296        let ctx = ConversationContext::with_messages(messages);
297
298        let limited = ctx.to_llm_messages_with_budget(50);
299        assert!(limited.len() < 10);
300    }
301
302    #[test]
303    fn test_to_llm_messages_with_allocation_caps_summary() {
304        let long_summary = "x".repeat(10000); // ~2500 tokens
305        let messages = vec![make_message(Role::User, "Hello")];
306        let ctx = ConversationContext::with_messages(messages).with_summary(long_summary, 50);
307
308        let allocation = TokenAllocation {
309            summary: 100,
310            recent_messages: 2048,
311            facts: 512,
312        };
313
314        let result = ctx.to_llm_messages_with_allocation(&allocation);
315        // Summary should be truncated
316        let summary_msg = &result[0];
317        let summary_tokens = estimate_tokens(&summary_msg.content);
318        assert!(
319            summary_tokens <= 120,
320            "Summary should be roughly capped: got {}",
321            summary_tokens
322        );
323        // Recent message should still be present
324        assert!(result.len() >= 2);
325    }
326
327    #[test]
328    fn test_to_llm_messages_with_allocation_caps_recent() {
329        let messages: Vec<ChatMessage> = (0..50)
330            .map(|i| {
331                make_message(
332                    Role::User,
333                    &format!(
334                        "Message number {} with some extra text to increase tokens",
335                        i
336                    ),
337                )
338            })
339            .collect();
340        let ctx = ConversationContext::with_messages(messages);
341
342        let allocation = TokenAllocation {
343            summary: 1024,
344            recent_messages: 200,
345            facts: 512,
346        };
347
348        let result = ctx.to_llm_messages_with_allocation(&allocation);
349        assert!(
350            result.len() < 50,
351            "Should have fewer messages due to cap: got {}",
352            result.len()
353        );
354        // Messages should be the most recent
355        let last = &result[result.len() - 1];
356        assert!(
357            last.content.contains("49"),
358            "Last message should be the most recent"
359        );
360    }
361
362    #[test]
363    fn test_to_llm_messages_with_allocation_no_summary() {
364        let messages = vec![
365            make_message(Role::User, "Hello"),
366            make_message(Role::Assistant, "Hi!"),
367        ];
368        let ctx = ConversationContext::with_messages(messages);
369
370        let allocation = TokenAllocation {
371            summary: 1024,
372            recent_messages: 2048,
373            facts: 512,
374        };
375
376        let result = ctx.to_llm_messages_with_allocation(&allocation);
377        assert_eq!(result.len(), 2);
378    }
379
380    #[test]
381    fn test_estimate_tokens_english() {
382        assert_eq!(estimate_tokens(""), 0);
383        assert_eq!(estimate_tokens("test"), 1);
384        assert_eq!(estimate_tokens("hello world"), 3);
385    }
386
387    #[test]
388    fn test_estimate_tokens_korean() {
389        let tokens = estimate_tokens("안녕하세요");
390        assert!(
391            tokens >= 5,
392            "Korean text should have more tokens: {}",
393            tokens
394        );
395    }
396
397    #[test]
398    fn test_estimate_tokens_japanese() {
399        let tokens = estimate_tokens("こんにちは");
400        assert!(
401            tokens >= 5,
402            "Japanese text should have more tokens: {}",
403            tokens
404        );
405    }
406
407    #[test]
408    fn test_estimate_tokens_chinese() {
409        let tokens = estimate_tokens("你好世界");
410        assert!(
411            tokens >= 4,
412            "Chinese text should have more tokens: {}",
413            tokens
414        );
415    }
416
417    #[test]
418    fn test_estimate_tokens_mixed() {
419        let tokens = estimate_tokens("Hello 안녕 World 世界");
420        assert!(tokens >= 6, "Mixed text: {}", tokens);
421    }
422
423    #[test]
424    fn test_compress_result_variants() {
425        let not_needed = CompressResult::NotNeeded;
426        assert!(matches!(not_needed, CompressResult::NotNeeded));
427
428        let compressed = CompressResult::Compressed {
429            messages_summarized: 5,
430            new_summary_length: 100,
431            tokens_saved: 500,
432        };
433        if let CompressResult::Compressed {
434            messages_summarized,
435            ..
436        } = compressed
437        {
438            assert_eq!(messages_summarized, 5);
439        }
440
441        let failed = CompressResult::Failed {
442            error: "test error".to_string(),
443        };
444        if let CompressResult::Failed { error } = failed {
445            assert_eq!(error, "test error");
446        }
447    }
448}