matrixcode_core/prompt/
cache.rs1use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14pub struct CacheKey {
15 pub name: String,
17 pub profile: String,
19 pub content_hash: Option<u64>,
21}
22
23impl CacheKey {
24 pub fn new(name: impl Into<String>, profile: impl Into<String>) -> Self {
25 Self {
26 name: name.into(),
27 profile: profile.into(),
28 content_hash: None,
29 }
30 }
31
32 pub fn with_hash(self, hash: u64) -> Self {
33 Self { content_hash: Some(hash), ..self }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct CachedEntry {
40 pub content: String,
42 pub cached_at: Instant,
44 pub token_count: usize,
46 pub use_count: u64,
48}
49
50impl CachedEntry {
51 pub fn new(content: String) -> Self {
52 let token_count = estimate_tokens(&content);
53 Self {
54 content,
55 cached_at: Instant::now(),
56 token_count,
57 use_count: 0,
58 }
59 }
60
61 pub fn is_expired(&self, max_age: Duration) -> bool {
63 self.cached_at.elapsed() > max_age
64 }
65
66 pub fn mark_used(&mut self) {
68 self.use_count += 1;
69 }
70}
71
72pub struct SectionCache {
74 entries: RwLock<HashMap<CacheKey, CachedEntry>>,
76 max_age: Duration,
78 stats: RwLock<CacheStats>,
80}
81
82#[derive(Debug, Clone, Default)]
84pub struct CacheStats {
85 pub total_entries: usize,
87 pub total_hits: u64,
89 pub total_misses: u64,
91 pub total_evictions: u64,
93 pub tokens_saved: u64,
95}
96
97impl CacheStats {
98 pub fn hit_rate(&self) -> f64 {
99 if self.total_hits + self.total_misses == 0 {
100 0.0
101 } else {
102 self.total_hits as f64 / (self.total_hits + self.total_misses) as f64
103 }
104 }
105}
106
107impl SectionCache {
108 pub fn new() -> Self {
110 Self {
111 entries: RwLock::new(HashMap::new()),
112 max_age: Duration::from_secs(3600), stats: RwLock::new(CacheStats::default()),
114 }
115 }
116
117 pub fn with_max_age(max_age: Duration) -> Self {
119 Self {
120 entries: RwLock::new(HashMap::new()),
121 max_age,
122 stats: RwLock::new(CacheStats::default()),
123 }
124 }
125
126 pub fn get(&self, key: &CacheKey) -> Option<String> {
128 let mut entries = self.entries.write().unwrap();
129 let mut stats = self.stats.write().unwrap();
130
131 if let Some(entry) = entries.get_mut(key) {
132 if entry.is_expired(self.max_age) {
133 entries.remove(key);
135 stats.total_misses += 1;
136 stats.total_evictions += 1;
137 None
138 } else {
139 entry.mark_used();
141 stats.total_hits += 1;
142 stats.tokens_saved += entry.token_count as u64;
143 Some(entry.content.clone())
144 }
145 } else {
146 stats.total_misses += 1;
147 None
148 }
149 }
150
151 pub fn set(&self, key: CacheKey, content: String) {
153 let mut entries = self.entries.write().unwrap();
154 let mut stats = self.stats.write().unwrap();
155
156 let entry = CachedEntry::new(content);
157 entries.insert(key, entry);
158 stats.total_entries = entries.len();
159 }
160
161 pub fn get_or_compute<F>(&self, key: &CacheKey, compute: F) -> String
163 where
164 F: FnOnce() -> String,
165 {
166 if let Some(cached) = self.get(key) {
167 cached
168 } else {
169 let content = compute();
170 self.set(key.clone(), content.clone());
171 content
172 }
173 }
174
175 pub fn clear(&self) {
177 let mut entries = self.entries.write().unwrap();
178 let mut stats = self.stats.write().unwrap();
179
180 let evicted = entries.len();
181 entries.clear();
182 stats.total_entries = 0;
183 stats.total_evictions += evicted as u64;
184 }
185
186 pub fn clear_profile(&self, profile: &str) {
188 let mut entries = self.entries.write().unwrap();
189 let mut stats = self.stats.write().unwrap();
190
191 entries.retain(|k, _| k.profile != profile);
192 stats.total_entries = entries.len();
193 }
194
195 pub fn stats(&self) -> CacheStats {
197 self.stats.read().unwrap().clone()
198 }
199
200 pub fn cached_tokens(&self) -> usize {
202 let entries = self.entries.read().unwrap();
203 entries.values().map(|e| e.token_count).sum()
204 }
205
206 pub fn is_empty(&self) -> bool {
208 self.entries.read().unwrap().is_empty()
209 }
210
211 pub fn size(&self) -> usize {
213 self.entries.read().unwrap().len()
214 }
215}
216
217impl Default for SectionCache {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223pub fn estimate_tokens(content: &str) -> usize {
229 let chinese_chars = content.chars().filter(|c| c.is_alphabetic() && c.len_utf8() > 1).count();
234 let english_words = content.split_whitespace().count();
235 let non_whitespace: usize = content.chars().filter(|c| !c.is_whitespace()).count();
236
237 let fallback_estimate = if english_words == 0 && non_whitespace > 0 {
239 non_whitespace / 4
240 } else {
241 0
242 };
243
244 chinese_chars / 3 + english_words + fallback_estimate
245}
246
247static GLOBAL_CACHE: std::sync::OnceLock<Arc<SectionCache>> = std::sync::OnceLock::new();
249
250pub fn global_cache() -> Arc<SectionCache> {
252 GLOBAL_CACHE.get_or_init(|| Arc::new(SectionCache::new())).clone()
253}
254
255pub fn clear_global_cache() {
257 global_cache().clear();
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_cache_basic() {
266 let cache = SectionCache::new();
267 let key = CacheKey::new("test", "default");
268
269 assert!(cache.get(&key).is_none());
271
272 cache.set(key.clone(), "test content".to_string());
274
275 assert_eq!(cache.get(&key), Some("test content".to_string()));
277
278 let stats = cache.stats();
280 assert_eq!(stats.total_hits, 1);
281 assert_eq!(stats.total_misses, 1);
282 }
283
284 #[test]
285 fn test_cache_expiry() {
286 let cache = SectionCache::with_max_age(Duration::from_millis(10));
287 let key = CacheKey::new("test", "default");
288
289 cache.set(key.clone(), "test".to_string());
290
291 std::thread::sleep(Duration::from_millis(20));
293
294 assert!(cache.get(&key).is_none());
296 let stats = cache.stats();
297 assert_eq!(stats.total_evictions, 1);
298 }
299
300 #[test]
301 fn test_get_or_compute() {
302 let cache = SectionCache::new();
303 let key = CacheKey::new("compute", "default");
304
305 let result = cache.get_or_compute(&key, || "computed".to_string());
306 assert_eq!(result, "computed");
307
308 let result2 = cache.get_or_compute(&key, || "different".to_string());
310 assert_eq!(result2, "computed"); }
312
313 #[test]
314 fn test_clear_profile() {
315 let cache = SectionCache::new();
316
317 cache.set(CacheKey::new("a", "default"), "a".to_string());
318 cache.set(CacheKey::new("b", "safe"), "b".to_string());
319
320 cache.clear_profile("default");
321
322 assert!(cache.get(&CacheKey::new("a", "default")).is_none());
323 assert_eq!(cache.get(&CacheKey::new("b", "safe")), Some("b".to_string()));
324 }
325
326 #[test]
327 fn test_estimate_tokens() {
328 let english = "Hello world this is a test";
329 let chinese = "你好世界这是一个测试";
330
331 let eng_tokens = estimate_tokens(english);
333 assert!(eng_tokens >= 5 && eng_tokens <= 10, "English tokens: {}", eng_tokens);
334
335 let ch_tokens = estimate_tokens(chinese);
337 assert!(ch_tokens >= 2 && ch_tokens <= 10, "Chinese tokens: {}", ch_tokens);
338 }
339
340 #[test]
341 fn test_global_cache() {
342 clear_global_cache();
343 let cache = global_cache();
344
345 let key = CacheKey::new("global_test", "default");
346 cache.set(key.clone(), "global content".to_string());
347
348 let cache2 = global_cache();
350 assert_eq!(cache2.get(&key), Some("global content".to_string()));
351
352 clear_global_cache();
353 assert!(cache2.get(&key).is_none());
354 }
355}