use anyhow::{Context, Result, anyhow};
use chrono::{DateTime, Utc};
use redb::{Database, ReadableTable};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use uuid::Uuid;
use crate::memory_core::store::kg_store::RECALL_LOG;
const STOP_WORDS: &[&str] = &[
"the", "a", "an", "is", "are", "was", "were", "in", "on", "at", "to", "of", "for", "with",
"by", "from", "and", "or", "but", "not", "it", "this", "that", "be", "as", "do", "did", "has",
"have", "had",
];
pub fn normalize_query(text: &str) -> String {
let lower = text.to_lowercase();
let no_punct: String = lower
.chars()
.map(|c| {
if c.is_alphanumeric() || c == ' ' {
c
} else {
' '
}
})
.collect();
let words: Vec<&str> = no_punct
.split_whitespace()
.filter(|w| !STOP_WORDS.contains(w))
.collect();
words.join(" ")
}
pub fn fnv1a_hash(text: &str) -> u64 {
const OFFSET: u64 = 14695981039346656037;
const PRIME: u64 = 1099511628211;
let mut hash = OFFSET;
for byte in text.bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(PRIME);
}
hash
}
pub fn query_hash(text: &str) -> u64 {
fnv1a_hash(&normalize_query(text))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecallEvent {
pub palace_id: String,
pub query_hash: u64,
pub layer: u8,
pub drawer_id: Option<Uuid>,
pub score: f32,
pub occurred_at: DateTime<Utc>,
}
pub struct RecallLog {
db: Arc<Database>,
path: PathBuf,
next_id: AtomicU64,
}
impl RecallLog {
pub fn open(path: &Path) -> Result<Self> {
let redb_path = resolve_redb_path(path);
if let Some(parent) = redb_path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent).with_context(|| {
format!(
"failed to create recall log parent dir {}",
parent.display()
)
})?;
}
}
#[cfg(feature = "sqlite-kg")]
migrate_from_sqlite_if_present(path, &redb_path)?;
let db = Database::create(&redb_path).with_context(|| {
format!("failed to open redb recall log at {}", redb_path.display())
})?;
let mut max_seen: u64 = 0;
{
let wtx = db
.begin_write()
.context("failed to begin write txn for recall log init")?;
{
let table = wtx
.open_table(RECALL_LOG)
.context("failed to open RECALL_LOG table")?;
if let Some(entry) = table
.last()
.context("failed to read last key from RECALL_LOG")?
{
max_seen = entry.0.value();
}
}
wtx.commit().context("failed to commit recall log init")?;
}
Ok(Self {
db: Arc::new(db),
path: redb_path,
next_id: AtomicU64::new(max_seen),
})
}
fn alloc_id(&self) -> u64 {
let now_ms = Utc::now().timestamp_millis().max(0) as u64;
loop {
let current = self.next_id.load(Ordering::Acquire);
let candidate = now_ms.max(current + 1);
if self
.next_id
.compare_exchange(current, candidate, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return candidate;
}
}
}
pub async fn record(&self, event: RecallEvent) -> Result<()> {
let id = self.alloc_id();
let bytes =
postcard::to_allocvec(&event).context("failed to postcard-encode RecallEvent")?;
let db = self.db.clone();
let path = self.path.clone();
tokio::task::spawn_blocking(move || -> Result<()> {
let wtx = db
.begin_write()
.with_context(|| format!("begin_write recall log {}", path.display()))?;
{
let mut table = wtx
.open_table(RECALL_LOG)
.context("open RECALL_LOG table")?;
table
.insert(id, bytes.as_slice())
.context("insert RecallEvent row")?;
}
wtx.commit().context("commit RecallEvent write")?;
Ok(())
})
.await
.context("record task join error")??;
Ok(())
}
fn snapshot(&self) -> Result<Vec<RecallEvent>> {
let db = self.db.clone();
let path = self.path.clone();
let rtx = db
.begin_read()
.with_context(|| format!("begin_read recall log {}", path.display()))?;
let table = rtx
.open_table(RECALL_LOG)
.context("open RECALL_LOG table (read)")?;
let mut out = Vec::new();
for entry in table.iter().context("iter RECALL_LOG")? {
let (_k, v) = entry.context("decode RECALL_LOG row")?;
let ev: RecallEvent =
postcard::from_bytes(v.value()).context("postcard decode RecallEvent")?;
out.push(ev);
}
Ok(out)
}
pub async fn hit_count(&self, drawer_id: Uuid) -> Result<u64> {
let events = self.snapshot_async().await?;
let mut count: u64 = 0;
for ev in events {
if ev.drawer_id == Some(drawer_id) {
count += 1;
}
}
Ok(count)
}
pub async fn miss_rate(&self, palace_id: &str, window_days: u32) -> Result<f32> {
let events = self.snapshot_async().await?;
let since = Utc::now() - chrono::Duration::days(window_days as i64);
use std::collections::HashSet;
let mut total: HashSet<u64> = HashSet::new();
let mut misses: HashSet<u64> = HashSet::new();
for ev in events {
if ev.palace_id != palace_id || ev.occurred_at < since {
continue;
}
total.insert(ev.query_hash);
if ev.drawer_id.is_none() {
misses.insert(ev.query_hash);
}
}
if total.is_empty() {
return Ok(0.0);
}
Ok(misses.len() as f32 / total.len() as f32)
}
pub async fn top_drawers(&self, palace_id: &str, limit: usize) -> Result<Vec<(Uuid, u64)>> {
let events = self.snapshot_async().await?;
use std::collections::HashMap;
let mut counts: HashMap<Uuid, u64> = HashMap::new();
for ev in events {
if ev.palace_id != palace_id {
continue;
}
if let Some(id) = ev.drawer_id {
*counts.entry(id).or_insert(0) += 1;
}
}
let mut out: Vec<(Uuid, u64)> = counts.into_iter().collect();
out.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
out.truncate(limit);
Ok(out)
}
pub async fn missed_queries(&self, palace_id: &str, limit: usize) -> Result<Vec<(u64, u64)>> {
let events = self.snapshot_async().await?;
use std::collections::HashMap;
let mut counts: HashMap<u64, u64> = HashMap::new();
for ev in events {
if ev.palace_id != palace_id || ev.drawer_id.is_some() {
continue;
}
*counts.entry(ev.query_hash).or_insert(0) += 1;
}
let mut out: Vec<(u64, u64)> = counts.into_iter().collect();
out.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
out.truncate(limit);
Ok(out)
}
async fn snapshot_async(&self) -> Result<Vec<RecallEvent>> {
let db = self.db.clone();
let path = self.path.clone();
tokio::task::spawn_blocking(move || -> Result<Vec<RecallEvent>> {
let rtx = db
.begin_read()
.with_context(|| format!("begin_read recall log {}", path.display()))?;
let table = rtx
.open_table(RECALL_LOG)
.context("open RECALL_LOG table (read)")?;
let mut out = Vec::new();
for entry in table.iter().context("iter RECALL_LOG")? {
let (_k, v) = entry.context("decode RECALL_LOG row")?;
let ev: RecallEvent =
postcard::from_bytes(v.value()).context("postcard decode RecallEvent")?;
out.push(ev);
}
Ok(out)
})
.await
.context("snapshot task join error")?
}
}
fn resolve_redb_path(path: &Path) -> PathBuf {
if path.extension().is_some_and(|e| e == "db") {
path.with_extension("redb")
} else {
path.to_path_buf()
}
}
#[cfg(feature = "sqlite-kg")]
fn migrate_from_sqlite_if_present(orig_path: &Path, redb_path: &Path) -> Result<()> {
let sqlite_path = if orig_path.extension().is_some_and(|e| e == "db") {
orig_path.to_path_buf()
} else {
let parent = redb_path.parent().unwrap_or(Path::new("."));
parent.join("recall.db")
};
if !sqlite_path.exists() {
return Ok(());
}
let migrated_marker = sqlite_path.with_extension("db.migrated");
if migrated_marker.exists() && !sqlite_path.exists() {
return Ok(());
}
use rusqlite::Connection;
let conn = Connection::open_with_flags(
&sqlite_path,
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY | rusqlite::OpenFlags::SQLITE_OPEN_URI,
)
.with_context(|| {
format!(
"open legacy sqlite recall log read-only: {}",
sqlite_path.display()
)
})?;
let table_exists: bool = conn
.query_row(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name='recall_events'",
[],
|_| Ok(true),
)
.unwrap_or(false);
if !table_exists {
let _ = std::fs::rename(&sqlite_path, &migrated_marker);
return Ok(());
}
let mut stmt = conn
.prepare(
"SELECT palace_id, query_hash, layer, drawer_id, score, occurred_at \
FROM recall_events ORDER BY id ASC",
)
.context("prepare legacy recall_events select")?;
let rows_iter = stmt
.query_map([], |row| {
let palace_id: String = row.get(0)?;
let query_hash_i: i64 = row.get(1)?;
let layer_i: i64 = row.get(2)?;
let drawer_id: Option<String> = row.get(3)?;
let score: f64 = row.get(4)?;
let occurred_at: String = row.get(5)?;
Ok((
palace_id,
query_hash_i,
layer_i,
drawer_id,
score,
occurred_at,
))
})
.context("query legacy recall_events rows")?;
let mut staged: Vec<RecallEvent> = Vec::new();
for row in rows_iter {
let (palace_id, qh_i, layer_i, drawer_id_str, score, occurred_at_s) =
row.context("read legacy recall_events row")?;
let drawer_id = match drawer_id_str {
Some(s) => Some(
Uuid::parse_str(&s)
.map_err(|e| anyhow!("invalid uuid in legacy recall row: {e}"))?,
),
None => None,
};
let occurred_at = DateTime::parse_from_rfc3339(&occurred_at_s)
.map_err(|e| anyhow!("invalid occurred_at in legacy recall row: {e}"))?
.with_timezone(&Utc);
staged.push(RecallEvent {
palace_id,
query_hash: qh_i as u64,
layer: layer_i as u8,
drawer_id,
score: score as f32,
occurred_at,
});
}
drop(stmt);
drop(conn);
let db = Database::create(redb_path).with_context(|| {
format!(
"open redb recall log for migration write: {}",
redb_path.display()
)
})?;
let wtx = db
.begin_write()
.context("begin_write redb for recall migration")?;
{
let mut table = wtx
.open_table(RECALL_LOG)
.context("open RECALL_LOG table for migration")?;
for (i, ev) in staged.iter().enumerate() {
let id = (i as u64).saturating_add(1);
let bytes =
postcard::to_allocvec(ev).context("postcard encode migrated RecallEvent")?;
table
.insert(id, bytes.as_slice())
.context("insert migrated RecallEvent row")?;
}
}
wtx.commit().context("commit migrated recall rows")?;
drop(db);
std::fs::rename(&sqlite_path, &migrated_marker).with_context(|| {
format!(
"rename legacy recall db {} -> {}",
sqlite_path.display(),
migrated_marker.display()
)
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn normalize_removes_stop_words() {
assert_eq!(normalize_query("The quick Brown Fox!"), "quick brown fox");
}
#[test]
fn normalize_strips_punctuation() {
assert_eq!(normalize_query(" Rust is fast! "), "rust fast");
}
#[test]
fn fnv1a_is_deterministic() {
assert_eq!(fnv1a_hash("hello"), fnv1a_hash("hello"));
assert_ne!(fnv1a_hash("hello"), fnv1a_hash("world"));
}
#[tokio::test]
async fn record_then_hit_count() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
let id = Uuid::new_v4();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: query_hash("vector store"),
layer: 2,
drawer_id: Some(id),
score: 0.9,
occurred_at: Utc::now(),
})
.await
.unwrap();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: query_hash("vector store"),
layer: 2,
drawer_id: Some(id),
score: 0.85,
occurred_at: Utc::now(),
})
.await
.unwrap();
assert_eq!(log.hit_count(id).await.unwrap(), 2);
}
#[tokio::test]
async fn miss_rate_all_miss() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: query_hash("unknown topic"),
layer: 3,
drawer_id: None,
score: 0.0,
occurred_at: Utc::now(),
})
.await
.unwrap();
let rate = log.miss_rate("test", 7).await.unwrap();
assert!((rate - 1.0).abs() < 1e-4, "expected 1.0 got {rate}");
}
#[tokio::test]
async fn miss_rate_all_hit() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: query_hash("rust"),
layer: 2,
drawer_id: Some(Uuid::new_v4()),
score: 0.9,
occurred_at: Utc::now(),
})
.await
.unwrap();
let rate = log.miss_rate("test", 7).await.unwrap();
assert!((rate - 0.0).abs() < 1e-4, "expected 0.0 got {rate}");
}
#[tokio::test]
async fn top_drawers_sorted_by_hits() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
let id_a = Uuid::new_v4();
let id_b = Uuid::new_v4();
for _ in 0..3 {
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: 1,
layer: 2,
drawer_id: Some(id_a),
score: 0.9,
occurred_at: Utc::now(),
})
.await
.unwrap();
}
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: 2,
layer: 2,
drawer_id: Some(id_b),
score: 0.8,
occurred_at: Utc::now(),
})
.await
.unwrap();
let top = log.top_drawers("test", 5).await.unwrap();
assert_eq!(top[0].0, id_a);
assert_eq!(top[0].1, 3);
assert_eq!(top[1].0, id_b);
}
#[tokio::test]
async fn missed_queries_returns_most_missed_first() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
let h1 = query_hash("obscure topic");
let h2 = query_hash("another gap");
for _ in 0..3 {
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: h1,
layer: 3,
drawer_id: None,
score: 0.0,
occurred_at: Utc::now(),
})
.await
.unwrap();
}
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: h2,
layer: 3,
drawer_id: None,
score: 0.0,
occurred_at: Utc::now(),
})
.await
.unwrap();
let missed = log.missed_queries("test", 5).await.unwrap();
assert_eq!(missed[0].0, h1);
assert_eq!(missed[0].1, 3);
}
#[tokio::test]
async fn roundtrip_persists_across_reopen() {
let dir = tempdir().unwrap();
let path = dir.path().join("recall.db");
let id = Uuid::new_v4();
{
let log = RecallLog::open(&path).unwrap();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: 42,
layer: 2,
drawer_id: Some(id),
score: 0.5,
occurred_at: Utc::now(),
})
.await
.unwrap();
}
let log2 = RecallLog::open(&path).unwrap();
assert_eq!(log2.hit_count(id).await.unwrap(), 1);
log2.record(RecallEvent {
palace_id: "test".into(),
query_hash: 42,
layer: 2,
drawer_id: Some(id),
score: 0.7,
occurred_at: Utc::now(),
})
.await
.unwrap();
assert_eq!(log2.hit_count(id).await.unwrap(), 2);
}
#[test]
fn callers_passing_recall_db_get_redb_sibling() {
let dir = tempdir().unwrap();
let legacy = dir.path().join("recall.db");
let _log = RecallLog::open(&legacy).unwrap();
let redb_path = dir.path().join("recall.redb");
assert!(
redb_path.exists(),
"expected redb sibling to be created at {}",
redb_path.display()
);
}
#[cfg(feature = "sqlite-kg")]
#[tokio::test]
async fn migrates_legacy_sqlite_rows() {
use rusqlite::params;
let dir = tempdir().unwrap();
let legacy = dir.path().join("recall.db");
let drawer_a = Uuid::new_v4();
{
let conn = rusqlite::Connection::open(&legacy).unwrap();
conn.execute_batch(
"CREATE TABLE recall_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
palace_id TEXT NOT NULL,
query_hash INTEGER NOT NULL,
layer INTEGER NOT NULL,
drawer_id TEXT,
score REAL NOT NULL,
occurred_at TEXT NOT NULL
);",
)
.unwrap();
conn.execute(
"INSERT INTO recall_events
(palace_id, query_hash, layer, drawer_id, score, occurred_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
"test",
123_i64,
2_i64,
drawer_a.to_string(),
0.9_f64,
Utc::now().to_rfc3339(),
],
)
.unwrap();
conn.execute(
"INSERT INTO recall_events
(palace_id, query_hash, layer, drawer_id, score, occurred_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
"test",
456_i64,
3_i64,
Option::<String>::None,
0.0_f64,
Utc::now().to_rfc3339(),
],
)
.unwrap();
}
let log = RecallLog::open(&legacy).unwrap();
assert_eq!(log.hit_count(drawer_a).await.unwrap(), 1);
let rate = log.miss_rate("test", 7).await.unwrap();
assert!(rate > 0.0, "expected non-zero miss rate, got {rate}");
assert!(!legacy.exists(), "legacy recall.db should be renamed");
assert!(
dir.path().join("recall.db.migrated").exists(),
"expected migration marker file"
);
drop(log);
let log2 = RecallLog::open(&legacy).unwrap();
assert_eq!(log2.hit_count(drawer_a).await.unwrap(), 1);
}
}