ai_agent/services/compact/
compact.rs1use crate::types::api_types::{Message, MessageRole};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
15#[serde(rename_all = "lowercase")]
16pub enum CompactDirection {
17 Head,
19 Tail,
21 #[default]
23 Smart,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CompactResult {
29 pub success: bool,
31 pub messages_removed: usize,
33 pub tokens_before: u64,
35 pub tokens_after: u64,
37 pub direction: CompactDirection,
39 pub summary: String,
41 pub messages_to_keep: Vec<Message>,
43 pub error: Option<String>,
45}
46
47#[derive(Debug, Clone, Default)]
49pub struct CompactOptions {
50 pub max_tokens: Option<u64>,
52 pub direction: CompactDirection,
54 pub create_boundary: bool,
56 pub system_prompt: Option<String>,
58}
59
60#[derive(Debug, Clone)]
62struct MessageGroup {
63 start_index: usize,
65 messages: Vec<Message>,
67 token_count: u64,
69 is_boundary: bool,
71}
72
73fn group_messages(messages: &[Message]) -> Vec<MessageGroup> {
76 let mut groups = Vec::new();
77 let mut current_group = MessageGroup {
78 start_index: 0,
79 messages: Vec::new(),
80 token_count: 0,
81 is_boundary: false,
82 };
83
84 for (i, msg) in messages.iter().enumerate() {
85 match &msg.role {
86 MessageRole::User => {
87 if !current_group.messages.is_empty() {
89 groups.push(std::mem::replace(
90 &mut current_group,
91 MessageGroup {
92 start_index: i,
93 messages: Vec::new(),
94 token_count: 0,
95 is_boundary: false,
96 },
97 ));
98 }
99 current_group.messages.push(msg.clone());
100 current_group.token_count += estimate_tokens_for_message(msg);
101 }
102 MessageRole::Assistant | MessageRole::Tool | MessageRole::System => {
103 current_group.messages.push(msg.clone());
104 current_group.token_count += estimate_tokens_for_message(msg);
105 }
106 }
107 }
108
109 if !current_group.messages.is_empty() {
111 groups.push(current_group);
112 }
113
114 if let Some(last) = groups.last_mut() {
116 last.is_boundary = true;
117 }
118
119 groups
120}
121
122fn estimate_tokens_for_message(msg: &Message) -> u64 {
124 let content_tokens = (msg.content.len() as u64 + 3) / 4;
126
127 let tool_call_tokens = msg
129 .tool_calls
130 .as_ref()
131 .map(|calls| {
132 calls
133 .iter()
134 .map(|tc| {
135 let name_tokens = (tc.name.len() as u64 + 3) / 4;
136 let args_tokens = (tc.arguments.to_string().len() as u64 + 3) / 4;
137 name_tokens + args_tokens + 2 })
139 .sum::<u64>()
140 })
141 .unwrap_or(0);
142
143 let role_overhead: u64 = 4;
145
146 content_tokens + tool_call_tokens + role_overhead
147}
148
149pub async fn compact_messages(
157 messages: &[Message],
158 options: CompactOptions,
159) -> Result<CompactResult, String> {
160 if messages.is_empty() {
161 return Ok(CompactResult {
162 success: true,
163 messages_removed: 0,
164 tokens_before: 0,
165 tokens_after: 0,
166 direction: options.direction,
167 summary: String::new(),
168 messages_to_keep: Vec::new(),
169 error: None,
170 });
171 }
172
173 let tokens_before: u64 = messages.iter().map(estimate_tokens_for_message).sum();
175
176 let target_tokens = options.max_tokens.unwrap_or(tokens_before);
177
178 if tokens_before <= target_tokens {
180 return Ok(CompactResult {
181 success: true,
182 messages_removed: 0,
183 tokens_before,
184 tokens_after: tokens_before,
185 direction: options.direction,
186 summary: String::new(),
187 messages_to_keep: messages.to_vec(),
188 error: None,
189 });
190 }
191
192 let groups = group_messages(messages);
194
195 let direction = if options.direction == CompactDirection::Smart {
197 get_recommended_direction(messages.len(), tokens_before, target_tokens)
198 } else {
199 options.direction
200 };
201
202 let (kept_groups, compacted_groups) =
204 select_groups_to_compact(&groups, target_tokens, direction);
205
206 let messages_to_keep: Vec<Message> = kept_groups
208 .iter()
209 .flat_map(|g| g.messages.clone())
210 .collect();
211
212 let messages_removed: usize = compacted_groups.iter().map(|g| g.messages.len()).sum();
213
214 let summary = create_compact_summary(&compacted_groups);
216
217 let tokens_after: u64 = messages_to_keep
219 .iter()
220 .map(estimate_tokens_for_message)
221 .sum();
222
223 log::info!(
224 "[compact] Compacted {} messages: {} -> {} tokens (direction: {:?})",
225 messages_removed,
226 tokens_before,
227 tokens_after,
228 direction
229 );
230
231 Ok(CompactResult {
232 success: true,
233 messages_removed,
234 tokens_before,
235 tokens_after,
236 direction,
237 summary,
238 messages_to_keep,
239 error: None,
240 })
241}
242
243fn select_groups_to_compact(
245 groups: &[MessageGroup],
246 target_tokens: u64,
247 direction: CompactDirection,
248) -> (Vec<&MessageGroup>, Vec<&MessageGroup>) {
249 let (boundary, non_boundary): (Vec<_>, Vec<_>) = groups.iter().partition(|g| g.is_boundary);
251
252 let boundary_tokens: u64 = boundary.iter().map(|g| g.token_count).sum();
254 let mut remaining_budget = target_tokens.saturating_sub(boundary_tokens);
255
256 let mut kept = boundary;
257 let mut compacted = Vec::new();
258
259 match direction {
260 CompactDirection::Head => {
261 let mut non_boundary_iter = non_boundary.into_iter().peekable();
263 while let Some(group) = non_boundary_iter.next() {
264 if remaining_budget >= group.token_count {
265 kept.push(group);
266 remaining_budget -= group.token_count;
267 } else {
268 compacted.push(group);
269 compacted.extend(non_boundary_iter);
271 break;
272 }
273 }
274 }
275 CompactDirection::Tail => {
276 let mut non_boundary_iter = non_boundary.into_iter().rev().peekable();
278 while let Some(group) = non_boundary_iter.next() {
279 if remaining_budget >= group.token_count {
280 kept.push(group);
281 remaining_budget -= group.token_count;
282 } else {
283 compacted.push(group);
284 compacted.extend(non_boundary_iter);
286 break;
287 }
288 }
289 }
290 CompactDirection::Smart => {
291 let mut non_boundary_iter = non_boundary.into_iter().peekable();
294 while let Some(group) = non_boundary_iter.next() {
295 if remaining_budget >= group.token_count {
296 kept.push(group);
297 remaining_budget -= group.token_count;
298 } else {
299 compacted.push(group);
300 compacted.extend(non_boundary_iter);
301 break;
302 }
303 }
304 }
305 }
306
307 kept.sort_by_key(|g| g.start_index);
309
310 (kept, compacted)
311}
312
313fn create_compact_summary(compacted_groups: &[&MessageGroup]) -> String {
315 if compacted_groups.is_empty() {
316 return String::new();
317 }
318
319 let mut summary = String::new();
320 let total_compacted: usize = compacted_groups.iter().map(|g| g.messages.len()).sum();
321 let total_tokens: u64 = compacted_groups.iter().map(|g| g.token_count).sum();
322
323 summary.push_str(&format!(
324 "Compacted {} messages (~{} tokens) from the conversation history.\n\n",
325 total_compacted, total_tokens
326 ));
327
328 let mut user_messages = 0;
330 let mut assistant_messages = 0;
331 let mut tool_messages = 0;
332
333 for group in compacted_groups {
334 for msg in &group.messages {
335 match &msg.role {
336 MessageRole::User => user_messages += 1,
337 MessageRole::Assistant => assistant_messages += 1,
338 MessageRole::Tool => tool_messages += 1,
339 MessageRole::System => {}
340 }
341 }
342 }
343
344 if user_messages > 0 || assistant_messages > 0 {
345 summary.push_str(&format!(
346 "The compacted section contained {} user messages and {} assistant responses",
347 user_messages, assistant_messages
348 ));
349 if tool_messages > 0 {
350 summary.push_str(&format!(" with {} tool results", tool_messages));
351 }
352 summary.push_str(".\n");
353 }
354
355 summary
356}
357
358pub fn get_recommended_direction(
360 message_count: usize,
361 total_tokens: u64,
362 max_tokens: u64,
363) -> CompactDirection {
364 if total_tokens <= max_tokens {
365 return CompactDirection::Smart;
366 }
367
368 if message_count > 10 {
371 CompactDirection::Head
372 } else {
373 CompactDirection::Smart
374 }
375}
376
377pub fn calculate_messages_to_remove(
379 current_tokens: u64,
380 target_tokens: u64,
381 avg_tokens_per_message: u64,
382) -> usize {
383 if current_tokens <= target_tokens {
384 return 0;
385 }
386
387 let tokens_to_remove = current_tokens - target_tokens;
388 (tokens_to_remove / avg_tokens_per_message) as usize
389}
390
391pub fn rough_token_estimation(text: &str) -> u64 {
393 (text.len() as u64 + 3) / 4
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_compact_direction_default() {
403 let options = CompactOptions::default();
404 assert_eq!(options.direction, CompactDirection::Smart);
405 }
406
407 #[test]
408 fn test_get_recommended_direction_no_compact() {
409 let dir = get_recommended_direction(5, 1000, 2000);
410 assert_eq!(dir, CompactDirection::Smart);
411 }
412
413 #[test]
414 fn test_calculate_messages_to_remove() {
415 let count = calculate_messages_to_remove(5000, 2000, 500);
416 assert_eq!(count, 6);
417 }
418
419 #[test]
420 fn test_calculate_messages_to_remove_no_need() {
421 let count = calculate_messages_to_remove(1000, 2000, 500);
422 assert_eq!(count, 0);
423 }
424
425 #[test]
426 fn test_rough_token_estimation() {
427 let text = "Hello, this is a test message with some content.";
428 let tokens = rough_token_estimation(text);
429 assert!(tokens > 0);
430 assert!(tokens <= (text.len() as u64 + 3) / 4 + 1);
432 }
433
434 #[test]
435 fn test_estimate_tokens_for_message() {
436 let msg = Message {
437 role: MessageRole::User,
438 content: "Hello, how are you?".to_string(),
439 ..Default::default()
440 };
441 let tokens = estimate_tokens_for_message(&msg);
442 assert!(tokens > 0);
443 }
444
445 #[test]
446 fn test_group_messages_basic() {
447 let messages = vec![
448 Message {
449 role: MessageRole::User,
450 content: "Question 1".to_string(),
451 ..Default::default()
452 },
453 Message {
454 role: MessageRole::Assistant,
455 content: "Answer 1".to_string(),
456 ..Default::default()
457 },
458 Message {
459 role: MessageRole::User,
460 content: "Question 2".to_string(),
461 ..Default::default()
462 },
463 Message {
464 role: MessageRole::Assistant,
465 content: "Answer 2".to_string(),
466 ..Default::default()
467 },
468 ];
469
470 let groups = group_messages(&messages);
471 assert_eq!(groups.len(), 2);
473 assert!(!groups[0].is_boundary);
474 assert!(groups[1].is_boundary);
475 }
476
477 #[tokio::test]
478 async fn test_compact_messages_empty() {
479 let result = compact_messages(&[], CompactOptions::default())
480 .await
481 .unwrap();
482 assert!(result.success);
483 assert_eq!(result.messages_removed, 0);
484 }
485
486 #[tokio::test]
487 async fn test_compact_messages_within_budget() {
488 let messages = vec![Message {
489 role: MessageRole::User,
490 content: "Short message".to_string(),
491 ..Default::default()
492 }];
493 let options = CompactOptions {
494 max_tokens: Some(1000000),
495 ..Default::default()
496 };
497 let result = compact_messages(&messages, options).await.unwrap();
498 assert!(result.success);
499 assert_eq!(result.messages_removed, 0);
500 }
501
502 #[tokio::test]
503 async fn test_create_compact_summary() {
504 let msg1 = Message {
505 role: MessageRole::User,
506 content: "Hello".to_string(),
507 ..Default::default()
508 };
509 let msg2 = Message {
510 role: MessageRole::Assistant,
511 content: "Hi there".to_string(),
512 ..Default::default()
513 };
514 let g1 = MessageGroup {
515 start_index: 0,
516 messages: vec![msg1],
517 token_count: 10,
518 is_boundary: false,
519 };
520 let g2 = MessageGroup {
521 start_index: 1,
522 messages: vec![msg2],
523 token_count: 10,
524 is_boundary: false,
525 };
526 let groups = vec![&g1, &g2];
527
528 let summary = create_compact_summary(&groups);
529 assert!(summary.contains("2 messages"));
530 }
531}