use std::collections::HashSet;
use std::io::Write;
use std::path::PathBuf;
use anyhow::{Context, Result};
use clap::{Parser, Subcommand};
use uuid::Uuid;
use second_brain_api::eval::caching_store::CachingStore;
use second_brain_api::eval::{self, GateSweepReport, QueryRecord};
use second_brain_core::embedding::Embedder;
use second_brain_core::kuzu_store::KuzuStore;
const EVAL_MACHINE_ID: &str = "sb-recall-eval";
const DEFAULT_LIMIT: usize = 10;
#[derive(Parser)]
#[command(name = "sb-recall-eval", about = "Recall A/B eval harness for second-brain")]
struct Cli {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
ExtractCorpus {
#[arg(long)]
snapshot: PathBuf,
#[arg(long)]
out: PathBuf,
},
RunAb {
#[arg(long)]
snapshot: PathBuf,
#[arg(long)]
eval: PathBuf,
#[arg(long)]
out: PathBuf,
#[arg(long, default_value_t = DEFAULT_LIMIT)]
limit: usize,
#[arg(long, default_value_t = 0, hide = true)]
max_queries: usize,
},
GateSweep {
#[arg(long, conflicts_with = "ab")]
records: Option<PathBuf>,
#[arg(long, conflicts_with = "records")]
ab: Option<PathBuf>,
#[arg(long)]
out: PathBuf,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Command::ExtractCorpus { snapshot, out } => extract_corpus(&snapshot, &out),
Command::RunAb {
snapshot,
eval,
out,
limit,
max_queries,
} => run_ab(&snapshot, &eval, &out, limit, max_queries),
Command::GateSweep { records, ab, out } => gate_sweep(records, ab, &out),
}
}
fn open_store(snapshot: &std::path::Path) -> Result<KuzuStore> {
KuzuStore::open(snapshot, EVAL_MACHINE_ID.to_string())
.with_context(|| format!("opening snapshot at {}", snapshot.display()))
}
fn extract_corpus(snapshot: &std::path::Path, out: &std::path::Path) -> Result<()> {
let store = open_store(snapshot)?;
let count = eval::extract_corpus(&store, out)?;
println!("wrote {count} corpus entries to {}", out.display());
Ok(())
}
fn run_ab(
snapshot: &std::path::Path,
eval_path: &std::path::Path,
out: &std::path::Path,
limit: usize,
max_queries: usize,
) -> Result<()> {
let mut queries = eval::load_eval_set(eval_path)?;
if max_queries > 0 && queries.len() > max_queries {
queries.truncate(max_queries);
}
let store = open_store(snapshot)?;
let embedder = Embedder::new().context("loading embedding model")?;
eprintln!("embedding {} queries", queries.len());
let embedded = eval::embed_all_queries(&embedder, &queries)?;
let caching_store = CachingStore::new(&store);
eprintln!("prewarming relation cache");
caching_store.prewarm().context("prewarming relation cache")?;
eprintln!("running bare arm");
let bare = eval::run_arm_parallel(&caching_store, &embedded, false, limit)?;
eprintln!("running prefixed arm");
let prefixed = eval::run_arm_parallel(&caching_store, &embedded, true, limit)?;
let relevant_sets: std::collections::HashMap<String, HashSet<Uuid>> = queries
.iter()
.map(|q| {
(
q.query_id.clone(),
q.relevant_memory_ids.iter().copied().collect(),
)
})
.collect();
let report = eval::aggregate(&bare, &prefixed, &relevant_sets);
write_records(&out.with_extension("bare.jsonl"), &bare)?;
write_records(&out.with_extension("prefixed.jsonl"), &prefixed)?;
write_json(out, &report)?;
println!("aggregate report written to {}", out.display());
Ok(())
}
fn gate_sweep(
records: Option<PathBuf>,
ab: Option<PathBuf>,
out: &std::path::Path,
) -> Result<()> {
let prefixed = match (records, ab) {
(Some(path), _) => read_records(&path)?,
(None, Some(path)) => read_records(&path.with_extension("prefixed.jsonl"))?,
(None, None) => {
anyhow::bail!("gate-sweep requires either --records or --ab");
}
};
let report: GateSweepReport = eval::gate_sweep(&prefixed);
write_json(out, &report)?;
println!("gate sweep report written to {}", out.display());
Ok(())
}
fn read_records(path: &std::path::Path) -> Result<Vec<QueryRecord>> {
let text = std::fs::read_to_string(path)
.with_context(|| format!("reading records from {}", path.display()))?;
parse_records(&text)
}
fn parse_records(text: &str) -> Result<Vec<QueryRecord>> {
let mut out = Vec::new();
for line in text.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
out.push(serde_json::from_str(trimmed)?);
}
Ok(out)
}
fn write_records(path: &std::path::Path, records: &[QueryRecord]) -> Result<()> {
let mut file = std::fs::File::create(path)?;
for r in records {
writeln!(file, "{}", serde_json::to_string(r)?)?;
}
Ok(())
}
fn write_json<T: serde::Serialize>(path: &std::path::Path, value: &T) -> Result<()> {
let json = serde_json::to_string_pretty(value)?;
std::fs::write(path, json)?;
Ok(())
}