Skip to main content

mmr_rerank/
lib.rs

1//! # mmr-rerank
2//!
3//! Maximal Marginal Relevance (Carbonell & Goldstein, 1998) reranker.
4//!
5//! Given relevance scores `rel[i] = sim(query, doc_i)` and a pairwise
6//! similarity matrix `pair[i][j] = sim(doc_i, doc_j)`, returns the
7//! top-k indices that balance query relevance against novelty:
8//!
9//! ```text
10//! pick = argmax_{i not picked}  λ * rel[i] - (1 - λ) * max_{j picked} pair[i][j]
11//! ```
12//!
13//! ## Example
14//!
15//! ```
16//! use mmr_rerank::mmr;
17//! let rel = vec![0.9, 0.85, 0.6, 0.55];
18//! // 4 docs; pair[i][j] = pairwise similarity
19//! let pair = vec![
20//!     vec![1.0, 0.95, 0.10, 0.10],
21//!     vec![0.95, 1.0, 0.10, 0.10],
22//!     vec![0.10, 0.10, 1.0, 0.95],
23//!     vec![0.10, 0.10, 0.95, 1.0],
24//! ];
25//! // λ = 0.5 trades off relevance and diversity equally.
26//! let picks = mmr(&rel, &pair, 0.5, 2);
27//! // Expect 0 (top relevance), then 2 (low pair sim with 0).
28//! assert_eq!(picks, vec![0, 2]);
29//! ```
30
31#![deny(missing_docs)]
32
33/// Pick `k` indices in MMR order. `lambda ∈ [0, 1]`.
34pub fn mmr(rel: &[f32], pair: &[Vec<f32>], lambda: f32, k: usize) -> Vec<usize> {
35    let n = rel.len();
36    assert_eq!(pair.len(), n, "pair rows must equal rel length");
37    let mut picked: Vec<usize> = Vec::with_capacity(k.min(n));
38    let mut remaining: Vec<usize> = (0..n).collect();
39
40    while picked.len() < k && !remaining.is_empty() {
41        let mut best_score = f32::NEG_INFINITY;
42        let mut best_idx_in_rem = 0;
43        for (slot, &cand) in remaining.iter().enumerate() {
44            let max_pair = picked
45                .iter()
46                .map(|&p| pair[cand][p])
47                .fold(0.0_f32, f32::max);
48            let score = lambda * rel[cand] - (1.0 - lambda) * max_pair;
49            if score > best_score {
50                best_score = score;
51                best_idx_in_rem = slot;
52            }
53        }
54        let winner = remaining.remove(best_idx_in_rem);
55        picked.push(winner);
56    }
57
58    picked
59}