1use crate::context::pruner::ProgressivePruner;
22use crate::context::token_estimator::TokenEstimator;
23use crate::context::types::{CodeBlock, CompressionConfig, CompressionResult, PruningConfig};
24use crate::conversation::message::{Message, MessageContent};
25use regex::Regex;
26use std::sync::LazyLock;
27
28pub const DEFAULT_CODE_BLOCK_MAX_LINES: usize = 50;
34
35pub const DEFAULT_TOOL_OUTPUT_MAX_CHARS: usize = 2000;
37
38pub const DEFAULT_FILE_CONTENT_MAX_CHARS: usize = 1500;
40
41const HEAD_RATIO: f64 = 0.6;
43
44#[allow(dead_code)]
46const TAIL_RATIO: f64 = 0.4;
47
48#[allow(dead_code)]
50const OMISSION_MARKER: &str = "\n... [content omitted] ...\n";
51
52static CODE_BLOCK_REGEX: LazyLock<Regex> =
54 LazyLock::new(|| Regex::new(r"```(\w*)\n([\s\S]*?)```").expect("Invalid code block regex"));
55
56static FILE_PATH_REGEX: LazyLock<Regex> = LazyLock::new(|| {
58 Regex::new(r"(?:^|\s)([./~]?(?:[\w.-]+/)+[\w.-]+\.\w+)").expect("Invalid file path regex")
59});
60
61pub struct MessageCompressor;
67
68impl MessageCompressor {
69 pub fn compress_code_block(code: &str, max_lines: usize) -> String {
97 let lines: Vec<&str> = code.lines().collect();
98 let total_lines = lines.len();
99
100 if total_lines <= max_lines {
102 return code.to_string();
103 }
104
105 let head_lines = ((max_lines as f64) * HEAD_RATIO).ceil() as usize;
107 let tail_lines = max_lines.saturating_sub(head_lines);
108
109 let head_lines = head_lines.min(total_lines);
111 let tail_lines = tail_lines.min(total_lines.saturating_sub(head_lines));
112
113 let head: Vec<&str> = lines.iter().take(head_lines).copied().collect();
115 let tail: Vec<&str> = lines
116 .iter()
117 .skip(total_lines.saturating_sub(tail_lines))
118 .copied()
119 .collect();
120
121 let omitted_count = total_lines - head_lines - tail_lines;
122 let omission_text = format!("\n... [{} lines omitted] ...\n", omitted_count);
123
124 format!("{}{}{}", head.join("\n"), omission_text, tail.join("\n"))
125 }
126
127 pub fn extract_code_blocks(text: &str) -> Vec<CodeBlock> {
141 CODE_BLOCK_REGEX
142 .captures_iter(text)
143 .map(|cap| {
144 let full_match = cap.get(0).unwrap();
145 let language = cap.get(1).map(|m| m.as_str().to_string());
146 let code = cap
147 .get(2)
148 .map(|m| m.as_str().to_string())
149 .unwrap_or_default();
150
151 CodeBlock::new(
152 code,
153 if language.as_ref().map(|l| l.is_empty()).unwrap_or(true) {
154 None
155 } else {
156 language
157 },
158 full_match.start(),
159 full_match.end(),
160 )
161 })
162 .collect()
163 }
164
165 pub fn compress_code_blocks_in_text(text: &str, max_lines: usize) -> String {
178 let mut result = text.to_string();
179 let blocks = Self::extract_code_blocks(text);
180
181 for block in blocks.into_iter().rev() {
183 if block.line_count() > max_lines {
184 let compressed_code = Self::compress_code_block(&block.code, max_lines);
185 let language = block.language.as_deref().unwrap_or("");
186 let replacement = format!("```{}\n{}```", language, compressed_code);
187 result.replace_range(block.start..block.end, &replacement);
188 }
189 }
190
191 result
192 }
193
194 pub fn compress_tool_output(content: &str, max_chars: usize) -> String {
212 if content.len() <= max_chars {
213 return content.to_string();
214 }
215
216 let code_blocks = Self::extract_code_blocks(content);
218 if !code_blocks.is_empty() {
219 return Self::compress_tool_output_with_code(content, max_chars, &code_blocks);
220 }
221
222 let head_chars = ((max_chars as f64) * 0.7).ceil() as usize;
224 let tail_chars = max_chars.saturating_sub(head_chars);
225
226 let head = Self::safe_substring(content, 0, head_chars);
227 let tail = Self::safe_substring(
228 content,
229 content.len().saturating_sub(tail_chars),
230 content.len(),
231 );
232
233 let omitted = content.len() - head.len() - tail.len();
234 format!(
235 "{}\n... [{} characters omitted] ...\n{}",
236 head, omitted, tail
237 )
238 }
239
240 fn compress_tool_output_with_code(
242 content: &str,
243 max_chars: usize,
244 code_blocks: &[CodeBlock],
245 ) -> String {
246 let total_code_chars: usize = code_blocks.iter().map(|b| b.code.len()).sum();
248
249 if total_code_chars <= max_chars {
250 let remaining = max_chars.saturating_sub(total_code_chars);
252 let text_before_first = code_blocks
253 .first()
254 .map(|b| content.get(..b.start).unwrap_or(""))
255 .unwrap_or("");
256 let text_after_last = code_blocks
257 .last()
258 .map(|b| content.get(b.end..).unwrap_or(""))
259 .unwrap_or("");
260
261 let before_budget = remaining / 2;
262 let after_budget = remaining.saturating_sub(before_budget);
263
264 let compressed_before = if text_before_first.len() > before_budget {
265 format!(
266 "{}...",
267 Self::safe_substring(text_before_first, 0, before_budget)
268 )
269 } else {
270 text_before_first.to_string()
271 };
272
273 let compressed_after = if text_after_last.len() > after_budget {
274 format!(
275 "...{}",
276 Self::safe_substring(
277 text_after_last,
278 text_after_last.len().saturating_sub(after_budget),
279 text_after_last.len()
280 )
281 )
282 } else {
283 text_after_last.to_string()
284 };
285
286 let mut result = compressed_before;
288 for block in code_blocks {
289 let lang = block.language.as_deref().unwrap_or("");
290 let compressed_code =
291 Self::compress_code_block(&block.code, DEFAULT_CODE_BLOCK_MAX_LINES);
292 result.push_str(&format!("```{}\n{}```", lang, compressed_code));
293 }
294 result.push_str(&compressed_after);
295 result
296 } else {
297 let budget_per_block = max_chars / code_blocks.len().max(1);
299 let lines_budget = budget_per_block / 40; let mut result = String::new();
302 for block in code_blocks {
303 let lang = block.language.as_deref().unwrap_or("");
304 let compressed = Self::compress_code_block(&block.code, lines_budget.max(10));
305 result.push_str(&format!("```{}\n{}```\n", lang, compressed));
306 }
307 result
308 }
309 }
310
311 pub fn extract_file_references(text: &str) -> Vec<String> {
323 FILE_PATH_REGEX
324 .captures_iter(text)
325 .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
326 .collect()
327 }
328
329 pub fn compress_message(message: &Message, config: &CompressionConfig) -> Message {
347 let compressed_content: Vec<MessageContent> = message
348 .content
349 .iter()
350 .map(|content| Self::compress_content(content, config))
351 .collect();
352
353 Message {
354 id: message.id.clone(),
355 role: message.role.clone(),
356 created: message.created,
357 content: compressed_content,
358 metadata: message.metadata,
359 }
360 }
361
362 fn compress_content(content: &MessageContent, config: &CompressionConfig) -> MessageContent {
364 match content {
365 MessageContent::Text(text_content) => {
366 let compressed_text = Self::compress_code_blocks_in_text(
367 &text_content.text,
368 config.code_block_max_lines,
369 );
370 MessageContent::text(compressed_text)
371 }
372 MessageContent::ToolResponse(tool_response) => {
373 Self::compress_tool_response(tool_response, config)
374 }
375 other => other.clone(),
377 }
378 }
379
380 fn compress_tool_response(
382 tool_response: &crate::conversation::message::ToolResponse,
383 config: &CompressionConfig,
384 ) -> MessageContent {
385 use rmcp::model::{CallToolResult, Content, RawContent, RawTextContent};
386
387 match &tool_response.tool_result {
388 Ok(result) => {
389 let compressed_content: Vec<Content> = result
390 .content
391 .iter()
392 .map(|c| {
393 if let RawContent::Text(text) = &c.raw {
394 let compressed = Self::compress_tool_output(
395 &text.text,
396 config.tool_output_max_chars,
397 );
398 Content {
399 raw: RawContent::Text(RawTextContent {
400 text: compressed,
401 meta: text.meta.clone(),
402 }),
403 annotations: c.annotations.clone(),
404 }
405 } else {
406 c.clone()
407 }
408 })
409 .collect();
410
411 MessageContent::ToolResponse(crate::conversation::message::ToolResponse {
412 id: tool_response.id.clone(),
413 tool_result: Ok(CallToolResult {
414 content: compressed_content,
415 is_error: result.is_error,
416 meta: result.meta.clone(),
417 structured_content: result.structured_content.clone(),
418 }),
419 metadata: tool_response.metadata.clone(),
420 })
421 }
422 Err(e) => MessageContent::ToolResponse(crate::conversation::message::ToolResponse {
423 id: tool_response.id.clone(),
424 tool_result: Err(e.clone()),
425 metadata: tool_response.metadata.clone(),
426 }),
427 }
428 }
429
430 pub fn batch_compress_tool_results(messages: &[Message], max_chars: usize) -> Vec<Message> {
441 let config = CompressionConfig {
442 tool_output_max_chars: max_chars,
443 ..Default::default()
444 };
445
446 messages
447 .iter()
448 .map(|msg| Self::compress_message(msg, &config))
449 .collect()
450 }
451
452 pub fn truncate_messages(
472 messages: &[Message],
473 max_tokens: usize,
474 keep_first: usize,
475 keep_last: usize,
476 ) -> Vec<Message> {
477 if messages.is_empty() {
478 return Vec::new();
479 }
480
481 let total_tokens = TokenEstimator::estimate_total_tokens(messages);
482 if total_tokens <= max_tokens {
483 return messages.to_vec();
484 }
485
486 let total_messages = messages.len();
487
488 if keep_first + keep_last >= total_messages {
490 return messages.to_vec();
491 }
492
493 let mut result: Vec<Message> = Vec::new();
495 let mut current_tokens = 0;
496
497 for msg in messages.iter().take(keep_first) {
499 let msg_tokens = TokenEstimator::estimate_message_tokens(msg);
500 if current_tokens + msg_tokens <= max_tokens {
501 result.push(msg.clone());
502 current_tokens += msg_tokens;
503 }
504 }
505
506 let last_messages: Vec<&Message> =
508 messages.iter().skip(total_messages - keep_last).collect();
509 let last_tokens: usize = last_messages
510 .iter()
511 .map(|m| TokenEstimator::estimate_message_tokens(m))
512 .sum();
513
514 let available_for_middle = max_tokens.saturating_sub(current_tokens + last_tokens);
516 let mut middle_tokens = 0;
517
518 for msg in messages
519 .iter()
520 .skip(keep_first)
521 .take(total_messages - keep_first - keep_last)
522 {
523 let msg_tokens = TokenEstimator::estimate_message_tokens(msg);
524 if middle_tokens + msg_tokens <= available_for_middle {
525 result.push(msg.clone());
526 middle_tokens += msg_tokens;
527 } else {
528 break;
529 }
530 }
531
532 for msg in last_messages {
534 result.push(msg.clone());
535 }
536
537 result
538 }
539
540 fn safe_substring(s: &str, start: usize, end: usize) -> &str {
551 if s.is_empty() || start >= s.len() {
552 return "";
553 }
554
555 let valid_start = s
557 .char_indices()
558 .map(|(i, _)| i)
559 .find(|&i| i >= start)
560 .unwrap_or(s.len());
561
562 let valid_end = if end >= s.len() {
565 s.len()
566 } else {
567 s.char_indices()
569 .map(|(i, _)| i)
570 .take_while(|&i| i <= end)
571 .last()
572 .unwrap_or(0)
573 };
574
575 if valid_start >= valid_end {
576 return "";
577 }
578
579 s.get(valid_start..valid_end).unwrap_or("")
580 }
581
582 pub fn calculate_compression_result(
584 original: &Message,
585 compressed: &Message,
586 ) -> CompressionResult {
587 let original_tokens = TokenEstimator::estimate_message_tokens(original);
588 let compressed_tokens = TokenEstimator::estimate_message_tokens(compressed);
589
590 CompressionResult::new(original_tokens, compressed_tokens, "message_compression")
591 }
592
593 pub fn compress_with_pruning(
625 messages: &[Message],
626 usage_ratio: f64,
627 compression_config: &CompressionConfig,
628 pruning_config: &PruningConfig,
629 ) -> Vec<Message> {
630 let compressed: Vec<Message> = messages
632 .iter()
633 .map(|msg| Self::compress_message(msg, compression_config))
634 .collect();
635
636 ProgressivePruner::prune_messages(&compressed, usage_ratio, pruning_config)
638 }
639
640 pub fn compress_tool_output_with_pruning(
656 content: &str,
657 max_chars: usize,
658 usage_ratio: f64,
659 pruning_config: &PruningConfig,
660 ) -> String {
661 let pruning_level = pruning_config.get_pruning_level(usage_ratio);
662
663 match pruning_level {
664 crate::context::types::PruningLevel::HardClear => {
665 ProgressivePruner::hard_clear(&pruning_config.hard_clear_placeholder)
666 }
667 crate::context::types::PruningLevel::SoftTrim => ProgressivePruner::soft_trim(
668 content,
669 pruning_config.soft_trim_head_chars,
670 pruning_config.soft_trim_tail_chars,
671 ),
672 crate::context::types::PruningLevel::None => {
673 Self::compress_tool_output(content, max_chars)
675 }
676 }
677 }
678}
679
680#[cfg(test)]
685mod tests {
686 use super::*;
687
688 #[test]
689 fn test_compress_code_block_within_limit() {
690 let code = "line 1\nline 2\nline 3";
691 let result = MessageCompressor::compress_code_block(code, 10);
692 assert_eq!(result, code);
693 }
694
695 #[test]
696 fn test_compress_code_block_exceeds_limit() {
697 let lines: Vec<String> = (0..100).map(|i| format!("line {}", i)).collect();
698 let code = lines.join("\n");
699
700 let result = MessageCompressor::compress_code_block(&code, 50);
701
702 assert!(result.contains("line 0"));
704 assert!(result.contains("line 29")); assert!(result.contains("lines omitted"));
708
709 assert!(result.contains("line 99"));
711 assert!(result.contains("line 80")); assert!(!result.contains("line 50"));
715 }
716
717 #[test]
718 fn test_extract_code_blocks() {
719 let text = r#"
720Some text before
721
722```rust
723fn main() {
724 println!("Hello");
725}
726```
727
728More text
729
730```python
731print("world")
732```
733"#;
734
735 let blocks = MessageCompressor::extract_code_blocks(text);
736 assert_eq!(blocks.len(), 2);
737
738 assert_eq!(blocks[0].language, Some("rust".to_string()));
739 assert!(blocks[0].code.contains("fn main()"));
740
741 assert_eq!(blocks[1].language, Some("python".to_string()));
742 assert!(blocks[1].code.contains("print"));
743 }
744
745 #[test]
746 fn test_extract_code_blocks_no_language() {
747 let text = "```\nplain code\n```";
748 let blocks = MessageCompressor::extract_code_blocks(text);
749 assert_eq!(blocks.len(), 1);
750 assert_eq!(blocks[0].language, None);
751 }
752
753 #[test]
754 fn test_compress_tool_output_within_limit() {
755 let content = "Short output";
756 let result = MessageCompressor::compress_tool_output(content, 100);
757 assert_eq!(result, content);
758 }
759
760 #[test]
761 fn test_compress_tool_output_exceeds_limit() {
762 let content = "A".repeat(1000);
763 let result = MessageCompressor::compress_tool_output(&content, 100);
764
765 assert!(result.len() < content.len());
766 assert!(result.contains("characters omitted"));
767 assert!(result.starts_with("AAAA"));
768 assert!(result.ends_with("AAAA"));
769 }
770
771 #[test]
772 fn test_extract_file_references() {
773 let text = "Check src/main.rs and ./lib/utils.ts for details";
774 let refs = MessageCompressor::extract_file_references(text);
775
776 assert!(refs.contains(&"src/main.rs".to_string()));
777 assert!(refs.contains(&"./lib/utils.ts".to_string()));
778 }
779
780 #[test]
781 fn test_compress_code_blocks_in_text() {
782 let lines: Vec<String> = (0..100).map(|i| format!(" line {}", i)).collect();
783 let code = lines.join("\n");
784 let text = format!("Before\n```rust\n{}```\nAfter", code);
785
786 let result = MessageCompressor::compress_code_blocks_in_text(&text, 50);
787
788 assert!(result.contains("Before"));
789 assert!(result.contains("After"));
790 assert!(result.contains("lines omitted"));
791 }
792
793 #[test]
794 fn test_truncate_messages_within_limit() {
795 let messages = vec![
796 Message::user().with_text("Hello"),
797 Message::assistant().with_text("Hi there"),
798 ];
799
800 let result = MessageCompressor::truncate_messages(&messages, 10000, 1, 1);
801 assert_eq!(result.len(), 2);
802 }
803
804 #[test]
805 fn test_safe_substring() {
806 let s = "Hello, 世界!";
807 let result = MessageCompressor::safe_substring(s, 0, 7);
808 assert_eq!(result, "Hello, ");
809
810 let result = MessageCompressor::safe_substring(s, 7, 13);
812 assert!(result.contains("世"));
813 }
814
815 #[test]
816 fn test_head_tail_ratio() {
817 let lines: Vec<String> = (0..100).map(|i| format!("line {}", i)).collect();
819 let code = lines.join("\n");
820
821 let result = MessageCompressor::compress_code_block(&code, 50);
822 let result_lines: Vec<&str> = result.lines().collect();
823
824 let content_lines: Vec<&str> = result_lines
826 .iter()
827 .filter(|l| !l.contains("omitted"))
828 .copied()
829 .collect();
830
831 assert!(content_lines.len() >= 48 && content_lines.len() <= 52);
833 }
834
835 #[test]
836 fn test_compress_tool_output_with_pruning_no_pruning() {
837 let content = "A".repeat(1000);
838 let config = PruningConfig::default();
839
840 let result =
842 MessageCompressor::compress_tool_output_with_pruning(&content, 2000, 0.2, &config);
843
844 assert_eq!(result, content);
846 }
847
848 #[test]
849 fn test_compress_tool_output_with_pruning_soft_trim() {
850 let content = "A".repeat(2000);
851 let config = PruningConfig::default();
852
853 let result =
855 MessageCompressor::compress_tool_output_with_pruning(&content, 3000, 0.4, &config);
856
857 assert!(result.contains("chars omitted"));
859 assert!(result.len() < content.len());
860 }
861
862 #[test]
863 fn test_compress_tool_output_with_pruning_hard_clear() {
864 let content = "A".repeat(2000);
865 let config = PruningConfig::default();
866
867 let result =
869 MessageCompressor::compress_tool_output_with_pruning(&content, 3000, 0.6, &config);
870
871 assert_eq!(result, "[content cleared]");
873 }
874
875 #[test]
876 fn test_compress_with_pruning() {
877 let messages = vec![
878 Message::user().with_text("Hello"),
879 Message::assistant().with_text("Hi there"),
880 ];
881 let compression_config = CompressionConfig::default();
882 let pruning_config = PruningConfig::default();
883
884 let result = MessageCompressor::compress_with_pruning(
886 &messages,
887 0.2,
888 &compression_config,
889 &pruning_config,
890 );
891
892 assert_eq!(result.len(), messages.len());
893 }
894}