ricecoder_providers/
cache.rs

1//! Provider response caching layer
2//!
3//! Caches AI provider responses to avoid redundant API calls.
4//! Uses file-based cache with TTL support.
5
6use crate::error::ProviderError;
7use crate::models::{ChatRequest, ChatResponse};
8use ricecoder_storage::{CacheInvalidationStrategy, CacheManager};
9use sha2::{Digest, Sha256};
10use std::path::Path;
11use std::sync::Arc;
12use tracing::{debug, info};
13
14/// Provider response cache
15///
16/// Caches AI provider responses to improve performance and reduce API calls.
17/// Uses SHA256 hash of request to create cache keys.
18pub struct ProviderCache {
19    cache: Arc<CacheManager>,
20    ttl_seconds: u64,
21}
22
23impl ProviderCache {
24    /// Create a new provider cache
25    ///
26    /// # Arguments
27    ///
28    /// * `cache_dir` - Directory to store cache files
29    /// * `ttl_seconds` - Time-to-live for cache entries (default: 86400 = 24 hours)
30    ///
31    /// # Errors
32    ///
33    /// Returns error if cache directory cannot be created
34    pub fn new(cache_dir: impl AsRef<Path>, ttl_seconds: u64) -> Result<Self, ProviderError> {
35        let cache = CacheManager::new(cache_dir)
36            .map_err(|e| ProviderError::Internal(format!("Failed to create cache: {}", e)))?;
37
38        Ok(Self {
39            cache: Arc::new(cache),
40            ttl_seconds,
41        })
42    }
43
44    /// Get a cached response
45    ///
46    /// # Arguments
47    ///
48    /// * `provider` - Provider name (e.g., "openai", "anthropic")
49    /// * `model` - Model name (e.g., "gpt-4", "claude-3")
50    /// * `request` - Chat request
51    ///
52    /// # Returns
53    ///
54    /// Returns cached response if found and not expired, None otherwise
55    pub fn get(
56        &self,
57        provider: &str,
58        model: &str,
59        request: &ChatRequest,
60    ) -> Result<Option<ChatResponse>, ProviderError> {
61        let cache_key = self.make_cache_key(provider, model, request);
62
63        match self.cache.get(&cache_key) {
64            Ok(Some(cached_json_str)) => {
65                match serde_json::from_str::<ChatResponse>(&cached_json_str) {
66                    Ok(response) => {
67                        debug!("Cache hit for provider response: {}/{}", provider, model);
68                        Ok(Some(response))
69                    }
70                    Err(e) => {
71                        debug!("Failed to deserialize cached response: {}", e);
72                        // Invalidate corrupted cache entry
73                        let _ = self.cache.invalidate(&cache_key);
74                        Ok(None)
75                    }
76                }
77            }
78            Ok(None) => {
79                debug!("Cache miss for provider response: {}/{}", provider, model);
80                Ok(None)
81            }
82            Err(e) => {
83                debug!("Cache lookup error: {}", e);
84                Ok(None)
85            }
86        }
87    }
88
89    /// Cache a response
90    ///
91    /// # Arguments
92    ///
93    /// * `provider` - Provider name
94    /// * `model` - Model name
95    /// * `request` - Chat request
96    /// * `response` - Chat response to cache
97    ///
98    /// # Errors
99    ///
100    /// Returns error if response cannot be cached
101    pub fn set(
102        &self,
103        provider: &str,
104        model: &str,
105        request: &ChatRequest,
106        response: &ChatResponse,
107    ) -> Result<(), ProviderError> {
108        let cache_key = self.make_cache_key(provider, model, request);
109
110        let response_json = serde_json::to_string(response)
111            .map_err(|e| ProviderError::Internal(format!("Failed to serialize response: {}", e)))?;
112
113        let json_len = response_json.len();
114
115        self.cache
116            .set(
117                &cache_key,
118                response_json,
119                CacheInvalidationStrategy::Ttl(self.ttl_seconds),
120            )
121            .map_err(|e| ProviderError::Internal(format!("Failed to cache response: {}", e)))?;
122
123        debug!(
124            "Cached response for {}/{}: {} bytes",
125            provider, model, json_len
126        );
127
128        Ok(())
129    }
130
131    /// Invalidate a cached response
132    ///
133    /// # Arguments
134    ///
135    /// * `provider` - Provider name
136    /// * `model` - Model name
137    /// * `request` - Chat request
138    ///
139    /// # Returns
140    ///
141    /// Returns Ok(true) if entry was deleted, Ok(false) if entry didn't exist
142    pub fn invalidate(
143        &self,
144        provider: &str,
145        model: &str,
146        request: &ChatRequest,
147    ) -> Result<bool, ProviderError> {
148        let cache_key = self.make_cache_key(provider, model, request);
149
150        self.cache
151            .invalidate(&cache_key)
152            .map_err(|e| ProviderError::Internal(format!("Failed to invalidate cache: {}", e)))
153    }
154
155    /// Clear all cached responses
156    ///
157    /// # Errors
158    ///
159    /// Returns error if cache cannot be cleared
160    pub fn clear(&self) -> Result<(), ProviderError> {
161        self.cache
162            .clear()
163            .map_err(|e| ProviderError::Internal(format!("Failed to clear cache: {}", e)))
164    }
165
166    /// Clean up expired cache entries
167    ///
168    /// # Returns
169    ///
170    /// Returns the number of entries cleaned up
171    pub fn cleanup_expired(&self) -> Result<usize, ProviderError> {
172        let cleaned = self
173            .cache
174            .cleanup_expired()
175            .map_err(|e| ProviderError::Internal(format!("Failed to cleanup cache: {}", e)))?;
176
177        if cleaned > 0 {
178            info!("Cleaned up {} expired cache entries", cleaned);
179        }
180
181        Ok(cleaned)
182    }
183
184    /// Create a cache key from provider, model, and request
185    fn make_cache_key(&self, provider: &str, model: &str, request: &ChatRequest) -> String {
186        // Create a deterministic hash of the request
187        let request_json = serde_json::to_string(request).unwrap_or_default();
188
189        let mut hasher = Sha256::new();
190        hasher.update(provider.as_bytes());
191        hasher.update(b"|");
192        hasher.update(model.as_bytes());
193        hasher.update(b"|");
194        hasher.update(request_json.as_bytes());
195
196        let hash = format!("{:x}", hasher.finalize());
197
198        format!("provider_response_{}", hash)
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::models::{FinishReason, Message, TokenUsage};
206    use tempfile::TempDir;
207
208    fn create_test_request() -> ChatRequest {
209        ChatRequest {
210            model: "gpt-4".to_string(),
211            messages: vec![Message {
212                role: "user".to_string(),
213                content: "Hello".to_string(),
214            }],
215            temperature: Some(0.7),
216            max_tokens: Some(100),
217            stream: false,
218        }
219    }
220
221    fn create_test_response() -> ChatResponse {
222        ChatResponse {
223            content: "Hi there!".to_string(),
224            model: "gpt-4".to_string(),
225            usage: TokenUsage {
226                prompt_tokens: 10,
227                completion_tokens: 5,
228                total_tokens: 15,
229            },
230            finish_reason: FinishReason::Stop,
231        }
232    }
233
234    #[test]
235    fn test_cache_set_and_get() -> Result<(), ProviderError> {
236        let temp_dir = TempDir::new().unwrap();
237        let cache = ProviderCache::new(temp_dir.path(), 3600)?;
238
239        let request = create_test_request();
240        let response = create_test_response();
241
242        // Cache response
243        cache.set("openai", "gpt-4", &request, &response)?;
244
245        // Retrieve from cache
246        let cached = cache.get("openai", "gpt-4", &request)?;
247        assert!(cached.is_some());
248        assert_eq!(cached.unwrap().content, "Hi there!");
249
250        Ok(())
251    }
252
253    #[test]
254    fn test_cache_miss() -> Result<(), ProviderError> {
255        let temp_dir = TempDir::new().unwrap();
256        let cache = ProviderCache::new(temp_dir.path(), 3600)?;
257
258        let request = create_test_request();
259
260        // Try to get non-existent entry
261        let cached = cache.get("openai", "gpt-4", &request)?;
262        assert!(cached.is_none());
263
264        Ok(())
265    }
266
267    #[test]
268    fn test_cache_invalidate() -> Result<(), ProviderError> {
269        let temp_dir = TempDir::new().unwrap();
270        let cache = ProviderCache::new(temp_dir.path(), 3600)?;
271
272        let request = create_test_request();
273        let response = create_test_response();
274
275        // Cache response
276        cache.set("openai", "gpt-4", &request, &response)?;
277
278        // Invalidate
279        let invalidated = cache.invalidate("openai", "gpt-4", &request)?;
280        assert!(invalidated);
281
282        // Should be gone now
283        let cached = cache.get("openai", "gpt-4", &request)?;
284        assert!(cached.is_none());
285
286        Ok(())
287    }
288
289    #[test]
290    fn test_cache_clear() -> Result<(), ProviderError> {
291        let temp_dir = TempDir::new().unwrap();
292        let cache = ProviderCache::new(temp_dir.path(), 3600)?;
293
294        let request = create_test_request();
295        let response = create_test_response();
296
297        // Cache multiple responses
298        cache.set("openai", "gpt-4", &request, &response)?;
299        cache.set("anthropic", "claude-3", &request, &response)?;
300
301        // Clear all
302        cache.clear()?;
303
304        // Both should be gone
305        assert!(cache.get("openai", "gpt-4", &request)?.is_none());
306        assert!(cache.get("anthropic", "claude-3", &request)?.is_none());
307
308        Ok(())
309    }
310
311    #[test]
312    fn test_different_requests_different_cache() -> Result<(), ProviderError> {
313        let temp_dir = TempDir::new().unwrap();
314        let cache = ProviderCache::new(temp_dir.path(), 3600)?;
315
316        let mut request1 = create_test_request();
317        let mut request2 = create_test_request();
318        request2.messages[0].content = "Different message".to_string();
319
320        let response1 = ChatResponse {
321            content: "Response 1".to_string(),
322            model: "gpt-4".to_string(),
323            usage: TokenUsage {
324                prompt_tokens: 10,
325                completion_tokens: 5,
326                total_tokens: 15,
327            },
328            finish_reason: FinishReason::Stop,
329        };
330
331        let response2 = ChatResponse {
332            content: "Response 2".to_string(),
333            model: "gpt-4".to_string(),
334            usage: TokenUsage {
335                prompt_tokens: 10,
336                completion_tokens: 5,
337                total_tokens: 15,
338            },
339            finish_reason: FinishReason::Stop,
340        };
341
342        // Cache different responses for different requests
343        cache.set("openai", "gpt-4", &request1, &response1)?;
344        cache.set("openai", "gpt-4", &request2, &response2)?;
345
346        // Verify they're cached separately
347        let cached1 = cache.get("openai", "gpt-4", &request1)?;
348        let cached2 = cache.get("openai", "gpt-4", &request2)?;
349
350        assert_eq!(cached1.unwrap().content, "Response 1");
351        assert_eq!(cached2.unwrap().content, "Response 2");
352
353        Ok(())
354    }
355}