use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tracing::debug;
use crate::types::{
ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider, RunnerError,
};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub cache_nonzero_temperature: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 256,
ttl: Duration::from_secs(300),
cache_nonzero_temperature: false,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub size: usize,
}
#[derive(Debug, Clone)]
struct CacheEntry {
response: ChatResponse,
inserted_at: Instant,
}
#[derive(Debug, Default)]
struct CacheState {
entries: HashMap<u64, CacheEntry>,
insertion_order: Vec<u64>,
stats: CacheStats,
}
pub struct CacheProvider {
inner: Box<dyn LlmProvider>,
config: CacheConfig,
state: Arc<Mutex<CacheState>>,
}
impl CacheProvider {
pub fn new(inner: Box<dyn LlmProvider>, config: CacheConfig) -> Self {
Self {
inner,
config,
state: Arc::new(Mutex::new(CacheState::default())),
}
}
pub fn cache_stats(&self) -> CacheStats {
let state = self.state.lock().expect("cache lock poisoned");
let mut snapshot = state.stats.clone();
snapshot.size = state.entries.len();
snapshot
}
fn cache_key(request: &ChatRequest) -> u64 {
let mut hasher = DefaultHasher::new();
serde_json::to_string(&(
&request.messages,
&request.model,
&request.temperature,
&request.max_tokens,
))
.unwrap_or_default()
.hash(&mut hasher);
hasher.finish()
}
fn should_bypass(&self, request: &ChatRequest) -> bool {
if self.config.cache_nonzero_temperature {
return false;
}
matches!(request.temperature, Some(t) if t > 0.0)
}
}
#[async_trait]
impl LlmProvider for CacheProvider {
fn name(&self) -> &'static str {
self.inner.name()
}
fn display_name(&self) -> &str {
self.inner.display_name()
}
fn capabilities(&self) -> LlmCapabilities {
self.inner.capabilities()
}
fn default_model(&self) -> &str {
self.inner.default_model()
}
fn available_models(&self) -> &[String] {
self.inner.available_models()
}
async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
if self.should_bypass(request) {
return self.inner.complete(request).await;
}
let key = Self::cache_key(request);
{
let mut state = self.state.lock().expect("cache lock poisoned");
let cached = state.entries.get(&key).and_then(|entry| {
if entry.inserted_at.elapsed() < self.config.ttl {
Some(entry.response.clone())
} else {
None
}
});
if let Some(response) = cached {
state.stats.hits += 1;
debug!(key, "cache hit");
return Ok(response);
}
if state.entries.remove(&key).is_some() {
state.insertion_order.retain(|k| *k != key);
state.stats.evictions += 1;
}
state.stats.misses += 1;
}
let response = self.inner.complete(request).await?;
{
let mut state = self.state.lock().expect("cache lock poisoned");
while state.entries.len() >= self.config.max_entries {
if let Some(oldest_key) = state.insertion_order.first().copied() {
state.entries.remove(&oldest_key);
state.insertion_order.remove(0);
state.stats.evictions += 1;
} else {
break;
}
}
state.entries.insert(
key,
CacheEntry {
response: response.clone(),
inserted_at: Instant::now(),
},
);
state.insertion_order.push(key);
}
Ok(response)
}
async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError> {
self.inner.complete_stream(request).await
}
async fn health_check(&self) -> Result<bool, RunnerError> {
self.inner.health_check().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{
ChatMessage, ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider,
RunnerError,
};
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
struct TestProvider {
responses: Mutex<Vec<Result<ChatResponse, RunnerError>>>,
call_count: AtomicU32,
}
impl TestProvider {
fn new(responses: Vec<Result<ChatResponse, RunnerError>>) -> Self {
Self {
responses: Mutex::new(responses),
call_count: AtomicU32::new(0),
}
}
}
#[async_trait]
impl LlmProvider for TestProvider {
fn name(&self) -> &'static str {
"test"
}
fn display_name(&self) -> &str {
"Test Provider"
}
fn capabilities(&self) -> LlmCapabilities {
LlmCapabilities::text_only()
}
fn default_model(&self) -> &'static str {
"test-model"
}
fn available_models(&self) -> &[String] {
&[]
}
async fn complete(&self, _request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let mut responses = self.responses.lock().expect("test lock");
if responses.is_empty() {
Ok(ChatResponse {
content: "default".to_owned(),
model: "test-model".to_owned(),
usage: None,
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})
} else {
responses.remove(0)
}
}
async fn complete_stream(&self, _request: &ChatRequest) -> Result<ChatStream, RunnerError> {
Err(RunnerError::internal("streaming not supported in test"))
}
async fn health_check(&self) -> Result<bool, RunnerError> {
Ok(true)
}
}
fn make_response(content: &str) -> ChatResponse {
ChatResponse {
content: content.to_owned(),
model: "test-model".to_owned(),
usage: None,
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
}
}
#[tokio::test]
async fn cache_hit() {
let provider = TestProvider::new(vec![Ok(make_response("cached"))]);
let cached = CacheProvider::new(Box::new(provider), CacheConfig::default());
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
let r1 = cached.complete(&request).await.expect("first call");
let r2 = cached.complete(&request).await.expect("second call");
assert_eq!(r1.content, "cached");
assert_eq!(r2.content, "cached");
let stats = cached.cache_stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[tokio::test]
async fn cache_miss_different_request() {
let provider = TestProvider::new(vec![
Ok(make_response("first")),
Ok(make_response("second")),
]);
let cached = CacheProvider::new(Box::new(provider), CacheConfig::default());
let r1 = cached
.complete(&ChatRequest::new(vec![ChatMessage::user("hello")]))
.await
.expect("first");
let r2 = cached
.complete(&ChatRequest::new(vec![ChatMessage::user("goodbye")]))
.await
.expect("second");
assert_eq!(r1.content, "first");
assert_eq!(r2.content, "second");
let stats = cached.cache_stats();
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 0);
}
#[tokio::test]
async fn bypass_nonzero_temp() {
let provider = TestProvider::new(vec![Ok(make_response("r1")), Ok(make_response("r2"))]);
let cached = CacheProvider::new(Box::new(provider), CacheConfig::default());
let request = ChatRequest::new(vec![ChatMessage::user("hi")]).with_temperature(0.7);
let r1 = cached.complete(&request).await.expect("first");
let r2 = cached.complete(&request).await.expect("second");
assert_eq!(r1.content, "r1");
assert_eq!(r2.content, "r2");
}
#[tokio::test]
async fn cache_with_temp_configured() {
let provider = TestProvider::new(vec![Ok(make_response("cached"))]);
let config = CacheConfig {
cache_nonzero_temperature: true,
..CacheConfig::default()
};
let cached = CacheProvider::new(Box::new(provider), config);
let request = ChatRequest::new(vec![ChatMessage::user("hi")]).with_temperature(0.7);
let r1 = cached.complete(&request).await.expect("first");
let r2 = cached.complete(&request).await.expect("second");
assert_eq!(r1.content, "cached");
assert_eq!(r2.content, "cached");
let stats = cached.cache_stats();
assert_eq!(stats.hits, 1);
}
#[tokio::test]
async fn ttl_expiration() {
let provider =
TestProvider::new(vec![Ok(make_response("old")), Ok(make_response("fresh"))]);
let config = CacheConfig {
ttl: Duration::from_millis(10),
..CacheConfig::default()
};
let cached = CacheProvider::new(Box::new(provider), config);
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
let r1 = cached.complete(&request).await.expect("first");
assert_eq!(r1.content, "old");
tokio::time::sleep(Duration::from_millis(20)).await;
let r2 = cached.complete(&request).await.expect("after expiry");
assert_eq!(r2.content, "fresh");
let stats = cached.cache_stats();
assert_eq!(stats.evictions, 1);
}
#[tokio::test]
async fn eviction_at_capacity() {
let provider = TestProvider::new(vec![
Ok(make_response("a")),
Ok(make_response("b")),
Ok(make_response("c")),
]);
let config = CacheConfig {
max_entries: 2,
..CacheConfig::default()
};
let cached = CacheProvider::new(Box::new(provider), config);
cached
.complete(&ChatRequest::new(vec![ChatMessage::user("1")]))
.await
.expect("a");
cached
.complete(&ChatRequest::new(vec![ChatMessage::user("2")]))
.await
.expect("b");
cached
.complete(&ChatRequest::new(vec![ChatMessage::user("3")]))
.await
.expect("c");
let stats = cached.cache_stats();
assert_eq!(stats.size, 2);
assert_eq!(stats.evictions, 1);
}
#[tokio::test]
async fn stats_tracking() {
let provider = TestProvider::new(vec![Ok(make_response("r1")), Ok(make_response("r2"))]);
let cached = CacheProvider::new(Box::new(provider), CacheConfig::default());
let req1 = ChatRequest::new(vec![ChatMessage::user("hello")]);
let req2 = ChatRequest::new(vec![ChatMessage::user("world")]);
cached.complete(&req1).await.expect("miss");
cached.complete(&req1).await.expect("hit");
cached.complete(&req2).await.expect("miss");
let stats = cached.cache_stats();
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 1);
assert_eq!(stats.size, 2);
}
#[tokio::test]
async fn streaming_bypasses() {
let provider = TestProvider::new(vec![]);
let cached = CacheProvider::new(Box::new(provider), CacheConfig::default());
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
let result = cached.complete_stream(&request).await;
assert!(result.is_err());
}
#[test]
fn key_determinism() {
let req1 = ChatRequest::new(vec![ChatMessage::user("hello")]);
let req2 = ChatRequest::new(vec![ChatMessage::user("hello")]);
let req3 = ChatRequest::new(vec![ChatMessage::user("different")]);
assert_eq!(
CacheProvider::cache_key(&req1),
CacheProvider::cache_key(&req2)
);
assert_ne!(
CacheProvider::cache_key(&req1),
CacheProvider::cache_key(&req3)
);
}
}