Skip to main content

matrixcode_core/compress/
cache.rs

1//! Compression cache to avoid redundant compression operations.
2//!
3//! This module provides caching for compression results to improve
4//! performance when dealing with repeated compression of similar content.
5//!
6//! # Extended Cache Features
7//! - 焦点预测缓存
8//! - 优先级分数缓存
9//! - Coherence 分段缓存
10//! - 复杂度分析缓存
11
12use crate::providers::Message;
13use crate::compress::complexity::ComplexityLevel;
14use crate::compress::focus_point::FocusPoint;
15use std::collections::HashMap;
16use std::time::{Duration, Instant};
17use chrono::{DateTime, Utc};
18
19/// Cache entry for a compressed message.
20#[derive(Debug, Clone)]
21pub struct CacheEntry {
22    /// Compressed message
23    pub compressed: Message,
24    /// Original content hash
25    pub hash: u64,
26    /// When the entry was created
27    pub created_at: Instant,
28    /// Number of times this entry was used
29    pub hit_count: usize,
30}
31
32/// Statistics for the compression cache.
33#[derive(Debug, Clone, Default)]
34pub struct CacheStats {
35    pub hits: usize,
36    pub misses: usize,
37    pub entries: usize,
38    pub total_saved_tokens: u32,
39}
40
41impl CacheStats {
42    pub fn hit_rate(&self) -> f32 {
43        if self.hits + self.misses == 0 {
44            0.0
45        } else {
46            self.hits as f32 / (self.hits + self.misses) as f32
47        }
48    }
49}
50
51/// Cached priority score with validity tracking
52#[derive(Debug, Clone)]
53pub struct CachedPriorityScore {
54    /// Priority score value
55    pub score: f32,
56    /// When the score was calculated
57    pub calculated_at: DateTime<Utc>,
58    /// How long the score remains valid
59    pub valid_for: Duration,
60    /// Keywords that influenced this score
61    pub keywords: Vec<String>,
62}
63
64impl CachedPriorityScore {
65    pub fn new(score: f32, keywords: Vec<String>, valid_for: Duration) -> Self {
66        Self {
67            score,
68            calculated_at: Utc::now(),
69            valid_for,
70            keywords,
71        }
72    }
73    
74    pub fn is_valid(&self) -> bool {
75        let now = Utc::now();
76        now - self.calculated_at < chrono::Duration::from_std(self.valid_for).unwrap()
77    }
78}
79
80/// Cached focus prediction
81#[derive(Debug, Clone)]
82pub struct CachedFocusPrediction {
83    /// Predicted focus point
84    pub focus: FocusPoint,
85    /// Prediction confidence
86    pub confidence: f32,
87    /// When the prediction was made
88    pub predicted_at: DateTime<Utc>,
89}
90
91impl CachedFocusPrediction {
92    pub fn new(focus: FocusPoint, confidence: f32) -> Self {
93        Self {
94            focus,
95            confidence,
96            predicted_at: Utc::now(),
97        }
98    }
99}
100
101/// Cached complexity level
102#[derive(Debug, Clone)]
103pub struct CachedComplexity {
104    /// Complexity level
105    pub level: ComplexityLevel,
106    /// When the analysis was performed
107    pub analyzed_at: DateTime<Utc>,
108}
109
110impl CachedComplexity {
111    pub fn new(level: ComplexityLevel) -> Self {
112        Self {
113            level,
114            analyzed_at: Utc::now(),
115        }
116    }
117}
118
119/// Compression cache configuration.
120#[derive(Debug, Clone)]
121pub struct CacheConfig {
122    /// Maximum number of entries
123    pub max_entries: usize,
124    /// Time-to-live for cache entries
125    pub ttl: Duration,
126    /// Minimum message size to cache (in characters)
127    pub min_size_to_cache: usize,
128}
129
130impl Default for CacheConfig {
131    fn default() -> Self {
132        Self {
133            max_entries: 100,
134            ttl: Duration::from_secs(300), // 5 minutes
135            min_size_to_cache: 100,       // Only cache messages > 100 chars
136        }
137    }
138}
139
140/// Compression cache implementation.
141#[derive(Debug)]
142pub struct CompressionCache {
143    entries: HashMap<u64, CacheEntry>,
144    config: CacheConfig,
145    stats: CacheStats,
146}
147
148impl Default for CompressionCache {
149    fn default() -> Self {
150        Self::new(CacheConfig::default())
151    }
152}
153
154impl CompressionCache {
155    pub fn new(config: CacheConfig) -> Self {
156        Self {
157            entries: HashMap::new(),
158            config,
159            stats: CacheStats::default(),
160        }
161    }
162
163    /// Calculate hash for a message.
164    fn hash_message(message: &Message) -> u64 {
165        use std::collections::hash_map::DefaultHasher;
166        use std::hash::{Hash, Hasher};
167
168        let mut hasher = DefaultHasher::new();
169        
170        // Hash role as string
171        let role_str = match message.role {
172            crate::providers::Role::User => "user",
173            crate::providers::Role::Assistant => "assistant",
174            crate::providers::Role::System => "system",
175            crate::providers::Role::Tool => "tool",
176        };
177        role_str.hash(&mut hasher);
178
179        // Hash content
180        match &message.content {
181            crate::providers::MessageContent::Text(text) => {
182                text.hash(&mut hasher);
183            }
184            crate::providers::MessageContent::Blocks(blocks) => {
185                // Hash each block's string representation
186                for block in blocks {
187                    let block_str = format!("{:?}", block);
188                    block_str.hash(&mut hasher);
189                }
190            }
191        }
192
193        hasher.finish()
194    }
195
196    /// Check if a message is in the cache.
197    pub fn get(&mut self, message: &Message) -> Option<&CacheEntry> {
198        let hash = Self::hash_message(message);
199
200        if let Some(entry) = self.entries.get(&hash) {
201            // Check TTL
202            if entry.created_at.elapsed() < self.config.ttl {
203                self.stats.hits += 1;
204                let entry = self.entries.get_mut(&hash).unwrap();
205                entry.hit_count += 1;
206                return Some(entry);
207            } else {
208                // Expired, remove it
209                self.entries.remove(&hash);
210            }
211        }
212
213        self.stats.misses += 1;
214        None
215    }
216
217    /// Add a compressed message to the cache.
218    pub fn put(&mut self, original: &Message, compressed: Message) {
219        let hash = Self::hash_message(original);
220
221        // Check minimum size
222        let size = match &original.content {
223            crate::providers::MessageContent::Text(text) => text.len(),
224            crate::providers::MessageContent::Blocks(blocks) => {
225                blocks.iter().map(|b| format!("{:?}", b).len()).sum()
226            }
227        };
228
229        if size < self.config.min_size_to_cache {
230            return;
231        }
232
233        // Evict old entries if at capacity
234        if self.entries.len() >= self.config.max_entries {
235            self.evict_oldest();
236        }
237
238        self.entries.insert(
239            hash,
240            CacheEntry {
241                compressed,
242                hash,
243                created_at: Instant::now(),
244                hit_count: 0,
245            },
246        );
247        self.stats.entries = self.entries.len();
248    }
249
250    /// Evict the oldest entry.
251    fn evict_oldest(&mut self) {
252        if let Some((&oldest_hash, _)) = self
253            .entries
254            .iter()
255            .min_by_key(|(_, entry)| entry.created_at)
256        {
257            self.entries.remove(&oldest_hash);
258        }
259    }
260
261    /// Evict expired entries.
262    pub fn evict_expired(&mut self) {
263        let now = Instant::now();
264        self.entries.retain(|_, entry| {
265            now.duration_since(entry.created_at) < self.config.ttl
266        });
267        self.stats.entries = self.entries.len();
268    }
269
270    /// Clear the cache.
271    pub fn clear(&mut self) {
272        self.entries.clear();
273        self.stats.entries = 0;
274    }
275
276    /// Get cache statistics.
277    pub fn stats(&self) -> &CacheStats {
278        &self.stats
279    }
280
281    /// Get the number of cached entries.
282    pub fn len(&self) -> usize {
283        self.entries.len()
284    }
285
286    /// Check if the cache is empty.
287    pub fn is_empty(&self) -> bool {
288        self.entries.is_empty()
289    }
290
291    /// Record token savings from cache hit.
292    pub fn record_token_savings(&mut self, tokens: u32) {
293        self.stats.total_saved_tokens += tokens;
294    }
295}
296
297/// 扩展压缩缓存:支持多种类型的缓存
298/// 
299/// 包括:焦点预测、优先级分数、复杂度分析等
300#[derive(Debug)]
301pub struct ExtendedCompressionCache {
302    /// 基础摘要缓存(现有)
303    base_cache: CompressionCache,
304    
305    /// 焦点预测缓存:message_id -> predicted focus
306    focus_predictions: HashMap<String, CachedFocusPrediction>,
307    
308    /// 优先级分数缓存:message_id -> priority score
309    priority_scores: HashMap<String, CachedPriorityScore>,
310    
311    /// 复杂度分析缓存:conversation_id -> complexity
312    complexity_cache: HashMap<String, CachedComplexity>,
313    
314    /// 配置
315    config: ExtendedCacheConfig,
316}
317
318/// 扩展缓存配置
319#[derive(Debug, Clone)]
320pub struct ExtendedCacheConfig {
321    /// 优先级分数缓存有效期
322    priority_validity: Duration,
323    /// 焦点预测缓存最大数量
324    max_focus_predictions: usize,
325    /// 复杂度缓存最大数量
326    max_complexity_entries: usize,
327}
328
329impl Default for ExtendedCacheConfig {
330    fn default() -> Self {
331        Self {
332            priority_validity: Duration::from_secs(600),  // 10 minutes
333            max_focus_predictions: 50,
334            max_complexity_entries: 20,
335        }
336    }
337}
338
339impl Default for ExtendedCompressionCache {
340    fn default() -> Self {
341        Self::new(ExtendedCacheConfig::default())
342    }
343}
344
345impl ExtendedCompressionCache {
346    pub fn new(config: ExtendedCacheConfig) -> Self {
347        Self {
348            base_cache: CompressionCache::default(),
349            focus_predictions: HashMap::new(),
350            priority_scores: HashMap::new(),
351            complexity_cache: HashMap::new(),
352            config,
353        }
354    }
355    
356    /// 获取优先级分数(如果缓存有效)
357    pub fn get_priority_score(&self, message_id: &str) -> Option<&CachedPriorityScore> {
358        self.priority_scores.get(message_id)
359            .filter(|cached| cached.is_valid())
360    }
361    
362    /// 添加优先级分数缓存
363    pub fn put_priority_score(&mut self, message_id: String, score: CachedPriorityScore) {
364        // 如果超出容量,移除最旧的
365        if self.priority_scores.len() >= self.config.max_focus_predictions {
366            self.evict_oldest_priority();
367        }
368        
369        self.priority_scores.insert(message_id, score);
370    }
371    
372    /// 获取焦点预测
373    pub fn get_focus_prediction(&self, message_id: &str) -> Option<&CachedFocusPrediction> {
374        self.focus_predictions.get(message_id)
375    }
376    
377    /// 添加焦点预测缓存
378    pub fn put_focus_prediction(&mut self, message_id: String, prediction: CachedFocusPrediction) {
379        if self.focus_predictions.len() >= self.config.max_focus_predictions {
380            self.evict_oldest_focus();
381        }
382        
383        self.focus_predictions.insert(message_id, prediction);
384    }
385    
386    /// 获取复杂度分析
387    pub fn get_complexity(&self, conversation_id: &str) -> Option<&CachedComplexity> {
388        self.complexity_cache.get(conversation_id)
389    }
390    
391    /// 添加复杂度缓存
392    pub fn put_complexity(&mut self, conversation_id: String, complexity: CachedComplexity) {
393        if self.complexity_cache.len() >= self.config.max_complexity_entries {
394            self.evict_oldest_complexity();
395        }
396        
397        self.complexity_cache.insert(conversation_id, complexity);
398    }
399    
400    /// 增量更新优先级分数(基于新消息)
401    pub fn update_priority_incremental(&mut self, new_keywords: &[String], existing_messages: &[Message]) {
402        let now = Utc::now();
403        
404        for (id, cached) in &mut self.priority_scores {
405            // 检查关键词重叠
406            let overlap_count = cached.keywords.iter()
407                .filter(|kw| new_keywords.contains(kw))
408                .count();
409            
410            // 如果有重叠,更新相关性分数
411            if overlap_count > 0 {
412                cached.score += overlap_count as f32 * 0.1;
413                cached.calculated_at = now;
414            }
415        }
416    }
417    
418    /// 清理过期缓存
419    pub fn cleanup_expired(&mut self) {
420        // 清理过期的优先级分数
421        self.priority_scores.retain(|_, cached| cached.is_valid());
422        
423        // 清理基础缓存过期项
424        self.base_cache.evict_expired();
425    }
426    
427    /// 移除最旧的优先级分数
428    fn evict_oldest_priority(&mut self) {
429        if let Some((oldest_id, _)) = self.priority_scores.iter()
430            .min_by_key(|(_, cached)| cached.calculated_at)
431        {
432            let id = oldest_id.clone();
433            self.priority_scores.remove(&id);
434        }
435    }
436    
437    /// 移除最旧的焦点预测
438    fn evict_oldest_focus(&mut self) {
439        if let Some((oldest_id, _)) = self.focus_predictions.iter()
440            .min_by_key(|(_, cached)| cached.predicted_at)
441        {
442            let id = oldest_id.clone();
443            self.focus_predictions.remove(&id);
444        }
445    }
446    
447    /// 移除最旧的复杂度缓存
448    fn evict_oldest_complexity(&mut self) {
449        if let Some((oldest_id, _)) = self.complexity_cache.iter()
450            .min_by_key(|(_, cached)| cached.analyzed_at)
451        {
452            let id = oldest_id.clone();
453            self.complexity_cache.remove(&id);
454        }
455    }
456    
457    /// 获取基础缓存
458    pub fn base_cache(&self) -> &CompressionCache {
459        &self.base_cache
460    }
461    
462    /// 获取基础缓存(可变)
463    pub fn base_cache_mut(&mut self) -> &mut CompressionCache {
464        &mut self.base_cache
465    }
466    
467    /// 清空所有缓存
468    pub fn clear_all(&mut self) {
469        self.base_cache.clear();
470        self.focus_predictions.clear();
471        self.priority_scores.clear();
472        self.complexity_cache.clear();
473    }
474    
475    /// 获取缓存统计
476    pub fn extended_stats(&self) -> ExtendedCacheStats {
477        ExtendedCacheStats {
478            base_stats: self.base_cache.stats().clone(),
479            focus_prediction_count: self.focus_predictions.len(),
480            priority_score_count: self.priority_scores.len(),
481            complexity_cache_count: self.complexity_cache.len(),
482        }
483    }
484}
485
486/// 扩展缓存统计
487#[derive(Debug, Clone)]
488pub struct ExtendedCacheStats {
489    pub base_stats: CacheStats,
490    pub focus_prediction_count: usize,
491    pub priority_score_count: usize,
492    pub complexity_cache_count: usize,
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::providers::{MessageContent, Role};
499
500    fn create_test_message(content: &str) -> Message {
501        Message {
502            role: Role::User,
503            content: MessageContent::Text(content.to_string()),
504        }
505    }
506
507    #[test]
508    fn test_cache_put_and_get() {
509        let mut cache = CompressionCache::default();
510        let original = create_test_message("This is a test message that is long enough to be cached");
511        let compressed = create_test_message("This is a test message...");
512
513        // Put in cache
514        cache.put(&original, compressed.clone());
515
516        // Get from cache
517        let entry = cache.get(&original);
518        assert!(entry.is_some());
519        assert_eq!(entry.unwrap().hit_count, 1);
520    }
521
522    #[test]
523    fn test_cache_miss() {
524        let mut cache = CompressionCache::default();
525        let msg = create_test_message("Test message");
526
527        let entry = cache.get(&msg);
528        assert!(entry.is_none());
529        assert_eq!(cache.stats().misses, 1);
530    }
531
532    #[test]
533    fn test_cache_hit_increments_counter() {
534        let mut cache = CompressionCache::default();
535        let original = create_test_message("This is a longer test message for caching purposes");
536        let compressed = create_test_message("Longer test message...");
537
538        cache.put(&original, compressed);
539
540        // Get multiple times
541        cache.get(&original);
542        cache.get(&original);
543        cache.get(&original);
544
545        assert_eq!(cache.stats().hits, 3);
546    }
547
548    #[test]
549    fn test_cache_minimum_size() {
550        let config = CacheConfig {
551            min_size_to_cache: 50,
552            ..Default::default()
553        };
554        let mut cache = CompressionCache::new(config);
555
556        let small_msg = create_test_message("Short");
557        let compressed = create_test_message("...");
558
559        cache.put(&small_msg, compressed);
560
561        // Should not be cached (too small)
562        assert!(cache.get(&small_msg).is_none());
563    }
564
565    #[test]
566    fn test_cache_eviction() {
567        let config = CacheConfig {
568            max_entries: 2,
569            ..Default::default()
570        };
571        let mut cache = CompressionCache::new(config);
572
573        let msg1 = create_test_message("Message 1 - long enough for caching");
574        let msg2 = create_test_message("Message 2 - also long enough");
575        let msg3 = create_test_message("Message 3 - this one too");
576
577        cache.put(&msg1, msg1.clone());
578        cache.put(&msg2, msg2.clone());
579        assert_eq!(cache.len(), 2);
580
581        // Adding a third should evict the oldest
582        cache.put(&msg3, msg3.clone());
583        assert_eq!(cache.len(), 2);
584
585        // msg1 should have been evicted
586        assert!(cache.get(&msg1).is_none());
587        assert!(cache.get(&msg2).is_some());
588        assert!(cache.get(&msg3).is_some());
589    }
590
591    #[test]
592    fn test_cache_clear() {
593        let mut cache = CompressionCache::default();
594        let msg = create_test_message("Long enough message for the cache system");
595
596        cache.put(&msg, msg.clone());
597        assert!(!cache.is_empty());
598
599        cache.clear();
600        assert!(cache.is_empty());
601    }
602
603    #[test]
604    fn test_cache_stats() {
605        let mut cache = CompressionCache::default();
606        let msg = create_test_message("This is a test message for statistics tracking");
607
608        // Miss
609        cache.get(&msg);
610        assert_eq!(cache.stats().misses, 1);
611        assert_eq!(cache.stats().hits, 0);
612
613        // Put and hit
614        cache.put(&msg, msg.clone());
615        cache.get(&msg);
616        assert_eq!(cache.stats().hits, 1);
617
618        // Hit rate
619        assert_eq!(cache.stats().hit_rate(), 0.5);
620    }
621
622    #[test]
623    fn test_message_hash_consistency() {
624        let msg1 = create_test_message("Test message");
625        let msg2 = create_test_message("Test message");
626        let msg3 = create_test_message("Different message");
627
628        let hash1 = CompressionCache::hash_message(&msg1);
629        let hash2 = CompressionCache::hash_message(&msg2);
630        let hash3 = CompressionCache::hash_message(&msg3);
631
632        // Same content should have same hash
633        assert_eq!(hash1, hash2);
634        // Different content should have different hash
635        assert_ne!(hash1, hash3);
636    }
637}