oxify_connect_llm/
cache.rs

1//! LLM response caching for cost optimization
2
3use crate::{LlmProvider, LlmRequest, LlmResponse, Result};
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10/// Cache key for LLM requests
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12struct CacheKey {
13    model: String,
14    prompt: String,
15    system_prompt: Option<String>,
16    temperature: Option<u32>, // Store as u32 (temperature * 1000) for hashing
17    max_tokens: Option<u32>,
18}
19
20impl CacheKey {
21    fn from_request(request: &LlmRequest, model: &str) -> Self {
22        Self {
23            model: model.to_string(),
24            prompt: request.prompt.clone(),
25            system_prompt: request.system_prompt.clone(),
26            temperature: request.temperature.map(|t| (t * 1000.0) as u32),
27            max_tokens: request.max_tokens,
28        }
29    }
30}
31
32/// Cached response with expiration
33#[derive(Debug, Clone)]
34struct CachedResponse {
35    response: LlmResponse,
36    inserted_at: Instant,
37    ttl: Duration,
38}
39
40impl CachedResponse {
41    fn is_expired(&self) -> bool {
42        self.inserted_at.elapsed() > self.ttl
43    }
44}
45
46/// In-memory LLM response cache
47#[derive(Debug)]
48pub struct LlmCache {
49    cache: Arc<Mutex<HashMap<CacheKey, CachedResponse>>>,
50    default_ttl: Duration,
51    max_size: usize,
52    hits: Arc<AtomicU64>,
53    misses: Arc<AtomicU64>,
54}
55
56impl Clone for LlmCache {
57    fn clone(&self) -> Self {
58        Self {
59            cache: Arc::clone(&self.cache),
60            default_ttl: self.default_ttl,
61            max_size: self.max_size,
62            hits: Arc::clone(&self.hits),
63            misses: Arc::clone(&self.misses),
64        }
65    }
66}
67
68impl Default for LlmCache {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl LlmCache {
75    /// Create a new cache with default settings
76    /// - TTL: 1 hour
77    /// - Max size: 1000 entries
78    pub fn new() -> Self {
79        Self {
80            cache: Arc::new(Mutex::new(HashMap::new())),
81            default_ttl: Duration::from_secs(3600), // 1 hour
82            max_size: 1000,
83            hits: Arc::new(AtomicU64::new(0)),
84            misses: Arc::new(AtomicU64::new(0)),
85        }
86    }
87
88    /// Create a new cache with custom settings
89    pub fn with_config(ttl: Duration, max_size: usize) -> Self {
90        Self {
91            cache: Arc::new(Mutex::new(HashMap::new())),
92            default_ttl: ttl,
93            max_size,
94            hits: Arc::new(AtomicU64::new(0)),
95            misses: Arc::new(AtomicU64::new(0)),
96        }
97    }
98
99    /// Get a cached response if it exists and hasn't expired
100    pub fn get(&self, request: &LlmRequest, model: &str) -> Option<LlmResponse> {
101        let key = CacheKey::from_request(request, model);
102        let mut cache = self.cache.lock().unwrap();
103
104        if let Some(cached) = cache.get(&key) {
105            if !cached.is_expired() {
106                self.hits.fetch_add(1, Ordering::Relaxed);
107                return Some(cached.response.clone());
108            } else {
109                // Remove expired entry
110                cache.remove(&key);
111            }
112        }
113
114        self.misses.fetch_add(1, Ordering::Relaxed);
115        None
116    }
117
118    /// Store a response in the cache
119    pub fn put(&self, request: &LlmRequest, model: &str, response: LlmResponse) {
120        let key = CacheKey::from_request(request, model);
121        let mut cache = self.cache.lock().unwrap();
122
123        // Evict oldest entry if cache is full (simple FIFO)
124        if cache.len() >= self.max_size {
125            if let Some(oldest_key) = cache.keys().next().cloned() {
126                cache.remove(&oldest_key);
127            }
128        }
129
130        cache.insert(
131            key,
132            CachedResponse {
133                response,
134                inserted_at: Instant::now(),
135                ttl: self.default_ttl,
136            },
137        );
138    }
139
140    /// Clear all cached entries
141    pub fn clear(&self) {
142        let mut cache = self.cache.lock().unwrap();
143        cache.clear();
144    }
145
146    /// Remove expired entries
147    pub fn cleanup(&self) {
148        let mut cache = self.cache.lock().unwrap();
149        cache.retain(|_, v| !v.is_expired());
150    }
151
152    /// Get cache statistics
153    pub fn stats(&self) -> CacheStats {
154        let cache = self.cache.lock().unwrap();
155        let total = cache.len();
156        let expired = cache.values().filter(|v| v.is_expired()).count();
157        let hits = self.hits.load(Ordering::Relaxed);
158        let misses = self.misses.load(Ordering::Relaxed);
159
160        CacheStats {
161            total_entries: total,
162            expired_entries: expired,
163            active_entries: total - expired,
164            hits,
165            misses,
166        }
167    }
168
169    /// Reset hit/miss counters
170    pub fn reset_stats(&self) {
171        self.hits.store(0, Ordering::Relaxed);
172        self.misses.store(0, Ordering::Relaxed);
173    }
174}
175
176/// Cache statistics
177#[derive(Debug, Clone)]
178pub struct CacheStats {
179    pub total_entries: usize,
180    pub expired_entries: usize,
181    pub active_entries: usize,
182    pub hits: u64,
183    pub misses: u64,
184}
185
186impl CacheStats {
187    /// Calculate hit rate as a percentage (0.0 to 100.0)
188    pub fn hit_rate(&self) -> f64 {
189        let total = self.hits + self.misses;
190        if total == 0 {
191            0.0
192        } else {
193            (self.hits as f64 / total as f64) * 100.0
194        }
195    }
196}
197
198// ===== CachedProvider Wrapper =====
199
200/// A wrapper that adds caching functionality to any LLM provider
201pub struct CachedProvider<P> {
202    inner: P,
203    cache: LlmCache,
204    model_name: String,
205}
206
207impl<P> CachedProvider<P> {
208    /// Create a new CachedProvider with default cache settings
209    pub fn new(provider: P, model_name: String) -> Self {
210        Self {
211            inner: provider,
212            cache: LlmCache::new(),
213            model_name,
214        }
215    }
216
217    /// Create a new CachedProvider with a custom cache
218    pub fn with_cache(provider: P, model_name: String, cache: LlmCache) -> Self {
219        Self {
220            inner: provider,
221            cache,
222            model_name,
223        }
224    }
225
226    /// Get a reference to the inner provider
227    pub fn inner(&self) -> &P {
228        &self.inner
229    }
230
231    /// Get a mutable reference to the inner provider
232    pub fn inner_mut(&mut self) -> &mut P {
233        &mut self.inner
234    }
235
236    /// Get a reference to the cache
237    pub fn cache(&self) -> &LlmCache {
238        &self.cache
239    }
240
241    /// Get cache statistics
242    pub fn stats(&self) -> CacheStats {
243        self.cache.stats()
244    }
245}
246
247#[async_trait]
248impl<P: LlmProvider> LlmProvider for CachedProvider<P> {
249    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
250        // Check cache first
251        if let Some(cached) = self.cache.get(&request, &self.model_name) {
252            tracing::debug!(
253                model = %self.model_name,
254                "Cache hit for LLM request"
255            );
256            return Ok(cached);
257        }
258
259        // Cache miss - call the inner provider
260        let response = self.inner.complete(request.clone()).await?;
261
262        // Store in cache
263        self.cache.put(&request, &self.model_name, response.clone());
264
265        Ok(response)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::Usage;
273
274    #[test]
275    fn test_cache_hit() {
276        let cache = LlmCache::new();
277
278        let request = LlmRequest {
279            prompt: "Hello".to_string(),
280            system_prompt: None,
281            temperature: Some(0.7),
282            max_tokens: Some(100),
283            tools: Vec::new(),
284            images: Vec::new(),
285        };
286
287        let response = LlmResponse {
288            content: "Hi there!".to_string(),
289            model: "gpt-4".to_string(),
290            usage: Some(Usage {
291                prompt_tokens: 10,
292                completion_tokens: 5,
293                total_tokens: 15,
294            }),
295            tool_calls: Vec::new(),
296        };
297
298        // Cache miss
299        assert!(cache.get(&request, "gpt-4").is_none());
300
301        // Store in cache
302        cache.put(&request, "gpt-4", response.clone());
303
304        // Cache hit
305        let cached = cache.get(&request, "gpt-4").unwrap();
306        assert_eq!(cached.content, response.content);
307    }
308
309    #[test]
310    fn test_cache_expiration() {
311        let cache = LlmCache::with_config(Duration::from_millis(500), 100);
312
313        let request = LlmRequest {
314            prompt: "Hello".to_string(),
315            system_prompt: None,
316            temperature: Some(0.7),
317            max_tokens: Some(100),
318            tools: Vec::new(),
319            images: Vec::new(),
320        };
321
322        let response = LlmResponse {
323            content: "Hi there!".to_string(),
324            model: "gpt-4".to_string(),
325            usage: None,
326            tool_calls: Vec::new(),
327        };
328
329        cache.put(&request, "gpt-4", response);
330
331        // Should be in cache immediately
332        assert!(cache.get(&request, "gpt-4").is_some());
333
334        // Wait for expiration (sleep longer than TTL to ensure expiration)
335        std::thread::sleep(Duration::from_millis(600));
336
337        // Should be expired
338        assert!(cache.get(&request, "gpt-4").is_none());
339    }
340
341    #[test]
342    fn test_cache_max_size() {
343        let cache = LlmCache::with_config(Duration::from_secs(3600), 3);
344
345        for i in 0..5 {
346            let request = LlmRequest {
347                prompt: format!("Prompt {}", i),
348                system_prompt: None,
349                temperature: Some(0.7),
350                max_tokens: Some(100),
351                tools: Vec::new(),
352                images: Vec::new(),
353            };
354
355            let response = LlmResponse {
356                content: format!("Response {}", i),
357                model: "gpt-4".to_string(),
358                usage: None,
359                tool_calls: Vec::new(),
360            };
361
362            cache.put(&request, "gpt-4", response);
363        }
364
365        let stats = cache.stats();
366        assert_eq!(stats.total_entries, 3); // Max size enforced
367    }
368
369    #[test]
370    fn test_cache_cleanup() {
371        let cache = LlmCache::with_config(Duration::from_millis(10), 100);
372
373        // Add some entries that will expire
374        for i in 0..5 {
375            let request = LlmRequest {
376                prompt: format!("Prompt {}", i),
377                system_prompt: None,
378                temperature: Some(0.7),
379                max_tokens: Some(100),
380                tools: Vec::new(),
381                images: Vec::new(),
382            };
383
384            let response = LlmResponse {
385                content: format!("Response {}", i),
386                model: "gpt-4".to_string(),
387                usage: None,
388                tool_calls: Vec::new(),
389            };
390
391            cache.put(&request, "gpt-4", response);
392        }
393
394        assert_eq!(cache.stats().total_entries, 5);
395
396        // Wait for expiration
397        std::thread::sleep(Duration::from_millis(20));
398
399        // Cleanup
400        cache.cleanup();
401
402        assert_eq!(cache.stats().total_entries, 0);
403    }
404
405    #[test]
406    fn test_cache_hit_rate() {
407        let cache = LlmCache::new();
408
409        let request = LlmRequest {
410            prompt: "Hello".to_string(),
411            system_prompt: None,
412            temperature: Some(0.7),
413            max_tokens: Some(100),
414            tools: Vec::new(),
415            images: Vec::new(),
416        };
417
418        let response = LlmResponse {
419            content: "Hi there!".to_string(),
420            model: "gpt-4".to_string(),
421            usage: None,
422            tool_calls: Vec::new(),
423        };
424
425        // First get is a miss
426        cache.get(&request, "gpt-4");
427        assert_eq!(cache.stats().misses, 1);
428        assert_eq!(cache.stats().hits, 0);
429        assert_eq!(cache.stats().hit_rate(), 0.0);
430
431        // Store in cache
432        cache.put(&request, "gpt-4", response);
433
434        // Second get is a hit
435        cache.get(&request, "gpt-4");
436        assert_eq!(cache.stats().hits, 1);
437        assert_eq!(cache.stats().misses, 1);
438        assert_eq!(cache.stats().hit_rate(), 50.0);
439
440        // Third get is also a hit
441        cache.get(&request, "gpt-4");
442        assert_eq!(cache.stats().hits, 2);
443        assert_eq!(cache.stats().misses, 1);
444
445        // Hit rate should be approximately 66.67%
446        let hit_rate = cache.stats().hit_rate();
447        assert!(hit_rate > 66.0 && hit_rate < 67.0);
448
449        // Reset stats
450        cache.reset_stats();
451        assert_eq!(cache.stats().hits, 0);
452        assert_eq!(cache.stats().misses, 0);
453    }
454}