1use crate::{LlmProvider, LlmRequest, LlmResponse, Result};
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12struct CacheKey {
13 model: String,
14 prompt: String,
15 system_prompt: Option<String>,
16 temperature: Option<u32>, max_tokens: Option<u32>,
18}
19
20impl CacheKey {
21 fn from_request(request: &LlmRequest, model: &str) -> Self {
22 Self {
23 model: model.to_string(),
24 prompt: request.prompt.clone(),
25 system_prompt: request.system_prompt.clone(),
26 temperature: request.temperature.map(|t| (t * 1000.0) as u32),
27 max_tokens: request.max_tokens,
28 }
29 }
30}
31
32#[derive(Debug, Clone)]
34struct CachedResponse {
35 response: LlmResponse,
36 inserted_at: Instant,
37 ttl: Duration,
38}
39
40impl CachedResponse {
41 fn is_expired(&self) -> bool {
42 self.inserted_at.elapsed() > self.ttl
43 }
44}
45
46#[derive(Debug)]
48pub struct LlmCache {
49 cache: Arc<Mutex<HashMap<CacheKey, CachedResponse>>>,
50 default_ttl: Duration,
51 max_size: usize,
52 hits: Arc<AtomicU64>,
53 misses: Arc<AtomicU64>,
54}
55
56impl Clone for LlmCache {
57 fn clone(&self) -> Self {
58 Self {
59 cache: Arc::clone(&self.cache),
60 default_ttl: self.default_ttl,
61 max_size: self.max_size,
62 hits: Arc::clone(&self.hits),
63 misses: Arc::clone(&self.misses),
64 }
65 }
66}
67
68impl Default for LlmCache {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl LlmCache {
75 pub fn new() -> Self {
79 Self {
80 cache: Arc::new(Mutex::new(HashMap::new())),
81 default_ttl: Duration::from_secs(3600), max_size: 1000,
83 hits: Arc::new(AtomicU64::new(0)),
84 misses: Arc::new(AtomicU64::new(0)),
85 }
86 }
87
88 pub fn with_config(ttl: Duration, max_size: usize) -> Self {
90 Self {
91 cache: Arc::new(Mutex::new(HashMap::new())),
92 default_ttl: ttl,
93 max_size,
94 hits: Arc::new(AtomicU64::new(0)),
95 misses: Arc::new(AtomicU64::new(0)),
96 }
97 }
98
99 pub fn get(&self, request: &LlmRequest, model: &str) -> Option<LlmResponse> {
101 let key = CacheKey::from_request(request, model);
102 let mut cache = self.cache.lock().unwrap();
103
104 if let Some(cached) = cache.get(&key) {
105 if !cached.is_expired() {
106 self.hits.fetch_add(1, Ordering::Relaxed);
107 return Some(cached.response.clone());
108 } else {
109 cache.remove(&key);
111 }
112 }
113
114 self.misses.fetch_add(1, Ordering::Relaxed);
115 None
116 }
117
118 pub fn put(&self, request: &LlmRequest, model: &str, response: LlmResponse) {
120 let key = CacheKey::from_request(request, model);
121 let mut cache = self.cache.lock().unwrap();
122
123 if cache.len() >= self.max_size {
125 if let Some(oldest_key) = cache.keys().next().cloned() {
126 cache.remove(&oldest_key);
127 }
128 }
129
130 cache.insert(
131 key,
132 CachedResponse {
133 response,
134 inserted_at: Instant::now(),
135 ttl: self.default_ttl,
136 },
137 );
138 }
139
140 pub fn clear(&self) {
142 let mut cache = self.cache.lock().unwrap();
143 cache.clear();
144 }
145
146 pub fn cleanup(&self) {
148 let mut cache = self.cache.lock().unwrap();
149 cache.retain(|_, v| !v.is_expired());
150 }
151
152 pub fn stats(&self) -> CacheStats {
154 let cache = self.cache.lock().unwrap();
155 let total = cache.len();
156 let expired = cache.values().filter(|v| v.is_expired()).count();
157 let hits = self.hits.load(Ordering::Relaxed);
158 let misses = self.misses.load(Ordering::Relaxed);
159
160 CacheStats {
161 total_entries: total,
162 expired_entries: expired,
163 active_entries: total - expired,
164 hits,
165 misses,
166 }
167 }
168
169 pub fn reset_stats(&self) {
171 self.hits.store(0, Ordering::Relaxed);
172 self.misses.store(0, Ordering::Relaxed);
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct CacheStats {
179 pub total_entries: usize,
180 pub expired_entries: usize,
181 pub active_entries: usize,
182 pub hits: u64,
183 pub misses: u64,
184}
185
186impl CacheStats {
187 pub fn hit_rate(&self) -> f64 {
189 let total = self.hits + self.misses;
190 if total == 0 {
191 0.0
192 } else {
193 (self.hits as f64 / total as f64) * 100.0
194 }
195 }
196}
197
198pub struct CachedProvider<P> {
202 inner: P,
203 cache: LlmCache,
204 model_name: String,
205}
206
207impl<P> CachedProvider<P> {
208 pub fn new(provider: P, model_name: String) -> Self {
210 Self {
211 inner: provider,
212 cache: LlmCache::new(),
213 model_name,
214 }
215 }
216
217 pub fn with_cache(provider: P, model_name: String, cache: LlmCache) -> Self {
219 Self {
220 inner: provider,
221 cache,
222 model_name,
223 }
224 }
225
226 pub fn inner(&self) -> &P {
228 &self.inner
229 }
230
231 pub fn inner_mut(&mut self) -> &mut P {
233 &mut self.inner
234 }
235
236 pub fn cache(&self) -> &LlmCache {
238 &self.cache
239 }
240
241 pub fn stats(&self) -> CacheStats {
243 self.cache.stats()
244 }
245}
246
247#[async_trait]
248impl<P: LlmProvider> LlmProvider for CachedProvider<P> {
249 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
250 if let Some(cached) = self.cache.get(&request, &self.model_name) {
252 tracing::debug!(
253 model = %self.model_name,
254 "Cache hit for LLM request"
255 );
256 return Ok(cached);
257 }
258
259 let response = self.inner.complete(request.clone()).await?;
261
262 self.cache.put(&request, &self.model_name, response.clone());
264
265 Ok(response)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::Usage;
273
274 #[test]
275 fn test_cache_hit() {
276 let cache = LlmCache::new();
277
278 let request = LlmRequest {
279 prompt: "Hello".to_string(),
280 system_prompt: None,
281 temperature: Some(0.7),
282 max_tokens: Some(100),
283 tools: Vec::new(),
284 images: Vec::new(),
285 };
286
287 let response = LlmResponse {
288 content: "Hi there!".to_string(),
289 model: "gpt-4".to_string(),
290 usage: Some(Usage {
291 prompt_tokens: 10,
292 completion_tokens: 5,
293 total_tokens: 15,
294 }),
295 tool_calls: Vec::new(),
296 };
297
298 assert!(cache.get(&request, "gpt-4").is_none());
300
301 cache.put(&request, "gpt-4", response.clone());
303
304 let cached = cache.get(&request, "gpt-4").unwrap();
306 assert_eq!(cached.content, response.content);
307 }
308
309 #[test]
310 fn test_cache_expiration() {
311 let cache = LlmCache::with_config(Duration::from_millis(500), 100);
312
313 let request = LlmRequest {
314 prompt: "Hello".to_string(),
315 system_prompt: None,
316 temperature: Some(0.7),
317 max_tokens: Some(100),
318 tools: Vec::new(),
319 images: Vec::new(),
320 };
321
322 let response = LlmResponse {
323 content: "Hi there!".to_string(),
324 model: "gpt-4".to_string(),
325 usage: None,
326 tool_calls: Vec::new(),
327 };
328
329 cache.put(&request, "gpt-4", response);
330
331 assert!(cache.get(&request, "gpt-4").is_some());
333
334 std::thread::sleep(Duration::from_millis(600));
336
337 assert!(cache.get(&request, "gpt-4").is_none());
339 }
340
341 #[test]
342 fn test_cache_max_size() {
343 let cache = LlmCache::with_config(Duration::from_secs(3600), 3);
344
345 for i in 0..5 {
346 let request = LlmRequest {
347 prompt: format!("Prompt {}", i),
348 system_prompt: None,
349 temperature: Some(0.7),
350 max_tokens: Some(100),
351 tools: Vec::new(),
352 images: Vec::new(),
353 };
354
355 let response = LlmResponse {
356 content: format!("Response {}", i),
357 model: "gpt-4".to_string(),
358 usage: None,
359 tool_calls: Vec::new(),
360 };
361
362 cache.put(&request, "gpt-4", response);
363 }
364
365 let stats = cache.stats();
366 assert_eq!(stats.total_entries, 3); }
368
369 #[test]
370 fn test_cache_cleanup() {
371 let cache = LlmCache::with_config(Duration::from_millis(10), 100);
372
373 for i in 0..5 {
375 let request = LlmRequest {
376 prompt: format!("Prompt {}", i),
377 system_prompt: None,
378 temperature: Some(0.7),
379 max_tokens: Some(100),
380 tools: Vec::new(),
381 images: Vec::new(),
382 };
383
384 let response = LlmResponse {
385 content: format!("Response {}", i),
386 model: "gpt-4".to_string(),
387 usage: None,
388 tool_calls: Vec::new(),
389 };
390
391 cache.put(&request, "gpt-4", response);
392 }
393
394 assert_eq!(cache.stats().total_entries, 5);
395
396 std::thread::sleep(Duration::from_millis(20));
398
399 cache.cleanup();
401
402 assert_eq!(cache.stats().total_entries, 0);
403 }
404
405 #[test]
406 fn test_cache_hit_rate() {
407 let cache = LlmCache::new();
408
409 let request = LlmRequest {
410 prompt: "Hello".to_string(),
411 system_prompt: None,
412 temperature: Some(0.7),
413 max_tokens: Some(100),
414 tools: Vec::new(),
415 images: Vec::new(),
416 };
417
418 let response = LlmResponse {
419 content: "Hi there!".to_string(),
420 model: "gpt-4".to_string(),
421 usage: None,
422 tool_calls: Vec::new(),
423 };
424
425 cache.get(&request, "gpt-4");
427 assert_eq!(cache.stats().misses, 1);
428 assert_eq!(cache.stats().hits, 0);
429 assert_eq!(cache.stats().hit_rate(), 0.0);
430
431 cache.put(&request, "gpt-4", response);
433
434 cache.get(&request, "gpt-4");
436 assert_eq!(cache.stats().hits, 1);
437 assert_eq!(cache.stats().misses, 1);
438 assert_eq!(cache.stats().hit_rate(), 50.0);
439
440 cache.get(&request, "gpt-4");
442 assert_eq!(cache.stats().hits, 2);
443 assert_eq!(cache.stats().misses, 1);
444
445 let hit_rate = cache.stats().hit_rate();
447 assert!(hit_rate > 66.0 && hit_rate < 67.0);
448
449 cache.reset_stats();
451 assert_eq!(cache.stats().hits, 0);
452 assert_eq!(cache.stats().misses, 0);
453 }
454}