Skip to main content

mnemara_core/
evaluation.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeSet;
3
4use crate::Result;
5use crate::query::{RecallQuery, RecallResult};
6use crate::store::MemoryStore;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9pub struct JudgedRecallCase {
10    pub name: String,
11    pub query: RecallQuery,
12    pub relevant_record_ids: Vec<String>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
16#[serde(default)]
17pub struct RecallEvaluationAssertions {
18    pub expected_record_ids: Vec<String>,
19    pub optional_record_ids: Vec<String>,
20    pub disallowed_record_ids: Vec<String>,
21    pub required_explanation_notes: Vec<String>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub struct RecallEvaluationCase {
26    pub name: String,
27    pub query: RecallQuery,
28    pub assertions: RecallEvaluationAssertions,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32pub struct RecallEvaluationCaseReport {
33    pub name: String,
34    pub passed: bool,
35    pub ranked_record_ids: Vec<String>,
36    pub missing_expected_record_ids: Vec<String>,
37    pub present_disallowed_record_ids: Vec<String>,
38    pub missing_explanation_notes: Vec<String>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub struct RecallEvaluationReport {
43    pub cases: usize,
44    pub passed_cases: usize,
45    pub failed_cases: usize,
46    pub pass_rate: f32,
47    pub ranking_metrics: RankingMetrics,
48    pub case_reports: Vec<RecallEvaluationCaseReport>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub struct RankingMetrics {
53    pub cases: usize,
54    pub hit_rate_at_k: f32,
55    pub recall_at_k: f32,
56    pub mrr: f32,
57    pub ndcg_at_k: f32,
58}
59
60pub fn evaluate_rankings_at_k(rankings: &[(&[String], &[String])], k: usize) -> RankingMetrics {
61    if rankings.is_empty() || k == 0 {
62        return RankingMetrics {
63            cases: rankings.len(),
64            hit_rate_at_k: 0.0,
65            recall_at_k: 0.0,
66            mrr: 0.0,
67            ndcg_at_k: 0.0,
68        };
69    }
70
71    let mut hits = 0.0f32;
72    let mut recall = 0.0f32;
73    let mut reciprocal_rank = 0.0f32;
74    let mut ndcg = 0.0f32;
75
76    for (ranked_ids, relevant_ids) in rankings {
77        let relevant = relevant_ids.iter().cloned().collect::<BTreeSet<_>>();
78        if relevant.is_empty() {
79            continue;
80        }
81        let top_k = ranked_ids.iter().take(k).collect::<Vec<_>>();
82        let matches = top_k
83            .iter()
84            .filter(|record_id| relevant.contains(record_id.as_str()))
85            .count() as f32;
86        if matches > 0.0 {
87            hits += 1.0;
88        }
89        recall += matches / relevant.len() as f32;
90        if let Some(rank) = ranked_ids
91            .iter()
92            .position(|record_id| relevant.contains(record_id.as_str()))
93        {
94            reciprocal_rank += 1.0 / (rank as f32 + 1.0);
95        }
96
97        let dcg = top_k
98            .iter()
99            .enumerate()
100            .filter(|(_, record_id)| relevant.contains(record_id.as_str()))
101            .map(|(index, _)| 1.0 / ((index as f32 + 2.0).log2()))
102            .sum::<f32>();
103        let ideal_hits = relevant.len().min(k);
104        let ideal_dcg = (0..ideal_hits)
105            .map(|index| 1.0 / ((index as f32 + 2.0).log2()))
106            .sum::<f32>();
107        if ideal_dcg > 0.0 {
108            ndcg += dcg / ideal_dcg;
109        }
110    }
111
112    let cases = rankings.len() as f32;
113    RankingMetrics {
114        cases: rankings.len(),
115        hit_rate_at_k: hits / cases,
116        recall_at_k: recall / cases,
117        mrr: reciprocal_rank / cases,
118        ndcg_at_k: ndcg / cases,
119    }
120}
121
122pub async fn run_recall_evaluation<S>(
123    store: &S,
124    cases: &[RecallEvaluationCase],
125    k: usize,
126) -> Result<RecallEvaluationReport>
127where
128    S: MemoryStore + ?Sized,
129{
130    let mut results = Vec::with_capacity(cases.len());
131    for case in cases {
132        results.push(store.recall(case.query.clone()).await?);
133    }
134    Ok(evaluate_recall_results(cases, &results, k))
135}
136
137pub fn evaluate_recall_results(
138    cases: &[RecallEvaluationCase],
139    results: &[RecallResult],
140    k: usize,
141) -> RecallEvaluationReport {
142    let mut ranked = Vec::<Vec<String>>::with_capacity(results.len());
143    let mut relevant = Vec::<Vec<String>>::with_capacity(cases.len());
144    let mut case_reports = Vec::with_capacity(cases.len());
145
146    for (case, result) in cases.iter().zip(results.iter()) {
147        let ranked_record_ids = result
148            .hits
149            .iter()
150            .map(|hit| hit.record.id.clone())
151            .collect::<Vec<_>>();
152        let ranked_set = ranked_record_ids.iter().cloned().collect::<BTreeSet<_>>();
153        let expected_set = case
154            .assertions
155            .expected_record_ids
156            .iter()
157            .cloned()
158            .collect::<BTreeSet<_>>();
159        let optional = case
160            .assertions
161            .optional_record_ids
162            .iter()
163            .cloned()
164            .collect::<BTreeSet<_>>();
165        let mut relevant_ids = expected_set
166            .union(&optional)
167            .cloned()
168            .collect::<Vec<String>>();
169        relevant_ids.sort();
170
171        let missing_expected_record_ids = expected_set
172            .difference(&ranked_set)
173            .cloned()
174            .collect::<Vec<_>>();
175        let present_disallowed_record_ids = case
176            .assertions
177            .disallowed_record_ids
178            .iter()
179            .filter(|record_id| ranked_set.contains(*record_id))
180            .cloned()
181            .collect::<Vec<_>>();
182
183        let explanation_notes = result
184            .explanation
185            .as_ref()
186            .map(|explanation| explanation.policy_notes.as_slice())
187            .unwrap_or_default();
188        let missing_explanation_notes = case
189            .assertions
190            .required_explanation_notes
191            .iter()
192            .filter(|required| {
193                !explanation_notes
194                    .iter()
195                    .any(|note| note.contains(required.as_str()))
196            })
197            .cloned()
198            .collect::<Vec<_>>();
199
200        let passed = missing_expected_record_ids.is_empty()
201            && present_disallowed_record_ids.is_empty()
202            && missing_explanation_notes.is_empty();
203
204        ranked.push(ranked_record_ids.clone());
205        relevant.push(relevant_ids);
206        case_reports.push(RecallEvaluationCaseReport {
207            name: case.name.clone(),
208            passed,
209            ranked_record_ids,
210            missing_expected_record_ids,
211            present_disallowed_record_ids,
212            missing_explanation_notes,
213        });
214    }
215
216    let ranking_pairs = ranked
217        .iter()
218        .zip(relevant.iter())
219        .map(|(ranked_ids, relevant_ids)| (ranked_ids.as_slice(), relevant_ids.as_slice()))
220        .collect::<Vec<_>>();
221    let passed_cases = case_reports.iter().filter(|report| report.passed).count();
222    let cases_len = cases.len();
223
224    RecallEvaluationReport {
225        cases: cases_len,
226        passed_cases,
227        failed_cases: cases_len.saturating_sub(passed_cases),
228        pass_rate: if cases_len == 0 {
229            0.0
230        } else {
231            passed_cases as f32 / cases_len as f32
232        },
233        ranking_metrics: evaluate_rankings_at_k(&ranking_pairs, k),
234        case_reports,
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::{
241        RecallEvaluationAssertions, RecallEvaluationCase, evaluate_rankings_at_k,
242        evaluate_recall_results,
243    };
244    use crate::{
245        MemoryQualityState, MemoryRecord, MemoryRecordKind, MemoryScope, MemoryTrustLevel,
246        RecallExplanation, RecallFilters, RecallHit, RecallQuery, RecallResult,
247        RecallScoreBreakdown,
248    };
249    use std::collections::BTreeMap;
250
251    #[test]
252    fn computes_standard_ranking_metrics() {
253        let ranked_a = vec!["a".to_string(), "b".to_string(), "c".to_string()];
254        let relevant_a = vec!["b".to_string(), "d".to_string()];
255        let ranked_b = vec!["x".to_string(), "y".to_string(), "z".to_string()];
256        let relevant_b = vec!["x".to_string()];
257
258        let metrics =
259            evaluate_rankings_at_k(&[(&ranked_a, &relevant_a), (&ranked_b, &relevant_b)], 3);
260
261        assert_eq!(metrics.cases, 2);
262        assert!(metrics.hit_rate_at_k > 0.9);
263        assert!(metrics.recall_at_k > 0.7);
264        assert!(metrics.mrr > 0.6);
265        assert!(metrics.ndcg_at_k > 0.6);
266    }
267
268    fn scope() -> MemoryScope {
269        MemoryScope {
270            tenant_id: "tenant-a".to_string(),
271            namespace: "ops".to_string(),
272            actor_id: "ava".to_string(),
273            conversation_id: None,
274            session_id: None,
275            source: "test".to_string(),
276            labels: vec!["eval".to_string()],
277            trust_level: MemoryTrustLevel::Verified,
278        }
279    }
280
281    fn hit(record_id: &str) -> RecallHit {
282        RecallHit {
283            record: MemoryRecord {
284                id: record_id.to_string(),
285                scope: scope(),
286                kind: MemoryRecordKind::Fact,
287                content: format!("record {record_id}"),
288                summary: None,
289                source_id: None,
290                metadata: BTreeMap::new(),
291                quality_state: MemoryQualityState::Active,
292                created_at_unix_ms: 1,
293                updated_at_unix_ms: 1,
294                expires_at_unix_ms: None,
295                importance_score: 0.5,
296                artifact: None,
297                episode: None,
298                historical_state: Default::default(),
299                lineage: Vec::new(),
300                conflict: None,
301            },
302            breakdown: RecallScoreBreakdown {
303                lexical: 1.0,
304                semantic: 0.0,
305                graph: 0.0,
306                temporal: 0.0,
307                metadata: 0.0,
308                episodic: 0.0,
309                salience: 0.0,
310                curation: 0.0,
311                policy: 0.0,
312                total: 1.0,
313            },
314            explanation: None,
315        }
316    }
317
318    #[test]
319    fn evaluates_judged_recall_cases_with_assertions() {
320        let query = RecallQuery {
321            scope: scope(),
322            query_text: "release".to_string(),
323            max_items: 3,
324            token_budget: None,
325            filters: RecallFilters::default(),
326            include_explanation: true,
327        };
328        let cases = vec![RecallEvaluationCase {
329            name: "release memory".to_string(),
330            query,
331            assertions: RecallEvaluationAssertions {
332                expected_record_ids: vec!["a".to_string()],
333                optional_record_ids: vec!["b".to_string()],
334                disallowed_record_ids: vec!["x".to_string()],
335                required_explanation_notes: vec!["policy=general".to_string()],
336            },
337        }];
338        let results = vec![RecallResult {
339            hits: vec![hit("a"), hit("b")],
340            total_candidates_examined: 2,
341            explanation: Some(RecallExplanation {
342                selected_channels: vec!["lexical".to_string()],
343                policy_notes: vec!["recall_policy=general".to_string()],
344                trace_id: None,
345                planning_trace: None,
346                planning_profile: None,
347                policy_profile: None,
348                scorer_kind: None,
349                scoring_profile: None,
350            }),
351        }];
352
353        let report = evaluate_recall_results(&cases, &results, 3);
354
355        assert_eq!(report.cases, 1);
356        assert_eq!(report.passed_cases, 1);
357        assert_eq!(report.failed_cases, 0);
358        assert!(report.pass_rate > 0.99);
359        assert!(report.ranking_metrics.hit_rate_at_k > 0.99);
360        assert!(report.case_reports[0].passed);
361    }
362}