llmkit/
cache.rs

1//! Response caching infrastructure for LLM API calls.
2//!
3//! This module provides a flexible caching system to reduce API costs
4//! and improve response times by caching identical requests.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use llmkit::{CacheConfig, CachingProvider, InMemoryCache, OpenAIProvider};
10//!
11//! // Create a caching provider
12//! let inner = OpenAIProvider::from_env()?;
13//! let cache = InMemoryCache::new(CacheConfig::default());
14//! let provider = CachingProvider::new(inner, cache);
15//!
16//! // First request hits the API
17//! let response1 = provider.complete(request.clone()).await?;
18//!
19//! // Second identical request hits the cache
20//! let response2 = provider.complete(request).await?;
21//! ```
22//!
23//! # Cache Key Computation
24//!
25//! Cache keys are computed from:
26//! - Model name
27//! - Messages content
28//! - Tools (if any)
29//! - System prompt
30//!
31//! By default, non-deterministic parameters (temperature, top_p) are excluded
32//! from the cache key to allow caching regardless of sampling settings.
33
34use std::collections::HashSet;
35use std::pin::Pin;
36use std::sync::atomic::{AtomicU64, Ordering};
37use std::sync::Arc;
38use std::time::{Duration, SystemTime};
39
40use async_trait::async_trait;
41use dashmap::DashMap;
42use futures::Stream;
43use sha2::{Digest, Sha256};
44
45use crate::error::Result;
46use crate::provider::Provider;
47use crate::types::{CompletionRequest, CompletionResponse, StreamChunk};
48
49/// Configuration for the caching system.
50#[derive(Debug, Clone)]
51pub struct CacheConfig {
52    /// Whether caching is enabled.
53    pub enabled: bool,
54    /// Time-to-live for cached entries.
55    pub ttl: Duration,
56    /// Maximum number of entries in the cache.
57    pub max_entries: usize,
58    /// Whether to cache streaming responses (after collection).
59    pub cache_streaming: bool,
60    /// Fields to exclude from cache key (for deterministic caching).
61    pub exclude_fields: HashSet<String>,
62}
63
64impl Default for CacheConfig {
65    fn default() -> Self {
66        Self {
67            enabled: true,
68            ttl: Duration::from_secs(3600), // 1 hour
69            max_entries: 10_000,
70            cache_streaming: false,
71            exclude_fields: HashSet::from_iter([
72                "temperature".to_string(),
73                "top_p".to_string(),
74                "top_k".to_string(),
75                "seed".to_string(),
76            ]),
77        }
78    }
79}
80
81impl CacheConfig {
82    /// Create a new cache configuration.
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Set the TTL for cached entries.
88    pub fn with_ttl(mut self, ttl: Duration) -> Self {
89        self.ttl = ttl;
90        self
91    }
92
93    /// Set the maximum number of entries.
94    pub fn with_max_entries(mut self, max_entries: usize) -> Self {
95        self.max_entries = max_entries;
96        self
97    }
98
99    /// Enable or disable caching.
100    pub fn with_enabled(mut self, enabled: bool) -> Self {
101        self.enabled = enabled;
102        self
103    }
104
105    /// Enable or disable streaming cache.
106    pub fn with_cache_streaming(mut self, cache_streaming: bool) -> Self {
107        self.cache_streaming = cache_streaming;
108        self
109    }
110}
111
112/// A cached response with metadata.
113#[derive(Debug, Clone)]
114pub struct CachedResponse {
115    /// The cached response.
116    pub response: CompletionResponse,
117    /// When this entry was created.
118    pub created_at: SystemTime,
119    /// Number of times this entry was hit.
120    pub hit_count: Arc<AtomicU64>,
121}
122
123impl CachedResponse {
124    /// Create a new cached response.
125    pub fn new(response: CompletionResponse) -> Self {
126        Self {
127            response,
128            created_at: SystemTime::now(),
129            hit_count: Arc::new(AtomicU64::new(0)),
130        }
131    }
132
133    /// Check if this entry has expired.
134    pub fn is_expired(&self, ttl: Duration) -> bool {
135        self.created_at
136            .elapsed()
137            .map(|elapsed| elapsed > ttl)
138            .unwrap_or(true)
139    }
140
141    /// Increment the hit count and return the new value.
142    pub fn record_hit(&self) -> u64 {
143        self.hit_count.fetch_add(1, Ordering::Relaxed) + 1
144    }
145}
146
147/// Statistics about cache performance.
148#[derive(Debug, Clone, Default)]
149pub struct CacheStats {
150    /// Number of cache hits.
151    pub hits: u64,
152    /// Number of cache misses.
153    pub misses: u64,
154    /// Current number of entries.
155    pub entries: usize,
156    /// Approximate size in bytes.
157    pub size_bytes: usize,
158}
159
160impl CacheStats {
161    /// Get the hit rate (0.0 to 1.0).
162    pub fn hit_rate(&self) -> f64 {
163        let total = self.hits + self.misses;
164        if total == 0 {
165            0.0
166        } else {
167            self.hits as f64 / total as f64
168        }
169    }
170}
171
172/// Trait for cache backends.
173#[async_trait]
174pub trait CacheBackend: Send + Sync {
175    /// Get a cached response by key.
176    async fn get(&self, key: &str) -> Option<CachedResponse>;
177
178    /// Store a response in the cache.
179    async fn set(&self, key: &str, response: CachedResponse);
180
181    /// Invalidate a specific key.
182    async fn invalidate(&self, key: &str);
183
184    /// Clear all cache entries.
185    async fn clear(&self);
186
187    /// Get cache statistics.
188    fn stats(&self) -> CacheStats;
189}
190
191/// In-memory cache backend using DashMap.
192pub struct InMemoryCache {
193    entries: DashMap<String, CachedResponse>,
194    config: CacheConfig,
195    hits: AtomicU64,
196    misses: AtomicU64,
197}
198
199impl InMemoryCache {
200    /// Create a new in-memory cache with the given configuration.
201    pub fn new(config: CacheConfig) -> Arc<Self> {
202        Arc::new(Self {
203            entries: DashMap::new(),
204            config,
205            hits: AtomicU64::new(0),
206            misses: AtomicU64::new(0),
207        })
208    }
209
210    /// Create a new in-memory cache with default configuration.
211    pub fn default_cache() -> Arc<Self> {
212        Self::new(CacheConfig::default())
213    }
214
215    /// Evict expired entries.
216    pub fn evict_expired(&self) {
217        let ttl = self.config.ttl;
218        self.entries.retain(|_, v| !v.is_expired(ttl));
219    }
220
221    /// Evict entries to meet the max size.
222    fn evict_if_needed(&self) {
223        if self.entries.len() >= self.config.max_entries {
224            // Simple LRU-like eviction: remove oldest entries
225            // In a production system, you'd want a proper LRU implementation
226            let mut oldest_keys: Vec<(String, SystemTime)> = self
227                .entries
228                .iter()
229                .map(|e| (e.key().clone(), e.value().created_at))
230                .collect();
231
232            oldest_keys.sort_by(|a, b| a.1.cmp(&b.1));
233
234            // Remove 10% of oldest entries
235            let to_remove = self.config.max_entries / 10;
236            for (key, _) in oldest_keys.into_iter().take(to_remove) {
237                self.entries.remove(&key);
238            }
239        }
240    }
241}
242
243#[async_trait]
244impl CacheBackend for InMemoryCache {
245    async fn get(&self, key: &str) -> Option<CachedResponse> {
246        if let Some(entry) = self.entries.get(key) {
247            if entry.is_expired(self.config.ttl) {
248                self.entries.remove(key);
249                self.misses.fetch_add(1, Ordering::Relaxed);
250                None
251            } else {
252                entry.record_hit();
253                self.hits.fetch_add(1, Ordering::Relaxed);
254                Some(entry.clone())
255            }
256        } else {
257            self.misses.fetch_add(1, Ordering::Relaxed);
258            None
259        }
260    }
261
262    async fn set(&self, key: &str, response: CachedResponse) {
263        self.evict_if_needed();
264        self.entries.insert(key.to_string(), response);
265    }
266
267    async fn invalidate(&self, key: &str) {
268        self.entries.remove(key);
269    }
270
271    async fn clear(&self) {
272        self.entries.clear();
273        self.hits.store(0, Ordering::Relaxed);
274        self.misses.store(0, Ordering::Relaxed);
275    }
276
277    fn stats(&self) -> CacheStats {
278        CacheStats {
279            hits: self.hits.load(Ordering::Relaxed),
280            misses: self.misses.load(Ordering::Relaxed),
281            entries: self.entries.len(),
282            size_bytes: 0, // Would need serialization to compute accurately
283        }
284    }
285}
286
287/// A provider wrapper that caches responses.
288pub struct CachingProvider<P> {
289    /// The inner provider.
290    inner: P,
291    /// The cache backend.
292    cache: Arc<dyn CacheBackend>,
293    /// Cache configuration.
294    config: CacheConfig,
295}
296
297impl<P> CachingProvider<P> {
298    /// Create a new caching provider.
299    pub fn new(inner: P, cache: Arc<dyn CacheBackend>) -> Self {
300        Self {
301            inner,
302            cache,
303            config: CacheConfig::default(),
304        }
305    }
306
307    /// Create a new caching provider with custom configuration.
308    pub fn with_config(inner: P, cache: Arc<dyn CacheBackend>, config: CacheConfig) -> Self {
309        Self {
310            inner,
311            cache,
312            config,
313        }
314    }
315
316    /// Get the inner provider.
317    pub fn inner(&self) -> &P {
318        &self.inner
319    }
320
321    /// Get cache statistics.
322    pub fn stats(&self) -> CacheStats {
323        self.cache.stats()
324    }
325
326    /// Clear the cache.
327    pub async fn clear_cache(&self) {
328        self.cache.clear().await;
329    }
330
331    /// Compute a cache key for a request.
332    fn compute_cache_key(&self, request: &CompletionRequest) -> String {
333        // Create a normalized representation for hashing
334        let mut hasher = Sha256::new();
335
336        // Include model
337        hasher.update(request.model.as_bytes());
338        hasher.update(b"|");
339
340        // Include system prompt
341        if let Some(ref system) = request.system {
342            hasher.update(system.as_bytes());
343        }
344        hasher.update(b"|");
345
346        // Include messages
347        for msg in &request.messages {
348            hasher.update(format!("{:?}", msg.role).as_bytes());
349            hasher.update(b":");
350            for block in &msg.content {
351                hasher.update(format!("{:?}", block).as_bytes());
352            }
353            hasher.update(b";");
354        }
355        hasher.update(b"|");
356
357        // Include tools
358        if let Some(ref tools) = request.tools {
359            for tool in tools {
360                hasher.update(tool.name.as_bytes());
361                hasher.update(b":");
362                hasher.update(tool.description.as_bytes());
363                hasher.update(b";");
364            }
365        }
366        hasher.update(b"|");
367
368        // Include response format
369        if let Some(ref format) = request.response_format {
370            hasher.update(format!("{:?}", format.format_type).as_bytes());
371        }
372
373        format!("cache:{}", hex::encode(hasher.finalize()))
374    }
375}
376
377#[async_trait]
378impl<P: Provider> Provider for CachingProvider<P> {
379    fn name(&self) -> &str {
380        self.inner.name()
381    }
382
383    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
384        if !self.config.enabled {
385            return self.inner.complete(request).await;
386        }
387
388        let cache_key = self.compute_cache_key(&request);
389
390        // Check cache
391        if let Some(cached) = self.cache.get(&cache_key).await {
392            tracing::debug!(key = %cache_key, "Cache hit");
393            return Ok(cached.response);
394        }
395
396        // Call inner provider
397        tracing::debug!(key = %cache_key, "Cache miss");
398        let response = self.inner.complete(request).await?;
399
400        // Store in cache
401        let cached = CachedResponse::new(response.clone());
402        self.cache.set(&cache_key, cached).await;
403
404        Ok(response)
405    }
406
407    async fn complete_stream(
408        &self,
409        request: CompletionRequest,
410    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
411        // For streaming, we don't cache by default
412        // (would need to collect the stream and replay it)
413        self.inner.complete_stream(request).await
414    }
415
416    fn supports_tools(&self) -> bool {
417        self.inner.supports_tools()
418    }
419
420    fn supports_vision(&self) -> bool {
421        self.inner.supports_vision()
422    }
423
424    fn supports_streaming(&self) -> bool {
425        self.inner.supports_streaming()
426    }
427
428    fn supported_models(&self) -> Option<&[&str]> {
429        self.inner.supported_models()
430    }
431
432    fn default_model(&self) -> Option<&str> {
433        self.inner.default_model()
434    }
435}
436
437/// Cache key builder for custom cache key computation.
438#[derive(Default)]
439pub struct CacheKeyBuilder {
440    parts: Vec<String>,
441}
442
443impl CacheKeyBuilder {
444    /// Create a new cache key builder.
445    pub fn new() -> Self {
446        Self::default()
447    }
448
449    /// Add a part to the cache key.
450    pub fn with_part(mut self, part: impl Into<String>) -> Self {
451        self.parts.push(part.into());
452        self
453    }
454
455    /// Build the cache key.
456    pub fn build(self) -> String {
457        let mut hasher = Sha256::new();
458        for part in self.parts {
459            hasher.update(part.as_bytes());
460            hasher.update(b"|");
461        }
462        format!("cache:{}", hex::encode(hasher.finalize()))
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_cache_config_default() {
472        let config = CacheConfig::default();
473        assert!(config.enabled);
474        assert_eq!(config.ttl, Duration::from_secs(3600));
475        assert_eq!(config.max_entries, 10_000);
476        assert!(!config.cache_streaming);
477    }
478
479    #[test]
480    fn test_cache_config_builder() {
481        let config = CacheConfig::new()
482            .with_ttl(Duration::from_secs(600))
483            .with_max_entries(1000)
484            .with_enabled(false);
485
486        assert!(!config.enabled);
487        assert_eq!(config.ttl, Duration::from_secs(600));
488        assert_eq!(config.max_entries, 1000);
489    }
490
491    #[test]
492    fn test_cached_response_expiry() {
493        let response = CompletionResponse {
494            id: "test".to_string(),
495            model: "test".to_string(),
496            content: vec![],
497            stop_reason: crate::types::StopReason::EndTurn,
498            usage: crate::types::Usage::default(),
499        };
500
501        let cached = CachedResponse::new(response);
502
503        // Should not be expired with long TTL
504        assert!(!cached.is_expired(Duration::from_secs(3600)));
505
506        // Should be expired with zero TTL
507        assert!(cached.is_expired(Duration::from_secs(0)));
508    }
509
510    #[test]
511    fn test_cache_stats_hit_rate() {
512        let stats = CacheStats {
513            hits: 80,
514            misses: 20,
515            entries: 100,
516            size_bytes: 0,
517        };
518
519        assert!((stats.hit_rate() - 0.8).abs() < 0.001);
520    }
521
522    #[test]
523    fn test_cache_stats_hit_rate_zero() {
524        let stats = CacheStats::default();
525        assert_eq!(stats.hit_rate(), 0.0);
526    }
527
528    #[tokio::test]
529    async fn test_in_memory_cache() {
530        let cache = InMemoryCache::new(CacheConfig::default());
531
532        let response = CompletionResponse {
533            id: "test".to_string(),
534            model: "test".to_string(),
535            content: vec![],
536            stop_reason: crate::types::StopReason::EndTurn,
537            usage: crate::types::Usage::default(),
538        };
539
540        // Initially empty
541        assert!(cache.get("key1").await.is_none());
542
543        // After set
544        cache.set("key1", CachedResponse::new(response)).await;
545        assert!(cache.get("key1").await.is_some());
546
547        // Stats
548        let stats = cache.stats();
549        assert_eq!(stats.hits, 1);
550        assert_eq!(stats.misses, 1);
551        assert_eq!(stats.entries, 1);
552
553        // Invalidate
554        cache.invalidate("key1").await;
555        assert!(cache.get("key1").await.is_none());
556
557        // Clear
558        cache
559            .set(
560                "key2",
561                CachedResponse::new(CompletionResponse {
562                    id: "test2".to_string(),
563                    model: "test".to_string(),
564                    content: vec![],
565                    stop_reason: crate::types::StopReason::EndTurn,
566                    usage: crate::types::Usage::default(),
567                }),
568            )
569            .await;
570        cache.clear().await;
571        assert_eq!(cache.stats().entries, 0);
572    }
573
574    #[test]
575    fn test_cache_key_builder() {
576        let key = CacheKeyBuilder::new()
577            .with_part("model")
578            .with_part("prompt")
579            .build();
580
581        assert!(key.starts_with("cache:"));
582        assert_eq!(key.len(), 6 + 64); // "cache:" + 64 hex chars
583    }
584}