1use crate::error::ProviderError;
7use crate::models::{ChatRequest, ChatResponse};
8use ricecoder_storage::{CacheInvalidationStrategy, CacheManager};
9use sha2::{Digest, Sha256};
10use std::path::Path;
11use std::sync::Arc;
12use tracing::{debug, info};
13
14pub struct ProviderCache {
19 cache: Arc<CacheManager>,
20 ttl_seconds: u64,
21}
22
23impl ProviderCache {
24 pub fn new(cache_dir: impl AsRef<Path>, ttl_seconds: u64) -> Result<Self, ProviderError> {
35 let cache = CacheManager::new(cache_dir)
36 .map_err(|e| ProviderError::Internal(format!("Failed to create cache: {}", e)))?;
37
38 Ok(Self {
39 cache: Arc::new(cache),
40 ttl_seconds,
41 })
42 }
43
44 pub fn get(
56 &self,
57 provider: &str,
58 model: &str,
59 request: &ChatRequest,
60 ) -> Result<Option<ChatResponse>, ProviderError> {
61 let cache_key = self.make_cache_key(provider, model, request);
62
63 match self.cache.get(&cache_key) {
64 Ok(Some(cached_json_str)) => {
65 match serde_json::from_str::<ChatResponse>(&cached_json_str) {
66 Ok(response) => {
67 debug!("Cache hit for provider response: {}/{}", provider, model);
68 Ok(Some(response))
69 }
70 Err(e) => {
71 debug!("Failed to deserialize cached response: {}", e);
72 let _ = self.cache.invalidate(&cache_key);
74 Ok(None)
75 }
76 }
77 }
78 Ok(None) => {
79 debug!("Cache miss for provider response: {}/{}", provider, model);
80 Ok(None)
81 }
82 Err(e) => {
83 debug!("Cache lookup error: {}", e);
84 Ok(None)
85 }
86 }
87 }
88
89 pub fn set(
102 &self,
103 provider: &str,
104 model: &str,
105 request: &ChatRequest,
106 response: &ChatResponse,
107 ) -> Result<(), ProviderError> {
108 let cache_key = self.make_cache_key(provider, model, request);
109
110 let response_json = serde_json::to_string(response)
111 .map_err(|e| ProviderError::Internal(format!("Failed to serialize response: {}", e)))?;
112
113 let json_len = response_json.len();
114
115 self.cache
116 .set(
117 &cache_key,
118 response_json,
119 CacheInvalidationStrategy::Ttl(self.ttl_seconds),
120 )
121 .map_err(|e| ProviderError::Internal(format!("Failed to cache response: {}", e)))?;
122
123 debug!(
124 "Cached response for {}/{}: {} bytes",
125 provider, model, json_len
126 );
127
128 Ok(())
129 }
130
131 pub fn invalidate(
143 &self,
144 provider: &str,
145 model: &str,
146 request: &ChatRequest,
147 ) -> Result<bool, ProviderError> {
148 let cache_key = self.make_cache_key(provider, model, request);
149
150 self.cache
151 .invalidate(&cache_key)
152 .map_err(|e| ProviderError::Internal(format!("Failed to invalidate cache: {}", e)))
153 }
154
155 pub fn clear(&self) -> Result<(), ProviderError> {
161 self.cache
162 .clear()
163 .map_err(|e| ProviderError::Internal(format!("Failed to clear cache: {}", e)))
164 }
165
166 pub fn cleanup_expired(&self) -> Result<usize, ProviderError> {
172 let cleaned = self
173 .cache
174 .cleanup_expired()
175 .map_err(|e| ProviderError::Internal(format!("Failed to cleanup cache: {}", e)))?;
176
177 if cleaned > 0 {
178 info!("Cleaned up {} expired cache entries", cleaned);
179 }
180
181 Ok(cleaned)
182 }
183
184 fn make_cache_key(&self, provider: &str, model: &str, request: &ChatRequest) -> String {
186 let request_json = serde_json::to_string(request).unwrap_or_default();
188
189 let mut hasher = Sha256::new();
190 hasher.update(provider.as_bytes());
191 hasher.update(b"|");
192 hasher.update(model.as_bytes());
193 hasher.update(b"|");
194 hasher.update(request_json.as_bytes());
195
196 let hash = format!("{:x}", hasher.finalize());
197
198 format!("provider_response_{}", hash)
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::models::{FinishReason, Message, TokenUsage};
206 use tempfile::TempDir;
207
208 fn create_test_request() -> ChatRequest {
209 ChatRequest {
210 model: "gpt-4".to_string(),
211 messages: vec![Message {
212 role: "user".to_string(),
213 content: "Hello".to_string(),
214 }],
215 temperature: Some(0.7),
216 max_tokens: Some(100),
217 stream: false,
218 }
219 }
220
221 fn create_test_response() -> ChatResponse {
222 ChatResponse {
223 content: "Hi there!".to_string(),
224 model: "gpt-4".to_string(),
225 usage: TokenUsage {
226 prompt_tokens: 10,
227 completion_tokens: 5,
228 total_tokens: 15,
229 },
230 finish_reason: FinishReason::Stop,
231 }
232 }
233
234 #[test]
235 fn test_cache_set_and_get() -> Result<(), ProviderError> {
236 let temp_dir = TempDir::new().unwrap();
237 let cache = ProviderCache::new(temp_dir.path(), 3600)?;
238
239 let request = create_test_request();
240 let response = create_test_response();
241
242 cache.set("openai", "gpt-4", &request, &response)?;
244
245 let cached = cache.get("openai", "gpt-4", &request)?;
247 assert!(cached.is_some());
248 assert_eq!(cached.unwrap().content, "Hi there!");
249
250 Ok(())
251 }
252
253 #[test]
254 fn test_cache_miss() -> Result<(), ProviderError> {
255 let temp_dir = TempDir::new().unwrap();
256 let cache = ProviderCache::new(temp_dir.path(), 3600)?;
257
258 let request = create_test_request();
259
260 let cached = cache.get("openai", "gpt-4", &request)?;
262 assert!(cached.is_none());
263
264 Ok(())
265 }
266
267 #[test]
268 fn test_cache_invalidate() -> Result<(), ProviderError> {
269 let temp_dir = TempDir::new().unwrap();
270 let cache = ProviderCache::new(temp_dir.path(), 3600)?;
271
272 let request = create_test_request();
273 let response = create_test_response();
274
275 cache.set("openai", "gpt-4", &request, &response)?;
277
278 let invalidated = cache.invalidate("openai", "gpt-4", &request)?;
280 assert!(invalidated);
281
282 let cached = cache.get("openai", "gpt-4", &request)?;
284 assert!(cached.is_none());
285
286 Ok(())
287 }
288
289 #[test]
290 fn test_cache_clear() -> Result<(), ProviderError> {
291 let temp_dir = TempDir::new().unwrap();
292 let cache = ProviderCache::new(temp_dir.path(), 3600)?;
293
294 let request = create_test_request();
295 let response = create_test_response();
296
297 cache.set("openai", "gpt-4", &request, &response)?;
299 cache.set("anthropic", "claude-3", &request, &response)?;
300
301 cache.clear()?;
303
304 assert!(cache.get("openai", "gpt-4", &request)?.is_none());
306 assert!(cache.get("anthropic", "claude-3", &request)?.is_none());
307
308 Ok(())
309 }
310
311 #[test]
312 fn test_different_requests_different_cache() -> Result<(), ProviderError> {
313 let temp_dir = TempDir::new().unwrap();
314 let cache = ProviderCache::new(temp_dir.path(), 3600)?;
315
316 let mut request1 = create_test_request();
317 let mut request2 = create_test_request();
318 request2.messages[0].content = "Different message".to_string();
319
320 let response1 = ChatResponse {
321 content: "Response 1".to_string(),
322 model: "gpt-4".to_string(),
323 usage: TokenUsage {
324 prompt_tokens: 10,
325 completion_tokens: 5,
326 total_tokens: 15,
327 },
328 finish_reason: FinishReason::Stop,
329 };
330
331 let response2 = ChatResponse {
332 content: "Response 2".to_string(),
333 model: "gpt-4".to_string(),
334 usage: TokenUsage {
335 prompt_tokens: 10,
336 completion_tokens: 5,
337 total_tokens: 15,
338 },
339 finish_reason: FinishReason::Stop,
340 };
341
342 cache.set("openai", "gpt-4", &request1, &response1)?;
344 cache.set("openai", "gpt-4", &request2, &response2)?;
345
346 let cached1 = cache.get("openai", "gpt-4", &request1)?;
348 let cached2 = cache.get("openai", "gpt-4", &request2)?;
349
350 assert_eq!(cached1.unwrap().content, "Response 1");
351 assert_eq!(cached2.unwrap().content, "Response 2");
352
353 Ok(())
354 }
355}