Skip to main content

litellm_rs/core/traits/
cache.rs

1//! Cache system trait definitions
2//!
3//! Provides unified cache interface supporting multiple cache backends
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9/// Core cache trait
10///
11/// Defines unified cache operation interface
12#[async_trait]
13pub trait Cache<K, V>: Send + Sync
14where
15    K: Send + Sync,
16    V: Send + Sync,
17{
18    /// Error
19    type Error: std::error::Error + Send + Sync + 'static;
20
21    /// Get
22    async fn get(&self, key: &K) -> Result<Option<V>, Self::Error>;
23
24    /// Settings
25    async fn set(&self, key: &K, value: V, ttl: Duration) -> Result<(), Self::Error>;
26
27    /// Delete
28    async fn delete(&self, key: &K) -> Result<bool, Self::Error>;
29
30    /// Check
31    async fn exists(&self, key: &K) -> Result<bool, Self::Error>;
32
33    /// Settings
34    async fn expire(&self, key: &K, ttl: Duration) -> Result<bool, Self::Error>;
35
36    /// Get
37    async fn ttl(&self, key: &K) -> Result<Option<Duration>, Self::Error>;
38
39    /// Clear all cache
40    async fn clear(&self) -> Result<(), Self::Error>;
41
42    /// Get
43    async fn size(&self) -> Result<usize, Self::Error>;
44
45    /// Get
46    async fn get_many(&self, keys: &[K]) -> Result<Vec<Option<V>>, Self::Error> {
47        let mut results = Vec::with_capacity(keys.len());
48        for key in keys {
49            results.push(self.get(key).await?);
50        }
51        Ok(results)
52    }
53
54    /// Settings
55    async fn set_many(&self, items: &[(K, V, Duration)]) -> Result<(), Self::Error>
56    where
57        K: Clone,
58        V: Clone,
59    {
60        for (key, value, ttl) in items {
61            self.set(key, value.clone(), *ttl).await?;
62        }
63        Ok(())
64    }
65}
66
67/// Cache key trait
68///
69/// Defines operations that cache keys must support
70pub trait CacheKey: Send + Sync + Clone + std::fmt::Debug + std::hash::Hash + Eq {
71    /// Serialize key to string
72    fn to_cache_key(&self) -> String;
73
74    /// Deserialize key from string
75    fn from_cache_key(s: &str) -> Result<Self, CacheError>
76    where
77        Self: Sized;
78}
79
80/// Cache value trait
81///
82/// Defines operations that cache values must support
83pub trait CacheValue: Send + Sync + Clone + std::fmt::Debug {
84    /// Serialize to bytes
85    fn to_bytes(&self) -> Result<Vec<u8>, CacheError>;
86
87    /// Deserialize from bytes
88    fn from_bytes(bytes: &[u8]) -> Result<Self, CacheError>
89    where
90        Self: Sized;
91}
92
93/// Implementation of CacheKey for String
94impl CacheKey for String {
95    fn to_cache_key(&self) -> String {
96        self.clone()
97    }
98
99    fn from_cache_key(s: &str) -> Result<Self, CacheError> {
100        Ok(s.to_string())
101    }
102}
103
104/// Implementation of CacheValue for all types that implement Serialize + DeserializeOwned
105impl<T> CacheValue for T
106where
107    T: Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + std::fmt::Debug,
108{
109    fn to_bytes(&self) -> Result<Vec<u8>, CacheError> {
110        bincode::serialize(self).map_err(CacheError::Serialization)
111    }
112
113    fn from_bytes(bytes: &[u8]) -> Result<Self, CacheError> {
114        bincode::deserialize(bytes).map_err(CacheError::Deserialization)
115    }
116}
117
118/// Cache statistics
119#[derive(Debug, Clone)]
120pub struct CacheStats {
121    /// Cache hit count
122    pub hits: u64,
123    /// Cache miss count
124    pub misses: u64,
125    /// Current key count
126    pub key_count: usize,
127    /// Used memory amount (bytes)
128    pub memory_usage: usize,
129    /// Hit rate
130    pub hit_rate: f64,
131}
132
133impl CacheStats {
134    pub fn new() -> Self {
135        Self {
136            hits: 0,
137            misses: 0,
138            key_count: 0,
139            memory_usage: 0,
140            hit_rate: 0.0,
141        }
142    }
143
144    pub fn calculate_hit_rate(&mut self) {
145        let total = self.hits + self.misses;
146        if total > 0 {
147            self.hit_rate = self.hits as f64 / total as f64;
148        }
149    }
150}
151
152impl Default for CacheStats {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Cache trait with statistics functionality
159#[async_trait]
160pub trait CacheWithStats<K, V>: Cache<K, V>
161where
162    K: Send + Sync,
163    V: Send + Sync,
164{
165    /// Get
166    async fn stats(&self) -> Result<CacheStats, Self::Error>;
167
168    /// Reset statistics
169    async fn reset_stats(&self) -> Result<(), Self::Error>;
170}
171
172/// Cache event types
173#[derive(Debug, Clone)]
174pub enum CacheEvent<K, V> {
175    /// Cache hit
176    Hit { key: K },
177    /// Cache miss
178    Miss { key: K },
179    /// Settings
180    Set { key: K, value: V },
181    /// Delete
182    Delete { key: K },
183    /// Cache expiration
184    Expire { key: K },
185    /// Cache clear
186    Clear,
187}
188
189/// Cache event listener
190#[async_trait]
191pub trait CacheEventListener<K, V>: Send + Sync
192where
193    K: Send + Sync,
194    V: Send + Sync,
195{
196    /// Handle
197    async fn on_event(&self, event: CacheEvent<K, V>);
198}
199
200/// Error
201#[derive(Debug, thiserror::Error)]
202pub enum CacheError {
203    #[error("Connection failed: {0}")]
204    Connection(String),
205
206    #[error("Serialization failed: {0}")]
207    Serialization(#[from] Box<bincode::ErrorKind>),
208
209    #[error("Deserialization failed: {0}")]
210    Deserialization(Box<bincode::ErrorKind>),
211
212    #[error("Key not found: {key}")]
213    KeyNotFound { key: String },
214
215    #[error("Cache is full")]
216    CacheFull,
217
218    #[error("Invalid TTL: {ttl_ms}ms")]
219    InvalidTTL { ttl_ms: u64 },
220
221    #[error("Cache operation timeout")]
222    Timeout,
223
224    #[error("Cache backend error: {0}")]
225    Backend(String),
226
227    #[error("Other cache error: {0}")]
228    Other(String),
229}
230
231impl CacheError {
232    pub fn connection(msg: impl Into<String>) -> Self {
233        Self::Connection(msg.into())
234    }
235
236    pub fn key_not_found(key: impl Into<String>) -> Self {
237        Self::KeyNotFound { key: key.into() }
238    }
239
240    pub fn backend(msg: impl Into<String>) -> Self {
241        Self::Backend(msg.into())
242    }
243
244    pub fn other(msg: impl Into<String>) -> Self {
245        Self::Other(msg.into())
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    // ==================== CacheKey Tests ====================
254
255    #[test]
256    fn test_string_cache_key_to_cache_key() {
257        let key = "my-cache-key".to_string();
258        assert_eq!(key.to_cache_key(), "my-cache-key");
259    }
260
261    #[test]
262    fn test_string_cache_key_from_cache_key() {
263        let key = String::from_cache_key("restored-key").unwrap();
264        assert_eq!(key, "restored-key");
265    }
266
267    #[test]
268    fn test_string_cache_key_roundtrip() {
269        let original = "test-key-123".to_string();
270        let serialized = original.to_cache_key();
271        let restored = String::from_cache_key(&serialized).unwrap();
272        assert_eq!(original, restored);
273    }
274
275    // ==================== CacheValue Tests ====================
276
277    #[test]
278    fn test_cache_value_to_bytes_string() {
279        let value = "hello world".to_string();
280        let bytes = value.to_bytes();
281        assert!(bytes.is_ok());
282        assert!(!bytes.unwrap().is_empty());
283    }
284
285    #[test]
286    fn test_cache_value_from_bytes_string() {
287        let value = "test value".to_string();
288        let bytes = value.to_bytes().unwrap();
289        let restored = String::from_bytes(&bytes).unwrap();
290        assert_eq!(value, restored);
291    }
292
293    #[test]
294    fn test_cache_value_roundtrip_integer() {
295        let value: i32 = 42;
296        let bytes = value.to_bytes().unwrap();
297        let restored = i32::from_bytes(&bytes).unwrap();
298        assert_eq!(value, restored);
299    }
300
301    #[test]
302    fn test_cache_value_roundtrip_complex() {
303        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
304        struct TestData {
305            id: u64,
306            name: String,
307            active: bool,
308        }
309
310        let value = TestData {
311            id: 123,
312            name: "test".to_string(),
313            active: true,
314        };
315        let bytes = value.to_bytes().unwrap();
316        let restored = TestData::from_bytes(&bytes).unwrap();
317        assert_eq!(value, restored);
318    }
319
320    // ==================== CacheStats Tests ====================
321
322    #[test]
323    fn test_cache_stats_new() {
324        let stats = CacheStats::new();
325        assert_eq!(stats.hits, 0);
326        assert_eq!(stats.misses, 0);
327        assert_eq!(stats.key_count, 0);
328        assert_eq!(stats.memory_usage, 0);
329        assert!((stats.hit_rate - 0.0).abs() < f64::EPSILON);
330    }
331
332    #[test]
333    fn test_cache_stats_default() {
334        let stats = CacheStats::default();
335        assert_eq!(stats.hits, 0);
336        assert_eq!(stats.misses, 0);
337    }
338
339    #[test]
340    fn test_cache_stats_calculate_hit_rate_zero_total() {
341        let mut stats = CacheStats::new();
342        stats.calculate_hit_rate();
343        assert!((stats.hit_rate - 0.0).abs() < f64::EPSILON);
344    }
345
346    #[test]
347    fn test_cache_stats_calculate_hit_rate_all_hits() {
348        let mut stats = CacheStats::new();
349        stats.hits = 100;
350        stats.misses = 0;
351        stats.calculate_hit_rate();
352        assert!((stats.hit_rate - 1.0).abs() < f64::EPSILON);
353    }
354
355    #[test]
356    fn test_cache_stats_calculate_hit_rate_all_misses() {
357        let mut stats = CacheStats::new();
358        stats.hits = 0;
359        stats.misses = 100;
360        stats.calculate_hit_rate();
361        assert!((stats.hit_rate - 0.0).abs() < f64::EPSILON);
362    }
363
364    #[test]
365    fn test_cache_stats_calculate_hit_rate_mixed() {
366        let mut stats = CacheStats::new();
367        stats.hits = 75;
368        stats.misses = 25;
369        stats.calculate_hit_rate();
370        assert!((stats.hit_rate - 0.75).abs() < 0.001);
371    }
372
373    #[test]
374    fn test_cache_stats_clone() {
375        let mut stats = CacheStats::new();
376        stats.hits = 10;
377        stats.key_count = 5;
378        let cloned = stats.clone();
379        assert_eq!(stats.hits, cloned.hits);
380        assert_eq!(stats.key_count, cloned.key_count);
381    }
382
383    #[test]
384    fn test_cache_stats_debug() {
385        let stats = CacheStats::new();
386        let debug = format!("{:?}", stats);
387        assert!(debug.contains("CacheStats"));
388    }
389
390    // ==================== CacheEvent Tests ====================
391
392    #[test]
393    fn test_cache_event_hit() {
394        let event: CacheEvent<String, i32> = CacheEvent::Hit {
395            key: "key1".to_string(),
396        };
397        assert!(matches!(event, CacheEvent::Hit { key } if key == "key1"));
398    }
399
400    #[test]
401    fn test_cache_event_miss() {
402        let event: CacheEvent<String, i32> = CacheEvent::Miss {
403            key: "key2".to_string(),
404        };
405        assert!(matches!(event, CacheEvent::Miss { key } if key == "key2"));
406    }
407
408    #[test]
409    fn test_cache_event_set() {
410        let event = CacheEvent::Set {
411            key: "key3".to_string(),
412            value: 42,
413        };
414        assert!(matches!(event, CacheEvent::Set { key, value } if key == "key3" && value == 42));
415    }
416
417    #[test]
418    fn test_cache_event_delete() {
419        let event: CacheEvent<String, i32> = CacheEvent::Delete {
420            key: "key4".to_string(),
421        };
422        assert!(matches!(event, CacheEvent::Delete { key } if key == "key4"));
423    }
424
425    #[test]
426    fn test_cache_event_expire() {
427        let event: CacheEvent<String, i32> = CacheEvent::Expire {
428            key: "key5".to_string(),
429        };
430        assert!(matches!(event, CacheEvent::Expire { key } if key == "key5"));
431    }
432
433    #[test]
434    fn test_cache_event_clear() {
435        let event: CacheEvent<String, i32> = CacheEvent::Clear;
436        assert!(matches!(event, CacheEvent::Clear));
437    }
438
439    #[test]
440    fn test_cache_event_clone() {
441        let event = CacheEvent::Set {
442            key: "key".to_string(),
443            value: 100,
444        };
445        let cloned = event.clone();
446        assert!(matches!(cloned, CacheEvent::Set { key, value } if key == "key" && value == 100));
447    }
448
449    // ==================== CacheError Tests ====================
450
451    #[test]
452    fn test_cache_error_connection() {
453        let err = CacheError::connection("Redis connection failed");
454        assert!(matches!(err, CacheError::Connection(_)));
455        assert!(err.to_string().contains("Connection failed"));
456    }
457
458    #[test]
459    fn test_cache_error_key_not_found() {
460        let err = CacheError::key_not_found("missing-key");
461        assert!(matches!(err, CacheError::KeyNotFound { .. }));
462        assert!(err.to_string().contains("Key not found"));
463        assert!(err.to_string().contains("missing-key"));
464    }
465
466    #[test]
467    fn test_cache_error_cache_full() {
468        let err = CacheError::CacheFull;
469        assert!(err.to_string().contains("Cache is full"));
470    }
471
472    #[test]
473    fn test_cache_error_invalid_ttl() {
474        let err = CacheError::InvalidTTL { ttl_ms: 0 };
475        assert!(err.to_string().contains("Invalid TTL"));
476    }
477
478    #[test]
479    fn test_cache_error_timeout() {
480        let err = CacheError::Timeout;
481        assert!(err.to_string().contains("timeout"));
482    }
483
484    #[test]
485    fn test_cache_error_backend() {
486        let err = CacheError::backend("Backend failure");
487        assert!(matches!(err, CacheError::Backend(_)));
488        assert!(err.to_string().contains("Backend"));
489    }
490
491    #[test]
492    fn test_cache_error_other() {
493        let err = CacheError::other("Some other error");
494        assert!(matches!(err, CacheError::Other(_)));
495    }
496
497    #[test]
498    fn test_cache_error_display() {
499        let err = CacheError::connection("test error");
500        let display = format!("{}", err);
501        assert!(!display.is_empty());
502    }
503
504    #[test]
505    fn test_cache_error_debug() {
506        let err = CacheError::CacheFull;
507        let debug = format!("{:?}", err);
508        assert!(debug.contains("CacheFull"));
509    }
510}