use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::RwLock;
use cognis_core::outputs::ChatResult;
use super::LlmCache;
#[derive(Clone)]
struct CacheEntry {
value: ChatResult,
last_accessed: Instant,
created_at: Instant,
}
pub struct InMemoryCache {
store: Arc<RwLock<HashMap<String, CacheEntry>>>,
max_size: Option<usize>,
ttl: Option<Duration>,
}
pub struct InMemoryCacheBuilder {
max_size: Option<usize>,
ttl: Option<Duration>,
}
impl InMemoryCacheBuilder {
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = Some(size);
self
}
pub fn ttl(mut self, duration: Duration) -> Self {
self.ttl = Some(duration);
self
}
pub fn build(self) -> InMemoryCache {
InMemoryCache {
store: Arc::new(RwLock::new(HashMap::new())),
max_size: self.max_size,
ttl: self.ttl,
}
}
}
impl InMemoryCache {
pub fn new() -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
max_size: None,
ttl: None,
}
}
pub fn builder() -> InMemoryCacheBuilder {
InMemoryCacheBuilder {
max_size: None,
ttl: None,
}
}
pub async fn len(&self) -> usize {
self.store.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.store.read().await.is_empty()
}
fn is_expired(&self, entry: &CacheEntry) -> bool {
if let Some(ttl) = self.ttl {
entry.created_at.elapsed() > ttl
} else {
false
}
}
fn evict_lru(store: &mut HashMap<String, CacheEntry>) {
if let Some(lru_key) = store
.iter()
.min_by_key(|(_, entry)| entry.last_accessed)
.map(|(k, _)| k.clone())
{
store.remove(&lru_key);
}
}
}
impl Default for InMemoryCache {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl LlmCache for InMemoryCache {
async fn get(&self, key: &str) -> Option<ChatResult> {
{
let store = self.store.read().await;
let entry = store.get(key)?;
if self.is_expired(entry) {
drop(store);
let mut store = self.store.write().await;
store.remove(key);
return None;
}
let value = entry.value.clone();
drop(store);
let mut store = self.store.write().await;
if let Some(entry) = store.get_mut(key) {
entry.last_accessed = Instant::now();
}
Some(value)
}
}
async fn put(&self, key: &str, result: &ChatResult) {
let mut store = self.store.write().await;
let now = Instant::now();
let entry = CacheEntry {
value: result.clone(),
last_accessed: now,
created_at: now,
};
if store.contains_key(key) {
store.insert(key.to_string(), entry);
return;
}
if let Some(max) = self.max_size {
while store.len() >= max {
Self::evict_lru(&mut store);
}
}
store.insert(key.to_string(), entry);
}
async fn clear(&self) {
let mut store = self.store.write().await;
store.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::messages::AIMessage;
use cognis_core::outputs::{ChatGeneration, ChatResult};
fn make_result(text: &str) -> ChatResult {
ChatResult {
generations: vec![ChatGeneration::new(AIMessage::new(text))],
llm_output: None,
}
}
#[tokio::test]
async fn test_put_and_get() {
let cache = InMemoryCache::new();
let result = make_result("hello");
cache.put("k1", &result).await;
let got = cache.get("k1").await;
assert!(got.is_some());
assert_eq!(got.unwrap(), result);
}
#[tokio::test]
async fn test_get_returns_none_on_miss() {
let cache = InMemoryCache::new();
assert!(cache.get("nonexistent").await.is_none());
}
#[tokio::test]
async fn test_clear_empties_cache() {
let cache = InMemoryCache::new();
cache.put("a", &make_result("a")).await;
cache.put("b", &make_result("b")).await;
assert_eq!(cache.len().await, 2);
cache.clear().await;
assert!(cache.is_empty().await);
assert!(cache.get("a").await.is_none());
}
#[tokio::test]
async fn test_lru_eviction() {
let cache = InMemoryCache::builder().max_size(2).build();
cache.put("a", &make_result("first")).await;
cache.put("b", &make_result("second")).await;
let _ = cache.get("a").await;
cache.put("c", &make_result("third")).await;
assert!(
cache.get("a").await.is_some(),
"a should survive (recently accessed)"
);
assert!(cache.get("b").await.is_none(), "b should be evicted (LRU)");
assert!(cache.get("c").await.is_some(), "c should be present");
}
#[tokio::test]
async fn test_ttl_expiry() {
let cache = InMemoryCache::builder()
.ttl(Duration::from_millis(50))
.build();
cache.put("k", &make_result("ephemeral")).await;
assert!(cache.get("k").await.is_some());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(cache.get("k").await.is_none(), "entry should have expired");
}
#[tokio::test]
async fn test_overwrite_existing_key() {
let cache = InMemoryCache::builder().max_size(2).build();
cache.put("k", &make_result("v1")).await;
cache.put("k", &make_result("v2")).await;
let got = cache.get("k").await.unwrap();
assert_eq!(got.generations[0].text, "v2");
assert_eq!(cache.len().await, 1, "overwrite should not increase size");
}
#[tokio::test]
async fn test_unbounded_cache() {
let cache = InMemoryCache::new();
for i in 0..100 {
cache
.put(&format!("k{i}"), &make_result(&format!("v{i}")))
.await;
}
assert_eq!(cache.len().await, 100);
assert!(cache.get("k0").await.is_some());
assert!(cache.get("k99").await.is_some());
}
}