1use crate::error::Result;
20use crate::traits::{ChatMessage, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse};
21use async_trait::async_trait;
22use std::collections::HashMap;
23use std::hash::{Hash, Hasher};
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use tokio::sync::RwLock;
27
28#[derive(Debug, Clone)]
30pub struct CacheConfig {
31 pub max_entries: usize,
33 pub ttl: Duration,
35 pub cache_completions: bool,
37 pub cache_embeddings: bool,
39}
40
41impl Default for CacheConfig {
42 fn default() -> Self {
43 Self {
44 max_entries: 1000,
45 ttl: Duration::from_secs(3600), cache_completions: true,
47 cache_embeddings: true,
48 }
49 }
50}
51
52impl CacheConfig {
53 pub fn new(max_entries: usize) -> Self {
55 Self {
56 max_entries,
57 ..Default::default()
58 }
59 }
60
61 pub fn with_ttl(mut self, ttl: Duration) -> Self {
63 self.ttl = ttl;
64 self
65 }
66
67 pub fn with_completion_caching(mut self, enabled: bool) -> Self {
69 self.cache_completions = enabled;
70 self
71 }
72
73 pub fn with_embedding_caching(mut self, enabled: bool) -> Self {
75 self.cache_embeddings = enabled;
76 self
77 }
78}
79
80#[derive(Debug, Clone)]
82struct CacheEntry<T> {
83 value: T,
84 created_at: Instant,
85 access_count: usize,
86}
87
88impl<T: Clone> CacheEntry<T> {
89 fn new(value: T) -> Self {
90 Self {
91 value,
92 created_at: Instant::now(),
93 access_count: 0,
94 }
95 }
96
97 fn is_expired(&self, ttl: Duration) -> bool {
98 self.created_at.elapsed() > ttl
99 }
100
101 fn access(&mut self) -> T {
102 self.access_count += 1;
103 self.value.clone()
104 }
105}
106
107#[derive(Debug, Clone, Eq, PartialEq, Hash)]
109struct CacheKey {
110 hash: u64,
111}
112
113impl CacheKey {
114 fn from_prompt(prompt: &str) -> Self {
115 let mut hasher = std::collections::hash_map::DefaultHasher::new();
116 prompt.hash(&mut hasher);
117 Self {
118 hash: hasher.finish(),
119 }
120 }
121
122 fn from_texts(texts: &[&str]) -> Self {
123 let mut hasher = std::collections::hash_map::DefaultHasher::new();
124 for text in texts {
125 text.hash(&mut hasher);
126 }
127 Self {
128 hash: hasher.finish(),
129 }
130 }
131}
132
133#[derive(Debug, Clone, Default)]
135pub struct CacheStats {
136 pub hits: usize,
138 pub misses: usize,
140 pub entries: usize,
142 pub evictions: usize,
144}
145
146impl CacheStats {
147 pub fn hit_rate(&self) -> f64 {
149 let total = self.hits + self.misses;
150 if total == 0 {
151 0.0
152 } else {
153 self.hits as f64 / total as f64
154 }
155 }
156}
157
158pub struct LLMCache {
161 config: CacheConfig,
162 completions: RwLock<HashMap<CacheKey, CacheEntry<LLMResponse>>>,
163 embeddings: RwLock<HashMap<CacheKey, CacheEntry<Vec<Vec<f32>>>>>,
164 stats: RwLock<CacheStats>,
165}
166
167impl LLMCache {
168 pub fn new(config: CacheConfig) -> Self {
170 Self {
171 config,
172 completions: RwLock::new(HashMap::new()),
173 embeddings: RwLock::new(HashMap::new()),
174 stats: RwLock::new(CacheStats::default()),
175 }
176 }
177
178 pub async fn stats(&self) -> CacheStats {
180 let stats = self.stats.read().await;
181 let completions = self.completions.read().await;
182 let embeddings = self.embeddings.read().await;
183
184 CacheStats {
185 entries: completions.len() + embeddings.len(),
186 ..*stats
187 }
188 }
189
190 pub async fn clear(&self) {
192 let mut completions = self.completions.write().await;
193 let mut embeddings = self.embeddings.write().await;
194 let mut stats = self.stats.write().await;
195
196 let evicted = completions.len() + embeddings.len();
197 completions.clear();
198 embeddings.clear();
199 stats.evictions += evicted;
200 }
201
202 pub async fn get_completion(&self, prompt: &str) -> Option<LLMResponse> {
204 if !self.config.cache_completions {
205 return None;
206 }
207
208 let key = CacheKey::from_prompt(prompt);
209 let mut cache = self.completions.write().await;
210
211 if let Some(entry) = cache.get_mut(&key) {
212 if entry.is_expired(self.config.ttl) {
213 cache.remove(&key);
214 let mut stats = self.stats.write().await;
215 stats.misses += 1;
216 stats.evictions += 1;
217 return None;
218 }
219
220 let mut stats = self.stats.write().await;
221 stats.hits += 1;
222 return Some(entry.access());
223 }
224
225 let mut stats = self.stats.write().await;
226 stats.misses += 1;
227 None
228 }
229
230 pub async fn put_completion(&self, prompt: &str, response: LLMResponse) {
232 if !self.config.cache_completions {
233 return;
234 }
235
236 let key = CacheKey::from_prompt(prompt);
237 let mut cache = self.completions.write().await;
238
239 if cache.len() >= self.config.max_entries {
241 self.evict_lru(&mut cache).await;
242 }
243
244 cache.insert(key, CacheEntry::new(response));
245 }
246
247 pub async fn get_embeddings(&self, texts: &[&str]) -> Option<Vec<Vec<f32>>> {
249 if !self.config.cache_embeddings {
250 return None;
251 }
252
253 let key = CacheKey::from_texts(texts);
254 let mut cache = self.embeddings.write().await;
255
256 if let Some(entry) = cache.get_mut(&key) {
257 if entry.is_expired(self.config.ttl) {
258 cache.remove(&key);
259 let mut stats = self.stats.write().await;
260 stats.misses += 1;
261 stats.evictions += 1;
262 return None;
263 }
264
265 let mut stats = self.stats.write().await;
266 stats.hits += 1;
267 return Some(entry.access());
268 }
269
270 let mut stats = self.stats.write().await;
271 stats.misses += 1;
272 None
273 }
274
275 pub async fn put_embeddings(&self, texts: &[&str], embeddings: Vec<Vec<f32>>) {
277 if !self.config.cache_embeddings {
278 return;
279 }
280
281 let key = CacheKey::from_texts(texts);
282 let mut cache = self.embeddings.write().await;
283
284 if cache.len() >= self.config.max_entries {
286 self.evict_lru_embeddings(&mut cache).await;
287 }
288
289 cache.insert(key, CacheEntry::new(embeddings));
290 }
291
292 async fn evict_lru<T: Clone>(&self, cache: &mut HashMap<CacheKey, CacheEntry<T>>) {
293 if let Some(key) = cache
295 .iter()
296 .min_by_key(|(_, entry)| (entry.access_count, entry.created_at))
297 .map(|(k, _)| k.clone())
298 {
299 cache.remove(&key);
300 let mut stats = self.stats.write().await;
301 stats.evictions += 1;
302 }
303 }
304
305 async fn evict_lru_embeddings(&self, cache: &mut HashMap<CacheKey, CacheEntry<Vec<Vec<f32>>>>) {
306 if let Some(key) = cache
307 .iter()
308 .min_by_key(|(_, entry)| (entry.access_count, entry.created_at))
309 .map(|(k, _)| k.clone())
310 {
311 cache.remove(&key);
312 let mut stats = self.stats.write().await;
313 stats.evictions += 1;
314 }
315 }
316}
317
318pub struct CachedProvider<P> {
320 inner: P,
321 cache: Arc<LLMCache>,
322}
323
324impl<P> CachedProvider<P> {
325 pub fn new(inner: P, cache: Arc<LLMCache>) -> Self {
327 Self { inner, cache }
328 }
329
330 pub fn with_default_cache(inner: P) -> Self {
332 Self {
333 inner,
334 cache: Arc::new(LLMCache::new(CacheConfig::default())),
335 }
336 }
337
338 pub async fn cache_stats(&self) -> CacheStats {
340 self.cache.stats().await
341 }
342
343 pub async fn clear_cache(&self) {
345 self.cache.clear().await;
346 }
347}
348
349#[async_trait]
350impl<P: LLMProvider> LLMProvider for CachedProvider<P> {
351 fn name(&self) -> &str {
352 self.inner.name()
353 }
354
355 fn model(&self) -> &str {
356 self.inner.model()
357 }
358
359 fn max_context_length(&self) -> usize {
360 self.inner.max_context_length()
361 }
362
363 async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
364 if let Some(cached) = self.cache.get_completion(prompt).await {
366 tracing::debug!("Cache hit for completion");
367 return Ok(cached);
368 }
369
370 let response = self.inner.complete(prompt).await?;
372
373 self.cache.put_completion(prompt, response.clone()).await;
375
376 Ok(response)
377 }
378
379 async fn complete_with_options(
380 &self,
381 prompt: &str,
382 options: &CompletionOptions,
383 ) -> Result<LLMResponse> {
384 self.inner.complete_with_options(prompt, options).await
386 }
387
388 async fn chat(
389 &self,
390 messages: &[ChatMessage],
391 options: Option<&CompletionOptions>,
392 ) -> Result<LLMResponse> {
393 self.inner.chat(messages, options).await
395 }
396}
397
398#[async_trait]
399impl<P: EmbeddingProvider> EmbeddingProvider for CachedProvider<P> {
400 fn name(&self) -> &str {
401 self.inner.name()
402 }
403
404 fn model(&self) -> &str {
405 self.inner.model()
406 }
407
408 fn dimension(&self) -> usize {
409 self.inner.dimension()
410 }
411
412 fn max_tokens(&self) -> usize {
413 self.inner.max_tokens()
414 }
415
416 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
417 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
419
420 if let Some(cached) = self.cache.get_embeddings(&text_refs).await {
422 tracing::debug!("Cache hit for embeddings");
423 return Ok(cached);
424 }
425
426 let embeddings = self.inner.embed(texts).await?;
428
429 self.cache
431 .put_embeddings(&text_refs, embeddings.clone())
432 .await;
433
434 Ok(embeddings)
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_cache_key_from_prompt() {
444 let key1 = CacheKey::from_prompt("Hello world");
445 let key2 = CacheKey::from_prompt("Hello world");
446 let key3 = CacheKey::from_prompt("Different prompt");
447
448 assert_eq!(key1, key2);
449 assert_ne!(key1, key3);
450 }
451
452 #[test]
453 fn test_cache_key_from_texts() {
454 let key1 = CacheKey::from_texts(&["a", "b", "c"]);
455 let key2 = CacheKey::from_texts(&["a", "b", "c"]);
456 let key3 = CacheKey::from_texts(&["x", "y", "z"]);
457
458 assert_eq!(key1, key2);
459 assert_ne!(key1, key3);
460 }
461
462 #[test]
463 fn test_cache_config_default() {
464 let config = CacheConfig::default();
465 assert_eq!(config.max_entries, 1000);
466 assert!(config.cache_completions);
467 assert!(config.cache_embeddings);
468 }
469
470 #[test]
471 fn test_cache_config_builder() {
472 let config = CacheConfig::new(500)
473 .with_ttl(Duration::from_secs(600))
474 .with_completion_caching(false);
475
476 assert_eq!(config.max_entries, 500);
477 assert_eq!(config.ttl, Duration::from_secs(600));
478 assert!(!config.cache_completions);
479 }
480
481 #[tokio::test]
482 async fn test_cache_stats() {
483 let cache = LLMCache::new(CacheConfig::default());
484 let stats = cache.stats().await;
485
486 assert_eq!(stats.hits, 0);
487 assert_eq!(stats.misses, 0);
488 assert_eq!(stats.entries, 0);
489 }
490
491 #[tokio::test]
492 async fn test_cache_miss() {
493 let cache = LLMCache::new(CacheConfig::default());
494 let result = cache.get_completion("test prompt").await;
495
496 assert!(result.is_none());
497
498 let stats = cache.stats().await;
499 assert_eq!(stats.misses, 1);
500 }
501
502 #[tokio::test]
503 async fn test_cache_hit() {
504 let cache = LLMCache::new(CacheConfig::default());
505
506 let response = LLMResponse::new("test response", "gpt-4").with_usage(10, 5);
507
508 cache.put_completion("test prompt", response.clone()).await;
509 let result = cache.get_completion("test prompt").await;
510
511 assert!(result.is_some());
512 assert_eq!(result.unwrap().content, "test response");
513
514 let stats = cache.stats().await;
515 assert_eq!(stats.hits, 1);
516 }
517
518 #[tokio::test]
519 async fn test_cache_clear() {
520 let cache = LLMCache::new(CacheConfig::default());
521
522 let response = LLMResponse::new("test", "gpt-4").with_usage(1, 1);
523
524 cache.put_completion("prompt", response).await;
525 assert_eq!(cache.stats().await.entries, 1);
526
527 cache.clear().await;
528 assert_eq!(cache.stats().await.entries, 0);
529 }
530
531 #[test]
532 fn test_hit_rate() {
533 let mut stats = CacheStats::default();
534 assert_eq!(stats.hit_rate(), 0.0);
535
536 stats.hits = 3;
537 stats.misses = 1;
538 assert_eq!(stats.hit_rate(), 0.75);
539 }
540
541 #[tokio::test]
542 async fn test_embedding_cache() {
543 let cache = LLMCache::new(CacheConfig::default());
544
545 let embeddings = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
546 let texts = ["text1", "text2"];
547
548 cache.put_embeddings(&texts, embeddings.clone()).await;
549 let result = cache.get_embeddings(&texts).await;
550
551 assert!(result.is_some());
552 assert_eq!(result.unwrap(), embeddings);
553 }
554
555 #[tokio::test]
556 async fn test_disabled_caching() {
557 let config = CacheConfig::default()
558 .with_completion_caching(false)
559 .with_embedding_caching(false);
560 let cache = LLMCache::new(config);
561
562 let response = LLMResponse::new("test", "gpt-4").with_usage(1, 1);
563
564 cache.put_completion("prompt", response).await;
565 assert!(cache.get_completion("prompt").await.is_none());
566
567 cache.put_embeddings(&["text"], vec![vec![1.0]]).await;
568 assert!(cache.get_embeddings(&["text"]).await.is_none());
569 }
570
571 #[tokio::test]
572 async fn test_ttl_expiration_completion() {
573 let config = CacheConfig::new(100).with_ttl(Duration::from_millis(1));
575 let cache = LLMCache::new(config);
576
577 let response = LLMResponse::new("expires", "gpt-4").with_usage(5, 3);
578 cache.put_completion("ephemeral", response).await;
579
580 tokio::time::sleep(Duration::from_millis(10)).await;
582
583 let result = cache.get_completion("ephemeral").await;
584 assert!(result.is_none(), "Expired entry should return None");
585
586 let stats = cache.stats().await;
587 assert_eq!(stats.evictions, 1, "Expired entry should count as eviction");
588 assert_eq!(stats.misses, 1);
589 }
590
591 #[tokio::test]
592 async fn test_ttl_expiration_embeddings() {
593 let config = CacheConfig::new(100).with_ttl(Duration::from_millis(1));
594 let cache = LLMCache::new(config);
595
596 cache.put_embeddings(&["txt"], vec![vec![1.0, 2.0]]).await;
597
598 tokio::time::sleep(Duration::from_millis(10)).await;
599
600 assert!(cache.get_embeddings(&["txt"]).await.is_none());
601 let stats = cache.stats().await;
602 assert_eq!(stats.evictions, 1);
603 }
604
605 #[tokio::test]
606 async fn test_lru_eviction_completions() {
607 let config = CacheConfig::new(2);
609 let cache = LLMCache::new(config);
610
611 let r1 = LLMResponse::new("first", "gpt-4").with_usage(1, 1);
612 let r2 = LLMResponse::new("second", "gpt-4").with_usage(1, 1);
613 let r3 = LLMResponse::new("third", "gpt-4").with_usage(1, 1);
614
615 cache.put_completion("p1", r1).await;
616 cache.put_completion("p2", r2).await;
617
618 let _ = cache.get_completion("p2").await;
620
621 cache.put_completion("p3", r3).await;
623
624 assert!(
625 cache.get_completion("p1").await.is_none(),
626 "p1 should have been evicted"
627 );
628 assert!(cache.get_completion("p2").await.is_some());
630 assert!(cache.get_completion("p3").await.is_some());
631 }
632
633 #[tokio::test]
634 async fn test_lru_eviction_embeddings() {
635 let config = CacheConfig::new(1);
636 let cache = LLMCache::new(config);
637
638 cache.put_embeddings(&["a"], vec![vec![1.0]]).await;
639 cache.put_embeddings(&["b"], vec![vec![2.0]]).await;
640
641 assert!(cache.get_embeddings(&["a"]).await.is_none());
643 assert!(cache.get_embeddings(&["b"]).await.is_some());
644 }
645
646 #[tokio::test]
647 async fn test_access_count_increments() {
648 let cache = LLMCache::new(CacheConfig::default());
649 let response = LLMResponse::new("counter", "gpt-4").with_usage(1, 1);
650
651 cache.put_completion("cnt", response).await;
652
653 for _ in 0..3 {
655 let _ = cache.get_completion("cnt").await;
656 }
657
658 let stats = cache.stats().await;
659 assert_eq!(stats.hits, 3);
660 }
661
662 #[tokio::test]
663 async fn test_cache_entry_is_expired() {
664 let entry = CacheEntry::new("value".to_string());
665 assert!(!entry.is_expired(Duration::from_secs(3600)));
667 assert!(entry.is_expired(Duration::ZERO));
669 }
670
671 #[tokio::test]
672 async fn test_cached_provider_complete_delegates() {
673 use crate::providers::MockProvider;
674
675 let mock = MockProvider::new();
676 mock.add_response("cached answer").await;
677
678 let cache = Arc::new(LLMCache::new(CacheConfig::default()));
679 let provider = CachedProvider::new(mock, cache);
680
681 let r1 = provider.inner.complete("hello").await.unwrap();
683 assert_eq!(r1.content, "cached answer");
684 }
685
686 #[tokio::test]
687 async fn test_cached_provider_name_model_delegates() {
688 use crate::providers::MockProvider;
689
690 let mock = MockProvider::new();
691 let cache = Arc::new(LLMCache::new(CacheConfig::default()));
692 let provider = CachedProvider::new(mock, cache);
693
694 assert_eq!(LLMProvider::name(&provider), "mock");
695 assert_eq!(LLMProvider::model(&provider), "mock-model");
696 assert_eq!(provider.max_context_length(), 4096);
697 }
698
699 #[tokio::test]
700 async fn test_cached_provider_with_default_cache() {
701 use crate::providers::MockProvider;
702
703 let mock = MockProvider::new();
704 let provider = CachedProvider::with_default_cache(mock);
705
706 let stats = provider.cache_stats().await;
707 assert_eq!(stats.entries, 0);
708 assert_eq!(stats.hits, 0);
709 }
710
711 #[tokio::test]
712 async fn test_cached_provider_clear_cache() {
713 use crate::providers::MockProvider;
714
715 let mock = MockProvider::new();
716 let cache = Arc::new(LLMCache::new(CacheConfig::default()));
717 let provider = CachedProvider::new(mock, cache);
718
719 provider
721 .cache
722 .put_completion("test", LLMResponse::new("v", "m").with_usage(1, 1))
723 .await;
724 assert_eq!(provider.cache_stats().await.entries, 1);
725
726 provider.clear_cache().await;
727 assert_eq!(provider.cache_stats().await.entries, 0);
728 }
729
730 #[test]
731 fn test_cache_key_empty_prompt() {
732 let k1 = CacheKey::from_prompt("");
733 let k2 = CacheKey::from_prompt("");
734 assert_eq!(k1, k2);
735 }
736
737 #[test]
738 fn test_cache_key_empty_texts() {
739 let k = CacheKey::from_texts(&[]);
740 let k2 = CacheKey::from_texts(&[]);
741 assert_eq!(k, k2);
742 }
743
744 #[tokio::test]
745 async fn test_multiple_put_same_key_overwrites() {
746 let cache = LLMCache::new(CacheConfig::default());
747 let r1 = LLMResponse::new("first", "m").with_usage(1, 1);
748 let r2 = LLMResponse::new("second", "m").with_usage(1, 1);
749
750 cache.put_completion("key", r1).await;
751 cache.put_completion("key", r2).await;
752
753 let result = cache.get_completion("key").await.unwrap();
754 assert_eq!(result.content, "second");
755 assert_eq!(cache.stats().await.entries, 1);
757 }
758
759 #[test]
760 fn test_cache_config_with_embedding_caching() {
761 let config = CacheConfig::default().with_embedding_caching(false);
762 assert!(!config.cache_embeddings);
763 assert!(config.cache_completions); }
765
766 #[tokio::test]
767 async fn test_clear_updates_eviction_count() {
768 let cache = LLMCache::new(CacheConfig::default());
769 let r = LLMResponse::new("a", "m").with_usage(1, 1);
770 cache.put_completion("x", r).await;
771 cache.put_embeddings(&["y"], vec![vec![1.0]]).await;
772
773 cache.clear().await;
774 let stats = cache.stats().await;
775 assert_eq!(stats.evictions, 2, "Clear should count 2 evictions");
776 assert_eq!(stats.entries, 0);
777 }
778
779 #[tokio::test]
780 async fn test_cached_provider_embed_delegates() {
781 use crate::providers::MockProvider;
782
783 let mock = MockProvider::new();
784 let cache = Arc::new(LLMCache::new(CacheConfig::default()));
785 let provider = CachedProvider::new(mock, cache);
786
787 let result = provider.embed(&["hello".to_string()]).await.unwrap();
789 assert_eq!(result.len(), 1);
790 assert_eq!(result[0].len(), 1536);
791
792 let result2 = provider.embed(&["hello".to_string()]).await.unwrap();
794 assert_eq!(result2, result);
795
796 let stats = provider.cache_stats().await;
797 assert_eq!(stats.hits, 1);
798 }
799}