matrixcode_core/prompt/
cache.rs1use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14pub struct CacheKey {
15 pub name: String,
17 pub profile: String,
19 pub content_hash: Option<u64>,
21}
22
23impl CacheKey {
24 pub fn new(name: impl Into<String>, profile: impl Into<String>) -> Self {
25 Self {
26 name: name.into(),
27 profile: profile.into(),
28 content_hash: None,
29 }
30 }
31
32 pub fn with_hash(self, hash: u64) -> Self {
33 Self {
34 content_hash: Some(hash),
35 ..self
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct CachedEntry {
43 pub content: String,
45 pub cached_at: Instant,
47 pub token_count: usize,
49 pub use_count: u64,
51}
52
53impl CachedEntry {
54 pub fn new(content: String) -> Self {
55 let token_count = estimate_tokens(&content);
56 Self {
57 content,
58 cached_at: Instant::now(),
59 token_count,
60 use_count: 0,
61 }
62 }
63
64 pub fn is_expired(&self, max_age: Duration) -> bool {
66 self.cached_at.elapsed() > max_age
67 }
68
69 pub fn mark_used(&mut self) {
71 self.use_count += 1;
72 }
73}
74
75pub struct SectionCache {
77 entries: RwLock<HashMap<CacheKey, CachedEntry>>,
79 max_age: Duration,
81 stats: RwLock<CacheStats>,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct CacheStats {
88 pub total_entries: usize,
90 pub total_hits: u64,
92 pub total_misses: u64,
94 pub total_evictions: u64,
96 pub tokens_saved: u64,
98}
99
100impl CacheStats {
101 pub fn hit_rate(&self) -> f64 {
102 if self.total_hits + self.total_misses == 0 {
103 0.0
104 } else {
105 self.total_hits as f64 / (self.total_hits + self.total_misses) as f64
106 }
107 }
108}
109
110impl SectionCache {
111 pub fn new() -> Self {
113 Self {
114 entries: RwLock::new(HashMap::new()),
115 max_age: Duration::from_secs(3600), stats: RwLock::new(CacheStats::default()),
117 }
118 }
119
120 pub fn with_max_age(max_age: Duration) -> Self {
122 Self {
123 entries: RwLock::new(HashMap::new()),
124 max_age,
125 stats: RwLock::new(CacheStats::default()),
126 }
127 }
128
129 pub fn get(&self, key: &CacheKey) -> Option<String> {
131 let mut entries = self.entries.write().unwrap();
132 let mut stats = self.stats.write().unwrap();
133
134 if let Some(entry) = entries.get_mut(key) {
135 if entry.is_expired(self.max_age) {
136 entries.remove(key);
138 stats.total_misses += 1;
139 stats.total_evictions += 1;
140 None
141 } else {
142 entry.mark_used();
144 stats.total_hits += 1;
145 stats.tokens_saved += entry.token_count as u64;
146 Some(entry.content.clone())
147 }
148 } else {
149 stats.total_misses += 1;
150 None
151 }
152 }
153
154 pub fn set(&self, key: CacheKey, content: String) {
156 let mut entries = self.entries.write().unwrap();
157 let mut stats = self.stats.write().unwrap();
158
159 let entry = CachedEntry::new(content);
160 entries.insert(key, entry);
161 stats.total_entries = entries.len();
162 }
163
164 pub fn get_or_compute<F>(&self, key: &CacheKey, compute: F) -> String
166 where
167 F: FnOnce() -> String,
168 {
169 if let Some(cached) = self.get(key) {
170 cached
171 } else {
172 let content = compute();
173 self.set(key.clone(), content.clone());
174 content
175 }
176 }
177
178 pub fn clear(&self) {
180 let mut entries = self.entries.write().unwrap();
181 let mut stats = self.stats.write().unwrap();
182
183 let evicted = entries.len();
184 entries.clear();
185 stats.total_entries = 0;
186 stats.total_evictions += evicted as u64;
187 }
188
189 pub fn clear_profile(&self, profile: &str) {
191 let mut entries = self.entries.write().unwrap();
192 let mut stats = self.stats.write().unwrap();
193
194 entries.retain(|k, _| k.profile != profile);
195 stats.total_entries = entries.len();
196 }
197
198 pub fn stats(&self) -> CacheStats {
200 self.stats.read().unwrap().clone()
201 }
202
203 pub fn cached_tokens(&self) -> usize {
205 let entries = self.entries.read().unwrap();
206 entries.values().map(|e| e.token_count).sum()
207 }
208
209 pub fn is_empty(&self) -> bool {
211 self.entries.read().unwrap().is_empty()
212 }
213
214 pub fn size(&self) -> usize {
216 self.entries.read().unwrap().len()
217 }
218}
219
220impl Default for SectionCache {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226pub fn estimate_tokens(content: &str) -> usize {
232 let chinese_chars = content
237 .chars()
238 .filter(|c| c.is_alphabetic() && c.len_utf8() > 1)
239 .count();
240 let english_words = content.split_whitespace().count();
241 let non_whitespace: usize = content.chars().filter(|c| !c.is_whitespace()).count();
242
243 let fallback_estimate = if english_words == 0 && non_whitespace > 0 {
245 non_whitespace / 4
246 } else {
247 0
248 };
249
250 chinese_chars / 3 + english_words + fallback_estimate
251}
252
253static GLOBAL_CACHE: std::sync::OnceLock<Arc<SectionCache>> = std::sync::OnceLock::new();
255
256pub fn global_cache() -> Arc<SectionCache> {
258 GLOBAL_CACHE
259 .get_or_init(|| Arc::new(SectionCache::new()))
260 .clone()
261}
262
263pub fn clear_global_cache() {
265 global_cache().clear();
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_cache_basic() {
274 let cache = SectionCache::new();
275 let key = CacheKey::new("test", "default");
276
277 assert!(cache.get(&key).is_none());
279
280 cache.set(key.clone(), "test content".to_string());
282
283 assert_eq!(cache.get(&key), Some("test content".to_string()));
285
286 let stats = cache.stats();
288 assert_eq!(stats.total_hits, 1);
289 assert_eq!(stats.total_misses, 1);
290 }
291
292 #[test]
293 fn test_cache_expiry() {
294 let cache = SectionCache::with_max_age(Duration::from_millis(10));
295 let key = CacheKey::new("test", "default");
296
297 cache.set(key.clone(), "test".to_string());
298
299 std::thread::sleep(Duration::from_millis(20));
301
302 assert!(cache.get(&key).is_none());
304 let stats = cache.stats();
305 assert_eq!(stats.total_evictions, 1);
306 }
307
308 #[test]
309 fn test_get_or_compute() {
310 let cache = SectionCache::new();
311 let key = CacheKey::new("compute", "default");
312
313 let result = cache.get_or_compute(&key, || "computed".to_string());
314 assert_eq!(result, "computed");
315
316 let result2 = cache.get_or_compute(&key, || "different".to_string());
318 assert_eq!(result2, "computed"); }
320
321 #[test]
322 fn test_clear_profile() {
323 let cache = SectionCache::new();
324
325 cache.set(CacheKey::new("a", "default"), "a".to_string());
326 cache.set(CacheKey::new("b", "safe"), "b".to_string());
327
328 cache.clear_profile("default");
329
330 assert!(cache.get(&CacheKey::new("a", "default")).is_none());
331 assert_eq!(
332 cache.get(&CacheKey::new("b", "safe")),
333 Some("b".to_string())
334 );
335 }
336
337 #[test]
338 fn test_estimate_tokens() {
339 let english = "Hello world this is a test";
340 let chinese = "你好世界这是一个测试";
341
342 let eng_tokens = estimate_tokens(english);
344 assert!(
345 eng_tokens >= 5 && eng_tokens <= 10,
346 "English tokens: {}",
347 eng_tokens
348 );
349
350 let ch_tokens = estimate_tokens(chinese);
352 assert!(
353 ch_tokens >= 2 && ch_tokens <= 10,
354 "Chinese tokens: {}",
355 ch_tokens
356 );
357 }
358
359 #[test]
360 fn test_global_cache() {
361 clear_global_cache();
362 let cache = global_cache();
363
364 let key = CacheKey::new("global_test", "default");
365 cache.set(key.clone(), "global content".to_string());
366
367 let cache2 = global_cache();
369 assert_eq!(cache2.get(&key), Some("global content".to_string()));
370
371 clear_global_cache();
372 assert!(cache2.get(&key).is_none());
373 }
374}