Skip to main content

matrixcode_core/prompt/
cache.rs

1//! Section Cache System
2//!
3//! Provides caching for static sections to:
4//! - Reduce token costs (cached content not re-computed)
5//! - Enable prompt prefix caching for API efficiency
6//! - Track cache statistics for optimization
7
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11
12/// Cache key for a section
13#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14pub struct CacheKey {
15    /// Section name
16    pub name: String,
17    /// Profile (default, safe, fast, review)
18    pub profile: String,
19    /// Optional hash of content for validation
20    pub content_hash: Option<u64>,
21}
22
23impl CacheKey {
24    pub fn new(name: impl Into<String>, profile: impl Into<String>) -> Self {
25        Self {
26            name: name.into(),
27            profile: profile.into(),
28            content_hash: None,
29        }
30    }
31
32    pub fn with_hash(self, hash: u64) -> Self {
33        Self { content_hash: Some(hash), ..self }
34    }
35}
36
37/// Cached entry with metadata
38#[derive(Debug, Clone)]
39pub struct CachedEntry {
40    /// Cached content
41    pub content: String,
42    /// When it was cached
43    pub cached_at: Instant,
44    /// Estimated token count
45    pub token_count: usize,
46    /// Number of times used
47    pub use_count: u64,
48}
49
50impl CachedEntry {
51    pub fn new(content: String) -> Self {
52        let token_count = estimate_tokens(&content);
53        Self {
54            content,
55            cached_at: Instant::now(),
56            token_count,
57            use_count: 0,
58        }
59    }
60
61    /// Check if entry is expired
62    pub fn is_expired(&self, max_age: Duration) -> bool {
63        self.cached_at.elapsed() > max_age
64    }
65
66    /// Mark as used
67    pub fn mark_used(&mut self) {
68        self.use_count += 1;
69    }
70}
71
72/// Section cache with statistics
73pub struct SectionCache {
74    /// Cached entries
75    entries: RwLock<HashMap<CacheKey, CachedEntry>>,
76    /// Maximum cache age
77    max_age: Duration,
78    /// Statistics
79    stats: RwLock<CacheStats>,
80}
81
82/// Cache statistics
83#[derive(Debug, Clone, Default)]
84pub struct CacheStats {
85    /// Total cached entries
86    pub total_entries: usize,
87    /// Total hits
88    pub total_hits: u64,
89    /// Total misses
90    pub total_misses: u64,
91    /// Total evictions
92    pub total_evictions: u64,
93    /// Estimated tokens saved
94    pub tokens_saved: u64,
95}
96
97impl CacheStats {
98    pub fn hit_rate(&self) -> f64 {
99        if self.total_hits + self.total_misses == 0 {
100            0.0
101        } else {
102            self.total_hits as f64 / (self.total_hits + self.total_misses) as f64
103        }
104    }
105}
106
107impl SectionCache {
108    /// Create a new cache with default max age
109    pub fn new() -> Self {
110        Self {
111            entries: RwLock::new(HashMap::new()),
112            max_age: Duration::from_secs(3600), // 1 hour default
113            stats: RwLock::new(CacheStats::default()),
114        }
115    }
116
117    /// Create cache with custom max age
118    pub fn with_max_age(max_age: Duration) -> Self {
119        Self {
120            entries: RwLock::new(HashMap::new()),
121            max_age,
122            stats: RwLock::new(CacheStats::default()),
123        }
124    }
125
126    /// Get a cached entry
127    pub fn get(&self, key: &CacheKey) -> Option<String> {
128        let mut entries = self.entries.write().unwrap();
129        let mut stats = self.stats.write().unwrap();
130        
131        if let Some(entry) = entries.get_mut(key) {
132            if entry.is_expired(self.max_age) {
133                // Expired, remove and count as miss
134                entries.remove(key);
135                stats.total_misses += 1;
136                stats.total_evictions += 1;
137                None
138            } else {
139                // Valid, mark as used
140                entry.mark_used();
141                stats.total_hits += 1;
142                stats.tokens_saved += entry.token_count as u64;
143                Some(entry.content.clone())
144            }
145        } else {
146            stats.total_misses += 1;
147            None
148        }
149    }
150
151    /// Set a cached entry
152    pub fn set(&self, key: CacheKey, content: String) {
153        let mut entries = self.entries.write().unwrap();
154        let mut stats = self.stats.write().unwrap();
155        
156        let entry = CachedEntry::new(content);
157        entries.insert(key, entry);
158        stats.total_entries = entries.len();
159    }
160
161    /// Get or compute (cache miss pattern)
162    pub fn get_or_compute<F>(&self, key: &CacheKey, compute: F) -> String
163    where
164        F: FnOnce() -> String,
165    {
166        if let Some(cached) = self.get(key) {
167            cached
168        } else {
169            let content = compute();
170            self.set(key.clone(), content.clone());
171            content
172        }
173    }
174
175    /// Clear all cache entries
176    pub fn clear(&self) {
177        let mut entries = self.entries.write().unwrap();
178        let mut stats = self.stats.write().unwrap();
179        
180        let evicted = entries.len();
181        entries.clear();
182        stats.total_entries = 0;
183        stats.total_evictions += evicted as u64;
184    }
185
186    /// Clear entries for a specific profile
187    pub fn clear_profile(&self, profile: &str) {
188        let mut entries = self.entries.write().unwrap();
189        let mut stats = self.stats.write().unwrap();
190        
191        entries.retain(|k, _| k.profile != profile);
192        stats.total_entries = entries.len();
193    }
194
195    /// Get statistics
196    pub fn stats(&self) -> CacheStats {
197        self.stats.read().unwrap().clone()
198    }
199
200    /// Get total cached token count
201    pub fn cached_tokens(&self) -> usize {
202        let entries = self.entries.read().unwrap();
203        entries.values().map(|e| e.token_count).sum()
204    }
205
206    /// Check if cache is empty
207    pub fn is_empty(&self) -> bool {
208        self.entries.read().unwrap().is_empty()
209    }
210
211    /// Get cache size
212    pub fn size(&self) -> usize {
213        self.entries.read().unwrap().len()
214    }
215}
216
217impl Default for SectionCache {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223// Note: Clone implementation removed - use Arc<SectionCache> for sharing
224// Full cloning of potentially large cache entries is expensive and unnecessary
225// when Arc provides cheap reference counting
226
227/// Estimate token count for content
228pub fn estimate_tokens(content: &str) -> usize {
229    // Rough estimate:
230    // - Chinese: ~3 chars per token (each Chinese char is ~1 token, but /3 for safety)
231    // - English words: ~1 token per word
232    // - Other ASCII chars: ~4 chars per token
233    let chinese_chars = content.chars().filter(|c| c.is_alphabetic() && c.len_utf8() > 1).count();
234    let english_words = content.split_whitespace().count();
235    let non_whitespace: usize = content.chars().filter(|c| !c.is_whitespace()).count();
236    
237    // Fallback: if no words detected (no whitespace), use char count / 4
238    let fallback_estimate = if english_words == 0 && non_whitespace > 0 {
239        non_whitespace / 4
240    } else {
241        0
242    };
243    
244    chinese_chars / 3 + english_words + fallback_estimate
245}
246
247/// Global cache instance
248static GLOBAL_CACHE: std::sync::OnceLock<Arc<SectionCache>> = std::sync::OnceLock::new();
249
250/// Get the global section cache
251pub fn global_cache() -> Arc<SectionCache> {
252    GLOBAL_CACHE.get_or_init(|| Arc::new(SectionCache::new())).clone()
253}
254
255/// Clear the global cache (for /clear, /compact, worktree switch)
256pub fn clear_global_cache() {
257    global_cache().clear();
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_cache_basic() {
266        let cache = SectionCache::new();
267        let key = CacheKey::new("test", "default");
268        
269        // Miss
270        assert!(cache.get(&key).is_none());
271        
272        // Set
273        cache.set(key.clone(), "test content".to_string());
274        
275        // Hit
276        assert_eq!(cache.get(&key), Some("test content".to_string()));
277        
278        // Stats
279        let stats = cache.stats();
280        assert_eq!(stats.total_hits, 1);
281        assert_eq!(stats.total_misses, 1);
282    }
283
284    #[test]
285    fn test_cache_expiry() {
286        let cache = SectionCache::with_max_age(Duration::from_millis(10));
287        let key = CacheKey::new("test", "default");
288        
289        cache.set(key.clone(), "test".to_string());
290        
291        // Wait for expiry
292        std::thread::sleep(Duration::from_millis(20));
293        
294        // Should be expired
295        assert!(cache.get(&key).is_none());
296        let stats = cache.stats();
297        assert_eq!(stats.total_evictions, 1);
298    }
299
300    #[test]
301    fn test_get_or_compute() {
302        let cache = SectionCache::new();
303        let key = CacheKey::new("compute", "default");
304        
305        let result = cache.get_or_compute(&key, || "computed".to_string());
306        assert_eq!(result, "computed");
307        
308        // Second call should use cache
309        let result2 = cache.get_or_compute(&key, || "different".to_string());
310        assert_eq!(result2, "computed"); // Still cached value
311    }
312
313    #[test]
314    fn test_clear_profile() {
315        let cache = SectionCache::new();
316        
317        cache.set(CacheKey::new("a", "default"), "a".to_string());
318        cache.set(CacheKey::new("b", "safe"), "b".to_string());
319        
320        cache.clear_profile("default");
321        
322        assert!(cache.get(&CacheKey::new("a", "default")).is_none());
323        assert_eq!(cache.get(&CacheKey::new("b", "safe")), Some("b".to_string()));
324    }
325
326    #[test]
327    fn test_estimate_tokens() {
328        let english = "Hello world this is a test";
329        let chinese = "你好世界这是一个测试";
330        
331        // English: 5 words, should be roughly 5-7 tokens
332        let eng_tokens = estimate_tokens(english);
333        assert!(eng_tokens >= 5 && eng_tokens <= 10, "English tokens: {}", eng_tokens);
334        
335        // Chinese: 9 chars / 3 = 3 tokens
336        let ch_tokens = estimate_tokens(chinese);
337        assert!(ch_tokens >= 2 && ch_tokens <= 10, "Chinese tokens: {}", ch_tokens);
338    }
339
340    #[test]
341    fn test_global_cache() {
342        clear_global_cache();
343        let cache = global_cache();
344        
345        let key = CacheKey::new("global_test", "default");
346        cache.set(key.clone(), "global content".to_string());
347        
348        // Should persist across calls
349        let cache2 = global_cache();
350        assert_eq!(cache2.get(&key), Some("global content".to_string()));
351        
352        clear_global_cache();
353        assert!(cache2.get(&key).is_none());
354    }
355}