Skip to main content

matrixcode_core/compress/
cache.rs

1//! Compression cache to avoid redundant compression operations.
2//!
3//! This module provides caching for compression results to improve
4//! performance when dealing with repeated compression of similar content.
5//!
6//! # Extended Cache Features
7//! - 焦点预测缓存
8//! - 优先级分数缓存
9//! - Coherence 分段缓存
10//! - 复杂度分析缓存
11
12use crate::providers::Message;
13use crate::compress::complexity::ComplexityLevel;
14use crate::compress::focus_point::FocusPoint;
15use std::collections::HashMap;
16use std::time::{Duration, Instant};
17use chrono::{DateTime, Utc};
18
19/// Cache entry for a compressed message.
20#[derive(Debug, Clone)]
21pub struct CacheEntry {
22    /// Compressed message
23    pub compressed: Message,
24    /// Original content hash
25    pub hash: u64,
26    /// When the entry was created
27    pub created_at: Instant,
28    /// Number of times this entry was used
29    pub hit_count: usize,
30}
31
32/// Statistics for the compression cache.
33#[derive(Debug, Clone, Default)]
34pub struct CacheStats {
35    pub hits: usize,
36    pub misses: usize,
37    pub entries: usize,
38    pub total_saved_tokens: u32,
39}
40
41impl CacheStats {
42    pub fn hit_rate(&self) -> f32 {
43        if self.hits + self.misses == 0 {
44            0.0
45        } else {
46            self.hits as f32 / (self.hits + self.misses) as f32
47        }
48    }
49}
50
51/// Cached priority score with validity tracking
52#[derive(Debug, Clone)]
53pub struct CachedPriorityScore {
54    /// Priority score value
55    pub score: f32,
56    /// When the score was calculated
57    pub calculated_at: DateTime<Utc>,
58    /// How long the score remains valid
59    pub valid_for: Duration,
60    /// Keywords that influenced this score
61    pub keywords: Vec<String>,
62}
63
64impl CachedPriorityScore {
65    pub fn new(score: f32, keywords: Vec<String>, valid_for: Duration) -> Self {
66        Self {
67            score,
68            calculated_at: Utc::now(),
69            valid_for,
70            keywords,
71        }
72    }
73    
74    pub fn is_valid(&self) -> bool {
75        let now = Utc::now();
76        now - self.calculated_at < chrono::Duration::from_std(self.valid_for).unwrap()
77    }
78}
79
80/// Cached focus prediction
81#[derive(Debug, Clone)]
82pub struct CachedFocusPrediction {
83    /// Predicted focus point
84    pub focus: FocusPoint,
85    /// Prediction confidence
86    pub confidence: f32,
87    /// When the prediction was made
88    pub predicted_at: DateTime<Utc>,
89}
90
91impl CachedFocusPrediction {
92    pub fn new(focus: FocusPoint, confidence: f32) -> Self {
93        Self {
94            focus,
95            confidence,
96            predicted_at: Utc::now(),
97        }
98    }
99}
100
101/// Cached complexity level
102#[derive(Debug, Clone)]
103pub struct CachedComplexity {
104    /// Complexity level
105    pub level: ComplexityLevel,
106    /// When the analysis was performed
107    pub analyzed_at: DateTime<Utc>,
108}
109
110impl CachedComplexity {
111    pub fn new(level: ComplexityLevel) -> Self {
112        Self {
113            level,
114            analyzed_at: Utc::now(),
115        }
116    }
117}
118
119/// Compression cache configuration.
120#[derive(Debug, Clone)]
121pub struct CacheConfig {
122    /// Maximum number of entries
123    pub max_entries: usize,
124    /// Time-to-live for cache entries
125    pub ttl: Duration,
126    /// Minimum message size to cache (in characters)
127    pub min_size_to_cache: usize,
128}
129
130impl Default for CacheConfig {
131    fn default() -> Self {
132        Self {
133            max_entries: 100,
134            ttl: Duration::from_secs(300), // 5 minutes
135            min_size_to_cache: 100,       // Only cache messages > 100 chars
136        }
137    }
138}
139
140/// Compression cache implementation.
141#[derive(Debug)]
142pub struct CompressionCache {
143    entries: HashMap<u64, CacheEntry>,
144    config: CacheConfig,
145    stats: CacheStats,
146}
147
148impl Default for CompressionCache {
149    fn default() -> Self {
150        Self::new(CacheConfig::default())
151    }
152}
153
154impl CompressionCache {
155    pub fn new(config: CacheConfig) -> Self {
156        Self {
157            entries: HashMap::new(),
158            config,
159            stats: CacheStats::default(),
160        }
161    }
162
163    /// Calculate hash for a message.
164    fn hash_message(message: &Message) -> u64 {
165        use std::collections::hash_map::DefaultHasher;
166        use std::hash::{Hash, Hasher};
167
168        let mut hasher = DefaultHasher::new();
169        
170        // Hash role as string
171        let role_str = match message.role {
172            crate::providers::Role::User => "user",
173            crate::providers::Role::Assistant => "assistant",
174            crate::providers::Role::System => "system",
175            crate::providers::Role::Tool => "tool",
176        };
177        role_str.hash(&mut hasher);
178
179        // Hash content
180        match &message.content {
181            crate::providers::MessageContent::Text(text) => {
182                text.hash(&mut hasher);
183            }
184            crate::providers::MessageContent::Blocks(blocks) => {
185                // Hash each block's string representation
186                for block in blocks {
187                    let block_str = format!("{:?}", block);
188                    block_str.hash(&mut hasher);
189                }
190            }
191        }
192
193        hasher.finish()
194    }
195
196    /// Check if a message is in the cache.
197    pub fn get(&mut self, message: &Message) -> Option<&CacheEntry> {
198        let hash = Self::hash_message(message);
199
200        if let Some(entry) = self.entries.get(&hash) {
201            // Check TTL
202            if entry.created_at.elapsed() < self.config.ttl {
203                self.stats.hits += 1;
204                let entry = self.entries.get_mut(&hash).unwrap();
205                entry.hit_count += 1;
206                return Some(entry);
207            } else {
208                // Expired, remove it
209                self.entries.remove(&hash);
210            }
211        }
212
213        self.stats.misses += 1;
214        None
215    }
216
217    /// Add a compressed message to the cache.
218    pub fn put(&mut self, original: &Message, compressed: Message) {
219        let hash = Self::hash_message(original);
220
221        // Check minimum size
222        let size = match &original.content {
223            crate::providers::MessageContent::Text(text) => text.len(),
224            crate::providers::MessageContent::Blocks(blocks) => {
225                blocks.iter().map(|b| format!("{:?}", b).len()).sum()
226            }
227        };
228
229        if size < self.config.min_size_to_cache {
230            return;
231        }
232
233        // Evict old entries if at capacity
234        if self.entries.len() >= self.config.max_entries {
235            self.evict_oldest();
236        }
237
238        self.entries.insert(
239            hash,
240            CacheEntry {
241                compressed,
242                hash,
243                created_at: Instant::now(),
244                hit_count: 0,
245            },
246        );
247        self.stats.entries = self.entries.len();
248    }
249
250    /// Evict the oldest entry.
251    fn evict_oldest(&mut self) {
252        if let Some((&oldest_hash, _)) = self
253            .entries
254            .iter()
255            .min_by_key(|(_, entry)| entry.created_at)
256        {
257            self.entries.remove(&oldest_hash);
258        }
259    }
260
261    /// Evict expired entries.
262    pub fn evict_expired(&mut self) {
263        let now = Instant::now();
264        self.entries.retain(|_, entry| {
265            now.duration_since(entry.created_at) < self.config.ttl
266        });
267        self.stats.entries = self.entries.len();
268    }
269
270    /// Clear the cache.
271    pub fn clear(&mut self) {
272        self.entries.clear();
273        self.stats.entries = 0;
274    }
275
276    /// Get cache statistics.
277    pub fn stats(&self) -> &CacheStats {
278        &self.stats
279    }
280
281    /// Get the number of cached entries.
282    pub fn len(&self) -> usize {
283        self.entries.len()
284    }
285
286    /// Check if the cache is empty.
287    pub fn is_empty(&self) -> bool {
288        self.entries.is_empty()
289    }
290
291    /// Record token savings from cache hit.
292    pub fn record_token_savings(&mut self, tokens: u32) {
293        self.stats.total_saved_tokens += tokens;
294    }
295}
296
297/// 扩展压缩缓存:支持多种类型的缓存
298/// 
299/// 包括:焦点预测、优先级分数、复杂度分析等
300#[derive(Debug)]
301pub struct ExtendedCompressionCache {
302    /// 基础摘要缓存(现有)
303    base_cache: CompressionCache,
304    
305    /// 焦点预测缓存:message_id -> predicted focus
306    focus_predictions: HashMap<String, CachedFocusPrediction>,
307    
308    /// 优先级分数缓存:message_id -> priority score
309    priority_scores: HashMap<String, CachedPriorityScore>,
310    
311    /// 复杂度分析缓存:conversation_id -> complexity
312    complexity_cache: HashMap<String, CachedComplexity>,
313    
314    /// 配置
315    config: ExtendedCacheConfig,
316}
317
318/// 扩展缓存配置
319#[derive(Debug, Clone)]
320pub struct ExtendedCacheConfig {
321    /// 焦点预测缓存最大数量
322    max_focus_predictions: usize,
323    /// 复杂度缓存最大数量
324    max_complexity_entries: usize,
325}
326
327impl Default for ExtendedCacheConfig {
328    fn default() -> Self {
329        Self {
330            max_focus_predictions: 50,
331            max_complexity_entries: 20,
332        }
333    }
334}
335
336impl Default for ExtendedCompressionCache {
337    fn default() -> Self {
338        Self::new(ExtendedCacheConfig::default())
339    }
340}
341
342impl ExtendedCompressionCache {
343    pub fn new(config: ExtendedCacheConfig) -> Self {
344        Self {
345            base_cache: CompressionCache::default(),
346            focus_predictions: HashMap::new(),
347            priority_scores: HashMap::new(),
348            complexity_cache: HashMap::new(),
349            config,
350        }
351    }
352    
353    /// 获取优先级分数(如果缓存有效)
354    pub fn get_priority_score(&self, message_id: &str) -> Option<&CachedPriorityScore> {
355        self.priority_scores.get(message_id)
356            .filter(|cached| cached.is_valid())
357    }
358    
359    /// 添加优先级分数缓存
360    pub fn put_priority_score(&mut self, message_id: String, score: CachedPriorityScore) {
361        // 如果超出容量,移除最旧的
362        if self.priority_scores.len() >= self.config.max_focus_predictions {
363            self.evict_oldest_priority();
364        }
365        
366        self.priority_scores.insert(message_id, score);
367    }
368    
369    /// 获取焦点预测
370    pub fn get_focus_prediction(&self, message_id: &str) -> Option<&CachedFocusPrediction> {
371        self.focus_predictions.get(message_id)
372    }
373    
374    /// 添加焦点预测缓存
375    pub fn put_focus_prediction(&mut self, message_id: String, prediction: CachedFocusPrediction) {
376        if self.focus_predictions.len() >= self.config.max_focus_predictions {
377            self.evict_oldest_focus();
378        }
379        
380        self.focus_predictions.insert(message_id, prediction);
381    }
382    
383    /// 获取复杂度分析
384    pub fn get_complexity(&self, conversation_id: &str) -> Option<&CachedComplexity> {
385        self.complexity_cache.get(conversation_id)
386    }
387    
388    /// 添加复杂度缓存
389    pub fn put_complexity(&mut self, conversation_id: String, complexity: CachedComplexity) {
390        if self.complexity_cache.len() >= self.config.max_complexity_entries {
391            self.evict_oldest_complexity();
392        }
393        
394        self.complexity_cache.insert(conversation_id, complexity);
395    }
396    
397    /// 增量更新优先级分数(基于新消息)
398    pub fn update_priority_incremental(&mut self, new_keywords: &[String], _existing_messages: &[Message]) {
399        let now = Utc::now();
400
401        for (_id, cached) in &mut self.priority_scores {
402            // 检查关键词重叠
403            let overlap_count = cached.keywords.iter()
404                .filter(|kw| new_keywords.contains(kw))
405                .count();
406            
407            // 如果有重叠,更新相关性分数
408            if overlap_count > 0 {
409                cached.score += overlap_count as f32 * 0.1;
410                cached.calculated_at = now;
411            }
412        }
413    }
414    
415    /// 清理过期缓存
416    pub fn cleanup_expired(&mut self) {
417        // 清理过期的优先级分数
418        self.priority_scores.retain(|_, cached| cached.is_valid());
419        
420        // 清理基础缓存过期项
421        self.base_cache.evict_expired();
422    }
423    
424    /// 移除最旧的优先级分数
425    fn evict_oldest_priority(&mut self) {
426        if let Some((oldest_id, _)) = self.priority_scores.iter()
427            .min_by_key(|(_, cached)| cached.calculated_at)
428        {
429            let id = oldest_id.clone();
430            self.priority_scores.remove(&id);
431        }
432    }
433    
434    /// 移除最旧的焦点预测
435    fn evict_oldest_focus(&mut self) {
436        if let Some((oldest_id, _)) = self.focus_predictions.iter()
437            .min_by_key(|(_, cached)| cached.predicted_at)
438        {
439            let id = oldest_id.clone();
440            self.focus_predictions.remove(&id);
441        }
442    }
443    
444    /// 移除最旧的复杂度缓存
445    fn evict_oldest_complexity(&mut self) {
446        if let Some((oldest_id, _)) = self.complexity_cache.iter()
447            .min_by_key(|(_, cached)| cached.analyzed_at)
448        {
449            let id = oldest_id.clone();
450            self.complexity_cache.remove(&id);
451        }
452    }
453    
454    /// 获取基础缓存
455    pub fn base_cache(&self) -> &CompressionCache {
456        &self.base_cache
457    }
458    
459    /// 获取基础缓存(可变)
460    pub fn base_cache_mut(&mut self) -> &mut CompressionCache {
461        &mut self.base_cache
462    }
463    
464    /// 清空所有缓存
465    pub fn clear_all(&mut self) {
466        self.base_cache.clear();
467        self.focus_predictions.clear();
468        self.priority_scores.clear();
469        self.complexity_cache.clear();
470    }
471    
472    /// 获取缓存统计
473    pub fn extended_stats(&self) -> ExtendedCacheStats {
474        ExtendedCacheStats {
475            base_stats: self.base_cache.stats().clone(),
476            focus_prediction_count: self.focus_predictions.len(),
477            priority_score_count: self.priority_scores.len(),
478            complexity_cache_count: self.complexity_cache.len(),
479        }
480    }
481}
482
483/// 扩展缓存统计
484#[derive(Debug, Clone)]
485pub struct ExtendedCacheStats {
486    pub base_stats: CacheStats,
487    pub focus_prediction_count: usize,
488    pub priority_score_count: usize,
489    pub complexity_cache_count: usize,
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use crate::providers::{MessageContent, Role};
496
497    fn create_test_message(content: &str) -> Message {
498        Message {
499            role: Role::User,
500            content: MessageContent::Text(content.to_string()),
501        }
502    }
503
504    #[test]
505    fn test_cache_put_and_get() {
506        let mut cache = CompressionCache::default();
507        let original = create_test_message("This is a test message that is long enough to be cached");
508        let compressed = create_test_message("This is a test message...");
509
510        // Put in cache
511        cache.put(&original, compressed.clone());
512
513        // Get from cache
514        let entry = cache.get(&original);
515        assert!(entry.is_some());
516        assert_eq!(entry.unwrap().hit_count, 1);
517    }
518
519    #[test]
520    fn test_cache_miss() {
521        let mut cache = CompressionCache::default();
522        let msg = create_test_message("Test message");
523
524        let entry = cache.get(&msg);
525        assert!(entry.is_none());
526        assert_eq!(cache.stats().misses, 1);
527    }
528
529    #[test]
530    fn test_cache_hit_increments_counter() {
531        let mut cache = CompressionCache::default();
532        let original = create_test_message("This is a longer test message for caching purposes");
533        let compressed = create_test_message("Longer test message...");
534
535        cache.put(&original, compressed);
536
537        // Get multiple times
538        cache.get(&original);
539        cache.get(&original);
540        cache.get(&original);
541
542        assert_eq!(cache.stats().hits, 3);
543    }
544
545    #[test]
546    fn test_cache_minimum_size() {
547        let config = CacheConfig {
548            min_size_to_cache: 50,
549            ..Default::default()
550        };
551        let mut cache = CompressionCache::new(config);
552
553        let small_msg = create_test_message("Short");
554        let compressed = create_test_message("...");
555
556        cache.put(&small_msg, compressed);
557
558        // Should not be cached (too small)
559        assert!(cache.get(&small_msg).is_none());
560    }
561
562    #[test]
563    fn test_cache_eviction() {
564        let config = CacheConfig {
565            max_entries: 2,
566            ..Default::default()
567        };
568        let mut cache = CompressionCache::new(config);
569
570        let msg1 = create_test_message("Message 1 - long enough for caching");
571        let msg2 = create_test_message("Message 2 - also long enough");
572        let msg3 = create_test_message("Message 3 - this one too");
573
574        cache.put(&msg1, msg1.clone());
575        cache.put(&msg2, msg2.clone());
576        assert_eq!(cache.len(), 2);
577
578        // Adding a third should evict the oldest
579        cache.put(&msg3, msg3.clone());
580        assert_eq!(cache.len(), 2);
581
582        // msg1 should have been evicted
583        assert!(cache.get(&msg1).is_none());
584        assert!(cache.get(&msg2).is_some());
585        assert!(cache.get(&msg3).is_some());
586    }
587
588    #[test]
589    fn test_cache_clear() {
590        let mut cache = CompressionCache::default();
591        let msg = create_test_message("Long enough message for the cache system");
592
593        cache.put(&msg, msg.clone());
594        assert!(!cache.is_empty());
595
596        cache.clear();
597        assert!(cache.is_empty());
598    }
599
600    #[test]
601    fn test_cache_stats() {
602        let mut cache = CompressionCache::default();
603        let msg = create_test_message("This is a test message for statistics tracking");
604
605        // Miss
606        cache.get(&msg);
607        assert_eq!(cache.stats().misses, 1);
608        assert_eq!(cache.stats().hits, 0);
609
610        // Put and hit
611        cache.put(&msg, msg.clone());
612        cache.get(&msg);
613        assert_eq!(cache.stats().hits, 1);
614
615        // Hit rate
616        assert_eq!(cache.stats().hit_rate(), 0.5);
617    }
618
619    #[test]
620    fn test_message_hash_consistency() {
621        let msg1 = create_test_message("Test message");
622        let msg2 = create_test_message("Test message");
623        let msg3 = create_test_message("Different message");
624
625        let hash1 = CompressionCache::hash_message(&msg1);
626        let hash2 = CompressionCache::hash_message(&msg2);
627        let hash3 = CompressionCache::hash_message(&msg3);
628
629        // Same content should have same hash
630        assert_eq!(hash1, hash2);
631        // Different content should have different hash
632        assert_ne!(hash1, hash3);
633    }
634}