1use anyhow::Result;
8
9use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
10
11use super::compressor::{compress_messages, estimate_total_tokens};
12use super::config::{
13 CircuitBreakerState, CompressionConfig, TIME_BASED_MC_CLEARED_MESSAGE, ThresholdLevel,
14};
15use super::dependency::DependencyBuilder;
16use super::phase_detector::PhaseDetector;
17use super::scorer::Scorer;
18use super::summarizer::Summarizer;
19use super::tool_compressor::ToolCompressor;
20use super::types::{
21 AiCompressionMode, CompressionStrategy, CompressionThresholds, DependencyGraph, ScoredMessage,
22};
23
24pub struct CompressionPipeline {
26 config: CompressionConfig,
28 scorer: Scorer,
30 tool_compressor: ToolCompressor,
32 circuit_breaker: CircuitBreakerState,
34}
35
36pub struct CompressionOutcome {
38 pub messages: Vec<Message>,
40 pub threshold_level: ThresholdLevel,
42 pub percent_left: u32,
44 pub success: bool,
46 pub error: Option<String>,
48 pub circuit_breaker_tripped: bool,
50}
51
52#[derive(Debug, Clone)]
54pub enum ValidationError {
55 OrphanedToolResult { tool_use_id: String, index: usize },
57 OrphanedToolUse { tool_use_id: String, index: usize },
59 MissingFirstMessage,
61 OrderViolation {
63 expected_role: Role,
64 actual_role: Role,
65 index: usize,
66 },
67}
68
69impl CompressionPipeline {
70 pub fn new_rule_only(config: CompressionConfig) -> Self {
72 let thresholds = CompressionThresholds::default();
73 Self {
74 config,
75 scorer: Scorer::new_rule_only(),
76 tool_compressor: ToolCompressor::new_truncate_only(thresholds),
77 circuit_breaker: CircuitBreakerState::new(),
78 }
79 }
80
81 pub fn new_with_ai(config: CompressionConfig, fast_model: Box<dyn Provider>) -> Self {
83 let thresholds = CompressionThresholds::default();
84 let summarizer = Summarizer::new(fast_model.clone());
85
86 Self {
87 config,
88 scorer: Scorer::new_with_ai(fast_model),
89 tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
90 circuit_breaker: CircuitBreakerState::new(),
91 }
92 }
93
94 pub fn new_with_full_ai(
96 config: CompressionConfig,
97 fast_model: Box<dyn Provider>,
98 main_model: Box<dyn Provider>,
99 ) -> Self {
100 let thresholds = CompressionThresholds::default();
101 let summarizer = Summarizer::new_with_main(fast_model.clone(), main_model);
102
103 Self {
104 config,
105 scorer: Scorer::new_with_ai(fast_model),
106 tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
107 circuit_breaker: CircuitBreakerState::new(),
108 }
109 }
110
111 pub fn should_compress(&self, token_usage: u32, context_window: u32) -> (bool, ThresholdLevel) {
113 if self.circuit_breaker.should_skip() {
115 return (false, ThresholdLevel::Blocking);
116 }
117
118 let (level, _) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
119
120 let should_compress = level != ThresholdLevel::Normal;
121 (should_compress, level)
122 }
123
124 pub fn should_time_based_clear(messages: &[Message]) -> bool {
127 let last_assistant = messages.iter().rev().find(|m| m.role == Role::Assistant);
128
129 if let Some(_msg) = last_assistant {
130 let messages_since = messages
133 .iter()
134 .rev()
135 .take_while(|m| m.role != Role::Assistant)
136 .count();
137 messages_since > 10
139 } else {
140 false
141 }
142 }
143
144 pub fn time_based_microcompact(messages: &[Message]) -> Vec<Message> {
146 messages
147 .iter()
148 .map(|msg| {
149 if msg.role != Role::Tool {
150 return msg.clone();
151 }
152
153 match &msg.content {
155 MessageContent::Blocks(blocks) => {
156 let new_blocks: Vec<ContentBlock> = blocks
157 .iter()
158 .map(|b| {
159 if let ContentBlock::ToolResult {
160 tool_use_id,
161 content,
162 } = b
163 {
164 if content.len() > 500
166 && content != TIME_BASED_MC_CLEARED_MESSAGE
167 {
168 ContentBlock::ToolResult {
169 tool_use_id: tool_use_id.clone(),
170 content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
171 }
172 } else {
173 b.clone()
174 }
175 } else {
176 b.clone()
177 }
178 })
179 .collect();
180 Message {
181 role: msg.role.clone(),
182 content: MessageContent::Blocks(new_blocks),
183 }
184 }
185 _ => msg.clone(),
186 }
187 })
188 .collect()
189 }
190
191 pub fn strip_thinking(messages: &[Message]) -> Vec<Message> {
194 messages
195 .iter()
196 .map(|msg| {
197 match &msg.content {
198 MessageContent::Blocks(blocks) => {
199 let new_blocks: Vec<ContentBlock> = blocks
200 .iter()
201 .filter(|b| {
202 !matches!(b, ContentBlock::Thinking { .. })
204 })
205 .cloned()
206 .collect();
207 Message {
208 role: msg.role.clone(),
209 content: MessageContent::Blocks(new_blocks),
210 }
211 }
212 _ => msg.clone(),
213 }
214 })
215 .collect()
216 }
217
218 const COMPACTABLE_TOOLS: &[&str] = &[
221 "bash",
222 "read",
223 "glob",
224 "grep",
225 "ls",
226 "edit",
227 "write",
228 "notebook_edit",
229 "web_fetch",
230 "web_search",
231 ];
232
233 pub fn is_compactable_tool(tool_name: &str) -> bool {
235 Self::COMPACTABLE_TOOLS.contains(&tool_name)
236 }
237
238 pub fn clear_tool_results(messages: &[Message], _tool_names: &[&str]) -> Vec<Message> {
240 messages
241 .iter()
242 .map(|msg| {
243 if msg.role != Role::Tool {
244 return msg.clone();
245 }
246
247 match &msg.content {
248 MessageContent::Blocks(blocks) => {
249 let new_blocks: Vec<ContentBlock> = blocks
250 .iter()
251 .map(|b| {
252 if let ContentBlock::ToolResult {
253 tool_use_id,
254 content,
255 } = b
256 {
257 if content.len() > 500
260 && content != TIME_BASED_MC_CLEARED_MESSAGE
261 {
262 ContentBlock::ToolResult {
263 tool_use_id: tool_use_id.clone(),
264 content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
265 }
266 } else {
267 b.clone()
268 }
269 } else {
270 b.clone()
271 }
272 })
273 .collect();
274 Message {
275 role: msg.role.clone(),
276 content: MessageContent::Blocks(new_blocks),
277 }
278 }
279 _ => msg.clone(),
280 }
281 })
282 .collect()
283 }
284
285 pub fn full_microcompact(messages: &[Message]) -> Vec<Message> {
287 let no_thinking = Self::strip_thinking(messages);
289 Self::time_based_microcompact(&no_thinking)
291 }
292
293 pub fn validate_compression(
299 messages: &[Message],
300 _original_deps: &DependencyGraph,
301 ) -> Vec<ValidationError> {
302 let mut errors = Vec::new();
303
304 if messages.is_empty() {
306 errors.push(ValidationError::MissingFirstMessage);
307 return errors;
308 }
309
310 let new_deps = DependencyBuilder::build(messages);
312
313 for (idx, msg) in messages.iter().enumerate() {
315 if msg.role == Role::Tool
316 && let MessageContent::Blocks(blocks) = &msg.content
317 {
318 for block in blocks {
319 if let ContentBlock::ToolResult { tool_use_id, .. } = block {
320 let has_tool_use = messages.iter().any(|m| {
322 if let MessageContent::Blocks(bs) = &m.content {
323 bs.iter().any(|b| {
324 if let ContentBlock::ToolUse { id, .. } = b {
325 id == tool_use_id
326 } else {
327 false
328 }
329 })
330 } else {
331 false
332 }
333 });
334
335 if !has_tool_use {
336 errors.push(ValidationError::OrphanedToolResult {
337 tool_use_id: tool_use_id.clone(),
338 index: idx,
339 });
340 }
341 }
342 }
343 }
344 }
345
346 for (idx, msg) in messages.iter().enumerate() {
348 if let MessageContent::Blocks(blocks) = &msg.content {
349 for block in blocks {
350 if let ContentBlock::ToolUse { id, .. } = block {
351 let has_tool_result = messages.iter().any(|m| {
353 if m.role == Role::Tool {
354 if let MessageContent::Blocks(bs) = &m.content {
355 bs.iter().any(|b| {
356 if let ContentBlock::ToolResult { tool_use_id, .. } = b {
357 tool_use_id == id
358 } else {
359 false
360 }
361 })
362 } else {
363 false
364 }
365 } else {
366 false
367 }
368 });
369
370 if !has_tool_result {
371 errors.push(ValidationError::OrphanedToolUse {
372 tool_use_id: id.clone(),
373 index: idx,
374 });
375 }
376 }
377 }
378 }
379 }
380
381 for dep in &new_deps.dependencies {
383 if dep.tool_use_idx >= messages.len() {
384 errors.push(ValidationError::OrphanedToolUse {
385 tool_use_id: dep.tool_name.clone(),
386 index: dep.tool_use_idx,
387 });
388 }
389 if dep.tool_result_idx >= messages.len() {
390 errors.push(ValidationError::OrphanedToolResult {
391 tool_use_id: dep.tool_name.clone(),
392 index: dep.tool_result_idx,
393 });
394 }
395 }
396
397 errors
398 }
399
400 pub fn is_valid_compression(messages: &[Message], original_deps: &DependencyGraph) -> bool {
402 Self::validate_compression(messages, original_deps).is_empty()
403 }
404
405 pub async fn execute(
407 &mut self,
408 messages: &[Message],
409 ai_mode: AiCompressionMode,
410 token_usage: u32,
411 context_window: u32,
412 ) -> Result<CompressionOutcome> {
413 if self.circuit_breaker.should_skip() {
415 return Ok(CompressionOutcome {
416 messages: messages.to_vec(),
417 threshold_level: ThresholdLevel::Blocking,
418 percent_left: 0,
419 success: false,
420 error: Some("Circuit breaker tripped - too many consecutive failures".to_string()),
421 circuit_breaker_tripped: true,
422 });
423 }
424
425 if messages.len() <= self.config.min_preserve_messages {
426 let (level, percent) =
427 CompressionConfig::calculate_threshold_level(token_usage, context_window);
428 return Ok(CompressionOutcome {
429 messages: messages.to_vec(),
430 threshold_level: level,
431 percent_left: percent,
432 success: true,
433 error: None,
434 circuit_breaker_tripped: false,
435 });
436 }
437
438 let pre_processed = if Self::should_time_based_clear(messages) {
440 Self::time_based_microcompact(messages)
441 } else {
442 messages.to_vec()
443 };
444
445 let phase = PhaseDetector::detect(&pre_processed);
447 let weights = phase.default_weights();
448 let deps = DependencyBuilder::build(&pre_processed);
449
450 let scored = self
452 .scorer
453 .score_all(&pre_processed, &weights, &deps, ai_mode)
454 .await?;
455
456 let compressed = self
458 .tool_compressor
459 .compress_results(&pre_processed, ai_mode)
460 .await?;
461
462 let target_count = calculate_target_count(pre_processed.len(), &self.config);
464 let selected = self.select_messages(scored, &deps, target_count, &compressed);
465
466 let final_messages = self.ensure_dependency_integrity(selected, &deps, &pre_processed);
468
469 self.circuit_breaker.record_success();
471
472 let post_tokens = estimate_total_tokens(&final_messages);
474 let (level, percent) =
475 CompressionConfig::calculate_threshold_level(post_tokens, context_window);
476
477 Ok(CompressionOutcome {
478 messages: final_messages,
479 threshold_level: level,
480 percent_left: percent,
481 success: true,
482 error: None,
483 circuit_breaker_tripped: false,
484 })
485 }
486
487 pub async fn execute_with_circuit_breaker(
489 &mut self,
490 messages: &[Message],
491 ai_mode: AiCompressionMode,
492 token_usage: u32,
493 context_window: u32,
494 ) -> Result<CompressionOutcome> {
495 let result = self
496 .execute(messages, ai_mode, token_usage, context_window)
497 .await;
498
499 match result {
500 Ok(res) => Ok(res),
501 Err(e) => {
502 let tripped = self.circuit_breaker.record_failure();
504
505 let (level, percent) =
506 CompressionConfig::calculate_threshold_level(token_usage, context_window);
507
508 Ok(CompressionOutcome {
509 messages: messages.to_vec(),
510 threshold_level: level,
511 percent_left: percent,
512 success: false,
513 error: Some(e.to_string()),
514 circuit_breaker_tripped: tripped,
515 })
516 }
517 }
518 }
519
520 pub fn execute_sync(&self, messages: &[Message]) -> Result<Vec<Message>> {
522 compress_messages(messages, CompressionStrategy::BiasBased, &self.config)
524 }
525
526 fn select_messages(
528 &self,
529 scored: Vec<ScoredMessage>,
530 deps: &DependencyGraph,
531 target_count: usize,
532 compressed_messages: &[Message],
533 ) -> Vec<Message> {
534 let mut sorted = scored;
536 sorted.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
537
538 let mut preserve_indices: std::collections::HashSet<usize> =
540 std::collections::HashSet::new();
541
542 for sm in sorted.iter().take(target_count) {
544 preserve_indices.insert(sm.index);
545
546 for pair_idx in deps.get_pair_indices(sm.index) {
548 preserve_indices.insert(pair_idx);
549 }
550 }
551
552 let selected: Vec<Message> = preserve_indices
554 .iter()
555 .filter_map(|idx| compressed_messages.get(*idx).cloned())
556 .collect();
557
558 selected
559 }
560
561 fn ensure_dependency_integrity(
563 &self,
564 selected: Vec<Message>,
565 _deps: &DependencyGraph,
566 _original: &[Message],
567 ) -> Vec<Message> {
568 selected
571 }
572
573 pub fn score_only(&self, messages: &[Message]) -> Vec<ScoredMessage> {
575 let phase = PhaseDetector::detect(messages);
576 let weights = phase.default_weights();
577 let deps = DependencyBuilder::build(messages);
578
579 let mut scored: Vec<ScoredMessage> = Vec::new();
581 for (idx, msg) in messages.iter().enumerate() {
582 let base_score = super::scorer::score_by_rules(msg, idx, &weights);
583 scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
584 }
585
586 let bonus = weights.dependency_pair_bonus;
588 for dep in &deps.dependencies {
589 if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
590 sm.with_dependency_bonus(bonus);
591 }
592 if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
593 sm.with_dependency_bonus(bonus);
594 }
595 }
596
597 scored
598 }
599}
600
601fn calculate_target_count(total: usize, config: &CompressionConfig) -> usize {
603 let target = (total as f64 * config.target_ratio) as usize;
604 target.max(config.min_preserve_messages)
605}
606
607pub fn compress_with_pipeline(
609 messages: &[Message],
610 config: &CompressionConfig,
611 ai_mode: AiCompressionMode,
612 fast_model: Option<Box<dyn Provider>>,
613) -> Result<Vec<Message>> {
614 let pipeline = match (ai_mode, fast_model) {
616 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
617 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
618 CompressionPipeline::new_with_ai(config.clone(), model)
619 }
620 _ => CompressionPipeline::new_rule_only(config.clone()),
621 };
622
623 pipeline.execute_sync(messages)
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630 use crate::providers::{MessageContent, Role};
631
632 #[test]
633 fn test_pipeline_new_rule_only() {
634 let config = CompressionConfig::default();
635 let pipeline = CompressionPipeline::new_rule_only(config);
636 let messages = vec![Message {
638 role: Role::User,
639 content: MessageContent::Text("Test".to_string()),
640 }];
641 let result = pipeline.execute_sync(&messages);
642 assert!(result.is_ok());
643 }
644
645 #[test]
646 fn test_calculate_target_count() {
647 let config = CompressionConfig::default();
648 let total = 100;
649 let target = calculate_target_count(total, &config);
650 assert!(target >= config.min_preserve_messages);
651 assert!(target < total);
652 }
653
654 #[test]
655 fn test_score_only() {
656 let config = CompressionConfig::default();
657 let pipeline = CompressionPipeline::new_rule_only(config);
658
659 let messages = vec![
660 Message {
661 role: Role::User,
662 content: MessageContent::Text("Hello".to_string()),
663 },
664 Message {
665 role: Role::Assistant,
666 content: MessageContent::Text("Hi".to_string()),
667 },
668 ];
669
670 let scored = pipeline.score_only(&messages);
671 assert_eq!(scored.len(), 2);
672 assert!(scored[0].final_score > scored[1].final_score); }
674
675 #[test]
676 fn test_execute_sync_small() {
677 let config = CompressionConfig::default();
678 let pipeline = CompressionPipeline::new_rule_only(config);
679
680 let messages = vec![Message {
681 role: Role::User,
682 content: MessageContent::Text("Hello".to_string()),
683 }];
684
685 let result = pipeline.execute_sync(&messages).unwrap();
686 assert_eq!(result.len(), 1); }
688
689 #[test]
690 fn test_time_based_microcompact() {
691 let messages = vec![
692 Message {
693 role: Role::Tool,
694 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
695 tool_use_id: "tool_1".to_string(),
696 content: "This is a very long tool result content that should be cleared..."
697 .repeat(20),
698 }]),
699 },
700 Message {
701 role: Role::Tool,
702 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
703 tool_use_id: "tool_2".to_string(),
704 content: "Short content".to_string(),
705 }]),
706 },
707 ];
708
709 let compacted = CompressionPipeline::time_based_microcompact(&messages);
710
711 if let MessageContent::Blocks(blocks) = &compacted[0].content {
713 if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
714 assert_eq!(content, TIME_BASED_MC_CLEARED_MESSAGE);
715 }
716 }
717
718 if let MessageContent::Blocks(blocks) = &compacted[1].content {
720 if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
721 assert_eq!(content, "Short content");
722 }
723 }
724 }
725
726 #[test]
727 fn test_strip_thinking() {
728 let messages = vec![Message {
729 role: Role::Assistant,
730 content: MessageContent::Blocks(vec![
731 ContentBlock::Text {
732 text: "Response".to_string(),
733 },
734 ContentBlock::Thinking {
735 thinking: "Long thinking process...".to_string(),
736 signature: None,
737 },
738 ]),
739 }];
740
741 let stripped = CompressionPipeline::strip_thinking(&messages);
742
743 if let MessageContent::Blocks(blocks) = &stripped[0].content {
745 assert_eq!(blocks.len(), 1);
746 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
747 }
748 }
749
750 #[test]
751 fn test_is_compactable_tool() {
752 assert!(CompressionPipeline::is_compactable_tool("bash"));
753 assert!(CompressionPipeline::is_compactable_tool("read"));
754 assert!(CompressionPipeline::is_compactable_tool("glob"));
755 assert!(!CompressionPipeline::is_compactable_tool("unknown_tool"));
756 }
757
758 #[test]
759 fn test_should_time_based_clear() {
760 let mut many_messages: Vec<Message> = vec![Message {
762 role: Role::Assistant,
763 content: MessageContent::Text("response".to_string()),
764 }];
765 for i in 0..15 {
767 many_messages.push(Message {
768 role: if i % 2 == 0 { Role::User } else { Role::Tool },
769 content: MessageContent::Text("content".to_string()),
770 });
771 }
772
773 assert!(CompressionPipeline::should_time_based_clear(&many_messages));
774
775 let few_messages = vec![
777 Message {
778 role: Role::Assistant,
779 content: MessageContent::Text("response".to_string()),
780 },
781 Message {
782 role: Role::User,
783 content: MessageContent::Text("follow-up".to_string()),
784 },
785 ];
786
787 assert!(!CompressionPipeline::should_time_based_clear(&few_messages));
788 }
789
790 #[test]
791 fn test_validate_compression_valid() {
792 let messages = vec![
793 Message {
794 role: Role::User,
795 content: MessageContent::Text("Request".to_string()),
796 },
797 Message {
798 role: Role::Assistant,
799 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
800 id: "tool_1".to_string(),
801 name: "read".to_string(),
802 input: serde_json::json!({"path": "test.txt"}),
803 }]),
804 },
805 Message {
806 role: Role::Tool,
807 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
808 tool_use_id: "tool_1".to_string(),
809 content: "File content".to_string(),
810 }]),
811 },
812 ];
813
814 let deps = DependencyBuilder::build(&messages);
815 let errors = CompressionPipeline::validate_compression(&messages, &deps);
816 assert!(errors.is_empty());
817 }
818
819 #[test]
820 fn test_validate_compression_orphaned_tool_result() {
821 let messages = vec![
822 Message {
823 role: Role::User,
824 content: MessageContent::Text("Request".to_string()),
825 },
826 Message {
827 role: Role::Tool,
828 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
829 tool_use_id: "tool_missing".to_string(),
830 content: "Orphaned result".to_string(),
831 }]),
832 },
833 ];
834
835 let deps = DependencyBuilder::build(&messages);
836 let errors = CompressionPipeline::validate_compression(&messages, &deps);
837 assert!(!errors.is_empty());
838 assert!(
839 errors
840 .iter()
841 .any(|e| matches!(e, ValidationError::OrphanedToolResult { .. }))
842 );
843 }
844
845 #[test]
846 fn test_validate_compression_empty() {
847 let messages: Vec<Message> = vec![];
848 let deps = DependencyBuilder::build(&messages);
849 let errors = CompressionPipeline::validate_compression(&messages, &deps);
850 assert!(!errors.is_empty());
851 assert!(
852 errors
853 .iter()
854 .any(|e| matches!(e, ValidationError::MissingFirstMessage))
855 );
856 }
857}