Skip to main content

ai_agents_memory/
context.rs

1//! Conversation context for memory management
2
3use serde::{Deserialize, Serialize};
4
5use super::token_budget::TokenAllocation;
6use ai_agents_core::{ChatMessage, Role};
7
8fn prefix_at_char_boundary(text: &str, max_chars: usize) -> &str {
9    if max_chars == 0 {
10        return "";
11    }
12
13    match text.char_indices().nth(max_chars) {
14        Some((idx, _)) => &text[..idx],
15        None => text,
16    }
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default)]
20pub struct ConversationContext {
21    pub summary: Option<String>,
22    pub messages: Vec<ChatMessage>,
23    pub total_messages: usize,
24    pub summarized_count: usize,
25}
26
27impl ConversationContext {
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    pub fn with_messages(messages: Vec<ChatMessage>) -> Self {
33        let total = messages.len();
34        Self {
35            summary: None,
36            messages,
37            total_messages: total,
38            summarized_count: 0,
39        }
40    }
41
42    pub fn with_summary(mut self, summary: String, summarized_count: usize) -> Self {
43        self.summary = Some(summary);
44        self.summarized_count = summarized_count;
45        self
46    }
47
48    pub fn to_llm_messages(&self) -> Vec<ChatMessage> {
49        let mut result = Vec::new();
50
51        if let Some(ref summary) = self.summary {
52            result.push(ChatMessage {
53                role: Role::System,
54                content: format!("[Previous conversation summary]\n{}", summary),
55                name: None,
56                timestamp: None,
57            });
58        }
59
60        result.extend(self.messages.clone());
61        result
62    }
63
64    /// Build LLM messages with per-component token budgets.
65    pub fn to_llm_messages_with_allocation(
66        &self,
67        allocation: &TokenAllocation,
68    ) -> Vec<ChatMessage> {
69        let mut result = Vec::new();
70
71        // Summary - capped to allocation.summary tokens
72        if let Some(ref summary) = self.summary {
73            let summary_content = format!("[Previous conversation summary]\n{}", summary);
74            let summary_tokens = estimate_tokens(&summary_content);
75
76            let final_content = if summary_tokens > allocation.summary {
77                let char_count = summary_content.chars().count() as f64;
78                let ratio = char_count / summary_tokens as f64;
79                let target_chars = (allocation.summary as f64 * ratio) as usize;
80                let truncated = prefix_at_char_boundary(&summary_content, target_chars);
81                format!("{}...", truncated)
82            } else {
83                summary_content
84            };
85
86            result.push(ChatMessage {
87                role: Role::System,
88                content: final_content,
89                name: None,
90                timestamp: None,
91            });
92        }
93
94        // Recent messages - capped to allocation.recent_messages tokens
95        let mut used_message_tokens = 0u32;
96        let mut messages_to_add: Vec<&ChatMessage> = Vec::new();
97
98        for msg in self.messages.iter().rev() {
99            let tokens = estimate_message_tokens(msg);
100            if used_message_tokens + tokens <= allocation.recent_messages {
101                used_message_tokens += tokens;
102                messages_to_add.push(msg);
103            } else {
104                break;
105            }
106        }
107
108        messages_to_add.reverse();
109        for msg in messages_to_add {
110            result.push(msg.clone());
111        }
112
113        // TODO:
114        // Facts - reserved for 'Session Management' feature, not injected yet.
115
116        result
117    }
118
119    pub fn to_llm_messages_with_budget(&self, max_tokens: u32) -> Vec<ChatMessage> {
120        let mut result = Vec::new();
121        let mut used_tokens = 0u32;
122
123        if let Some(ref summary) = self.summary {
124            let summary_msg = ChatMessage {
125                role: Role::System,
126                content: format!("[Previous conversation summary]\n{}", summary),
127                name: None,
128                timestamp: None,
129            };
130            let tokens = estimate_message_tokens(&summary_msg);
131            if tokens <= max_tokens {
132                used_tokens = tokens;
133                result.push(summary_msg);
134            }
135        }
136
137        let mut messages_to_add: Vec<&ChatMessage> = Vec::new();
138        for msg in self.messages.iter().rev() {
139            let tokens = estimate_message_tokens(msg);
140            if used_tokens + tokens <= max_tokens {
141                used_tokens += tokens;
142                messages_to_add.push(msg);
143            } else {
144                break;
145            }
146        }
147
148        messages_to_add.reverse();
149        for msg in messages_to_add {
150            result.push(msg.clone());
151        }
152
153        result
154    }
155
156    pub fn estimated_tokens(&self) -> u32 {
157        let summary_tokens = self
158            .summary
159            .as_ref()
160            .map(|s| estimate_tokens(s))
161            .unwrap_or(0);
162
163        let message_tokens: u32 = self.messages.iter().map(estimate_message_tokens).sum();
164
165        summary_tokens + message_tokens
166    }
167
168    pub fn is_empty(&self) -> bool {
169        self.summary.is_none() && self.messages.is_empty()
170    }
171
172    pub fn message_count(&self) -> usize {
173        self.messages.len()
174    }
175}
176
177/// Language-aware token estimation for multi-language support
178pub fn estimate_tokens(text: &str) -> u32 {
179    if text.is_empty() {
180        return 0;
181    }
182
183    let ascii_chars = text.chars().filter(|c| c.is_ascii()).count();
184    let cjk_chars = text.chars().filter(|c| is_cjk(*c)).count();
185    let other_chars = text.chars().count() - ascii_chars - cjk_chars;
186
187    let estimated =
188        (ascii_chars as f64 / 4.0) + (cjk_chars as f64 * 1.5) + (other_chars as f64 * 1.0);
189
190    estimated.ceil().max(1.0) as u32
191}
192
193fn is_cjk(c: char) -> bool {
194    matches!(c,
195        '\u{4E00}'..='\u{9FFF}' |   // CJK Unified Ideographs
196        '\u{3400}'..='\u{4DBF}' |   // CJK Extension A
197        '\u{AC00}'..='\u{D7AF}' |   // Korean Hangul
198        '\u{3040}'..='\u{30FF}' |   // Japanese Hiragana/Katakana
199        '\u{31F0}'..='\u{31FF}'     // Katakana Extensions
200    )
201}
202
203pub fn estimate_message_tokens(message: &ChatMessage) -> u32 {
204    let role_tokens = 4u32;
205    let content_tokens = estimate_tokens(&message.content);
206    let name_tokens = message
207        .name
208        .as_ref()
209        .map(|n| estimate_tokens(n))
210        .unwrap_or(0);
211    role_tokens + content_tokens + name_tokens
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub enum CompressResult {
216    NotNeeded,
217    Compressed {
218        messages_summarized: usize,
219        new_summary_length: usize,
220        tokens_saved: u32,
221    },
222    AlreadyCompressed,
223    Failed {
224        error: String,
225    },
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    fn make_message(role: Role, content: &str) -> ChatMessage {
233        ChatMessage {
234            role,
235            content: content.to_string(),
236            name: None,
237            timestamp: None,
238        }
239    }
240
241    #[test]
242    fn test_conversation_context_new() {
243        let ctx = ConversationContext::new();
244        assert!(ctx.is_empty());
245        assert_eq!(ctx.message_count(), 0);
246        assert!(ctx.summary.is_none());
247    }
248
249    #[test]
250    fn test_conversation_context_with_messages() {
251        let messages = vec![
252            make_message(Role::User, "Hello"),
253            make_message(Role::Assistant, "Hi there!"),
254        ];
255        let ctx = ConversationContext::with_messages(messages);
256        assert_eq!(ctx.message_count(), 2);
257        assert_eq!(ctx.total_messages, 2);
258        assert!(!ctx.is_empty());
259    }
260
261    #[test]
262    fn test_conversation_context_with_summary() {
263        let messages = vec![make_message(Role::User, "Current message")];
264        let ctx = ConversationContext::with_messages(messages)
265            .with_summary("Previous discussion about weather".to_string(), 5);
266
267        assert!(ctx.summary.is_some());
268        assert_eq!(ctx.summarized_count, 5);
269
270        let llm_messages = ctx.to_llm_messages();
271        assert_eq!(llm_messages.len(), 2);
272        assert!(
273            llm_messages[0]
274                .content
275                .contains("Previous conversation summary")
276        );
277    }
278
279    #[test]
280    fn test_to_llm_messages_without_summary() {
281        let messages = vec![
282            make_message(Role::User, "Hello"),
283            make_message(Role::Assistant, "Hi!"),
284        ];
285        let ctx = ConversationContext::with_messages(messages);
286
287        let llm_messages = ctx.to_llm_messages();
288        assert_eq!(llm_messages.len(), 2);
289        assert_eq!(llm_messages[0].role, Role::User);
290    }
291
292    #[test]
293    fn test_estimated_tokens() {
294        let ctx = ConversationContext::with_messages(vec![
295            make_message(Role::User, "Hello world"),
296            make_message(Role::Assistant, "Hi there"),
297        ]);
298
299        let tokens = ctx.estimated_tokens();
300        assert!(tokens > 0);
301    }
302
303    #[test]
304    fn test_to_llm_messages_with_budget() {
305        let messages: Vec<ChatMessage> = (0..10)
306            .map(|i| make_message(Role::User, &format!("Message number {}", i)))
307            .collect();
308        let ctx = ConversationContext::with_messages(messages);
309
310        let limited = ctx.to_llm_messages_with_budget(50);
311        assert!(limited.len() < 10);
312    }
313
314    #[test]
315    fn test_to_llm_messages_with_allocation_caps_summary() {
316        let long_summary = "x".repeat(10000); // ~2500 tokens
317        let messages = vec![make_message(Role::User, "Hello")];
318        let ctx = ConversationContext::with_messages(messages).with_summary(long_summary, 50);
319
320        let allocation = TokenAllocation {
321            summary: 100,
322            recent_messages: 2048,
323            facts: 512,
324            relationships: 0,
325        };
326
327        let result = ctx.to_llm_messages_with_allocation(&allocation);
328        // Summary should be truncated
329        let summary_msg = &result[0];
330        let summary_tokens = estimate_tokens(&summary_msg.content);
331        assert!(
332            summary_tokens <= 120,
333            "Summary should be roughly capped: got {}",
334            summary_tokens
335        );
336        // Recent message should still be present
337        assert!(result.len() >= 2);
338    }
339
340    #[test]
341    fn test_to_llm_messages_with_allocation_caps_recent() {
342        let messages: Vec<ChatMessage> = (0..50)
343            .map(|i| {
344                make_message(
345                    Role::User,
346                    &format!(
347                        "Message number {} with some extra text to increase tokens",
348                        i
349                    ),
350                )
351            })
352            .collect();
353        let ctx = ConversationContext::with_messages(messages);
354
355        let allocation = TokenAllocation {
356            summary: 1024,
357            recent_messages: 200,
358            facts: 512,
359            relationships: 0,
360        };
361
362        let result = ctx.to_llm_messages_with_allocation(&allocation);
363        assert!(
364            result.len() < 50,
365            "Should have fewer messages due to cap: got {}",
366            result.len()
367        );
368        // Messages should be the most recent
369        let last = &result[result.len() - 1];
370        assert!(
371            last.content.contains("49"),
372            "Last message should be the most recent"
373        );
374    }
375
376    #[test]
377    fn test_prefix_at_char_boundary_handles_unicode() {
378        let text = "제 이름은 Jay이고 계약서를 확인하고 싶어요";
379        let prefix = prefix_at_char_boundary(text, 7);
380        assert_eq!(prefix.chars().count(), 7);
381        assert!(text.starts_with(prefix));
382    }
383
384    #[test]
385    fn test_to_llm_messages_with_allocation_no_summary() {
386        let messages = vec![
387            make_message(Role::User, "Hello"),
388            make_message(Role::Assistant, "Hi!"),
389        ];
390        let ctx = ConversationContext::with_messages(messages);
391
392        let allocation = TokenAllocation {
393            summary: 1024,
394            recent_messages: 2048,
395            facts: 512,
396            relationships: 0,
397        };
398
399        let result = ctx.to_llm_messages_with_allocation(&allocation);
400        assert_eq!(result.len(), 2);
401    }
402
403    #[test]
404    fn test_estimate_tokens_english() {
405        assert_eq!(estimate_tokens(""), 0);
406        assert_eq!(estimate_tokens("test"), 1);
407        assert_eq!(estimate_tokens("hello world"), 3);
408    }
409
410    #[test]
411    fn test_estimate_tokens_korean() {
412        let tokens = estimate_tokens("안녕하세요");
413        assert!(
414            tokens >= 5,
415            "Korean text should have more tokens: {}",
416            tokens
417        );
418    }
419
420    #[test]
421    fn test_estimate_tokens_japanese() {
422        let tokens = estimate_tokens("こんにちは");
423        assert!(
424            tokens >= 5,
425            "Japanese text should have more tokens: {}",
426            tokens
427        );
428    }
429
430    #[test]
431    fn test_estimate_tokens_chinese() {
432        let tokens = estimate_tokens("你好世界");
433        assert!(
434            tokens >= 4,
435            "Chinese text should have more tokens: {}",
436            tokens
437        );
438    }
439
440    #[test]
441    fn test_estimate_tokens_mixed() {
442        let tokens = estimate_tokens("Hello 안녕 World 世界");
443        assert!(tokens >= 6, "Mixed text: {}", tokens);
444    }
445
446    #[test]
447    fn test_compress_result_variants() {
448        let not_needed = CompressResult::NotNeeded;
449        assert!(matches!(not_needed, CompressResult::NotNeeded));
450
451        let compressed = CompressResult::Compressed {
452            messages_summarized: 5,
453            new_summary_length: 100,
454            tokens_saved: 500,
455        };
456        if let CompressResult::Compressed {
457            messages_summarized,
458            ..
459        } = compressed
460        {
461            assert_eq!(messages_summarized, 5);
462        }
463
464        let failed = CompressResult::Failed {
465            error: "test error".to_string(),
466        };
467        if let CompressResult::Failed { error } = failed {
468            assert_eq!(error, "test error");
469        }
470    }
471}