1use anyhow::Result;
8
9use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
10use super::hardcode_config::HardcodeConfig;
11
12use super::compressor::{compress_messages, estimate_total_tokens};
13use super::config::{
14 CircuitBreakerState, CompressionConfig, TIME_BASED_MC_CLEARED_MESSAGE, ThresholdLevel,
15};
16use super::dependency::DependencyBuilder;
17use super::phase_detector::PhaseDetector;
18use super::scorer::Scorer;
19use super::summarizer::Summarizer;
20use super::tool_compressor::ToolCompressor;
21use super::types::{
22 AiCompressionMode, CompressionStrategy, CompressionThresholds, DependencyGraph, ScoredMessage,
23};
24
25pub struct CompressionPipeline {
27 config: CompressionConfig,
29 scorer: Scorer,
31 tool_compressor: ToolCompressor,
33 circuit_breaker: CircuitBreakerState,
35}
36
37pub struct CompressionOutcome {
39 pub messages: Vec<Message>,
41 pub threshold_level: ThresholdLevel,
43 pub percent_left: u32,
45 pub success: bool,
47 pub error: Option<String>,
49 pub circuit_breaker_tripped: bool,
51}
52
53#[derive(Debug, Clone)]
55pub enum ValidationError {
56 OrphanedToolResult { tool_use_id: String, index: usize },
58 OrphanedToolUse { tool_use_id: String, index: usize },
60 MissingFirstMessage,
62 OrderViolation {
64 expected_role: Role,
65 actual_role: Role,
66 index: usize,
67 },
68}
69
70impl CompressionPipeline {
71 pub fn new_rule_only(config: CompressionConfig) -> Self {
73 let thresholds = CompressionThresholds::default();
74 Self {
75 config,
76 scorer: Scorer::new_rule_only(),
77 tool_compressor: ToolCompressor::new_truncate_only(thresholds),
78 circuit_breaker: CircuitBreakerState::new(),
79 }
80 }
81
82 pub fn new_with_ai(config: CompressionConfig, fast_model: Box<dyn Provider>) -> Self {
84 let thresholds = CompressionThresholds::default();
85 let summarizer = Summarizer::new(fast_model.clone());
86
87 Self {
88 config,
89 scorer: Scorer::new_with_ai(fast_model),
90 tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
91 circuit_breaker: CircuitBreakerState::new(),
92 }
93 }
94
95 pub fn new_with_full_ai(
97 config: CompressionConfig,
98 fast_model: Box<dyn Provider>,
99 main_model: Box<dyn Provider>,
100 ) -> Self {
101 let thresholds = CompressionThresholds::default();
102 let summarizer = Summarizer::new_with_main(fast_model.clone(), main_model);
103
104 Self {
105 config,
106 scorer: Scorer::new_with_ai(fast_model),
107 tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
108 circuit_breaker: CircuitBreakerState::new(),
109 }
110 }
111
112 pub fn should_compress(&self, token_usage: u32, context_window: u32) -> (bool, ThresholdLevel) {
114 if self.circuit_breaker.should_skip() {
116 return (false, ThresholdLevel::Blocking);
117 }
118
119 let (level, _) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
120
121 let should_compress = level != ThresholdLevel::Normal;
122 (should_compress, level)
123 }
124
125 pub fn should_time_based_clear(messages: &[Message]) -> bool {
128 let last_assistant = messages.iter().rev().find(|m| m.role == Role::Assistant);
129
130 if let Some(_msg) = last_assistant {
131 let messages_since = messages
134 .iter()
135 .rev()
136 .take_while(|m| m.role != Role::Assistant)
137 .count();
138 messages_since > 10
140 } else {
141 false
142 }
143 }
144
145 pub fn time_based_microcompact(messages: &[Message]) -> Vec<Message> {
147 let config = HardcodeConfig::default();
148 messages
149 .iter()
150 .map(|msg| {
151 if msg.role != Role::Tool {
152 return msg.clone();
153 }
154
155 match &msg.content {
157 MessageContent::Blocks(blocks) => {
158 let new_blocks: Vec<ContentBlock> = blocks
159 .iter()
160 .map(|b| {
161 if let ContentBlock::ToolResult {
162 tool_use_id,
163 content,
164 } = b
165 {
166 if content.len() > config.preserve_content_threshold
168 && content != TIME_BASED_MC_CLEARED_MESSAGE
169 {
170 ContentBlock::ToolResult {
171 tool_use_id: tool_use_id.clone(),
172 content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
173 }
174 } else {
175 b.clone()
176 }
177 } else {
178 b.clone()
179 }
180 })
181 .collect();
182 Message {
183 role: msg.role.clone(),
184 content: MessageContent::Blocks(new_blocks),
185 }
186 }
187 _ => msg.clone(),
188 }
189 })
190 .collect()
191 }
192
193 pub fn strip_thinking(messages: &[Message]) -> Vec<Message> {
196 messages
197 .iter()
198 .map(|msg| {
199 match &msg.content {
200 MessageContent::Blocks(blocks) => {
201 let new_blocks: Vec<ContentBlock> = blocks
202 .iter()
203 .filter(|b| {
204 !matches!(b, ContentBlock::Thinking { .. })
206 })
207 .cloned()
208 .collect();
209 Message {
210 role: msg.role.clone(),
211 content: MessageContent::Blocks(new_blocks),
212 }
213 }
214 _ => msg.clone(),
215 }
216 })
217 .collect()
218 }
219
220 const COMPACTABLE_TOOLS: &[&str] = &[
223 "bash",
224 "read",
225 "glob",
226 "grep",
227 "ls",
228 "edit",
229 "write",
230 "notebook_edit",
231 "web_fetch",
232 "web_search",
233 ];
234
235 pub fn is_compactable_tool(tool_name: &str) -> bool {
237 Self::COMPACTABLE_TOOLS.contains(&tool_name)
238 }
239
240 pub fn clear_tool_results(messages: &[Message], _tool_names: &[&str]) -> Vec<Message> {
242 let config = HardcodeConfig::default();
243 messages
244 .iter()
245 .map(|msg| {
246 if msg.role != Role::Tool {
247 return msg.clone();
248 }
249
250 match &msg.content {
251 MessageContent::Blocks(blocks) => {
252 let new_blocks: Vec<ContentBlock> = blocks
253 .iter()
254 .map(|b| {
255 if let ContentBlock::ToolResult {
256 tool_use_id,
257 content,
258 } = b
259 {
260 if content.len() > config.preserve_content_threshold
263 && content != TIME_BASED_MC_CLEARED_MESSAGE
264 {
265 ContentBlock::ToolResult {
266 tool_use_id: tool_use_id.clone(),
267 content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
268 }
269 } else {
270 b.clone()
271 }
272 } else {
273 b.clone()
274 }
275 })
276 .collect();
277 Message {
278 role: msg.role.clone(),
279 content: MessageContent::Blocks(new_blocks),
280 }
281 }
282 _ => msg.clone(),
283 }
284 })
285 .collect()
286 }
287
288 pub fn full_microcompact(messages: &[Message]) -> Vec<Message> {
290 let no_thinking = Self::strip_thinking(messages);
292 Self::time_based_microcompact(&no_thinking)
294 }
295
296 pub fn validate_compression(
302 messages: &[Message],
303 _original_deps: &DependencyGraph,
304 ) -> Vec<ValidationError> {
305 let mut errors = Vec::new();
306
307 if messages.is_empty() {
309 errors.push(ValidationError::MissingFirstMessage);
310 return errors;
311 }
312
313 let new_deps = DependencyBuilder::build(messages);
315
316 for (idx, msg) in messages.iter().enumerate() {
318 if msg.role == Role::Tool
319 && let MessageContent::Blocks(blocks) = &msg.content
320 {
321 for block in blocks {
322 if let ContentBlock::ToolResult { tool_use_id, .. } = block {
323 let has_tool_use = messages.iter().any(|m| {
325 if let MessageContent::Blocks(bs) = &m.content {
326 bs.iter().any(|b| {
327 if let ContentBlock::ToolUse { id, .. } = b {
328 id == tool_use_id
329 } else {
330 false
331 }
332 })
333 } else {
334 false
335 }
336 });
337
338 if !has_tool_use {
339 errors.push(ValidationError::OrphanedToolResult {
340 tool_use_id: tool_use_id.clone(),
341 index: idx,
342 });
343 }
344 }
345 }
346 }
347 }
348
349 for (idx, msg) in messages.iter().enumerate() {
351 if let MessageContent::Blocks(blocks) = &msg.content {
352 for block in blocks {
353 if let ContentBlock::ToolUse { id, .. } = block {
354 let has_tool_result = messages.iter().any(|m| {
356 if m.role == Role::Tool {
357 if let MessageContent::Blocks(bs) = &m.content {
358 bs.iter().any(|b| {
359 if let ContentBlock::ToolResult { tool_use_id, .. } = b {
360 tool_use_id == id
361 } else {
362 false
363 }
364 })
365 } else {
366 false
367 }
368 } else {
369 false
370 }
371 });
372
373 if !has_tool_result {
374 errors.push(ValidationError::OrphanedToolUse {
375 tool_use_id: id.clone(),
376 index: idx,
377 });
378 }
379 }
380 }
381 }
382 }
383
384 for dep in &new_deps.dependencies {
386 if dep.tool_use_idx >= messages.len() {
387 errors.push(ValidationError::OrphanedToolUse {
388 tool_use_id: dep.tool_name.clone(),
389 index: dep.tool_use_idx,
390 });
391 }
392 if dep.tool_result_idx >= messages.len() {
393 errors.push(ValidationError::OrphanedToolResult {
394 tool_use_id: dep.tool_name.clone(),
395 index: dep.tool_result_idx,
396 });
397 }
398 }
399
400 errors
401 }
402
403 pub fn is_valid_compression(messages: &[Message], original_deps: &DependencyGraph) -> bool {
405 Self::validate_compression(messages, original_deps).is_empty()
406 }
407
408 pub async fn execute(
410 &mut self,
411 messages: &[Message],
412 ai_mode: AiCompressionMode,
413 token_usage: u32,
414 context_window: u32,
415 ) -> Result<CompressionOutcome> {
416 if self.circuit_breaker.should_skip() {
418 return Ok(CompressionOutcome {
419 messages: messages.to_vec(),
420 threshold_level: ThresholdLevel::Blocking,
421 percent_left: 0,
422 success: false,
423 error: Some("Circuit breaker tripped - too many consecutive failures".to_string()),
424 circuit_breaker_tripped: true,
425 });
426 }
427
428 if messages.len() <= self.config.min_preserve_messages {
429 let (level, percent) =
430 CompressionConfig::calculate_threshold_level(token_usage, context_window);
431 return Ok(CompressionOutcome {
432 messages: messages.to_vec(),
433 threshold_level: level,
434 percent_left: percent,
435 success: true,
436 error: None,
437 circuit_breaker_tripped: false,
438 });
439 }
440
441 let pre_processed = if Self::should_time_based_clear(messages) {
443 Self::time_based_microcompact(messages)
444 } else {
445 messages.to_vec()
446 };
447
448 let phase = PhaseDetector::detect(&pre_processed);
450 let weights = phase.default_weights();
451 let deps = DependencyBuilder::build(&pre_processed);
452
453 let scored = self
455 .scorer
456 .score_all(&pre_processed, &weights, &deps, ai_mode)
457 .await?;
458
459 let compressed = self
461 .tool_compressor
462 .compress_results(&pre_processed, ai_mode)
463 .await?;
464
465 let target_count = calculate_target_count(pre_processed.len(), &self.config);
467 let selected = self.select_messages(scored, &deps, target_count, &compressed);
468
469 let final_messages = self.ensure_dependency_integrity(selected, &deps, &pre_processed);
471
472 self.circuit_breaker.record_success();
474
475 let post_tokens = estimate_total_tokens(&final_messages);
477 let (level, percent) =
478 CompressionConfig::calculate_threshold_level(post_tokens, context_window);
479
480 Ok(CompressionOutcome {
481 messages: final_messages,
482 threshold_level: level,
483 percent_left: percent,
484 success: true,
485 error: None,
486 circuit_breaker_tripped: false,
487 })
488 }
489
490 pub async fn execute_with_circuit_breaker(
492 &mut self,
493 messages: &[Message],
494 ai_mode: AiCompressionMode,
495 token_usage: u32,
496 context_window: u32,
497 ) -> Result<CompressionOutcome> {
498 let result = self
499 .execute(messages, ai_mode, token_usage, context_window)
500 .await;
501
502 match result {
503 Ok(res) => Ok(res),
504 Err(e) => {
505 let tripped = self.circuit_breaker.record_failure();
507
508 let (level, percent) =
509 CompressionConfig::calculate_threshold_level(token_usage, context_window);
510
511 Ok(CompressionOutcome {
512 messages: messages.to_vec(),
513 threshold_level: level,
514 percent_left: percent,
515 success: false,
516 error: Some(e.to_string()),
517 circuit_breaker_tripped: tripped,
518 })
519 }
520 }
521 }
522
523 pub fn execute_sync(&self, messages: &[Message]) -> Result<Vec<Message>> {
525 compress_messages(messages, CompressionStrategy::BiasBased, &self.config)
527 }
528
529 fn select_messages(
531 &self,
532 scored: Vec<ScoredMessage>,
533 deps: &DependencyGraph,
534 target_count: usize,
535 compressed_messages: &[Message],
536 ) -> Vec<Message> {
537 let mut sorted = scored;
539 sorted.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
540
541 let mut preserve_indices: std::collections::HashSet<usize> =
543 std::collections::HashSet::new();
544
545 for sm in sorted.iter().take(target_count) {
547 preserve_indices.insert(sm.index);
548
549 for pair_idx in deps.get_pair_indices(sm.index) {
551 preserve_indices.insert(pair_idx);
552 }
553 }
554
555 let selected: Vec<Message> = preserve_indices
557 .iter()
558 .filter_map(|idx| compressed_messages.get(*idx).cloned())
559 .collect();
560
561 selected
562 }
563
564 fn ensure_dependency_integrity(
566 &self,
567 selected: Vec<Message>,
568 _deps: &DependencyGraph,
569 _original: &[Message],
570 ) -> Vec<Message> {
571 selected
574 }
575
576 pub fn score_only(&self, messages: &[Message]) -> Vec<ScoredMessage> {
578 let phase = PhaseDetector::detect(messages);
579 let weights = phase.default_weights();
580 let deps = DependencyBuilder::build(messages);
581
582 let mut scored: Vec<ScoredMessage> = Vec::new();
584 for (idx, msg) in messages.iter().enumerate() {
585 let base_score = super::scorer::score_by_rules(msg, idx, &weights);
586 scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
587 }
588
589 let bonus = weights.dependency_pair_bonus;
591 for dep in &deps.dependencies {
592 if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
593 sm.with_dependency_bonus(bonus);
594 }
595 if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
596 sm.with_dependency_bonus(bonus);
597 }
598 }
599
600 scored
601 }
602}
603
604fn calculate_target_count(total: usize, config: &CompressionConfig) -> usize {
606 let target = (total as f64 * config.target_ratio) as usize;
607 target.max(config.min_preserve_messages)
608}
609
610pub fn compress_with_pipeline(
612 messages: &[Message],
613 config: &CompressionConfig,
614 ai_mode: AiCompressionMode,
615 fast_model: Option<Box<dyn Provider>>,
616) -> Result<Vec<Message>> {
617 let pipeline = match (ai_mode, fast_model) {
619 (AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
620 (AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
621 CompressionPipeline::new_with_ai(config.clone(), model)
622 }
623 _ => CompressionPipeline::new_rule_only(config.clone()),
624 };
625
626 pipeline.execute_sync(messages)
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use crate::providers::{MessageContent, Role};
634
635 #[test]
636 fn test_pipeline_new_rule_only() {
637 let config = CompressionConfig::default();
638 let pipeline = CompressionPipeline::new_rule_only(config);
639 let messages = vec![Message {
641 role: Role::User,
642 content: MessageContent::Text("Test".to_string()),
643 }];
644 let result = pipeline.execute_sync(&messages);
645 assert!(result.is_ok());
646 }
647
648 #[test]
649 fn test_calculate_target_count() {
650 let config = CompressionConfig::default();
651 let total = 100;
652 let target = calculate_target_count(total, &config);
653 assert!(target >= config.min_preserve_messages);
654 assert!(target < total);
655 }
656
657 #[test]
658 fn test_score_only() {
659 let config = CompressionConfig::default();
660 let pipeline = CompressionPipeline::new_rule_only(config);
661
662 let messages = vec![
663 Message {
664 role: Role::User,
665 content: MessageContent::Text("Hello".to_string()),
666 },
667 Message {
668 role: Role::Assistant,
669 content: MessageContent::Text("Hi".to_string()),
670 },
671 ];
672
673 let scored = pipeline.score_only(&messages);
674 assert_eq!(scored.len(), 2);
675 assert!(scored[0].final_score > scored[1].final_score); }
677
678 #[test]
679 fn test_execute_sync_small() {
680 let config = CompressionConfig::default();
681 let pipeline = CompressionPipeline::new_rule_only(config);
682
683 let messages = vec![Message {
684 role: Role::User,
685 content: MessageContent::Text("Hello".to_string()),
686 }];
687
688 let result = pipeline.execute_sync(&messages).unwrap();
689 assert_eq!(result.len(), 1); }
691
692 #[test]
693 fn test_time_based_microcompact() {
694 let messages = vec![
695 Message {
696 role: Role::Tool,
697 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
698 tool_use_id: "tool_1".to_string(),
699 content: "This is a very long tool result content that should be cleared..."
700 .repeat(20),
701 }]),
702 },
703 Message {
704 role: Role::Tool,
705 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
706 tool_use_id: "tool_2".to_string(),
707 content: "Short content".to_string(),
708 }]),
709 },
710 ];
711
712 let compacted = CompressionPipeline::time_based_microcompact(&messages);
713
714 if let MessageContent::Blocks(blocks) = &compacted[0].content {
716 if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
717 assert_eq!(content, TIME_BASED_MC_CLEARED_MESSAGE);
718 }
719 }
720
721 if let MessageContent::Blocks(blocks) = &compacted[1].content {
723 if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
724 assert_eq!(content, "Short content");
725 }
726 }
727 }
728
729 #[test]
730 fn test_strip_thinking() {
731 let messages = vec![Message {
732 role: Role::Assistant,
733 content: MessageContent::Blocks(vec![
734 ContentBlock::Text {
735 text: "Response".to_string(),
736 },
737 ContentBlock::Thinking {
738 thinking: "Long thinking process...".to_string(),
739 signature: None,
740 },
741 ]),
742 }];
743
744 let stripped = CompressionPipeline::strip_thinking(&messages);
745
746 if let MessageContent::Blocks(blocks) = &stripped[0].content {
748 assert_eq!(blocks.len(), 1);
749 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
750 }
751 }
752
753 #[test]
754 fn test_is_compactable_tool() {
755 assert!(CompressionPipeline::is_compactable_tool("bash"));
756 assert!(CompressionPipeline::is_compactable_tool("read"));
757 assert!(CompressionPipeline::is_compactable_tool("glob"));
758 assert!(!CompressionPipeline::is_compactable_tool("unknown_tool"));
759 }
760
761 #[test]
762 fn test_should_time_based_clear() {
763 let mut many_messages: Vec<Message> = vec![Message {
765 role: Role::Assistant,
766 content: MessageContent::Text("response".to_string()),
767 }];
768 for i in 0..15 {
770 many_messages.push(Message {
771 role: if i % 2 == 0 { Role::User } else { Role::Tool },
772 content: MessageContent::Text("content".to_string()),
773 });
774 }
775
776 assert!(CompressionPipeline::should_time_based_clear(&many_messages));
777
778 let few_messages = vec![
780 Message {
781 role: Role::Assistant,
782 content: MessageContent::Text("response".to_string()),
783 },
784 Message {
785 role: Role::User,
786 content: MessageContent::Text("follow-up".to_string()),
787 },
788 ];
789
790 assert!(!CompressionPipeline::should_time_based_clear(&few_messages));
791 }
792
793 #[test]
794 fn test_validate_compression_valid() {
795 let messages = vec![
796 Message {
797 role: Role::User,
798 content: MessageContent::Text("Request".to_string()),
799 },
800 Message {
801 role: Role::Assistant,
802 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
803 id: "tool_1".to_string(),
804 name: "read".to_string(),
805 input: serde_json::json!({"path": "test.txt"}),
806 }]),
807 },
808 Message {
809 role: Role::Tool,
810 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
811 tool_use_id: "tool_1".to_string(),
812 content: "File content".to_string(),
813 }]),
814 },
815 ];
816
817 let deps = DependencyBuilder::build(&messages);
818 let errors = CompressionPipeline::validate_compression(&messages, &deps);
819 assert!(errors.is_empty());
820 }
821
822 #[test]
823 fn test_validate_compression_orphaned_tool_result() {
824 let messages = vec![
825 Message {
826 role: Role::User,
827 content: MessageContent::Text("Request".to_string()),
828 },
829 Message {
830 role: Role::Tool,
831 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
832 tool_use_id: "tool_missing".to_string(),
833 content: "Orphaned result".to_string(),
834 }]),
835 },
836 ];
837
838 let deps = DependencyBuilder::build(&messages);
839 let errors = CompressionPipeline::validate_compression(&messages, &deps);
840 assert!(!errors.is_empty());
841 assert!(
842 errors
843 .iter()
844 .any(|e| matches!(e, ValidationError::OrphanedToolResult { .. }))
845 );
846 }
847
848 #[test]
849 fn test_validate_compression_empty() {
850 let messages: Vec<Message> = vec![];
851 let deps = DependencyBuilder::build(&messages);
852 let errors = CompressionPipeline::validate_compression(&messages, &deps);
853 assert!(!errors.is_empty());
854 assert!(
855 errors
856 .iter()
857 .any(|e| matches!(e, ValidationError::MissingFirstMessage))
858 );
859 }
860}