Skip to main content

matrixcode_core/prompt/
cache.rs

1//! Section Cache System
2//!
3//! Provides caching for static sections to:
4//! - Reduce token costs (cached content not re-computed)
5//! - Enable prompt prefix caching for API efficiency
6//! - Track cache statistics for optimization
7
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11
12/// Cache key for a section
13#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14pub struct CacheKey {
15    /// Section name
16    pub name: String,
17    /// Profile (default, safe, fast, review)
18    pub profile: String,
19    /// Optional hash of content for validation
20    pub content_hash: Option<u64>,
21}
22
23impl CacheKey {
24    pub fn new(name: impl Into<String>, profile: impl Into<String>) -> Self {
25        Self {
26            name: name.into(),
27            profile: profile.into(),
28            content_hash: None,
29        }
30    }
31
32    pub fn with_hash(self, hash: u64) -> Self {
33        Self {
34            content_hash: Some(hash),
35            ..self
36        }
37    }
38}
39
40/// Cached entry with metadata
41#[derive(Debug, Clone)]
42pub struct CachedEntry {
43    /// Cached content
44    pub content: String,
45    /// When it was cached
46    pub cached_at: Instant,
47    /// Estimated token count
48    pub token_count: usize,
49    /// Number of times used
50    pub use_count: u64,
51}
52
53impl CachedEntry {
54    pub fn new(content: String) -> Self {
55        let token_count = estimate_tokens(&content);
56        Self {
57            content,
58            cached_at: Instant::now(),
59            token_count,
60            use_count: 0,
61        }
62    }
63
64    /// Check if entry is expired
65    pub fn is_expired(&self, max_age: Duration) -> bool {
66        self.cached_at.elapsed() > max_age
67    }
68
69    /// Mark as used
70    pub fn mark_used(&mut self) {
71        self.use_count += 1;
72    }
73}
74
75/// Section cache with statistics
76pub struct SectionCache {
77    /// Cached entries
78    entries: RwLock<HashMap<CacheKey, CachedEntry>>,
79    /// Maximum cache age
80    max_age: Duration,
81    /// Statistics
82    stats: RwLock<CacheStats>,
83}
84
85/// Cache statistics
86#[derive(Debug, Clone, Default)]
87pub struct CacheStats {
88    /// Total cached entries
89    pub total_entries: usize,
90    /// Total hits
91    pub total_hits: u64,
92    /// Total misses
93    pub total_misses: u64,
94    /// Total evictions
95    pub total_evictions: u64,
96    /// Estimated tokens saved
97    pub tokens_saved: u64,
98}
99
100impl CacheStats {
101    pub fn hit_rate(&self) -> f64 {
102        if self.total_hits + self.total_misses == 0 {
103            0.0
104        } else {
105            self.total_hits as f64 / (self.total_hits + self.total_misses) as f64
106        }
107    }
108}
109
110impl SectionCache {
111    /// Create a new cache with default max age
112    pub fn new() -> Self {
113        Self {
114            entries: RwLock::new(HashMap::new()),
115            max_age: Duration::from_secs(3600), // 1 hour default
116            stats: RwLock::new(CacheStats::default()),
117        }
118    }
119
120    /// Create cache with custom max age
121    pub fn with_max_age(max_age: Duration) -> Self {
122        Self {
123            entries: RwLock::new(HashMap::new()),
124            max_age,
125            stats: RwLock::new(CacheStats::default()),
126        }
127    }
128
129    /// Get a cached entry
130    pub fn get(&self, key: &CacheKey) -> Option<String> {
131        let mut entries = self.entries.write().unwrap();
132        let mut stats = self.stats.write().unwrap();
133
134        if let Some(entry) = entries.get_mut(key) {
135            if entry.is_expired(self.max_age) {
136                // Expired, remove and count as miss
137                entries.remove(key);
138                stats.total_misses += 1;
139                stats.total_evictions += 1;
140                None
141            } else {
142                // Valid, mark as used
143                entry.mark_used();
144                stats.total_hits += 1;
145                stats.tokens_saved += entry.token_count as u64;
146                Some(entry.content.clone())
147            }
148        } else {
149            stats.total_misses += 1;
150            None
151        }
152    }
153
154    /// Set a cached entry
155    pub fn set(&self, key: CacheKey, content: String) {
156        let mut entries = self.entries.write().unwrap();
157        let mut stats = self.stats.write().unwrap();
158
159        let entry = CachedEntry::new(content);
160        entries.insert(key, entry);
161        stats.total_entries = entries.len();
162    }
163
164    /// Get or compute (cache miss pattern)
165    pub fn get_or_compute<F>(&self, key: &CacheKey, compute: F) -> String
166    where
167        F: FnOnce() -> String,
168    {
169        if let Some(cached) = self.get(key) {
170            cached
171        } else {
172            let content = compute();
173            self.set(key.clone(), content.clone());
174            content
175        }
176    }
177
178    /// Clear all cache entries
179    pub fn clear(&self) {
180        let mut entries = self.entries.write().unwrap();
181        let mut stats = self.stats.write().unwrap();
182
183        let evicted = entries.len();
184        entries.clear();
185        stats.total_entries = 0;
186        stats.total_evictions += evicted as u64;
187    }
188
189    /// Clear entries for a specific profile
190    pub fn clear_profile(&self, profile: &str) {
191        let mut entries = self.entries.write().unwrap();
192        let mut stats = self.stats.write().unwrap();
193
194        entries.retain(|k, _| k.profile != profile);
195        stats.total_entries = entries.len();
196    }
197
198    /// Get statistics
199    pub fn stats(&self) -> CacheStats {
200        self.stats.read().unwrap().clone()
201    }
202
203    /// Get total cached token count
204    pub fn cached_tokens(&self) -> usize {
205        let entries = self.entries.read().unwrap();
206        entries.values().map(|e| e.token_count).sum()
207    }
208
209    /// Check if cache is empty
210    pub fn is_empty(&self) -> bool {
211        self.entries.read().unwrap().is_empty()
212    }
213
214    /// Get cache size
215    pub fn size(&self) -> usize {
216        self.entries.read().unwrap().len()
217    }
218}
219
220impl Default for SectionCache {
221    fn default() -> Self {
222        Self::new()
223    }
224}
225
226// Note: Clone implementation removed - use Arc<SectionCache> for sharing
227// Full cloning of potentially large cache entries is expensive and unnecessary
228// when Arc provides cheap reference counting
229
230/// Estimate token count for content
231pub fn estimate_tokens(content: &str) -> usize {
232    // Rough estimate:
233    // - Chinese: ~3 chars per token (each Chinese char is ~1 token, but /3 for safety)
234    // - English words: ~1 token per word
235    // - Other ASCII chars: ~4 chars per token
236    let chinese_chars = content
237        .chars()
238        .filter(|c| c.is_alphabetic() && c.len_utf8() > 1)
239        .count();
240    let english_words = content.split_whitespace().count();
241    let non_whitespace: usize = content.chars().filter(|c| !c.is_whitespace()).count();
242
243    // Fallback: if no words detected (no whitespace), use char count / 4
244    let fallback_estimate = if english_words == 0 && non_whitespace > 0 {
245        non_whitespace / 4
246    } else {
247        0
248    };
249
250    chinese_chars / 3 + english_words + fallback_estimate
251}
252
253/// Global cache instance
254static GLOBAL_CACHE: std::sync::OnceLock<Arc<SectionCache>> = std::sync::OnceLock::new();
255
256/// Get the global section cache
257pub fn global_cache() -> Arc<SectionCache> {
258    GLOBAL_CACHE
259        .get_or_init(|| Arc::new(SectionCache::new()))
260        .clone()
261}
262
263/// Clear the global cache (for /clear, /compact, worktree switch)
264pub fn clear_global_cache() {
265    global_cache().clear();
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_cache_basic() {
274        let cache = SectionCache::new();
275        let key = CacheKey::new("test", "default");
276
277        // Miss
278        assert!(cache.get(&key).is_none());
279
280        // Set
281        cache.set(key.clone(), "test content".to_string());
282
283        // Hit
284        assert_eq!(cache.get(&key), Some("test content".to_string()));
285
286        // Stats
287        let stats = cache.stats();
288        assert_eq!(stats.total_hits, 1);
289        assert_eq!(stats.total_misses, 1);
290    }
291
292    #[test]
293    fn test_cache_expiry() {
294        let cache = SectionCache::with_max_age(Duration::from_millis(10));
295        let key = CacheKey::new("test", "default");
296
297        cache.set(key.clone(), "test".to_string());
298
299        // Wait for expiry
300        std::thread::sleep(Duration::from_millis(20));
301
302        // Should be expired
303        assert!(cache.get(&key).is_none());
304        let stats = cache.stats();
305        assert_eq!(stats.total_evictions, 1);
306    }
307
308    #[test]
309    fn test_get_or_compute() {
310        let cache = SectionCache::new();
311        let key = CacheKey::new("compute", "default");
312
313        let result = cache.get_or_compute(&key, || "computed".to_string());
314        assert_eq!(result, "computed");
315
316        // Second call should use cache
317        let result2 = cache.get_or_compute(&key, || "different".to_string());
318        assert_eq!(result2, "computed"); // Still cached value
319    }
320
321    #[test]
322    fn test_clear_profile() {
323        let cache = SectionCache::new();
324
325        cache.set(CacheKey::new("a", "default"), "a".to_string());
326        cache.set(CacheKey::new("b", "safe"), "b".to_string());
327
328        cache.clear_profile("default");
329
330        assert!(cache.get(&CacheKey::new("a", "default")).is_none());
331        assert_eq!(
332            cache.get(&CacheKey::new("b", "safe")),
333            Some("b".to_string())
334        );
335    }
336
337    #[test]
338    fn test_estimate_tokens() {
339        let english = "Hello world this is a test";
340        let chinese = "你好世界这是一个测试";
341
342        // English: 5 words, should be roughly 5-7 tokens
343        let eng_tokens = estimate_tokens(english);
344        assert!(
345            eng_tokens >= 5 && eng_tokens <= 10,
346            "English tokens: {}",
347            eng_tokens
348        );
349
350        // Chinese: 9 chars / 3 = 3 tokens
351        let ch_tokens = estimate_tokens(chinese);
352        assert!(
353            ch_tokens >= 2 && ch_tokens <= 10,
354            "Chinese tokens: {}",
355            ch_tokens
356        );
357    }
358
359    #[test]
360    fn test_global_cache() {
361        clear_global_cache();
362        let cache = global_cache();
363
364        let key = CacheKey::new("global_test", "default");
365        cache.set(key.clone(), "global content".to_string());
366
367        // Should persist across calls
368        let cache2 = global_cache();
369        assert_eq!(cache2.get(&key), Some("global content".to_string()));
370
371        clear_global_cache();
372        assert!(cache2.get(&key).is_none());
373    }
374}