1use crate::compress::{
7 CompressionCache, CompressionConfig, CacheConfig, PriorityScorer,
8 SemanticCompressor, SemanticStrategy, estimate_tokens,
9 FocusTracker, ConversationFocus,
10};
11use crate::compress::hardcode_config::HardcodeConfig;
12use crate::providers::{Message, MessageContent, Role};
13use anyhow::Result;
14
15pub struct OptimizedCompressor {
17 config: CompressionConfig,
18 cache: CompressionCache,
19 scorer: PriorityScorer,
20 semantic_strategy: SemanticStrategy,
21 focus_tracker: FocusTracker,
22 hardcode_config: HardcodeConfig,
23 semantic_compressor: SemanticCompressor,
24}
25
26impl OptimizedCompressor {
27 pub fn new(
28 compression_config: CompressionConfig,
29 cache_config: CacheConfig,
30 semantic_strategy: SemanticStrategy,
31 ) -> Self {
32 Self {
33 config: compression_config,
34 cache: CompressionCache::new(cache_config),
35 scorer: PriorityScorer::default(),
36 semantic_strategy,
37 focus_tracker: FocusTracker::new(),
38 hardcode_config: HardcodeConfig::default(),
39 semantic_compressor: SemanticCompressor::default(),
40 }
41 }
42
43 pub async fn compress(&mut self, messages: Vec<Message>, context_size: Option<u32>) -> Result<Vec<Message>> {
45 if messages.is_empty() {
46 return Ok(messages);
47 }
48
49 let focus = self.focus_tracker.detect_focus(&messages);
51 log::info!(
52 "Detected focus - Topic: {:?}, Question: {:?}",
53 focus.current_topic,
54 focus.current_question
55 );
56
57 let current_tokens: u32 = messages.iter().map(|m| estimate_tokens(m)).sum();
59 let context_limit = context_size.unwrap_or(100_000);
60
61 log::info!(
62 "Current tokens: {}, Context limit: {}, Threshold: {}",
63 current_tokens,
64 context_limit,
65 (context_limit as f64 * self.config.threshold) as u32
66 );
67
68 if current_tokens < (context_limit as f64 * self.config.threshold) as u32 {
70 log::debug!("No compression needed");
71 return Ok(messages);
72 }
73
74 log::info!("Starting optimized compression with focus preservation");
75
76 let scored_messages = self.score_messages_with_focus(&messages, &focus);
78
79 let compressed = self.compress_with_cache_and_focus(scored_messages, &focus, context_limit)?;
81
82 let final_messages = self.inject_focus_message(compressed, &focus);
84
85 self.log_stats();
87
88 Ok(final_messages)
89 }
90
91 fn score_messages_with_focus(&self, messages: &[Message], focus: &ConversationFocus) -> Vec<(Message, f32)> {
93 messages
94 .iter()
95 .enumerate()
96 .map(|(idx, msg)| {
97 let priority_score = self.scorer.score(msg, idx, messages.len()).value();
99 let focus_score = self.focus_tracker.focus_score(msg, focus);
100
101 let combined_score = priority_score + focus_score;
104
105 log::trace!(
106 "Message {} - Priority: {:.2}, Focus: {:.2}, Combined: {:.2}",
107 idx,
108 priority_score,
109 focus_score,
110 combined_score
111 );
112
113 (msg.clone(), combined_score.min(1.0)) })
115 .collect()
116 }
117
118 fn compress_with_cache_and_focus(
120 &mut self,
121 scored_messages: Vec<(Message, f32)>,
122 focus: &ConversationFocus,
123 context_limit: u32,
124 ) -> Result<Vec<Message>> {
125 let target_tokens = (context_limit as f64 * self.config.target_ratio) as u32;
126 let mut compressed = Vec::new();
127 let mut current_tokens = 0u32;
128
129 for (msg, _score) in scored_messages.iter() {
131 if matches!(msg.role, Role::System) {
132 compressed.push(msg.clone());
133 current_tokens += estimate_tokens(msg);
134 }
135 }
136
137 for (msg, score) in scored_messages.iter() {
139 if *score >= 0.7 && !matches!(msg.role, Role::System) {
140 if let Some(entry) = self.cache.get(msg) {
142 log::debug!("Cache hit for high score message");
143 compressed.push(entry.compressed.clone());
144 current_tokens += estimate_tokens(&entry.compressed);
145 } else {
146 compressed.push(msg.clone());
147 current_tokens += estimate_tokens(msg);
148 }
149 }
150 }
151
152 for ctx_text in &focus.recent_context {
154 for (msg, score) in scored_messages.iter() {
156 if *score < 0.7 {
157 let msg_text = match &msg.content {
158 MessageContent::Text(t) => t.clone(),
159 MessageContent::Blocks(blocks) => {
160 blocks.iter()
161 .filter_map(|b| {
162 if let crate::providers::ContentBlock::Text { text } = b {
163 Some(text.clone())
164 } else {
165 None
166 }
167 })
168 .collect::<Vec<_>>()
169 .join(" ")
170 }
171 };
172
173 if msg_text.contains(ctx_text) && !compressed.contains(msg) {
174 compressed.push(msg.clone());
175 current_tokens += estimate_tokens(msg);
176 log::debug!("Preserved message for focus context: {}", ctx_text);
177 }
178 }
179 }
180 }
181
182 for (msg, score) in scored_messages.iter() {
184 if *score < 0.7 && !compressed.contains(msg) {
185 if current_tokens >= target_tokens {
186 let compressed_msg = self.compress_message(msg, score)?;
188
189 let msg_tokens = estimate_tokens(&compressed_msg);
191
192 self.cache.put(msg, compressed_msg.clone());
194
195 compressed.push(compressed_msg);
196 current_tokens += msg_tokens;
197 } else {
198 compressed.push(msg.clone());
200 current_tokens += estimate_tokens(msg);
201 }
202 }
203 }
204
205 Ok(compressed)
206 }
207
208 fn inject_focus_message(&self, mut compressed: Vec<Message>, focus: &ConversationFocus) -> Vec<Message> {
210 let focus_msg = self.focus_tracker.create_focus_message(focus);
212
213 let insert_pos = compressed.iter()
215 .position(|m| !matches!(m.role, Role::System))
216 .unwrap_or(1);
217
218 compressed.insert(insert_pos, focus_msg);
219
220 log::info!("Injected focus message at position {}", insert_pos);
221 compressed
222 }
223
224 fn compress_message(&self, message: &Message, _score: &f32) -> Result<Message> {
226 match self.semantic_strategy {
227 SemanticStrategy::None => {
228 self.truncate_message(message)
230 }
231 SemanticStrategy::OldOnly | SemanticStrategy::Aggressive => {
232 if self.semantic_compressor.should_summarize(&[message.clone()]) {
234 self.truncate_message(message)
237 } else {
238 self.truncate_message(message)
239 }
240 }
241 }
242 }
243
244 fn truncate_message(&self, message: &Message) -> Result<Message> {
246 match &message.content {
248 MessageContent::Text(text) => {
249 if text.len() > self.hardcode_config.long_text_threshold {
250 let keep_len = (self.hardcode_config.long_text_threshold as f64 * 0.75) as usize;
251 let truncated = format!("{}...[compressed]", &text.chars().take(keep_len).collect::<String>());
252 Ok(Message {
253 role: message.role,
254 content: MessageContent::Text(truncated),
255 })
256 } else {
257 Ok(message.clone())
258 }
259 }
260 MessageContent::Blocks(blocks) => {
261 let compressed_blocks = blocks
263 .iter()
264 .filter_map(|block| {
265 match block {
266 crate::providers::ContentBlock::Text { text } => {
267 if text.len() > self.hardcode_config.long_text_threshold {
268 let keep_len = (self.hardcode_config.long_text_threshold as f64 * 0.75) as usize;
269 Some(crate::providers::ContentBlock::Text {
270 text: format!("{}...[compressed]", &text.chars().take(keep_len).collect::<String>()),
271 })
272 } else {
273 Some(block.clone())
274 }
275 }
276 _ => Some(block.clone()),
277 }
278 })
279 .collect();
280
281 Ok(Message {
282 role: message.role,
283 content: MessageContent::Blocks(compressed_blocks),
284 })
285 }
286 }
287 }
288
289 fn log_stats(&self) {
291 let stats = self.cache.stats();
292 log::info!(
293 "Compression stats - Hits: {}, Misses: {}, Hit rate: {:.2}%, Entries: {}",
294 stats.hits,
295 stats.misses,
296 stats.hit_rate() * 100.0,
297 stats.entries
298 );
299 }
300}
301
302pub async fn example_optimized_compression() -> Result<()> {
304 let compression_config = CompressionConfig::default();
306
307 let cache_config = CacheConfig {
308 max_entries: 100,
309 ttl: std::time::Duration::from_secs(300),
310 min_size_to_cache: 100,
311 };
312
313 let mut compressor = OptimizedCompressor::new(
314 compression_config,
315 cache_config,
316 SemanticStrategy::OldOnly,
317 );
318
319 let messages = vec![
321 Message {
322 role: Role::System,
323 content: MessageContent::Text("You are a helpful coding assistant.".to_string()),
324 },
325 Message {
326 role: Role::User,
327 content: MessageContent::Text("Let's discuss compression algorithms.".to_string()),
328 },
329 Message {
330 role: Role::Assistant,
331 content: MessageContent::Text("Compression algorithms reduce data size...".to_string()),
332 },
333 Message {
334 role: Role::User,
335 content: MessageContent::Text("How do I implement Huffman coding?".to_string()),
336 },
337 Message {
338 role: Role::Assistant,
339 content: MessageContent::Text("Huffman coding uses frequency-based encoding...".to_string()),
340 },
341 Message {
343 role: Role::User,
344 content: MessageContent::Text("Wait, switching to a different topic: how to optimize database queries?".to_string()),
345 },
346 Message {
347 role: Role::Assistant,
348 content: MessageContent::Text("Database optimization involves indexing...".to_string()),
349 },
350 Message {
351 role: Role::User,
352 content: MessageContent::Text("Can you help me fix this slow query in PostgreSQL?".to_string()),
353 },
354 ];
355
356 let compressed = compressor.compress(messages.clone(), Some(50_000)).await?;
358
359 println!("Original messages: {}", messages.len());
360 println!("Compressed messages: {}", compressed.len());
361
362 for msg in compressed.iter() {
364 if let MessageContent::Text(text) = &msg.content {
365 if text.contains("Current Conversation Focus") {
366 println!("\nFocus message found:\n{}", text);
367 }
368 }
369 }
370
371 Ok(())
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_optimized_compressor_creation() {
380 let compressor = OptimizedCompressor::new(
381 CompressionConfig::default(),
382 CacheConfig::default(),
383 SemanticStrategy::OldOnly,
384 );
385 assert!(compressor.cache.is_empty());
386 }
387
388 #[test]
389 fn test_focus_detection() {
390 let mut compressor = OptimizedCompressor::new(
391 CompressionConfig::default(),
392 CacheConfig::default(),
393 SemanticStrategy::None,
394 );
395
396 let messages = vec![
397 Message {
398 role: Role::User,
399 content: MessageContent::Text("Test message".to_string()),
400 },
401 Message {
402 role: Role::Assistant,
403 content: MessageContent::Text("Response".to_string()),
404 },
405 ];
406
407 let focus = compressor.focus_tracker.detect_focus(&messages);
408 assert!(focus.recent_context.len() > 0);
409 }
410
411 #[test]
412 fn test_combined_scoring() {
413 let mut compressor = OptimizedCompressor::new(
414 CompressionConfig::default(),
415 CacheConfig::default(),
416 SemanticStrategy::None,
417 );
418
419 let messages = vec![
420 Message {
421 role: Role::User,
422 content: MessageContent::Text("Let's discuss database optimization".to_string()),
423 },
424 Message {
425 role: Role::Assistant,
426 content: MessageContent::Text("Database optimization is important...".to_string()),
427 },
428 Message {
429 role: Role::User,
430 content: MessageContent::Text("How to fix slow query?".to_string()),
431 },
432 ];
433
434 let focus = compressor.focus_tracker.detect_focus(&messages);
435 let scored = compressor.score_messages_with_focus(&messages, &focus);
436
437 assert!(scored[2].1 > scored[0].1);
439 }
440
441 #[test]
442 fn test_focus_message_injection() {
443 let compressor = OptimizedCompressor::new(
444 CompressionConfig::default(),
445 CacheConfig::default(),
446 SemanticStrategy::None,
447 );
448
449 let focus = ConversationFocus {
450 current_topic: Some("optimization".to_string()),
451 current_question: Some("How to fix slow query?".to_string()),
452 recent_context: vec!["Database discussion".to_string()],
453 topic_transitions: vec![],
454 detected_at: 2,
455 };
456
457 let messages = vec![
458 Message {
459 role: Role::System,
460 content: MessageContent::Text("System prompt".to_string()),
461 },
462 Message {
463 role: Role::User,
464 content: MessageContent::Text("User question".to_string()),
465 },
466 ];
467
468 let final_messages = compressor.inject_focus_message(messages, &focus);
469
470 assert_eq!(final_messages.len(), 3);
472
473 if let MessageContent::Text(text) = &final_messages[1].content {
475 assert!(text.contains("焦点上下文"));
476 } else {
477 panic!("Expected text content");
478 }
479 }
480}