1use serde::{Deserialize, Serialize};
4
5use super::token_budget::TokenAllocation;
6use ai_agents_core::{ChatMessage, Role};
7
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct ConversationContext {
10 pub summary: Option<String>,
11 pub messages: Vec<ChatMessage>,
12 pub total_messages: usize,
13 pub summarized_count: usize,
14}
15
16impl ConversationContext {
17 pub fn new() -> Self {
18 Self::default()
19 }
20
21 pub fn with_messages(messages: Vec<ChatMessage>) -> Self {
22 let total = messages.len();
23 Self {
24 summary: None,
25 messages,
26 total_messages: total,
27 summarized_count: 0,
28 }
29 }
30
31 pub fn with_summary(mut self, summary: String, summarized_count: usize) -> Self {
32 self.summary = Some(summary);
33 self.summarized_count = summarized_count;
34 self
35 }
36
37 pub fn to_llm_messages(&self) -> Vec<ChatMessage> {
38 let mut result = Vec::new();
39
40 if let Some(ref summary) = self.summary {
41 result.push(ChatMessage {
42 role: Role::System,
43 content: format!("[Previous conversation summary]\n{}", summary),
44 name: None,
45 timestamp: None,
46 });
47 }
48
49 result.extend(self.messages.clone());
50 result
51 }
52
53 pub fn to_llm_messages_with_allocation(
55 &self,
56 allocation: &TokenAllocation,
57 ) -> Vec<ChatMessage> {
58 let mut result = Vec::new();
59
60 if let Some(ref summary) = self.summary {
62 let summary_content = format!("[Previous conversation summary]\n{}", summary);
63 let summary_tokens = estimate_tokens(&summary_content);
64
65 let final_content = if summary_tokens > allocation.summary {
66 let ratio = summary_content.len() as f64 / summary_tokens as f64;
67 let target_chars = (allocation.summary as f64 * ratio) as usize;
68 let truncated = &summary_content[..target_chars.min(summary_content.len())];
69 format!("{}...", truncated)
70 } else {
71 summary_content
72 };
73
74 result.push(ChatMessage {
75 role: Role::System,
76 content: final_content,
77 name: None,
78 timestamp: None,
79 });
80 }
81
82 let mut used_message_tokens = 0u32;
84 let mut messages_to_add: Vec<&ChatMessage> = Vec::new();
85
86 for msg in self.messages.iter().rev() {
87 let tokens = estimate_message_tokens(msg);
88 if used_message_tokens + tokens <= allocation.recent_messages {
89 used_message_tokens += tokens;
90 messages_to_add.push(msg);
91 } else {
92 break;
93 }
94 }
95
96 messages_to_add.reverse();
97 for msg in messages_to_add {
98 result.push(msg.clone());
99 }
100
101 result
105 }
106
107 pub fn to_llm_messages_with_budget(&self, max_tokens: u32) -> Vec<ChatMessage> {
108 let mut result = Vec::new();
109 let mut used_tokens = 0u32;
110
111 if let Some(ref summary) = self.summary {
112 let summary_msg = ChatMessage {
113 role: Role::System,
114 content: format!("[Previous conversation summary]\n{}", summary),
115 name: None,
116 timestamp: None,
117 };
118 let tokens = estimate_message_tokens(&summary_msg);
119 if tokens <= max_tokens {
120 used_tokens = tokens;
121 result.push(summary_msg);
122 }
123 }
124
125 let mut messages_to_add: Vec<&ChatMessage> = Vec::new();
126 for msg in self.messages.iter().rev() {
127 let tokens = estimate_message_tokens(msg);
128 if used_tokens + tokens <= max_tokens {
129 used_tokens += tokens;
130 messages_to_add.push(msg);
131 } else {
132 break;
133 }
134 }
135
136 messages_to_add.reverse();
137 for msg in messages_to_add {
138 result.push(msg.clone());
139 }
140
141 result
142 }
143
144 pub fn estimated_tokens(&self) -> u32 {
145 let summary_tokens = self
146 .summary
147 .as_ref()
148 .map(|s| estimate_tokens(s))
149 .unwrap_or(0);
150
151 let message_tokens: u32 = self.messages.iter().map(estimate_message_tokens).sum();
152
153 summary_tokens + message_tokens
154 }
155
156 pub fn is_empty(&self) -> bool {
157 self.summary.is_none() && self.messages.is_empty()
158 }
159
160 pub fn message_count(&self) -> usize {
161 self.messages.len()
162 }
163}
164
165pub fn estimate_tokens(text: &str) -> u32 {
167 if text.is_empty() {
168 return 0;
169 }
170
171 let ascii_chars = text.chars().filter(|c| c.is_ascii()).count();
172 let cjk_chars = text.chars().filter(|c| is_cjk(*c)).count();
173 let other_chars = text.chars().count() - ascii_chars - cjk_chars;
174
175 let estimated =
176 (ascii_chars as f64 / 4.0) + (cjk_chars as f64 * 1.5) + (other_chars as f64 * 1.0);
177
178 estimated.ceil().max(1.0) as u32
179}
180
181fn is_cjk(c: char) -> bool {
182 matches!(c,
183 '\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{AC00}'..='\u{D7AF}' | '\u{3040}'..='\u{30FF}' | '\u{31F0}'..='\u{31FF}' )
189}
190
191pub fn estimate_message_tokens(message: &ChatMessage) -> u32 {
192 let role_tokens = 4u32;
193 let content_tokens = estimate_tokens(&message.content);
194 let name_tokens = message
195 .name
196 .as_ref()
197 .map(|n| estimate_tokens(n))
198 .unwrap_or(0);
199 role_tokens + content_tokens + name_tokens
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub enum CompressResult {
204 NotNeeded,
205 Compressed {
206 messages_summarized: usize,
207 new_summary_length: usize,
208 tokens_saved: u32,
209 },
210 AlreadyCompressed,
211 Failed {
212 error: String,
213 },
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 fn make_message(role: Role, content: &str) -> ChatMessage {
221 ChatMessage {
222 role,
223 content: content.to_string(),
224 name: None,
225 timestamp: None,
226 }
227 }
228
229 #[test]
230 fn test_conversation_context_new() {
231 let ctx = ConversationContext::new();
232 assert!(ctx.is_empty());
233 assert_eq!(ctx.message_count(), 0);
234 assert!(ctx.summary.is_none());
235 }
236
237 #[test]
238 fn test_conversation_context_with_messages() {
239 let messages = vec![
240 make_message(Role::User, "Hello"),
241 make_message(Role::Assistant, "Hi there!"),
242 ];
243 let ctx = ConversationContext::with_messages(messages);
244 assert_eq!(ctx.message_count(), 2);
245 assert_eq!(ctx.total_messages, 2);
246 assert!(!ctx.is_empty());
247 }
248
249 #[test]
250 fn test_conversation_context_with_summary() {
251 let messages = vec![make_message(Role::User, "Current message")];
252 let ctx = ConversationContext::with_messages(messages)
253 .with_summary("Previous discussion about weather".to_string(), 5);
254
255 assert!(ctx.summary.is_some());
256 assert_eq!(ctx.summarized_count, 5);
257
258 let llm_messages = ctx.to_llm_messages();
259 assert_eq!(llm_messages.len(), 2);
260 assert!(
261 llm_messages[0]
262 .content
263 .contains("Previous conversation summary")
264 );
265 }
266
267 #[test]
268 fn test_to_llm_messages_without_summary() {
269 let messages = vec![
270 make_message(Role::User, "Hello"),
271 make_message(Role::Assistant, "Hi!"),
272 ];
273 let ctx = ConversationContext::with_messages(messages);
274
275 let llm_messages = ctx.to_llm_messages();
276 assert_eq!(llm_messages.len(), 2);
277 assert_eq!(llm_messages[0].role, Role::User);
278 }
279
280 #[test]
281 fn test_estimated_tokens() {
282 let ctx = ConversationContext::with_messages(vec![
283 make_message(Role::User, "Hello world"),
284 make_message(Role::Assistant, "Hi there"),
285 ]);
286
287 let tokens = ctx.estimated_tokens();
288 assert!(tokens > 0);
289 }
290
291 #[test]
292 fn test_to_llm_messages_with_budget() {
293 let messages: Vec<ChatMessage> = (0..10)
294 .map(|i| make_message(Role::User, &format!("Message number {}", i)))
295 .collect();
296 let ctx = ConversationContext::with_messages(messages);
297
298 let limited = ctx.to_llm_messages_with_budget(50);
299 assert!(limited.len() < 10);
300 }
301
302 #[test]
303 fn test_to_llm_messages_with_allocation_caps_summary() {
304 let long_summary = "x".repeat(10000); let messages = vec![make_message(Role::User, "Hello")];
306 let ctx = ConversationContext::with_messages(messages).with_summary(long_summary, 50);
307
308 let allocation = TokenAllocation {
309 summary: 100,
310 recent_messages: 2048,
311 facts: 512,
312 };
313
314 let result = ctx.to_llm_messages_with_allocation(&allocation);
315 let summary_msg = &result[0];
317 let summary_tokens = estimate_tokens(&summary_msg.content);
318 assert!(
319 summary_tokens <= 120,
320 "Summary should be roughly capped: got {}",
321 summary_tokens
322 );
323 assert!(result.len() >= 2);
325 }
326
327 #[test]
328 fn test_to_llm_messages_with_allocation_caps_recent() {
329 let messages: Vec<ChatMessage> = (0..50)
330 .map(|i| {
331 make_message(
332 Role::User,
333 &format!(
334 "Message number {} with some extra text to increase tokens",
335 i
336 ),
337 )
338 })
339 .collect();
340 let ctx = ConversationContext::with_messages(messages);
341
342 let allocation = TokenAllocation {
343 summary: 1024,
344 recent_messages: 200,
345 facts: 512,
346 };
347
348 let result = ctx.to_llm_messages_with_allocation(&allocation);
349 assert!(
350 result.len() < 50,
351 "Should have fewer messages due to cap: got {}",
352 result.len()
353 );
354 let last = &result[result.len() - 1];
356 assert!(
357 last.content.contains("49"),
358 "Last message should be the most recent"
359 );
360 }
361
362 #[test]
363 fn test_to_llm_messages_with_allocation_no_summary() {
364 let messages = vec![
365 make_message(Role::User, "Hello"),
366 make_message(Role::Assistant, "Hi!"),
367 ];
368 let ctx = ConversationContext::with_messages(messages);
369
370 let allocation = TokenAllocation {
371 summary: 1024,
372 recent_messages: 2048,
373 facts: 512,
374 };
375
376 let result = ctx.to_llm_messages_with_allocation(&allocation);
377 assert_eq!(result.len(), 2);
378 }
379
380 #[test]
381 fn test_estimate_tokens_english() {
382 assert_eq!(estimate_tokens(""), 0);
383 assert_eq!(estimate_tokens("test"), 1);
384 assert_eq!(estimate_tokens("hello world"), 3);
385 }
386
387 #[test]
388 fn test_estimate_tokens_korean() {
389 let tokens = estimate_tokens("안녕하세요");
390 assert!(
391 tokens >= 5,
392 "Korean text should have more tokens: {}",
393 tokens
394 );
395 }
396
397 #[test]
398 fn test_estimate_tokens_japanese() {
399 let tokens = estimate_tokens("こんにちは");
400 assert!(
401 tokens >= 5,
402 "Japanese text should have more tokens: {}",
403 tokens
404 );
405 }
406
407 #[test]
408 fn test_estimate_tokens_chinese() {
409 let tokens = estimate_tokens("你好世界");
410 assert!(
411 tokens >= 4,
412 "Chinese text should have more tokens: {}",
413 tokens
414 );
415 }
416
417 #[test]
418 fn test_estimate_tokens_mixed() {
419 let tokens = estimate_tokens("Hello 안녕 World 世界");
420 assert!(tokens >= 6, "Mixed text: {}", tokens);
421 }
422
423 #[test]
424 fn test_compress_result_variants() {
425 let not_needed = CompressResult::NotNeeded;
426 assert!(matches!(not_needed, CompressResult::NotNeeded));
427
428 let compressed = CompressResult::Compressed {
429 messages_summarized: 5,
430 new_summary_length: 100,
431 tokens_saved: 500,
432 };
433 if let CompressResult::Compressed {
434 messages_summarized,
435 ..
436 } = compressed
437 {
438 assert_eq!(messages_summarized, 5);
439 }
440
441 let failed = CompressResult::Failed {
442 error: "test error".to_string(),
443 };
444 if let CompressResult::Failed { error } = failed {
445 assert_eq!(error, "test error");
446 }
447 }
448}