use std::collections::HashMap;
use std::future::Future;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use chrono::Duration;
use lru::LruCache;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock as AsyncRwLock;
use tracing::{debug, info};
use super::types::{MemoEntry, MemoKey, MemoOpType, MemoStats, MemoValue};
use crate::error::Result;
use crate::utils::fingerprint::Fingerprint;
const DEFAULT_TTL: Duration = Duration::days(7);
const DEFAULT_MAX_SIZE: usize = 10_000;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MemoStoreData {
version: u32,
entries: HashMap<String, MemoEntry>,
stats: MemoStats,
}
#[derive(Debug, Default)]
struct AtomicStats {
hits: AtomicU64,
misses: AtomicU64,
tokens_saved: AtomicU64,
}
impl AtomicStats {
fn new() -> Self {
Self {
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
tokens_saved: AtomicU64::new(0),
}
}
fn record_hit(&self) {
self.hits.fetch_add(1, Ordering::Relaxed);
}
fn record_miss(&self) {
self.misses.fetch_add(1, Ordering::Relaxed);
}
fn add_tokens_saved(&self, tokens: u64) {
self.tokens_saved.fetch_add(tokens, Ordering::Relaxed);
}
fn snapshot(&self) -> (u64, u64, u64) {
(
self.hits.load(Ordering::Relaxed),
self.misses.load(Ordering::Relaxed),
self.tokens_saved.load(Ordering::Relaxed),
)
}
}
pub struct MemoStore {
cache: Arc<RwLock<LruCache<String, MemoEntry>>>,
stats: Arc<AsyncRwLock<MemoStats>>,
ttl: Duration,
model_id: Option<String>,
version: u32,
}
impl std::fmt::Debug for MemoStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoStore")
.field("ttl", &self.ttl)
.field("model_id", &self.model_id)
.field("version", &self.version)
.field("cache_len", &self.cache.read().len())
.finish()
}
}
impl Clone for MemoStore {
fn clone(&self) -> Self {
Self {
cache: Arc::clone(&self.cache),
stats: Arc::clone(&self.stats),
ttl: self.ttl,
model_id: self.model_id.clone(),
version: self.version,
}
}
}
impl MemoStore {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_MAX_SIZE)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
cache: Arc::new(RwLock::new(LruCache::new(
std::num::NonZeroUsize::new(capacity)
.unwrap_or(std::num::NonZeroUsize::new(1000).unwrap()),
))),
stats: Arc::new(AsyncRwLock::new(MemoStats::default())),
ttl: DEFAULT_TTL,
model_id: None,
version: 1,
}
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_model(mut self, model_id: &str) -> Self {
self.model_id = Some(model_id.to_string());
self
}
pub fn with_version(mut self, version: u32) -> Self {
self.version = version;
self
}
pub fn get(&self, key: &MemoKey) -> Option<MemoValue> {
let full_key = self.make_key(key);
let mut cache = self.cache.write();
if let Some(entry) = cache.get_mut(&full_key) {
if entry.is_expired(self.ttl) {
cache.pop(&full_key);
return None;
}
entry.record_hit();
debug!("Memo cache hit for {:?}", key.op_type);
return Some(entry.value.clone());
}
None
}
pub fn put(&self, key: MemoKey, value: MemoValue) {
self.put_with_tokens(key, value, 0);
}
pub fn put_with_tokens(&self, key: MemoKey, value: MemoValue, tokens_saved: u64) {
let full_key = self.make_key(&key);
let entry = MemoEntry::with_tokens(value, tokens_saved);
let mut cache = self.cache.write();
cache.put(full_key, entry);
debug!("Memo cache put for {:?}", key.op_type);
}
pub async fn get_or_compute<F, Fut>(&self, key: MemoKey, compute: F) -> Result<MemoValue>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<(MemoValue, u64)>>, {
if let Some(value) = self.get(&key) {
let mut stats = self.stats.write().await;
stats.hits += 1;
return Ok(value);
}
{
let mut stats = self.stats.write().await;
stats.misses += 1;
}
let (value, tokens) = compute().await?;
self.put_with_tokens(key.clone(), value.clone(), tokens);
{
let mut stats = self.stats.write().await;
stats.tokens_saved += tokens;
}
Ok(value)
}
pub fn contains(&self, key: &MemoKey) -> bool {
let full_key = self.make_key(key);
let cache = self.cache.read();
cache.contains(&full_key)
}
pub fn remove(&self, key: &MemoKey) -> Option<MemoValue> {
let full_key = self.make_key(key);
let mut cache = self.cache.write();
cache.pop(&full_key).map(|e| e.value)
}
pub fn clear(&self) {
let mut cache = self.cache.write();
cache.clear();
debug!("Memo cache cleared");
}
pub fn len(&self) -> usize {
let cache = self.cache.read();
cache.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub async fn stats(&self) -> MemoStats {
let stats = self.stats.read().await;
let mut result = stats.clone();
result.entries = self.len();
result
}
pub fn stats_snapshot(&self) -> MemoStats {
match self.stats.try_read() {
Ok(stats) => {
let mut result = stats.clone();
result.entries = self.len();
result
}
Err(_) => MemoStats {
entries: self.len(),
..Default::default()
},
}
}
pub fn invalidate_by_op_type(&self, op_type: MemoOpType) -> usize {
let mut cache = self.cache.write();
let before = cache.len();
let keys_to_remove: Vec<String> = cache
.iter()
.filter_map(|(key, entry)| {
let matches = match (&entry.value, op_type) {
(MemoValue::Summary(_), MemoOpType::Summary) => true,
(MemoValue::PilotDecision(_), MemoOpType::PilotDecision) => true,
(MemoValue::QueryAnalysis(_), MemoOpType::QueryAnalysis) => true,
(MemoValue::Extraction(_), MemoOpType::Extraction) => true,
_ => false,
};
if matches { Some(key.clone()) } else { None }
})
.collect();
for key in keys_to_remove {
cache.pop(&key);
}
let removed = before - cache.len();
if removed > 0 {
debug!("Invalidated {} entries for op_type {:?}", removed, op_type);
}
removed
}
pub fn invalidate_by_model_prefix(&self, prefix: &str) -> usize {
let mut cache = self.cache.write();
let before = cache.len();
let should_clear = self
.model_id
.as_ref()
.map(|m| m.starts_with(prefix))
.unwrap_or(false);
if should_clear {
cache.clear();
let removed = before;
debug!(
"Invalidated all {} entries (model prefix '{}')",
removed, prefix
);
return removed;
}
0
}
pub fn prune_expired(&self) -> usize {
let mut cache = self.cache.write();
let before = cache.len();
let expired: Vec<String> = cache
.iter()
.filter(|(_, entry)| entry.is_expired(self.ttl))
.map(|(k, _)| k.clone())
.collect();
for key in expired {
cache.pop(&key);
}
let removed = before - cache.len();
if removed > 0 {
debug!("Pruned {} expired memo entries", removed);
}
removed
}
pub async fn save(&self, path: &Path) -> Result<()> {
let cache = self.cache.read();
let stats = self.stats.read().await;
let entries: HashMap<String, MemoEntry> =
cache.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
let data = MemoStoreData {
version: 1,
entries,
stats: stats.clone(),
};
let parent = path
.parent()
.ok_or_else(|| crate::Error::Parse("Invalid path for memo store".to_string()))?;
tokio::fs::create_dir_all(parent).await?;
let temp_path = path.with_extension("tmp");
let json = serde_json::to_vec_pretty(&data)
.map_err(|e| crate::Error::Parse(format!("Failed to serialize memo store: {}", e)))?;
tokio::fs::write(&temp_path, &json).await?;
tokio::fs::rename(&temp_path, path).await?;
info!(
"Saved memo store with {} entries to {:?}",
data.entries.len(),
path
);
Ok(())
}
pub async fn load(&self, path: &Path) -> Result<()> {
if !path.exists() {
return Ok(());
}
let bytes = tokio::fs::read(path).await?;
let data: MemoStoreData = serde_json::from_slice(&bytes)
.map_err(|e| crate::Error::Parse(format!("Failed to deserialize memo store: {}", e)))?;
let mut cache = self.cache.write();
let mut stats = self.stats.write().await;
for (key, entry) in data.entries {
if !entry.is_expired(self.ttl) {
cache.put(key, entry);
}
}
stats.entries = cache.len();
stats.hits = data.stats.hits;
stats.misses = data.stats.misses;
stats.tokens_saved = data.stats.tokens_saved;
stats.cost_saved = data.stats.cost_saved;
info!(
"Loaded memo store with {} entries from {:?}",
cache.len(),
path
);
Ok(())
}
fn make_key(&self, key: &MemoKey) -> String {
let mut key_with_context = key.clone();
if key_with_context.model_id.is_none() {
key_with_context.model_id = self.model_id.clone();
}
if key_with_context.version == 0 {
key_with_context.version = self.version;
}
key_with_context.fingerprint().to_string()
}
}
impl Default for MemoStore {
fn default() -> Self {
Self::new()
}
}
pub struct MemoKeyBuilder {
model_id: Option<String>,
version: u32,
}
impl MemoKeyBuilder {
pub fn new() -> Self {
Self {
model_id: None,
version: 1,
}
}
pub fn with_model(mut self, model_id: &str) -> Self {
self.model_id = Some(model_id.to_string());
self
}
pub fn with_version(mut self, version: u32) -> Self {
self.version = version;
self
}
pub fn summary_key(&self, content_fp: &Fingerprint) -> MemoKey {
MemoKey {
op_type: super::types::MemoOpType::Summary,
input_fp: *content_fp,
model_id: self.model_id.clone(),
version: self.version,
context_fp: Fingerprint::zero(),
}
}
pub fn pilot_key(&self, context_fp: &Fingerprint, query_fp: &Fingerprint) -> MemoKey {
MemoKey {
op_type: super::types::MemoOpType::PilotDecision,
input_fp: *query_fp,
model_id: self.model_id.clone(),
version: self.version,
context_fp: *context_fp,
}
}
pub fn query_analysis_key(&self, query_fp: &Fingerprint) -> MemoKey {
MemoKey {
op_type: super::types::MemoOpType::QueryAnalysis,
input_fp: *query_fp,
model_id: self.model_id.clone(),
version: self.version,
context_fp: Fingerprint::zero(),
}
}
pub fn extraction_key(&self, content_fp: &Fingerprint) -> MemoKey {
MemoKey {
op_type: super::types::MemoOpType::Extraction,
input_fp: *content_fp,
model_id: self.model_id.clone(),
version: self.version,
context_fp: Fingerprint::zero(),
}
}
}
impl Default for MemoKeyBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn make_test_key() -> MemoKey {
let fp = Fingerprint::from_str("test content");
MemoKey::summary(&fp)
}
#[test]
fn test_memo_store_basic() {
let store = MemoStore::new();
let key = make_test_key();
assert!(!store.contains(&key));
store.put(key.clone(), MemoValue::Summary("Test summary".to_string()));
assert!(store.contains(&key));
let value = store.get(&key);
assert!(value.is_some());
assert_eq!(value.unwrap().as_summary(), Some("Test summary"));
}
#[test]
fn test_memo_store_lru_eviction() {
let store = MemoStore::with_capacity(3);
for i in 0..5 {
let fp = Fingerprint::from_str(&format!("content {}", i));
let key = MemoKey::summary(&fp);
store.put(key, MemoValue::Summary(format!("Summary {}", i)));
}
assert_eq!(store.len(), 3);
}
#[tokio::test]
async fn test_memo_store_get_or_compute() {
let store = MemoStore::new();
let key = make_test_key();
let call_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let count_clone = call_count.clone();
let result = store
.get_or_compute(key.clone(), || {
let c = count_clone.clone();
async move {
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok((MemoValue::Summary("Computed".to_string()), 100))
}
})
.await
.unwrap();
assert_eq!(result.as_summary(), Some("Computed"));
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
let result2 = store
.get_or_compute(key.clone(), || {
let c = count_clone.clone();
async move {
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok((MemoValue::Summary("Should not be called".to_string()), 100))
}
})
.await
.unwrap();
assert_eq!(result2.as_summary(), Some("Computed"));
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_memo_store_persistence() {
let temp = TempDir::new().unwrap();
let path = temp.path().join("memo.json");
let store = MemoStore::new();
let key = make_test_key();
store.put_with_tokens(
key.clone(),
MemoValue::Summary("Test summary".to_string()),
100,
);
store.save(&path).await.unwrap();
assert!(path.exists());
let store2 = MemoStore::new();
store2.load(&path).await.unwrap();
assert!(store2.contains(&key));
let value = store2.get(&key);
assert_eq!(value.unwrap().as_summary(), Some("Test summary"));
}
#[tokio::test]
async fn test_memo_store_stats() {
let store = MemoStore::new();
let key = make_test_key();
store
.get_or_compute(key.clone(), || async {
Ok((MemoValue::Summary("Test".to_string()), 100))
})
.await
.unwrap();
store
.get_or_compute(key.clone(), || async {
Ok((MemoValue::Summary("Should not be called".to_string()), 0))
})
.await
.unwrap();
let stats = store.stats().await;
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 1);
assert_eq!(stats.tokens_saved, 100);
}
#[test]
fn test_memo_key_builder() {
let builder = MemoKeyBuilder::new().with_model("gpt-4").with_version(2);
let fp = Fingerprint::from_str("content");
let key = builder.summary_key(&fp);
assert_eq!(key.model_id, Some("gpt-4".to_string()));
assert_eq!(key.version, 2);
}
}