use anyhow::{Context, Result};
use rayon::prelude::*;
use rusqlite::{params, Connection};
use std::sync::{Arc, Mutex};
use crate::similarity::cosine;
pub struct Chunk {
pub text: String,
pub embedding: Vec<f32>,
}
#[derive(serde::Serialize)]
pub struct CacheStats {
pub entries: i64,
pub total_chunks: i64,
pub cache_hits: i64,
pub tokens_original: i64,
pub tokens_returned: i64,
pub db_bytes: u64,
}
impl CacheStats {
pub fn tokens_saved(&self) -> i64 {
self.tokens_original - self.tokens_returned
+ self.cache_hits * (self.tokens_original.max(1) / self.entries.max(1))
}
pub fn hit_rate(&self) -> f64 {
let total = self.entries + self.cache_hits;
if total == 0 {
return 0.0;
}
self.cache_hits as f64 / total as f64 * 100.0
}
}
pub struct Cache {
conn: Arc<Mutex<Connection>>,
}
const SCHEMA_VERSION: u32 = 1;
impl Cache {
pub fn new(path: &str) -> Result<Self> {
let conn = Connection::open(path)
.with_context(|| format!("Cannot open cache database at {path}"))?;
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA foreign_keys = ON;",
)?;
migrate(&conn)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
pub fn search(&self, query_emb: &[f32], threshold: f32) -> Result<Option<(i64, Vec<Chunk>)>> {
let conn = self.conn.lock().unwrap();
let candidates: Vec<(i64, Vec<f32>)> = {
let mut stmt = conn.prepare("SELECT id, query_embedding FROM queries")?;
let rows: Vec<(i64, Vec<u8>)> = stmt
.query_map([], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, Vec<u8>>(1)?))
})?
.filter_map(|r| r.ok())
.collect();
rows.into_iter()
.map(|(id, blob)| (id, bytes_to_f32(&blob)))
.collect()
};
let best: Option<(i64, f32)> = candidates
.par_iter()
.map(|(id, emb)| (*id, cosine(query_emb, emb)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let Some((query_id, score)) = best else {
return Ok(None);
};
if score < threshold {
return Ok(None);
}
let chunks = self.load_chunks(&conn, query_id)?;
Ok(Some((query_id, chunks)))
}
pub fn store(
&self,
query_text: &str,
query_emb: &[f32],
chunks: &[(String, Vec<f32>)],
original_tokens: usize,
returned_tokens: usize,
) -> Result<()> {
let conn = self.conn.lock().unwrap();
let now = unix_now();
conn.execute(
"INSERT INTO queries
(query_text, query_embedding, original_tokens, returned_tokens, created_at, last_used_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?5)",
params![
query_text,
f32_to_bytes(query_emb),
original_tokens as i64,
returned_tokens as i64,
now
],
)?;
let query_id = conn.last_insert_rowid();
let mut stmt = conn.prepare(
"INSERT INTO chunks (query_id, text, embedding, position) VALUES (?1, ?2, ?3, ?4)",
)?;
for (pos, (text, emb)) in chunks.iter().enumerate() {
stmt.execute(params![query_id, text, f32_to_bytes(emb), pos as i64])?;
}
Ok(())
}
pub fn record_hit(&self, query_id: i64) -> Result<()> {
let conn = self.conn.lock().unwrap();
conn.execute(
"UPDATE queries SET hit_count = hit_count + 1, last_used_at = ?1 WHERE id = ?2",
params![unix_now(), query_id],
)?;
Ok(())
}
pub fn evict_expired(&self, days: u64) -> Result<usize> {
let conn = self.conn.lock().unwrap();
let cutoff = unix_now() - (days * 86_400) as i64;
let n = conn.execute(
"DELETE FROM queries WHERE last_used_at < ?1",
params![cutoff],
)?;
Ok(n)
}
pub fn clear_older_than(&self, days: u64) -> Result<usize> {
self.evict_expired(days)
}
pub fn clear_all(&self) -> Result<usize> {
let conn = self.conn.lock().unwrap();
let n = conn.execute("DELETE FROM queries", [])?;
Ok(n)
}
pub fn evict_lru_if_needed(&self, path: &str, max_bytes: u64) -> Result<()> {
let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
if size <= max_bytes {
return Ok(());
}
let conn = self.conn.lock().unwrap();
let total: i64 = conn.query_row("SELECT COUNT(*) FROM queries", [], |r| r.get(0))?;
let to_delete = (total as f64 * 0.20).ceil() as i64;
conn.execute(
"DELETE FROM queries WHERE id IN (
SELECT id FROM queries ORDER BY last_used_at ASC LIMIT ?1
)",
params![to_delete],
)?;
Ok(())
}
pub fn stats(&self, path: &str) -> Result<CacheStats> {
let conn = self.conn.lock().unwrap();
let (entries, total_chunks, cache_hits, tokens_original, tokens_returned) = conn
.query_row(
"SELECT
(SELECT COUNT(*) FROM queries),
(SELECT COUNT(*) FROM chunks),
COALESCE((SELECT SUM(hit_count) FROM queries), 0),
COALESCE((SELECT SUM(original_tokens) FROM queries), 0),
COALESCE((SELECT SUM(returned_tokens) FROM queries), 0)",
[],
|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?, r.get(4)?)),
)?;
let db_bytes = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
Ok(CacheStats {
entries,
total_chunks,
cache_hits,
tokens_original,
tokens_returned,
db_bytes,
})
}
fn load_chunks(&self, conn: &Connection, query_id: i64) -> Result<Vec<Chunk>> {
let mut stmt = conn
.prepare("SELECT text, embedding FROM chunks WHERE query_id = ?1 ORDER BY position")?;
let chunks = stmt
.query_map(params![query_id], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, Vec<u8>>(1)?))
})?
.filter_map(|r| r.ok())
.map(|(text, blob)| Chunk {
text,
embedding: bytes_to_f32(&blob),
})
.collect();
Ok(chunks)
}
}
fn migrate(conn: &Connection) -> Result<()> {
const MIGRATIONS: &[(u32, &str)] = &[(
1,
"CREATE TABLE IF NOT EXISTS queries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query_text TEXT NOT NULL,
query_embedding BLOB NOT NULL,
original_tokens INTEGER NOT NULL DEFAULT 0,
returned_tokens INTEGER NOT NULL DEFAULT 0,
hit_count INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL,
last_used_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query_id INTEGER NOT NULL REFERENCES queries(id) ON DELETE CASCADE,
text TEXT NOT NULL,
embedding BLOB NOT NULL,
position INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_chunks_query ON chunks(query_id);
CREATE INDEX IF NOT EXISTS idx_queries_last ON queries(last_used_at);
CREATE INDEX IF NOT EXISTS idx_queries_created ON queries(created_at);",
)];
let current: u32 = conn.query_row("PRAGMA user_version", [], |r| r.get(0))?;
for &(version, sql) in MIGRATIONS {
if current < version {
conn.execute_batch(sql)?;
conn.execute_batch(&format!("PRAGMA user_version = {version}"))?;
}
}
debug_assert_eq!(
{
let v: u32 = conn.query_row("PRAGMA user_version", [], |r| r.get(0))?;
v
},
SCHEMA_VERSION,
"schema version mismatch after migration"
);
Ok(())
}
fn unix_now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
v.iter().flat_map(|f| f.to_le_bytes()).collect()
}
fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()
}