use anyhow::Result;
use crate::providers::{ContentBlock, Message, MessageContent, Provider, Role};
use super::compressor::{compress_messages, estimate_total_tokens};
use super::config::{CompressionConfig, CircuitBreakerState, ThresholdLevel,
TIME_BASED_MC_CLEARED_MESSAGE};
use super::dependency::DependencyBuilder;
use super::phase_detector::PhaseDetector;
use super::scorer::Scorer;
use super::summarizer::Summarizer;
use super::tool_compressor::ToolCompressor;
use super::types::{
AiCompressionMode, CompressionThresholds, DependencyGraph,
ScoredMessage, CompressionStrategy,
};
pub struct CompressionPipeline {
config: CompressionConfig,
scorer: Scorer,
tool_compressor: ToolCompressor,
circuit_breaker: CircuitBreakerState,
}
pub struct CompressionOutcome {
pub messages: Vec<Message>,
pub threshold_level: ThresholdLevel,
pub percent_left: u32,
pub success: bool,
pub error: Option<String>,
pub circuit_breaker_tripped: bool,
}
#[derive(Debug, Clone)]
pub enum ValidationError {
OrphanedToolResult { tool_use_id: String, index: usize },
OrphanedToolUse { tool_use_id: String, index: usize },
MissingFirstMessage,
OrderViolation { expected_role: Role, actual_role: Role, index: usize },
}
impl CompressionPipeline {
pub fn new_rule_only(config: CompressionConfig) -> Self {
let thresholds = CompressionThresholds::default();
Self {
config,
scorer: Scorer::new_rule_only(),
tool_compressor: ToolCompressor::new_truncate_only(thresholds),
circuit_breaker: CircuitBreakerState::new(),
}
}
pub fn new_with_ai(
config: CompressionConfig,
fast_model: Box<dyn Provider>,
) -> Self {
let thresholds = CompressionThresholds::default();
let summarizer = Summarizer::new(fast_model.clone());
Self {
config,
scorer: Scorer::new_with_ai(fast_model),
tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
circuit_breaker: CircuitBreakerState::new(),
}
}
pub fn new_with_full_ai(
config: CompressionConfig,
fast_model: Box<dyn Provider>,
main_model: Box<dyn Provider>,
) -> Self {
let thresholds = CompressionThresholds::default();
let summarizer = Summarizer::new_with_main(fast_model.clone(), main_model);
Self {
config,
scorer: Scorer::new_with_ai(fast_model),
tool_compressor: ToolCompressor::new_with_ai(summarizer, thresholds),
circuit_breaker: CircuitBreakerState::new(),
}
}
pub fn should_compress(
&self,
token_usage: u32,
context_window: u32,
) -> (bool, ThresholdLevel) {
if self.circuit_breaker.should_skip() {
return (false, ThresholdLevel::Blocking);
}
let (level, _) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
let should_compress = level != ThresholdLevel::Normal;
(should_compress, level)
}
pub fn should_time_based_clear(messages: &[Message]) -> bool {
let last_assistant = messages.iter().rev().find(|m| m.role == Role::Assistant);
if let Some(_msg) = last_assistant {
let messages_since = messages.iter().rev().take_while(|m| m.role != Role::Assistant).count();
messages_since > 10
} else {
false
}
}
pub fn time_based_microcompact(messages: &[Message]) -> Vec<Message> {
messages.iter().map(|msg| {
if msg.role != Role::Tool {
return msg.clone();
}
match &msg.content {
MessageContent::Blocks(blocks) => {
let new_blocks: Vec<ContentBlock> = blocks.iter().map(|b| {
if let ContentBlock::ToolResult { tool_use_id, content } = b {
if content.len() > 500 && content != TIME_BASED_MC_CLEARED_MESSAGE {
ContentBlock::ToolResult {
tool_use_id: tool_use_id.clone(),
content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
}
} else {
b.clone()
}
} else {
b.clone()
}
}).collect();
Message {
role: msg.role.clone(),
content: MessageContent::Blocks(new_blocks),
}
}
_ => msg.clone(),
}
}).collect()
}
pub fn strip_thinking(messages: &[Message]) -> Vec<Message> {
messages.iter().map(|msg| {
match &msg.content {
MessageContent::Blocks(blocks) => {
let new_blocks: Vec<ContentBlock> = blocks.iter()
.filter(|b| {
!matches!(b, ContentBlock::Thinking { .. })
})
.cloned()
.collect();
Message {
role: msg.role.clone(),
content: MessageContent::Blocks(new_blocks),
}
}
_ => msg.clone(),
}
}).collect()
}
const COMPACTABLE_TOOLS: &[&str] = &[
"bash", "read", "glob", "grep", "ls", "edit", "write",
"notebook_edit", "web_fetch", "web_search",
];
pub fn is_compactable_tool(tool_name: &str) -> bool {
Self::COMPACTABLE_TOOLS.contains(&tool_name)
}
pub fn clear_tool_results(messages: &[Message], _tool_names: &[&str]) -> Vec<Message> {
messages.iter().map(|msg| {
if msg.role != Role::Tool {
return msg.clone();
}
match &msg.content {
MessageContent::Blocks(blocks) => {
let new_blocks: Vec<ContentBlock> = blocks.iter().map(|b| {
if let ContentBlock::ToolResult { tool_use_id, content } = b {
if content.len() > 500 && content != TIME_BASED_MC_CLEARED_MESSAGE {
ContentBlock::ToolResult {
tool_use_id: tool_use_id.clone(),
content: TIME_BASED_MC_CLEARED_MESSAGE.to_string(),
}
} else {
b.clone()
}
} else {
b.clone()
}
}).collect();
Message {
role: msg.role.clone(),
content: MessageContent::Blocks(new_blocks),
}
}
_ => msg.clone(),
}
}).collect()
}
pub fn full_microcompact(messages: &[Message]) -> Vec<Message> {
let no_thinking = Self::strip_thinking(messages);
Self::time_based_microcompact(&no_thinking)
}
pub fn validate_compression(messages: &[Message], _original_deps: &DependencyGraph) -> Vec<ValidationError> {
let mut errors = Vec::new();
if messages.is_empty() {
errors.push(ValidationError::MissingFirstMessage);
return errors;
}
let new_deps = DependencyBuilder::build(messages);
for (idx, msg) in messages.iter().enumerate() {
if msg.role == Role::Tool {
if let MessageContent::Blocks(blocks) = &msg.content {
for block in blocks {
if let ContentBlock::ToolResult { tool_use_id, .. } = block {
let has_tool_use = messages.iter().any(|m| {
if let MessageContent::Blocks(bs) = &m.content {
bs.iter().any(|b| {
if let ContentBlock::ToolUse { id, .. } = b {
id == tool_use_id
} else {
false
}
})
} else {
false
}
});
if !has_tool_use {
errors.push(ValidationError::OrphanedToolResult {
tool_use_id: tool_use_id.clone(),
index: idx,
});
}
}
}
}
}
}
for (idx, msg) in messages.iter().enumerate() {
if let MessageContent::Blocks(blocks) = &msg.content {
for block in blocks {
if let ContentBlock::ToolUse { id, .. } = block {
let has_tool_result = messages.iter().any(|m| {
if m.role == Role::Tool {
if let MessageContent::Blocks(bs) = &m.content {
bs.iter().any(|b| {
if let ContentBlock::ToolResult { tool_use_id, .. } = b {
tool_use_id == id
} else {
false
}
})
} else {
false
}
} else {
false
}
});
if !has_tool_result {
errors.push(ValidationError::OrphanedToolUse {
tool_use_id: id.clone(),
index: idx,
});
}
}
}
}
}
for dep in &new_deps.dependencies {
if dep.tool_use_idx >= messages.len() {
errors.push(ValidationError::OrphanedToolUse {
tool_use_id: dep.tool_name.clone(),
index: dep.tool_use_idx,
});
}
if dep.tool_result_idx >= messages.len() {
errors.push(ValidationError::OrphanedToolResult {
tool_use_id: dep.tool_name.clone(),
index: dep.tool_result_idx,
});
}
}
errors
}
pub fn is_valid_compression(messages: &[Message], original_deps: &DependencyGraph) -> bool {
Self::validate_compression(messages, original_deps).is_empty()
}
pub async fn execute(
&mut self,
messages: &[Message],
ai_mode: AiCompressionMode,
token_usage: u32,
context_window: u32,
) -> Result<CompressionOutcome> {
if self.circuit_breaker.should_skip() {
return Ok(CompressionOutcome {
messages: messages.to_vec(),
threshold_level: ThresholdLevel::Blocking,
percent_left: 0,
success: false,
error: Some("Circuit breaker tripped - too many consecutive failures".to_string()),
circuit_breaker_tripped: true,
});
}
if messages.len() <= self.config.min_preserve_messages {
let (level, percent) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
return Ok(CompressionOutcome {
messages: messages.to_vec(),
threshold_level: level,
percent_left: percent,
success: true,
error: None,
circuit_breaker_tripped: false,
});
}
let pre_processed = if Self::should_time_based_clear(messages) {
Self::time_based_microcompact(messages)
} else {
messages.to_vec()
};
let phase = PhaseDetector::detect(&pre_processed);
let weights = phase.default_weights();
let deps = DependencyBuilder::build(&pre_processed);
let scored = self.scorer.score_all(&pre_processed, &weights, &deps, ai_mode).await?;
let compressed = self.tool_compressor.compress_results(&pre_processed, ai_mode).await?;
let target_count = calculate_target_count(pre_processed.len(), &self.config);
let selected = self.select_messages(scored, &deps, target_count, &compressed);
let final_messages = self.ensure_dependency_integrity(selected, &deps, &pre_processed);
self.circuit_breaker.record_success();
let post_tokens = estimate_total_tokens(&final_messages);
let (level, percent) = CompressionConfig::calculate_threshold_level(post_tokens, context_window);
Ok(CompressionOutcome {
messages: final_messages,
threshold_level: level,
percent_left: percent,
success: true,
error: None,
circuit_breaker_tripped: false,
})
}
pub async fn execute_with_circuit_breaker(
&mut self,
messages: &[Message],
ai_mode: AiCompressionMode,
token_usage: u32,
context_window: u32,
) -> Result<CompressionOutcome> {
let result = self.execute(messages, ai_mode, token_usage, context_window).await;
match result {
Ok(res) => Ok(res),
Err(e) => {
let tripped = self.circuit_breaker.record_failure();
let (level, percent) = CompressionConfig::calculate_threshold_level(token_usage, context_window);
Ok(CompressionOutcome {
messages: messages.to_vec(),
threshold_level: level,
percent_left: percent,
success: false,
error: Some(e.to_string()),
circuit_breaker_tripped: tripped,
})
}
}
}
pub fn execute_sync(&self, messages: &[Message]) -> Result<Vec<Message>> {
compress_messages(messages, CompressionStrategy::BiasBased, &self.config)
}
fn select_messages(
&self,
scored: Vec<ScoredMessage>,
deps: &DependencyGraph,
target_count: usize,
compressed_messages: &[Message],
) -> Vec<Message> {
let mut sorted = scored;
sorted.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
let mut preserve_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
for sm in sorted.iter().take(target_count) {
preserve_indices.insert(sm.index);
for pair_idx in deps.get_pair_indices(sm.index) {
preserve_indices.insert(pair_idx);
}
}
let selected: Vec<Message> = preserve_indices
.iter()
.filter_map(|idx| compressed_messages.get(*idx).cloned())
.collect();
selected
}
fn ensure_dependency_integrity(
&self,
selected: Vec<Message>,
_deps: &DependencyGraph,
_original: &[Message],
) -> Vec<Message> {
selected
}
pub fn score_only(&self, messages: &[Message]) -> Vec<ScoredMessage> {
let phase = PhaseDetector::detect(messages);
let weights = phase.default_weights();
let deps = DependencyBuilder::build(messages);
let mut scored: Vec<ScoredMessage> = Vec::new();
for (idx, msg) in messages.iter().enumerate() {
let base_score = super::scorer::score_by_rules(msg, idx, &weights);
scored.push(ScoredMessage::new(idx, msg.clone(), base_score));
}
let bonus = weights.dependency_pair_bonus;
for dep in &deps.dependencies {
if let Some(sm) = scored.get_mut(dep.tool_use_idx) {
sm.with_dependency_bonus(bonus);
}
if let Some(sm) = scored.get_mut(dep.tool_result_idx) {
sm.with_dependency_bonus(bonus);
}
}
scored
}
}
fn calculate_target_count(total: usize, config: &CompressionConfig) -> usize {
let target = (total as f64 * config.target_ratio) as usize;
target.max(config.min_preserve_messages)
}
pub fn compress_with_pipeline(
messages: &[Message],
config: &CompressionConfig,
ai_mode: AiCompressionMode,
fast_model: Option<Box<dyn Provider>>,
) -> Result<Vec<Message>> {
let pipeline = match (ai_mode, fast_model) {
(AiCompressionMode::None, _) => CompressionPipeline::new_rule_only(config.clone()),
(AiCompressionMode::Light | AiCompressionMode::Deep, Some(model)) => {
CompressionPipeline::new_with_ai(config.clone(), model)
}
_ => CompressionPipeline::new_rule_only(config.clone()),
};
pipeline.execute_sync(messages)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{MessageContent, Role};
#[test]
fn test_pipeline_new_rule_only() {
let config = CompressionConfig::default();
let pipeline = CompressionPipeline::new_rule_only(config);
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("Test".to_string()),
},
];
let result = pipeline.execute_sync(&messages);
assert!(result.is_ok());
}
#[test]
fn test_calculate_target_count() {
let config = CompressionConfig::default();
let total = 100;
let target = calculate_target_count(total, &config);
assert!(target >= config.min_preserve_messages);
assert!(target < total);
}
#[test]
fn test_score_only() {
let config = CompressionConfig::default();
let pipeline = CompressionPipeline::new_rule_only(config);
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("Hello".to_string()),
},
Message {
role: Role::Assistant,
content: MessageContent::Text("Hi".to_string()),
},
];
let scored = pipeline.score_only(&messages);
assert_eq!(scored.len(), 2);
assert!(scored[0].final_score > scored[1].final_score); }
#[test]
fn test_execute_sync_small() {
let config = CompressionConfig::default();
let pipeline = CompressionPipeline::new_rule_only(config);
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("Hello".to_string()),
},
];
let result = pipeline.execute_sync(&messages).unwrap();
assert_eq!(result.len(), 1); }
#[test]
fn test_time_based_microcompact() {
let messages = vec![
Message {
role: Role::Tool,
content: MessageContent::Blocks(vec![
ContentBlock::ToolResult {
tool_use_id: "tool_1".to_string(),
content: "This is a very long tool result content that should be cleared...".repeat(20),
},
]),
},
Message {
role: Role::Tool,
content: MessageContent::Blocks(vec![
ContentBlock::ToolResult {
tool_use_id: "tool_2".to_string(),
content: "Short content".to_string(),
},
]),
},
];
let compacted = CompressionPipeline::time_based_microcompact(&messages);
if let MessageContent::Blocks(blocks) = &compacted[0].content {
if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
assert_eq!(content, TIME_BASED_MC_CLEARED_MESSAGE);
}
}
if let MessageContent::Blocks(blocks) = &compacted[1].content {
if let ContentBlock::ToolResult { content, .. } = &blocks[0] {
assert_eq!(content, "Short content");
}
}
}
#[test]
fn test_strip_thinking() {
let messages = vec![
Message {
role: Role::Assistant,
content: MessageContent::Blocks(vec![
ContentBlock::Text { text: "Response".to_string() },
ContentBlock::Thinking { thinking: "Long thinking process...".to_string(), signature: None },
]),
},
];
let stripped = CompressionPipeline::strip_thinking(&messages);
if let MessageContent::Blocks(blocks) = &stripped[0].content {
assert_eq!(blocks.len(), 1);
assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
}
}
#[test]
fn test_is_compactable_tool() {
assert!(CompressionPipeline::is_compactable_tool("bash"));
assert!(CompressionPipeline::is_compactable_tool("read"));
assert!(CompressionPipeline::is_compactable_tool("glob"));
assert!(!CompressionPipeline::is_compactable_tool("unknown_tool"));
}
#[test]
fn test_should_time_based_clear() {
let mut many_messages: Vec<Message> = vec![
Message {
role: Role::Assistant,
content: MessageContent::Text("response".to_string()),
},
];
for i in 0..15 {
many_messages.push(Message {
role: if i % 2 == 0 { Role::User } else { Role::Tool },
content: MessageContent::Text("content".to_string()),
});
}
assert!(CompressionPipeline::should_time_based_clear(&many_messages));
let few_messages = vec![
Message {
role: Role::Assistant,
content: MessageContent::Text("response".to_string()),
},
Message {
role: Role::User,
content: MessageContent::Text("follow-up".to_string()),
},
];
assert!(!CompressionPipeline::should_time_based_clear(&few_messages));
}
#[test]
fn test_validate_compression_valid() {
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("Request".to_string()),
},
Message {
role: Role::Assistant,
content: MessageContent::Blocks(vec![
ContentBlock::ToolUse {
id: "tool_1".to_string(),
name: "read".to_string(),
input: serde_json::json!({"path": "test.txt"}),
},
]),
},
Message {
role: Role::Tool,
content: MessageContent::Blocks(vec![
ContentBlock::ToolResult {
tool_use_id: "tool_1".to_string(),
content: "File content".to_string(),
},
]),
},
];
let deps = DependencyBuilder::build(&messages);
let errors = CompressionPipeline::validate_compression(&messages, &deps);
assert!(errors.is_empty());
}
#[test]
fn test_validate_compression_orphaned_tool_result() {
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("Request".to_string()),
},
Message {
role: Role::Tool,
content: MessageContent::Blocks(vec![
ContentBlock::ToolResult {
tool_use_id: "tool_missing".to_string(),
content: "Orphaned result".to_string(),
},
]),
},
];
let deps = DependencyBuilder::build(&messages);
let errors = CompressionPipeline::validate_compression(&messages, &deps);
assert!(!errors.is_empty());
assert!(errors.iter().any(|e| matches!(e, ValidationError::OrphanedToolResult { .. })));
}
#[test]
fn test_validate_compression_empty() {
let messages: Vec<Message> = vec![];
let deps = DependencyBuilder::build(&messages);
let errors = CompressionPipeline::validate_compression(&messages, &deps);
assert!(!errors.is_empty());
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingFirstMessage)));
}
}