pub mod bootstrap;
pub mod caching_store;
pub mod metrics;
use std::collections::HashSet;
use std::io::Write;
use std::path::Path;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use second_brain_core::embedding::Embedder;
use second_brain_core::kuzu_store::KuzuStore;
use second_brain_core::query::{QueryEngine, QueryFilters, QueryRequest};
use second_brain_core::store::Store;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EvalQuery {
pub query_id: String,
pub query: String,
pub query_variant: String,
pub seed_memory_id: Uuid,
pub memory_type: String,
pub relevant_memory_ids: Vec<Uuid>,
#[serde(default)]
pub note: Option<String>,
#[serde(default)]
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRecord {
pub query_id: String,
pub use_prefix: bool,
pub ranked_ids: Vec<Uuid>,
pub scores: Vec<f32>,
pub first_relevant_rank: Option<usize>,
pub gold_raw_rank: Option<usize>,
pub gold_raw_similarity: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArmMetrics {
pub recall_at_1: f32,
pub recall_at_3: f32,
pub recall_at_5: f32,
pub mrr: f32,
pub precision_at_5: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregateReport {
pub bare: ArmMetrics,
pub prefixed: ArmMetrics,
pub delta_recall_at_3_ci: (f32, f32),
pub delta_mrr_ci: (f32, f32),
pub gated_out_rate_bare: f32,
pub gated_out_rate_prefixed: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatePoint {
pub threshold: f32,
pub recall_at_1: f32,
pub recall_at_3: f32,
pub recall_at_5: f32,
pub precision_proxy: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GateSweepReport {
pub frontier: Vec<GatePoint>,
pub baseline_threshold: f32,
pub chosen_threshold: f32,
pub chosen_beats_baseline: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorpusEntry {
pub id: Uuid,
pub content: String,
pub memory_type: String,
pub created_at: String,
pub project_path: Option<String>,
}
const BASELINE_THRESHOLD: f32 = 0.59;
pub fn load_eval_set(path: &Path) -> Result<Vec<EvalQuery>> {
let text = std::fs::read_to_string(path)?;
let mut out = Vec::new();
for line in text.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
out.push(serde_json::from_str(trimmed)?);
}
Ok(out)
}
pub fn run_arm(
store: &KuzuStore,
embedder: &Embedder,
queries: &[EvalQuery],
use_prefix: bool,
limit: usize,
) -> Result<Vec<QueryRecord>> {
let engine = QueryEngine::new(store);
let mut records = Vec::with_capacity(queries.len());
for q in queries {
let embedding = if use_prefix {
embedder.embed_query(&q.query)?
} else {
embedder.embed(&q.query)?
};
let relevant: HashSet<Uuid> = q.relevant_memory_ids.iter().copied().collect();
let request = QueryRequest {
text: q.query.clone(),
embedding: embedding.clone(),
limit,
filters: QueryFilters::default(),
};
let results = engine.recall(&request)?;
let ranked_ids: Vec<Uuid> = results.iter().map(|r| r.memory.id).collect();
let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let first_relevant_rank = ranked_ids
.iter()
.position(|id| relevant.contains(id))
.map(|idx| idx + 1);
let raw = store.vector_search(&embedding, limit * 3)?;
let mut gold_raw_rank = None;
let mut gold_raw_similarity = None;
for (idx, (mem, sim)) in raw.iter().enumerate() {
if relevant.contains(&mem.id) {
gold_raw_rank = Some(idx + 1);
gold_raw_similarity = Some(*sim);
break;
}
}
records.push(QueryRecord {
query_id: q.query_id.clone(),
use_prefix,
ranked_ids,
scores,
first_relevant_rank,
gold_raw_rank,
gold_raw_similarity,
});
}
Ok(records)
}
pub struct EmbeddedQuery {
pub query: EvalQuery,
pub relevant: HashSet<Uuid>,
pub bare_embedding: Vec<f32>,
pub prefixed_embedding: Vec<f32>,
}
pub fn embed_all_queries(
embedder: &Embedder,
queries: &[EvalQuery],
) -> Result<Vec<EmbeddedQuery>> {
let bare_texts: Vec<&str> = queries.iter().map(|q| q.query.as_str()).collect();
let bare = embedder.embed_batch(&bare_texts)?;
let prefixed_owned: Vec<String> = queries
.iter()
.map(|q| second_brain_core::embedding::query_prompt(&q.query))
.collect();
let prefixed_texts: Vec<&str> = prefixed_owned.iter().map(|s| s.as_str()).collect();
let prefixed = embedder.embed_batch(&prefixed_texts)?;
let mut out = Vec::with_capacity(queries.len());
for (i, q) in queries.iter().enumerate() {
out.push(EmbeddedQuery {
query: q.clone(),
relevant: q.relevant_memory_ids.iter().copied().collect(),
bare_embedding: bare[i].clone(),
prefixed_embedding: prefixed[i].clone(),
});
}
Ok(out)
}
fn record_for<S: Store + Sync>(
embedded: &EmbeddedQuery,
store: &S,
use_prefix: bool,
limit: usize,
) -> Result<QueryRecord> {
let engine = QueryEngine::new(store);
let embedding = if use_prefix {
&embedded.prefixed_embedding
} else {
&embedded.bare_embedding
};
let request = QueryRequest {
text: embedded.query.query.clone(),
embedding: embedding.clone(),
limit,
filters: QueryFilters::default(),
};
let results = engine.recall(&request)?;
let ranked_ids: Vec<Uuid> = results.iter().map(|r| r.memory.id).collect();
let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let first_relevant_rank = ranked_ids
.iter()
.position(|id| embedded.relevant.contains(id))
.map(|idx| idx + 1);
let raw = store.vector_search(embedding, limit * 3)?;
let mut gold_raw_rank = None;
let mut gold_raw_similarity = None;
for (idx, (mem, sim)) in raw.iter().enumerate() {
if embedded.relevant.contains(&mem.id) {
gold_raw_rank = Some(idx + 1);
gold_raw_similarity = Some(*sim);
break;
}
}
Ok(QueryRecord {
query_id: embedded.query.query_id.clone(),
use_prefix,
ranked_ids,
scores,
first_relevant_rank,
gold_raw_rank,
gold_raw_similarity,
})
}
pub fn run_arm_parallel<S: Store + Sync>(
store: &S,
embedded: &[EmbeddedQuery],
use_prefix: bool,
limit: usize,
) -> Result<Vec<QueryRecord>> {
let total = embedded.len();
if total == 0 {
return Ok(Vec::new());
}
let workers = thread::available_parallelism()
.map(|n| n.get().saturating_sub(1).max(1))
.unwrap_or(1);
let chunk_size = total.div_ceil(workers);
let done = AtomicUsize::new(0);
let collected: Mutex<Vec<(usize, QueryRecord)>> = Mutex::new(Vec::with_capacity(total));
let error: Mutex<Option<anyhow::Error>> = Mutex::new(None);
thread::scope(|scope| {
for chunk in 0..workers {
let start = chunk * chunk_size;
if start >= total {
break;
}
let end = (start + chunk_size).min(total);
let done = &done;
let collected = &collected;
let error = &error;
scope.spawn(move || {
let mut local: Vec<(usize, QueryRecord)> = Vec::with_capacity(end - start);
for (offset, eq) in embedded[start..end].iter().enumerate() {
if error.lock().unwrap().is_some() {
return;
}
match record_for(eq, store, use_prefix, limit) {
Ok(rec) => local.push((start + offset, rec)),
Err(e) => {
*error.lock().unwrap() = Some(e);
return;
}
}
let n = done.fetch_add(1, Ordering::Relaxed) + 1;
if n % 25 == 0 || n == total {
eprintln!(" {n}/{total} queries");
}
}
collected.lock().unwrap().extend(local);
});
}
});
if let Some(e) = error.into_inner().unwrap() {
return Err(e);
}
let mut indexed = collected.into_inner().unwrap();
indexed.sort_by_key(|(i, _)| *i);
Ok(indexed.into_iter().map(|(_, r)| r).collect())
}
pub fn aggregate(
bare: &[QueryRecord],
prefixed: &[QueryRecord],
relevant_sets: &std::collections::HashMap<String, HashSet<Uuid>>,
) -> AggregateReport {
const AGG_SEED: u64 = 0x4B1D_C0DE;
let empty: HashSet<Uuid> = HashSet::new();
let per_query = |rec: &QueryRecord| -> (f32, f32, f32, f32, f32) {
let rel = relevant_sets.get(&rec.query_id).unwrap_or(&empty);
(
metrics::recall_at_k(&rec.ranked_ids, rel, 1),
metrics::recall_at_k(&rec.ranked_ids, rel, 3),
metrics::recall_at_k(&rec.ranked_ids, rel, 5),
metrics::mrr(&rec.ranked_ids, rel),
metrics::precision_at_k(&rec.ranked_ids, rel, 5),
)
};
let arm = |records: &[QueryRecord]| -> ArmMetrics {
if records.is_empty() {
return ArmMetrics {
recall_at_1: 0.0,
recall_at_3: 0.0,
recall_at_5: 0.0,
mrr: 0.0,
precision_at_5: 0.0,
};
}
let n = records.len() as f32;
let mut acc = (0.0, 0.0, 0.0, 0.0, 0.0);
for r in records {
let (r1, r3, r5, m, p5) = per_query(r);
acc.0 += r1;
acc.1 += r3;
acc.2 += r5;
acc.3 += m;
acc.4 += p5;
}
ArmMetrics {
recall_at_1: acc.0 / n,
recall_at_3: acc.1 / n,
recall_at_5: acc.2 / n,
mrr: acc.3 / n,
precision_at_5: acc.4 / n,
}
};
let bare_idx: std::collections::HashMap<&str, &QueryRecord> =
bare.iter().map(|r| (r.query_id.as_str(), r)).collect();
let mut delta_r3 = Vec::new();
let mut delta_mrr = Vec::new();
for p_rec in prefixed {
if let Some(b_rec) = bare_idx.get(p_rec.query_id.as_str()) {
let (_, p_r3, _, p_mrr, _) = per_query(p_rec);
let (_, b_r3, _, b_mrr, _) = per_query(b_rec);
delta_r3.push(p_r3 - b_r3);
delta_mrr.push(p_mrr - b_mrr);
}
}
let gated_rate = |records: &[QueryRecord]| -> f32 {
let flags: Vec<bool> = records
.iter()
.map(|r| match (r.gold_raw_rank, r.gold_raw_similarity) {
(Some(_), Some(sim)) => sim < BASELINE_THRESHOLD,
_ => false,
})
.collect();
metrics::gated_out_rate(&flags)
};
AggregateReport {
bare: arm(bare),
prefixed: arm(prefixed),
delta_recall_at_3_ci: bootstrap::paired_bootstrap_ci(&delta_r3, 10000, 0.95, AGG_SEED),
delta_mrr_ci: bootstrap::paired_bootstrap_ci(&delta_mrr, 10000, 0.95, AGG_SEED),
gated_out_rate_bare: gated_rate(bare),
gated_out_rate_prefixed: gated_rate(prefixed),
}
}
pub fn gate_sweep(prefixed: &[QueryRecord]) -> GateSweepReport {
const GRID_SEED: u64 = 0x5EED_6A7E;
let recalled_at = |rec: &QueryRecord, k: usize, t: f32| -> f32 {
match (rec.gold_raw_rank, rec.gold_raw_similarity) {
(Some(rank), Some(sim)) if rank <= k && sim >= t => 1.0,
_ => 0.0,
}
};
let mean = |vals: &[f32]| -> f32 {
if vals.is_empty() {
0.0
} else {
vals.iter().sum::<f32>() / vals.len() as f32
}
};
let baseline_r3: Vec<f32> = prefixed
.iter()
.map(|r| recalled_at(r, 3, BASELINE_THRESHOLD))
.collect();
let baseline_r3_mean = mean(&baseline_r3);
let mut frontier = Vec::with_capacity(41);
let mut chosen_threshold = BASELINE_THRESHOLD;
let mut chosen_beats_baseline = false;
let mut best_recall_at_3 = baseline_r3_mean;
for step in 0..=40u32 {
let t = 0.40 + step as f32 * 0.01;
let r1: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 1, t)).collect();
let r3: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 3, t)).collect();
let r5: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 5, t)).collect();
let recall_at_3 = mean(&r3);
let recall_at_5 = mean(&r5);
frontier.push(GatePoint {
threshold: t,
recall_at_1: mean(&r1),
recall_at_3,
recall_at_5,
precision_proxy: recall_at_5 / 5.0,
});
let deltas: Vec<f32> = r3
.iter()
.zip(baseline_r3.iter())
.map(|(t_val, b_val)| t_val - b_val)
.collect();
let (lo, _hi) = bootstrap::paired_bootstrap_ci(&deltas, 2000, 0.95, GRID_SEED);
if lo > 0.0 && recall_at_3 > best_recall_at_3 + 1e-6 {
best_recall_at_3 = recall_at_3;
chosen_threshold = t;
chosen_beats_baseline = true;
}
}
GateSweepReport {
frontier,
baseline_threshold: BASELINE_THRESHOLD,
chosen_threshold,
chosen_beats_baseline,
}
}
pub fn extract_corpus(store: &KuzuStore, out: &Path) -> Result<usize> {
let memories = store.all_memories_with_embeddings()?;
let mut file = std::fs::File::create(out)?;
let mut count = 0;
for m in &memories {
let entry = CorpusEntry {
id: m.id,
content: m.content.clone(),
memory_type: format!("{:?}", m.memory_type).to_lowercase(),
created_at: m.created_at.to_rfc3339(),
project_path: m.project_path.clone(),
};
writeln!(file, "{}", serde_json::to_string(&entry)?)?;
count += 1;
}
Ok(count)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
fn record(id_rank: Option<usize>, raw_rank: Option<usize>, raw_sim: Option<f32>) -> QueryRecord {
QueryRecord {
query_id: "q".to_string(),
use_prefix: true,
ranked_ids: Vec::new(),
scores: Vec::new(),
first_relevant_rank: id_rank,
gold_raw_rank: raw_rank,
gold_raw_similarity: raw_sim,
}
}
#[test]
fn load_eval_set_parses_one_object_per_line() {
let dir = std::env::temp_dir();
let path = dir.join(format!("eval_set_{}.jsonl", Uuid::new_v4()));
let id_a = Uuid::new_v4();
let id_b = Uuid::new_v4();
let line1 = format!(
r#"{{"query_id":"q1","query":"kuzu choice","query_variant":"literal","seed_memory_id":"{id_a}","memory_type":"decision","relevant_memory_ids":["{id_a}"]}}"#
);
let line2 = format!(
r#"{{"query_id":"q2","query":"sync design","query_variant":"paraphrase","seed_memory_id":"{id_b}","memory_type":"architecture","relevant_memory_ids":["{id_b}","{id_a}"],"tags":["sync"]}}"#
);
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, "{line1}").unwrap();
writeln!(f, "{line2}").unwrap();
drop(f);
let queries = load_eval_set(&path).unwrap();
std::fs::remove_file(&path).ok();
assert_eq!(queries.len(), 2);
assert_eq!(queries[0].query_id, "q1");
assert_eq!(queries[0].seed_memory_id, id_a);
assert_eq!(queries[1].relevant_memory_ids.len(), 2);
assert_eq!(queries[1].tags, vec!["sync".to_string()]);
}
#[test]
fn load_eval_set_tolerates_blank_lines() {
let dir = std::env::temp_dir();
let path = dir.join(format!("eval_blank_{}.jsonl", Uuid::new_v4()));
let id = Uuid::new_v4();
let line = format!(
r#"{{"query_id":"q1","query":"x","query_variant":"v","seed_memory_id":"{id}","memory_type":"semantic","relevant_memory_ids":["{id}"]}}"#
);
std::fs::write(&path, format!("\n{line}\n\n")).unwrap();
let queries = load_eval_set(&path).unwrap();
std::fs::remove_file(&path).ok();
assert_eq!(queries.len(), 1);
}
#[test]
fn gate_sweep_emits_full_grid_and_monotone_recall() {
let records = vec![
record(Some(1), Some(1), Some(0.85)),
record(Some(2), Some(2), Some(0.62)),
record(Some(4), Some(4), Some(0.55)),
record(None, None, None),
];
let report = gate_sweep(&records);
assert_eq!(report.frontier.len(), 41);
assert!((report.frontier.first().unwrap().threshold - 0.40).abs() < 1e-4);
assert!((report.frontier.last().unwrap().threshold - 0.80).abs() < 1e-4);
for w in report.frontier.windows(2) {
assert!(
w[0].recall_at_3 >= w[1].recall_at_3 - 1e-6,
"recall must not increase as the gate tightens"
);
}
}
#[test]
fn gate_sweep_recall_reflects_raw_rank_and_similarity() {
let records = vec![
record(Some(1), Some(1), Some(0.85)),
record(Some(2), Some(2), Some(0.62)),
record(Some(4), Some(4), Some(0.55)),
record(None, None, None),
];
let report = gate_sweep(&records);
let at = |t: f32| {
report
.frontier
.iter()
.find(|p| (p.threshold - t).abs() < 1e-4)
.unwrap()
};
let p050 = at(0.50);
assert!((p050.recall_at_1 - 0.25).abs() < 1e-6, "recall@1 was {}", p050.recall_at_1);
assert!((p050.recall_at_3 - 0.5).abs() < 1e-6, "recall@3 was {}", p050.recall_at_3);
assert!((p050.recall_at_5 - 0.75).abs() < 1e-6, "recall@5 was {}", p050.recall_at_5);
let p070 = at(0.70);
assert!((p070.recall_at_1 - 0.25).abs() < 1e-6);
assert!((p070.recall_at_3 - 0.25).abs() < 1e-6);
assert!((p070.recall_at_5 - 0.25).abs() < 1e-6);
}
#[test]
fn gate_sweep_keeps_baseline_when_nothing_beats_it() {
let records = vec![record(Some(1), Some(1), Some(0.85))];
let report = gate_sweep(&records);
assert!((report.baseline_threshold - 0.59).abs() < 1e-6);
assert!(!report.chosen_beats_baseline);
assert!((report.chosen_threshold - 0.59).abs() < 1e-6);
}
}