chess_vector_engine/utils/
cache.rs1use lru::LruCache;
2use std::collections::HashMap;
3use std::hash::Hash;
4use std::num::NonZeroUsize;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8pub struct TimedLruCache<K, V> {
10 cache: Arc<Mutex<LruCache<K, CacheEntry<V>>>>,
11 ttl: Duration,
12}
13
14#[derive(Debug, Clone)]
16struct CacheEntry<V> {
17 value: V,
18 timestamp: Instant,
19}
20
21impl<K, V> TimedLruCache<K, V>
22where
23 K: Hash + Eq + Clone,
24 V: Clone,
25{
26 pub fn new(capacity: usize, ttl: Duration) -> Self {
28 let non_zero_capacity =
29 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).unwrap());
30 Self {
31 cache: Arc::new(Mutex::new(LruCache::new(non_zero_capacity))),
32 ttl,
33 }
34 }
35
36 pub fn insert(&self, key: K, value: V) {
38 let entry = CacheEntry {
39 value,
40 timestamp: Instant::now(),
41 };
42
43 if let Ok(mut cache) = self.cache.lock() {
44 cache.put(key, entry);
45 }
46 }
47
48 pub fn get(&self, key: &K) -> Option<V> {
50 if let Ok(mut cache) = self.cache.lock() {
51 if let Some(entry) = cache.get(key) {
52 if entry.timestamp.elapsed() < self.ttl {
54 return Some(entry.value.clone());
55 } else {
56 cache.pop(key);
58 }
59 }
60 }
61 None
62 }
63
64 pub fn contains(&self, key: &K) -> bool {
66 if let Ok(cache) = self.cache.lock() {
67 if let Some(entry) = cache.peek(key) {
68 return entry.timestamp.elapsed() < self.ttl;
69 }
70 }
71 false
72 }
73
74 pub fn clear(&self) {
76 if let Ok(mut cache) = self.cache.lock() {
77 cache.clear();
78 }
79 }
80
81 pub fn stats(&self) -> CacheStats {
83 if let Ok(cache) = self.cache.lock() {
84 let capacity = cache.cap().get();
85 let size = cache.len();
86 let expired_count = cache
87 .iter()
88 .filter(|(_, entry)| entry.timestamp.elapsed() >= self.ttl)
89 .count();
90
91 CacheStats {
92 capacity,
93 size,
94 expired_count,
95 hit_ratio: 0.0, }
97 } else {
98 CacheStats {
99 capacity: 0,
100 size: 0,
101 expired_count: 0,
102 hit_ratio: 0.0,
103 }
104 }
105 }
106
107 pub fn cleanup_expired(&self) {
109 if let Ok(mut cache) = self.cache.lock() {
110 let now = Instant::now();
111 let expired_keys: Vec<K> = cache
112 .iter()
113 .filter(|(_, entry)| now.duration_since(entry.timestamp) >= self.ttl)
114 .map(|(k, _)| k.clone())
115 .collect();
116
117 for key in expired_keys {
118 cache.pop(&key);
119 }
120 }
121 }
122}
123
124pub struct SimilarityCache {
126 cache: TimedLruCache<(usize, usize), f32>,
127 hit_count: Arc<Mutex<u64>>,
128 miss_count: Arc<Mutex<u64>>,
129}
130
131impl SimilarityCache {
132 pub fn new(capacity: usize, ttl: Duration) -> Self {
134 Self {
135 cache: TimedLruCache::new(capacity, ttl),
136 hit_count: Arc::new(Mutex::new(0)),
137 miss_count: Arc::new(Mutex::new(0)),
138 }
139 }
140
141 pub fn get_similarity(&self, pos1: usize, pos2: usize) -> Option<f32> {
143 let key = if pos1 <= pos2 {
145 (pos1, pos2)
146 } else {
147 (pos2, pos1)
148 };
149
150 if let Some(similarity) = self.cache.get(&key) {
151 if let Ok(mut hits) = self.hit_count.lock() {
152 *hits += 1;
153 }
154 Some(similarity)
155 } else {
156 if let Ok(mut misses) = self.miss_count.lock() {
157 *misses += 1;
158 }
159 None
160 }
161 }
162
163 pub fn store_similarity(&self, pos1: usize, pos2: usize, similarity: f32) {
165 let key = if pos1 <= pos2 {
167 (pos1, pos2)
168 } else {
169 (pos2, pos1)
170 };
171 self.cache.insert(key, similarity);
172 }
173
174 pub fn stats(&self) -> CacheStats {
176 let mut base_stats = self.cache.stats();
177
178 let hits = self.hit_count.lock().map(|h| *h).unwrap_or(0);
179 let misses = self.miss_count.lock().map(|m| *m).unwrap_or(0);
180
181 base_stats.hit_ratio = if hits + misses > 0 {
182 hits as f64 / (hits + misses) as f64
183 } else {
184 0.0
185 };
186
187 base_stats
188 }
189
190 pub fn clear(&self) {
192 self.cache.clear();
193 if let Ok(mut hits) = self.hit_count.lock() {
194 *hits = 0;
195 }
196 if let Ok(mut misses) = self.miss_count.lock() {
197 *misses = 0;
198 }
199 }
200}
201
202pub struct EvaluationCache {
204 cache: TimedLruCache<String, f32>,
205 hit_count: Arc<Mutex<u64>>,
206 miss_count: Arc<Mutex<u64>>,
207}
208
209impl EvaluationCache {
210 pub fn new(capacity: usize, ttl: Duration) -> Self {
212 Self {
213 cache: TimedLruCache::new(capacity, ttl),
214 hit_count: Arc::new(Mutex::new(0)),
215 miss_count: Arc::new(Mutex::new(0)),
216 }
217 }
218
219 pub fn get_evaluation(&self, fen: &str) -> Option<f32> {
221 if let Some(evaluation) = self.cache.get(&fen.to_string()) {
222 if let Ok(mut hits) = self.hit_count.lock() {
223 *hits += 1;
224 }
225 Some(evaluation)
226 } else {
227 if let Ok(mut misses) = self.miss_count.lock() {
228 *misses += 1;
229 }
230 None
231 }
232 }
233
234 pub fn store_evaluation(&self, fen: &str, evaluation: f32) {
236 self.cache.insert(fen.to_string(), evaluation);
237 }
238
239 pub fn stats(&self) -> CacheStats {
241 let mut base_stats = self.cache.stats();
242
243 let hits = self.hit_count.lock().map(|h| *h).unwrap_or(0);
244 let misses = self.miss_count.lock().map(|m| *m).unwrap_or(0);
245
246 base_stats.hit_ratio = if hits + misses > 0 {
247 hits as f64 / (hits + misses) as f64
248 } else {
249 0.0
250 };
251
252 base_stats
253 }
254
255 pub fn clear(&self) {
257 self.cache.clear();
258 if let Ok(mut hits) = self.hit_count.lock() {
259 *hits = 0;
260 }
261 if let Ok(mut misses) = self.miss_count.lock() {
262 *misses = 0;
263 }
264 }
265}
266
267pub struct PatternCache<K, V> {
269 cache: Arc<Mutex<HashMap<K, V>>>,
270 backing_store: Arc<Mutex<HashMap<K, V>>>,
271 max_size: usize,
272}
273
274impl<K, V> PatternCache<K, V>
275where
276 K: Hash + Eq + Clone,
277 V: Clone,
278{
279 pub fn new(max_size: usize) -> Self {
281 Self {
282 cache: Arc::new(Mutex::new(HashMap::new())),
283 backing_store: Arc::new(Mutex::new(HashMap::new())),
284 max_size,
285 }
286 }
287
288 pub fn insert(&self, key: K, value: V) {
290 if let Ok(mut store) = self.backing_store.lock() {
292 store.insert(key.clone(), value.clone());
293 }
294
295 if let Ok(mut cache) = self.cache.lock() {
297 if cache.len() >= self.max_size {
299 if let Some(key_to_remove) = cache.keys().next().cloned() {
300 cache.remove(&key_to_remove);
301 }
302 }
303 cache.insert(key, value);
304 }
305 }
306
307 pub fn get(&self, key: &K) -> Option<V> {
309 if let Ok(cache) = self.cache.lock() {
311 if let Some(value) = cache.get(key) {
312 return Some(value.clone());
313 }
314 }
315
316 if let Ok(store) = self.backing_store.lock() {
318 if let Some(value) = store.get(key) {
319 let value = value.clone();
320
321 if let Ok(mut cache) = self.cache.lock() {
323 if cache.len() >= self.max_size {
324 if let Some(key_to_remove) = cache.keys().next().cloned() {
325 cache.remove(&key_to_remove);
326 }
327 }
328 cache.insert(key.clone(), value.clone());
329 }
330
331 return Some(value);
332 }
333 }
334
335 None
336 }
337
338 pub fn contains(&self, key: &K) -> bool {
340 if let Ok(cache) = self.cache.lock() {
341 if cache.contains_key(key) {
342 return true;
343 }
344 }
345
346 if let Ok(store) = self.backing_store.lock() {
347 return store.contains_key(key);
348 }
349
350 false
351 }
352
353 pub fn clear(&self) {
355 if let Ok(mut cache) = self.cache.lock() {
356 cache.clear();
357 }
358 if let Ok(mut store) = self.backing_store.lock() {
359 store.clear();
360 }
361 }
362
363 pub fn stats(&self) -> PatternCacheStats {
365 let cache_size = self.cache.lock().map(|c| c.len()).unwrap_or(0);
366 let backing_size = self.backing_store.lock().map(|s| s.len()).unwrap_or(0);
367
368 PatternCacheStats {
369 cache_size,
370 backing_size,
371 max_cache_size: self.max_size,
372 cache_hit_ratio: 0.0, }
374 }
375}
376
377#[derive(Debug, Clone)]
379pub struct CacheStats {
380 pub capacity: usize,
381 pub size: usize,
382 pub expired_count: usize,
383 pub hit_ratio: f64,
384}
385
386#[derive(Debug, Clone)]
388pub struct PatternCacheStats {
389 pub cache_size: usize,
390 pub backing_size: usize,
391 pub max_cache_size: usize,
392 pub cache_hit_ratio: f64,
393}
394
395pub struct BatchCache<K, V> {
397 cache: Arc<Mutex<HashMap<K, V>>>,
398 batch_size: usize,
399 pending_inserts: Arc<Mutex<HashMap<K, V>>>,
400}
401
402impl<K, V> BatchCache<K, V>
403where
404 K: Hash + Eq + Clone,
405 V: Clone,
406{
407 pub fn new(batch_size: usize) -> Self {
409 Self {
410 cache: Arc::new(Mutex::new(HashMap::new())),
411 batch_size,
412 pending_inserts: Arc::new(Mutex::new(HashMap::new())),
413 }
414 }
415
416 pub fn batch_insert(&self, key: K, value: V) {
418 if let Ok(mut pending) = self.pending_inserts.lock() {
419 pending.insert(key, value);
420
421 if pending.len() >= self.batch_size {
423 self.flush_batch();
424 }
425 }
426 }
427
428 pub fn flush_batch(&self) {
430 if let (Ok(mut cache), Ok(mut pending)) = (self.cache.lock(), self.pending_inserts.lock()) {
431 for (key, value) in pending.drain() {
432 cache.insert(key, value);
433 }
434 }
435 }
436
437 pub fn get(&self, key: &K) -> Option<V> {
439 if let Ok(cache) = self.cache.lock() {
441 if let Some(value) = cache.get(key) {
442 return Some(value.clone());
443 }
444 }
445
446 if let Ok(pending) = self.pending_inserts.lock() {
448 if let Some(value) = pending.get(key) {
449 return Some(value.clone());
450 }
451 }
452
453 None
454 }
455
456 pub fn clear(&self) {
458 if let Ok(mut cache) = self.cache.lock() {
459 cache.clear();
460 }
461 if let Ok(mut pending) = self.pending_inserts.lock() {
462 pending.clear();
463 }
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470 use std::time::Duration;
471
472 #[test]
473 fn test_timed_lru_cache() {
474 let cache = TimedLruCache::new(3, Duration::from_millis(100));
475
476 cache.insert("key1", "value1");
478 cache.insert("key2", "value2");
479 cache.insert("key3", "value3");
480
481 assert_eq!(cache.get(&"key1"), Some("value1"));
483 assert_eq!(cache.get(&"key2"), Some("value2"));
484 assert_eq!(cache.get(&"key3"), Some("value3"));
485
486 cache.insert("key4", "value4");
488 assert_eq!(cache.get(&"key1"), None); assert_eq!(cache.get(&"key4"), Some("value4"));
490 }
491
492 #[test]
493 fn test_similarity_cache() {
494 let cache = SimilarityCache::new(100, Duration::from_secs(1));
495
496 cache.store_similarity(1, 2, 0.8);
498
499 assert_eq!(cache.get_similarity(1, 2), Some(0.8));
501 assert_eq!(cache.get_similarity(2, 1), Some(0.8));
502
503 assert_eq!(cache.get_similarity(3, 4), None);
505
506 let stats = cache.stats();
508 assert_eq!(stats.hit_ratio, 2.0 / 3.0); }
510
511 #[test]
512 fn test_evaluation_cache() {
513 let cache = EvaluationCache::new(100, Duration::from_secs(1));
514
515 cache.store_evaluation(
517 "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
518 0.0,
519 );
520
521 assert_eq!(
523 cache.get_evaluation("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"),
524 Some(0.0)
525 );
526
527 assert_eq!(cache.get_evaluation("8/8/8/8/8/8/8/8 w - - 0 1"), None);
529 }
530
531 #[test]
532 fn test_pattern_cache() {
533 let cache = PatternCache::new(2);
534
535 cache.insert("pattern1", "data1");
537 cache.insert("pattern2", "data2");
538
539 assert_eq!(cache.get(&"pattern1"), Some("data1"));
541 assert_eq!(cache.get(&"pattern2"), Some("data2"));
542
543 cache.insert("pattern3", "data3");
545
546 assert_eq!(cache.get(&"pattern1"), Some("data1"));
548 assert_eq!(cache.get(&"pattern2"), Some("data2"));
549 assert_eq!(cache.get(&"pattern3"), Some("data3"));
550 }
551
552 #[test]
553 fn test_batch_cache() {
554 let cache = BatchCache::new(2);
555
556 cache.batch_insert("key1", "value1");
558 cache.batch_insert("key2", "value2");
559
560 assert_eq!(cache.get(&"key1"), Some("value1"));
562 assert_eq!(cache.get(&"key2"), Some("value2"));
563
564 cache.batch_insert("key3", "value3");
566
567 assert_eq!(cache.get(&"key1"), Some("value1"));
569 assert_eq!(cache.get(&"key2"), Some("value2"));
570 assert_eq!(cache.get(&"key3"), Some("value3"));
571 }
572}