use chrono::{Duration, Utc};
use rusqlite::{params, Connection};
use sha2::{Digest, Sha256};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::types::{MemoryError, MemoryResult};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ResponseCacheScope {
pub tenant_org_id: String,
pub tenant_workspace_id: String,
pub tenant_deployment_id: Option<String>,
pub source_binding_ids: Vec<String>,
}
impl ResponseCacheScope {
pub fn tenant(
tenant_org_id: impl Into<String>,
tenant_workspace_id: impl Into<String>,
tenant_deployment_id: Option<String>,
) -> Self {
Self {
tenant_org_id: tenant_org_id.into(),
tenant_workspace_id: tenant_workspace_id.into(),
tenant_deployment_id,
source_binding_ids: Vec::new(),
}
}
pub fn with_source_bindings(mut self, source_binding_ids: Vec<String>) -> Self {
self.source_binding_ids = normalized_source_binding_ids(source_binding_ids);
self
}
fn source_binding_key(&self) -> String {
source_binding_key(&self.source_binding_ids)
}
fn fingerprint(&self) -> String {
format!(
"org={}|workspace={}|deployment={}|source_bindings={}",
self.tenant_org_id,
self.tenant_workspace_id,
self.tenant_deployment_id.as_deref().unwrap_or(""),
self.source_binding_key()
)
}
}
pub struct ResponseCache {
conn: Arc<Mutex<Connection>>,
#[allow(dead_code)]
db_path: PathBuf,
ttl_minutes: i64,
max_entries: usize,
}
impl ResponseCache {
pub async fn new(db_dir: &Path, ttl_minutes: u32, max_entries: usize) -> MemoryResult<Self> {
tokio::fs::create_dir_all(db_dir)
.await
.map_err(MemoryError::Io)?;
let db_path = db_dir.join("response_cache.db");
let conn = Connection::open(&db_path)?;
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA temp_store = MEMORY;",
)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS response_cache (
prompt_hash TEXT PRIMARY KEY,
model TEXT NOT NULL,
response TEXT NOT NULL,
token_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
accessed_at TEXT NOT NULL,
hit_count INTEGER NOT NULL DEFAULT 0,
tenant_org_id TEXT,
tenant_workspace_id TEXT,
tenant_deployment_id TEXT,
source_binding_key TEXT NOT NULL DEFAULT ''
);
CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
CREATE INDEX IF NOT EXISTS idx_rc_created ON response_cache(created_at);
CREATE INDEX IF NOT EXISTS idx_rc_tenant_scope
ON response_cache(tenant_org_id, tenant_workspace_id, tenant_deployment_id);
CREATE INDEX IF NOT EXISTS idx_rc_source_binding
ON response_cache(source_binding_key);",
)?;
migrate_response_cache_scope_columns(&conn)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
db_path,
ttl_minutes: i64::from(ttl_minutes),
max_entries,
})
}
pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(model.as_bytes());
hasher.update(b"|");
if let Some(sys) = system_prompt {
hasher.update(sys.as_bytes());
}
hasher.update(b"|");
hasher.update(user_prompt.as_bytes());
format!("{:064x}", hasher.finalize())
}
pub fn cache_key_scoped(
model: &str,
system_prompt: Option<&str>,
user_prompt: &str,
scope: &ResponseCacheScope,
) -> String {
let mut hasher = Sha256::new();
hasher.update(model.as_bytes());
hasher.update(b"|");
if let Some(sys) = system_prompt {
hasher.update(sys.as_bytes());
}
hasher.update(b"|");
hasher.update(user_prompt.as_bytes());
hasher.update(b"|");
hasher.update(scope.fingerprint().as_bytes());
format!("{:064x}", hasher.finalize())
}
pub async fn get(&self, key: &str) -> MemoryResult<Option<String>> {
let conn = self.conn.lock().await;
let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
let result: Option<String> = conn
.query_row(
"SELECT response FROM response_cache
WHERE prompt_hash = ?1 AND created_at > ?2",
params![key, cutoff],
|row| row.get(0),
)
.ok();
if result.is_some() {
let now = Utc::now().to_rfc3339();
conn.execute(
"UPDATE response_cache
SET accessed_at = ?1, hit_count = hit_count + 1
WHERE prompt_hash = ?2",
params![now, key],
)?;
}
Ok(result)
}
pub async fn put(
&self,
key: &str,
model: &str,
response: &str,
token_count: u32,
) -> MemoryResult<()> {
let conn = self.conn.lock().await;
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT OR REPLACE INTO response_cache
(prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
params![key, model, response, token_count, now, now],
)?;
let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
conn.execute(
"DELETE FROM response_cache WHERE created_at <= ?1",
params![cutoff],
)?;
#[allow(clippy::cast_possible_wrap)]
let max = self.max_entries as i64;
conn.execute(
"DELETE FROM response_cache WHERE prompt_hash IN (
SELECT prompt_hash FROM response_cache
ORDER BY accessed_at ASC
LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
)",
params![max],
)?;
Ok(())
}
pub async fn put_scoped(
&self,
key: &str,
model: &str,
response: &str,
token_count: u32,
scope: &ResponseCacheScope,
) -> MemoryResult<()> {
let conn = self.conn.lock().await;
let now = Utc::now().to_rfc3339();
let source_binding_key = scope.source_binding_key();
conn.execute(
"INSERT OR REPLACE INTO response_cache
(prompt_hash, model, response, token_count, created_at, accessed_at, hit_count,
tenant_org_id, tenant_workspace_id, tenant_deployment_id, source_binding_key)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0, ?7, ?8, ?9, ?10)",
params![
key,
model,
response,
token_count,
now,
now,
scope.tenant_org_id,
scope.tenant_workspace_id,
scope.tenant_deployment_id,
source_binding_key
],
)?;
self.evict_locked(&conn)?;
Ok(())
}
pub async fn invalidate_source_binding(
&self,
tenant_org_id: &str,
tenant_workspace_id: &str,
tenant_deployment_id: Option<&str>,
source_binding_id: &str,
) -> MemoryResult<usize> {
let conn = self.conn.lock().await;
let needle = format!("%|{}|%", normalize_source_binding_id(source_binding_id));
let affected = conn.execute(
"DELETE FROM response_cache
WHERE tenant_org_id = ?1
AND tenant_workspace_id = ?2
AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')
AND source_binding_key LIKE ?4",
params![
tenant_org_id,
tenant_workspace_id,
tenant_deployment_id,
needle
],
)?;
Ok(affected)
}
pub async fn invalidate_tenant(
&self,
tenant_org_id: &str,
tenant_workspace_id: &str,
tenant_deployment_id: Option<&str>,
) -> MemoryResult<usize> {
let conn = self.conn.lock().await;
let affected = conn.execute(
"DELETE FROM response_cache
WHERE tenant_org_id = ?1
AND tenant_workspace_id = ?2
AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')",
params![tenant_org_id, tenant_workspace_id, tenant_deployment_id],
)?;
Ok(affected)
}
pub async fn stats(&self) -> MemoryResult<(usize, u64, u64)> {
let conn = self.conn.lock().await;
let count: i64 =
conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
let hits: i64 = conn.query_row(
"SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
[],
|row| row.get(0),
)?;
let tokens_saved: i64 = conn.query_row(
"SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
[],
|row| row.get(0),
)?;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok((count as usize, hits as u64, tokens_saved as u64))
}
pub async fn clear(&self) -> MemoryResult<usize> {
let conn = self.conn.lock().await;
let affected = conn.execute("DELETE FROM response_cache", [])?;
Ok(affected)
}
fn evict_locked(&self, conn: &Connection) -> MemoryResult<()> {
let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
conn.execute(
"DELETE FROM response_cache WHERE created_at <= ?1",
params![cutoff],
)?;
#[allow(clippy::cast_possible_wrap)]
let max = self.max_entries as i64;
conn.execute(
"DELETE FROM response_cache WHERE prompt_hash IN (
SELECT prompt_hash FROM response_cache
ORDER BY accessed_at ASC
LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
)",
params![max],
)?;
Ok(())
}
}
fn migrate_response_cache_scope_columns(conn: &Connection) -> MemoryResult<()> {
let columns = conn
.prepare("PRAGMA table_info(response_cache)")?
.query_map([], |row| row.get::<_, String>(1))?
.collect::<Result<std::collections::HashSet<_>, _>>()?;
for (name, ddl) in [
(
"tenant_org_id",
"ALTER TABLE response_cache ADD COLUMN tenant_org_id TEXT",
),
(
"tenant_workspace_id",
"ALTER TABLE response_cache ADD COLUMN tenant_workspace_id TEXT",
),
(
"tenant_deployment_id",
"ALTER TABLE response_cache ADD COLUMN tenant_deployment_id TEXT",
),
(
"source_binding_key",
"ALTER TABLE response_cache ADD COLUMN source_binding_key TEXT NOT NULL DEFAULT ''",
),
] {
if !columns.contains(name) {
conn.execute(ddl, [])?;
}
}
Ok(())
}
fn normalized_source_binding_ids(source_binding_ids: Vec<String>) -> Vec<String> {
let mut ids = source_binding_ids
.into_iter()
.map(|id| normalize_source_binding_id(&id))
.filter(|id| !id.is_empty())
.collect::<Vec<_>>();
ids.sort();
ids.dedup();
ids
}
fn normalize_source_binding_id(source_binding_id: &str) -> String {
source_binding_id.trim().replace('|', "")
}
fn source_binding_key(source_binding_ids: &[String]) -> String {
if source_binding_ids.is_empty() {
return String::new();
}
format!("|{}|", source_binding_ids.join("|"))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
let tmp = TempDir::new().unwrap();
let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000)
.await
.unwrap();
(tmp, cache)
}
#[tokio::test]
async fn cache_key_is_deterministic() {
let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
assert_eq!(k1, k2);
assert_eq!(k1.len(), 64);
}
#[tokio::test]
async fn cache_key_varies_by_model() {
let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
let k2 = ResponseCache::cache_key("claude-3", None, "hello");
assert_ne!(k1, k2);
}
#[tokio::test]
async fn scoped_cache_key_varies_by_tenant_and_source_binding() {
let scope_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
.with_source_bindings(vec!["finance-drive".to_string()]);
let scope_b = ResponseCacheScope::tenant("org-a", "workspace-a", None)
.with_source_bindings(vec!["hr-drive".to_string()]);
let key_a = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_a);
let key_b = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_b);
assert_ne!(key_a, key_b);
}
#[tokio::test]
async fn put_and_get_roundtrip() {
let (_tmp, cache) = temp_cache(60).await;
let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
cache
.put(&key, "gpt-4", "Rust is a systems programming language.", 25)
.await
.unwrap();
let result = cache.get(&key).await.unwrap();
assert_eq!(
result.as_deref(),
Some("Rust is a systems programming language.")
);
}
#[tokio::test]
async fn miss_returns_none() {
let (_tmp, cache) = temp_cache(60).await;
let result = cache.get("nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn expired_entry_returns_none() {
let (_tmp, cache) = temp_cache(0).await; let key = ResponseCache::cache_key("gpt-4", None, "test");
cache.put(&key, "gpt-4", "response", 10).await.unwrap();
let result = cache.get(&key).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn stats_tracks_hits_and_tokens() {
let (_tmp, cache) = temp_cache(60).await;
let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
cache.put(&key, "gpt-4", "Rust is...", 100).await.unwrap();
for _ in 0..5 {
let _ = cache.get(&key).await.unwrap();
}
let (_, hits, tokens) = cache.stats().await.unwrap();
assert_eq!(hits, 5);
assert_eq!(tokens, 500);
}
#[tokio::test]
async fn lru_eviction_respects_max_entries() {
let tmp = TempDir::new().unwrap();
let cache = ResponseCache::new(tmp.path(), 60, 3).await.unwrap();
for i in 0..5 {
let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
cache
.put(&key, "gpt-4", &format!("response {i}"), 10)
.await
.unwrap();
}
let (count, _, _) = cache.stats().await.unwrap();
assert!(count <= 3, "cache must not exceed max_entries");
}
#[tokio::test]
async fn invalidate_source_binding_removes_only_matching_tenant_entries() {
let (_tmp, cache) = temp_cache(60).await;
let finance_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
.with_source_bindings(vec!["finance-drive".to_string()]);
let hr_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
.with_source_bindings(vec!["hr-drive".to_string()]);
let finance_b = ResponseCacheScope::tenant("org-b", "workspace-b", None)
.with_source_bindings(vec!["finance-drive".to_string()]);
for (idx, scope) in [&finance_a, &hr_a, &finance_b].into_iter().enumerate() {
let key =
ResponseCache::cache_key_scoped("gpt-4", None, &format!("prompt {idx}"), scope);
cache
.put_scoped(&key, "gpt-4", &format!("response {idx}"), 10, scope)
.await
.unwrap();
}
let removed = cache
.invalidate_source_binding("org-a", "workspace-a", None, "finance-drive")
.await
.unwrap();
assert_eq!(removed, 1);
let (count, _, _) = cache.stats().await.unwrap();
assert_eq!(count, 2);
}
}