1use anyhow::Result;
8
9use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
10
11use super::compressor::{compress_messages, estimate_total_tokens};
12use super::config::{CompressionConfig, CircuitBreakerState, ThresholdLevel,
13 TIME_BASED_MC_CLEARED_MESSAGE};
14use super::dependency::DependencyBuilder;
15use super::phase_detector::PhaseDetector;
16use super::scorer::Scorer;
17use super::summarizer::Summarizer;
18use super::tool_compressor::ToolCompressor;
19use super::types::{
20 AiCompressionMode, CompressionThresholds, DependencyGraph,
21 ScoredMessage, CompressionStrategy,
22};
23
24pub struct CompressionPipeline {
26 config: CompressionConfig,
28 scorer: Scorer,
30 tool_compressor: ToolCompressor,
32 circuit_breaker: CircuitBreakerState,
34}
35
36pub struct CompressionOutcome {
38 pub messages: Vec<Message>,
40 pub threshold_level: ThresholdLevel,
42 pub percent_left: u32,
44 pub success: bool,
46 pub error: Option<String>,
48 pub circuit_breaker_tripped: bool,
50}
51
52#[derive(Debug, Clone)]
54pub enum ValidationError {
55 OrphanedToolResult { tool_use_id: String, index: usize },
57 OrphanedToolUse { tool_use_id: String, index: usize },
59 MissingFirstMessage,
61 OrderViolation { expected_role: Role, actual_role: Role, index: usize },
63}
64
65impl CompressionPipeline {
66 pub fn new_rule_only(config: CompressionConfig) -> Self {
68 let thresholds = CompressionThresholds::default();
69 Self {
70 config,
71 scorer: Scorer::new_rule_only(),
72 tool_compressor: ToolCompressor::new_truncate_only(thresholds),
73 circuit_breaker: CircuitBreakerState::new(),
74 }
75 }
76
77 pub fn new_with_ai(
79 config: CompressionConfig,
80 fast_model: Box<dyn Provider>,
81 ) -> Self {
82 let thresholds = CompressionThresholds::default();
83 let summarizer = Summarizer::new(fast_model.clone());
84
85 Self {
86 config,
87 scorer: Scorer::new_with_ai(fast_model),
88 tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
89 circuit_breaker: CircuitBreakerState::new(),
90 }
91 }
92
93 pub fn new_with_full_ai(
95 config: CompressionConfig,
96 fast_model: Box<dyn Provider>,
97 main_model: Box<dyn Provider>,
98 ) -> Self {
99 let thresholds = CompressionThresholds::default();
100 let summarizer = Summarizer::new_with_main(fast_model.clone(), main_model);
101
102 Self {
103 config,
104 scorer: Scorer::new_with_ai(fast_model),
105 tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
106 circuit_breaker: CircuitBreakerState::new(),
107 }
108 }
109
110 pub fn should_compress(
112 &self,
113 token_usage: u32,
114 context_window: u32,
115 ) -> (bool, ThresholdLevel) {
116 if self.circuit_breaker.should_skip() {
118 return (false, ThresholdLevel::Blocking);
119 }
120
121 let (level, _) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
122
123 let should_compress = level != ThresholdLevel::Normal;
124 (should_compress, level)
125 }
126
127 pub fn should_time_based_clear(messages: &[Message]) -> bool {
130 let last_assistant = messages.iter().rev().find(|m| m.role == Role::Assistant);
131
132 if let Some(_msg) = last_assistant {
133 let messages_since = messages.iter().rev().take_while(|m| m.role != Role::Assistant).count();
136 messages_since > 10
138 } else {
139 false
140 }
141 }
142
143 pub fn time_based_microcompact(messages: &[Message]) -> Vec<Message> {
145 messages.iter().map(|msg| {
146 if msg.role != Role::Tool {
147 return msg.clone();
148 }
149
150 match &msg.content {
152 MessageContent::Blocks(blocks) => {
153 let new_blocks: Vec<ContentBlock> = blocks.iter().map(|b| {
154 if let ContentBlock::ToolResult { tool_use_id, content } = b {
155 if content.len() > 500 && content != TIME_BASED_MC_CLEARED_MESSAGE {
157 ContentBlock::ToolResult {
158 tool_use_id: tool_use_id.clone(),
159 content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
160 }
161 } else {
162 b.clone()
163 }
164 } else {
165 b.clone()
166 }
167 }).collect();
168 Message {
169 role: msg.role.clone(),
170 content: MessageContent::Blocks(new_blocks),
171 }
172 }
173 _ => msg.clone(),
174 }
175 }).collect()
176 }
177
178 pub fn strip_thinking(messages: &[Message]) -> Vec<Message> {
181 messages.iter().map(|msg| {
182 match &msg.content {
183 MessageContent::Blocks(blocks) => {
184 let new_blocks: Vec<ContentBlock> = blocks.iter()
185 .filter(|b| {
186 !matches!(b, ContentBlock::Thinking { .. })
188 })
189 .cloned()
190 .collect();
191 Message {
192 role: msg.role.clone(),
193 content: MessageContent::Blocks(new_blocks),
194 }
195 }
196 _ => msg.clone(),
197 }
198 }).collect()
199 }
200
201 const COMPACTABLE_TOOLS: &[&str] = &[
204 "bash", "read", "glob", "grep", "ls", "edit", "write",
205 "notebook_edit", "web_fetch", "web_search",
206 ];
207
208 pub fn is_compactable_tool(tool_name: &str) -> bool {
210 Self::COMPACTABLE_TOOLS.contains(&tool_name)
211 }
212
213 pub fn clear_tool_results(messages: &[Message], _tool_names: &[&str]) -> Vec<Message> {
215 messages.iter().map(|msg| {
216 if msg.role != Role::Tool {
217 return msg.clone();
218 }
219
220 match &msg.content {
221 MessageContent::Blocks(blocks) => {
222 let new_blocks: Vec<ContentBlock> = blocks.iter().map(|b| {
223 if let ContentBlock::ToolResult { tool_use_id, content } = b {
224 if content.len() > 500 && content != TIME_BASED_MC_CLEARED_MESSAGE {
227 ContentBlock::ToolResult {
228 tool_use_id: tool_use_id.clone(),
229 content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
230 }
231 } else {
232 b.clone()
233 }
234 } else {
235 b.clone()
236 }
237 }).collect();
238 Message {
239 role: msg.role.clone(),
240 content: MessageContent::Blocks(new_blocks),
241 }
242 }
243 _ => msg.clone(),
244 }
245 }).collect()
246 }
247
248 pub fn full_microcompact(messages: &[Message]) -> Vec<Message> {
250 let no_thinking = Self::strip_thinking(messages);
252 Self::time_based_microcompact(&no_thinking)
254 }
255
256 pub fn validate_compression(messages: &[Message], _original_deps: &DependencyGraph) -> Vec<ValidationError> {
262 let mut errors = Vec::new();
263
264 if messages.is_empty() {
266 errors.push(ValidationError::MissingFirstMessage);
267 return errors;
268 }
269
270 let new_deps = DependencyBuilder::build(messages);
272
273 for (idx, msg) in messages.iter().enumerate() {
275 if msg.role == Role::Tool {
276 if let MessageContent::Blocks(blocks) = &msg.content {
277 for block in blocks {
278 if let ContentBlock::ToolResult { tool_use_id, .. } = block {
279 let has_tool_use = messages.iter().any(|m| {
281 if let MessageContent::Blocks(bs) = &m.content {
282 bs.iter().any(|b| {
283 if let ContentBlock::ToolUse { id, .. } = b {
284 id == tool_use_id
285 } else {
286 false
287 }
288 })
289 } else {
290 false
291 }
292 });
293
294 if !has_tool_use {
295 errors.push(ValidationError::OrphanedToolResult {
296 tool_use_id: tool_use_id.clone(),
297 index: idx,
298 });
299 }
300 }
301 }
302 }
303 }
304 }
305
306 for (idx, msg) in messages.iter().enumerate() {
308 if let MessageContent::Blocks(blocks) = &msg.content {
309 for block in blocks {
310 if let ContentBlock::ToolUse { id, .. } = block {
311 let has_tool_result = messages.iter().any(|m| {
313 if m.role == Role::Tool {
314 if let MessageContent::Blocks(bs) = &m.content {
315 bs.iter().any(|b| {
316 if let ContentBlock::ToolResult { tool_use_id, .. } = b {
317 tool_use_id == id
318 } else {
319 false
320 }
321 })
322 } else {
323 false
324 }
325 } else {
326 false
327 }
328 });
329
330 if !has_tool_result {
331 errors.push(ValidationError::OrphanedToolUse {
332 tool_use_id: id.clone(),
333 index: idx,
334 });
335 }
336 }
337 }
338 }
339 }
340
341 for dep in &new_deps.dependencies {
343 if dep.tool_use_idx >= messages.len() {
344 errors.push(ValidationError::OrphanedToolUse {
345 tool_use_id: dep.tool_name.clone(),
346 index: dep.tool_use_idx,
347 });
348 }
349 if dep.tool_result_idx >= messages.len() {
350 errors.push(ValidationError::OrphanedToolResult {
351 tool_use_id: dep.tool_name.clone(),
352 index: dep.tool_result_idx,
353 });
354 }
355 }
356
357 errors
358 }
359
360 pub fn is_valid_compression(messages: &[Message], original_deps: &DependencyGraph) -> bool {
362 Self::validate_compression(messages, original_deps).is_empty()
363 }
364
365 pub async fn execute(
367 &mut self,
368 messages: &[Message],
369 ai_mode: AiCompressionMode,
370 token_usage: u32,
371 context_window: u32,
372 ) -> Result<CompressionOutcome> {
373 if self.circuit_breaker.should_skip() {
375 return Ok(CompressionOutcome {
376 messages: messages.to_vec(),
377 threshold_level: ThresholdLevel::Blocking,
378 percent_left: 0,
379 success: false,
380 error: Some("Circuit breaker tripped - too many consecutive failures".to_string()),
381 circuit_breaker_tripped: true,
382 });
383 }
384
385 if messages.len() <= self.config.min_preserve_messages {
386 let (level, percent) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
387 return Ok(CompressionOutcome {
388 messages: messages.to_vec(),
389 threshold_level: level,
390 percent_left: percent,
391 success: true,
392 error: None,
393 circuit_breaker_tripped: false,
394 });
395 }
396
397 let pre_processed = if Self::should_time_based_clear(messages) {
399 Self::time_based_microcompact(messages)
400 } else {
401 messages.to_vec()
402 };
403
404 let phase = PhaseDetector::detect(&pre_processed);
406 let weights = phase.default_weights();
407 let deps = DependencyBuilder::build(&pre_processed);
408
409 let scored = self.scorer.score_all(&pre_processed, &weights, &deps, ai_mode).await?;
411
412 let compressed = self.tool_compressor.compress_results(&pre_processed, ai_mode).await?;
414
415 let target_count = calculate_target_count(pre_processed.len(), &self.config);
417 let selected = self.select_messages(scored, &deps, target_count, &compressed);
418
419 let final_messages = self.ensure_dependency_integrity(selected, &deps, &pre_processed);
421
422 self.circuit_breaker.record_success();
424
425 let post_tokens = estimate_total_tokens(&final_messages);
427 let (level, percent) = CompressionConfig::calculate_threshold_level(post_tokens, context_window);
428
429 Ok(CompressionOutcome {
430 messages: final_messages,
431 threshold_level: level,
432 percent_left: percent,
433 success: true,
434 error: None,
435 circuit_breaker_tripped: false,
436 })
437 }
438
439 pub async fn execute_with_circuit_breaker(
441 &mut self,
442 messages: &[Message],
443 ai_mode: AiCompressionMode,
444 token_usage: u32,
445 context_window: u32,
446 ) -> Result<CompressionOutcome> {
447 let result = self.execute(messages, ai_mode, token_usage, context_window).await;
448
449 match result {
450 Ok(res) => Ok(res),
451 Err(e) => {
452 let tripped = self.circuit_breaker.record_failure();
454
455 let (level, percent) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
456
457 Ok(CompressionOutcome {
458 messages: messages.to_vec(),
459 threshold_level: level,
460 percent_left: percent,
461 success: false,
462 error: Some(e.to_string()),
463 circuit_breaker_tripped: tripped,
464 })
465 }
466 }
467 }
468
469 pub fn execute_sync(&self, messages: &[Message]) -> Result<Vec<Message>> {
471 compress_messages(messages, CompressionStrategy::BiasBased, &self.config)
473 }
474
475 fn select_messages(
477 &self,
478 scored: Vec<ScoredMessage>,
479 deps: &DependencyGraph,
480 target_count: usize,
481 compressed_messages: &[Message],
482 ) -> Vec<Message> {
483 let mut sorted = scored;
485 sorted.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
486
487 let mut preserve_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
489
490 for sm in sorted.iter().take(target_count) {
492 preserve_indices.insert(sm.index);
493
494 for pair_idx in deps.get_pair_indices(sm.index) {
496 preserve_indices.insert(pair_idx);
497 }
498 }
499
500 let selected: Vec<Message> = preserve_indices
502 .iter()
503 .filter_map(|idx| compressed_messages.get(*idx).cloned())
504 .collect();
505
506 selected
507 }
508
509 fn ensure_dependency_integrity(
511 &self,
512 selected: Vec<Message>,
513 _deps: &DependencyGraph,
514 _original: &[Message],
515 ) -> Vec<Message> {
516 selected
519 }
520
521 pub fn score_only(&self, messages: &[Message]) -> Vec<ScoredMessage> {
523 let phase = PhaseDetector::detect(messages);
524 let weights = phase.default_weights();
525 let deps = DependencyBuilder::build(messages);
526
527 let mut scored: Vec<ScoredMessage> = Vec::new();
529 for (idx, msg) in messages.iter().enumerate() {
530 let base_score = super::scorer::score_by_rules(msg, idx, &weights);
531 scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
532 }
533
534 let bonus = weights.dependency_pair_bonus;
536 for dep in &deps.dependencies {
537 if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
538 sm.with_dependency_bonus(bonus);
539 }
540 if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
541 sm.with_dependency_bonus(bonus);
542 }
543 }
544
545 scored
546 }
547}
548
549fn calculate_target_count(total: usize, config: &CompressionConfig) -> usize {
551 let target = (total as f64 * config.target_ratio) as usize;
552 target.max(config.min_preserve_messages)
553}
554
555pub fn compress_with_pipeline(
557 messages: &[Message],
558 config: &CompressionConfig,
559 ai_mode: AiCompressionMode,
560 fast_model: Option<Box<dyn Provider>>,
561) -> Result<Vec<Message>> {
562 let pipeline = match (ai_mode, fast_model) {
564 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
565 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
566 CompressionPipeline::new_with_ai(config.clone(), model)
567 }
568 _ => CompressionPipeline::new_rule_only(config.clone()),
569 };
570
571 pipeline.execute_sync(messages)
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use crate::providers::{MessageContent, Role};
579
580 #[test]
581 fn test_pipeline_new_rule_only() {
582 let config = CompressionConfig::default();
583 let pipeline = CompressionPipeline::new_rule_only(config);
584 let messages = vec![
586 Message {
587 role: Role::User,
588 content: MessageContent::Text("Test".to_string()),
589 },
590 ];
591 let result = pipeline.execute_sync(&messages);
592 assert!(result.is_ok());
593 }
594
595 #[test]
596 fn test_calculate_target_count() {
597 let config = CompressionConfig::default();
598 let total = 100;
599 let target = calculate_target_count(total, &config);
600 assert!(target >= config.min_preserve_messages);
601 assert!(target < total);
602 }
603
604 #[test]
605 fn test_score_only() {
606 let config = CompressionConfig::default();
607 let pipeline = CompressionPipeline::new_rule_only(config);
608
609 let messages = vec![
610 Message {
611 role: Role::User,
612 content: MessageContent::Text("Hello".to_string()),
613 },
614 Message {
615 role: Role::Assistant,
616 content: MessageContent::Text("Hi".to_string()),
617 },
618 ];
619
620 let scored = pipeline.score_only(&messages);
621 assert_eq!(scored.len(), 2);
622 assert!(scored[0].final_score > scored[1].final_score); }
624
625 #[test]
626 fn test_execute_sync_small() {
627 let config = CompressionConfig::default();
628 let pipeline = CompressionPipeline::new_rule_only(config);
629
630 let messages = vec![
631 Message {
632 role: Role::User,
633 content: MessageContent::Text("Hello".to_string()),
634 },
635 ];
636
637 let result = pipeline.execute_sync(&messages).unwrap();
638 assert_eq!(result.len(), 1); }
640
641 #[test]
642 fn test_time_based_microcompact() {
643 let messages = vec![
644 Message {
645 role: Role::Tool,
646 content: MessageContent::Blocks(vec![
647 ContentBlock::ToolResult {
648 tool_use_id: "tool_1".to_string(),
649 content: "This is a very long tool result content that should be cleared...".repeat(20),
650 },
651 ]),
652 },
653 Message {
654 role: Role::Tool,
655 content: MessageContent::Blocks(vec![
656 ContentBlock::ToolResult {
657 tool_use_id: "tool_2".to_string(),
658 content: "Short content".to_string(),
659 },
660 ]),
661 },
662 ];
663
664 let compacted = CompressionPipeline::time_based_microcompact(&messages);
665
666 if let MessageContent::Blocks(blocks) = &compacted[0].content {
668 if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
669 assert_eq!(content, TIME_BASED_MC_CLEARED_MESSAGE);
670 }
671 }
672
673 if let MessageContent::Blocks(blocks) = &compacted[1].content {
675 if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
676 assert_eq!(content, "Short content");
677 }
678 }
679 }
680
681 #[test]
682 fn test_strip_thinking() {
683 let messages = vec![
684 Message {
685 role: Role::Assistant,
686 content: MessageContent::Blocks(vec![
687 ContentBlock::Text { text: "Response".to_string() },
688 ContentBlock::Thinking { thinking: "Long thinking process...".to_string(), signature: None },
689 ]),
690 },
691 ];
692
693 let stripped = CompressionPipeline::strip_thinking(&messages);
694
695 if let MessageContent::Blocks(blocks) = &stripped[0].content {
697 assert_eq!(blocks.len(), 1);
698 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
699 }
700 }
701
702 #[test]
703 fn test_is_compactable_tool() {
704 assert!(CompressionPipeline::is_compactable_tool("bash"));
705 assert!(CompressionPipeline::is_compactable_tool("read"));
706 assert!(CompressionPipeline::is_compactable_tool("glob"));
707 assert!(!CompressionPipeline::is_compactable_tool("unknown_tool"));
708 }
709
710 #[test]
711 fn test_should_time_based_clear() {
712 let mut many_messages: Vec<Message> = vec![
714 Message {
715 role: Role::Assistant,
716 content: MessageContent::Text("response".to_string()),
717 },
718 ];
719 for i in 0..15 {
721 many_messages.push(Message {
722 role: if i % 2 == 0 { Role::User } else { Role::Tool },
723 content: MessageContent::Text("content".to_string()),
724 });
725 }
726
727 assert!(CompressionPipeline::should_time_based_clear(&many_messages));
728
729 let few_messages = vec![
731 Message {
732 role: Role::Assistant,
733 content: MessageContent::Text("response".to_string()),
734 },
735 Message {
736 role: Role::User,
737 content: MessageContent::Text("follow-up".to_string()),
738 },
739 ];
740
741 assert!(!CompressionPipeline::should_time_based_clear(&few_messages));
742 }
743
744 #[test]
745 fn test_validate_compression_valid() {
746 let messages = vec![
747 Message {
748 role: Role::User,
749 content: MessageContent::Text("Request".to_string()),
750 },
751 Message {
752 role: Role::Assistant,
753 content: MessageContent::Blocks(vec![
754 ContentBlock::ToolUse {
755 id: "tool_1".to_string(),
756 name: "read".to_string(),
757 input: serde_json::json!({"path": "test.txt"}),
758 },
759 ]),
760 },
761 Message {
762 role: Role::Tool,
763 content: MessageContent::Blocks(vec![
764 ContentBlock::ToolResult {
765 tool_use_id: "tool_1".to_string(),
766 content: "File content".to_string(),
767 },
768 ]),
769 },
770 ];
771
772 let deps = DependencyBuilder::build(&messages);
773 let errors = CompressionPipeline::validate_compression(&messages, &deps);
774 assert!(errors.is_empty());
775 }
776
777 #[test]
778 fn test_validate_compression_orphaned_tool_result() {
779 let messages = vec![
780 Message {
781 role: Role::User,
782 content: MessageContent::Text("Request".to_string()),
783 },
784 Message {
785 role: Role::Tool,
786 content: MessageContent::Blocks(vec![
787 ContentBlock::ToolResult {
788 tool_use_id: "tool_missing".to_string(),
789 content: "Orphaned result".to_string(),
790 },
791 ]),
792 },
793 ];
794
795 let deps = DependencyBuilder::build(&messages);
796 let errors = CompressionPipeline::validate_compression(&messages, &deps);
797 assert!(!errors.is_empty());
798 assert!(errors.iter().any(|e| matches!(e, ValidationError::OrphanedToolResult { .. })));
799 }
800
801 #[test]
802 fn test_validate_compression_empty() {
803 let messages: Vec<Message> = vec![];
804 let deps = DependencyBuilder::build(&messages);
805 let errors = CompressionPipeline::validate_compression(&messages, &deps);
806 assert!(!errors.is_empty());
807 assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingFirstMessage)));
808 }
809}