chess_vector_engine/utils/
cache.rs

1use lru::LruCache;
2use std::collections::HashMap;
3use std::hash::Hash;
4use std::num::NonZeroUsize;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8/// Thread-safe LRU cache with time-based expiration
9pub struct TimedLruCache<K, V> {
10    cache: Arc<Mutex<LruCache<K, CacheEntry<V>>>>,
11    ttl: Duration,
12}
13
14/// Cache entry with timestamp for TTL support
15#[derive(Debug, Clone)]
16struct CacheEntry<V> {
17    value: V,
18    timestamp: Instant,
19}
20
21impl<K, V> TimedLruCache<K, V>
22where
23    K: Hash + Eq + Clone,
24    V: Clone,
25{
26    /// Create a new timed LRU cache
27    pub fn new(capacity: usize, ttl: Duration) -> Self {
28        let non_zero_capacity =
29            NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).unwrap());
30        Self {
31            cache: Arc::new(Mutex::new(LruCache::new(non_zero_capacity))),
32            ttl,
33        }
34    }
35
36    /// Insert a value into the cache
37    pub fn insert(&self, key: K, value: V) {
38        let entry = CacheEntry {
39            value,
40            timestamp: Instant::now(),
41        };
42
43        if let Ok(mut cache) = self.cache.lock() {
44            cache.put(key, entry);
45        }
46    }
47
48    /// Get a value from the cache
49    pub fn get(&self, key: &K) -> Option<V> {
50        if let Ok(mut cache) = self.cache.lock() {
51            if let Some(entry) = cache.get(key) {
52                // Check if entry has expired
53                if entry.timestamp.elapsed() < self.ttl {
54                    return Some(entry.value.clone());
55                } else {
56                    // Remove expired entry
57                    cache.pop(key);
58                }
59            }
60        }
61        None
62    }
63
64    /// Check if a key exists in cache (without updating LRU order)
65    pub fn contains(&self, key: &K) -> bool {
66        if let Ok(cache) = self.cache.lock() {
67            if let Some(entry) = cache.peek(key) {
68                return entry.timestamp.elapsed() < self.ttl;
69            }
70        }
71        false
72    }
73
74    /// Clear all entries from the cache
75    pub fn clear(&self) {
76        if let Ok(mut cache) = self.cache.lock() {
77            cache.clear();
78        }
79    }
80
81    /// Get cache statistics
82    pub fn stats(&self) -> CacheStats {
83        if let Ok(cache) = self.cache.lock() {
84            let capacity = cache.cap().get();
85            let size = cache.len();
86            let expired_count = cache
87                .iter()
88                .filter(|(_, entry)| entry.timestamp.elapsed() >= self.ttl)
89                .count();
90
91            CacheStats {
92                capacity,
93                size,
94                expired_count,
95                hit_ratio: 0.0, // Would need hit/miss tracking for accurate ratio
96            }
97        } else {
98            CacheStats {
99                capacity: 0,
100                size: 0,
101                expired_count: 0,
102                hit_ratio: 0.0,
103            }
104        }
105    }
106
107    /// Clean up expired entries
108    pub fn cleanup_expired(&self) {
109        if let Ok(mut cache) = self.cache.lock() {
110            let now = Instant::now();
111            let expired_keys: Vec<K> = cache
112                .iter()
113                .filter(|(_, entry)| now.duration_since(entry.timestamp) >= self.ttl)
114                .map(|(k, _)| k.clone())
115                .collect();
116
117            for key in expired_keys {
118                cache.pop(&key);
119            }
120        }
121    }
122}
123
124/// High-performance similarity cache for chess positions
125pub struct SimilarityCache {
126    cache: TimedLruCache<(usize, usize), f32>,
127    hit_count: Arc<Mutex<u64>>,
128    miss_count: Arc<Mutex<u64>>,
129}
130
131impl SimilarityCache {
132    /// Create a new similarity cache
133    pub fn new(capacity: usize, ttl: Duration) -> Self {
134        Self {
135            cache: TimedLruCache::new(capacity, ttl),
136            hit_count: Arc::new(Mutex::new(0)),
137            miss_count: Arc::new(Mutex::new(0)),
138        }
139    }
140
141    /// Get similarity from cache
142    pub fn get_similarity(&self, pos1: usize, pos2: usize) -> Option<f32> {
143        // Normalize key order (similarity is symmetric)
144        let key = if pos1 <= pos2 {
145            (pos1, pos2)
146        } else {
147            (pos2, pos1)
148        };
149
150        if let Some(similarity) = self.cache.get(&key) {
151            if let Ok(mut hits) = self.hit_count.lock() {
152                *hits += 1;
153            }
154            Some(similarity)
155        } else {
156            if let Ok(mut misses) = self.miss_count.lock() {
157                *misses += 1;
158            }
159            None
160        }
161    }
162
163    /// Store similarity in cache
164    pub fn store_similarity(&self, pos1: usize, pos2: usize, similarity: f32) {
165        // Normalize key order (similarity is symmetric)
166        let key = if pos1 <= pos2 {
167            (pos1, pos2)
168        } else {
169            (pos2, pos1)
170        };
171        self.cache.insert(key, similarity);
172    }
173
174    /// Get cache statistics with hit ratio
175    pub fn stats(&self) -> CacheStats {
176        let mut base_stats = self.cache.stats();
177
178        let hits = self.hit_count.lock().map(|h| *h).unwrap_or(0);
179        let misses = self.miss_count.lock().map(|m| *m).unwrap_or(0);
180
181        base_stats.hit_ratio = if hits + misses > 0 {
182            hits as f64 / (hits + misses) as f64
183        } else {
184            0.0
185        };
186
187        base_stats
188    }
189
190    /// Clear cache and reset statistics
191    pub fn clear(&self) {
192        self.cache.clear();
193        if let Ok(mut hits) = self.hit_count.lock() {
194            *hits = 0;
195        }
196        if let Ok(mut misses) = self.miss_count.lock() {
197            *misses = 0;
198        }
199    }
200}
201
202/// Evaluation result cache for chess positions
203pub struct EvaluationCache {
204    cache: TimedLruCache<String, f32>,
205    hit_count: Arc<Mutex<u64>>,
206    miss_count: Arc<Mutex<u64>>,
207}
208
209impl EvaluationCache {
210    /// Create a new evaluation cache
211    pub fn new(capacity: usize, ttl: Duration) -> Self {
212        Self {
213            cache: TimedLruCache::new(capacity, ttl),
214            hit_count: Arc::new(Mutex::new(0)),
215            miss_count: Arc::new(Mutex::new(0)),
216        }
217    }
218
219    /// Get evaluation from cache using FEN string as key
220    pub fn get_evaluation(&self, fen: &str) -> Option<f32> {
221        if let Some(evaluation) = self.cache.get(&fen.to_string()) {
222            if let Ok(mut hits) = self.hit_count.lock() {
223                *hits += 1;
224            }
225            Some(evaluation)
226        } else {
227            if let Ok(mut misses) = self.miss_count.lock() {
228                *misses += 1;
229            }
230            None
231        }
232    }
233
234    /// Store evaluation in cache
235    pub fn store_evaluation(&self, fen: &str, evaluation: f32) {
236        self.cache.insert(fen.to_string(), evaluation);
237    }
238
239    /// Get cache statistics with hit ratio
240    pub fn stats(&self) -> CacheStats {
241        let mut base_stats = self.cache.stats();
242
243        let hits = self.hit_count.lock().map(|h| *h).unwrap_or(0);
244        let misses = self.miss_count.lock().map(|m| *m).unwrap_or(0);
245
246        base_stats.hit_ratio = if hits + misses > 0 {
247            hits as f64 / (hits + misses) as f64
248        } else {
249            0.0
250        };
251
252        base_stats
253    }
254
255    /// Clear cache and reset statistics
256    pub fn clear(&self) {
257        self.cache.clear();
258        if let Ok(mut hits) = self.hit_count.lock() {
259            *hits = 0;
260        }
261        if let Ok(mut misses) = self.miss_count.lock() {
262            *misses = 0;
263        }
264    }
265}
266
267/// Write-through cache for pattern data
268pub struct PatternCache<K, V> {
269    cache: Arc<Mutex<HashMap<K, V>>>,
270    backing_store: Arc<Mutex<HashMap<K, V>>>,
271    max_size: usize,
272}
273
274impl<K, V> PatternCache<K, V>
275where
276    K: Hash + Eq + Clone,
277    V: Clone,
278{
279    /// Create a new pattern cache with backing store
280    pub fn new(max_size: usize) -> Self {
281        Self {
282            cache: Arc::new(Mutex::new(HashMap::new())),
283            backing_store: Arc::new(Mutex::new(HashMap::new())),
284            max_size,
285        }
286    }
287
288    /// Insert a value (writes to both cache and backing store)
289    pub fn insert(&self, key: K, value: V) {
290        // Write to backing store first
291        if let Ok(mut store) = self.backing_store.lock() {
292            store.insert(key.clone(), value.clone());
293        }
294
295        // Then write to cache
296        if let Ok(mut cache) = self.cache.lock() {
297            // If cache is full, remove random entry
298            if cache.len() >= self.max_size {
299                if let Some(key_to_remove) = cache.keys().next().cloned() {
300                    cache.remove(&key_to_remove);
301                }
302            }
303            cache.insert(key, value);
304        }
305    }
306
307    /// Get a value (checks cache first, then backing store)
308    pub fn get(&self, key: &K) -> Option<V> {
309        // Check cache first
310        if let Ok(cache) = self.cache.lock() {
311            if let Some(value) = cache.get(key) {
312                return Some(value.clone());
313            }
314        }
315
316        // Check backing store
317        if let Ok(store) = self.backing_store.lock() {
318            if let Some(value) = store.get(key) {
319                let value = value.clone();
320
321                // Promote to cache
322                if let Ok(mut cache) = self.cache.lock() {
323                    if cache.len() >= self.max_size {
324                        if let Some(key_to_remove) = cache.keys().next().cloned() {
325                            cache.remove(&key_to_remove);
326                        }
327                    }
328                    cache.insert(key.clone(), value.clone());
329                }
330
331                return Some(value);
332            }
333        }
334
335        None
336    }
337
338    /// Check if key exists (cache or backing store)
339    pub fn contains(&self, key: &K) -> bool {
340        if let Ok(cache) = self.cache.lock() {
341            if cache.contains_key(key) {
342                return true;
343            }
344        }
345
346        if let Ok(store) = self.backing_store.lock() {
347            return store.contains_key(key);
348        }
349
350        false
351    }
352
353    /// Clear both cache and backing store
354    pub fn clear(&self) {
355        if let Ok(mut cache) = self.cache.lock() {
356            cache.clear();
357        }
358        if let Ok(mut store) = self.backing_store.lock() {
359            store.clear();
360        }
361    }
362
363    /// Get cache statistics
364    pub fn stats(&self) -> PatternCacheStats {
365        let cache_size = self.cache.lock().map(|c| c.len()).unwrap_or(0);
366        let backing_size = self.backing_store.lock().map(|s| s.len()).unwrap_or(0);
367
368        PatternCacheStats {
369            cache_size,
370            backing_size,
371            max_cache_size: self.max_size,
372            cache_hit_ratio: 0.0, // Would need hit/miss tracking
373        }
374    }
375}
376
377/// Cache statistics
378#[derive(Debug, Clone)]
379pub struct CacheStats {
380    pub capacity: usize,
381    pub size: usize,
382    pub expired_count: usize,
383    pub hit_ratio: f64,
384}
385
386/// Pattern cache statistics
387#[derive(Debug, Clone)]
388pub struct PatternCacheStats {
389    pub cache_size: usize,
390    pub backing_size: usize,
391    pub max_cache_size: usize,
392    pub cache_hit_ratio: f64,
393}
394
395/// Batch cache operations for improved performance
396pub struct BatchCache<K, V> {
397    cache: Arc<Mutex<HashMap<K, V>>>,
398    batch_size: usize,
399    pending_inserts: Arc<Mutex<HashMap<K, V>>>,
400}
401
402impl<K, V> BatchCache<K, V>
403where
404    K: Hash + Eq + Clone,
405    V: Clone,
406{
407    /// Create a new batch cache
408    pub fn new(batch_size: usize) -> Self {
409        Self {
410            cache: Arc::new(Mutex::new(HashMap::new())),
411            batch_size,
412            pending_inserts: Arc::new(Mutex::new(HashMap::new())),
413        }
414    }
415
416    /// Add item to pending batch
417    pub fn batch_insert(&self, key: K, value: V) {
418        if let Ok(mut pending) = self.pending_inserts.lock() {
419            pending.insert(key, value);
420
421            // Flush if batch is full
422            if pending.len() >= self.batch_size {
423                self.flush_batch();
424            }
425        }
426    }
427
428    /// Flush pending batch to main cache
429    pub fn flush_batch(&self) {
430        if let (Ok(mut cache), Ok(mut pending)) = (self.cache.lock(), self.pending_inserts.lock()) {
431            for (key, value) in pending.drain() {
432                cache.insert(key, value);
433            }
434        }
435    }
436
437    /// Get value from cache (including pending batch)
438    pub fn get(&self, key: &K) -> Option<V> {
439        // Check main cache first
440        if let Ok(cache) = self.cache.lock() {
441            if let Some(value) = cache.get(key) {
442                return Some(value.clone());
443            }
444        }
445
446        // Check pending batch
447        if let Ok(pending) = self.pending_inserts.lock() {
448            if let Some(value) = pending.get(key) {
449                return Some(value.clone());
450            }
451        }
452
453        None
454    }
455
456    /// Clear cache and pending batch
457    pub fn clear(&self) {
458        if let Ok(mut cache) = self.cache.lock() {
459            cache.clear();
460        }
461        if let Ok(mut pending) = self.pending_inserts.lock() {
462            pending.clear();
463        }
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use std::time::Duration;
471
472    #[test]
473    fn test_timed_lru_cache() {
474        let cache = TimedLruCache::new(3, Duration::from_millis(100));
475
476        // Insert values
477        cache.insert("key1", "value1");
478        cache.insert("key2", "value2");
479        cache.insert("key3", "value3");
480
481        // Should be able to retrieve all
482        assert_eq!(cache.get(&"key1"), Some("value1"));
483        assert_eq!(cache.get(&"key2"), Some("value2"));
484        assert_eq!(cache.get(&"key3"), Some("value3"));
485
486        // Insert one more (should evict LRU)
487        cache.insert("key4", "value4");
488        assert_eq!(cache.get(&"key1"), None); // Should be evicted
489        assert_eq!(cache.get(&"key4"), Some("value4"));
490    }
491
492    #[test]
493    fn test_similarity_cache() {
494        let cache = SimilarityCache::new(100, Duration::from_secs(1));
495
496        // Store similarity
497        cache.store_similarity(1, 2, 0.8);
498
499        // Should be able to retrieve it (order shouldn't matter)
500        assert_eq!(cache.get_similarity(1, 2), Some(0.8));
501        assert_eq!(cache.get_similarity(2, 1), Some(0.8));
502
503        // Non-existent similarity
504        assert_eq!(cache.get_similarity(3, 4), None);
505
506        // Check statistics
507        let stats = cache.stats();
508        assert_eq!(stats.hit_ratio, 2.0 / 3.0); // 2 hits out of 3 total requests
509    }
510
511    #[test]
512    fn test_evaluation_cache() {
513        let cache = EvaluationCache::new(100, Duration::from_secs(1));
514
515        // Store evaluation
516        cache.store_evaluation(
517            "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
518            0.0,
519        );
520
521        // Should be able to retrieve it
522        assert_eq!(
523            cache.get_evaluation("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"),
524            Some(0.0)
525        );
526
527        // Non-existent evaluation
528        assert_eq!(cache.get_evaluation("8/8/8/8/8/8/8/8 w - - 0 1"), None);
529    }
530
531    #[test]
532    fn test_pattern_cache() {
533        let cache = PatternCache::new(2);
534
535        // Insert values
536        cache.insert("pattern1", "data1");
537        cache.insert("pattern2", "data2");
538
539        // Should be able to retrieve
540        assert_eq!(cache.get(&"pattern1"), Some("data1"));
541        assert_eq!(cache.get(&"pattern2"), Some("data2"));
542
543        // Insert one more (should evict from cache but keep in backing store)
544        cache.insert("pattern3", "data3");
545
546        // Should still be able to retrieve all (from backing store)
547        assert_eq!(cache.get(&"pattern1"), Some("data1"));
548        assert_eq!(cache.get(&"pattern2"), Some("data2"));
549        assert_eq!(cache.get(&"pattern3"), Some("data3"));
550    }
551
552    #[test]
553    fn test_batch_cache() {
554        let cache = BatchCache::new(2);
555
556        // Add items to batch
557        cache.batch_insert("key1", "value1");
558        cache.batch_insert("key2", "value2");
559
560        // Should be able to retrieve from pending batch
561        assert_eq!(cache.get(&"key1"), Some("value1"));
562        assert_eq!(cache.get(&"key2"), Some("value2"));
563
564        // Add one more (should trigger flush)
565        cache.batch_insert("key3", "value3");
566
567        // Should still be able to retrieve all
568        assert_eq!(cache.get(&"key1"), Some("value1"));
569        assert_eq!(cache.get(&"key2"), Some("value2"));
570        assert_eq!(cache.get(&"key3"), Some("value3"));
571    }
572}