kaccy_ai/llm/
cache.rs

1//! LLM response caching for cost optimization
2//!
3//! This module provides caching capabilities for LLM responses to reduce
4//! API costs and improve response times for repeated queries.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::RwLock;
12
13use super::types::{ChatRequest, ChatResponse};
14
15/// Cache configuration
16#[derive(Debug, Clone)]
17pub struct LlmCacheConfig {
18    /// Maximum number of entries in the cache
19    pub max_entries: usize,
20    /// Time-to-live for cache entries
21    pub ttl: Duration,
22    /// Whether to enable semantic similarity matching
23    pub semantic_matching: bool,
24    /// Similarity threshold for semantic matching (0.0-1.0)
25    pub similarity_threshold: f64,
26    /// Maximum tokens in cached prompts (to avoid caching long conversations)
27    pub max_prompt_tokens: usize,
28}
29
30impl Default for LlmCacheConfig {
31    fn default() -> Self {
32        Self {
33            max_entries: 1000,
34            ttl: Duration::from_secs(3600), // 1 hour
35            semantic_matching: false,       // Disabled by default (requires embedding model)
36            similarity_threshold: 0.95,
37            max_prompt_tokens: 2000,
38        }
39    }
40}
41
42impl LlmCacheConfig {
43    /// Create config optimized for development (shorter TTL, smaller cache)
44    #[must_use]
45    pub fn development() -> Self {
46        Self {
47            max_entries: 100,
48            ttl: Duration::from_secs(300), // 5 minutes
49            semantic_matching: false,
50            similarity_threshold: 0.95,
51            max_prompt_tokens: 1000,
52        }
53    }
54
55    /// Create config optimized for production (larger cache, longer TTL)
56    #[must_use]
57    pub fn production() -> Self {
58        Self {
59            max_entries: 10000,
60            ttl: Duration::from_secs(86400), // 24 hours
61            semantic_matching: false,
62            similarity_threshold: 0.95,
63            max_prompt_tokens: 4000,
64        }
65    }
66}
67
68/// Cached LLM response entry
69#[derive(Debug, Clone)]
70struct CacheEntry {
71    /// The cached response
72    response: CachedResponse,
73    /// When the entry was created
74    created_at: Instant,
75    /// Number of times this entry was hit
76    hit_count: u64,
77    /// Last access time
78    last_accessed: Instant,
79}
80
81/// Cached response data
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CachedResponse {
84    /// Response content
85    pub content: String,
86    /// Finish reason
87    pub finish_reason: Option<String>,
88    /// Tokens used in prompt
89    pub prompt_tokens: u32,
90    /// Tokens used in completion
91    pub completion_tokens: u32,
92}
93
94impl From<&ChatResponse> for CachedResponse {
95    fn from(response: &ChatResponse) -> Self {
96        Self {
97            content: response.message.content.clone(),
98            finish_reason: response.finish_reason.clone(),
99            prompt_tokens: response.prompt_tokens,
100            completion_tokens: response.completion_tokens,
101        }
102    }
103}
104
105/// Cache key derived from request
106#[derive(Debug, Clone, PartialEq, Eq, Hash)]
107struct CacheKey {
108    /// Hash of the messages
109    messages_hash: u64,
110    /// Temperature (quantized to avoid floating point issues)
111    temperature_q: u32,
112    /// Max tokens (affects response length)
113    max_tokens: Option<u32>,
114}
115
116impl CacheKey {
117    fn from_request(request: &ChatRequest) -> Self {
118        let mut hasher = std::collections::hash_map::DefaultHasher::new();
119
120        // Hash all messages (role + content)
121        for msg in &request.messages {
122            // Hash role as string representation
123            let role_str = match msg.role {
124                super::types::ChatRole::System => "system",
125                super::types::ChatRole::User => "user",
126                super::types::ChatRole::Assistant => "assistant",
127            };
128            role_str.hash(&mut hasher);
129            msg.content.hash(&mut hasher);
130        }
131        let messages_hash = hasher.finish();
132
133        // Quantize temperature to avoid floating point comparison issues
134        let temperature_q = (request.temperature.unwrap_or(1.0) * 100.0) as u32;
135
136        Self {
137            messages_hash,
138            temperature_q,
139            max_tokens: request.max_tokens,
140        }
141    }
142}
143
144/// LLM response cache
145pub struct LlmCache {
146    config: LlmCacheConfig,
147    cache: Arc<RwLock<HashMap<CacheKey, CacheEntry>>>,
148    stats: Arc<RwLock<CacheStats>>,
149}
150
151/// Cache statistics
152#[derive(Debug, Clone, Default, Serialize)]
153pub struct CacheStats {
154    /// Total cache hits
155    pub hits: u64,
156    /// Total cache misses
157    pub misses: u64,
158    /// Total entries currently in cache
159    pub entries: usize,
160    /// Total bytes saved (estimated from tokens)
161    pub tokens_saved: u64,
162    /// Estimated cost saved (in cents, assuming $0.01 per 1K tokens)
163    pub estimated_cost_saved_cents: f64,
164    /// Evictions due to TTL
165    pub ttl_evictions: u64,
166    /// Evictions due to capacity
167    pub capacity_evictions: u64,
168}
169
170impl CacheStats {
171    /// Calculate hit rate
172    #[must_use]
173    pub fn hit_rate(&self) -> f64 {
174        let total = self.hits + self.misses;
175        if total == 0 {
176            0.0
177        } else {
178            self.hits as f64 / total as f64
179        }
180    }
181}
182
183impl LlmCache {
184    /// Create a new LLM cache
185    #[must_use]
186    pub fn new(config: LlmCacheConfig) -> Self {
187        Self {
188            config,
189            cache: Arc::new(RwLock::new(HashMap::new())),
190            stats: Arc::new(RwLock::new(CacheStats::default())),
191        }
192    }
193
194    /// Create with default configuration
195    #[must_use]
196    pub fn default_cache() -> Self {
197        Self::new(LlmCacheConfig::default())
198    }
199
200    /// Try to get a cached response
201    pub async fn get(&self, request: &ChatRequest) -> Option<CachedResponse> {
202        // Check if request is cacheable
203        if !self.is_cacheable(request) {
204            return None;
205        }
206
207        let key = CacheKey::from_request(request);
208
209        let mut cache = self.cache.write().await;
210
211        if let Some(entry) = cache.get_mut(&key) {
212            // Check TTL
213            if entry.created_at.elapsed() > self.config.ttl {
214                // Entry expired
215                cache.remove(&key);
216                let mut stats = self.stats.write().await;
217                stats.ttl_evictions += 1;
218                stats.misses += 1;
219                return None;
220            }
221
222            // Update access stats
223            entry.hit_count += 1;
224            entry.last_accessed = Instant::now();
225
226            // Update global stats
227            let mut stats = self.stats.write().await;
228            stats.hits += 1;
229            stats.tokens_saved +=
230                u64::from(entry.response.prompt_tokens + entry.response.completion_tokens);
231            stats.estimated_cost_saved_cents +=
232                f64::from(entry.response.prompt_tokens + entry.response.completion_tokens) / 1000.0;
233
234            return Some(entry.response.clone());
235        }
236
237        // Cache miss
238        let mut stats = self.stats.write().await;
239        stats.misses += 1;
240
241        None
242    }
243
244    /// Store a response in the cache
245    pub async fn put(&self, request: &ChatRequest, response: &ChatResponse) {
246        // Check if request is cacheable
247        if !self.is_cacheable(request) {
248            return;
249        }
250
251        let key = CacheKey::from_request(request);
252        let entry = CacheEntry {
253            response: CachedResponse::from(response),
254            created_at: Instant::now(),
255            hit_count: 0,
256            last_accessed: Instant::now(),
257        };
258
259        let mut cache = self.cache.write().await;
260
261        // Evict if at capacity
262        if cache.len() >= self.config.max_entries {
263            self.evict_lru(&mut cache).await;
264        }
265
266        cache.insert(key, entry);
267
268        // Update stats
269        let mut stats = self.stats.write().await;
270        stats.entries = cache.len();
271    }
272
273    /// Check if a request is cacheable
274    fn is_cacheable(&self, request: &ChatRequest) -> bool {
275        // Don't cache if temperature is high (non-deterministic)
276        if request.temperature.unwrap_or(1.0) > 0.5 {
277            return false;
278        }
279
280        // Don't cache very long conversations
281        let total_content_len: usize = request.messages.iter().map(|m| m.content.len()).sum();
282        let estimated_tokens = total_content_len / 4; // Rough estimate
283
284        if estimated_tokens > self.config.max_prompt_tokens {
285            return false;
286        }
287
288        // Don't cache if stream is requested
289        // (streaming responses are handled differently)
290
291        true
292    }
293
294    /// Evict least recently used entry
295    async fn evict_lru(&self, cache: &mut HashMap<CacheKey, CacheEntry>) {
296        if let Some((key_to_remove, _)) = cache
297            .iter()
298            .min_by_key(|(_, entry)| entry.last_accessed)
299            .map(|(k, v)| (k.clone(), v.clone()))
300        {
301            cache.remove(&key_to_remove);
302
303            let mut stats = self.stats.write().await;
304            stats.capacity_evictions += 1;
305        }
306    }
307
308    /// Clear expired entries
309    pub async fn clear_expired(&self) {
310        let mut cache = self.cache.write().await;
311        let now = Instant::now();
312        let ttl = self.config.ttl;
313
314        let expired_keys: Vec<CacheKey> = cache
315            .iter()
316            .filter(|(_, entry)| now.duration_since(entry.created_at) > ttl)
317            .map(|(k, _)| k.clone())
318            .collect();
319
320        let expired_count = expired_keys.len();
321
322        for key in expired_keys {
323            cache.remove(&key);
324        }
325
326        if expired_count > 0 {
327            let mut stats = self.stats.write().await;
328            stats.ttl_evictions += expired_count as u64;
329            stats.entries = cache.len();
330
331            tracing::debug!(expired = expired_count, "Cleared expired cache entries");
332        }
333    }
334
335    /// Clear all cache entries
336    pub async fn clear(&self) {
337        let mut cache = self.cache.write().await;
338        cache.clear();
339
340        let mut stats = self.stats.write().await;
341        stats.entries = 0;
342    }
343
344    /// Get cache statistics
345    pub async fn get_stats(&self) -> CacheStats {
346        let stats = self.stats.read().await;
347        stats.clone()
348    }
349
350    /// Get detailed cache info
351    pub async fn get_cache_info(&self) -> CacheInfo {
352        let cache = self.cache.read().await;
353        let stats = self.stats.read().await;
354
355        let total_hit_count: u64 = cache.values().map(|e| e.hit_count).sum();
356        let avg_hit_count = if cache.is_empty() {
357            0.0
358        } else {
359            total_hit_count as f64 / cache.len() as f64
360        };
361
362        let oldest_entry = cache.values().map(|e| e.created_at).min();
363        let newest_entry = cache.values().map(|e| e.created_at).max();
364
365        CacheInfo {
366            config: self.config.clone(),
367            stats: stats.clone(),
368            entry_count: cache.len(),
369            total_hit_count,
370            avg_hit_count_per_entry: avg_hit_count,
371            oldest_entry_age_secs: oldest_entry.map(|t| t.elapsed().as_secs()),
372            newest_entry_age_secs: newest_entry.map(|t| t.elapsed().as_secs()),
373        }
374    }
375}
376
377/// Detailed cache information
378#[derive(Debug, Clone, Serialize)]
379pub struct CacheInfo {
380    /// Cache configuration
381    #[serde(skip)]
382    pub config: LlmCacheConfig,
383    /// Cache statistics
384    pub stats: CacheStats,
385    /// Number of entries
386    pub entry_count: usize,
387    /// Total hit count across all entries
388    pub total_hit_count: u64,
389    /// Average hit count per entry
390    pub avg_hit_count_per_entry: f64,
391    /// Age of oldest entry in seconds
392    pub oldest_entry_age_secs: Option<u64>,
393    /// Age of newest entry in seconds
394    pub newest_entry_age_secs: Option<u64>,
395}
396
397/// Cached LLM client wrapper
398pub struct CachedLlmClient<P> {
399    /// Underlying provider
400    provider: P,
401    /// Cache
402    cache: LlmCache,
403}
404
405impl<P> CachedLlmClient<P> {
406    /// Create a new cached client
407    pub fn new(provider: P, cache_config: LlmCacheConfig) -> Self {
408        Self {
409            provider,
410            cache: LlmCache::new(cache_config),
411        }
412    }
413
414    /// Create with default cache config
415    pub fn with_default_cache(provider: P) -> Self {
416        Self {
417            provider,
418            cache: LlmCache::default_cache(),
419        }
420    }
421
422    /// Get the underlying provider
423    pub fn provider(&self) -> &P {
424        &self.provider
425    }
426
427    /// Get cache statistics
428    pub async fn cache_stats(&self) -> CacheStats {
429        self.cache.get_stats().await
430    }
431
432    /// Get detailed cache info
433    pub async fn cache_info(&self) -> CacheInfo {
434        self.cache.get_cache_info().await
435    }
436
437    /// Clear the cache
438    pub async fn clear_cache(&self) {
439        self.cache.clear().await;
440    }
441
442    /// Clear expired entries
443    pub async fn clear_expired(&self) {
444        self.cache.clear_expired().await;
445    }
446}
447
448/// Request deduplication to prevent duplicate in-flight requests
449pub struct RequestDeduplicator {
450    /// In-flight requests
451    in_flight: Arc<RwLock<HashMap<u64, tokio::sync::watch::Receiver<Option<CachedResponse>>>>>,
452}
453
454impl RequestDeduplicator {
455    /// Create a new deduplicator
456    #[must_use]
457    pub fn new() -> Self {
458        Self {
459            in_flight: Arc::new(RwLock::new(HashMap::new())),
460        }
461    }
462
463    /// Hash a request for deduplication
464    fn hash_request(request: &ChatRequest) -> u64 {
465        let key = CacheKey::from_request(request);
466        let mut hasher = std::collections::hash_map::DefaultHasher::new();
467        key.hash(&mut hasher);
468        hasher.finish()
469    }
470
471    /// Check if a request is in flight
472    pub async fn is_in_flight(&self, request: &ChatRequest) -> bool {
473        let hash = Self::hash_request(request);
474        let in_flight = self.in_flight.read().await;
475        in_flight.contains_key(&hash)
476    }
477
478    /// Register an in-flight request
479    pub async fn register(
480        &self,
481        request: &ChatRequest,
482    ) -> tokio::sync::watch::Sender<Option<CachedResponse>> {
483        let hash = Self::hash_request(request);
484        let (tx, rx) = tokio::sync::watch::channel(None);
485
486        let mut in_flight = self.in_flight.write().await;
487        in_flight.insert(hash, rx);
488
489        tx
490    }
491
492    /// Wait for an in-flight request to complete
493    pub async fn wait_for(&self, request: &ChatRequest) -> Option<CachedResponse> {
494        let hash = Self::hash_request(request);
495
496        let rx = {
497            let in_flight = self.in_flight.read().await;
498            in_flight.get(&hash).cloned()
499        };
500
501        if let Some(mut rx) = rx {
502            // Wait for the value to change
503            let _ = rx.changed().await;
504            rx.borrow().clone()
505        } else {
506            None
507        }
508    }
509
510    /// Complete an in-flight request
511    pub async fn complete(&self, request: &ChatRequest, response: Option<CachedResponse>) {
512        let hash = Self::hash_request(request);
513
514        let mut in_flight = self.in_flight.write().await;
515        in_flight.remove(&hash);
516
517        // The sender going out of scope will notify all waiters
518        drop(response);
519    }
520}
521
522impl Default for RequestDeduplicator {
523    fn default() -> Self {
524        Self::new()
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::{ChatMessage, ChatRole};
532
533    #[test]
534    fn test_cache_key_generation() {
535        let request = ChatRequest {
536            messages: vec![ChatMessage {
537                role: ChatRole::User,
538                content: "Hello".to_string(),
539            }],
540            temperature: Some(0.0),
541            max_tokens: None,
542            stop: None,
543            images: None,
544        };
545
546        let key1 = CacheKey::from_request(&request);
547        let key2 = CacheKey::from_request(&request);
548
549        assert_eq!(key1, key2);
550    }
551
552    #[test]
553    fn test_different_messages_different_keys() {
554        let request1 = ChatRequest {
555            messages: vec![ChatMessage {
556                role: ChatRole::User,
557                content: "Hello".to_string(),
558            }],
559            temperature: Some(0.0),
560            max_tokens: None,
561            stop: None,
562            images: None,
563        };
564
565        let request2 = ChatRequest {
566            messages: vec![ChatMessage {
567                role: ChatRole::User,
568                content: "Goodbye".to_string(),
569            }],
570            temperature: Some(0.0),
571            max_tokens: None,
572            stop: None,
573            images: None,
574        };
575
576        let key1 = CacheKey::from_request(&request1);
577        let key2 = CacheKey::from_request(&request2);
578
579        assert_ne!(key1, key2);
580    }
581
582    #[test]
583    fn test_cache_config_defaults() {
584        let config = LlmCacheConfig::default();
585        assert_eq!(config.max_entries, 1000);
586        assert_eq!(config.ttl, Duration::from_secs(3600));
587    }
588
589    #[tokio::test]
590    async fn test_cache_miss() {
591        let cache = LlmCache::default_cache();
592
593        let request = ChatRequest {
594            messages: vec![ChatMessage {
595                role: ChatRole::User,
596                content: "Test".to_string(),
597            }],
598            temperature: Some(0.0),
599            max_tokens: None,
600            stop: None,
601            images: None,
602        };
603
604        let result = cache.get(&request).await;
605        assert!(result.is_none());
606
607        let stats = cache.get_stats().await;
608        assert_eq!(stats.misses, 1);
609        assert_eq!(stats.hits, 0);
610    }
611
612    #[test]
613    fn test_not_cacheable_high_temperature() {
614        let cache = LlmCache::default_cache();
615
616        let request = ChatRequest {
617            messages: vec![ChatMessage {
618                role: ChatRole::User,
619                content: "Test".to_string(),
620            }],
621            temperature: Some(0.9), // High temperature
622            max_tokens: None,
623            stop: None,
624            images: None,
625        };
626
627        assert!(!cache.is_cacheable(&request));
628    }
629}