use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use redb::{Database, ReadableDatabase, ReadableTable};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use uuid::Uuid;
use crate::memory_core::store::kg_store::RECALL_LOG;
const STOP_WORDS: &[&str] = &[
"the", "a", "an", "is", "are", "was", "were", "in", "on", "at", "to", "of", "for", "with",
"by", "from", "and", "or", "but", "not", "it", "this", "that", "be", "as", "do", "did", "has",
"have", "had",
];
pub fn normalize_query(text: &str) -> String {
let lower = text.to_lowercase();
let no_punct: String = lower
.chars()
.map(|c| {
if c.is_alphanumeric() || c == ' ' {
c
} else {
' '
}
})
.collect();
let words: Vec<&str> = no_punct
.split_whitespace()
.filter(|w| !STOP_WORDS.contains(w))
.collect();
words.join(" ")
}
pub fn fnv1a_hash(text: &str) -> u64 {
const OFFSET: u64 = 14695981039346656037;
const PRIME: u64 = 1099511628211;
let mut hash = OFFSET;
for byte in text.bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(PRIME);
}
hash
}
pub fn query_hash(text: &str) -> u64 {
fnv1a_hash(&normalize_query(text))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecallEvent {
pub palace_id: String,
pub query_hash: u64,
pub layer: u8,
pub drawer_id: Option<Uuid>,
pub score: f32,
pub occurred_at: DateTime<Utc>,
}
pub struct RecallLog {
db: Arc<Database>,
path: PathBuf,
next_id: AtomicU64,
}
impl RecallLog {
pub fn open(path: &Path) -> Result<Self> {
let redb_path = resolve_redb_path(path);
if let Some(parent) = redb_path.parent()
&& !parent.as_os_str().is_empty()
{
std::fs::create_dir_all(parent).with_context(|| {
format!(
"failed to create recall log parent dir {}",
parent.display()
)
})?;
}
let db = super::store::open_or_recreate(&redb_path).with_context(|| {
format!("failed to open redb recall log at {}", redb_path.display())
})?;
let mut max_seen: u64 = 0;
{
let wtx = db
.begin_write()
.context("failed to begin write txn for recall log init")?;
{
let table = wtx
.open_table(RECALL_LOG)
.context("failed to open RECALL_LOG table")?;
if let Some(entry) = table
.last()
.context("failed to read last key from RECALL_LOG")?
{
max_seen = entry.0.value();
}
}
wtx.commit().context("failed to commit recall log init")?;
}
Ok(Self {
db: Arc::new(db),
path: redb_path,
next_id: AtomicU64::new(max_seen),
})
}
fn alloc_id(&self) -> u64 {
let now_ms = Utc::now().timestamp_millis().max(0) as u64;
loop {
let current = self.next_id.load(Ordering::Acquire);
let candidate = now_ms.max(current + 1);
if self
.next_id
.compare_exchange(current, candidate, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return candidate;
}
}
}
pub async fn record(&self, event: RecallEvent) -> Result<()> {
let id = self.alloc_id();
let bytes =
postcard::to_allocvec(&event).context("failed to postcard-encode RecallEvent")?;
let db = self.db.clone();
let path = self.path.clone();
tokio::task::spawn_blocking(move || -> Result<()> {
let wtx = db
.begin_write()
.with_context(|| format!("begin_write recall log {}", path.display()))?;
{
let mut table = wtx
.open_table(RECALL_LOG)
.context("open RECALL_LOG table")?;
table
.insert(id, bytes.as_slice())
.context("insert RecallEvent row")?;
}
wtx.commit().context("commit RecallEvent write")?;
Ok(())
})
.await
.context("record task join error")??;
Ok(())
}
#[allow(dead_code)]
fn snapshot(&self) -> Result<Vec<RecallEvent>> {
let db = self.db.clone();
let path = self.path.clone();
let rtx = db
.begin_read()
.with_context(|| format!("begin_read recall log {}", path.display()))?;
let table = rtx
.open_table(RECALL_LOG)
.context("open RECALL_LOG table (read)")?;
let mut out = Vec::new();
for entry in table.iter().context("iter RECALL_LOG")? {
let (_k, v) = entry.context("decode RECALL_LOG row")?;
let ev: RecallEvent =
postcard::from_bytes(v.value()).context("postcard decode RecallEvent")?;
out.push(ev);
}
Ok(out)
}
pub async fn hit_count(&self, drawer_id: Uuid) -> Result<u64> {
let events = self.snapshot_async().await?;
let mut count: u64 = 0;
for ev in events {
if ev.drawer_id == Some(drawer_id) {
count += 1;
}
}
Ok(count)
}
pub async fn miss_rate(&self, palace_id: &str, window_days: u32) -> Result<f32> {
let events = self.snapshot_async().await?;
let since = Utc::now() - chrono::Duration::days(window_days as i64);
use std::collections::HashSet;
let mut total: HashSet<u64> = HashSet::new();
let mut misses: HashSet<u64> = HashSet::new();
for ev in events {
if ev.palace_id != palace_id || ev.occurred_at < since {
continue;
}
total.insert(ev.query_hash);
if ev.drawer_id.is_none() {
misses.insert(ev.query_hash);
}
}
if total.is_empty() {
return Ok(0.0);
}
Ok(misses.len() as f32 / total.len() as f32)
}
pub async fn top_drawers(&self, palace_id: &str, limit: usize) -> Result<Vec<(Uuid, u64)>> {
let events = self.snapshot_async().await?;
use std::collections::HashMap;
let mut counts: HashMap<Uuid, u64> = HashMap::new();
for ev in events {
if ev.palace_id != palace_id {
continue;
}
if let Some(id) = ev.drawer_id {
*counts.entry(id).or_insert(0) += 1;
}
}
let mut out: Vec<(Uuid, u64)> = counts.into_iter().collect();
out.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
out.truncate(limit);
Ok(out)
}
pub async fn missed_queries(&self, palace_id: &str, limit: usize) -> Result<Vec<(u64, u64)>> {
let events = self.snapshot_async().await?;
use std::collections::HashMap;
let mut counts: HashMap<u64, u64> = HashMap::new();
for ev in events {
if ev.palace_id != palace_id || ev.drawer_id.is_some() {
continue;
}
*counts.entry(ev.query_hash).or_insert(0) += 1;
}
let mut out: Vec<(u64, u64)> = counts.into_iter().collect();
out.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
out.truncate(limit);
Ok(out)
}
async fn snapshot_async(&self) -> Result<Vec<RecallEvent>> {
let db = self.db.clone();
let path = self.path.clone();
tokio::task::spawn_blocking(move || -> Result<Vec<RecallEvent>> {
let rtx = db
.begin_read()
.with_context(|| format!("begin_read recall log {}", path.display()))?;
let table = rtx
.open_table(RECALL_LOG)
.context("open RECALL_LOG table (read)")?;
let mut out = Vec::new();
for entry in table.iter().context("iter RECALL_LOG")? {
let (_k, v) = entry.context("decode RECALL_LOG row")?;
let ev: RecallEvent =
postcard::from_bytes(v.value()).context("postcard decode RecallEvent")?;
out.push(ev);
}
Ok(out)
})
.await
.context("snapshot task join error")?
}
}
fn resolve_redb_path(path: &Path) -> PathBuf {
if path.extension().is_some_and(|e| e == "db") {
path.with_extension("redb")
} else {
path.to_path_buf()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn normalize_removes_stop_words() {
assert_eq!(normalize_query("The quick Brown Fox!"), "quick brown fox");
}
#[test]
fn normalize_strips_punctuation() {
assert_eq!(normalize_query(" Rust is fast! "), "rust fast");
}
#[test]
fn fnv1a_is_deterministic() {
assert_eq!(fnv1a_hash("hello"), fnv1a_hash("hello"));
assert_ne!(fnv1a_hash("hello"), fnv1a_hash("world"));
}
#[tokio::test]
async fn record_then_hit_count() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
let id = Uuid::new_v4();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: query_hash("vector store"),
layer: 2,
drawer_id: Some(id),
score: 0.9,
occurred_at: Utc::now(),
})
.await
.unwrap();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: query_hash("vector store"),
layer: 2,
drawer_id: Some(id),
score: 0.85,
occurred_at: Utc::now(),
})
.await
.unwrap();
assert_eq!(log.hit_count(id).await.unwrap(), 2);
}
#[tokio::test]
async fn miss_rate_all_miss() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: query_hash("unknown topic"),
layer: 3,
drawer_id: None,
score: 0.0,
occurred_at: Utc::now(),
})
.await
.unwrap();
let rate = log.miss_rate("test", 7).await.unwrap();
assert!((rate - 1.0).abs() < 1e-4, "expected 1.0 got {rate}");
}
#[tokio::test]
async fn miss_rate_all_hit() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: query_hash("rust"),
layer: 2,
drawer_id: Some(Uuid::new_v4()),
score: 0.9,
occurred_at: Utc::now(),
})
.await
.unwrap();
let rate = log.miss_rate("test", 7).await.unwrap();
assert!((rate - 0.0).abs() < 1e-4, "expected 0.0 got {rate}");
}
#[tokio::test]
async fn top_drawers_sorted_by_hits() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
let id_a = Uuid::new_v4();
let id_b = Uuid::new_v4();
for _ in 0..3 {
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: 1,
layer: 2,
drawer_id: Some(id_a),
score: 0.9,
occurred_at: Utc::now(),
})
.await
.unwrap();
}
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: 2,
layer: 2,
drawer_id: Some(id_b),
score: 0.8,
occurred_at: Utc::now(),
})
.await
.unwrap();
let top = log.top_drawers("test", 5).await.unwrap();
assert_eq!(top[0].0, id_a);
assert_eq!(top[0].1, 3);
assert_eq!(top[1].0, id_b);
}
#[tokio::test]
async fn missed_queries_returns_most_missed_first() {
let dir = tempdir().unwrap();
let log = RecallLog::open(&dir.path().join("recall.db")).unwrap();
let h1 = query_hash("obscure topic");
let h2 = query_hash("another gap");
for _ in 0..3 {
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: h1,
layer: 3,
drawer_id: None,
score: 0.0,
occurred_at: Utc::now(),
})
.await
.unwrap();
}
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: h2,
layer: 3,
drawer_id: None,
score: 0.0,
occurred_at: Utc::now(),
})
.await
.unwrap();
let missed = log.missed_queries("test", 5).await.unwrap();
assert_eq!(missed[0].0, h1);
assert_eq!(missed[0].1, 3);
}
#[tokio::test]
async fn roundtrip_persists_across_reopen() {
let dir = tempdir().unwrap();
let path = dir.path().join("recall.db");
let id = Uuid::new_v4();
{
let log = RecallLog::open(&path).unwrap();
log.record(RecallEvent {
palace_id: "test".into(),
query_hash: 42,
layer: 2,
drawer_id: Some(id),
score: 0.5,
occurred_at: Utc::now(),
})
.await
.unwrap();
}
let log2 = RecallLog::open(&path).unwrap();
assert_eq!(log2.hit_count(id).await.unwrap(), 1);
log2.record(RecallEvent {
palace_id: "test".into(),
query_hash: 42,
layer: 2,
drawer_id: Some(id),
score: 0.7,
occurred_at: Utc::now(),
})
.await
.unwrap();
assert_eq!(log2.hit_count(id).await.unwrap(), 2);
}
#[test]
fn callers_passing_recall_db_get_redb_sibling() {
let dir = tempdir().unwrap();
let legacy = dir.path().join("recall.db");
let _log = RecallLog::open(&legacy).unwrap();
let redb_path = dir.path().join("recall.redb");
assert!(
redb_path.exists(),
"expected redb sibling to be created at {}",
redb_path.display()
);
}
}