1use serde::{Deserialize, Serialize};
2use std::collections::BTreeSet;
3
4use crate::Result;
5use crate::query::{RecallQuery, RecallResult};
6use crate::store::MemoryStore;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9pub struct JudgedRecallCase {
10 pub name: String,
11 pub query: RecallQuery,
12 pub relevant_record_ids: Vec<String>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
16#[serde(default)]
17pub struct RecallEvaluationAssertions {
18 pub expected_record_ids: Vec<String>,
19 pub optional_record_ids: Vec<String>,
20 pub disallowed_record_ids: Vec<String>,
21 pub required_explanation_notes: Vec<String>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub struct RecallEvaluationCase {
26 pub name: String,
27 pub query: RecallQuery,
28 pub assertions: RecallEvaluationAssertions,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32pub struct RecallEvaluationCaseReport {
33 pub name: String,
34 pub passed: bool,
35 pub ranked_record_ids: Vec<String>,
36 pub missing_expected_record_ids: Vec<String>,
37 pub present_disallowed_record_ids: Vec<String>,
38 pub missing_explanation_notes: Vec<String>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub struct RecallEvaluationReport {
43 pub cases: usize,
44 pub passed_cases: usize,
45 pub failed_cases: usize,
46 pub pass_rate: f32,
47 pub ranking_metrics: RankingMetrics,
48 pub case_reports: Vec<RecallEvaluationCaseReport>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub struct RankingMetrics {
53 pub cases: usize,
54 pub hit_rate_at_k: f32,
55 pub recall_at_k: f32,
56 pub mrr: f32,
57 pub ndcg_at_k: f32,
58}
59
60pub fn evaluate_rankings_at_k(rankings: &[(&[String], &[String])], k: usize) -> RankingMetrics {
61 if rankings.is_empty() || k == 0 {
62 return RankingMetrics {
63 cases: rankings.len(),
64 hit_rate_at_k: 0.0,
65 recall_at_k: 0.0,
66 mrr: 0.0,
67 ndcg_at_k: 0.0,
68 };
69 }
70
71 let mut hits = 0.0f32;
72 let mut recall = 0.0f32;
73 let mut reciprocal_rank = 0.0f32;
74 let mut ndcg = 0.0f32;
75
76 for (ranked_ids, relevant_ids) in rankings {
77 let relevant = relevant_ids.iter().cloned().collect::<BTreeSet<_>>();
78 if relevant.is_empty() {
79 continue;
80 }
81 let top_k = ranked_ids.iter().take(k).collect::<Vec<_>>();
82 let matches = top_k
83 .iter()
84 .filter(|record_id| relevant.contains(record_id.as_str()))
85 .count() as f32;
86 if matches > 0.0 {
87 hits += 1.0;
88 }
89 recall += matches / relevant.len() as f32;
90 if let Some(rank) = ranked_ids
91 .iter()
92 .position(|record_id| relevant.contains(record_id.as_str()))
93 {
94 reciprocal_rank += 1.0 / (rank as f32 + 1.0);
95 }
96
97 let dcg = top_k
98 .iter()
99 .enumerate()
100 .filter(|(_, record_id)| relevant.contains(record_id.as_str()))
101 .map(|(index, _)| 1.0 / ((index as f32 + 2.0).log2()))
102 .sum::<f32>();
103 let ideal_hits = relevant.len().min(k);
104 let ideal_dcg = (0..ideal_hits)
105 .map(|index| 1.0 / ((index as f32 + 2.0).log2()))
106 .sum::<f32>();
107 if ideal_dcg > 0.0 {
108 ndcg += dcg / ideal_dcg;
109 }
110 }
111
112 let cases = rankings.len() as f32;
113 RankingMetrics {
114 cases: rankings.len(),
115 hit_rate_at_k: hits / cases,
116 recall_at_k: recall / cases,
117 mrr: reciprocal_rank / cases,
118 ndcg_at_k: ndcg / cases,
119 }
120}
121
122pub async fn run_recall_evaluation<S>(
123 store: &S,
124 cases: &[RecallEvaluationCase],
125 k: usize,
126) -> Result<RecallEvaluationReport>
127where
128 S: MemoryStore + ?Sized,
129{
130 let mut results = Vec::with_capacity(cases.len());
131 for case in cases {
132 results.push(store.recall(case.query.clone()).await?);
133 }
134 Ok(evaluate_recall_results(cases, &results, k))
135}
136
137pub fn evaluate_recall_results(
138 cases: &[RecallEvaluationCase],
139 results: &[RecallResult],
140 k: usize,
141) -> RecallEvaluationReport {
142 let mut ranked = Vec::<Vec<String>>::with_capacity(results.len());
143 let mut relevant = Vec::<Vec<String>>::with_capacity(cases.len());
144 let mut case_reports = Vec::with_capacity(cases.len());
145
146 for (case, result) in cases.iter().zip(results.iter()) {
147 let ranked_record_ids = result
148 .hits
149 .iter()
150 .map(|hit| hit.record.id.clone())
151 .collect::<Vec<_>>();
152 let ranked_set = ranked_record_ids.iter().cloned().collect::<BTreeSet<_>>();
153 let expected_set = case
154 .assertions
155 .expected_record_ids
156 .iter()
157 .cloned()
158 .collect::<BTreeSet<_>>();
159 let optional = case
160 .assertions
161 .optional_record_ids
162 .iter()
163 .cloned()
164 .collect::<BTreeSet<_>>();
165 let mut relevant_ids = expected_set
166 .union(&optional)
167 .cloned()
168 .collect::<Vec<String>>();
169 relevant_ids.sort();
170
171 let missing_expected_record_ids = expected_set
172 .difference(&ranked_set)
173 .cloned()
174 .collect::<Vec<_>>();
175 let present_disallowed_record_ids = case
176 .assertions
177 .disallowed_record_ids
178 .iter()
179 .filter(|record_id| ranked_set.contains(*record_id))
180 .cloned()
181 .collect::<Vec<_>>();
182
183 let explanation_notes = result
184 .explanation
185 .as_ref()
186 .map(|explanation| explanation.policy_notes.as_slice())
187 .unwrap_or_default();
188 let missing_explanation_notes = case
189 .assertions
190 .required_explanation_notes
191 .iter()
192 .filter(|required| {
193 !explanation_notes
194 .iter()
195 .any(|note| note.contains(required.as_str()))
196 })
197 .cloned()
198 .collect::<Vec<_>>();
199
200 let passed = missing_expected_record_ids.is_empty()
201 && present_disallowed_record_ids.is_empty()
202 && missing_explanation_notes.is_empty();
203
204 ranked.push(ranked_record_ids.clone());
205 relevant.push(relevant_ids);
206 case_reports.push(RecallEvaluationCaseReport {
207 name: case.name.clone(),
208 passed,
209 ranked_record_ids,
210 missing_expected_record_ids,
211 present_disallowed_record_ids,
212 missing_explanation_notes,
213 });
214 }
215
216 let ranking_pairs = ranked
217 .iter()
218 .zip(relevant.iter())
219 .map(|(ranked_ids, relevant_ids)| (ranked_ids.as_slice(), relevant_ids.as_slice()))
220 .collect::<Vec<_>>();
221 let passed_cases = case_reports.iter().filter(|report| report.passed).count();
222 let cases_len = cases.len();
223
224 RecallEvaluationReport {
225 cases: cases_len,
226 passed_cases,
227 failed_cases: cases_len.saturating_sub(passed_cases),
228 pass_rate: if cases_len == 0 {
229 0.0
230 } else {
231 passed_cases as f32 / cases_len as f32
232 },
233 ranking_metrics: evaluate_rankings_at_k(&ranking_pairs, k),
234 case_reports,
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::{
241 RecallEvaluationAssertions, RecallEvaluationCase, evaluate_rankings_at_k,
242 evaluate_recall_results,
243 };
244 use crate::{
245 MemoryQualityState, MemoryRecord, MemoryRecordKind, MemoryScope, MemoryTrustLevel,
246 RecallExplanation, RecallFilters, RecallHit, RecallQuery, RecallResult,
247 RecallScoreBreakdown,
248 };
249 use std::collections::BTreeMap;
250
251 #[test]
252 fn computes_standard_ranking_metrics() {
253 let ranked_a = vec!["a".to_string(), "b".to_string(), "c".to_string()];
254 let relevant_a = vec!["b".to_string(), "d".to_string()];
255 let ranked_b = vec!["x".to_string(), "y".to_string(), "z".to_string()];
256 let relevant_b = vec!["x".to_string()];
257
258 let metrics =
259 evaluate_rankings_at_k(&[(&ranked_a, &relevant_a), (&ranked_b, &relevant_b)], 3);
260
261 assert_eq!(metrics.cases, 2);
262 assert!(metrics.hit_rate_at_k > 0.9);
263 assert!(metrics.recall_at_k > 0.7);
264 assert!(metrics.mrr > 0.6);
265 assert!(metrics.ndcg_at_k > 0.6);
266 }
267
268 fn scope() -> MemoryScope {
269 MemoryScope {
270 tenant_id: "tenant-a".to_string(),
271 namespace: "ops".to_string(),
272 actor_id: "ava".to_string(),
273 conversation_id: None,
274 session_id: None,
275 source: "test".to_string(),
276 labels: vec!["eval".to_string()],
277 trust_level: MemoryTrustLevel::Verified,
278 }
279 }
280
281 fn hit(record_id: &str) -> RecallHit {
282 RecallHit {
283 record: MemoryRecord {
284 id: record_id.to_string(),
285 scope: scope(),
286 kind: MemoryRecordKind::Fact,
287 content: format!("record {record_id}"),
288 summary: None,
289 source_id: None,
290 metadata: BTreeMap::new(),
291 quality_state: MemoryQualityState::Active,
292 created_at_unix_ms: 1,
293 updated_at_unix_ms: 1,
294 expires_at_unix_ms: None,
295 importance_score: 0.5,
296 artifact: None,
297 episode: None,
298 historical_state: Default::default(),
299 lineage: Vec::new(),
300 conflict: None,
301 },
302 breakdown: RecallScoreBreakdown {
303 lexical: 1.0,
304 semantic: 0.0,
305 graph: 0.0,
306 temporal: 0.0,
307 metadata: 0.0,
308 episodic: 0.0,
309 salience: 0.0,
310 curation: 0.0,
311 policy: 0.0,
312 total: 1.0,
313 },
314 explanation: None,
315 }
316 }
317
318 #[test]
319 fn evaluates_judged_recall_cases_with_assertions() {
320 let query = RecallQuery {
321 scope: scope(),
322 query_text: "release".to_string(),
323 max_items: 3,
324 token_budget: None,
325 filters: RecallFilters::default(),
326 include_explanation: true,
327 };
328 let cases = vec![RecallEvaluationCase {
329 name: "release memory".to_string(),
330 query,
331 assertions: RecallEvaluationAssertions {
332 expected_record_ids: vec!["a".to_string()],
333 optional_record_ids: vec!["b".to_string()],
334 disallowed_record_ids: vec!["x".to_string()],
335 required_explanation_notes: vec!["policy=general".to_string()],
336 },
337 }];
338 let results = vec![RecallResult {
339 hits: vec![hit("a"), hit("b")],
340 total_candidates_examined: 2,
341 explanation: Some(RecallExplanation {
342 selected_channels: vec!["lexical".to_string()],
343 policy_notes: vec!["recall_policy=general".to_string()],
344 trace_id: None,
345 planning_trace: None,
346 planning_profile: None,
347 policy_profile: None,
348 scorer_kind: None,
349 scoring_profile: None,
350 }),
351 }];
352
353 let report = evaluate_recall_results(&cases, &results, 3);
354
355 assert_eq!(report.cases, 1);
356 assert_eq!(report.passed_cases, 1);
357 assert_eq!(report.failed_cases, 0);
358 assert!(report.pass_rate > 0.99);
359 assert!(report.ranking_metrics.hit_rate_at_k > 0.99);
360 assert!(report.case_reports[0].passed);
361 }
362}