1use std::collections::{HashMap, HashSet};
2use std::future::Future;
3use std::pin::Pin;
4use std::time::Duration;
5
6use crate::client::models::{Message as LLMMessage, MessageOptions};
7use crate::client::LLMClient;
8
9use crate::controller::types::{ContentBlock, Message, TextBlock, TurnId, UserMessage};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ToolCompaction {
14 Summarize,
16 Redact,
18}
19
20impl std::fmt::Display for ToolCompaction {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 ToolCompaction::Summarize => write!(f, "summarize"),
24 ToolCompaction::Redact => write!(f, "redact"),
25 }
26 }
27}
28
29#[derive(Debug, Clone, Default)]
31pub struct CompactionResult {
32 pub tool_results_summarized: usize,
34 pub tool_results_redacted: usize,
36 pub turns_compacted: usize,
38}
39
40impl CompactionResult {
41 pub fn total_compacted(&self) -> usize {
43 self.tool_results_summarized + self.tool_results_redacted
44 }
45}
46
47#[derive(Debug)]
49pub enum CompactionError {
50 LLMError(String),
52 Timeout,
54 ConfigError(String),
56}
57
58impl std::fmt::Display for CompactionError {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 CompactionError::LLMError(msg) => write!(f, "LLM error: {}", msg),
62 CompactionError::Timeout => write!(f, "compaction timed out"),
63 CompactionError::ConfigError(msg) => write!(f, "config error: {}", msg),
64 }
65 }
66}
67
68impl std::error::Error for CompactionError {}
69
70pub trait Compactor: Send + Sync {
74 fn should_compact(&self, context_used: i64, context_limit: i32) -> bool;
80
81 fn compact(
88 &self,
89 conversation: &mut Vec<Message>,
90 compact_summaries: &HashMap<String, String>,
91 ) -> CompactionResult;
92
93 fn is_async(&self) -> bool {
96 false
97 }
98}
99
100pub trait AsyncCompactor: Compactor {
103 fn compact_async<'a>(
109 &'a self,
110 conversation: Vec<Message>,
111 compact_summaries: &'a HashMap<String, String>,
112 ) -> Pin<Box<dyn Future<Output = Result<(Vec<Message>, CompactionResult), CompactionError>> + Send + 'a>>;
113}
114
115pub struct ThresholdCompactor {
119 threshold: f64,
122
123 keep_recent_turns: usize,
126
127 tool_compaction: ToolCompaction,
129}
130
131impl ThresholdCompactor {
132 pub fn new(
142 threshold: f64,
143 keep_recent_turns: usize,
144 tool_compaction: ToolCompaction,
145 ) -> Result<Self, String> {
146 if threshold <= 0.0 || threshold >= 1.0 {
147 return Err(format!(
148 "threshold must be between 0 and 1 (exclusive), got {}",
149 threshold
150 ));
151 }
152
153 Ok(Self {
154 threshold,
155 keep_recent_turns,
156 tool_compaction,
157 })
158 }
159
160 fn unique_turn_ids(&self, conversation: &[Message]) -> Vec<TurnId> {
162 let mut seen = HashSet::new();
163 let mut ids = Vec::new();
164
165 for msg in conversation {
166 let turn_id = msg.turn_id();
167 if seen.insert(turn_id.clone()) {
168 ids.push(turn_id.clone());
169 }
170 }
171
172 ids
173 }
174
175 fn compact_message(
177 &self,
178 msg: &mut Message,
179 compact_summaries: &HashMap<String, String>,
180 ) -> (usize, usize) {
181 let mut summarized = 0;
182 let mut redacted = 0;
183
184 for block in msg.content_mut() {
185 if let ContentBlock::ToolResult(tool_result) = block {
186 match self.tool_compaction {
187 ToolCompaction::Summarize => {
188 if let Some(summary) = compact_summaries.get(&tool_result.tool_use_id) {
190 tool_result.content = summary.clone();
191 summarized += 1;
192 tracing::debug!(
193 tool_use_id = %tool_result.tool_use_id,
194 "Tool result summarized"
195 );
196 }
197 }
199 ToolCompaction::Redact => {
200 tool_result.content =
201 "[Tool result redacted during compaction]".to_string();
202 redacted += 1;
203 tracing::debug!(
204 tool_use_id = %tool_result.tool_use_id,
205 "Tool result redacted"
206 );
207 }
208 }
209 }
210 }
211
212 (summarized, redacted)
213 }
214}
215
216impl Compactor for ThresholdCompactor {
217 fn should_compact(&self, context_used: i64, context_limit: i32) -> bool {
218 if context_limit == 0 {
219 return false;
220 }
221
222 let utilization = context_used as f64 / context_limit as f64;
223 let should_compact = utilization > self.threshold;
224
225 if should_compact {
226 tracing::info!(
227 utilization = utilization,
228 threshold = self.threshold,
229 context_used,
230 context_limit,
231 "Compaction triggered - context utilization exceeded threshold"
232 );
233 }
234
235 should_compact
236 }
237
238 fn compact(
239 &self,
240 conversation: &mut Vec<Message>,
241 compact_summaries: &HashMap<String, String>,
242 ) -> CompactionResult {
243 if conversation.is_empty() {
244 return CompactionResult::default();
245 }
246
247 let turn_ids = self.unique_turn_ids(conversation);
249
250 if turn_ids.len() <= self.keep_recent_turns {
252 tracing::debug!(
253 total_turns = turn_ids.len(),
254 keep_recent = self.keep_recent_turns,
255 "Skipping compaction - not enough turns"
256 );
257 return CompactionResult::default();
258 }
259
260 let start_idx = turn_ids.len() - self.keep_recent_turns;
262 let turns_to_keep: HashSet<_> = turn_ids[start_idx..].iter().cloned().collect();
263 let turns_compacted = start_idx;
264
265 tracing::info!(
266 total_turns = turn_ids.len(),
267 keep_recent = self.keep_recent_turns,
268 compacting_turns = turns_compacted,
269 tool_compaction_strategy = %self.tool_compaction,
270 "Starting conversation compaction"
271 );
272
273 let mut total_summarized = 0;
275 let mut total_redacted = 0;
276
277 for msg in conversation.iter_mut() {
278 let turn_id = msg.turn_id();
279 if turns_to_keep.contains(turn_id) {
280 continue; }
282
283 let (summarized, redacted) = self.compact_message(msg, compact_summaries);
284 total_summarized += summarized;
285 total_redacted += redacted;
286 }
287
288 tracing::info!(
289 tool_results_summarized = total_summarized,
290 tool_results_redacted = total_redacted,
291 turns_compacted,
292 "Conversation compaction completed"
293 );
294
295 CompactionResult {
296 tool_results_summarized: total_summarized,
297 tool_results_redacted: total_redacted,
298 turns_compacted,
299 }
300 }
301}
302
303pub const DEFAULT_SUMMARY_SYSTEM_PROMPT: &str = r#"You are a conversation summarizer. Your task is to create a concise summary of the conversation history provided.
309
310Guidelines:
311- Capture the key topics discussed, decisions made, and important context
312- Preserve any technical details, file paths, code snippets, or specific values that would be needed to continue the conversation
313- Include the user's original goals and any progress made toward them
314- Note any pending tasks or unresolved questions
315- Keep the summary focused and actionable
316- Format the summary as a narrative that provides context for continuing the conversation
317
318Respond with only the summary, no additional commentary."#;
319
320pub const DEFAULT_MAX_SUMMARY_TOKENS: i64 = 2048;
322
323pub const DEFAULT_SUMMARY_TIMEOUT: Duration = Duration::from_secs(60);
325
326#[derive(Debug, Clone)]
328pub struct LLMCompactorConfig {
329 pub threshold: f64,
331
332 pub keep_recent_turns: usize,
334
335 pub summary_system_prompt: Option<String>,
337
338 pub max_summary_tokens: Option<i64>,
340
341 pub summary_timeout: Option<Duration>,
343}
344
345impl LLMCompactorConfig {
346 pub fn new(threshold: f64, keep_recent_turns: usize) -> Self {
348 Self {
349 threshold,
350 keep_recent_turns,
351 summary_system_prompt: None,
352 max_summary_tokens: None,
353 summary_timeout: None,
354 }
355 }
356
357 pub fn validate(&self) -> Result<(), String> {
359 if self.threshold <= 0.0 || self.threshold >= 1.0 {
360 return Err(format!(
361 "threshold must be between 0 and 1 (exclusive), got {}",
362 self.threshold
363 ));
364 }
365 Ok(())
366 }
367
368 pub fn system_prompt(&self) -> &str {
370 self.summary_system_prompt
371 .as_deref()
372 .unwrap_or(DEFAULT_SUMMARY_SYSTEM_PROMPT)
373 }
374
375 pub fn max_tokens(&self) -> i64 {
377 self.max_summary_tokens.unwrap_or(DEFAULT_MAX_SUMMARY_TOKENS)
378 }
379
380 pub fn timeout(&self) -> Duration {
382 self.summary_timeout.unwrap_or(DEFAULT_SUMMARY_TIMEOUT)
383 }
384}
385
386impl Default for LLMCompactorConfig {
387 fn default() -> Self {
388 Self::new(0.75, 5)
389 }
390}
391
392pub struct LLMCompactor {
395 client: LLMClient,
397
398 config: LLMCompactorConfig,
400}
401
402impl LLMCompactor {
403 pub fn new(client: LLMClient, config: LLMCompactorConfig) -> Result<Self, String> {
412 config.validate()?;
413
414 tracing::info!(
415 threshold = config.threshold,
416 keep_recent_turns = config.keep_recent_turns,
417 max_summary_tokens = config.max_tokens(),
418 "LLM compactor created"
419 );
420
421 Ok(Self { client, config })
422 }
423
424 fn unique_turn_ids(&self, conversation: &[Message]) -> Vec<TurnId> {
426 let mut seen = HashSet::new();
427 let mut ids = Vec::new();
428
429 for msg in conversation {
430 let turn_id = msg.turn_id();
431 if seen.insert(turn_id.clone()) {
432 ids.push(turn_id.clone());
433 }
434 }
435
436 ids
437 }
438
439 fn format_messages_for_summary(&self, messages: &[Message]) -> String {
441 let mut builder = String::new();
442
443 for msg in messages {
444 if msg.is_user() {
446 builder.push_str("User: ");
447 } else {
448 builder.push_str("Assistant: ");
449 }
450
451 for block in msg.content() {
453 match block {
454 ContentBlock::Text(text) => {
455 builder.push_str(&text.text);
456 }
457 ContentBlock::ToolUse(tool_use) => {
458 builder.push_str(&format!(
459 "[Called tool: {} with input: {:?}]",
460 tool_use.name, tool_use.input
461 ));
462 }
463 ContentBlock::ToolResult(tool_result) => {
464 let content = truncate_content(&tool_result.content, 1000);
465 if tool_result.is_error {
466 builder.push_str(&format!("[Tool error: {}]", content));
467 } else {
468 builder.push_str(&format!("[Tool result: {}]", content));
469 }
470 }
471 }
472 }
473 builder.push_str("\n\n");
474 }
475
476 builder
477 }
478
479 fn create_summary_message(&self, summary: &str, session_id: &str) -> Message {
481 let turn_id = TurnId::new_user_turn(0);
483
484 let now = std::time::SystemTime::now()
486 .duration_since(std::time::UNIX_EPOCH)
487 .unwrap_or_default();
488
489 Message::User(UserMessage {
490 id: format!("summary-{}", now.as_nanos()),
491 session_id: session_id.to_string(),
492 turn_id,
493 created_at: now.as_secs() as i64,
494 content: vec![ContentBlock::Text(TextBlock {
495 text: format!("[Previous conversation summary]:\n\n{}", summary),
496 })],
497 })
498 }
499
500 fn get_session_id(&self, conversation: &[Message]) -> String {
502 conversation
503 .first()
504 .map(|m| m.session_id().to_string())
505 .unwrap_or_default()
506 }
507}
508
509impl Compactor for LLMCompactor {
510 fn should_compact(&self, context_used: i64, context_limit: i32) -> bool {
511 if context_limit == 0 {
512 return false;
513 }
514
515 let utilization = context_used as f64 / context_limit as f64;
516 let should_compact = utilization > self.config.threshold;
517
518 if should_compact {
519 tracing::info!(
520 utilization = utilization,
521 threshold = self.config.threshold,
522 context_used,
523 context_limit,
524 "LLM compaction triggered - context utilization exceeded threshold"
525 );
526 }
527
528 should_compact
529 }
530
531 fn compact(
532 &self,
533 _conversation: &mut Vec<Message>,
534 _compact_summaries: &HashMap<String, String>,
535 ) -> CompactionResult {
536 tracing::warn!("LLMCompactor::compact() called - use compact_async() instead");
538 CompactionResult::default()
539 }
540
541 fn is_async(&self) -> bool {
542 true
543 }
544}
545
546impl AsyncCompactor for LLMCompactor {
547 fn compact_async<'a>(
548 &'a self,
549 conversation: Vec<Message>,
550 _compact_summaries: &'a HashMap<String, String>,
551 ) -> Pin<Box<dyn Future<Output = Result<(Vec<Message>, CompactionResult), CompactionError>> + Send + 'a>>
552 {
553 Box::pin(async move {
554 if conversation.is_empty() {
555 return Ok((conversation, CompactionResult::default()));
556 }
557
558 let turn_ids = self.unique_turn_ids(&conversation);
560
561 if turn_ids.len() <= self.config.keep_recent_turns {
563 tracing::debug!(
564 total_turns = turn_ids.len(),
565 keep_recent = self.config.keep_recent_turns,
566 "Skipping LLM compaction - not enough turns"
567 );
568 return Ok((conversation, CompactionResult::default()));
569 }
570
571 let start_idx = turn_ids.len() - self.config.keep_recent_turns;
573 let turns_to_keep: HashSet<_> = turn_ids[start_idx..].iter().cloned().collect();
574
575 let mut old_messages = Vec::new();
576 let mut recent_messages = Vec::new();
577
578 for msg in conversation {
579 if turns_to_keep.contains(msg.turn_id()) {
580 recent_messages.push(msg);
581 } else {
582 old_messages.push(msg);
583 }
584 }
585
586 if old_messages.is_empty() {
587 return Ok((recent_messages, CompactionResult::default()));
588 }
589
590 let session_id = self.get_session_id(&old_messages);
591
592 let formatted_conversation = self.format_messages_for_summary(&old_messages);
594
595 tracing::info!(
596 total_turns = turn_ids.len(),
597 turns_to_summarize = start_idx,
598 turns_to_keep = self.config.keep_recent_turns,
599 messages_to_summarize = old_messages.len(),
600 formatted_length = formatted_conversation.len(),
601 "Starting LLM conversation compaction"
602 );
603
604 let options = MessageOptions {
606 max_tokens: Some(self.config.max_tokens() as u32),
607 ..Default::default()
608 };
609
610 let llm_messages = vec![
612 LLMMessage::system(self.config.system_prompt()),
613 LLMMessage::user(formatted_conversation),
614 ];
615
616 let result = tokio::time::timeout(
618 self.config.timeout(),
619 self.client.send_message(&llm_messages, &options),
620 )
621 .await;
622
623 let response = match result {
624 Ok(Ok(msg)) => msg,
625 Ok(Err(e)) => {
626 tracing::error!(error = %e, "LLM compaction failed");
627 return Err(CompactionError::LLMError(e.to_string()));
628 }
629 Err(_) => {
630 tracing::error!("LLM compaction timed out");
631 return Err(CompactionError::Timeout);
632 }
633 };
634
635 let summary_text = response
637 .content
638 .iter()
639 .filter_map(|c| {
640 if let crate::client::models::Content::Text(t) = c {
641 Some(t.as_str())
642 } else {
643 None
644 }
645 })
646 .collect::<Vec<_>>()
647 .join("");
648
649 let summary_message = self.create_summary_message(&summary_text, &session_id);
651
652 let mut new_conversation = Vec::with_capacity(1 + recent_messages.len());
654 new_conversation.push(summary_message);
655 new_conversation.extend(recent_messages);
656
657 let result = CompactionResult {
658 tool_results_summarized: 0,
659 tool_results_redacted: 0,
660 turns_compacted: start_idx,
661 };
662
663 tracing::info!(
664 original_messages = old_messages.len() + result.turns_compacted,
665 new_messages = new_conversation.len(),
666 summary_length = summary_text.len(),
667 turns_compacted = result.turns_compacted,
668 "LLM compaction completed"
669 );
670
671 Ok((new_conversation, result))
672 })
673 }
674}
675
676fn truncate_content(content: &str, max_len: usize) -> String {
678 if content.len() <= max_len {
679 content.to_string()
680 } else {
681 format!("{}...", &content[..max_len.saturating_sub(3)])
682 }
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688 use crate::controller::types::{UserMessage, AssistantMessage};
689
690 fn make_user_message(turn_id: TurnId) -> Message {
691 Message::User(UserMessage {
692 id: format!("msg_{}", turn_id),
693 session_id: "test_session".to_string(),
694 turn_id,
695 created_at: 0,
696 content: vec![ContentBlock::Text(crate::controller::types::TextBlock {
697 text: "test".to_string(),
698 })],
699 })
700 }
701
702 fn make_assistant_message(turn_id: TurnId) -> Message {
703 Message::Assistant(AssistantMessage {
704 id: format!("msg_{}", turn_id),
705 session_id: "test_session".to_string(),
706 turn_id,
707 parent_id: String::new(),
708 created_at: 0,
709 completed_at: None,
710 model_id: "test_model".to_string(),
711 provider_id: "test_provider".to_string(),
712 input_tokens: 0,
713 output_tokens: 0,
714 cache_read_tokens: 0,
715 cache_write_tokens: 0,
716 finish_reason: None,
717 error: None,
718 content: vec![ContentBlock::Text(crate::controller::types::TextBlock {
719 text: "test".to_string(),
720 })],
721 })
722 }
723
724 fn make_tool_result_message(tool_use_id: &str, content: &str, turn_id: TurnId) -> Message {
725 Message::User(UserMessage {
726 id: format!("msg_{}", turn_id),
727 session_id: "test_session".to_string(),
728 turn_id,
729 created_at: 0,
730 content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
731 tool_use_id: tool_use_id.to_string(),
732 content: content.to_string(),
733 is_error: false,
734 compact_summary: None,
735 })],
736 })
737 }
738
739 #[test]
740 fn test_threshold_compactor_creation() {
741 let compactor = ThresholdCompactor::new(0.75, 3, ToolCompaction::Redact);
743 assert!(compactor.is_ok());
744
745 let compactor = ThresholdCompactor::new(0.0, 3, ToolCompaction::Redact);
747 assert!(compactor.is_err());
748
749 let compactor = ThresholdCompactor::new(1.0, 3, ToolCompaction::Redact);
751 assert!(compactor.is_err());
752 }
753
754 #[test]
755 fn test_should_compact() {
756 let compactor = ThresholdCompactor::new(0.75, 3, ToolCompaction::Redact).unwrap();
757
758 assert!(!compactor.should_compact(7000, 10000));
760
761 assert!(compactor.should_compact(8000, 10000));
763
764 assert!(!compactor.should_compact(8000, 0));
766 }
767
768 #[test]
769 fn test_compact_not_enough_turns() {
770 let compactor = ThresholdCompactor::new(0.75, 3, ToolCompaction::Redact).unwrap();
771
772 let mut conversation = vec![
773 make_user_message(TurnId::new_user_turn(1)),
774 make_assistant_message(TurnId::new_assistant_turn(1)),
775 ];
776
777 let summaries = std::collections::HashMap::new();
778 let result = compactor.compact(&mut conversation, &summaries);
779
780 assert_eq!(result.turns_compacted, 0);
782 }
783
784 #[test]
785 fn test_compact_redacts_old_tool_results() {
786 let compactor = ThresholdCompactor::new(0.75, 2, ToolCompaction::Redact).unwrap();
788
789 let mut conversation = vec![
790 make_tool_result_message("tool_1", "old result", TurnId::new_user_turn(1)),
792 make_assistant_message(TurnId::new_assistant_turn(1)),
793 make_tool_result_message("tool_2", "new result", TurnId::new_user_turn(2)),
795 make_assistant_message(TurnId::new_assistant_turn(2)),
796 ];
797
798 let summaries = std::collections::HashMap::new();
799 let result = compactor.compact(&mut conversation, &summaries);
800
801 assert_eq!(result.tool_results_redacted, 1);
802 assert_eq!(result.turns_compacted, 2); if let ContentBlock::ToolResult(tr) = &conversation[0].content()[0] {
806 assert!(tr.content.contains("redacted"));
807 } else {
808 panic!("Expected ToolResult");
809 }
810
811 if let ContentBlock::ToolResult(tr) = &conversation[2].content()[0] {
813 assert_eq!(tr.content, "new result");
814 } else {
815 panic!("Expected ToolResult");
816 }
817 }
818
819 #[test]
820 fn test_compact_summarizes_with_summary() {
821 let compactor = ThresholdCompactor::new(0.75, 2, ToolCompaction::Summarize).unwrap();
823
824 let mut conversation = vec![
825 make_tool_result_message("tool_1", "very long result", TurnId::new_user_turn(1)),
826 make_assistant_message(TurnId::new_assistant_turn(1)),
827 make_user_message(TurnId::new_user_turn(2)),
828 make_assistant_message(TurnId::new_assistant_turn(2)),
829 ];
830
831 let mut summaries = std::collections::HashMap::new();
832 summaries.insert("tool_1".to_string(), "[summary]".to_string());
833
834 let result = compactor.compact(&mut conversation, &summaries);
835
836 assert_eq!(result.tool_results_summarized, 1);
837
838 if let ContentBlock::ToolResult(tr) = &conversation[0].content()[0] {
840 assert_eq!(tr.content, "[summary]");
841 } else {
842 panic!("Expected ToolResult");
843 }
844 }
845
846 #[test]
851 fn test_llm_compactor_config_creation() {
852 let config = LLMCompactorConfig::new(0.75, 5);
853 assert_eq!(config.threshold, 0.75);
854 assert_eq!(config.keep_recent_turns, 5);
855 assert!(config.summary_system_prompt.is_none());
856 assert!(config.max_summary_tokens.is_none());
857 assert!(config.summary_timeout.is_none());
858 }
859
860 #[test]
861 fn test_llm_compactor_config_validation() {
862 let config = LLMCompactorConfig::new(0.75, 5);
864 assert!(config.validate().is_ok());
865
866 let config = LLMCompactorConfig::new(0.0, 5);
868 assert!(config.validate().is_err());
869
870 let config = LLMCompactorConfig::new(1.0, 5);
872 assert!(config.validate().is_err());
873
874 let config = LLMCompactorConfig::new(0.01, 5);
876 assert!(config.validate().is_ok());
877
878 let config = LLMCompactorConfig::new(0.99, 5);
880 assert!(config.validate().is_ok());
881 }
882
883 #[test]
884 fn test_llm_compactor_config_defaults() {
885 let config = LLMCompactorConfig::default();
886 assert_eq!(config.threshold, 0.75);
887 assert_eq!(config.keep_recent_turns, 5);
888
889 assert_eq!(config.system_prompt(), DEFAULT_SUMMARY_SYSTEM_PROMPT);
891 assert_eq!(config.max_tokens(), DEFAULT_MAX_SUMMARY_TOKENS);
892 assert_eq!(config.timeout(), DEFAULT_SUMMARY_TIMEOUT);
893 }
894
895 #[test]
896 fn test_llm_compactor_config_custom_values() {
897 let config = LLMCompactorConfig {
898 threshold: 0.8,
899 keep_recent_turns: 3,
900 summary_system_prompt: Some("Custom prompt".to_string()),
901 max_summary_tokens: Some(4096),
902 summary_timeout: Some(Duration::from_secs(120)),
903 };
904
905 assert_eq!(config.system_prompt(), "Custom prompt");
906 assert_eq!(config.max_tokens(), 4096);
907 assert_eq!(config.timeout(), Duration::from_secs(120));
908 }
909
910 #[test]
911 fn test_truncate_content() {
912 assert_eq!(truncate_content("hello", 10), "hello");
914
915 assert_eq!(truncate_content("hello", 5), "hello");
917
918 assert_eq!(truncate_content("hello world", 8), "hello...");
920
921 assert_eq!(truncate_content("hello", 3), "...");
923 }
924}