1use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ScoreBreakdown {
13 pub bm25_score: f32,
15 pub vector_score: f32,
17 pub fuzzy_score: f32,
19 pub recency_boost: f32,
21 pub importance_weight: f32,
23 pub rerank_score: Option<f32>,
25 pub final_score: f32,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct SignalContribution {
32 pub signal: String,
34 pub score: f32,
36 pub contribution_pct: f32,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SearchExplanation {
43 pub memory_id: i64,
45 pub rank: usize,
47 pub scores: ScoreBreakdown,
49 pub explanation: String,
51 pub top_signals: Vec<SignalContribution>,
53}
54
55pub struct SearchExplainer {
57 pub rrf_k: f32,
59 pub reranking_active: bool,
61}
62
63impl SearchExplainer {
64 pub fn new(rrf_k: f32, reranking_active: bool) -> Self {
66 Self {
67 rrf_k,
68 reranking_active,
69 }
70 }
71
72 #[allow(clippy::too_many_arguments)]
85 pub fn explain_result(
86 &self,
87 memory_id: i64,
88 rank: usize,
89 bm25: f32,
90 vector: f32,
91 fuzzy: f32,
92 recency: f32,
93 importance: f32,
94 rerank: Option<f32>,
95 final_score: f32,
96 ) -> SearchExplanation {
97 let scores = ScoreBreakdown {
98 bm25_score: bm25,
99 vector_score: vector,
100 fuzzy_score: fuzzy,
101 recency_boost: recency,
102 importance_weight: importance,
103 rerank_score: rerank,
104 final_score,
105 };
106
107 let top_signals = self.compute_signal_contributions(&scores);
108 let explanation = self.generate_explanation(rank, &scores, &top_signals);
109
110 SearchExplanation {
111 memory_id,
112 rank,
113 scores,
114 explanation,
115 top_signals,
116 }
117 }
118
119 pub fn explain_batch(
124 &self,
125 results: Vec<(i64, f32, f32, f32, f32, f32, Option<f32>, f32)>,
126 ) -> Vec<SearchExplanation> {
127 results
128 .into_iter()
129 .enumerate()
130 .map(
131 |(
132 i,
133 (memory_id, bm25, vector, fuzzy, recency, importance, rerank, final_score),
134 )| {
135 self.explain_result(
136 memory_id,
137 i + 1,
138 bm25,
139 vector,
140 fuzzy,
141 recency,
142 importance,
143 rerank,
144 final_score,
145 )
146 },
147 )
148 .collect()
149 }
150
151 pub fn generate_explanation(
153 &self,
154 rank: usize,
155 scores: &ScoreBreakdown,
156 signals: &[SignalContribution],
157 ) -> String {
158 let mut parts: Vec<String> = Vec::new();
159
160 parts.push(format!(
162 "Ranked #{rank} (score: {:.2}).",
163 scores.final_score
164 ));
165
166 if let Some(primary) = signals.first() {
168 parts.push(format!(
169 "Primary signal: {} ({:.0}%).",
170 primary.signal, primary.contribution_pct
171 ));
172 }
173
174 for signal in signals.iter().skip(1).take(3) {
176 if signal.contribution_pct >= 1.0 {
177 let verb = match signal.signal.as_str() {
179 "BM25 keyword match" => "BM25 keyword match contributed",
180 "recency boost" => "Recency boost added",
181 "importance weight" => "Importance weight contributed",
182 "fuzzy match" => "Fuzzy match contributed",
183 _ => "contributed",
184 };
185 parts.push(format!("{} {:.0}%.", verb, signal.contribution_pct));
186 }
187 }
188
189 if self.reranking_active && scores.rerank_score.is_some() {
191 parts.push("Cross-encoder reranking confirmed relevance.".to_string());
192 }
193
194 parts.join(" ")
195 }
196
197 fn compute_signal_contributions(&self, scores: &ScoreBreakdown) -> Vec<SignalContribution> {
203 let mut raw: Vec<(&str, f32)> = vec![
204 ("semantic similarity", scores.vector_score),
205 ("BM25 keyword match", scores.bm25_score),
206 ("fuzzy match", scores.fuzzy_score),
207 ("recency boost", scores.recency_boost),
208 ("importance weight", scores.importance_weight),
209 ];
210
211 if self.reranking_active {
213 if let Some(rs) = scores.rerank_score {
214 raw.push(("cross-encoder reranking", rs));
215 }
216 }
217
218 let total: f32 = raw.iter().map(|(_, s)| s).sum();
219
220 let mut contributions: Vec<SignalContribution> = raw
221 .into_iter()
222 .map(|(name, score)| {
223 let contribution_pct = if total > 0.0 {
224 (score / total) * 100.0
225 } else {
226 0.0
228 };
229 SignalContribution {
230 signal: name.to_string(),
231 score,
232 contribution_pct,
233 }
234 })
235 .collect();
236
237 contributions.sort_by(|a, b| {
239 b.contribution_pct
240 .partial_cmp(&a.contribution_pct)
241 .unwrap_or(std::cmp::Ordering::Equal)
242 });
243
244 contributions
245 }
246}
247
248impl Default for SearchExplainer {
249 fn default() -> Self {
250 Self::new(60.0, false)
251 }
252}
253
254#[cfg(test)]
259mod tests {
260 use super::*;
261
262 fn make_explainer() -> SearchExplainer {
263 SearchExplainer::new(60.0, true)
264 }
265
266 fn default_explanation(explainer: &SearchExplainer) -> SearchExplanation {
268 explainer.explain_result(
269 42, 1, 0.5, 0.8, 0.3, 0.1, 0.6, Some(0.7), 0.85, )
279 }
280
281 #[test]
285 fn test_single_result_has_all_fields() {
286 let explainer = make_explainer();
287 let exp = default_explanation(&explainer);
288
289 assert_eq!(exp.memory_id, 42);
290 assert_eq!(exp.rank, 1);
291 assert!((exp.scores.final_score - 0.85).abs() < f32::EPSILON);
292 assert!(!exp.explanation.is_empty());
293 assert!(!exp.top_signals.is_empty());
294 }
295
296 #[test]
300 fn test_top_signals_sorted_descending() {
301 let explainer = make_explainer();
302 let exp = default_explanation(&explainer);
303
304 for window in exp.top_signals.windows(2) {
305 assert!(
306 window[0].contribution_pct >= window[1].contribution_pct,
307 "signals not sorted: {} ({:.2}%) before {} ({:.2}%)",
308 window[0].signal,
309 window[0].contribution_pct,
310 window[1].signal,
311 window[1].contribution_pct
312 );
313 }
314 }
315
316 #[test]
320 fn test_contribution_percentages_sum_to_100() {
321 let explainer = make_explainer();
322 let exp = default_explanation(&explainer);
323
324 let total: f32 = exp.top_signals.iter().map(|s| s.contribution_pct).sum();
325 assert!(
326 (total - 100.0).abs() < 0.1,
327 "percentages sum to {total:.2}, expected ~100"
328 );
329 }
330
331 #[test]
335 fn test_rerank_score_included_when_active() {
336 let explainer = SearchExplainer::new(60.0, true);
337 let exp = explainer.explain_result(1, 1, 0.4, 0.6, 0.2, 0.05, 0.5, Some(0.9), 0.75);
338
339 assert!(
340 exp.scores.rerank_score.is_some(),
341 "rerank_score should be Some when active"
342 );
343 assert!(
345 exp.top_signals
346 .iter()
347 .any(|s| s.signal == "cross-encoder reranking"),
348 "cross-encoder signal missing from top_signals"
349 );
350 }
351
352 #[test]
356 fn test_rerank_score_none_when_inactive() {
357 let explainer = SearchExplainer::new(60.0, false);
358 let exp = explainer.explain_result(1, 1, 0.4, 0.6, 0.2, 0.05, 0.5, None, 0.75);
362
363 assert!(
364 exp.scores.rerank_score.is_none(),
365 "rerank_score should be None when inactive"
366 );
367 assert!(
368 !exp.top_signals
369 .iter()
370 .any(|s| s.signal == "cross-encoder reranking"),
371 "cross-encoder signal must not appear when reranker is inactive"
372 );
373 }
374
375 #[test]
379 fn test_batch_assigns_correct_ranks() {
380 let explainer = SearchExplainer::new(60.0, false);
381 let results = vec![
382 (
383 1_i64, 0.9_f32, 0.8_f32, 0.1_f32, 0.05_f32, 0.7_f32, None, 0.92_f32,
384 ),
385 (
386 2_i64, 0.7_f32, 0.6_f32, 0.0_f32, 0.02_f32, 0.5_f32, None, 0.72_f32,
387 ),
388 (
389 3_i64, 0.5_f32, 0.4_f32, 0.2_f32, 0.01_f32, 0.3_f32, None, 0.55_f32,
390 ),
391 ];
392
393 let explanations = explainer.explain_batch(results);
394
395 assert_eq!(explanations.len(), 3);
396 for (i, exp) in explanations.iter().enumerate() {
397 assert_eq!(exp.rank, i + 1, "rank mismatch at index {i}");
398 }
399 assert_eq!(explanations[0].memory_id, 1);
400 assert_eq!(explanations[1].memory_id, 2);
401 assert_eq!(explanations[2].memory_id, 3);
402 }
403
404 #[test]
408 fn test_explanation_text_contains_rank_and_top_signal() {
409 let explainer = make_explainer();
410 let exp = default_explanation(&explainer);
411
412 assert!(
413 exp.explanation.contains("#1"),
414 "explanation should reference rank #1: {:?}",
415 exp.explanation
416 );
417
418 let top_signal_name = &exp.top_signals[0].signal;
419 assert!(
420 exp.explanation.contains(top_signal_name.as_str()),
421 "explanation should mention top signal '{top_signal_name}': {:?}",
422 exp.explanation
423 );
424 }
425
426 #[test]
430 fn test_zero_scores_handled_gracefully() {
431 let explainer = SearchExplainer::new(60.0, false);
432 let exp = explainer.explain_result(99, 5, 0.0, 0.0, 0.0, 0.0, 0.0, None, 0.0);
433
434 for signal in &exp.top_signals {
436 assert!(
437 (signal.contribution_pct - 0.0).abs() < f32::EPSILON,
438 "expected 0% contribution, got {:.2}% for {}",
439 signal.contribution_pct,
440 signal.signal
441 );
442 }
443
444 assert!(exp.explanation.contains("#5"));
446 }
447
448 #[test]
452 fn test_equal_signals_have_roughly_equal_contributions() {
453 let explainer = SearchExplainer::new(60.0, true);
454 let exp = explainer.explain_result(7, 2, 1.0, 1.0, 1.0, 1.0, 1.0, Some(1.0), 1.0);
456
457 let expected_pct = 100.0 / 6.0;
458 for signal in &exp.top_signals {
459 assert!(
460 (signal.contribution_pct - expected_pct).abs() < 1.0,
461 "signal '{}' has {:.2}%, expected ~{:.2}%",
462 signal.signal,
463 signal.contribution_pct,
464 expected_pct
465 );
466 }
467 }
468}