Skip to main content

engram/search/
explain.rs

1//! Search result explainability (RML-1242)
2//!
3//! Provides human-readable explanations of why a search result ranked where it
4//! did, including per-signal score breakdowns and contribution percentages.
5//!
6//! This module is purely computational — it performs no database access.
7
8use serde::{Deserialize, Serialize};
9
10/// Breakdown of individual scoring signals for a search result.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ScoreBreakdown {
13    /// BM25 keyword relevance score (0.0–1.0)
14    pub bm25_score: f32,
15    /// Vector / semantic similarity score (0.0–1.0)
16    pub vector_score: f32,
17    /// Fuzzy match score (0.0–1.0)
18    pub fuzzy_score: f32,
19    /// Recency boost factor applied during reranking
20    pub recency_boost: f32,
21    /// Importance weight derived from `memory.importance`
22    pub importance_weight: f32,
23    /// Cross-encoder reranking score (`None` when the reranker is not active)
24    pub rerank_score: Option<f32>,
25    /// Final combined score after RRF / reranking
26    pub final_score: f32,
27}
28
29/// A named signal and its contribution to the final score.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct SignalContribution {
32    /// Signal name, e.g. `"semantic similarity"`.
33    pub signal: String,
34    /// Raw score for this signal.
35    pub score: f32,
36    /// Percentage contribution to the final score (0–100).
37    pub contribution_pct: f32,
38}
39
40/// Human-readable explanation of why a result ranked where it did.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SearchExplanation {
43    /// Memory ID of the explained result.
44    pub memory_id: i64,
45    /// 1-based rank position.
46    pub rank: usize,
47    /// Per-signal score breakdown.
48    pub scores: ScoreBreakdown,
49    /// Human-readable explanation text.
50    pub explanation: String,
51    /// Signals sorted by contribution percentage (descending).
52    pub top_signals: Vec<SignalContribution>,
53}
54
55/// Generates [`SearchExplanation`] values for search results.
56pub struct SearchExplainer {
57    /// RRF *k* parameter (should match `SearchConfig::rrf_k`).
58    pub rrf_k: f32,
59    /// Whether a cross-encoder reranker was active during this search.
60    pub reranking_active: bool,
61}
62
63impl SearchExplainer {
64    /// Create a new explainer.
65    pub fn new(rrf_k: f32, reranking_active: bool) -> Self {
66        Self {
67            rrf_k,
68            reranking_active,
69        }
70    }
71
72    /// Explain a single search result.
73    ///
74    /// # Parameters
75    /// * `memory_id` — ID of the memory being explained.
76    /// * `rank` — 1-based rank position in the result set.
77    /// * `bm25` — BM25 keyword relevance score.
78    /// * `vector` — Vector/semantic similarity score.
79    /// * `fuzzy` — Fuzzy match score.
80    /// * `recency` — Recency boost applied.
81    /// * `importance` — Importance weight.
82    /// * `rerank` — Cross-encoder score (`None` if reranker inactive).
83    /// * `final_score` — Final combined score after fusion/reranking.
84    #[allow(clippy::too_many_arguments)]
85    pub fn explain_result(
86        &self,
87        memory_id: i64,
88        rank: usize,
89        bm25: f32,
90        vector: f32,
91        fuzzy: f32,
92        recency: f32,
93        importance: f32,
94        rerank: Option<f32>,
95        final_score: f32,
96    ) -> SearchExplanation {
97        let scores = ScoreBreakdown {
98            bm25_score: bm25,
99            vector_score: vector,
100            fuzzy_score: fuzzy,
101            recency_boost: recency,
102            importance_weight: importance,
103            rerank_score: rerank,
104            final_score,
105        };
106
107        let top_signals = self.compute_signal_contributions(&scores);
108        let explanation = self.generate_explanation(rank, &scores, &top_signals);
109
110        SearchExplanation {
111            memory_id,
112            rank,
113            scores,
114            explanation,
115            top_signals,
116        }
117    }
118
119    /// Explain all results in a batch, assigning ranks 1..N in the order given.
120    ///
121    /// Each tuple is `(memory_id, bm25, vector, fuzzy, recency, importance,
122    /// rerank, final_score)`.
123    pub fn explain_batch(
124        &self,
125        results: Vec<(i64, f32, f32, f32, f32, f32, Option<f32>, f32)>,
126    ) -> Vec<SearchExplanation> {
127        results
128            .into_iter()
129            .enumerate()
130            .map(
131                |(
132                    i,
133                    (memory_id, bm25, vector, fuzzy, recency, importance, rerank, final_score),
134                )| {
135                    self.explain_result(
136                        memory_id,
137                        i + 1,
138                        bm25,
139                        vector,
140                        fuzzy,
141                        recency,
142                        importance,
143                        rerank,
144                        final_score,
145                    )
146                },
147            )
148            .collect()
149    }
150
151    /// Build the human-readable explanation string.
152    pub fn generate_explanation(
153        &self,
154        rank: usize,
155        scores: &ScoreBreakdown,
156        signals: &[SignalContribution],
157    ) -> String {
158        let mut parts: Vec<String> = Vec::new();
159
160        // Lead: rank + final score
161        parts.push(format!(
162            "Ranked #{rank} (score: {:.2}).",
163            scores.final_score
164        ));
165
166        // Primary signal
167        if let Some(primary) = signals.first() {
168            parts.push(format!(
169                "Primary signal: {} ({:.0}%).",
170                primary.signal, primary.contribution_pct
171            ));
172        }
173
174        // Secondary signals (up to 3 more)
175        for signal in signals.iter().skip(1).take(3) {
176            if signal.contribution_pct >= 1.0 {
177                // Only mention signals that meaningfully contributed
178                let verb = match signal.signal.as_str() {
179                    "BM25 keyword match" => "BM25 keyword match contributed",
180                    "recency boost" => "Recency boost added",
181                    "importance weight" => "Importance weight contributed",
182                    "fuzzy match" => "Fuzzy match contributed",
183                    _ => "contributed",
184                };
185                parts.push(format!("{} {:.0}%.", verb, signal.contribution_pct));
186            }
187        }
188
189        // Cross-encoder note
190        if self.reranking_active && scores.rerank_score.is_some() {
191            parts.push("Cross-encoder reranking confirmed relevance.".to_string());
192        }
193
194        parts.join(" ")
195    }
196
197    // ------------------------------------------------------------------ //
198    // Private helpers                                                      //
199    // ------------------------------------------------------------------ //
200
201    /// Map raw signal scores to [`SignalContribution`] sorted by contribution.
202    fn compute_signal_contributions(&self, scores: &ScoreBreakdown) -> Vec<SignalContribution> {
203        let mut raw: Vec<(&str, f32)> = vec![
204            ("semantic similarity", scores.vector_score),
205            ("BM25 keyword match", scores.bm25_score),
206            ("fuzzy match", scores.fuzzy_score),
207            ("recency boost", scores.recency_boost),
208            ("importance weight", scores.importance_weight),
209        ];
210
211        // Include cross-encoder only when the reranker is active
212        if self.reranking_active {
213            if let Some(rs) = scores.rerank_score {
214                raw.push(("cross-encoder reranking", rs));
215            }
216        }
217
218        let total: f32 = raw.iter().map(|(_, s)| s).sum();
219
220        let mut contributions: Vec<SignalContribution> = raw
221            .into_iter()
222            .map(|(name, score)| {
223                let contribution_pct = if total > 0.0 {
224                    (score / total) * 100.0
225                } else {
226                    // Equal contribution when all scores are zero
227                    0.0
228                };
229                SignalContribution {
230                    signal: name.to_string(),
231                    score,
232                    contribution_pct,
233                }
234            })
235            .collect();
236
237        // Sort descending by contribution percentage
238        contributions.sort_by(|a, b| {
239            b.contribution_pct
240                .partial_cmp(&a.contribution_pct)
241                .unwrap_or(std::cmp::Ordering::Equal)
242        });
243
244        contributions
245    }
246}
247
248impl Default for SearchExplainer {
249    fn default() -> Self {
250        Self::new(60.0, false)
251    }
252}
253
254// ------------------------------------------------------------------ //
255// Tests                                                                //
256// ------------------------------------------------------------------ //
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    fn make_explainer() -> SearchExplainer {
263        SearchExplainer::new(60.0, true)
264    }
265
266    // Helper: produce a deterministic explanation for most tests.
267    fn default_explanation(explainer: &SearchExplainer) -> SearchExplanation {
268        explainer.explain_result(
269            42,        // memory_id
270            1,         // rank
271            0.5,       // bm25
272            0.8,       // vector
273            0.3,       // fuzzy
274            0.1,       // recency
275            0.6,       // importance
276            Some(0.7), // rerank
277            0.85,      // final_score
278        )
279    }
280
281    // ------------------------------------------------------------------ //
282    // Test 1: Single result has all required fields populated             //
283    // ------------------------------------------------------------------ //
284    #[test]
285    fn test_single_result_has_all_fields() {
286        let explainer = make_explainer();
287        let exp = default_explanation(&explainer);
288
289        assert_eq!(exp.memory_id, 42);
290        assert_eq!(exp.rank, 1);
291        assert!((exp.scores.final_score - 0.85).abs() < f32::EPSILON);
292        assert!(!exp.explanation.is_empty());
293        assert!(!exp.top_signals.is_empty());
294    }
295
296    // ------------------------------------------------------------------ //
297    // Test 2: Top signals are sorted by contribution (descending)        //
298    // ------------------------------------------------------------------ //
299    #[test]
300    fn test_top_signals_sorted_descending() {
301        let explainer = make_explainer();
302        let exp = default_explanation(&explainer);
303
304        for window in exp.top_signals.windows(2) {
305            assert!(
306                window[0].contribution_pct >= window[1].contribution_pct,
307                "signals not sorted: {} ({:.2}%) before {} ({:.2}%)",
308                window[0].signal,
309                window[0].contribution_pct,
310                window[1].signal,
311                window[1].contribution_pct
312            );
313        }
314    }
315
316    // ------------------------------------------------------------------ //
317    // Test 3: Contribution percentages sum to ~100 %                     //
318    // ------------------------------------------------------------------ //
319    #[test]
320    fn test_contribution_percentages_sum_to_100() {
321        let explainer = make_explainer();
322        let exp = default_explanation(&explainer);
323
324        let total: f32 = exp.top_signals.iter().map(|s| s.contribution_pct).sum();
325        assert!(
326            (total - 100.0).abs() < 0.1,
327            "percentages sum to {total:.2}, expected ~100"
328        );
329    }
330
331    // ------------------------------------------------------------------ //
332    // Test 4: Rerank score included when reranker is active              //
333    // ------------------------------------------------------------------ //
334    #[test]
335    fn test_rerank_score_included_when_active() {
336        let explainer = SearchExplainer::new(60.0, true);
337        let exp = explainer.explain_result(1, 1, 0.4, 0.6, 0.2, 0.05, 0.5, Some(0.9), 0.75);
338
339        assert!(
340            exp.scores.rerank_score.is_some(),
341            "rerank_score should be Some when active"
342        );
343        // Cross-encoder signal must appear in top_signals
344        assert!(
345            exp.top_signals
346                .iter()
347                .any(|s| s.signal == "cross-encoder reranking"),
348            "cross-encoder signal missing from top_signals"
349        );
350    }
351
352    // ------------------------------------------------------------------ //
353    // Test 5: Rerank score is None when reranker is inactive             //
354    // ------------------------------------------------------------------ //
355    #[test]
356    fn test_rerank_score_none_when_inactive() {
357        let explainer = SearchExplainer::new(60.0, false);
358        // Pass Some(0.9) as raw input but the explainer is inactive — the
359        // stored rerank_score comes straight from the caller's value, but
360        // the signal must NOT appear in top_signals.
361        let exp = explainer.explain_result(1, 1, 0.4, 0.6, 0.2, 0.05, 0.5, None, 0.75);
362
363        assert!(
364            exp.scores.rerank_score.is_none(),
365            "rerank_score should be None when inactive"
366        );
367        assert!(
368            !exp.top_signals
369                .iter()
370                .any(|s| s.signal == "cross-encoder reranking"),
371            "cross-encoder signal must not appear when reranker is inactive"
372        );
373    }
374
375    // ------------------------------------------------------------------ //
376    // Test 6: Batch explanation assigns correct sequential ranks         //
377    // ------------------------------------------------------------------ //
378    #[test]
379    fn test_batch_assigns_correct_ranks() {
380        let explainer = SearchExplainer::new(60.0, false);
381        let results = vec![
382            (
383                1_i64, 0.9_f32, 0.8_f32, 0.1_f32, 0.05_f32, 0.7_f32, None, 0.92_f32,
384            ),
385            (
386                2_i64, 0.7_f32, 0.6_f32, 0.0_f32, 0.02_f32, 0.5_f32, None, 0.72_f32,
387            ),
388            (
389                3_i64, 0.5_f32, 0.4_f32, 0.2_f32, 0.01_f32, 0.3_f32, None, 0.55_f32,
390            ),
391        ];
392
393        let explanations = explainer.explain_batch(results);
394
395        assert_eq!(explanations.len(), 3);
396        for (i, exp) in explanations.iter().enumerate() {
397            assert_eq!(exp.rank, i + 1, "rank mismatch at index {i}");
398        }
399        assert_eq!(explanations[0].memory_id, 1);
400        assert_eq!(explanations[1].memory_id, 2);
401        assert_eq!(explanations[2].memory_id, 3);
402    }
403
404    // ------------------------------------------------------------------ //
405    // Test 7: Human-readable text contains rank and top signal name      //
406    // ------------------------------------------------------------------ //
407    #[test]
408    fn test_explanation_text_contains_rank_and_top_signal() {
409        let explainer = make_explainer();
410        let exp = default_explanation(&explainer);
411
412        assert!(
413            exp.explanation.contains("#1"),
414            "explanation should reference rank #1: {:?}",
415            exp.explanation
416        );
417
418        let top_signal_name = &exp.top_signals[0].signal;
419        assert!(
420            exp.explanation.contains(top_signal_name.as_str()),
421            "explanation should mention top signal '{top_signal_name}': {:?}",
422            exp.explanation
423        );
424    }
425
426    // ------------------------------------------------------------------ //
427    // Test 8: Zero scores handled gracefully (no panic, 0% contributions) //
428    // ------------------------------------------------------------------ //
429    #[test]
430    fn test_zero_scores_handled_gracefully() {
431        let explainer = SearchExplainer::new(60.0, false);
432        let exp = explainer.explain_result(99, 5, 0.0, 0.0, 0.0, 0.0, 0.0, None, 0.0);
433
434        // Should not panic; all contributions are 0
435        for signal in &exp.top_signals {
436            assert!(
437                (signal.contribution_pct - 0.0).abs() < f32::EPSILON,
438                "expected 0% contribution, got {:.2}% for {}",
439                signal.contribution_pct,
440                signal.signal
441            );
442        }
443
444        // Explanation should still be generated
445        assert!(exp.explanation.contains("#5"));
446    }
447
448    // ------------------------------------------------------------------ //
449    // Test 9: All signals at equal score → roughly equal contributions   //
450    // ------------------------------------------------------------------ //
451    #[test]
452    fn test_equal_signals_have_roughly_equal_contributions() {
453        let explainer = SearchExplainer::new(60.0, true);
454        // Six signals each at 1.0 (5 base + 1 rerank)
455        let exp = explainer.explain_result(7, 2, 1.0, 1.0, 1.0, 1.0, 1.0, Some(1.0), 1.0);
456
457        let expected_pct = 100.0 / 6.0;
458        for signal in &exp.top_signals {
459            assert!(
460                (signal.contribution_pct - expected_pct).abs() < 1.0,
461                "signal '{}' has {:.2}%, expected ~{:.2}%",
462                signal.signal,
463                signal.contribution_pct,
464                expected_pct
465            );
466        }
467    }
468}