1use std::fmt::Write;
2use std::time::Duration;
3
4use anyhow::Result;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7
8use std::sync::Arc;
9
10use crate::memory::traits::Memory;
11use crate::providers::traits::{ChatMessage, Provider};
12
13fn default_enabled() -> bool {
18 true
19}
20fn default_threshold_ratio() -> f64 {
21 0.50
22}
23fn default_protect_first_n() -> usize {
24 3
25}
26fn default_protect_last_n() -> usize {
27 4
28}
29fn default_max_passes() -> u32 {
30 3
31}
32fn default_summary_max_chars() -> usize {
33 4_000
34}
35fn default_source_max_chars() -> usize {
36 50_000
37}
38fn default_timeout_secs() -> u64 {
39 60
40}
41fn default_identifier_policy() -> String {
42 "strict".to_string()
43}
44fn default_tool_result_retrim_chars() -> usize {
45 2_000
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
49pub struct ContextCompressionConfig {
50 #[serde(default = "default_enabled")]
52 pub enabled: bool,
53 #[serde(default = "default_threshold_ratio")]
55 pub threshold_ratio: f64,
56 #[serde(default = "default_protect_first_n")]
58 pub protect_first_n: usize,
59 #[serde(default = "default_protect_last_n")]
61 pub protect_last_n: usize,
62 #[serde(default = "default_max_passes")]
64 pub max_passes: u32,
65 #[serde(default = "default_summary_max_chars")]
67 pub summary_max_chars: usize,
68 #[serde(default = "default_source_max_chars")]
70 pub source_max_chars: usize,
71 #[serde(default = "default_timeout_secs")]
73 pub timeout_secs: u64,
74 #[serde(default)]
76 pub summary_model: Option<String>,
77 #[serde(default = "default_identifier_policy")]
79 pub identifier_policy: String,
80 #[serde(default = "default_tool_result_retrim_chars")]
82 pub tool_result_retrim_chars: usize,
83 #[serde(default)]
85 pub tool_result_trim_exempt: Vec<String>,
86}
87
88impl Default for ContextCompressionConfig {
89 fn default() -> Self {
90 Self {
91 enabled: default_enabled(),
92 threshold_ratio: default_threshold_ratio(),
93 protect_first_n: default_protect_first_n(),
94 protect_last_n: default_protect_last_n(),
95 max_passes: default_max_passes(),
96 summary_max_chars: default_summary_max_chars(),
97 source_max_chars: default_source_max_chars(),
98 timeout_secs: default_timeout_secs(),
99 summary_model: None,
100 identifier_policy: default_identifier_policy(),
101 tool_result_retrim_chars: default_tool_result_retrim_chars(),
102 tool_result_trim_exempt: Vec::new(),
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
112pub struct CompressionResult {
113 pub compressed: bool,
114 pub tokens_before: usize,
115 pub tokens_after: usize,
116 pub passes_used: u32,
117}
118
119const PROBE_TIERS: &[usize] = &[
124 2_000_000, 1_000_000, 512_000, 200_000, 128_000, 64_000, 32_000,
125];
126
127fn next_probe_tier(current: usize) -> usize {
128 PROBE_TIERS
129 .iter()
130 .copied()
131 .find(|&tier| tier < current)
132 .unwrap_or(32_000)
133}
134
135pub fn parse_context_limit_from_error(msg: &str) -> Option<usize> {
141 let re_patterns: &[&str] = &[
144 r"(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})",
146 r"context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})",
148 r"(\d{4,})\s*(?:tokens?\s*)?(?:context|limit)",
150 r"available context size\s*\(\s*(\d{4,})",
152 r">\s*(\d{4,})\s*(?:maximum|max)?\s*(?:context)?\s*(?:length|size|window|tokens?)",
154 ];
155 let lower = msg.to_lowercase();
156 for pattern in re_patterns {
157 if let Ok(re) = regex::Regex::new(pattern) {
158 if let Some(caps) = re.captures(&lower) {
159 if let Some(m) = caps.get(1) {
160 if let Ok(limit) = m.as_str().parse::<usize>() {
161 if (1024..=10_000_000).contains(&limit) {
162 return Some(limit);
163 }
164 }
165 }
166 }
167 }
168 }
169 None
170}
171
172pub fn estimate_tokens(messages: &[ChatMessage]) -> usize {
179 let raw: usize = messages
180 .iter()
181 .map(|m| m.content.len().div_ceil(4) + 4)
182 .sum();
183 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
185 {
186 (raw as f64 * 1.2) as usize
187 }
188}
189
190const SUMMARIZER_SYSTEM: &str = "\
195You are a conversation compaction engine. Summarize the conversation segment below into concise context.
196
197PRESERVE exactly:
198- All identifiers (UUIDs, hashes, file paths, URLs, tokens, IPs)
199- Actions taken (tool calls, file operations, commands run)
200- Key information obtained (data, results, error messages)
201- Decisions made and user preferences expressed
202- Current task status and unresolved items
203- Constraints and requirements mentioned
204
205OMIT:
206- Verbose tool output (keep only key results)
207- Repeated greetings or filler
208- Redundant information already stated
209
210Output concise bullet points. Be thorough but brief.";
211
212pub struct ContextCompressor {
217 config: ContextCompressionConfig,
218 context_window: usize,
219 memory: Option<Arc<dyn Memory>>,
220}
221
222impl ContextCompressor {
223 pub fn new(config: ContextCompressionConfig, context_window: usize) -> Self {
224 Self {
225 config,
226 context_window,
227 memory: None,
228 }
229 }
230
231 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
234 self.memory = Some(memory);
235 self
236 }
237
238 pub fn set_context_window(&mut self, window: usize) {
240 self.context_window = window;
241 }
242
243 fn fast_trim_tool_results(&self, history: &mut [ChatMessage]) -> usize {
246 let max = self.config.tool_result_retrim_chars;
247 if max == 0 {
248 return 0;
249 }
250 let mut saved = 0;
251 let protect_start = self.config.protect_first_n.min(history.len());
252 let protect_end = history.len().saturating_sub(self.config.protect_last_n);
253
254 if protect_start >= protect_end {
255 return 0;
256 }
257
258 for msg in &mut history[protect_start..protect_end] {
259 if msg.role != "tool" {
260 continue;
261 }
262 if msg.content.len() <= max {
263 continue;
264 }
265 if self
267 .config
268 .tool_result_trim_exempt
269 .iter()
270 .any(|t| msg.content.contains(t.as_str()))
271 {
272 continue;
273 }
274 if msg.content.contains("data:image/") {
276 continue;
277 }
278 let original_len = msg.content.len();
279 msg.content = crate::agent::loop_::truncate_tool_result(&msg.content, max);
280 saved += original_len - msg.content.len();
281 }
282 saved
283 }
284
285 pub async fn compress_if_needed(
287 &self,
288 history: &mut Vec<ChatMessage>,
289 provider: &dyn Provider,
290 model: &str,
291 ) -> Result<CompressionResult> {
292 if !self.config.enabled {
293 let tokens = estimate_tokens(history);
294 return Ok(CompressionResult {
295 compressed: false,
296 tokens_before: tokens,
297 tokens_after: tokens,
298 passes_used: 0,
299 });
300 }
301
302 let tokens_before = estimate_tokens(history);
303 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
304 let threshold = (self.context_window as f64 * self.config.threshold_ratio) as usize;
305
306 if tokens_before <= threshold {
307 return Ok(CompressionResult {
308 compressed: false,
309 tokens_before,
310 tokens_after: tokens_before,
311 passes_used: 0,
312 });
313 }
314
315 let chars_saved = self.fast_trim_tool_results(history);
317 if chars_saved > 0 {
318 tracing::info!(chars_saved, "Fast-trim saved chars from old tool results");
319 let recheck = estimate_tokens(history);
320 if recheck <= threshold {
321 return Ok(CompressionResult {
322 compressed: true,
323 tokens_before,
324 tokens_after: recheck,
325 passes_used: 0,
326 });
327 }
328 }
329
330 let mut passes_used = 0;
331 for _ in 0..self.config.max_passes {
332 let did_compress = self.compress_once(history, provider, model).await?;
333 if did_compress {
334 passes_used += 1;
335 }
336 if estimate_tokens(history) <= threshold || !did_compress {
337 break;
338 }
339 }
340
341 let tokens_after = estimate_tokens(history);
342 Ok(CompressionResult {
343 compressed: passes_used > 0,
344 tokens_before,
345 tokens_after,
346 passes_used,
347 })
348 }
349
350 pub async fn compress_on_error(
353 &mut self,
354 history: &mut Vec<ChatMessage>,
355 provider: &dyn Provider,
356 model: &str,
357 error_msg: &str,
358 ) -> Result<bool> {
359 if let Some(limit) = parse_context_limit_from_error(error_msg) {
361 self.context_window = limit;
362 } else {
363 self.context_window = next_probe_tier(self.context_window);
365 }
366
367 tracing::info!(
368 context_window = self.context_window,
369 "Context limit adjusted, re-compressing"
370 );
371
372 let result = self.compress_if_needed(history, provider, model).await?;
373 Ok(result.compressed)
374 }
375
376 async fn compress_once(
378 &self,
379 history: &mut Vec<ChatMessage>,
380 provider: &dyn Provider,
381 model: &str,
382 ) -> Result<bool> {
383 let n = history.len();
384 let protected_total = self.config.protect_first_n + self.config.protect_last_n;
385 if n <= protected_total {
386 return Ok(false);
387 }
388
389 let mut start = self.config.protect_first_n.min(n);
390 let mut end = n.saturating_sub(self.config.protect_last_n);
391
392 start = align_boundary_forward(history, start);
394 end = align_boundary_backward(history, end);
395
396 if start >= end {
397 return Ok(false);
398 }
399
400 let middle = &history[start..end];
402 let transcript = build_transcript(middle, self.config.source_max_chars);
403
404 if transcript.is_empty() {
405 return Ok(false);
406 }
407
408 let message_count = end - start;
409 let summary_model = self.config.summary_model.as_deref().unwrap_or(model);
410
411 let identifier_note = if self.config.identifier_policy == "strict" {
412 "\nIMPORTANT: Preserve all identifiers exactly as they appear."
413 } else {
414 ""
415 };
416
417 let user_prompt = format!(
418 "Summarize the following conversation history ({message_count} messages) for context preservation. \
419 Keep it concise (max 20 bullet points).{identifier_note}\n\n{transcript}"
420 );
421
422 let timeout = Duration::from_secs(self.config.timeout_secs);
424 let summary_raw = match tokio::time::timeout(
425 timeout,
426 provider.chat_with_system(Some(SUMMARIZER_SYSTEM), &user_prompt, summary_model, 0.1),
427 )
428 .await
429 {
430 Ok(Ok(s)) => s,
431 Ok(Err(e)) => {
432 tracing::warn!(error = %e, "Summarization LLM call failed, using transcript truncation");
433 truncate_chars(&transcript, self.config.summary_max_chars)
434 }
435 Err(_) => {
436 tracing::warn!(
437 "Summarization timed out after {}s, using transcript truncation",
438 self.config.timeout_secs
439 );
440 truncate_chars(&transcript, self.config.summary_max_chars)
441 }
442 };
443
444 let summary = truncate_chars(&summary_raw, self.config.summary_max_chars);
445
446 if let Some(ref memory) = self.memory {
449 let facts_key = format!("compressed_context_{}", uuid::Uuid::new_v4());
450 if let Err(e) = memory
451 .store(
452 &facts_key,
453 &summary,
454 crate::memory::traits::MemoryCategory::Daily,
455 None,
456 )
457 .await
458 {
459 tracing::debug!("Failed to save compression summary to memory: {e}");
460 } else {
461 tracing::debug!(
462 "Saved compression summary to memory before discarding {message_count} messages"
463 );
464 }
465 }
466
467 let summary_msg = ChatMessage::assistant(format!(
469 "[CONTEXT SUMMARY \u{2014} {message_count} earlier messages compressed]\n\n{summary}"
470 ));
471 history.splice(start..end, std::iter::once(summary_msg));
472
473 repair_tool_pairs(history);
475
476 Ok(true)
477 }
478}
479
480fn align_boundary_forward(messages: &[ChatMessage], idx: usize) -> usize {
486 let mut i = idx;
487 while i < messages.len() && messages[i].role == "tool" {
488 i += 1;
489 }
490 i
491}
492
493fn align_boundary_backward(messages: &[ChatMessage], idx: usize) -> usize {
496 let mut i = idx;
497 while i > 0 && i < messages.len() && messages[i].role == "tool" {
500 i -= 1;
502 }
503 i
504}
505
506fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) {
517 let mut i = 0;
523 while i < messages.len() {
524 if messages[i].content.contains("[CONTEXT SUMMARY") {
525 while i + 1 < messages.len() && messages[i + 1].role == "tool" {
527 messages.remove(i + 1);
528 }
529 }
530 i += 1;
531 }
532
533 let start = if messages.first().map_or(false, |m| m.role == "system") {
536 1
537 } else {
538 0
539 };
540 while start < messages.len() && messages[start].role == "tool" {
541 messages.remove(start);
542 }
543}
544
545fn build_transcript(messages: &[ChatMessage], max_chars: usize) -> String {
550 let mut transcript = String::new();
551 for msg in messages {
552 let role = msg.role.to_uppercase();
553 let _ = writeln!(transcript, "{role}: {}", msg.content.trim());
554 }
555
556 if transcript.len() > max_chars {
557 truncate_chars(&transcript, max_chars)
558 } else {
559 transcript
560 }
561}
562
563fn truncate_chars(s: &str, max: usize) -> String {
564 if s.len() <= max {
565 return s.to_string();
566 }
567 let mut end = max;
569 while end > 0 && !s.is_char_boundary(end) {
570 end -= 1;
571 }
572 let mut result = s[..end].to_string();
573 result.push_str("...");
574 result
575}
576
577#[cfg(test)]
582mod tests {
583 use super::*;
584
585 fn msg(role: &str, content: &str) -> ChatMessage {
586 ChatMessage {
587 role: role.to_string(),
588 content: content.to_string(),
589 }
590 }
591
592 #[test]
593 fn test_estimate_tokens() {
594 let messages = vec![msg("user", "hello world")]; let tokens = estimate_tokens(&messages);
596 assert!(tokens > 0);
598 }
599
600 #[test]
601 fn test_estimate_tokens_empty() {
602 assert_eq!(estimate_tokens(&[]), 0);
603 }
604
605 #[test]
606 fn test_parse_context_limit_anthropic() {
607 let msg = "prompt is too long: 150000 tokens > 128000 maximum context length";
608 assert_eq!(parse_context_limit_from_error(msg), Some(128_000));
609 }
610
611 #[test]
612 fn test_parse_context_limit_openai() {
613 let msg = "This model's maximum context length is 128000 tokens. However, your messages resulted in 150000 tokens.";
614 assert_eq!(parse_context_limit_from_error(msg), Some(128_000));
615 }
616
617 #[test]
618 fn test_parse_context_limit_llamacpp() {
619 let msg = "request (8968 tokens) exceeds the available context size (8448 tokens)";
620 assert_eq!(parse_context_limit_from_error(msg), Some(8448));
621 }
622
623 #[test]
624 fn test_parse_context_limit_none() {
625 assert_eq!(parse_context_limit_from_error("some random error"), None);
626 }
627
628 #[test]
629 fn test_parse_context_limit_rejects_small() {
630 let msg = "limit is 100 tokens";
631 assert_eq!(parse_context_limit_from_error(msg), None); }
633
634 #[test]
635 fn test_next_probe_tier() {
636 assert_eq!(next_probe_tier(2_000_001), 2_000_000);
637 assert_eq!(next_probe_tier(2_000_000), 1_000_000);
638 assert_eq!(next_probe_tier(200_000), 128_000);
639 assert_eq!(next_probe_tier(64_000), 32_000);
640 assert_eq!(next_probe_tier(32_000), 32_000); assert_eq!(next_probe_tier(10_000), 32_000); }
643
644 #[test]
645 fn test_align_boundary_forward_skips_tool() {
646 let messages = vec![
647 msg("system", "sys"),
648 msg("user", "q"),
649 msg("tool", "result1"),
650 msg("tool", "result2"),
651 msg("user", "next"),
652 ];
653 assert_eq!(align_boundary_forward(&messages, 2), 4);
655 }
656
657 #[test]
658 fn test_align_boundary_forward_noop() {
659 let messages = vec![
660 msg("system", "sys"),
661 msg("user", "q"),
662 msg("assistant", "a"),
663 ];
664 assert_eq!(align_boundary_forward(&messages, 1), 1);
665 }
666
667 #[test]
668 fn test_repair_tool_pairs_removes_orphaned() {
669 let mut messages = vec![
670 msg("system", "sys"),
671 msg(
672 "assistant",
673 "[CONTEXT SUMMARY — 5 earlier messages compressed]\nstuff",
674 ),
675 msg("tool", "orphaned result"),
676 msg("user", "next question"),
677 ];
678 repair_tool_pairs(&mut messages);
679 assert_eq!(messages.len(), 3);
680 assert_eq!(messages[2].role, "user");
681 }
682
683 #[test]
684 fn test_repair_tool_pairs_no_false_positives() {
685 let mut messages = vec![
686 msg("system", "sys"),
687 msg("user", "q"),
688 msg("assistant", "calling tool"),
689 msg("tool", "result"),
690 msg("user", "thanks"),
691 ];
692 repair_tool_pairs(&mut messages);
693 assert_eq!(messages.len(), 5); }
695
696 #[test]
697 fn test_build_transcript() {
698 let messages = vec![msg("user", "hello"), msg("assistant", "hi there")];
699 let t = build_transcript(&messages, 10_000);
700 assert!(t.contains("USER: hello"));
701 assert!(t.contains("ASSISTANT: hi there"));
702 }
703
704 #[test]
705 fn test_build_transcript_truncates() {
706 let messages = vec![msg("user", &"x".repeat(1000))];
707 let t = build_transcript(&messages, 100);
708 assert!(t.len() <= 103); }
710
711 #[test]
712 fn test_truncate_chars() {
713 assert_eq!(truncate_chars("hello world", 5), "hello...");
714 assert_eq!(truncate_chars("hi", 10), "hi");
715 }
716
717 #[test]
718 fn test_config_defaults() {
719 let config = ContextCompressionConfig::default();
720 assert!(config.enabled);
721 assert!((config.threshold_ratio - 0.50).abs() < f64::EPSILON);
722 assert_eq!(config.protect_first_n, 3);
723 assert_eq!(config.protect_last_n, 4);
724 assert_eq!(config.max_passes, 3);
725 assert_eq!(config.summary_max_chars, 4_000);
726 assert_eq!(config.source_max_chars, 50_000);
727 assert_eq!(config.timeout_secs, 60);
728 assert!(config.summary_model.is_none());
729 assert_eq!(config.identifier_policy, "strict");
730 }
731
732 #[test]
733 fn test_config_serde_defaults() {
734 let json = "{}";
735 let config: ContextCompressionConfig = serde_json::from_str(json).unwrap();
736 assert!(config.enabled);
737 assert_eq!(config.protect_first_n, 3);
738 assert_eq!(config.max_passes, 3);
739 }
740
741 #[test]
742 fn test_config_serde_override() {
743 let json = r#"{"enabled": false, "protect_first_n": 5, "max_passes": 1}"#;
744 let config: ContextCompressionConfig = serde_json::from_str(json).unwrap();
745 assert!(!config.enabled);
746 assert_eq!(config.protect_first_n, 5);
747 assert_eq!(config.max_passes, 1);
748 }
749
750 #[test]
753 fn test_fast_trim_protects_first_and_last_n() {
754 let config = ContextCompressionConfig {
755 protect_first_n: 2,
756 protect_last_n: 2,
757 tool_result_retrim_chars: 100,
758 ..Default::default()
759 };
760 let compressor = ContextCompressor::new(config, 128_000);
761 let big = "x".repeat(5_000);
762 let mut history = vec![
763 msg("system", "sys"),
764 msg("tool", &big), msg("user", "q"),
766 msg("tool", &big), msg("user", "next"), msg("tool", &big), ];
770 let saved = compressor.fast_trim_tool_results(&mut history);
771 assert!(saved > 0);
772 assert_eq!(history[1].content.len(), 5_000);
774 assert_eq!(history[5].content.len(), 5_000);
775 assert!(history[3].content.len() <= 200); }
778
779 #[test]
780 fn test_fast_trim_skips_images() {
781 let config = ContextCompressionConfig {
782 protect_first_n: 0,
783 protect_last_n: 0,
784 tool_result_retrim_chars: 100,
785 ..Default::default()
786 };
787 let compressor = ContextCompressor::new(config, 128_000);
788 let img = format!("data:image/{}", "x".repeat(5_000));
789 let mut history = vec![msg("tool", &img)];
790 let saved = compressor.fast_trim_tool_results(&mut history);
791 assert_eq!(saved, 0);
792 assert!(history[0].content.len() > 5_000);
793 }
794
795 #[test]
796 fn test_fast_trim_skips_exempt_tools() {
797 let config = ContextCompressionConfig {
798 protect_first_n: 0,
799 protect_last_n: 0,
800 tool_result_retrim_chars: 100,
801 tool_result_trim_exempt: vec!["KEEPME".to_string()],
802 ..Default::default()
803 };
804 let compressor = ContextCompressor::new(config, 128_000);
805 let content = format!("KEEPME {}", "x".repeat(5_000));
806 let mut history = vec![msg("tool", &content)];
807 let saved = compressor.fast_trim_tool_results(&mut history);
808 assert_eq!(saved, 0);
809 }
810
811 #[test]
812 fn test_fast_trim_skips_small_results() {
813 let config = ContextCompressionConfig {
814 protect_first_n: 0,
815 protect_last_n: 0,
816 tool_result_retrim_chars: 2_000,
817 ..Default::default()
818 };
819 let compressor = ContextCompressor::new(config, 128_000);
820 let mut history = vec![msg("tool", "small result")];
821 let saved = compressor.fast_trim_tool_results(&mut history);
822 assert_eq!(saved, 0);
823 }
824
825 #[test]
826 fn test_fast_trim_skips_non_tool_messages() {
827 let config = ContextCompressionConfig {
828 protect_first_n: 0,
829 protect_last_n: 0,
830 tool_result_retrim_chars: 100,
831 ..Default::default()
832 };
833 let compressor = ContextCompressor::new(config, 128_000);
834 let big = "x".repeat(5_000);
835 let mut history = vec![msg("user", &big), msg("assistant", &big)];
836 let saved = compressor.fast_trim_tool_results(&mut history);
837 assert_eq!(saved, 0);
838 }
839
840 #[test]
841 fn test_fast_trim_config_defaults() {
842 let config = ContextCompressionConfig::default();
843 assert_eq!(config.tool_result_retrim_chars, 2_000);
844 assert!(config.tool_result_trim_exempt.is_empty());
845 }
846
847 #[test]
848 fn test_fast_trim_disabled_when_zero() {
849 let config = ContextCompressionConfig {
850 protect_first_n: 0,
851 protect_last_n: 0,
852 tool_result_retrim_chars: 0,
853 ..Default::default()
854 };
855 let compressor = ContextCompressor::new(config, 128_000);
856 let big = "x".repeat(5_000);
857 let mut history = vec![msg("tool", &big)];
858 let saved = compressor.fast_trim_tool_results(&mut history);
859 assert_eq!(saved, 0);
860 }
861}