1use std::collections::HashSet;
35use std::pin::Pin;
36use std::sync::atomic::{AtomicU64, Ordering};
37use std::sync::Arc;
38use std::time::{Duration, SystemTime};
39
40use async_trait::async_trait;
41use dashmap::DashMap;
42use futures::Stream;
43use sha2::{Digest, Sha256};
44
45use crate::error::Result;
46use crate::provider::Provider;
47use crate::types::{CompletionRequest, CompletionResponse, StreamChunk};
48
49#[derive(Debug, Clone)]
51pub struct CacheConfig {
52 pub enabled: bool,
54 pub ttl: Duration,
56 pub max_entries: usize,
58 pub cache_streaming: bool,
60 pub exclude_fields: HashSet<String>,
62}
63
64impl Default for CacheConfig {
65 fn default() -> Self {
66 Self {
67 enabled: true,
68 ttl: Duration::from_secs(3600), max_entries: 10_000,
70 cache_streaming: false,
71 exclude_fields: HashSet::from_iter([
72 "temperature".to_string(),
73 "top_p".to_string(),
74 "top_k".to_string(),
75 "seed".to_string(),
76 ]),
77 }
78 }
79}
80
81impl CacheConfig {
82 pub fn new() -> Self {
84 Self::default()
85 }
86
87 pub fn with_ttl(mut self, ttl: Duration) -> Self {
89 self.ttl = ttl;
90 self
91 }
92
93 pub fn with_max_entries(mut self, max_entries: usize) -> Self {
95 self.max_entries = max_entries;
96 self
97 }
98
99 pub fn with_enabled(mut self, enabled: bool) -> Self {
101 self.enabled = enabled;
102 self
103 }
104
105 pub fn with_cache_streaming(mut self, cache_streaming: bool) -> Self {
107 self.cache_streaming = cache_streaming;
108 self
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct CachedResponse {
115 pub response: CompletionResponse,
117 pub created_at: SystemTime,
119 pub hit_count: Arc<AtomicU64>,
121}
122
123impl CachedResponse {
124 pub fn new(response: CompletionResponse) -> Self {
126 Self {
127 response,
128 created_at: SystemTime::now(),
129 hit_count: Arc::new(AtomicU64::new(0)),
130 }
131 }
132
133 pub fn is_expired(&self, ttl: Duration) -> bool {
135 self.created_at
136 .elapsed()
137 .map(|elapsed| elapsed > ttl)
138 .unwrap_or(true)
139 }
140
141 pub fn record_hit(&self) -> u64 {
143 self.hit_count.fetch_add(1, Ordering::Relaxed) + 1
144 }
145}
146
147#[derive(Debug, Clone, Default)]
149pub struct CacheStats {
150 pub hits: u64,
152 pub misses: u64,
154 pub entries: usize,
156 pub size_bytes: usize,
158}
159
160impl CacheStats {
161 pub fn hit_rate(&self) -> f64 {
163 let total = self.hits + self.misses;
164 if total == 0 {
165 0.0
166 } else {
167 self.hits as f64 / total as f64
168 }
169 }
170}
171
172#[async_trait]
174pub trait CacheBackend: Send + Sync {
175 async fn get(&self, key: &str) -> Option<CachedResponse>;
177
178 async fn set(&self, key: &str, response: CachedResponse);
180
181 async fn invalidate(&self, key: &str);
183
184 async fn clear(&self);
186
187 fn stats(&self) -> CacheStats;
189}
190
191pub struct InMemoryCache {
193 entries: DashMap<String, CachedResponse>,
194 config: CacheConfig,
195 hits: AtomicU64,
196 misses: AtomicU64,
197}
198
199impl InMemoryCache {
200 pub fn new(config: CacheConfig) -> Arc<Self> {
202 Arc::new(Self {
203 entries: DashMap::new(),
204 config,
205 hits: AtomicU64::new(0),
206 misses: AtomicU64::new(0),
207 })
208 }
209
210 pub fn default_cache() -> Arc<Self> {
212 Self::new(CacheConfig::default())
213 }
214
215 pub fn evict_expired(&self) {
217 let ttl = self.config.ttl;
218 self.entries.retain(|_, v| !v.is_expired(ttl));
219 }
220
221 fn evict_if_needed(&self) {
223 if self.entries.len() >= self.config.max_entries {
224 let mut oldest_keys: Vec<(String, SystemTime)> = self
227 .entries
228 .iter()
229 .map(|e| (e.key().clone(), e.value().created_at))
230 .collect();
231
232 oldest_keys.sort_by(|a, b| a.1.cmp(&b.1));
233
234 let to_remove = self.config.max_entries / 10;
236 for (key, _) in oldest_keys.into_iter().take(to_remove) {
237 self.entries.remove(&key);
238 }
239 }
240 }
241}
242
243#[async_trait]
244impl CacheBackend for InMemoryCache {
245 async fn get(&self, key: &str) -> Option<CachedResponse> {
246 if let Some(entry) = self.entries.get(key) {
247 if entry.is_expired(self.config.ttl) {
248 self.entries.remove(key);
249 self.misses.fetch_add(1, Ordering::Relaxed);
250 None
251 } else {
252 entry.record_hit();
253 self.hits.fetch_add(1, Ordering::Relaxed);
254 Some(entry.clone())
255 }
256 } else {
257 self.misses.fetch_add(1, Ordering::Relaxed);
258 None
259 }
260 }
261
262 async fn set(&self, key: &str, response: CachedResponse) {
263 self.evict_if_needed();
264 self.entries.insert(key.to_string(), response);
265 }
266
267 async fn invalidate(&self, key: &str) {
268 self.entries.remove(key);
269 }
270
271 async fn clear(&self) {
272 self.entries.clear();
273 self.hits.store(0, Ordering::Relaxed);
274 self.misses.store(0, Ordering::Relaxed);
275 }
276
277 fn stats(&self) -> CacheStats {
278 CacheStats {
279 hits: self.hits.load(Ordering::Relaxed),
280 misses: self.misses.load(Ordering::Relaxed),
281 entries: self.entries.len(),
282 size_bytes: 0, }
284 }
285}
286
287pub struct CachingProvider<P> {
289 inner: P,
291 cache: Arc<dyn CacheBackend>,
293 config: CacheConfig,
295}
296
297impl<P> CachingProvider<P> {
298 pub fn new(inner: P, cache: Arc<dyn CacheBackend>) -> Self {
300 Self {
301 inner,
302 cache,
303 config: CacheConfig::default(),
304 }
305 }
306
307 pub fn with_config(inner: P, cache: Arc<dyn CacheBackend>, config: CacheConfig) -> Self {
309 Self {
310 inner,
311 cache,
312 config,
313 }
314 }
315
316 pub fn inner(&self) -> &P {
318 &self.inner
319 }
320
321 pub fn stats(&self) -> CacheStats {
323 self.cache.stats()
324 }
325
326 pub async fn clear_cache(&self) {
328 self.cache.clear().await;
329 }
330
331 fn compute_cache_key(&self, request: &CompletionRequest) -> String {
333 let mut hasher = Sha256::new();
335
336 hasher.update(request.model.as_bytes());
338 hasher.update(b"|");
339
340 if let Some(ref system) = request.system {
342 hasher.update(system.as_bytes());
343 }
344 hasher.update(b"|");
345
346 for msg in &request.messages {
348 hasher.update(format!("{:?}", msg.role).as_bytes());
349 hasher.update(b":");
350 for block in &msg.content {
351 hasher.update(format!("{:?}", block).as_bytes());
352 }
353 hasher.update(b";");
354 }
355 hasher.update(b"|");
356
357 if let Some(ref tools) = request.tools {
359 for tool in tools {
360 hasher.update(tool.name.as_bytes());
361 hasher.update(b":");
362 hasher.update(tool.description.as_bytes());
363 hasher.update(b";");
364 }
365 }
366 hasher.update(b"|");
367
368 if let Some(ref format) = request.response_format {
370 hasher.update(format!("{:?}", format.format_type).as_bytes());
371 }
372
373 format!("cache:{}", hex::encode(hasher.finalize()))
374 }
375}
376
377#[async_trait]
378impl<P: Provider> Provider for CachingProvider<P> {
379 fn name(&self) -> &str {
380 self.inner.name()
381 }
382
383 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
384 if !self.config.enabled {
385 return self.inner.complete(request).await;
386 }
387
388 let cache_key = self.compute_cache_key(&request);
389
390 if let Some(cached) = self.cache.get(&cache_key).await {
392 tracing::debug!(key = %cache_key, "Cache hit");
393 return Ok(cached.response);
394 }
395
396 tracing::debug!(key = %cache_key, "Cache miss");
398 let response = self.inner.complete(request).await?;
399
400 let cached = CachedResponse::new(response.clone());
402 self.cache.set(&cache_key, cached).await;
403
404 Ok(response)
405 }
406
407 async fn complete_stream(
408 &self,
409 request: CompletionRequest,
410 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
411 self.inner.complete_stream(request).await
414 }
415
416 fn supports_tools(&self) -> bool {
417 self.inner.supports_tools()
418 }
419
420 fn supports_vision(&self) -> bool {
421 self.inner.supports_vision()
422 }
423
424 fn supports_streaming(&self) -> bool {
425 self.inner.supports_streaming()
426 }
427
428 fn supported_models(&self) -> Option<&[&str]> {
429 self.inner.supported_models()
430 }
431
432 fn default_model(&self) -> Option<&str> {
433 self.inner.default_model()
434 }
435}
436
437#[derive(Default)]
439pub struct CacheKeyBuilder {
440 parts: Vec<String>,
441}
442
443impl CacheKeyBuilder {
444 pub fn new() -> Self {
446 Self::default()
447 }
448
449 pub fn with_part(mut self, part: impl Into<String>) -> Self {
451 self.parts.push(part.into());
452 self
453 }
454
455 pub fn build(self) -> String {
457 let mut hasher = Sha256::new();
458 for part in self.parts {
459 hasher.update(part.as_bytes());
460 hasher.update(b"|");
461 }
462 format!("cache:{}", hex::encode(hasher.finalize()))
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_cache_config_default() {
472 let config = CacheConfig::default();
473 assert!(config.enabled);
474 assert_eq!(config.ttl, Duration::from_secs(3600));
475 assert_eq!(config.max_entries, 10_000);
476 assert!(!config.cache_streaming);
477 }
478
479 #[test]
480 fn test_cache_config_builder() {
481 let config = CacheConfig::new()
482 .with_ttl(Duration::from_secs(600))
483 .with_max_entries(1000)
484 .with_enabled(false);
485
486 assert!(!config.enabled);
487 assert_eq!(config.ttl, Duration::from_secs(600));
488 assert_eq!(config.max_entries, 1000);
489 }
490
491 #[test]
492 fn test_cached_response_expiry() {
493 let response = CompletionResponse {
494 id: "test".to_string(),
495 model: "test".to_string(),
496 content: vec![],
497 stop_reason: crate::types::StopReason::EndTurn,
498 usage: crate::types::Usage::default(),
499 };
500
501 let cached = CachedResponse::new(response);
502
503 assert!(!cached.is_expired(Duration::from_secs(3600)));
505
506 assert!(cached.is_expired(Duration::from_secs(0)));
508 }
509
510 #[test]
511 fn test_cache_stats_hit_rate() {
512 let stats = CacheStats {
513 hits: 80,
514 misses: 20,
515 entries: 100,
516 size_bytes: 0,
517 };
518
519 assert!((stats.hit_rate() - 0.8).abs() < 0.001);
520 }
521
522 #[test]
523 fn test_cache_stats_hit_rate_zero() {
524 let stats = CacheStats::default();
525 assert_eq!(stats.hit_rate(), 0.0);
526 }
527
528 #[tokio::test]
529 async fn test_in_memory_cache() {
530 let cache = InMemoryCache::new(CacheConfig::default());
531
532 let response = CompletionResponse {
533 id: "test".to_string(),
534 model: "test".to_string(),
535 content: vec![],
536 stop_reason: crate::types::StopReason::EndTurn,
537 usage: crate::types::Usage::default(),
538 };
539
540 assert!(cache.get("key1").await.is_none());
542
543 cache.set("key1", CachedResponse::new(response)).await;
545 assert!(cache.get("key1").await.is_some());
546
547 let stats = cache.stats();
549 assert_eq!(stats.hits, 1);
550 assert_eq!(stats.misses, 1);
551 assert_eq!(stats.entries, 1);
552
553 cache.invalidate("key1").await;
555 assert!(cache.get("key1").await.is_none());
556
557 cache
559 .set(
560 "key2",
561 CachedResponse::new(CompletionResponse {
562 id: "test2".to_string(),
563 model: "test".to_string(),
564 content: vec![],
565 stop_reason: crate::types::StopReason::EndTurn,
566 usage: crate::types::Usage::default(),
567 }),
568 )
569 .await;
570 cache.clear().await;
571 assert_eq!(cache.stats().entries, 0);
572 }
573
574 #[test]
575 fn test_cache_key_builder() {
576 let key = CacheKeyBuilder::new()
577 .with_part("model")
578 .with_part("prompt")
579 .build();
580
581 assert!(key.starts_with("cache:"));
582 assert_eq!(key.len(), 6 + 64); }
584}