Skip to main content

mockforge_intelligence/intelligent_behavior/
cache.rs

1//! Response caching for intelligent behavior
2//!
3//! This module provides a simple cache for LLM responses to improve performance
4//! and reduce API costs.
5
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::RwLock;
11
12/// Cache entry with TTL
13#[derive(Clone)]
14struct CacheEntry<V> {
15    value: V,
16    inserted_at: Instant,
17    ttl: Duration,
18}
19
20impl<V> CacheEntry<V> {
21    fn new(value: V, ttl: Duration) -> Self {
22        Self {
23            value,
24            inserted_at: Instant::now(),
25            ttl,
26        }
27    }
28
29    fn is_expired(&self) -> bool {
30        self.inserted_at.elapsed() > self.ttl
31    }
32}
33
34/// Simple TTL-based cache for responses
35pub struct ResponseCache {
36    /// Cache storage
37    storage: Arc<RwLock<HashMap<String, CacheEntry<serde_json::Value>>>>,
38
39    /// Default TTL
40    default_ttl: Duration,
41}
42
43impl ResponseCache {
44    /// Create a new response cache
45    pub fn new(ttl_seconds: u64) -> Self {
46        Self {
47            storage: Arc::new(RwLock::new(HashMap::new())),
48            default_ttl: Duration::from_secs(ttl_seconds),
49        }
50    }
51
52    /// Get a value from cache
53    pub async fn get(&self, key: &str) -> Option<serde_json::Value> {
54        let storage = self.storage.read().await;
55
56        if let Some(entry) = storage.get(key) {
57            if !entry.is_expired() {
58                return Some(entry.value.clone());
59            }
60        }
61
62        None
63    }
64
65    /// Put a value in cache
66    pub async fn put(&self, key: String, value: serde_json::Value) {
67        let mut storage = self.storage.write().await;
68        storage.insert(key, CacheEntry::new(value, self.default_ttl));
69    }
70
71    /// Put a value with custom TTL
72    pub async fn put_with_ttl(&self, key: String, value: serde_json::Value, ttl: Duration) {
73        let mut storage = self.storage.write().await;
74        storage.insert(key, CacheEntry::new(value, ttl));
75    }
76
77    /// Remove a value from cache
78    pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
79        let mut storage = self.storage.write().await;
80        storage.remove(key).map(|entry| entry.value)
81    }
82
83    /// Clear all expired entries
84    pub async fn cleanup_expired(&self) -> usize {
85        let mut storage = self.storage.write().await;
86
87        let expired_keys: Vec<String> = storage
88            .iter()
89            .filter(|(_, entry)| entry.is_expired())
90            .map(|(key, _)| key.clone())
91            .collect();
92
93        let count = expired_keys.len();
94        for key in expired_keys {
95            storage.remove(&key);
96        }
97
98        count
99    }
100
101    /// Clear all cache entries
102    pub async fn clear(&self) {
103        let mut storage = self.storage.write().await;
104        storage.clear();
105    }
106
107    /// Get cache size
108    pub async fn size(&self) -> usize {
109        let storage = self.storage.read().await;
110        storage.len()
111    }
112}
113
114/// Generate a cache key from method, path, and request body
115pub fn generate_cache_key(method: &str, path: &str, body: Option<&serde_json::Value>) -> String {
116    use std::collections::hash_map::DefaultHasher;
117
118    let mut hasher = DefaultHasher::new();
119    method.hash(&mut hasher);
120    path.hash(&mut hasher);
121
122    if let Some(body) = body {
123        if let Ok(json_str) = serde_json::to_string(body) {
124            json_str.hash(&mut hasher);
125        }
126    }
127
128    format!("{}:{}:{:x}", method, path, hasher.finish())
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use serde_json::json;
135
136    #[tokio::test]
137    async fn test_cache_get_put() {
138        let cache = ResponseCache::new(60);
139
140        let value = json!({"message": "test"});
141        cache.put("test_key".to_string(), value.clone()).await;
142
143        let retrieved = cache.get("test_key").await;
144        assert_eq!(retrieved, Some(value));
145    }
146
147    #[tokio::test]
148    async fn test_cache_expiration() {
149        let cache = ResponseCache::new(1); // 1 second TTL
150
151        let value = json!({"message": "test"});
152        cache.put("test_key".to_string(), value.clone()).await;
153
154        // Should be present initially
155        assert!(cache.get("test_key").await.is_some());
156
157        // Wait for expiration
158        tokio::time::sleep(Duration::from_secs(2)).await;
159
160        // Should be expired
161        assert!(cache.get("test_key").await.is_none());
162    }
163
164    #[tokio::test]
165    async fn test_cache_cleanup() {
166        let cache = ResponseCache::new(1);
167
168        cache.put("key1".to_string(), json!("value1")).await;
169        cache.put("key2".to_string(), json!("value2")).await;
170
171        // Wait for expiration
172        tokio::time::sleep(Duration::from_secs(2)).await;
173
174        let cleaned = cache.cleanup_expired().await;
175        assert_eq!(cleaned, 2);
176        assert_eq!(cache.size().await, 0);
177    }
178
179    #[test]
180    fn test_cache_key_generation() {
181        let key1 = generate_cache_key("GET", "/api/users", None);
182        let key2 = generate_cache_key("GET", "/api/users", None);
183        let key3 = generate_cache_key("POST", "/api/users", None);
184
185        // Same request should generate same key
186        assert_eq!(key1, key2);
187
188        // Different method should generate different key
189        assert_ne!(key1, key3);
190    }
191
192    #[test]
193    fn test_cache_key_with_body() {
194        let body1 = json!({"name": "Alice"});
195        let body2 = json!({"name": "Bob"});
196
197        let key1 = generate_cache_key("POST", "/api/users", Some(&body1));
198        let key2 = generate_cache_key("POST", "/api/users", Some(&body2));
199
200        // Different body should generate different key
201        assert_ne!(key1, key2);
202    }
203}