use crate::core::language_models::LLMResult;
use crate::schema::Message;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct CachedLLMResult {
pub result: LLMResult,
pub cached_at: Instant,
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Option<Duration>,
pub enabled: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
ttl: Some(Duration::from_secs(3600)), enabled: true,
}
}
}
impl CacheConfig {
pub fn new() -> Self {
Self::default()
}
pub fn no_ttl(mut self) -> Self {
self.ttl = None;
self
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = Some(ttl);
self
}
pub fn with_max_entries(mut self, max: usize) -> Self {
self.max_entries = max;
self
}
pub fn disabled(mut self) -> Self {
self.enabled = false;
self
}
}
pub struct LLMCache {
config: CacheConfig,
store: RwLock<HashMap<String, CachedLLMResult>>,
}
impl LLMCache {
pub fn new() -> Self {
Self::with_config(CacheConfig::default())
}
pub fn with_config(config: CacheConfig) -> Self {
Self {
config,
store: RwLock::new(HashMap::new()),
}
}
pub fn build_key(messages: &[Message], model: &str) -> String {
format!("{}:{}", model, serde_json::to_string(messages).unwrap_or_default())
}
pub async fn get(&self, key: &str) -> Option<CachedLLMResult> {
if !self.config.enabled {
return None;
}
let store = self.store.read().await;
if let Some(entry) = store.get(key) {
if let Some(ttl) = self.config.ttl {
if entry.cached_at.elapsed() > ttl {
return None; }
}
Some(entry.clone())
} else {
None
}
}
pub async fn put(&self, key: String, result: LLMResult) {
if !self.config.enabled {
return;
}
let mut store = self.store.write().await;
if self.config.max_entries > 0 && store.len() >= self.config.max_entries {
if let Some(oldest_key) = store.iter()
.min_by_key(|(_, v)| v.cached_at)
.map(|(k, _)| k.clone())
{
store.remove(&oldest_key);
}
}
store.insert(key, CachedLLMResult {
result,
cached_at: Instant::now(),
});
}
pub async fn clear(&self) {
let mut store = self.store.write().await;
store.clear();
}
pub async fn len(&self) -> usize {
let store = self.store.read().await;
store.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
pub async fn evict_expired(&self) -> usize {
if let Some(ttl) = self.config.ttl {
let mut store = self.store.write().await;
let before = store.len();
store.retain(|_, v| v.cached_at.elapsed() <= ttl);
before - store.len()
} else {
0
}
}
}
impl Default for LLMCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::language_models::TokenUsage;
fn make_result(content: &str) -> LLMResult {
LLMResult {
content: content.to_string(),
model: "test-model".to_string(),
token_usage: Some(TokenUsage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
}),
tool_calls: None,
}
}
#[tokio::test]
async fn test_cache_put_and_get() {
let cache = LLMCache::new();
let key = "test-key";
let result = make_result("Hello, world!");
cache.put(key.to_string(), result.clone()).await;
let cached = cache.get(key).await;
assert!(cached.is_some());
assert_eq!(cached.unwrap().result.content, "Hello, world!");
}
#[tokio::test]
async fn test_cache_miss() {
let cache = LLMCache::new();
let cached = cache.get("non-existent").await;
assert!(cached.is_none());
}
#[tokio::test]
async fn test_cache_clear() {
let cache = LLMCache::new();
cache.put("k1".to_string(), make_result("r1")).await;
cache.put("k2".to_string(), make_result("r2")).await;
assert_eq!(cache.len().await, 2);
cache.clear().await;
assert_eq!(cache.len().await, 0);
}
#[tokio::test]
async fn test_cache_disabled() {
let config = CacheConfig::new().disabled();
let cache = LLMCache::with_config(config);
cache.put("key".to_string(), make_result("test")).await;
let cached = cache.get("key").await;
assert!(cached.is_none());
}
#[tokio::test]
async fn test_cache_ttl_expiry() {
let config = CacheConfig::new()
.with_ttl(Duration::from_millis(10));
let cache = LLMCache::with_config(config);
cache.put("key".to_string(), make_result("test")).await;
assert!(cache.get("key").await.is_some());
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(cache.get("key").await.is_none());
}
#[tokio::test]
async fn test_cache_max_entries() {
let config = CacheConfig::new()
.with_max_entries(3)
.no_ttl();
let cache = LLMCache::with_config(config);
cache.put("a".to_string(), make_result("1")).await;
cache.put("b".to_string(), make_result("2")).await;
cache.put("c".to_string(), make_result("3")).await;
assert_eq!(cache.len().await, 3);
cache.put("d".to_string(), make_result("4")).await;
assert_eq!(cache.len().await, 3);
assert!(cache.get("a").await.is_none());
}
#[tokio::test]
async fn test_cache_no_ttl() {
let config = CacheConfig::new().no_ttl();
let cache = LLMCache::with_config(config);
cache.put("key".to_string(), make_result("persist")).await;
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(cache.get("key").await.is_some());
}
#[tokio::test]
async fn test_cache_evict_expired() {
let config = CacheConfig::new()
.with_ttl(Duration::from_millis(0));
let cache = LLMCache::with_config(config);
cache.put("key".to_string(), make_result("test")).await;
tokio::time::sleep(Duration::from_millis(1)).await;
let evicted = cache.evict_expired().await;
assert_eq!(evicted, 1);
assert!(cache.is_empty().await);
}
#[tokio::test]
async fn test_cache_build_key() {
let messages = vec![
Message::human("Hello"),
Message::ai("Hi!"),
];
let key = LLMCache::build_key(&messages, "gpt-4");
assert!(key.contains("gpt-4"));
assert!(key.contains("Hello"));
}
#[tokio::test]
async fn test_cache_is_empty() {
let cache = LLMCache::new();
assert!(cache.is_empty().await);
cache.put("key".to_string(), make_result("test")).await;
assert!(!cache.is_empty().await);
}
}