#![allow(missing_docs)]
use crate::error::Result;
use crate::recursive::llm::{Llm, LmOutput};
use lru::LruCache;
use std::collections::hash_map::DefaultHasher;
use std::future::Future;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::sync::Mutex;
use std::task::{ready, Context, Poll};
pub struct CachedLlm<L: Llm> {
inner: L,
cache: Mutex<LruCache<u64, LmOutput>>,
}
impl<L: Llm> CachedLlm<L> {
pub fn new(inner: L, capacity: usize) -> Self {
Self {
inner,
cache: Mutex::new(LruCache::new(NonZeroUsize::new(capacity.max(1)).unwrap())),
}
}
pub fn cache_len(&self) -> usize {
self.cache.lock().expect("cache lock poisoned").len()
}
pub fn clear_cache(&self) {
self.cache.lock().expect("cache lock poisoned").clear();
}
}
pin_project_lite::pin_project! {
#[project = CachedFutProj]
pub enum CachedFut<'a, L>
where
L: Llm,
L: 'a,
{
Hit { result: Option<Result<LmOutput>> },
Miss {
#[pin]
inner: L::GenerateFut<'a>,
cache: &'a Mutex<LruCache<u64, LmOutput>>,
key: u64,
},
}
}
impl<'a, L: Llm + 'a> Future for CachedFut<'a, L> {
type Output = Result<LmOutput>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
CachedFutProj::Hit { result } => Poll::Ready(
result
.take()
.expect("CachedFut::Hit polled after completion"),
),
CachedFutProj::Miss { inner, cache, key } => {
let output = ready!(inner.poll(cx));
if let Ok(ref lm_out) = output {
cache
.lock()
.expect("cache lock poisoned")
.put(*key, lm_out.clone());
}
Poll::Ready(output)
}
}
}
}
impl<L: Llm> Llm for CachedLlm<L> {
type GenerateFut<'a>
= CachedFut<'a, L>
where
Self: 'a;
fn generate<'a>(
&'a self,
prompt: &'a str,
context: &'a str,
feedback: Option<&'a str>,
) -> Self::GenerateFut<'a> {
let key = cache_key(prompt, context, feedback);
if let Some(cached) = self.cache.lock().expect("cache lock poisoned").get(&key) {
return CachedFut::Hit {
result: Some(Ok(cached.clone())),
};
}
CachedFut::Miss {
inner: self.inner.generate(prompt, context, feedback),
cache: &self.cache,
key,
}
}
fn model_name(&self) -> &str {
self.inner.model_name()
}
fn max_context(&self) -> usize {
self.inner.max_context()
}
}
pub trait CacheExt: Llm + Sized {
fn with_cache(self, capacity: usize) -> CachedLlm<Self> {
CachedLlm::new(self, capacity)
}
}
impl<L: Llm> CacheExt for L {}
fn cache_key(prompt: &str, context: &str, feedback: Option<&str>) -> u64 {
let mut hasher = DefaultHasher::new();
prompt.hash(&mut hasher);
context.hash(&mut hasher);
feedback.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::recursive::llm::MockLlm;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn test_cache_hit() {
let call_count = AtomicU32::new(0);
let llm = MockLlm::new(move |_, _| {
call_count.fetch_add(1, Ordering::SeqCst);
"response".to_string()
})
.with_cache(10);
let r1 = llm.generate("hello", "", None).await.unwrap();
assert_eq!(&*r1.text, "response");
assert_eq!(llm.cache_len(), 1);
let r2 = llm.generate("hello", "", None).await.unwrap();
assert_eq!(&*r2.text, "response");
assert_eq!(llm.cache_len(), 1);
}
#[tokio::test]
async fn test_cache_different_prompts() {
let llm = MockLlm::new(|prompt, _| format!("reply to: {}", prompt)).with_cache(10);
let r1 = llm.generate("a", "", None).await.unwrap();
let r2 = llm.generate("b", "", None).await.unwrap();
assert_eq!(&*r1.text, "reply to: a");
assert_eq!(&*r2.text, "reply to: b");
assert_eq!(llm.cache_len(), 2);
}
#[tokio::test]
async fn test_cache_with_feedback() {
let llm = MockLlm::new(|_, fb| {
fb.map(|f| format!("with: {}", f))
.unwrap_or("no feedback".to_string())
})
.with_cache(10);
let r1 = llm.generate("p", "", None).await.unwrap();
let r2 = llm.generate("p", "", Some("improve")).await.unwrap();
assert_eq!(&*r1.text, "no feedback");
assert_eq!(&*r2.text, "with: improve");
assert_eq!(llm.cache_len(), 2); }
#[tokio::test]
async fn test_cache_eviction() {
let llm = MockLlm::new(|prompt, _| prompt.to_string()).with_cache(2);
llm.generate("a", "", None).await.unwrap();
llm.generate("b", "", None).await.unwrap();
assert_eq!(llm.cache_len(), 2);
llm.generate("c", "", None).await.unwrap();
assert_eq!(llm.cache_len(), 2);
}
#[tokio::test]
async fn test_cache_clear() {
let llm = MockLlm::new(|_, _| "ok".to_string()).with_cache(10);
llm.generate("a", "", None).await.unwrap();
llm.generate("b", "", None).await.unwrap();
assert_eq!(llm.cache_len(), 2);
llm.clear_cache();
assert_eq!(llm.cache_len(), 0);
}
#[test]
fn test_model_name_preserved() {
let llm = MockLlm::new(|_, _| "ok".to_string())
.with_name("gpt-4")
.with_cache(10);
assert_eq!(llm.model_name(), "gpt-4");
}
#[test]
fn test_cache_composable_with_retry() {
use crate::recursive::retry::LlmExt;
let llm = MockLlm::new(|_, _| "ok".to_string())
.with_cache(10)
.with_retry(3);
assert_eq!(llm.model_name(), "mock");
}
}