use std::collections::{HashMap, HashSet};
use std::path::Path;
use serde::{Deserialize, Serialize};
use terraphim_types::Thesaurus;
use crate::matcher::find_matches;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroundTruthDocument {
pub id: String,
pub text: String,
pub expected_terms: Vec<ExpectedMatch>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpectedMatch {
pub term: String,
pub category: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationMetrics {
pub precision: f64,
pub recall: f64,
pub f1: f64,
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TermReport {
pub term: String,
pub metrics: ClassificationMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvaluationResult {
pub total_documents: usize,
pub overall: ClassificationMetrics,
pub per_term: Vec<TermReport>,
pub systematic_errors: Vec<SystematicError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystematicError {
pub term: String,
pub false_positive_count: usize,
pub document_ids: Vec<String>,
}
const SYSTEMATIC_ERROR_THRESHOLD: usize = 2;
fn compute_metrics(tp: usize, fp: usize, fn_count: usize) -> ClassificationMetrics {
let precision = if tp + fp > 0 {
tp as f64 / (tp + fp) as f64
} else {
0.0
};
let recall = if tp + fn_count > 0 {
tp as f64 / (tp + fn_count) as f64
} else {
0.0
};
let f1 = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
ClassificationMetrics {
precision,
recall,
f1,
true_positives: tp,
false_positives: fp,
false_negatives: fn_count,
}
}
pub fn evaluate(ground_truth: &[GroundTruthDocument], thesaurus: Thesaurus) -> EvaluationResult {
let total_documents = ground_truth.len();
let mut total_tp: usize = 0;
let mut total_fp: usize = 0;
let mut total_fn: usize = 0;
let mut per_term_counts: HashMap<String, (usize, usize, usize)> = HashMap::new();
let mut fp_by_term: HashMap<String, Vec<String>> = HashMap::new();
for doc in ground_truth {
let matches = find_matches(&doc.text, thesaurus.clone(), false).unwrap_or_default();
let matched_nterms: HashSet<String> = matches
.iter()
.map(|m| m.normalized_term.value.as_str().to_string())
.collect();
let expected_nterms: HashSet<String> =
doc.expected_terms.iter().map(|e| e.term.clone()).collect();
let tp_set: HashSet<&String> = matched_nterms.intersection(&expected_nterms).collect();
let fp_set: HashSet<&String> = matched_nterms.difference(&expected_nterms).collect();
let fn_set: HashSet<&String> = expected_nterms.difference(&matched_nterms).collect();
let doc_tp = tp_set.len();
let doc_fp = fp_set.len();
let doc_fn = fn_set.len();
total_tp += doc_tp;
total_fp += doc_fp;
total_fn += doc_fn;
for term in &tp_set {
let entry = per_term_counts.entry((**term).clone()).or_insert((0, 0, 0));
entry.0 += 1;
}
for term in &fp_set {
let entry = per_term_counts.entry((**term).clone()).or_insert((0, 0, 0));
entry.1 += 1;
fp_by_term
.entry((**term).clone())
.or_default()
.push(doc.id.clone());
}
for term in &fn_set {
let entry = per_term_counts.entry((**term).clone()).or_insert((0, 0, 0));
entry.2 += 1;
}
}
let overall = compute_metrics(total_tp, total_fp, total_fn);
let mut per_term: Vec<TermReport> = per_term_counts
.into_iter()
.map(|(term, (tp, fp, fn_count))| TermReport {
term,
metrics: compute_metrics(tp, fp, fn_count),
})
.collect();
#[allow(clippy::unnecessary_sort_by)]
per_term.sort_by(|a, b| a.term.cmp(&b.term));
let mut systematic_errors: Vec<SystematicError> = fp_by_term
.into_iter()
.filter(|(_, doc_ids)| doc_ids.len() >= SYSTEMATIC_ERROR_THRESHOLD)
.map(|(term, document_ids)| SystematicError {
false_positive_count: document_ids.len(),
term,
document_ids,
})
.collect();
#[allow(clippy::unnecessary_sort_by)]
systematic_errors.sort_by(|a, b| a.term.cmp(&b.term));
EvaluationResult {
total_documents,
overall,
per_term,
systematic_errors,
}
}
pub fn load_ground_truth(
path: &Path,
) -> std::result::Result<Vec<GroundTruthDocument>, std::io::Error> {
let content = std::fs::read_to_string(path)?;
serde_json::from_str(&content)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
#[cfg(test)]
mod tests {
use super::*;
use terraphim_types::{NormalizedTerm, NormalizedTermValue, Thesaurus};
fn build_test_thesaurus(entries: &[(&str, &str)]) -> Thesaurus {
let mut thesaurus = Thesaurus::new("test".to_string());
for (i, (pattern, nterm)) in entries.iter().enumerate() {
let term = NormalizedTerm::new((i + 1) as u64, NormalizedTermValue::from(*nterm));
thesaurus.insert(NormalizedTermValue::from(*pattern), term);
}
thesaurus
}
#[test]
fn test_evaluate_perfect_match() {
let thesaurus = build_test_thesaurus(&[("rust", "rust"), ("async", "async")]);
let ground_truth = vec![GroundTruthDocument {
id: "doc1".to_string(),
text: "I love rust and async programming".to_string(),
expected_terms: vec![
ExpectedMatch {
term: "rust".to_string(),
category: None,
},
ExpectedMatch {
term: "async".to_string(),
category: None,
},
],
}];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.total_documents, 1);
assert_eq!(result.overall.true_positives, 2);
assert_eq!(result.overall.false_positives, 0);
assert_eq!(result.overall.false_negatives, 0);
assert!((result.overall.precision - 1.0).abs() < f64::EPSILON);
assert!((result.overall.recall - 1.0).abs() < f64::EPSILON);
assert!((result.overall.f1 - 1.0).abs() < f64::EPSILON);
assert!(result.systematic_errors.is_empty());
}
#[test]
fn test_evaluate_no_matches() {
let thesaurus = build_test_thesaurus(&[("python", "python"), ("java", "java")]);
let ground_truth = vec![GroundTruthDocument {
id: "doc1".to_string(),
text: "I love rust and async programming".to_string(),
expected_terms: vec![
ExpectedMatch {
term: "rust".to_string(),
category: None,
},
ExpectedMatch {
term: "async".to_string(),
category: None,
},
],
}];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.overall.true_positives, 0);
assert_eq!(result.overall.false_positives, 0);
assert_eq!(result.overall.false_negatives, 2);
assert!((result.overall.precision - 0.0).abs() < f64::EPSILON);
assert!((result.overall.recall - 0.0).abs() < f64::EPSILON);
assert!((result.overall.f1 - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_evaluate_partial_match() {
let thesaurus = build_test_thesaurus(&[("rust", "rust")]);
let ground_truth = vec![GroundTruthDocument {
id: "doc1".to_string(),
text: "I love rust and async programming".to_string(),
expected_terms: vec![
ExpectedMatch {
term: "rust".to_string(),
category: None,
},
ExpectedMatch {
term: "async".to_string(),
category: None,
},
],
}];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.overall.true_positives, 1);
assert_eq!(result.overall.false_positives, 0);
assert_eq!(result.overall.false_negatives, 1);
assert!((result.overall.precision - 1.0).abs() < f64::EPSILON);
assert!((result.overall.recall - 0.5).abs() < f64::EPSILON);
let expected_f1 = 2.0 / 3.0;
assert!((result.overall.f1 - expected_f1).abs() < 1e-10);
}
#[test]
fn test_evaluate_false_positives() {
let thesaurus = build_test_thesaurus(&[("rust", "rust"), ("love", "love")]);
let ground_truth = vec![GroundTruthDocument {
id: "doc1".to_string(),
text: "I love rust programming".to_string(),
expected_terms: vec![ExpectedMatch {
term: "rust".to_string(),
category: None,
}],
}];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.overall.true_positives, 1);
assert_eq!(result.overall.false_positives, 1);
assert_eq!(result.overall.false_negatives, 0);
assert!((result.overall.precision - 0.5).abs() < f64::EPSILON);
assert!((result.overall.recall - 1.0).abs() < f64::EPSILON);
let expected_f1 = 2.0 / 3.0;
assert!((result.overall.f1 - expected_f1).abs() < 1e-10);
}
#[test]
fn test_evaluate_systematic_errors() {
let thesaurus =
build_test_thesaurus(&[("rust", "rust"), ("the", "the"), ("async", "async")]);
let ground_truth = vec![
GroundTruthDocument {
id: "doc1".to_string(),
text: "the rust language is great".to_string(),
expected_terms: vec![ExpectedMatch {
term: "rust".to_string(),
category: None,
}],
},
GroundTruthDocument {
id: "doc2".to_string(),
text: "the async runtime is powerful".to_string(),
expected_terms: vec![ExpectedMatch {
term: "async".to_string(),
category: None,
}],
},
GroundTruthDocument {
id: "doc3".to_string(),
text: "the compiler catches errors at compile time".to_string(),
expected_terms: vec![],
},
];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.systematic_errors.len(), 1);
let error = &result.systematic_errors[0];
assert_eq!(error.term, "the");
assert_eq!(error.false_positive_count, 3);
assert_eq!(error.document_ids.len(), 3);
}
#[test]
fn test_evaluate_empty_ground_truth() {
let thesaurus = build_test_thesaurus(&[("rust", "rust")]);
let ground_truth: Vec<GroundTruthDocument> = vec![];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.total_documents, 0);
assert_eq!(result.overall.true_positives, 0);
assert_eq!(result.overall.false_positives, 0);
assert_eq!(result.overall.false_negatives, 0);
assert!((result.overall.precision - 0.0).abs() < f64::EPSILON);
assert!((result.overall.recall - 0.0).abs() < f64::EPSILON);
assert!((result.overall.f1 - 0.0).abs() < f64::EPSILON);
assert!(result.per_term.is_empty());
assert!(result.systematic_errors.is_empty());
}
#[test]
fn test_evaluate_per_term_metrics() {
let thesaurus =
build_test_thesaurus(&[("rust", "rust"), ("async", "async"), ("tokio", "tokio")]);
let ground_truth = vec![
GroundTruthDocument {
id: "doc1".to_string(),
text: "rust and async are great together".to_string(),
expected_terms: vec![
ExpectedMatch {
term: "rust".to_string(),
category: None,
},
ExpectedMatch {
term: "async".to_string(),
category: None,
},
],
},
GroundTruthDocument {
id: "doc2".to_string(),
text: "tokio powers async rust".to_string(),
expected_terms: vec![
ExpectedMatch {
term: "tokio".to_string(),
category: None,
},
ExpectedMatch {
term: "rust".to_string(),
category: None,
},
ExpectedMatch {
term: "async".to_string(),
category: None,
},
],
},
];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.overall.true_positives, 5);
assert_eq!(result.overall.false_positives, 0);
assert_eq!(result.overall.false_negatives, 0);
assert_eq!(result.per_term.len(), 3);
let async_report = result.per_term.iter().find(|r| r.term == "async").unwrap();
assert_eq!(async_report.metrics.true_positives, 2);
assert!((async_report.metrics.precision - 1.0).abs() < f64::EPSILON);
let rust_report = result.per_term.iter().find(|r| r.term == "rust").unwrap();
assert_eq!(rust_report.metrics.true_positives, 2);
assert!((rust_report.metrics.precision - 1.0).abs() < f64::EPSILON);
let tokio_report = result.per_term.iter().find(|r| r.term == "tokio").unwrap();
assert_eq!(tokio_report.metrics.true_positives, 1);
assert!((tokio_report.metrics.precision - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_load_ground_truth() {
let dir = tempfile::tempdir().unwrap();
let file_path = dir.path().join("ground_truth.json");
let data = serde_json::json!([
{
"id": "doc1",
"text": "hello world",
"expected_terms": [
{"term": "hello", "category": null}
]
},
{
"id": "doc2",
"text": "foo bar baz",
"expected_terms": [
{"term": "foo", "category": "test"},
{"term": "bar", "category": null}
]
}
]);
std::fs::write(&file_path, serde_json::to_string_pretty(&data).unwrap()).unwrap();
let docs = load_ground_truth(&file_path).unwrap();
assert_eq!(docs.len(), 2);
assert_eq!(docs[0].id, "doc1");
assert_eq!(docs[0].text, "hello world");
assert_eq!(docs[0].expected_terms.len(), 1);
assert_eq!(docs[0].expected_terms[0].term, "hello");
assert_eq!(docs[1].id, "doc2");
assert_eq!(docs[1].expected_terms.len(), 2);
assert_eq!(docs[1].expected_terms[0].category, Some("test".to_string()));
}
#[test]
fn test_load_ground_truth_invalid_file() {
let dir = tempfile::tempdir().unwrap();
let file_path = dir.path().join("bad.json");
std::fs::write(&file_path, "not valid json").unwrap();
let result = load_ground_truth(&file_path);
assert!(result.is_err());
}
#[test]
fn test_load_ground_truth_missing_file() {
let result = load_ground_truth(Path::new("/nonexistent/path/file.json"));
assert!(result.is_err());
}
#[test]
fn test_compute_metrics_all_zero() {
let m = compute_metrics(0, 0, 0);
assert!((m.precision - 0.0).abs() < f64::EPSILON);
assert!((m.recall - 0.0).abs() < f64::EPSILON);
assert!((m.f1 - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_compute_metrics_perfect() {
let m = compute_metrics(10, 0, 0);
assert!((m.precision - 1.0).abs() < f64::EPSILON);
assert!((m.recall - 1.0).abs() < f64::EPSILON);
assert!((m.f1 - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_evaluate_case_insensitive_matching() {
let thesaurus = build_test_thesaurus(&[("rust", "rust")]);
let ground_truth = vec![GroundTruthDocument {
id: "doc1".to_string(),
text: "I love Rust programming".to_string(),
expected_terms: vec![ExpectedMatch {
term: "rust".to_string(),
category: None,
}],
}];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.overall.true_positives, 1);
assert_eq!(result.overall.false_positives, 0);
assert_eq!(result.overall.false_negatives, 0);
}
#[test]
fn test_evaluate_multiple_docs_aggregation() {
let thesaurus = build_test_thesaurus(&[("rust", "rust"), ("go lang", "go lang")]);
let ground_truth = vec![
GroundTruthDocument {
id: "doc1".to_string(),
text: "rust is great".to_string(),
expected_terms: vec![ExpectedMatch {
term: "rust".to_string(),
category: None,
}],
},
GroundTruthDocument {
id: "doc2".to_string(),
text: "go lang is also great".to_string(),
expected_terms: vec![
ExpectedMatch {
term: "go lang".to_string(),
category: None,
},
ExpectedMatch {
term: "rust".to_string(),
category: None,
},
],
},
];
let result = evaluate(&ground_truth, thesaurus);
assert_eq!(result.overall.true_positives, 2);
assert_eq!(result.overall.false_positives, 0);
assert_eq!(result.overall.false_negatives, 1);
assert!((result.overall.precision - 1.0).abs() < f64::EPSILON);
let expected_recall = 2.0 / 3.0;
assert!((result.overall.recall - expected_recall).abs() < 1e-10);
}
}