1use anyhow::Result;
8
9use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
10use super::hardcode_config::HardcodeConfig;
11
12use super::compressor::{compress_messages, estimate_total_tokens};
13use super::config::{
14 CircuitBreakerState, CompressionConfig, TIME_BASED_MC_CLEARED_MESSAGE, ThresholdLevel,
15};
16use super::dependency::DependencyBuilder;
17use super::phase_detector::PhaseDetector;
18use super::scorer::Scorer;
19use super::summarizer::Summarizer;
20use super::tool_compressor::ToolCompressor;
21use super::types::{
22 AiCompressionMode, CompressionStrategy, CompressionThresholds, DependencyGraph, ScoredMessage,
23};
24
25pub struct CompressionPipeline {
27 config: CompressionConfig,
29 scorer: Scorer,
31 tool_compressor: ToolCompressor,
33 circuit_breaker: CircuitBreakerState,
35 hardcode_config: HardcodeConfig,
37}
38
39pub struct CompressionOutcome {
41 pub messages: Vec<Message>,
43 pub threshold_level: ThresholdLevel,
45 pub percent_left: u32,
47 pub success: bool,
49 pub error: Option<String>,
51 pub circuit_breaker_tripped: bool,
53}
54
55#[derive(Debug, Clone)]
57pub enum ValidationError {
58 OrphanedToolResult { tool_use_id: String, index: usize },
60 OrphanedToolUse { tool_use_id: String, index: usize },
62 MissingFirstMessage,
64 OrderViolation {
66 expected_role: Role,
67 actual_role: Role,
68 index: usize,
69 },
70}
71
72impl CompressionPipeline {
73 pub fn new_rule_only(config: CompressionConfig) -> Self {
75 let thresholds = CompressionThresholds::default();
76 Self {
77 config,
78 scorer: Scorer::new_rule_only(),
79 tool_compressor: ToolCompressor::new_truncate_only(thresholds),
80 circuit_breaker: CircuitBreakerState::new(),
81 hardcode_config: HardcodeConfig::default(),
82 }
83 }
84
85 pub fn new_with_ai(config: CompressionConfig, fast_model: Box<dyn Provider>) -> Self {
87 let thresholds = CompressionThresholds::default();
88 let summarizer = Summarizer::new(fast_model.clone());
89
90 Self {
91 config,
92 scorer: Scorer::new_with_ai(fast_model),
93 tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
94 circuit_breaker: CircuitBreakerState::new(),
95 hardcode_config: HardcodeConfig::default(),
96 }
97 }
98
99 pub fn new_with_full_ai(
101 config: CompressionConfig,
102 fast_model: Box<dyn Provider>,
103 main_model: Box<dyn Provider>,
104 ) -> Self {
105 let thresholds = CompressionThresholds::default();
106 let summarizer = Summarizer::new_with_main(fast_model.clone(), main_model);
107
108 Self {
109 config,
110 scorer: Scorer::new_with_ai(fast_model),
111 tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
112 circuit_breaker: CircuitBreakerState::new(),
113 hardcode_config: HardcodeConfig::default(),
114 }
115 }
116
117 pub fn should_compress(&self, token_usage: u32, context_window: u32) -> (bool, ThresholdLevel) {
119 if self.circuit_breaker.should_skip() {
121 return (false, ThresholdLevel::Blocking);
122 }
123
124 let (level, _) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
125
126 let should_compress = level != ThresholdLevel::Normal;
127 (should_compress, level)
128 }
129
130 pub fn should_time_based_clear(messages: &[Message]) -> bool {
133 let last_assistant = messages.iter().rev().find(|m| m.role == Role::Assistant);
134
135 if let Some(_msg) = last_assistant {
136 let messages_since = messages
139 .iter()
140 .rev()
141 .take_while(|m| m.role != Role::Assistant)
142 .count();
143 messages_since > 10
145 } else {
146 false
147 }
148 }
149
150 pub fn time_based_microcompact(messages: &[Message]) -> Vec<Message> {
152 let config = HardcodeConfig::default();
153 messages
154 .iter()
155 .map(|msg| {
156 if msg.role != Role::Tool {
157 return msg.clone();
158 }
159
160 match &msg.content {
162 MessageContent::Blocks(blocks) => {
163 let new_blocks: Vec<ContentBlock> = blocks
164 .iter()
165 .map(|b| {
166 if let ContentBlock::ToolResult {
167 tool_use_id,
168 content,
169 } = b
170 {
171 if content.len() > config.preserve_content_threshold
173 && content != TIME_BASED_MC_CLEARED_MESSAGE
174 {
175 ContentBlock::ToolResult {
176 tool_use_id: tool_use_id.clone(),
177 content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
178 }
179 } else {
180 b.clone()
181 }
182 } else {
183 b.clone()
184 }
185 })
186 .collect();
187 Message {
188 role: msg.role.clone(),
189 content: MessageContent::Blocks(new_blocks),
190 }
191 }
192 _ => msg.clone(),
193 }
194 })
195 .collect()
196 }
197
198 pub fn strip_thinking(messages: &[Message]) -> Vec<Message> {
201 messages
202 .iter()
203 .map(|msg| {
204 match &msg.content {
205 MessageContent::Blocks(blocks) => {
206 let new_blocks: Vec<ContentBlock> = blocks
207 .iter()
208 .filter(|b| {
209 !matches!(b, ContentBlock::Thinking { .. })
211 })
212 .cloned()
213 .collect();
214 Message {
215 role: msg.role.clone(),
216 content: MessageContent::Blocks(new_blocks),
217 }
218 }
219 _ => msg.clone(),
220 }
221 })
222 .collect()
223 }
224
225 const COMPACTABLE_TOOLS: &[&str] = &[
228 "bash",
229 "read",
230 "glob",
231 "grep",
232 "ls",
233 "edit",
234 "write",
235 "notebook_edit",
236 "web_fetch",
237 "web_search",
238 ];
239
240 pub fn is_compactable_tool(tool_name: &str) -> bool {
242 Self::COMPACTABLE_TOOLS.contains(&tool_name)
243 }
244
245 pub fn clear_tool_results(messages: &[Message], _tool_names: &[&str]) -> Vec<Message> {
247 let config = HardcodeConfig::default();
248 messages
249 .iter()
250 .map(|msg| {
251 if msg.role != Role::Tool {
252 return msg.clone();
253 }
254
255 match &msg.content {
256 MessageContent::Blocks(blocks) => {
257 let new_blocks: Vec<ContentBlock> = blocks
258 .iter()
259 .map(|b| {
260 if let ContentBlock::ToolResult {
261 tool_use_id,
262 content,
263 } = b
264 {
265 if content.len() > config.preserve_content_threshold
268 && content != TIME_BASED_MC_CLEARED_MESSAGE
269 {
270 ContentBlock::ToolResult {
271 tool_use_id: tool_use_id.clone(),
272 content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
273 }
274 } else {
275 b.clone()
276 }
277 } else {
278 b.clone()
279 }
280 })
281 .collect();
282 Message {
283 role: msg.role.clone(),
284 content: MessageContent::Blocks(new_blocks),
285 }
286 }
287 _ => msg.clone(),
288 }
289 })
290 .collect()
291 }
292
293 pub fn full_microcompact(messages: &[Message]) -> Vec<Message> {
295 let no_thinking = Self::strip_thinking(messages);
297 Self::time_based_microcompact(&no_thinking)
299 }
300
301 pub fn validate_compression(
307 messages: &[Message],
308 _original_deps: &DependencyGraph,
309 ) -> Vec<ValidationError> {
310 let mut errors = Vec::new();
311
312 if messages.is_empty() {
314 errors.push(ValidationError::MissingFirstMessage);
315 return errors;
316 }
317
318 let new_deps = DependencyBuilder::build(messages);
320
321 for (idx, msg) in messages.iter().enumerate() {
323 if msg.role == Role::Tool
324 && let MessageContent::Blocks(blocks) = &msg.content
325 {
326 for block in blocks {
327 if let ContentBlock::ToolResult { tool_use_id, .. } = block {
328 let has_tool_use = messages.iter().any(|m| {
330 if let MessageContent::Blocks(bs) = &m.content {
331 bs.iter().any(|b| {
332 if let ContentBlock::ToolUse { id, .. } = b {
333 id == tool_use_id
334 } else {
335 false
336 }
337 })
338 } else {
339 false
340 }
341 });
342
343 if !has_tool_use {
344 errors.push(ValidationError::OrphanedToolResult {
345 tool_use_id: tool_use_id.clone(),
346 index: idx,
347 });
348 }
349 }
350 }
351 }
352 }
353
354 for (idx, msg) in messages.iter().enumerate() {
356 if let MessageContent::Blocks(blocks) = &msg.content {
357 for block in blocks {
358 if let ContentBlock::ToolUse { id, .. } = block {
359 let has_tool_result = messages.iter().any(|m| {
361 if m.role == Role::Tool {
362 if let MessageContent::Blocks(bs) = &m.content {
363 bs.iter().any(|b| {
364 if let ContentBlock::ToolResult { tool_use_id, .. } = b {
365 tool_use_id == id
366 } else {
367 false
368 }
369 })
370 } else {
371 false
372 }
373 } else {
374 false
375 }
376 });
377
378 if !has_tool_result {
379 errors.push(ValidationError::OrphanedToolUse {
380 tool_use_id: id.clone(),
381 index: idx,
382 });
383 }
384 }
385 }
386 }
387 }
388
389 for dep in &new_deps.dependencies {
391 if dep.tool_use_idx >= messages.len() {
392 errors.push(ValidationError::OrphanedToolUse {
393 tool_use_id: dep.tool_name.clone(),
394 index: dep.tool_use_idx,
395 });
396 }
397 if dep.tool_result_idx >= messages.len() {
398 errors.push(ValidationError::OrphanedToolResult {
399 tool_use_id: dep.tool_name.clone(),
400 index: dep.tool_result_idx,
401 });
402 }
403 }
404
405 errors
406 }
407
408 pub fn is_valid_compression(messages: &[Message], original_deps: &DependencyGraph) -> bool {
410 Self::validate_compression(messages, original_deps).is_empty()
411 }
412
413 pub async fn execute(
415 &mut self,
416 messages: &[Message],
417 ai_mode: AiCompressionMode,
418 token_usage: u32,
419 context_window: u32,
420 ) -> Result<CompressionOutcome> {
421 if self.circuit_breaker.should_skip() {
423 return Ok(CompressionOutcome {
424 messages: messages.to_vec(),
425 threshold_level: ThresholdLevel::Blocking,
426 percent_left: 0,
427 success: false,
428 error: Some("Circuit breaker tripped - too many consecutive failures".to_string()),
429 circuit_breaker_tripped: true,
430 });
431 }
432
433 if messages.len() <= self.config.min_preserve_messages {
434 let (level, percent) =
435 CompressionConfig::calculate_threshold_level(token_usage, context_window);
436 return Ok(CompressionOutcome {
437 messages: messages.to_vec(),
438 threshold_level: level,
439 percent_left: percent,
440 success: true,
441 error: None,
442 circuit_breaker_tripped: false,
443 });
444 }
445
446 let pre_processed = if Self::should_time_based_clear(messages) {
448 Self::time_based_microcompact(messages)
449 } else {
450 messages.to_vec()
451 };
452
453 let phase = PhaseDetector::detect(&pre_processed);
455 let weights = phase.default_weights();
456 let deps = DependencyBuilder::build(&pre_processed);
457
458 let scored = self
460 .scorer
461 .score_all(&pre_processed, &weights, &deps, ai_mode)
462 .await?;
463
464 let compressed = self
466 .tool_compressor
467 .compress_results(&pre_processed, ai_mode)
468 .await?;
469
470 let target_count = calculate_target_count(pre_processed.len(), &self.config);
472 let selected = self.select_messages(scored, &deps, target_count, &compressed);
473
474 let final_messages = self.ensure_dependency_integrity(selected, &deps, &pre_processed);
476
477 self.circuit_breaker.record_success();
479
480 let post_tokens = estimate_total_tokens(&final_messages);
482 let (level, percent) =
483 CompressionConfig::calculate_threshold_level(post_tokens, context_window);
484
485 Ok(CompressionOutcome {
486 messages: final_messages,
487 threshold_level: level,
488 percent_left: percent,
489 success: true,
490 error: None,
491 circuit_breaker_tripped: false,
492 })
493 }
494
495 pub async fn execute_with_circuit_breaker(
497 &mut self,
498 messages: &[Message],
499 ai_mode: AiCompressionMode,
500 token_usage: u32,
501 context_window: u32,
502 ) -> Result<CompressionOutcome> {
503 let result = self
504 .execute(messages, ai_mode, token_usage, context_window)
505 .await;
506
507 match result {
508 Ok(res) => Ok(res),
509 Err(e) => {
510 let tripped = self.circuit_breaker.record_failure();
512
513 let (level, percent) =
514 CompressionConfig::calculate_threshold_level(token_usage, context_window);
515
516 Ok(CompressionOutcome {
517 messages: messages.to_vec(),
518 threshold_level: level,
519 percent_left: percent,
520 success: false,
521 error: Some(e.to_string()),
522 circuit_breaker_tripped: tripped,
523 })
524 }
525 }
526 }
527
528 pub fn execute_sync(&self, messages: &[Message]) -> Result<Vec<Message>> {
530 compress_messages(messages, CompressionStrategy::BiasBased, &self.config)
532 }
533
534 fn select_messages(
536 &self,
537 scored: Vec<ScoredMessage>,
538 deps: &DependencyGraph,
539 target_count: usize,
540 compressed_messages: &[Message],
541 ) -> Vec<Message> {
542 let mut sorted = scored;
544 sorted.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
545
546 let mut preserve_indices: std::collections::HashSet<usize> =
548 std::collections::HashSet::new();
549
550 for sm in sorted.iter().take(target_count) {
552 preserve_indices.insert(sm.index);
553
554 for pair_idx in deps.get_pair_indices(sm.index) {
556 preserve_indices.insert(pair_idx);
557 }
558 }
559
560 let selected: Vec<Message> = preserve_indices
562 .iter()
563 .filter_map(|idx| compressed_messages.get(*idx).cloned())
564 .collect();
565
566 selected
567 }
568
569 fn ensure_dependency_integrity(
571 &self,
572 selected: Vec<Message>,
573 _deps: &DependencyGraph,
574 _original: &[Message],
575 ) -> Vec<Message> {
576 selected
579 }
580
581 pub fn score_only(&self, messages: &[Message]) -> Vec<ScoredMessage> {
583 let phase = PhaseDetector::detect(messages);
584 let weights = phase.default_weights();
585 let deps = DependencyBuilder::build(messages);
586
587 let mut scored: Vec<ScoredMessage> = Vec::new();
589 for (idx, msg) in messages.iter().enumerate() {
590 let base_score = super::scorer::score_by_rules(msg, idx, &weights);
591 scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
592 }
593
594 let bonus = weights.dependency_pair_bonus;
596 for dep in &deps.dependencies {
597 if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
598 sm.with_dependency_bonus(bonus);
599 }
600 if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
601 sm.with_dependency_bonus(bonus);
602 }
603 }
604
605 scored
606 }
607}
608
609fn calculate_target_count(total: usize, config: &CompressionConfig) -> usize {
611 let target = (total as f64 * config.target_ratio) as usize;
612 target.max(config.min_preserve_messages)
613}
614
615pub fn compress_with_pipeline(
617 messages: &[Message],
618 config: &CompressionConfig,
619 ai_mode: AiCompressionMode,
620 fast_model: Option<Box<dyn Provider>>,
621) -> Result<Vec<Message>> {
622 let pipeline = match (ai_mode, fast_model) {
624 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
625 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
626 CompressionPipeline::new_with_ai(config.clone(), model)
627 }
628 _ => CompressionPipeline::new_rule_only(config.clone()),
629 };
630
631 pipeline.execute_sync(messages)
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use crate::providers::{MessageContent, Role};
639
640 #[test]
641 fn test_pipeline_new_rule_only() {
642 let config = CompressionConfig::default();
643 let pipeline = CompressionPipeline::new_rule_only(config);
644 let messages = vec![Message {
646 role: Role::User,
647 content: MessageContent::Text("Test".to_string()),
648 }];
649 let result = pipeline.execute_sync(&messages);
650 assert!(result.is_ok());
651 }
652
653 #[test]
654 fn test_calculate_target_count() {
655 let config = CompressionConfig::default();
656 let total = 100;
657 let target = calculate_target_count(total, &config);
658 assert!(target >= config.min_preserve_messages);
659 assert!(target < total);
660 }
661
662 #[test]
663 fn test_score_only() {
664 let config = CompressionConfig::default();
665 let pipeline = CompressionPipeline::new_rule_only(config);
666
667 let messages = vec![
668 Message {
669 role: Role::User,
670 content: MessageContent::Text("Hello".to_string()),
671 },
672 Message {
673 role: Role::Assistant,
674 content: MessageContent::Text("Hi".to_string()),
675 },
676 ];
677
678 let scored = pipeline.score_only(&messages);
679 assert_eq!(scored.len(), 2);
680 assert!(scored[0].final_score > scored[1].final_score); }
682
683 #[test]
684 fn test_execute_sync_small() {
685 let config = CompressionConfig::default();
686 let pipeline = CompressionPipeline::new_rule_only(config);
687
688 let messages = vec![Message {
689 role: Role::User,
690 content: MessageContent::Text("Hello".to_string()),
691 }];
692
693 let result = pipeline.execute_sync(&messages).unwrap();
694 assert_eq!(result.len(), 1); }
696
697 #[test]
698 fn test_time_based_microcompact() {
699 let messages = vec![
700 Message {
701 role: Role::Tool,
702 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
703 tool_use_id: "tool_1".to_string(),
704 content: "This is a very long tool result content that should be cleared..."
705 .repeat(20),
706 }]),
707 },
708 Message {
709 role: Role::Tool,
710 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
711 tool_use_id: "tool_2".to_string(),
712 content: "Short content".to_string(),
713 }]),
714 },
715 ];
716
717 let compacted = CompressionPipeline::time_based_microcompact(&messages);
718
719 if let MessageContent::Blocks(blocks) = &compacted[0].content {
721 if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
722 assert_eq!(content, TIME_BASED_MC_CLEARED_MESSAGE);
723 }
724 }
725
726 if let MessageContent::Blocks(blocks) = &compacted[1].content {
728 if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
729 assert_eq!(content, "Short content");
730 }
731 }
732 }
733
734 #[test]
735 fn test_strip_thinking() {
736 let messages = vec![Message {
737 role: Role::Assistant,
738 content: MessageContent::Blocks(vec![
739 ContentBlock::Text {
740 text: "Response".to_string(),
741 },
742 ContentBlock::Thinking {
743 thinking: "Long thinking process...".to_string(),
744 signature: None,
745 },
746 ]),
747 }];
748
749 let stripped = CompressionPipeline::strip_thinking(&messages);
750
751 if let MessageContent::Blocks(blocks) = &stripped[0].content {
753 assert_eq!(blocks.len(), 1);
754 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
755 }
756 }
757
758 #[test]
759 fn test_is_compactable_tool() {
760 assert!(CompressionPipeline::is_compactable_tool("bash"));
761 assert!(CompressionPipeline::is_compactable_tool("read"));
762 assert!(CompressionPipeline::is_compactable_tool("glob"));
763 assert!(!CompressionPipeline::is_compactable_tool("unknown_tool"));
764 }
765
766 #[test]
767 fn test_should_time_based_clear() {
768 let mut many_messages: Vec<Message> = vec![Message {
770 role: Role::Assistant,
771 content: MessageContent::Text("response".to_string()),
772 }];
773 for i in 0..15 {
775 many_messages.push(Message {
776 role: if i % 2 == 0 { Role::User } else { Role::Tool },
777 content: MessageContent::Text("content".to_string()),
778 });
779 }
780
781 assert!(CompressionPipeline::should_time_based_clear(&many_messages));
782
783 let few_messages = vec![
785 Message {
786 role: Role::Assistant,
787 content: MessageContent::Text("response".to_string()),
788 },
789 Message {
790 role: Role::User,
791 content: MessageContent::Text("follow-up".to_string()),
792 },
793 ];
794
795 assert!(!CompressionPipeline::should_time_based_clear(&few_messages));
796 }
797
798 #[test]
799 fn test_validate_compression_valid() {
800 let messages = vec![
801 Message {
802 role: Role::User,
803 content: MessageContent::Text("Request".to_string()),
804 },
805 Message {
806 role: Role::Assistant,
807 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
808 id: "tool_1".to_string(),
809 name: "read".to_string(),
810 input: serde_json::json!({"path": "test.txt"}),
811 }]),
812 },
813 Message {
814 role: Role::Tool,
815 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
816 tool_use_id: "tool_1".to_string(),
817 content: "File content".to_string(),
818 }]),
819 },
820 ];
821
822 let deps = DependencyBuilder::build(&messages);
823 let errors = CompressionPipeline::validate_compression(&messages, &deps);
824 assert!(errors.is_empty());
825 }
826
827 #[test]
828 fn test_validate_compression_orphaned_tool_result() {
829 let messages = vec![
830 Message {
831 role: Role::User,
832 content: MessageContent::Text("Request".to_string()),
833 },
834 Message {
835 role: Role::Tool,
836 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
837 tool_use_id: "tool_missing".to_string(),
838 content: "Orphaned result".to_string(),
839 }]),
840 },
841 ];
842
843 let deps = DependencyBuilder::build(&messages);
844 let errors = CompressionPipeline::validate_compression(&messages, &deps);
845 assert!(!errors.is_empty());
846 assert!(
847 errors
848 .iter()
849 .any(|e| matches!(e, ValidationError::OrphanedToolResult { .. }))
850 );
851 }
852
853 #[test]
854 fn test_validate_compression_empty() {
855 let messages: Vec<Message> = vec![];
856 let deps = DependencyBuilder::build(&messages);
857 let errors = CompressionPipeline::validate_compression(&messages, &deps);
858 assert!(!errors.is_empty());
859 assert!(
860 errors
861 .iter()
862 .any(|e| matches!(e, ValidationError::MissingFirstMessage))
863 );
864 }
865}