1use crate::compress::{
7 CompressionCache, CompressionConfig, CacheConfig, PriorityScorer,
8 SemanticCompressor, SemanticStrategy, estimate_tokens,
9 FocusTracker, ConversationFocus,
10};
11use crate::compress::hardcode_config::HardcodeConfig;
12use crate::providers::{Message, MessageContent, Role};
13use anyhow::Result;
14
15pub struct OptimizedCompressor {
17 config: CompressionConfig,
18 cache: CompressionCache,
19 scorer: PriorityScorer,
20 semantic_strategy: SemanticStrategy,
21 focus_tracker: FocusTracker,
22 hardcode_config: HardcodeConfig,
23 semantic_compressor: SemanticCompressor,
24}
25
26impl OptimizedCompressor {
27 pub fn new(
28 compression_config: CompressionConfig,
29 cache_config: CacheConfig,
30 semantic_strategy: SemanticStrategy,
31 ) -> Self {
32 Self {
33 config: compression_config,
34 cache: CompressionCache::new(cache_config),
35 scorer: PriorityScorer::default(),
36 semantic_strategy,
37 focus_tracker: FocusTracker::new(),
38 hardcode_config: HardcodeConfig::default(),
39 semantic_compressor: SemanticCompressor::default(),
40 }
41 }
42
43 pub async fn compress(&mut self, messages: Vec<Message>, context_size: Option<u32>) -> Result<Vec<Message>> {
45 if messages.is_empty() {
46 return Ok(messages);
47 }
48
49 let focus = self.focus_tracker.detect_focus(&messages);
51 log::info!(
52 "Detected focus - Topic: {:?}, Question: {:?}",
53 focus.current_topic,
54 focus.current_question
55 );
56
57 let current_tokens: u32 = messages.iter().map(|m| estimate_tokens(m)).sum();
59 let context_limit = context_size.unwrap_or(100_000);
60
61 log::info!(
62 "Current tokens: {}, Context limit: {}, Threshold: {}",
63 current_tokens,
64 context_limit,
65 (context_limit as f64 * self.config.threshold) as u32
66 );
67
68 if current_tokens < (context_limit as f64 * self.config.threshold) as u32 {
70 log::debug!("No compression needed");
71 return Ok(messages);
72 }
73
74 log::info!("Starting optimized compression with focus preservation");
75
76 let scored_messages = self.score_messages_with_focus(&messages, &focus);
78
79 let compressed = self.compress_with_cache_and_focus(scored_messages, &focus, context_limit)?;
81
82 let final_messages = self.inject_focus_message(compressed, &focus);
84
85 self.log_stats();
87
88 Ok(final_messages)
89 }
90
91 fn score_messages_with_focus(&self, messages: &[Message], focus: &ConversationFocus) -> Vec<(Message, f32)> {
93 messages
94 .iter()
95 .enumerate()
96 .map(|(idx, msg)| {
97 let priority_score = self.scorer.score(msg, idx, messages.len()).value();
99 let focus_score = self.focus_tracker.focus_score(msg, focus);
100
101 let combined_score = priority_score + focus_score;
104
105 log::trace!(
106 "Message {} - Priority: {:.2}, Focus: {:.2}, Combined: {:.2}",
107 idx,
108 priority_score,
109 focus_score,
110 combined_score
111 );
112
113 (msg.clone(), combined_score.min(1.0)) })
115 .collect()
116 }
117
118 fn compress_with_cache_and_focus(
120 &mut self,
121 scored_messages: Vec<(Message, f32)>,
122 focus: &ConversationFocus,
123 context_limit: u32,
124 ) -> Result<Vec<Message>> {
125 let target_tokens = (context_limit as f64 * self.config.target_ratio) as u32;
126 let mut compressed = Vec::new();
127 let mut current_tokens = 0u32;
128
129 for (msg, _score) in scored_messages.iter() {
131 if matches!(msg.role, Role::System) {
132 compressed.push(msg.clone());
133 current_tokens += estimate_tokens(msg);
134 }
135 }
136
137 for (msg, score) in scored_messages.iter() {
139 if *score >= 0.7 && !matches!(msg.role, Role::System) {
140 if let Some(entry) = self.cache.get(msg) {
142 log::debug!("Cache hit for high score message");
143 compressed.push(entry.compressed.clone());
144 current_tokens += estimate_tokens(&entry.compressed);
145 } else {
146 compressed.push(msg.clone());
147 current_tokens += estimate_tokens(msg);
148 }
149 }
150 }
151
152 for ctx_text in &focus.recent_context {
154 for (msg, score) in scored_messages.iter() {
156 if *score < 0.7 {
157 let msg_text = match &msg.content {
158 MessageContent::Text(t) => t.clone(),
159 MessageContent::Blocks(blocks) => {
160 blocks.iter()
161 .filter_map(|b| {
162 if let crate::providers::ContentBlock::Text { text } = b {
163 Some(text.clone())
164 } else {
165 None
166 }
167 })
168 .collect::<Vec<_>>()
169 .join(" ")
170 }
171 };
172
173 if msg_text.contains(ctx_text) && !compressed.contains(msg) {
174 compressed.push(msg.clone());
175 current_tokens += estimate_tokens(msg);
176 log::debug!("Preserved message for focus context: {}", ctx_text);
177 }
178 }
179 }
180 }
181
182 for (msg, score) in scored_messages.iter() {
184 if *score < 0.7 && !compressed.contains(msg) {
185 if current_tokens >= target_tokens {
186 let compressed_msg = self.compress_message(msg, score)?;
188
189 let msg_tokens = estimate_tokens(&compressed_msg);
191
192 self.cache.put(msg, compressed_msg.clone());
194
195 compressed.push(compressed_msg);
196 current_tokens += msg_tokens;
197 } else {
198 compressed.push(msg.clone());
200 current_tokens += estimate_tokens(msg);
201 }
202 }
203 }
204
205 Ok(compressed)
206 }
207
208 fn inject_focus_message(&self, mut compressed: Vec<Message>, focus: &ConversationFocus) -> Vec<Message> {
211 let focus_msg = self.focus_tracker.create_focus_message(focus);
213
214 let existing_focus_pos = compressed.iter().position(|m| {
216 if matches!(m.role, Role::System) {
217 match &m.content {
218 MessageContent::Text(t) => {
219 t.contains("焦点") || t.contains("Focus") || t.contains("【焦点上下文】")
220 }
221 _ => false
222 }
223 } else {
224 false
225 }
226 });
227
228 if let Some(pos) = existing_focus_pos {
229 compressed[pos] = focus_msg;
231 log::info!("Replaced existing focus message at position {}", pos);
232 } else {
233 let insert_pos = compressed.iter()
235 .position(|m| !matches!(m.role, Role::System))
236 .unwrap_or(1);
237
238 compressed.insert(insert_pos, focus_msg);
239 log::info!("Injected new focus message at position {}", insert_pos);
240 }
241
242 compressed
243 }
244
245 fn compress_message(&self, message: &Message, _score: &f32) -> Result<Message> {
247 match self.semantic_strategy {
248 SemanticStrategy::None => {
249 self.truncate_message(message)
251 }
252 SemanticStrategy::OldOnly | SemanticStrategy::Aggressive => {
253 if self.semantic_compressor.should_summarize(&[message.clone()]) {
255 self.truncate_message(message)
258 } else {
259 self.truncate_message(message)
260 }
261 }
262 }
263 }
264
265 fn truncate_message(&self, message: &Message) -> Result<Message> {
267 match &message.content {
269 MessageContent::Text(text) => {
270 if text.len() > self.hardcode_config.long_text_threshold {
271 let keep_len = (self.hardcode_config.long_text_threshold as f64 * 0.75) as usize;
272 let truncated = format!("{}...[compressed]", &text.chars().take(keep_len).collect::<String>());
273 Ok(Message {
274 role: message.role,
275 content: MessageContent::Text(truncated),
276 })
277 } else {
278 Ok(message.clone())
279 }
280 }
281 MessageContent::Blocks(blocks) => {
282 let compressed_blocks = blocks
284 .iter()
285 .filter_map(|block| {
286 match block {
287 crate::providers::ContentBlock::Text { text } => {
288 if text.len() > self.hardcode_config.long_text_threshold {
289 let keep_len = (self.hardcode_config.long_text_threshold as f64 * 0.75) as usize;
290 Some(crate::providers::ContentBlock::Text {
291 text: format!("{}...[compressed]", &text.chars().take(keep_len).collect::<String>()),
292 })
293 } else {
294 Some(block.clone())
295 }
296 }
297 _ => Some(block.clone()),
298 }
299 })
300 .collect();
301
302 Ok(Message {
303 role: message.role,
304 content: MessageContent::Blocks(compressed_blocks),
305 })
306 }
307 }
308 }
309
310 fn log_stats(&self) {
312 let stats = self.cache.stats();
313 log::info!(
314 "Compression stats - Hits: {}, Misses: {}, Hit rate: {:.2}%, Entries: {}",
315 stats.hits,
316 stats.misses,
317 stats.hit_rate() * 100.0,
318 stats.entries
319 );
320 }
321}
322
323pub async fn example_optimized_compression() -> Result<()> {
325 let compression_config = CompressionConfig::default();
327
328 let cache_config = CacheConfig {
329 max_entries: 100,
330 ttl: std::time::Duration::from_secs(300),
331 min_size_to_cache: 100,
332 };
333
334 let mut compressor = OptimizedCompressor::new(
335 compression_config,
336 cache_config,
337 SemanticStrategy::OldOnly,
338 );
339
340 let messages = vec![
342 Message {
343 role: Role::System,
344 content: MessageContent::Text("You are a helpful coding assistant.".to_string()),
345 },
346 Message {
347 role: Role::User,
348 content: MessageContent::Text("Let's discuss compression algorithms.".to_string()),
349 },
350 Message {
351 role: Role::Assistant,
352 content: MessageContent::Text("Compression algorithms reduce data size...".to_string()),
353 },
354 Message {
355 role: Role::User,
356 content: MessageContent::Text("How do I implement Huffman coding?".to_string()),
357 },
358 Message {
359 role: Role::Assistant,
360 content: MessageContent::Text("Huffman coding uses frequency-based encoding...".to_string()),
361 },
362 Message {
364 role: Role::User,
365 content: MessageContent::Text("Wait, switching to a different topic: how to optimize database queries?".to_string()),
366 },
367 Message {
368 role: Role::Assistant,
369 content: MessageContent::Text("Database optimization involves indexing...".to_string()),
370 },
371 Message {
372 role: Role::User,
373 content: MessageContent::Text("Can you help me fix this slow query in PostgreSQL?".to_string()),
374 },
375 ];
376
377 let compressed = compressor.compress(messages.clone(), Some(50_000)).await?;
379
380 println!("Original messages: {}", messages.len());
381 println!("Compressed messages: {}", compressed.len());
382
383 for msg in compressed.iter() {
385 if let MessageContent::Text(text) = &msg.content {
386 if text.contains("Current Conversation Focus") {
387 println!("\nFocus message found:\n{}", text);
388 }
389 }
390 }
391
392 Ok(())
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_optimized_compressor_creation() {
401 let compressor = OptimizedCompressor::new(
402 CompressionConfig::default(),
403 CacheConfig::default(),
404 SemanticStrategy::OldOnly,
405 );
406 assert!(compressor.cache.is_empty());
407 }
408
409 #[test]
410 fn test_focus_detection() {
411 let mut compressor = OptimizedCompressor::new(
412 CompressionConfig::default(),
413 CacheConfig::default(),
414 SemanticStrategy::None,
415 );
416
417 let messages = vec![
418 Message {
419 role: Role::User,
420 content: MessageContent::Text("Test message".to_string()),
421 },
422 Message {
423 role: Role::Assistant,
424 content: MessageContent::Text("Response".to_string()),
425 },
426 ];
427
428 let focus = compressor.focus_tracker.detect_focus(&messages);
429 assert!(focus.recent_context.len() > 0);
430 }
431
432 #[test]
433 fn test_combined_scoring() {
434 let mut compressor = OptimizedCompressor::new(
435 CompressionConfig::default(),
436 CacheConfig::default(),
437 SemanticStrategy::None,
438 );
439
440 let messages = vec![
441 Message {
442 role: Role::User,
443 content: MessageContent::Text("Let's discuss database optimization".to_string()),
444 },
445 Message {
446 role: Role::Assistant,
447 content: MessageContent::Text("Database optimization is important...".to_string()),
448 },
449 Message {
450 role: Role::User,
451 content: MessageContent::Text("How to fix slow query?".to_string()),
452 },
453 ];
454
455 let focus = compressor.focus_tracker.detect_focus(&messages);
456 let scored = compressor.score_messages_with_focus(&messages, &focus);
457
458 assert!(scored[2].1 > scored[0].1);
460 }
461
462 #[test]
463 fn test_focus_message_injection() {
464 let compressor = OptimizedCompressor::new(
465 CompressionConfig::default(),
466 CacheConfig::default(),
467 SemanticStrategy::None,
468 );
469
470 let focus = ConversationFocus {
471 current_topic: Some("optimization".to_string()),
472 current_question: Some("How to fix slow query?".to_string()),
473 recent_context: vec!["Database discussion".to_string()],
474 topic_transitions: vec![],
475 detected_at: 2,
476 };
477
478 let messages = vec![
479 Message {
480 role: Role::System,
481 content: MessageContent::Text("System prompt".to_string()),
482 },
483 Message {
484 role: Role::User,
485 content: MessageContent::Text("User question".to_string()),
486 },
487 ];
488
489 let final_messages = compressor.inject_focus_message(messages, &focus);
490
491 assert_eq!(final_messages.len(), 3);
493
494 if let MessageContent::Text(text) = &final_messages[1].content {
496 assert!(text.contains("焦点上下文"));
497 } else {
498 panic!("Expected text content");
499 }
500 }
501}