use crate::compress::{
CompressionCache, CompressionConfig, CacheConfig, PriorityScorer,
SemanticCompressor, SemanticStrategy, estimate_tokens,
FocusTracker, ConversationFocus,
};
use crate::compress::hardcode_config::HardcodeConfig;
use crate::providers::{Message, MessageContent, Role};
use anyhow::Result;
pub struct OptimizedCompressor {
config: CompressionConfig,
cache: CompressionCache,
scorer: PriorityScorer,
semantic_strategy: SemanticStrategy,
focus_tracker: FocusTracker,
hardcode_config: HardcodeConfig,
semantic_compressor: SemanticCompressor,
}
impl OptimizedCompressor {
pub fn new(
compression_config: CompressionConfig,
cache_config: CacheConfig,
semantic_strategy: SemanticStrategy,
) -> Self {
Self {
config: compression_config,
cache: CompressionCache::new(cache_config),
scorer: PriorityScorer::default(),
semantic_strategy,
focus_tracker: FocusTracker::new(),
hardcode_config: HardcodeConfig::default(),
semantic_compressor: SemanticCompressor::default(),
}
}
pub async fn compress(&mut self, messages: Vec<Message>, context_size: Option<u32>) -> Result<Vec<Message>> {
if messages.is_empty() {
return Ok(messages);
}
let focus = self.focus_tracker.detect_focus(&messages);
log::info!(
"Detected focus - Topic: {:?}, Question: {:?}",
focus.current_topic,
focus.current_question
);
let current_tokens: u32 = messages.iter().map(|m| estimate_tokens(m)).sum();
let context_limit = context_size.unwrap_or(100_000);
log::info!(
"Current tokens: {}, Context limit: {}, Threshold: {}",
current_tokens,
context_limit,
(context_limit as f64 * self.config.threshold) as u32
);
if current_tokens < (context_limit as f64 * self.config.threshold) as u32 {
log::debug!("No compression needed");
return Ok(messages);
}
log::info!("Starting optimized compression with focus preservation");
let scored_messages = self.score_messages_with_focus(&messages, &focus);
let compressed = self.compress_with_cache_and_focus(scored_messages, &focus, context_limit)?;
let final_messages = self.inject_focus_message(compressed, &focus);
self.log_stats();
Ok(final_messages)
}
fn score_messages_with_focus(&self, messages: &[Message], focus: &ConversationFocus) -> Vec<(Message, f32)> {
messages
.iter()
.enumerate()
.map(|(idx, msg)| {
let priority_score = self.scorer.score(msg, idx, messages.len()).value();
let focus_score = self.focus_tracker.focus_score(msg, focus);
let combined_score = priority_score + focus_score;
log::trace!(
"Message {} - Priority: {:.2}, Focus: {:.2}, Combined: {:.2}",
idx,
priority_score,
focus_score,
combined_score
);
(msg.clone(), combined_score.min(1.0)) })
.collect()
}
fn compress_with_cache_and_focus(
&mut self,
scored_messages: Vec<(Message, f32)>,
focus: &ConversationFocus,
context_limit: u32,
) -> Result<Vec<Message>> {
let target_tokens = (context_limit as f64 * self.config.target_ratio) as u32;
let mut compressed = Vec::new();
let mut current_tokens = 0u32;
for (msg, _score) in scored_messages.iter() {
if matches!(msg.role, Role::System) {
compressed.push(msg.clone());
current_tokens += estimate_tokens(msg);
}
}
for (msg, score) in scored_messages.iter() {
if *score >= 0.7 && !matches!(msg.role, Role::System) {
if let Some(entry) = self.cache.get(msg) {
log::debug!("Cache hit for high score message");
compressed.push(entry.compressed.clone());
current_tokens += estimate_tokens(&entry.compressed);
} else {
compressed.push(msg.clone());
current_tokens += estimate_tokens(msg);
}
}
}
for ctx_text in &focus.recent_context {
for (msg, score) in scored_messages.iter() {
if *score < 0.7 {
let msg_text = match &msg.content {
MessageContent::Text(t) => t.clone(),
MessageContent::Blocks(blocks) => {
blocks.iter()
.filter_map(|b| {
if let crate::providers::ContentBlock::Text { text } = b {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" ")
}
};
if msg_text.contains(ctx_text) && !compressed.contains(msg) {
compressed.push(msg.clone());
current_tokens += estimate_tokens(msg);
log::debug!("Preserved message for focus context: {}", ctx_text);
}
}
}
}
for (msg, score) in scored_messages.iter() {
if *score < 0.7 && !compressed.contains(msg) {
if current_tokens >= target_tokens {
let compressed_msg = self.compress_message(msg, score)?;
let msg_tokens = estimate_tokens(&compressed_msg);
self.cache.put(msg, compressed_msg.clone());
compressed.push(compressed_msg);
current_tokens += msg_tokens;
} else {
compressed.push(msg.clone());
current_tokens += estimate_tokens(msg);
}
}
}
Ok(compressed)
}
fn inject_focus_message(&self, mut compressed: Vec<Message>, focus: &ConversationFocus) -> Vec<Message> {
let focus_msg = self.focus_tracker.create_focus_message(focus);
let existing_focus_pos = compressed.iter().position(|m| {
if matches!(m.role, Role::System) {
match &m.content {
MessageContent::Text(t) => {
t.contains("焦点") || t.contains("Focus") || t.contains("【焦点上下文】")
}
_ => false
}
} else {
false
}
});
if let Some(pos) = existing_focus_pos {
compressed[pos] = focus_msg;
log::info!("Replaced existing focus message at position {}", pos);
} else {
let insert_pos = compressed.iter()
.position(|m| !matches!(m.role, Role::System))
.unwrap_or(1);
compressed.insert(insert_pos, focus_msg);
log::info!("Injected new focus message at position {}", insert_pos);
}
compressed
}
fn compress_message(&self, message: &Message, _score: &f32) -> Result<Message> {
match self.semantic_strategy {
SemanticStrategy::None => {
self.truncate_message(message)
}
SemanticStrategy::OldOnly | SemanticStrategy::Aggressive => {
if self.semantic_compressor.should_summarize(&[message.clone()]) {
self.truncate_message(message)
} else {
self.truncate_message(message)
}
}
}
}
fn truncate_message(&self, message: &Message) -> Result<Message> {
match &message.content {
MessageContent::Text(text) => {
if text.len() > self.hardcode_config.long_text_threshold {
let keep_len = (self.hardcode_config.long_text_threshold as f64 * 0.75) as usize;
let truncated = format!("{}...[compressed]", &text.chars().take(keep_len).collect::<String>());
Ok(Message {
role: message.role,
content: MessageContent::Text(truncated),
})
} else {
Ok(message.clone())
}
}
MessageContent::Blocks(blocks) => {
let compressed_blocks = blocks
.iter()
.filter_map(|block| {
match block {
crate::providers::ContentBlock::Text { text } => {
if text.len() > self.hardcode_config.long_text_threshold {
let keep_len = (self.hardcode_config.long_text_threshold as f64 * 0.75) as usize;
Some(crate::providers::ContentBlock::Text {
text: format!("{}...[compressed]", &text.chars().take(keep_len).collect::<String>()),
})
} else {
Some(block.clone())
}
}
_ => Some(block.clone()),
}
})
.collect();
Ok(Message {
role: message.role,
content: MessageContent::Blocks(compressed_blocks),
})
}
}
}
fn log_stats(&self) {
let stats = self.cache.stats();
log::info!(
"Compression stats - Hits: {}, Misses: {}, Hit rate: {:.2}%, Entries: {}",
stats.hits,
stats.misses,
stats.hit_rate() * 100.0,
stats.entries
);
}
}
pub async fn example_optimized_compression() -> Result<()> {
let compression_config = CompressionConfig::default();
let cache_config = CacheConfig {
max_entries: 100,
ttl: std::time::Duration::from_secs(300),
min_size_to_cache: 100,
};
let mut compressor = OptimizedCompressor::new(
compression_config,
cache_config,
SemanticStrategy::OldOnly,
);
let messages = vec![
Message {
role: Role::System,
content: MessageContent::Text("You are a helpful coding assistant.".to_string()),
},
Message {
role: Role::User,
content: MessageContent::Text("Let's discuss compression algorithms.".to_string()),
},
Message {
role: Role::Assistant,
content: MessageContent::Text("Compression algorithms reduce data size...".to_string()),
},
Message {
role: Role::User,
content: MessageContent::Text("How do I implement Huffman coding?".to_string()),
},
Message {
role: Role::Assistant,
content: MessageContent::Text("Huffman coding uses frequency-based encoding...".to_string()),
},
Message {
role: Role::User,
content: MessageContent::Text("Wait, switching to a different topic: how to optimize database queries?".to_string()),
},
Message {
role: Role::Assistant,
content: MessageContent::Text("Database optimization involves indexing...".to_string()),
},
Message {
role: Role::User,
content: MessageContent::Text("Can you help me fix this slow query in PostgreSQL?".to_string()),
},
];
let compressed = compressor.compress(messages.clone(), Some(50_000)).await?;
println!("Original messages: {}", messages.len());
println!("Compressed messages: {}", compressed.len());
for msg in compressed.iter() {
if let MessageContent::Text(text) = &msg.content {
if text.contains("Current Conversation Focus") {
println!("\nFocus message found:\n{}", text);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimized_compressor_creation() {
let compressor = OptimizedCompressor::new(
CompressionConfig::default(),
CacheConfig::default(),
SemanticStrategy::OldOnly,
);
assert!(compressor.cache.is_empty());
}
#[test]
fn test_focus_detection() {
let mut compressor = OptimizedCompressor::new(
CompressionConfig::default(),
CacheConfig::default(),
SemanticStrategy::None,
);
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("Test message".to_string()),
},
Message {
role: Role::Assistant,
content: MessageContent::Text("Response".to_string()),
},
];
let focus = compressor.focus_tracker.detect_focus(&messages);
assert!(focus.recent_context.len() > 0);
}
#[test]
fn test_combined_scoring() {
let mut compressor = OptimizedCompressor::new(
CompressionConfig::default(),
CacheConfig::default(),
SemanticStrategy::None,
);
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("Let's discuss database optimization".to_string()),
},
Message {
role: Role::Assistant,
content: MessageContent::Text("Database optimization is important...".to_string()),
},
Message {
role: Role::User,
content: MessageContent::Text("How to fix slow query?".to_string()),
},
];
let focus = compressor.focus_tracker.detect_focus(&messages);
let scored = compressor.score_messages_with_focus(&messages, &focus);
assert!(scored[2].1 > scored[0].1);
}
#[test]
fn test_focus_message_injection() {
let compressor = OptimizedCompressor::new(
CompressionConfig::default(),
CacheConfig::default(),
SemanticStrategy::None,
);
let focus = ConversationFocus {
current_topic: Some("optimization".to_string()),
current_question: Some("How to fix slow query?".to_string()),
recent_context: vec!["Database discussion".to_string()],
topic_transitions: vec![],
detected_at: 2,
};
let messages = vec![
Message {
role: Role::System,
content: MessageContent::Text("System prompt".to_string()),
},
Message {
role: Role::User,
content: MessageContent::Text("User question".to_string()),
},
];
let final_messages = compressor.inject_focus_message(messages, &focus);
assert_eq!(final_messages.len(), 3);
if let MessageContent::Text(text) = &final_messages[1].content {
assert!(text.contains("焦点上下文"));
} else {
panic!("Expected text content");
}
}
}