use spool::domain::{
LifecycleCandidate, MatchedProject, MemoryLifecycleState, MemoryOrigin, MemoryRecord,
MemoryScope, MemorySourceKind, OutputFormat, RouteInput, TargetTool,
};
use spool::engine::selector;
use std::collections::HashSet;
use std::path::PathBuf;
struct EvalCase {
name: &'static str,
task: &'static str,
cwd: &'static str,
files: Vec<&'static str>,
project: Option<MatchedProject>,
records: Vec<(String, MemoryRecord)>,
expected_ids: Vec<&'static str>,
negative_ids: Vec<&'static str>,
k: usize,
}
struct EvalResult {
name: &'static str,
precision_at_k: f64,
recall_at_k: f64,
mrr: f64,
retrieved_ids: Vec<String>,
}
fn compute_metrics(case: &EvalCase, results: &[LifecycleCandidate]) -> EvalResult {
let retrieved_ids: Vec<String> = results
.iter()
.take(case.k)
.map(|c| c.record_id.clone())
.collect();
let expected_set: HashSet<&str> = case.expected_ids.iter().copied().collect();
let relevant_in_top_k = retrieved_ids
.iter()
.filter(|id| expected_set.contains(id.as_str()))
.count();
let denominator = case.k.min(retrieved_ids.len()).max(1);
let precision_at_k = relevant_in_top_k as f64 / denominator as f64;
let recall_at_k = if expected_set.is_empty() {
1.0
} else {
relevant_in_top_k as f64 / expected_set.len() as f64
};
let mrr = retrieved_ids
.iter()
.position(|id| expected_set.contains(id.as_str()))
.map(|pos| 1.0 / (pos as f64 + 1.0))
.unwrap_or(0.0);
let negative_set: HashSet<&str> = case.negative_ids.iter().copied().collect();
for id in &retrieved_ids {
assert!(
!negative_set.contains(id.as_str()),
"[{}] Negative ID '{}' appeared in results: {:?}",
case.name,
id,
retrieved_ids
);
}
EvalResult {
name: case.name,
precision_at_k,
recall_at_k,
mrr,
retrieved_ids,
}
}
fn run_lifecycle_eval(case: &EvalCase) -> EvalResult {
let input = RouteInput {
task: case.task.to_string(),
cwd: PathBuf::from(case.cwd),
files: case.files.iter().map(|f| f.to_string()).collect(),
target: TargetTool::Codex,
format: OutputFormat::Prompt,
};
let results = selector::select_lifecycle_candidates(
case.project.as_ref(),
&case.records,
&input,
case.k,
&HashSet::new(),
None,
);
compute_metrics(case, &results)
}
fn make_record(
title: &str,
summary: &str,
memory_type: &str,
scope: MemoryScope,
project_id: Option<&str>,
state: MemoryLifecycleState,
) -> MemoryRecord {
MemoryRecord {
title: title.to_string(),
summary: summary.to_string(),
memory_type: memory_type.to_string(),
scope,
state,
origin: MemoryOrigin {
source_kind: MemorySourceKind::Manual,
source_ref: "eval".to_string(),
},
project_id: project_id.map(|v| v.to_string()),
user_id: None,
sensitivity: None,
entities: Vec::new(),
tags: Vec::new(),
triggers: Vec::new(),
related_files: Vec::new(),
related_records: Vec::new(),
supersedes: None,
applies_to: Vec::new(),
valid_until: None,
}
}
fn make_project(id: &str) -> MatchedProject {
MatchedProject {
id: id.to_string(),
name: id.to_string(),
reason: "eval".to_string(),
}
}
fn build_eval_cases() -> Vec<EvalCase> {
vec![
{
let mut rec = make_record(
"Use SQLite for local storage",
"All local persistence uses SQLite via rusqlite",
"decision",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
rec.entities = vec!["SQLite".to_string(), "rusqlite".to_string()];
let distractor = make_record(
"Use PostgreSQL for production",
"Production database is PostgreSQL",
"decision",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
EvalCase {
name: "entity_match_sqlite",
task: "implement SQLite migration for local cache",
cwd: "/tmp/repo",
files: vec!["src/db.rs"],
project: None,
records: vec![
("rec-sqlite".to_string(), rec),
("rec-pg".to_string(), distractor),
],
expected_ids: vec!["rec-sqlite"],
negative_ids: vec![],
k: 5,
}
},
{
let mut rec = make_record(
"Database connection pooling",
"Use connection pool with max 5 connections",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
rec.related_files = vec!["src/db.rs".to_string(), "src/pool.rs".to_string()];
let unrelated = make_record(
"UI color scheme",
"Use dark mode by default",
"preference",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
EvalCase {
name: "related_files_match",
task: "refactor database layer",
cwd: "/tmp/repo",
files: vec!["src/db.rs"],
project: None,
records: vec![
("rec-pool".to_string(), rec),
("rec-ui".to_string(), unrelated),
],
expected_ids: vec!["rec-pool"],
negative_ids: vec![],
k: 5,
}
},
{
let mut rec = make_record(
"Zero-downtime deployment",
"Always use rolling deployment strategy",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
rec.tags = vec!["deployment".to_string(), "infrastructure".to_string()];
let noise = make_record(
"Code review process",
"All PRs need at least one approval",
"workflow",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
EvalCase {
name: "tags_match_deployment",
task: "fix deployment pipeline timeout",
cwd: "/tmp/repo",
files: vec!["infra/deploy.yml"],
project: None,
records: vec![
("rec-deploy".to_string(), rec),
("rec-review".to_string(), noise),
],
expected_ids: vec!["rec-deploy"],
negative_ids: vec![],
k: 5,
}
},
{
let mut rec = make_record(
"Security audit checklist",
"Run OWASP checks before any release",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
rec.triggers = vec!["release".to_string(), "security".to_string()];
let noise = make_record(
"Logging format",
"Use structured JSON logging",
"preference",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
EvalCase {
name: "triggers_exact_match",
task: "prepare release checklist",
cwd: "/tmp/repo",
files: vec![],
project: None,
records: vec![
("rec-security".to_string(), rec),
("rec-logging".to_string(), noise),
],
expected_ids: vec!["rec-security"],
negative_ids: vec![],
k: 5,
}
},
{
let mut rec_a = make_record(
"API versioning strategy",
"Use URL-based versioning for public APIs",
"decision",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
rec_a.tags = vec!["api".to_string()];
rec_a.related_records = vec!["rec-B".to_string()];
let mut rec_b = make_record(
"Deprecation policy",
"Deprecated endpoints stay live for 6 months",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
rec_b.scope = MemoryScope::Project;
rec_b.project_id = Some("other-project".to_string());
EvalCase {
name: "relation_expansion_pull_in",
task: "design new api endpoint",
cwd: "/tmp/repo",
files: vec!["src/api/routes.rs"],
project: Some(make_project("spool")),
records: vec![("rec-A".to_string(), rec_a), ("rec-B".to_string(), rec_b)],
expected_ids: vec!["rec-A", "rec-B"],
negative_ids: vec![],
k: 5,
}
},
{
let constraint = make_record(
"Never skip CI",
"All merges must pass CI pipeline",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
let session = make_record(
"CI discussion notes",
"Discussed CI improvements in standup",
"session",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
EvalCase {
name: "memory_type_priority",
task: "fix CI pipeline",
cwd: "/tmp/repo",
files: vec![],
project: None,
records: vec![
("rec-constraint".to_string(), constraint),
("rec-session".to_string(), session),
],
expected_ids: vec!["rec-constraint"],
negative_ids: vec![],
k: 1,
}
},
{
let mut matching = make_record(
"Spool coding style",
"Follow Rust 2021 edition idioms",
"preference",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
matching.applies_to = vec!["spool".to_string()];
let mut wrong_project = make_record(
"Other project style",
"Follow Python PEP8",
"preference",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
wrong_project.applies_to = vec!["other-project".to_string()];
EvalCase {
name: "applies_to_filter",
task: "refactor coding style",
cwd: "/tmp/repo",
files: vec![],
project: Some(make_project("spool")),
records: vec![
("rec-spool".to_string(), matching),
("rec-other".to_string(), wrong_project),
],
expected_ids: vec!["rec-spool"],
negative_ids: vec![],
k: 5,
}
},
{
let active = make_record(
"Active constraint",
"Must validate all inputs",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
let archived = make_record(
"Old constraint",
"Must validate all inputs (deprecated)",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Archived,
);
EvalCase {
name: "archived_excluded",
task: "add input validation",
cwd: "/tmp/repo",
files: vec!["src/validation.rs"],
project: None,
records: vec![
("rec-active".to_string(), active),
("rec-archived".to_string(), archived),
],
expected_ids: vec!["rec-active"],
negative_ids: vec!["rec-archived"],
k: 5,
}
},
{
let matching = make_record(
"Spool architecture",
"Use layered architecture with shared core",
"decision",
MemoryScope::Project,
Some("spool"),
MemoryLifecycleState::Accepted,
);
let wrong = make_record(
"Other architecture",
"Use microservices",
"decision",
MemoryScope::Project,
Some("other-project"),
MemoryLifecycleState::Accepted,
);
EvalCase {
name: "project_scope_filter",
task: "review architecture decisions",
cwd: "/tmp/repo",
files: vec![],
project: Some(make_project("spool")),
records: vec![
("rec-spool-arch".to_string(), matching),
("rec-other-arch".to_string(), wrong),
],
expected_ids: vec!["rec-spool-arch"],
negative_ids: vec!["rec-other-arch"],
k: 5,
}
},
{
let canonical = make_record(
"Error handling policy",
"Always use Result types, never panic",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Canonical,
);
let accepted = make_record(
"Error handling notes",
"Discussed error handling approaches",
"constraint",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
EvalCase {
name: "canonical_boost",
task: "implement error handling",
cwd: "/tmp/repo",
files: vec!["src/error.rs"],
project: None,
records: vec![
("rec-canonical".to_string(), canonical),
("rec-accepted".to_string(), accepted),
],
expected_ids: vec!["rec-canonical"],
negative_ids: vec![],
k: 1,
}
},
{
let mut rich = make_record(
"Database migration strategy",
"Use incremental migrations with rollback support",
"decision",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
rich.entities = vec!["SQLite".to_string(), "migration".to_string()];
rich.tags = vec!["database".to_string()];
rich.related_files = vec!["src/db/migrate.rs".to_string()];
rich.triggers = vec!["migration".to_string()];
let weak = make_record(
"General notes",
"Some notes about database things",
"pattern",
MemoryScope::User,
None,
MemoryLifecycleState::Accepted,
);
EvalCase {
name: "compound_signals",
task: "run database migration",
cwd: "/tmp/repo",
files: vec!["src/db/migrate.rs"],
project: None,
records: vec![
("rec-rich".to_string(), rich),
("rec-weak".to_string(), weak),
],
expected_ids: vec!["rec-rich"],
negative_ids: vec![],
k: 5,
}
},
{
let candidate = make_record(
"Testing strategy",
"Prefer integration tests over unit tests for IO-heavy code",
"preference",
MemoryScope::User,
None,
MemoryLifecycleState::Candidate,
);
EvalCase {
name: "candidate_retrievable",
task: "write tests for IO module",
cwd: "/tmp/repo",
files: vec!["src/io.rs"],
project: None,
records: vec![("rec-candidate".to_string(), candidate)],
expected_ids: vec!["rec-candidate"],
negative_ids: vec![],
k: 5,
}
},
]
}
#[test]
fn retrieval_eval_lifecycle_precision() {
let cases = build_eval_cases();
let mut results: Vec<EvalResult> = Vec::new();
for case in &cases {
let result = run_lifecycle_eval(case);
results.push(result);
}
println!("\n{}", "=".repeat(70));
println!(" RETRIEVAL EVAL RESULTS");
println!("{}", "=".repeat(70));
println!("{:<30} {:>10} {:>10} {:>10}", "Case", "P@K", "R@K", "MRR");
println!("{}", "-".repeat(70));
let mut total_precision = 0.0;
let mut total_recall = 0.0;
let mut total_mrr = 0.0;
for result in &results {
println!(
"{:<30} {:>10.3} {:>10.3} {:>10.3}",
result.name, result.precision_at_k, result.recall_at_k, result.mrr
);
total_precision += result.precision_at_k;
total_recall += result.recall_at_k;
total_mrr += result.mrr;
}
let n = results.len() as f64;
let avg_precision = total_precision / n;
let avg_recall = total_recall / n;
let avg_mrr = total_mrr / n;
println!("{}", "-".repeat(70));
println!(
"{:<30} {:>10.3} {:>10.3} {:>10.3}",
"AVERAGE", avg_precision, avg_recall, avg_mrr
);
println!("{}", "=".repeat(70));
assert!(
avg_precision >= 0.6,
"Average precision@K = {:.3} is below 0.6 threshold",
avg_precision
);
assert!(
avg_recall >= 0.6,
"Average recall@K = {:.3} is below 0.6 threshold",
avg_recall
);
assert!(
avg_mrr >= 0.7,
"Average MRR = {:.3} is below 0.7 threshold",
avg_mrr
);
}
#[test]
fn retrieval_eval_per_case_assertions() {
let cases = build_eval_cases();
for case in &cases {
let result = run_lifecycle_eval(case);
let retrieved_set: HashSet<&str> =
result.retrieved_ids.iter().map(|s| s.as_str()).collect();
for expected in &case.expected_ids {
assert!(
retrieved_set.contains(expected),
"[{}] Expected '{}' in top-{} results, got: {:?}",
case.name,
expected,
case.k,
result.retrieved_ids
);
}
}
}