Skip to main content

rez_lsp_server/performance/
cache.rs

1//! Caching system for performance optimization.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::RwLock;
9
10/// A cache entry with expiration time.
11#[derive(Debug, Clone)]
12struct CacheEntry<T> {
13    value: T,
14    created_at: Instant,
15    ttl: Duration,
16}
17
18impl<T> CacheEntry<T> {
19    fn new(value: T, ttl: Duration) -> Self {
20        Self {
21            value,
22            created_at: Instant::now(),
23            ttl,
24        }
25    }
26
27    fn is_expired(&self) -> bool {
28        self.created_at.elapsed() > self.ttl
29    }
30}
31
32/// A thread-safe cache with TTL support.
33pub struct Cache<K, V> {
34    data: Arc<RwLock<HashMap<K, CacheEntry<V>>>>,
35    default_ttl: Duration,
36    max_size: usize,
37}
38
39impl<K, V> Cache<K, V>
40where
41    K: Eq + Hash + Clone,
42    V: Clone,
43{
44    /// Create a new cache with the given default TTL and maximum size.
45    pub fn new(default_ttl: Duration, max_size: usize) -> Self {
46        Self {
47            data: Arc::new(RwLock::new(HashMap::new())),
48            default_ttl,
49            max_size,
50        }
51    }
52
53    /// Get a value from the cache.
54    pub async fn get(&self, key: &K) -> Option<V> {
55        let data = self.data.read().await;
56        if let Some(entry) = data.get(key) {
57            if !entry.is_expired() {
58                return Some(entry.value.clone());
59            }
60        }
61        None
62    }
63
64    /// Put a value into the cache with default TTL.
65    pub async fn put(&self, key: K, value: V) {
66        self.put_with_ttl(key, value, self.default_ttl).await;
67    }
68
69    /// Put a value into the cache with custom TTL.
70    pub async fn put_with_ttl(&self, key: K, value: V, ttl: Duration) {
71        let mut data = self.data.write().await;
72
73        // Remove expired entries if we're at capacity
74        if data.len() >= self.max_size {
75            self.cleanup_expired(&mut data);
76
77            // If still at capacity, remove oldest entry
78            if data.len() >= self.max_size {
79                if let Some(oldest_key) = self.find_oldest_key(&data) {
80                    data.remove(&oldest_key);
81                }
82            }
83        }
84
85        data.insert(key, CacheEntry::new(value, ttl));
86    }
87
88    /// Remove a value from the cache.
89    pub async fn remove(&self, key: &K) -> Option<V> {
90        let mut data = self.data.write().await;
91        data.remove(key).map(|entry| entry.value)
92    }
93
94    /// Clear all entries from the cache.
95    pub async fn clear(&self) {
96        let mut data = self.data.write().await;
97        data.clear();
98    }
99
100    /// Get cache statistics.
101    pub async fn stats(&self) -> CacheStats {
102        let data = self.data.read().await;
103        let total_entries = data.len();
104        let expired_entries = data.values().filter(|entry| entry.is_expired()).count();
105        let active_entries = total_entries - expired_entries;
106
107        CacheStats {
108            total_entries,
109            active_entries,
110            expired_entries,
111            max_size: self.max_size,
112            hit_ratio: 0.0, // This would need to be tracked separately
113        }
114    }
115
116    /// Clean up expired entries.
117    fn cleanup_expired(&self, data: &mut HashMap<K, CacheEntry<V>>) {
118        data.retain(|_, entry| !entry.is_expired());
119    }
120
121    /// Find the oldest entry key.
122    fn find_oldest_key(&self, data: &HashMap<K, CacheEntry<V>>) -> Option<K> {
123        data.iter()
124            .min_by_key(|(_, entry)| entry.created_at)
125            .map(|(key, _)| key.clone())
126    }
127}
128
129/// Cache statistics.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct CacheStats {
132    /// Total number of entries (including expired)
133    pub total_entries: usize,
134    /// Number of active (non-expired) entries
135    pub active_entries: usize,
136    /// Number of expired entries
137    pub expired_entries: usize,
138    /// Maximum cache size
139    pub max_size: usize,
140    /// Cache hit ratio (0.0 to 1.0)
141    pub hit_ratio: f64,
142}
143
144/// Cache manager that handles multiple caches.
145pub struct CacheManager {
146    /// Cache for package discovery results
147    package_cache: Cache<String, Vec<crate::core::types::Package>>,
148    /// Cache for validation results
149    validation_cache: Cache<String, crate::validation::ValidationResult>,
150    /// Cache for completion results
151    completion_cache: Cache<String, Vec<String>>,
152    /// Cache statistics
153    stats: Arc<RwLock<CacheManagerStats>>,
154}
155
156#[derive(Debug, Clone)]
157struct CacheManagerStats {
158    hits: u64,
159    misses: u64,
160    puts: u64,
161    #[allow(dead_code)]
162    evictions: u64,
163}
164
165impl CacheManager {
166    /// Create a new cache manager.
167    pub fn new(config: &super::PerformanceConfig) -> Self {
168        let ttl = Duration::from_secs(config.cache_ttl_seconds);
169        let max_size = config.cache_size_mb * 1024 / 10; // Rough estimate of entries per MB
170
171        Self {
172            package_cache: Cache::new(ttl, max_size),
173            validation_cache: Cache::new(ttl, max_size / 2),
174            completion_cache: Cache::new(Duration::from_secs(60), max_size / 4), // Shorter TTL for completions
175            stats: Arc::new(RwLock::new(CacheManagerStats {
176                hits: 0,
177                misses: 0,
178                puts: 0,
179                evictions: 0,
180            })),
181        }
182    }
183
184    /// Get packages from cache.
185    pub async fn get_packages(&self, key: &str) -> Option<Vec<crate::core::types::Package>> {
186        let result = self.package_cache.get(&key.to_string()).await;
187        self.update_stats(result.is_some()).await;
188        result
189    }
190
191    /// Put packages into cache.
192    pub async fn put_packages(&self, key: String, packages: Vec<crate::core::types::Package>) {
193        self.package_cache.put(key, packages).await;
194        self.increment_puts().await;
195    }
196
197    /// Get validation result from cache.
198    pub async fn get_validation(&self, key: &str) -> Option<crate::validation::ValidationResult> {
199        let result = self.validation_cache.get(&key.to_string()).await;
200        self.update_stats(result.is_some()).await;
201        result
202    }
203
204    /// Put validation result into cache.
205    pub async fn put_validation(&self, key: String, result: crate::validation::ValidationResult) {
206        self.validation_cache.put(key, result).await;
207        self.increment_puts().await;
208    }
209
210    /// Get completion results from cache.
211    pub async fn get_completions(&self, key: &str) -> Option<Vec<String>> {
212        let result = self.completion_cache.get(&key.to_string()).await;
213        self.update_stats(result.is_some()).await;
214        result
215    }
216
217    /// Put completion results into cache.
218    pub async fn put_completions(&self, key: String, completions: Vec<String>) {
219        self.completion_cache.put(key, completions).await;
220        self.increment_puts().await;
221    }
222
223    /// Clear all caches.
224    pub async fn clear_all(&self) {
225        self.package_cache.clear().await;
226        self.validation_cache.clear().await;
227        self.completion_cache.clear().await;
228    }
229
230    /// Get overall cache statistics.
231    pub async fn get_stats(&self) -> CacheStats {
232        let stats = self.stats.read().await;
233        let total_requests = stats.hits + stats.misses;
234        let hit_ratio = if total_requests > 0 {
235            stats.hits as f64 / total_requests as f64
236        } else {
237            0.0
238        };
239
240        // Get individual cache stats
241        let package_stats = self.package_cache.stats().await;
242        let validation_stats = self.validation_cache.stats().await;
243        let completion_stats = self.completion_cache.stats().await;
244
245        CacheStats {
246            total_entries: package_stats.total_entries
247                + validation_stats.total_entries
248                + completion_stats.total_entries,
249            active_entries: package_stats.active_entries
250                + validation_stats.active_entries
251                + completion_stats.active_entries,
252            expired_entries: package_stats.expired_entries
253                + validation_stats.expired_entries
254                + completion_stats.expired_entries,
255            max_size: package_stats.max_size
256                + validation_stats.max_size
257                + completion_stats.max_size,
258            hit_ratio,
259        }
260    }
261
262    /// Update hit/miss statistics.
263    async fn update_stats(&self, is_hit: bool) {
264        let mut stats = self.stats.write().await;
265        if is_hit {
266            stats.hits += 1;
267        } else {
268            stats.misses += 1;
269        }
270    }
271
272    /// Increment put counter.
273    async fn increment_puts(&self) {
274        let mut stats = self.stats.write().await;
275        stats.puts += 1;
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use tokio::time::{sleep, Duration};
283
284    #[tokio::test]
285    async fn test_cache_basic_operations() {
286        let cache = Cache::new(Duration::from_secs(1), 10);
287
288        // Test put and get
289        cache.put("key1".to_string(), "value1".to_string()).await;
290        assert_eq!(
291            cache.get(&"key1".to_string()).await,
292            Some("value1".to_string())
293        );
294
295        // Test non-existent key
296        assert_eq!(cache.get(&"key2".to_string()).await, None);
297
298        // Test remove
299        assert_eq!(
300            cache.remove(&"key1".to_string()).await,
301            Some("value1".to_string())
302        );
303        assert_eq!(cache.get(&"key1".to_string()).await, None);
304    }
305
306    #[tokio::test]
307    async fn test_cache_expiration() {
308        let cache = Cache::new(Duration::from_millis(50), 10);
309
310        cache.put("key1".to_string(), "value1".to_string()).await;
311        assert_eq!(
312            cache.get(&"key1".to_string()).await,
313            Some("value1".to_string())
314        );
315
316        // Wait for expiration
317        sleep(Duration::from_millis(60)).await;
318        assert_eq!(cache.get(&"key1".to_string()).await, None);
319    }
320
321    #[tokio::test]
322    async fn test_cache_size_limit() {
323        let cache = Cache::new(Duration::from_secs(10), 2);
324
325        cache.put("key1".to_string(), "value1".to_string()).await;
326        cache.put("key2".to_string(), "value2".to_string()).await;
327        cache.put("key3".to_string(), "value3".to_string()).await;
328
329        let stats = cache.stats().await;
330        assert!(stats.total_entries <= 2);
331    }
332
333    #[tokio::test]
334    async fn test_cache_stats() {
335        let cache = Cache::new(Duration::from_secs(1), 10);
336
337        cache.put("key1".to_string(), "value1".to_string()).await;
338        cache.put("key2".to_string(), "value2".to_string()).await;
339
340        let stats = cache.stats().await;
341        assert_eq!(stats.active_entries, 2);
342        assert_eq!(stats.max_size, 10);
343    }
344
345    #[tokio::test]
346    async fn test_cache_manager() {
347        let config = super::super::PerformanceConfig::default();
348        let manager = CacheManager::new(&config);
349
350        // Test completion cache
351        manager
352            .put_completions(
353                "test_key".to_string(),
354                vec!["comp1".to_string(), "comp2".to_string()],
355            )
356            .await;
357        let completions = manager.get_completions("test_key").await;
358        assert_eq!(
359            completions,
360            Some(vec!["comp1".to_string(), "comp2".to_string()])
361        );
362
363        // Test stats
364        let stats = manager.get_stats().await;
365        assert!(stats.hit_ratio > 0.0);
366    }
367}