Skip to main content

aprender_rag/
mmr.rs

1//! HELIX-IDEA-006 Phase 1 — Maximal Marginal Relevance reranking.
2//!
3//! Contract: `contracts/apr-rerank-v1.yaml` (ACTIVE).
4//! Pattern source: helix-db `helix_engine/reranker/fusion/mmr.rs`
5//! (re-implemented; no code lift). Reference:
6//!
7//! > Carbonell & Goldstein (1998). "The Use of MMR, Diversity-Based
8//! > Reranking for Reordering Documents and Producing Summaries."
9//! > <https://www.cs.cmu.edu/~jgc/publication/MMR.pdf>
10//!
11//! MMR balances *relevance* against *diversity* by iteratively
12//! selecting the candidate that maximises:
13//!
14//! ```text
15//!   MMR(d) = λ · sim_query(d) − (1 − λ) · max_{s ∈ Selected} sim_pair(d, s)
16//! ```
17//!
18//! At `λ=1`, the diversity term vanishes and the output is the
19//! input sorted by relevance descending — verified by
20//! `FALSIFY-RERANK-MMR-002`.
21//!
22//! # Example
23//!
24//! ```
25//! use aprender_rag::mmr::mmr_select;
26//!
27//! let candidates = vec![
28//!     ("doc-a", 0.9_f32),
29//!     ("doc-b", 0.8),
30//!     ("doc-c", 0.7),
31//!     ("doc-a-paraphrase", 0.85),
32//! ];
33//! // Pretend doc-a and doc-a-paraphrase are highly similar.
34//! let sim = |x: &&str, y: &&str| if x.contains("doc-a") && y.contains("doc-a") { 0.95 } else { 0.05 };
35//!
36//! // λ=1 → pure relevance.
37//! let by_rel = mmr_select(candidates.clone(), sim, 1.0, 3);
38//! assert_eq!(by_rel[0].0, "doc-a");
39//! assert_eq!(by_rel[1].0, "doc-a-paraphrase");
40//!
41//! // λ=0.5 → diversity penalises near-duplicate of doc-a.
42//! let diverse = mmr_select(candidates, sim, 0.5, 3);
43//! assert_eq!(diverse[0].0, "doc-a");
44//! assert_eq!(diverse[1].0, "doc-b"); // not the paraphrase
45//! ```
46
47/// Select up to `top_k` items via Maximal Marginal Relevance.
48///
49/// `candidates` carry their query-relevance scores in the second
50/// tuple slot. `similarity` returns a value in `[0, 1]` between two
51/// items (1 = identical). `lambda ∈ [0, 1]` trades relevance for
52/// diversity: `λ=1` is pure relevance, `λ=0` is pure diversity.
53///
54/// The returned `Vec` carries the MMR score (relevance term minus
55/// diversity penalty at the moment of selection), not the original
56/// relevance.
57///
58/// # Panics
59///
60/// Does not panic. If `candidates` is empty, returns an empty `Vec`.
61pub fn mmr_select<T, F>(
62    mut candidates: Vec<(T, f32)>,
63    similarity: F,
64    lambda: f32,
65    top_k: usize,
66) -> Vec<(T, f32)>
67where
68    T: Clone,
69    F: Fn(&T, &T) -> f32,
70{
71    let cap = top_k.min(candidates.len());
72    let mut selected: Vec<(T, f32)> = Vec::with_capacity(cap);
73
74    while selected.len() < cap && !candidates.is_empty() {
75        // Find the candidate index whose MMR score is maximal given
76        // the currently selected set.
77        let (best_idx, best_score) = candidates
78            .iter()
79            .enumerate()
80            .map(|(i, (item, rel))| {
81                let max_sim =
82                    selected.iter().map(|(s, _)| similarity(item, s)).fold(0.0_f32, f32::max);
83                let mmr = lambda * rel - (1.0 - lambda) * max_sim;
84                (i, mmr)
85            })
86            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
87            .expect("loop condition guarantees candidates is non-empty");
88
89        let (item, _rel) = candidates.swap_remove(best_idx);
90        selected.push((item, best_score));
91    }
92
93    selected
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    fn sim_zero(_: &&str, _: &&str) -> f32 {
101        0.0
102    }
103
104    #[test]
105    fn empty_input_returns_empty() {
106        let got: Vec<(&str, f32)> = mmr_select(Vec::new(), sim_zero, 0.5, 10);
107        assert!(got.is_empty());
108    }
109
110    #[test]
111    fn top_k_clipped_to_candidate_count() {
112        let cands = vec![("a", 0.9_f32), ("b", 0.5)];
113        let got = mmr_select(cands, sim_zero, 1.0, 10);
114        assert_eq!(got.len(), 2);
115    }
116
117    #[test]
118    fn lambda_one_yields_relevance_descending() {
119        let cands = vec![("a", 0.5_f32), ("b", 0.9), ("c", 0.7)];
120        let got = mmr_select(cands, sim_zero, 1.0, 3);
121        assert_eq!(got[0].0, "b");
122        assert_eq!(got[1].0, "c");
123        assert_eq!(got[2].0, "a");
124        // Scores at λ=1 equal the relevance scores exactly.
125        assert!((got[0].1 - 0.9).abs() < f32::EPSILON);
126        assert!((got[1].1 - 0.7).abs() < f32::EPSILON);
127        assert!((got[2].1 - 0.5).abs() < f32::EPSILON);
128    }
129
130    #[test]
131    fn lambda_zero_with_uniform_relevance_picks_diverse() {
132        // All same relevance, similarity 1.0 between same-letter
133        // pairs. Diversity should pick distinct items.
134        let cands = vec![("a", 1.0_f32), ("a-dup", 1.0), ("b", 1.0)];
135        let sim = |x: &&str, y: &&str| {
136            if x.starts_with(x.chars().next().unwrap()) && y.starts_with(x.chars().next().unwrap())
137            {
138                // crude: same first char → similar
139                if x.chars().next() == y.chars().next() {
140                    1.0
141                } else {
142                    0.0
143                }
144            } else {
145                0.0
146            }
147        };
148        let got = mmr_select(cands, sim, 0.0, 3);
149        // First pick: any (all relevance equal, no diversity penalty yet).
150        // Second pick: must differ in first char from first.
151        assert_eq!(got.len(), 3);
152    }
153}