1use std::collections::HashMap;
26use std::hash::{DefaultHasher, Hash, Hasher};
27use std::pin::Pin;
28use std::sync::{Arc, Mutex};
29use std::task::{Context, Poll};
30use std::time::{Duration, Instant};
31
32use async_trait::async_trait;
33use futures::Stream;
34use serde_json::{Map, Value};
35
36use crate::error::{ProviderError, Result};
37#[cfg(test)]
38use crate::language_model::TextPart;
39use crate::language_model::{
40 BoxStream, CallOptions, GenerateResult, LanguageModel, StreamPart, StreamResult,
41};
42
43use super::language_model::LanguageModelMiddleware;
44
45pub trait CacheStore: Send + Sync + std::fmt::Debug {
49 fn get(&self, key: &str) -> Option<CachedEntry>;
51
52 fn put(&self, key: String, value: CachedEntry);
54}
55
56#[derive(Debug, Clone)]
62pub enum CachedEntry {
63 Generate(Box<GenerateResult>),
65 Stream(Vec<StreamPart>),
67}
68
69#[derive(Debug, Default)]
81pub struct MemoryCacheStore {
82 inner: Mutex<MemoryCacheState>,
83}
84
85#[derive(Debug, Default)]
86struct MemoryCacheState {
87 entries: HashMap<String, CacheEntry>,
88 tick: u64,
90 max_entries: Option<usize>,
92 max_age: Option<Duration>,
94}
95
96#[derive(Debug, Clone)]
97struct CacheEntry {
98 value: CachedEntry,
99 inserted_at: Instant,
100 last_access: u64,
101}
102
103impl MemoryCacheStore {
104 #[must_use]
106 pub fn new() -> Self {
107 Self::default()
108 }
109
110 #[must_use]
112 pub fn builder() -> MemoryCacheStoreBuilder {
113 MemoryCacheStoreBuilder::default()
114 }
115
116 #[must_use]
125 pub fn len(&self) -> usize {
126 self.inner
127 .lock()
128 .expect("cache mutex poisoned")
129 .entries
130 .len()
131 }
132
133 #[must_use]
139 pub fn is_empty(&self) -> bool {
140 self.inner
141 .lock()
142 .expect("cache mutex poisoned")
143 .entries
144 .is_empty()
145 }
146}
147
148#[derive(Debug, Default, Clone, Copy)]
150pub struct MemoryCacheStoreBuilder {
151 max_entries: Option<usize>,
152 max_age: Option<Duration>,
153}
154
155impl MemoryCacheStoreBuilder {
156 #[must_use]
158 pub fn max_entries(mut self, n: usize) -> Self {
159 self.max_entries = Some(n);
160 self
161 }
162
163 #[must_use]
166 pub fn max_age(mut self, age: Duration) -> Self {
167 self.max_age = Some(age);
168 self
169 }
170
171 #[must_use]
173 pub fn build(self) -> MemoryCacheStore {
174 MemoryCacheStore {
175 inner: Mutex::new(MemoryCacheState {
176 entries: HashMap::new(),
177 tick: 0,
178 max_entries: self.max_entries,
179 max_age: self.max_age,
180 }),
181 }
182 }
183}
184
185impl MemoryCacheState {
186 fn touch(&mut self) -> u64 {
187 self.tick = self.tick.saturating_add(1);
188 self.tick
189 }
190
191 fn evict_one_lru(&mut self) {
193 let victim = self
194 .entries
195 .iter()
196 .min_by_key(|(_, e)| e.last_access)
197 .map(|(k, _)| k.clone());
198 if let Some(k) = victim {
199 self.entries.remove(&k);
200 }
201 }
202
203 fn prune_expired(&mut self) {
205 let Some(age) = self.max_age else {
206 return;
207 };
208 let now = Instant::now();
209 self.entries
210 .retain(|_, e| now.duration_since(e.inserted_at) <= age);
211 }
212}
213
214impl CacheStore for MemoryCacheStore {
215 fn get(&self, key: &str) -> Option<CachedEntry> {
216 let mut guard = self.inner.lock().expect("cache mutex poisoned");
217 if let Some(age) = guard.max_age
219 && let Some(entry) = guard.entries.get(key)
220 && Instant::now().duration_since(entry.inserted_at) > age
221 {
222 guard.entries.remove(key);
223 return None;
224 }
225 let tick = guard.touch();
226 let entry = guard.entries.get_mut(key)?;
227 entry.last_access = tick;
228 Some(entry.value.clone())
229 }
230
231 fn put(&self, key: String, value: CachedEntry) {
232 let mut guard = self.inner.lock().expect("cache mutex poisoned");
233 guard.prune_expired();
234 let tick = guard.touch();
235 let new_entry = CacheEntry {
236 value,
237 inserted_at: Instant::now(),
238 last_access: tick,
239 };
240 guard.entries.insert(key, new_entry);
241 if let Some(cap) = guard.max_entries {
242 while guard.entries.len() > cap {
243 guard.evict_one_lru();
244 }
245 }
246 }
247}
248
249#[derive(Debug, Clone)]
256pub struct CacheMiddleware {
257 store: Arc<dyn CacheStore>,
258}
259
260impl CacheMiddleware {
261 #[must_use]
263 pub fn new(store: Arc<dyn CacheStore>) -> Self {
264 Self { store }
265 }
266}
267
268fn key_for(options: &CallOptions) -> Result<String> {
274 let bytes = serde_json::to_vec(options)
275 .map_err(|e| ProviderError::type_validation("call_options", Value::Null, e.to_string()))?;
276 let mut hasher = DefaultHasher::new();
277 bytes.hash(&mut hasher);
278 Ok(format!("{:016x}", hasher.finish()))
279}
280
281fn mark_generate_hit(result: &mut GenerateResult) {
284 let entry = result.provider_metadata.get_or_insert_with(HashMap::new);
285 let bucket = entry.entry("llmsdk".to_owned()).or_default();
286 bucket.insert("cache".to_owned(), Value::String("hit".to_owned()));
287}
288
289fn hit_metadata() -> crate::shared::ProviderMetadata {
292 let mut map: crate::shared::ProviderMetadata = HashMap::new();
293 let mut bucket = Map::new();
294 bucket.insert("cache".to_owned(), Value::String("hit".to_owned()));
295 map.insert("llmsdk".to_owned(), bucket);
296 map
297}
298
299fn annotate_stream_hit(parts: &mut Vec<StreamPart>) {
305 for part in parts.iter_mut() {
306 if matches!(part, StreamPart::StreamStart { .. }) {
307 continue;
308 }
309 if inject_metadata(part, &hit_metadata()) {
310 return;
311 }
312 }
313 parts.insert(
314 0,
315 StreamPart::Custom {
316 kind: "llmsdk.cache.hit".to_owned(),
317 provider_metadata: Some(hit_metadata()),
318 },
319 );
320}
321
322fn inject_metadata(part: &mut StreamPart, mark: &crate::shared::ProviderMetadata) -> bool {
325 let (StreamPart::TextStart {
326 provider_metadata: slot,
327 ..
328 }
329 | StreamPart::TextDelta {
330 provider_metadata: slot,
331 ..
332 }
333 | StreamPart::TextEnd {
334 provider_metadata: slot,
335 ..
336 }
337 | StreamPart::ReasoningStart {
338 provider_metadata: slot,
339 ..
340 }
341 | StreamPart::ReasoningDelta {
342 provider_metadata: slot,
343 ..
344 }
345 | StreamPart::ReasoningEnd {
346 provider_metadata: slot,
347 ..
348 }
349 | StreamPart::ToolInputStart {
350 provider_metadata: slot,
351 ..
352 }
353 | StreamPart::ToolInputDelta {
354 provider_metadata: slot,
355 ..
356 }
357 | StreamPart::ToolInputEnd {
358 provider_metadata: slot,
359 ..
360 }
361 | StreamPart::Custom {
362 provider_metadata: slot,
363 ..
364 }
365 | StreamPart::Finish {
366 provider_metadata: slot,
367 ..
368 }) = part
369 else {
370 return false;
371 };
372 let target = slot.get_or_insert_with(HashMap::new);
373 for (provider, bucket) in mark {
374 let dest = target.entry(provider.clone()).or_default();
375 for (k, v) in bucket {
376 dest.insert(k.clone(), v.clone());
377 }
378 }
379 true
380}
381
382#[async_trait]
383impl LanguageModelMiddleware for CacheMiddleware {
384 async fn wrap_generate(
385 &self,
386 next: &dyn LanguageModel,
387 params: CallOptions,
388 ) -> Result<GenerateResult> {
389 let key = key_for(¶ms)?;
390 if let Some(CachedEntry::Generate(mut hit)) = self.store.get(&key) {
391 mark_generate_hit(&mut hit);
392 return Ok(*hit);
393 }
394 let result = next.do_generate(params).await?;
395 self.store
396 .put(key, CachedEntry::Generate(Box::new(result.clone())));
397 Ok(result)
398 }
399
400 async fn wrap_stream(
401 &self,
402 next: &dyn LanguageModel,
403 params: CallOptions,
404 ) -> Result<StreamResult> {
405 let key = key_for(¶ms)?;
406 if let Some(CachedEntry::Stream(mut parts)) = self.store.get(&key) {
407 annotate_stream_hit(&mut parts);
408 let stream = futures::stream::iter(parts.into_iter().map(Ok));
409 return Ok(StreamResult {
410 stream: Box::pin(stream),
411 request: None,
412 response: None,
413 });
414 }
415 let StreamResult {
416 stream,
417 request,
418 response,
419 } = next.do_stream(params).await?;
420 let capturing = CapturingStream::new(stream, Arc::clone(&self.store), key);
421 Ok(StreamResult {
422 stream: Box::pin(capturing),
423 request,
424 response,
425 })
426 }
427}
428
429struct CapturingStream {
432 inner: BoxStream<Result<StreamPart>>,
433 store: Arc<dyn CacheStore>,
434 key: Option<String>,
435 captured: Vec<StreamPart>,
436 poisoned: bool,
437}
438
439impl CapturingStream {
440 fn new(inner: BoxStream<Result<StreamPart>>, store: Arc<dyn CacheStore>, key: String) -> Self {
441 Self {
442 inner,
443 store,
444 key: Some(key),
445 captured: Vec::new(),
446 poisoned: false,
447 }
448 }
449}
450
451impl Stream for CapturingStream {
452 type Item = Result<StreamPart>;
453
454 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
455 let polled = self.inner.as_mut().poll_next(cx);
456 match &polled {
457 Poll::Ready(Some(Ok(part))) => {
458 self.captured.push(part.clone());
459 }
460 Poll::Ready(Some(Err(_))) => {
461 self.poisoned = true;
462 }
463 Poll::Ready(None) => {
464 if !self.poisoned
465 && let Some(key) = self.key.take()
466 {
467 let captured = std::mem::take(&mut self.captured);
468 self.store.put(key, CachedEntry::Stream(captured));
469 }
470 }
471 Poll::Pending => {}
472 }
473 polled
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use std::sync::atomic::{AtomicUsize, Ordering};
480
481 use futures::StreamExt;
482
483 use crate::language_model::{Content, FinishReason, FinishReasonKind, Usage};
484
485 use super::*;
486
487 #[derive(Debug)]
488 struct CountingModel {
489 provider: String,
490 model_id: String,
491 generate_calls: AtomicUsize,
492 stream_calls: AtomicUsize,
493 }
494
495 impl CountingModel {
496 fn new() -> Self {
497 Self {
498 provider: "test".to_owned(),
499 model_id: "counter".to_owned(),
500 generate_calls: AtomicUsize::new(0),
501 stream_calls: AtomicUsize::new(0),
502 }
503 }
504 }
505
506 fn ok_generate(text: &str) -> GenerateResult {
507 GenerateResult {
508 content: vec![Content::Text(TextPart {
509 text: text.to_owned(),
510 provider_options: None,
511 })],
512 finish_reason: FinishReason::new(FinishReasonKind::Stop),
513 usage: Usage::default(),
514 provider_metadata: None,
515 request: None,
516 response: None,
517 warnings: vec![],
518 }
519 }
520
521 #[async_trait]
522 impl LanguageModel for CountingModel {
523 fn provider(&self) -> &str {
524 &self.provider
525 }
526 fn model_id(&self) -> &str {
527 &self.model_id
528 }
529 async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResult> {
530 self.generate_calls.fetch_add(1, Ordering::SeqCst);
531 Ok(ok_generate("hello"))
532 }
533 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
534 self.stream_calls.fetch_add(1, Ordering::SeqCst);
535 let parts = vec![
536 Ok(StreamPart::StreamStart { warnings: vec![] }),
537 Ok(StreamPart::TextStart {
538 id: "0".to_owned(),
539 provider_metadata: None,
540 }),
541 Ok(StreamPart::TextDelta {
542 id: "0".to_owned(),
543 delta: "hi".to_owned(),
544 provider_metadata: None,
545 }),
546 Ok(StreamPart::TextEnd {
547 id: "0".to_owned(),
548 provider_metadata: None,
549 }),
550 Ok(StreamPart::Finish {
551 usage: Usage::default(),
552 finish_reason: FinishReason::new(FinishReasonKind::Stop),
553 provider_metadata: None,
554 }),
555 ];
556 Ok(StreamResult {
557 stream: Box::pin(futures::stream::iter(parts)),
558 request: None,
559 response: None,
560 })
561 }
562 }
563
564 #[derive(Debug)]
565 struct FailingStreamModel {
566 provider: String,
567 model_id: String,
568 }
569
570 impl Default for FailingStreamModel {
571 fn default() -> Self {
572 Self {
573 provider: "test".to_owned(),
574 model_id: "fail-stream".to_owned(),
575 }
576 }
577 }
578
579 #[async_trait]
580 impl LanguageModel for FailingStreamModel {
581 fn provider(&self) -> &str {
582 &self.provider
583 }
584 fn model_id(&self) -> &str {
585 &self.model_id
586 }
587 async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResult> {
588 Ok(ok_generate(""))
589 }
590 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
591 let parts: Vec<Result<StreamPart>> = vec![
592 Ok(StreamPart::StreamStart { warnings: vec![] }),
593 Err(ProviderError::empty_response_body()),
594 ];
595 Ok(StreamResult {
596 stream: Box::pin(futures::stream::iter(parts)),
597 request: None,
598 response: None,
599 })
600 }
601 }
602
603 #[tokio::test]
604 async fn generate_second_call_hits_cache() {
605 let store = Arc::new(MemoryCacheStore::new());
606 let mw = CacheMiddleware::new(Arc::clone(&store) as Arc<dyn CacheStore>);
607 let model = CountingModel::new();
608
609 let first = mw
610 .wrap_generate(&model, CallOptions::default())
611 .await
612 .expect("first call");
613 assert!(first.provider_metadata.is_none(), "miss is not annotated");
614
615 let second = mw
616 .wrap_generate(&model, CallOptions::default())
617 .await
618 .expect("second call");
619 assert_eq!(model.generate_calls.load(Ordering::SeqCst), 1);
620 let llmsdk = second
621 .provider_metadata
622 .as_ref()
623 .and_then(|m| m.get("llmsdk"))
624 .expect("hit metadata present");
625 assert_eq!(llmsdk.get("cache"), Some(&Value::String("hit".to_owned())));
626 assert_eq!(store.len(), 1);
627 }
628
629 #[tokio::test]
630 async fn stream_second_call_replays_cached_parts() {
631 let store = Arc::new(MemoryCacheStore::new());
632 let mw = CacheMiddleware::new(Arc::clone(&store) as Arc<dyn CacheStore>);
633 let model = CountingModel::new();
634
635 let first = mw
637 .wrap_stream(&model, CallOptions::default())
638 .await
639 .expect("first stream");
640 let first_parts: Vec<_> = first
641 .stream
642 .filter_map(|r| async move { r.ok() })
643 .collect()
644 .await;
645 assert_eq!(first_parts.len(), 5);
646 assert_eq!(model.stream_calls.load(Ordering::SeqCst), 1);
647 assert_eq!(store.len(), 1, "stream committed after Ok completion");
648
649 let second = mw
651 .wrap_stream(&model, CallOptions::default())
652 .await
653 .expect("second stream");
654 let second_parts: Vec<_> = second
655 .stream
656 .filter_map(|r| async move { r.ok() })
657 .collect()
658 .await;
659 assert_eq!(
660 model.stream_calls.load(Ordering::SeqCst),
661 1,
662 "no second call"
663 );
664 assert_eq!(second_parts.len(), first_parts.len());
665
666 let any_hit = second_parts.iter().any(|p| match p {
668 StreamPart::TextStart {
669 provider_metadata, ..
670 }
671 | StreamPart::TextDelta {
672 provider_metadata, ..
673 }
674 | StreamPart::TextEnd {
675 provider_metadata, ..
676 }
677 | StreamPart::Finish {
678 provider_metadata, ..
679 } => {
680 provider_metadata
681 .as_ref()
682 .and_then(|m| m.get("llmsdk"))
683 .and_then(|b| b.get("cache"))
684 == Some(&Value::String("hit".to_owned()))
685 }
686 _ => false,
687 });
688 assert!(any_hit, "cache hit marker must be visible on replay");
689 }
690
691 #[tokio::test]
692 async fn stream_does_not_cache_when_inner_errors() {
693 let store = Arc::new(MemoryCacheStore::new());
694 let mw = CacheMiddleware::new(Arc::clone(&store) as Arc<dyn CacheStore>);
695 let model = FailingStreamModel::default();
696
697 let result = mw
698 .wrap_stream(&model, CallOptions::default())
699 .await
700 .expect("open succeeds");
701 let parts: Vec<Result<StreamPart>> = result.stream.collect().await;
702 assert_eq!(parts.len(), 2, "one Ok + one Err drained");
703 assert!(parts[1].is_err());
704 assert!(store.is_empty(), "must not cache a poisoned stream");
705 }
706
707 #[tokio::test]
708 async fn generate_failure_is_not_cached() {
709 #[derive(Debug)]
710 struct AlwaysFail {
711 provider: String,
712 model_id: String,
713 }
714 #[async_trait]
715 impl LanguageModel for AlwaysFail {
716 fn provider(&self) -> &str {
717 &self.provider
718 }
719 fn model_id(&self) -> &str {
720 &self.model_id
721 }
722 async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResult> {
723 Err(ProviderError::empty_response_body())
724 }
725 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
726 unreachable!()
727 }
728 }
729 let model = AlwaysFail {
730 provider: "test".to_owned(),
731 model_id: "fail".to_owned(),
732 };
733 let store = Arc::new(MemoryCacheStore::new());
734 let mw = CacheMiddleware::new(Arc::clone(&store) as Arc<dyn CacheStore>);
735 let _ = mw.wrap_generate(&model, CallOptions::default()).await;
736 assert!(store.is_empty());
737 }
738
739 #[test]
740 fn key_is_stable_for_equal_options() {
741 let a = CallOptions::default();
742 let b = CallOptions::default();
743 assert_eq!(key_for(&a).unwrap(), key_for(&b).unwrap());
744 }
745
746 #[test]
747 fn key_differs_when_temperature_changes() {
748 let a = CallOptions {
749 temperature: Some(0.1),
750 ..CallOptions::default()
751 };
752 let b = CallOptions {
753 temperature: Some(0.9),
754 ..CallOptions::default()
755 };
756 assert_ne!(key_for(&a).unwrap(), key_for(&b).unwrap());
757 }
758
759 fn dummy_entry() -> CachedEntry {
760 CachedEntry::Generate(Box::new(ok_generate("hello")))
761 }
762
763 #[test]
764 fn lru_evicts_oldest_entry_over_capacity() {
765 let store = MemoryCacheStore::builder().max_entries(2).build();
766 store.put("a".into(), dummy_entry());
767 store.put("b".into(), dummy_entry());
768 let _ = store.get("a");
770 store.put("c".into(), dummy_entry());
771
772 assert!(store.get("a").is_some(), "a still present after touch");
773 assert!(store.get("b").is_none(), "b evicted as LRU");
774 assert!(store.get("c").is_some(), "c just inserted");
775 assert_eq!(store.len(), 2);
776 }
777
778 #[test]
779 fn ttl_expires_entries_on_get() {
780 let store = MemoryCacheStore::builder()
781 .max_age(Duration::from_millis(10))
782 .build();
783 store.put("a".into(), dummy_entry());
784 std::thread::sleep(Duration::from_millis(20));
785 assert!(store.get("a").is_none(), "expired entry pruned");
786 assert_eq!(store.len(), 0);
787 }
788}