1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::RwLock;
12
13use super::types::{ChatRequest, ChatResponse};
14
15#[derive(Debug, Clone)]
17pub struct LlmCacheConfig {
18 pub max_entries: usize,
20 pub ttl: Duration,
22 pub semantic_matching: bool,
24 pub similarity_threshold: f64,
26 pub max_prompt_tokens: usize,
28}
29
30impl Default for LlmCacheConfig {
31 fn default() -> Self {
32 Self {
33 max_entries: 1000,
34 ttl: Duration::from_secs(3600), semantic_matching: false, similarity_threshold: 0.95,
37 max_prompt_tokens: 2000,
38 }
39 }
40}
41
42impl LlmCacheConfig {
43 #[must_use]
45 pub fn development() -> Self {
46 Self {
47 max_entries: 100,
48 ttl: Duration::from_secs(300), semantic_matching: false,
50 similarity_threshold: 0.95,
51 max_prompt_tokens: 1000,
52 }
53 }
54
55 #[must_use]
57 pub fn production() -> Self {
58 Self {
59 max_entries: 10000,
60 ttl: Duration::from_secs(86400), semantic_matching: false,
62 similarity_threshold: 0.95,
63 max_prompt_tokens: 4000,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70struct CacheEntry {
71 response: CachedResponse,
73 created_at: Instant,
75 hit_count: u64,
77 last_accessed: Instant,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CachedResponse {
84 pub content: String,
86 pub finish_reason: Option<String>,
88 pub prompt_tokens: u32,
90 pub completion_tokens: u32,
92}
93
94impl From<&ChatResponse> for CachedResponse {
95 fn from(response: &ChatResponse) -> Self {
96 Self {
97 content: response.message.content.clone(),
98 finish_reason: response.finish_reason.clone(),
99 prompt_tokens: response.prompt_tokens,
100 completion_tokens: response.completion_tokens,
101 }
102 }
103}
104
105#[derive(Debug, Clone, PartialEq, Eq, Hash)]
107struct CacheKey {
108 messages_hash: u64,
110 temperature_q: u32,
112 max_tokens: Option<u32>,
114}
115
116impl CacheKey {
117 fn from_request(request: &ChatRequest) -> Self {
118 let mut hasher = std::collections::hash_map::DefaultHasher::new();
119
120 for msg in &request.messages {
122 let role_str = match msg.role {
124 super::types::ChatRole::System => "system",
125 super::types::ChatRole::User => "user",
126 super::types::ChatRole::Assistant => "assistant",
127 };
128 role_str.hash(&mut hasher);
129 msg.content.hash(&mut hasher);
130 }
131 let messages_hash = hasher.finish();
132
133 let temperature_q = (request.temperature.unwrap_or(1.0) * 100.0) as u32;
135
136 Self {
137 messages_hash,
138 temperature_q,
139 max_tokens: request.max_tokens,
140 }
141 }
142}
143
144pub struct LlmCache {
146 config: LlmCacheConfig,
147 cache: Arc<RwLock<HashMap<CacheKey, CacheEntry>>>,
148 stats: Arc<RwLock<CacheStats>>,
149}
150
151#[derive(Debug, Clone, Default, Serialize)]
153pub struct CacheStats {
154 pub hits: u64,
156 pub misses: u64,
158 pub entries: usize,
160 pub tokens_saved: u64,
162 pub estimated_cost_saved_cents: f64,
164 pub ttl_evictions: u64,
166 pub capacity_evictions: u64,
168}
169
170impl CacheStats {
171 #[must_use]
173 pub fn hit_rate(&self) -> f64 {
174 let total = self.hits + self.misses;
175 if total == 0 {
176 0.0
177 } else {
178 self.hits as f64 / total as f64
179 }
180 }
181}
182
183impl LlmCache {
184 #[must_use]
186 pub fn new(config: LlmCacheConfig) -> Self {
187 Self {
188 config,
189 cache: Arc::new(RwLock::new(HashMap::new())),
190 stats: Arc::new(RwLock::new(CacheStats::default())),
191 }
192 }
193
194 #[must_use]
196 pub fn default_cache() -> Self {
197 Self::new(LlmCacheConfig::default())
198 }
199
200 pub async fn get(&self, request: &ChatRequest) -> Option<CachedResponse> {
202 if !self.is_cacheable(request) {
204 return None;
205 }
206
207 let key = CacheKey::from_request(request);
208
209 let mut cache = self.cache.write().await;
210
211 if let Some(entry) = cache.get_mut(&key) {
212 if entry.created_at.elapsed() > self.config.ttl {
214 cache.remove(&key);
216 let mut stats = self.stats.write().await;
217 stats.ttl_evictions += 1;
218 stats.misses += 1;
219 return None;
220 }
221
222 entry.hit_count += 1;
224 entry.last_accessed = Instant::now();
225
226 let mut stats = self.stats.write().await;
228 stats.hits += 1;
229 stats.tokens_saved +=
230 u64::from(entry.response.prompt_tokens + entry.response.completion_tokens);
231 stats.estimated_cost_saved_cents +=
232 f64::from(entry.response.prompt_tokens + entry.response.completion_tokens) / 1000.0;
233
234 return Some(entry.response.clone());
235 }
236
237 let mut stats = self.stats.write().await;
239 stats.misses += 1;
240
241 None
242 }
243
244 pub async fn put(&self, request: &ChatRequest, response: &ChatResponse) {
246 if !self.is_cacheable(request) {
248 return;
249 }
250
251 let key = CacheKey::from_request(request);
252 let entry = CacheEntry {
253 response: CachedResponse::from(response),
254 created_at: Instant::now(),
255 hit_count: 0,
256 last_accessed: Instant::now(),
257 };
258
259 let mut cache = self.cache.write().await;
260
261 if cache.len() >= self.config.max_entries {
263 self.evict_lru(&mut cache).await;
264 }
265
266 cache.insert(key, entry);
267
268 let mut stats = self.stats.write().await;
270 stats.entries = cache.len();
271 }
272
273 fn is_cacheable(&self, request: &ChatRequest) -> bool {
275 if request.temperature.unwrap_or(1.0) > 0.5 {
277 return false;
278 }
279
280 let total_content_len: usize = request.messages.iter().map(|m| m.content.len()).sum();
282 let estimated_tokens = total_content_len / 4; if estimated_tokens > self.config.max_prompt_tokens {
285 return false;
286 }
287
288 true
292 }
293
294 async fn evict_lru(&self, cache: &mut HashMap<CacheKey, CacheEntry>) {
296 if let Some((key_to_remove, _)) = cache
297 .iter()
298 .min_by_key(|(_, entry)| entry.last_accessed)
299 .map(|(k, v)| (k.clone(), v.clone()))
300 {
301 cache.remove(&key_to_remove);
302
303 let mut stats = self.stats.write().await;
304 stats.capacity_evictions += 1;
305 }
306 }
307
308 pub async fn clear_expired(&self) {
310 let mut cache = self.cache.write().await;
311 let now = Instant::now();
312 let ttl = self.config.ttl;
313
314 let expired_keys: Vec<CacheKey> = cache
315 .iter()
316 .filter(|(_, entry)| now.duration_since(entry.created_at) > ttl)
317 .map(|(k, _)| k.clone())
318 .collect();
319
320 let expired_count = expired_keys.len();
321
322 for key in expired_keys {
323 cache.remove(&key);
324 }
325
326 if expired_count > 0 {
327 let mut stats = self.stats.write().await;
328 stats.ttl_evictions += expired_count as u64;
329 stats.entries = cache.len();
330
331 tracing::debug!(expired = expired_count, "Cleared expired cache entries");
332 }
333 }
334
335 pub async fn clear(&self) {
337 let mut cache = self.cache.write().await;
338 cache.clear();
339
340 let mut stats = self.stats.write().await;
341 stats.entries = 0;
342 }
343
344 pub async fn get_stats(&self) -> CacheStats {
346 let stats = self.stats.read().await;
347 stats.clone()
348 }
349
350 pub async fn get_cache_info(&self) -> CacheInfo {
352 let cache = self.cache.read().await;
353 let stats = self.stats.read().await;
354
355 let total_hit_count: u64 = cache.values().map(|e| e.hit_count).sum();
356 let avg_hit_count = if cache.is_empty() {
357 0.0
358 } else {
359 total_hit_count as f64 / cache.len() as f64
360 };
361
362 let oldest_entry = cache.values().map(|e| e.created_at).min();
363 let newest_entry = cache.values().map(|e| e.created_at).max();
364
365 CacheInfo {
366 config: self.config.clone(),
367 stats: stats.clone(),
368 entry_count: cache.len(),
369 total_hit_count,
370 avg_hit_count_per_entry: avg_hit_count,
371 oldest_entry_age_secs: oldest_entry.map(|t| t.elapsed().as_secs()),
372 newest_entry_age_secs: newest_entry.map(|t| t.elapsed().as_secs()),
373 }
374 }
375}
376
377#[derive(Debug, Clone, Serialize)]
379pub struct CacheInfo {
380 #[serde(skip)]
382 pub config: LlmCacheConfig,
383 pub stats: CacheStats,
385 pub entry_count: usize,
387 pub total_hit_count: u64,
389 pub avg_hit_count_per_entry: f64,
391 pub oldest_entry_age_secs: Option<u64>,
393 pub newest_entry_age_secs: Option<u64>,
395}
396
397pub struct CachedLlmClient<P> {
399 provider: P,
401 cache: LlmCache,
403}
404
405impl<P> CachedLlmClient<P> {
406 pub fn new(provider: P, cache_config: LlmCacheConfig) -> Self {
408 Self {
409 provider,
410 cache: LlmCache::new(cache_config),
411 }
412 }
413
414 pub fn with_default_cache(provider: P) -> Self {
416 Self {
417 provider,
418 cache: LlmCache::default_cache(),
419 }
420 }
421
422 pub fn provider(&self) -> &P {
424 &self.provider
425 }
426
427 pub async fn cache_stats(&self) -> CacheStats {
429 self.cache.get_stats().await
430 }
431
432 pub async fn cache_info(&self) -> CacheInfo {
434 self.cache.get_cache_info().await
435 }
436
437 pub async fn clear_cache(&self) {
439 self.cache.clear().await;
440 }
441
442 pub async fn clear_expired(&self) {
444 self.cache.clear_expired().await;
445 }
446}
447
448pub struct RequestDeduplicator {
450 in_flight: Arc<RwLock<HashMap<u64, tokio::sync::watch::Receiver<Option<CachedResponse>>>>>,
452}
453
454impl RequestDeduplicator {
455 #[must_use]
457 pub fn new() -> Self {
458 Self {
459 in_flight: Arc::new(RwLock::new(HashMap::new())),
460 }
461 }
462
463 fn hash_request(request: &ChatRequest) -> u64 {
465 let key = CacheKey::from_request(request);
466 let mut hasher = std::collections::hash_map::DefaultHasher::new();
467 key.hash(&mut hasher);
468 hasher.finish()
469 }
470
471 pub async fn is_in_flight(&self, request: &ChatRequest) -> bool {
473 let hash = Self::hash_request(request);
474 let in_flight = self.in_flight.read().await;
475 in_flight.contains_key(&hash)
476 }
477
478 pub async fn register(
480 &self,
481 request: &ChatRequest,
482 ) -> tokio::sync::watch::Sender<Option<CachedResponse>> {
483 let hash = Self::hash_request(request);
484 let (tx, rx) = tokio::sync::watch::channel(None);
485
486 let mut in_flight = self.in_flight.write().await;
487 in_flight.insert(hash, rx);
488
489 tx
490 }
491
492 pub async fn wait_for(&self, request: &ChatRequest) -> Option<CachedResponse> {
494 let hash = Self::hash_request(request);
495
496 let rx = {
497 let in_flight = self.in_flight.read().await;
498 in_flight.get(&hash).cloned()
499 };
500
501 if let Some(mut rx) = rx {
502 let _ = rx.changed().await;
504 rx.borrow().clone()
505 } else {
506 None
507 }
508 }
509
510 pub async fn complete(&self, request: &ChatRequest, response: Option<CachedResponse>) {
512 let hash = Self::hash_request(request);
513
514 let mut in_flight = self.in_flight.write().await;
515 in_flight.remove(&hash);
516
517 drop(response);
519 }
520}
521
522impl Default for RequestDeduplicator {
523 fn default() -> Self {
524 Self::new()
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::{ChatMessage, ChatRole};
532
533 #[test]
534 fn test_cache_key_generation() {
535 let request = ChatRequest {
536 messages: vec![ChatMessage {
537 role: ChatRole::User,
538 content: "Hello".to_string(),
539 }],
540 temperature: Some(0.0),
541 max_tokens: None,
542 stop: None,
543 images: None,
544 };
545
546 let key1 = CacheKey::from_request(&request);
547 let key2 = CacheKey::from_request(&request);
548
549 assert_eq!(key1, key2);
550 }
551
552 #[test]
553 fn test_different_messages_different_keys() {
554 let request1 = ChatRequest {
555 messages: vec![ChatMessage {
556 role: ChatRole::User,
557 content: "Hello".to_string(),
558 }],
559 temperature: Some(0.0),
560 max_tokens: None,
561 stop: None,
562 images: None,
563 };
564
565 let request2 = ChatRequest {
566 messages: vec![ChatMessage {
567 role: ChatRole::User,
568 content: "Goodbye".to_string(),
569 }],
570 temperature: Some(0.0),
571 max_tokens: None,
572 stop: None,
573 images: None,
574 };
575
576 let key1 = CacheKey::from_request(&request1);
577 let key2 = CacheKey::from_request(&request2);
578
579 assert_ne!(key1, key2);
580 }
581
582 #[test]
583 fn test_cache_config_defaults() {
584 let config = LlmCacheConfig::default();
585 assert_eq!(config.max_entries, 1000);
586 assert_eq!(config.ttl, Duration::from_secs(3600));
587 }
588
589 #[tokio::test]
590 async fn test_cache_miss() {
591 let cache = LlmCache::default_cache();
592
593 let request = ChatRequest {
594 messages: vec![ChatMessage {
595 role: ChatRole::User,
596 content: "Test".to_string(),
597 }],
598 temperature: Some(0.0),
599 max_tokens: None,
600 stop: None,
601 images: None,
602 };
603
604 let result = cache.get(&request).await;
605 assert!(result.is_none());
606
607 let stats = cache.get_stats().await;
608 assert_eq!(stats.misses, 1);
609 assert_eq!(stats.hits, 0);
610 }
611
612 #[test]
613 fn test_not_cacheable_high_temperature() {
614 let cache = LlmCache::default_cache();
615
616 let request = ChatRequest {
617 messages: vec![ChatMessage {
618 role: ChatRole::User,
619 content: "Test".to_string(),
620 }],
621 temperature: Some(0.9), max_tokens: None,
623 stop: None,
624 images: None,
625 };
626
627 assert!(!cache.is_cacheable(&request));
628 }
629}