use std::collections::BTreeMap;
use std::time::{Duration, Instant};
use difflore_core::CoreError;
use difflore_core::context::eval as golden;
use difflore_core::context::retrieval::{self, RuleSearchRetrievalOptions};
use difflore_core::context::rule_source::RuleDocument;
use difflore_core::context::{index_db, rule_source};
use crate::runtime::CommandContext;
use crate::style::{self, sym};
const STOP: &[&str] = &[
"the", "a", "an", "and", "or", "of", "to", "for", "in", "on", "at", "by", "with", "when",
"use", "using", "as", "is", "are", "be", "this", "that", "from", "into", "do", "not", "should",
"must", "via", "than", "then", "but", "if", "else",
];
const DEFAULT_SAMPLES: usize = 20;
pub(crate) struct SelfRecallSample {
pub skill_id: String,
pub query: String,
pub language: Option<String>,
}
#[derive(Default)]
pub(crate) struct SelfRecallReport {
pub tested: usize,
pub hits_at_1: usize,
pub hits_at_5: usize,
pub reciprocal_rank_sum: f64,
pub per_lang: BTreeMap<String, (usize, usize, usize, f64)>,
pub retrieval_errors: usize,
pub rate_limited_errors: usize,
pub embed_cap_errors: usize,
latency_ms: Vec<u64>,
}
impl SelfRecallReport {
pub fn at5_pct(&self) -> f64 {
pct(self.hits_at_5, self.tested)
}
pub fn at1_pct(&self) -> f64 {
pct(self.hits_at_1, self.tested)
}
pub fn mrr(&self) -> f64 {
if self.tested == 0 {
0.0
} else {
self.reciprocal_rank_sum / self.tested as f64
}
}
pub fn avg_latency_ms(&self) -> u64 {
if self.latency_ms.is_empty() {
0
} else {
self.latency_ms.iter().sum::<u64>() / self.latency_ms.len() as u64
}
}
pub fn max_latency_ms(&self) -> u64 {
self.latency_ms.iter().copied().max().unwrap_or(0)
}
pub fn p95_latency_ms(&self) -> u64 {
percentile_latency_ms(&self.latency_ms, 0.95)
}
pub const fn latency_samples(&self) -> usize {
self.latency_ms.len()
}
}
fn pct(hits: usize, total: usize) -> f64 {
if total == 0 {
0.0
} else {
(hits as f64 / total as f64) * 100.0
}
}
pub(crate) fn self_recall_query(content: &str) -> String {
let body = content.split_once("\n\n").map_or(content, |(_, rest)| rest);
let mut out: Vec<&str> = Vec::new();
for word in body.split_whitespace() {
let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
if trimmed.is_empty() || STOP.contains(&trimmed.to_ascii_lowercase().as_str()) {
continue;
}
out.push(word);
if out.len() >= 8 {
break;
}
}
out.join(" ")
}
pub(crate) fn build_samples(rules: &[RuleDocument], n_target: usize) -> Vec<SelfRecallSample> {
if rules.is_empty() || n_target == 0 {
return Vec::new();
}
let sample_count = n_target.min(rules.len());
let mut indices = Vec::with_capacity(rules.len());
for i in 0..sample_count {
let index = if sample_count == 1 {
0
} else {
i * (rules.len() - 1) / (sample_count - 1)
};
if !indices.contains(&index) {
indices.push(index);
}
}
for index in 0..rules.len() {
if !indices.contains(&index) {
indices.push(index);
}
}
let mut samples = Vec::with_capacity(sample_count);
for index in indices {
if samples.len() >= sample_count {
break;
}
let rule = &rules[index];
let query = self_recall_query(&rule.content);
if !query.is_empty() {
samples.push(SelfRecallSample {
skill_id: rule.skill_id.clone(),
query,
language: rule.language.clone(),
});
}
}
samples
}
pub(crate) async fn measure_self_recall(
index_pool: &difflore_core::SqlitePool,
samples: &[SelfRecallSample],
embedding_timeout: Option<Duration>,
) -> SelfRecallReport {
let mut report = SelfRecallReport::default();
for sample in samples {
let lang_key = sample.language.as_deref().unwrap_or("(unknown)").to_owned();
let entry = report.per_lang.entry(lang_key).or_insert((0, 0, 0, 0.0));
entry.0 += 1;
report.tested += 1;
let query_started = Instant::now();
let hits = match retrieval::retrieve_rules_for_search(
index_pool,
RuleSearchRetrievalOptions {
query: &sample.query,
lexical_query: &sample.query,
top_k: 5,
confidence_map: None,
age_days_map: None,
effectiveness_map: None,
target_scope: None,
repo_scopes: &[],
ann_enabled: false,
local_query_embedding: false,
embedding_timeout,
cold_start_retry: false,
adaptive_prune: false,
},
)
.await
{
Ok(hits) => {
report.latency_ms.push(duration_ms(query_started.elapsed()));
hits
}
Err(e) => {
report.latency_ms.push(duration_ms(query_started.elapsed()));
report.retrieval_errors += 1;
if is_embed_cap_error(&e) {
report.embed_cap_errors += 1;
}
if is_rate_limited_error(&e) {
report.rate_limited_errors += 1;
}
continue;
}
};
if let Some(pos) = hits.iter().position(|h| h.skill_id == sample.skill_id) {
report.hits_at_5 += 1;
entry.2 += 1;
let reciprocal_rank = 1.0 / (pos as f64 + 1.0);
report.reciprocal_rank_sum += reciprocal_rank;
entry.3 += reciprocal_rank;
if pos == 0 {
report.hits_at_1 += 1;
entry.1 += 1;
}
}
}
report
}
fn duration_ms(duration: Duration) -> u64 {
u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
}
fn percentile_latency_ms(samples: &[u64], percentile: f64) -> u64 {
if samples.is_empty() {
return 0;
}
let mut sorted = samples.to_vec();
sorted.sort_unstable();
let pct = percentile.clamp(0.0, 1.0);
let idx = ((sorted.len() - 1) as f64 * pct).ceil() as usize;
sorted[idx]
}
const fn is_embed_cap_error(err: &CoreError) -> bool {
matches!(err, CoreError::EmbedCapReached { .. })
}
fn is_rate_limited_error(err: &CoreError) -> bool {
let text = difflore_core::error::error_chain_text(err).to_ascii_lowercase();
text.contains("rate limit") || text.contains("too many requests") || text.contains("429")
}
pub(crate) fn at5_mark(pct: f64) -> &'static str {
if pct >= 80.0 {
sym::OK
} else if pct >= 50.0 {
sym::WARN
} else {
sym::ERR
}
}
pub(crate) fn at1_mark(pct: f64) -> &'static str {
if pct >= 50.0 {
sym::OK
} else if pct >= 25.0 {
sym::WARN
} else {
sym::ERR
}
}
pub(crate) fn mrr_mark(mrr: f64) -> &'static str {
if mrr >= 0.7 {
sym::OK
} else if mrr >= 0.5 {
sym::WARN
} else {
sym::ERR
}
}
pub(crate) async fn handle_eval(ctx: &CommandContext, samples: Option<usize>, json: bool) {
let started = Instant::now();
let n = samples.unwrap_or(DEFAULT_SAMPLES).clamp(1, 200);
let rules = match rule_source::load_rules_from_db(&ctx.db).await {
Ok(r) => r,
Err(e) => {
style::report_error("could not load rules for eval", &e.to_string(), &[]);
return;
}
};
if rules.len() < 5 {
emit_too_few(rules.len(), json);
return;
}
if !json {
eprintln!(
" {} measuring recall over {} rules ({} sample{})…",
style::pewter(sym::BULLET),
rules.len(),
n,
if n == 1 { "" } else { "s" },
);
}
let tmp = match tempfile::tempdir() {
Ok(t) => t,
Err(e) => {
style::report_error("could not create eval index", &e.to_string(), &[]);
return;
}
};
let index_pool = match index_db::open_index_pool_at(&tmp.path().join("eval.db")).await {
Ok(p) => p,
Err(e) => {
style::report_error("could not open eval index", &e.to_string(), &[]);
return;
}
};
if let Err(e) = index_db::upsert_rule_chunks_isolated(&index_pool, &rules).await {
style::report_error("could not build eval index", &e.to_string(), &[]);
return;
}
let sample_set = build_samples(&rules, n);
let report = measure_self_recall(&index_pool, &sample_set, None).await;
if json {
emit_json(&report, rules.len(), started.elapsed());
} else {
emit_text(&report, rules.len(), started.elapsed());
}
}
fn emit_too_few(count: usize, json: bool) {
if json {
println!(
"{}",
serde_json::json!({ "ok": false, "reason": "too_few_rules", "rules": count })
);
} else {
println!(
" {} only {count} rule(s) — need ≥5 to measure recall. Try {} or {}.",
style::pewter(sym::WARN),
style::cmd("difflore try"),
style::cmd("difflore import-reviews"),
);
}
}
fn emit_text(report: &SelfRecallReport, corpus: usize, elapsed: Duration) {
let at5 = report.at5_pct();
let at1 = report.at1_pct();
let mrr = report.mrr();
println!();
println!(
" {} {}",
style::cmd("difflore eval"),
style::pewter(
"· self-recall sanity check · local lexical (SHA1) · the reranked search path"
),
);
println!(
" {}",
style::pewter(
"query = the rule's own text → an optimistic upper bound, NOT real-world recall"
),
);
println!();
println!(
" {} self-recall@5 {}/{} ({:.0}%)",
style::pewter(at5_mark(at5)),
report.hits_at_5,
report.tested,
at5,
);
println!(
" {} self-recall@1 {}/{} ({:.0}%)",
style::pewter(at1_mark(at1)),
report.hits_at_1,
report.tested,
at1,
);
println!(
" {} MRR {:.3}",
style::pewter(mrr_mark(mrr)),
mrr,
);
println!(
" {} latency avg {} ms · p95 {} ms · max {} ms ({} query sample{})",
style::pewter(sym::BULLET),
report.avg_latency_ms(),
report.p95_latency_ms(),
report.max_latency_ms(),
report.latency_samples(),
if report.latency_samples() == 1 {
""
} else {
"s"
},
);
println!(
" {} retrieval errs {} · rate-limit {} · embed-cap {}",
style::pewter(if report.retrieval_errors == 0 {
sym::OK
} else {
sym::WARN
}),
report.retrieval_errors,
report.rate_limited_errors,
report.embed_cap_errors,
);
let by_lang = top_languages(report, 4);
if by_lang.len() >= 2 {
println!();
println!(" {}", style::pewter("by language:"));
for (lang, (n, h1, h5, rr)) in by_lang {
let lang_mrr = if n == 0 { 0.0 } else { rr / n as f64 };
println!(
" {} @1 {}/{} · @5 {}/{} · MRR {:.2}",
style::pewter(&lang),
h1,
n,
h5,
n,
lang_mrr,
);
}
}
println!();
println!(
" {}",
style::pewter(&format!(
"{corpus} rules · {} sampled · {} ms · same rerank path recall/fix/MCP/hook use",
report.tested,
elapsed.as_millis(),
)),
);
println!(
" {}",
style::pewter("real-world paraphrase recall needs separate task-query evaluation"),
);
if at5 < 80.0 {
println!(
" {} low @5 — semantic embeddings usually lift ranking: {} or {}",
style::pewter(sym::TIP),
style::cmd("difflore cloud login"),
style::cmd("difflore embeddings setup"),
);
}
}
fn emit_json(report: &SelfRecallReport, corpus: usize, elapsed: Duration) {
let by_lang: serde_json::Map<String, serde_json::Value> = report
.per_lang
.iter()
.map(|(lang, (n, h1, h5, rr))| {
let mrr = if *n == 0 { 0.0 } else { rr / *n as f64 };
(
lang.clone(),
serde_json::json!({ "n": n, "at1": h1, "at5": h5, "mrr": mrr }),
)
})
.collect();
println!(
"{}",
serde_json::json!({
"ok": true,
"mode": "sha1",
"metric": "self-recall",
"real_world_recall_note": "requires separate task-query evaluation",
"path": "reranked_search",
"corpus_rules": corpus,
"samples": report.tested,
"at1": report.hits_at_1,
"at5": report.hits_at_5,
"at5_pct": report.at5_pct(),
"at1_pct": report.at1_pct(),
"mrr": report.mrr(),
"elapsed_ms": elapsed.as_millis(),
"latency": {
"samples": report.latency_samples(),
"avg_ms": report.avg_latency_ms(),
"p95_ms": report.p95_latency_ms(),
"max_ms": report.max_latency_ms(),
},
"retrieval_errors": report.retrieval_errors,
"rate_limited_errors": report.rate_limited_errors,
"embed_cap_errors": report.embed_cap_errors,
"by_language": by_lang,
})
);
}
fn top_languages(
report: &SelfRecallReport,
limit: usize,
) -> Vec<(String, (usize, usize, usize, f64))> {
let mut entries: Vec<(String, (usize, usize, usize, f64))> = report
.per_lang
.iter()
.map(|(k, v)| (k.clone(), *v))
.collect();
entries.sort_by(|a, b| b.1.0.cmp(&a.1.0).then_with(|| a.0.cmp(&b.0)));
if entries.len() <= limit {
return entries;
}
let (top, rest) = entries.split_at(limit);
let mut out = top.to_vec();
let folded = rest.iter().fold((0, 0, 0, 0.0), |acc, (_, t)| {
(acc.0 + t.0, acc.1 + t.1, acc.2 + t.2, acc.3 + t.3)
});
out.push(("other".to_owned(), folded));
out
}
pub(crate) async fn handle_golden_eval(json: bool) {
let fixture = match golden::parse_golden_fixture(golden::GOLDEN_SMOKE_FIXTURE) {
Ok(f) => f,
Err(e) => {
style::report_error("could not parse golden fixture", &e.to_string(), &[]);
return;
}
};
let docs = golden::golden_rules_to_documents(&fixture);
let tmp = match tempfile::tempdir() {
Ok(t) => t,
Err(e) => {
style::report_error("could not create eval index", &e.to_string(), &[]);
return;
}
};
let index_pool = match index_db::open_index_pool_at(&tmp.path().join("golden.db")).await {
Ok(p) => p,
Err(e) => {
style::report_error("could not open eval index", &e.to_string(), &[]);
return;
}
};
if let Err(e) = index_db::upsert_rule_chunks_isolated(&index_pool, &docs).await {
style::report_error("could not build eval index", &e.to_string(), &[]);
return;
}
let top_k = fixture.rules.len().max(golden::GOLDEN_K);
let report = match golden::score_golden_cases(&index_pool, &fixture, top_k).await {
Ok(r) => r,
Err(e) => {
style::report_error("golden eval failed", &e.to_string(), &[]);
return;
}
};
if json {
emit_golden_json(&report);
} else {
emit_golden_text(&report);
}
}
const fn golden_mark(ok: bool) -> &'static str {
if ok { sym::OK } else { sym::ERR }
}
fn emit_golden_text(report: &golden::GoldenReport) {
let recall_ok = report.mean_recall_at_k >= 0.8;
let forbid_ok = report.positive_forbidden_hits == 0;
let abstain_ok = report.negative_clean == report.negative_cases;
println!();
println!(
" {} {}",
style::cmd("difflore eval --golden"),
style::pewter(
"· golden-case precision/recall · local lexical (SHA1) · the reranked search path"
),
);
println!();
println!(
" {} recall@{} {:.0}% (mean over {} positive case{})",
style::pewter(golden_mark(recall_ok)),
report.k,
report.mean_recall_at_k * 100.0,
report.positive_cases,
if report.positive_cases == 1 { "" } else { "s" },
);
println!(
" {} precision@{} {:.0}%",
style::pewter(sym::BULLET),
report.k,
report.mean_precision_at_k * 100.0,
);
println!(
" {} MRR {:.3}",
style::pewter(sym::BULLET),
report.mean_reciprocal_rank,
);
println!(
" {} forbidden leak {} (in top-{} of positive cases — must be 0)",
style::pewter(golden_mark(forbid_ok)),
report.positive_forbidden_hits,
report.k,
);
println!(
" {} abstention {}/{} doc-only case{} recalled nothing forbidden",
style::pewter(golden_mark(abstain_ok)),
report.negative_clean,
report.negative_cases,
if report.negative_cases == 1 { "" } else { "s" },
);
if report.strict_file_total > 0 {
println!(
" {} strict-file {}/{} recalled rules matched the edited file's globs",
style::pewter(golden_mark(
report.strict_file_correct == report.strict_file_total
)),
report.strict_file_correct,
report.strict_file_total,
);
}
println!();
println!(" {}", style::pewter("per case:"));
for case in &report.cases {
let rank = case
.first_relevant_rank
.map_or_else(|| "—".to_owned(), |r| format!("#{r}"));
let detail = if case.expected == 0 {
format!(
"abstain {}",
if case.abstained_correctly == Some(true) {
"clean"
} else {
"LEAKED"
},
)
} else {
format!(
"expected @{rank} · forbidden in top-{}: {}",
report.k, case.forbidden_hits
)
};
println!(
" {} {}",
style::pewter(&case.case_id),
style::pewter(&detail)
);
}
println!();
println!(
" {}",
style::pewter("self-recall is an upper bound; this is the paraphrase-recall guardrail"),
);
}
fn emit_golden_json(report: &golden::GoldenReport) {
match serde_json::to_string(&serde_json::json!({
"ok": true,
"mode": "sha1",
"metric": "golden-case",
"path": "reranked_search",
"k": report.k,
"total_cases": report.total_cases,
"positive_cases": report.positive_cases,
"negative_cases": report.negative_cases,
"mean_recall_at_k": report.mean_recall_at_k,
"mean_precision_at_k": report.mean_precision_at_k,
"mean_reciprocal_rank": report.mean_reciprocal_rank,
"positive_forbidden_hits": report.positive_forbidden_hits,
"negative_clean": report.negative_clean,
"strict_file_correct": report.strict_file_correct,
"strict_file_total": report.strict_file_total,
"cases": report.cases,
})) {
Ok(s) => println!("{s}"),
Err(e) => style::report_error("could not render golden json", &e.to_string(), &[]),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn self_recall_query_takes_body_significant_words() {
let content = "Rule ID: x\nRule Name: T\nType: review\nSource: r\nTags: t\n\n\
Return 413 when the request body exceeds the configured size limit always";
let q = self_recall_query(content);
assert_eq!(q.split_whitespace().count(), 8);
assert!(q.starts_with("Return 413"), "got {q:?}");
assert!(
!q.to_lowercase().contains(" the "),
"stop-words must be dropped: {q:?}"
);
}
#[test]
fn report_math_matches_definitions() {
let mut r = SelfRecallReport {
tested: 4,
hits_at_1: 2,
hits_at_5: 3,
reciprocal_rank_sum: 1.0 + 0.5 + 0.25, per_lang: BTreeMap::new(),
retrieval_errors: 0,
rate_limited_errors: 0,
embed_cap_errors: 0,
latency_ms: Vec::new(),
};
assert!((r.at5_pct() - 75.0).abs() < 1e-9);
assert!((r.at1_pct() - 50.0).abs() < 1e-9);
assert!((r.mrr() - (1.75 / 4.0)).abs() < 1e-9);
r.tested = 0;
assert!(r.mrr().abs() < 1e-9, "no divide-by-zero on empty");
}
#[test]
fn latency_percentiles_are_stable_for_small_samples() {
let report = SelfRecallReport {
latency_ms: vec![10, 30, 20, 100],
..SelfRecallReport::default()
};
assert_eq!(report.avg_latency_ms(), 40);
assert_eq!(report.p95_latency_ms(), 100);
assert_eq!(report.max_latency_ms(), 100);
assert_eq!(percentile_latency_ms(&[10, 20, 30, 40], 0.50), 30);
}
#[test]
fn retrieval_error_classifier_counts_rate_limit_and_embed_cap() {
let cap = CoreError::EmbedCapReached {
cap: 200,
used: 200,
};
assert!(is_embed_cap_error(&cap));
assert!(is_rate_limited_error(&CoreError::Internal(
"provider returned 429 Too Many Requests".to_owned()
)));
}
#[test]
fn marks_follow_thresholds() {
assert_eq!(mrr_mark(0.7), sym::OK);
assert_eq!(mrr_mark(0.5), sym::WARN);
assert_eq!(mrr_mark(0.49), sym::ERR);
assert_eq!(at5_mark(80.0), sym::OK);
assert_eq!(at1_mark(24.0), sym::ERR);
}
#[test]
fn build_samples_is_deterministic_and_capped() {
let rules: Vec<RuleDocument> = (0..50)
.map(|i| RuleDocument {
skill_id: format!("r{i}"),
title: format!("t{i}"),
content: format!("Rule ID: r{i}\nRule Name: t{i}\n\nbody token alpha{i} bravo"),
confidence: 0.7,
file_patterns: None,
language: None,
repo_scope: None,
})
.collect();
let a = build_samples(&rules, 10);
let b = build_samples(&rules, 10);
assert_eq!(a.len(), 10);
assert_eq!(build_samples(&rules, 20).len(), 20);
assert_eq!(build_samples(&rules, 100).len(), 50);
assert_eq!(
build_samples(&rules[..10], 4)
.iter()
.map(|s| s.skill_id.as_str())
.collect::<Vec<_>>(),
vec!["r0", "r3", "r6", "r9"]
);
assert_eq!(
a.iter().map(|s| &s.skill_id).collect::<Vec<_>>(),
b.iter().map(|s| &s.skill_id).collect::<Vec<_>>(),
"stride sampling must be deterministic"
);
}
#[test]
fn build_samples_backfills_when_stride_hits_empty_queries() {
let rules: Vec<RuleDocument> = (0..6)
.map(|i| RuleDocument {
skill_id: format!("r{i}"),
title: format!("t{i}"),
content: if i == 0 {
String::new()
} else {
format!("Rule ID: r{i}\nRule Name: t{i}\n\nbody token alpha{i} bravo")
},
confidence: 0.7,
file_patterns: None,
language: None,
repo_scope: None,
})
.collect();
let samples = build_samples(&rules, 3);
assert_eq!(samples.len(), 3);
let sample_ids = samples
.iter()
.map(|sample| sample.skill_id.as_str())
.collect::<Vec<_>>();
assert!(
sample_ids.iter().all(|sample_id| *sample_id != "r0"),
"{sample_ids:?}"
);
}
}