Skip to main content

mnem_graphrag/
summarize.rs

1//! Centroid + MMR extractive community summarizer.
2//!
3//! Given a set of sentences (e.g. the text spans belonging to a
4//! community of nodes), score each sentence by its distance to the
5//! community centroid, optionally to a query vector, and by a
6//! graph-centrality fallback (degree today; PPR once E2 lands).
7//! Then greedily pick `k` sentences with MMR diversity so the output
8//! is not dominated by near-duplicates.
9//!
10//! # Determinism
11//!
12//! The function is input-order-insensitive: callers may pass
13//! sentences in any order and the resulting [`Summary`] is byte-for-byte
14//! identical. This is achieved by sorting the working set by the
15//! SHA-256-style content hash of each sentence (we use BLAKE3 for
16//! speed and because it is already a workspace dep; the guarantee the
17//! caller cares about is stability, not a specific hash family).
18//! MMR tie-breaks fall back to lexicographic order.
19//!
20//! # Weights
21//!
22//! `alpha = 0.5`, `beta = 0.3`, `gamma = 0.2` (see spec §E4).
23//! If `query_embed` is `None`, `beta` is redistributed to `alpha`
24//! (effective `alpha = 0.8, gamma = 0.2`).
25
26use mnem_embed_providers::Embedder;
27
28/// A single sentence picked by [`summarize_community`], with the
29/// final MMR-adjusted score at the moment it was selected.
30#[derive(Debug, Clone, PartialEq)]
31pub struct SummaryItem {
32    /// The original sentence text.
33    pub sentence: String,
34    /// The score at selection time (post-MMR penalty).
35    pub score: f32,
36}
37
38/// The output of [`summarize_community`]: picked sentences in MMR
39/// order (first-picked = highest effective score) plus a
40/// parallel-indexed score vector for callers that want the numbers
41/// without the strings.
42#[derive(Debug, Clone, PartialEq)]
43pub struct Summary {
44    /// Picked sentences in MMR selection order.
45    pub sentences: Vec<String>,
46    /// Scores aligned with [`Summary::sentences`].
47    pub scores: Vec<f32>,
48}
49
50impl Summary {
51    /// Convenience: zip the parallel vectors into [`SummaryItem`]s.
52    #[must_use]
53    pub fn items(&self) -> Vec<SummaryItem> {
54        self.sentences
55            .iter()
56            .zip(self.scores.iter())
57            .map(|(s, &score)| SummaryItem {
58                sentence: s.clone(),
59                score,
60            })
61            .collect()
62    }
63}
64
65/// Summarize a community of sentences using Centroid + MMR.
66///
67/// # Arguments
68///
69/// - `sentences`: all sentences in the community. Order-insensitive.
70/// - `embedder`: any [`Embedder`] (typically the MiniLM MCP default,
71///   or a mock in tests). Reused from `mnem-embed-providers`.
72/// - `query_embed`: optional query vector for query-focused
73///   summarization. Must match `embedder.dim()` when provided.
74/// - `centrality`: closure returning a non-negative centrality weight
75///   for each sentence index. Today this is degree-centrality from
76///   the caller; when E2 lands a PPR vector can slot in unchanged.
77/// - `k`: maximum number of sentences to return. `min(k, sentences.len())`
78///   are actually picked.
79/// - `mmr_lambda`: diversity knob in `[0.0, 1.0]`. `0.0` = pure
80///   relevance, `1.0` = pure diversity. Values outside the range
81///   are clamped. Default from spec: `0.5`.
82///
83/// # Panics
84///
85/// Does not panic on empty input; returns an empty [`Summary`].
86///
87/// # Errors
88///
89/// Propagates any [`mnem_embed_providers::EmbedError`] from the
90/// underlying embedder.
91#[allow(clippy::too_many_arguments)]
92pub fn summarize_community(
93    sentences: &[String],
94    embedder: &dyn Embedder,
95    query_embed: Option<&[f32]>,
96    centrality: &dyn Fn(usize) -> f32,
97    k: usize,
98    mmr_lambda: f32,
99) -> Result<Summary, mnem_embed_providers::EmbedError> {
100    if sentences.is_empty() || k == 0 {
101        return Ok(Summary {
102            sentences: Vec::new(),
103            scores: Vec::new(),
104        });
105    }
106
107    // ------------------------------------------------------------
108    // Step 1: stable ordering via content hash.
109    //
110    // We build an index permutation `perm` so that
111    // `sentences[perm[i]]` is the i-th sentence in canonical order.
112    // The caller's `centrality` closure is still called with the
113    // ORIGINAL index so that a caller-provided degree/PPR vector
114    // does not have to be re-permuted.
115    // ------------------------------------------------------------
116    let mut perm: Vec<usize> = (0..sentences.len()).collect();
117    perm.sort_by(|&a, &b| {
118        let ha = blake3::hash(sentences[a].as_bytes());
119        let hb = blake3::hash(sentences[b].as_bytes());
120        ha.as_bytes()
121            .cmp(hb.as_bytes())
122            .then_with(|| sentences[a].cmp(&sentences[b]))
123    });
124
125    // ------------------------------------------------------------
126    // Step 2: embed every sentence in canonical order.
127    // ------------------------------------------------------------
128    let texts: Vec<&str> = perm.iter().map(|&i| sentences[i].as_str()).collect();
129    let embeds = embedder.embed_batch(&texts)?;
130
131    // ------------------------------------------------------------
132    // Step 3: centroid = mean of sentence embeddings.
133    // ------------------------------------------------------------
134    let dim = embedder.dim() as usize;
135    let mut centroid = vec![0.0_f32; dim];
136    for v in &embeds {
137        for (c, x) in centroid.iter_mut().zip(v.iter()) {
138            *c += *x;
139        }
140    }
141    let n_f = embeds.len() as f32;
142    for c in &mut centroid {
143        *c /= n_f;
144    }
145
146    // Validate query_embed dimension if supplied.
147    if let Some(q) = query_embed
148        && q.len() != dim
149    {
150        return Err(mnem_embed_providers::EmbedError::DimMismatch {
151            expected: embedder.dim(),
152            got: u32::try_from(q.len()).unwrap_or(u32::MAX),
153        });
154    }
155
156    // ------------------------------------------------------------
157    // Step 4: per-sentence base score.
158    //
159    // Score(s_i) = alpha * cos(s_i, centroid)
160    //            + beta  * cos(s_i, query)         (if query)
161    //            + gamma * centrality(orig_i)/max_centrality
162    //
163    // If no query, redistribute beta to alpha (alpha=0.8, gamma=0.2).
164    // ------------------------------------------------------------
165    let (alpha, beta, gamma) = if query_embed.is_some() {
166        (0.5_f32, 0.3_f32, 0.2_f32)
167    } else {
168        (0.8_f32, 0.0_f32, 0.2_f32)
169    };
170
171    // Materialise centralities in canonical order AND find their max.
172    let mut centralities_canon: Vec<f32> = Vec::with_capacity(perm.len());
173    for &orig_i in &perm {
174        let c = centrality(orig_i);
175        centralities_canon.push(c.max(0.0));
176    }
177    let max_centrality = centralities_canon
178        .iter()
179        .copied()
180        .fold(0.0_f32, f32::max)
181        .max(f32::EPSILON); // avoid /0
182
183    let base_scores: Vec<f32> = embeds
184        .iter()
185        .enumerate()
186        .map(|(i, v)| {
187            let s_cent = cosine(v, &centroid);
188            let s_query = query_embed.map_or(0.0, |q| cosine(v, q));
189            let s_centrality = centralities_canon[i] / max_centrality;
190            alpha * s_cent + beta * s_query + gamma * s_centrality
191        })
192        .collect();
193
194    // ------------------------------------------------------------
195    // Step 5: MMR greedy selection.
196    //
197    // effective(i) = (1 - lambda) * base_scores[i]
198    //              - lambda * max_{j in picked} cos(v_i, v_j)
199    //
200    // Note: the spec text in the worktree task says the penalty is
201    // `lambda * max(cos(..., picked))`. The standard MMR formulation
202    // balances relevance and diversity as
203    //   MMR = lambda * rel - (1 - lambda) * max_sim
204    // We follow the standard interpretation with `mmr_lambda` being
205    // the *diversity* weight (high lambda -> strong penalty), which
206    // matches the spec's "diversity tradeoff" language and the
207    // lambda=0.5 default.
208    // ------------------------------------------------------------
209    let lambda = mmr_lambda.clamp(0.0, 1.0);
210    let k_cap = k.min(embeds.len());
211    let mut picked: Vec<usize> = Vec::with_capacity(k_cap);
212    let mut picked_set = vec![false; embeds.len()];
213    let mut out_sentences: Vec<String> = Vec::with_capacity(k_cap);
214    let mut out_scores: Vec<f32> = Vec::with_capacity(k_cap);
215
216    while picked.len() < k_cap {
217        let mut best_idx: Option<usize> = None;
218        let mut best_score = f32::NEG_INFINITY;
219
220        for i in 0..embeds.len() {
221            if picked_set[i] {
222                continue;
223            }
224            // MMR penalty: max cosine similarity to any already-picked
225            // sentence. Clamped to [0.0, 1.0] so that anti-correlated
226            // vectors (which the MockEmbedder can produce) cannot turn
227            // the penalty into a spurious BONUS. With normalized
228            // MiniLM embeddings cosines are in [0,1] already, but the
229            // clamp keeps the invariant "effective_score is
230            // non-increasing across greedy picks" under every embedder.
231            let penalty = if picked.is_empty() {
232                0.0
233            } else {
234                picked
235                    .iter()
236                    .map(|&j| cosine(&embeds[i], &embeds[j]).clamp(0.0, 1.0))
237                    .fold(0.0_f32, f32::max)
238            };
239            let eff = (1.0 - lambda) * base_scores[i] - lambda * penalty;
240
241            // Tie-break: lexicographic on the sentence text.
242            let is_better = match best_idx {
243                None => true,
244                Some(bi) => {
245                    if eff > best_score {
246                        true
247                    } else if (eff - best_score).abs() < f32::EPSILON {
248                        texts[i] < texts[bi]
249                    } else {
250                        false
251                    }
252                }
253            };
254            if is_better {
255                best_idx = Some(i);
256                best_score = eff;
257            }
258        }
259
260        if let Some(bi) = best_idx {
261            picked.push(bi);
262            picked_set[bi] = true;
263            out_sentences.push(texts[bi].to_owned());
264            out_scores.push(best_score);
265        } else {
266            break;
267        }
268    }
269
270    Ok(Summary {
271        sentences: out_sentences,
272        scores: out_scores,
273    })
274}
275
276/// Cosine similarity. Returns 0.0 when either vector is zero-norm
277/// (no panics, no NaN).
278fn cosine(a: &[f32], b: &[f32]) -> f32 {
279    debug_assert_eq!(a.len(), b.len(), "cosine: dim mismatch");
280    let mut dot = 0.0_f32;
281    let mut na = 0.0_f32;
282    let mut nb = 0.0_f32;
283    for (x, y) in a.iter().zip(b.iter()) {
284        dot += x * y;
285        na += x * x;
286        nb += y * y;
287    }
288    let denom = na.sqrt() * nb.sqrt();
289    if denom <= f32::EPSILON {
290        0.0
291    } else {
292        dot / denom
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use mnem_embed_providers::MockEmbedder;
300
301    fn make_mock() -> MockEmbedder {
302        MockEmbedder::new("test:mock", 32)
303    }
304
305    #[test]
306    fn empty_input_returns_empty_summary() {
307        let e = make_mock();
308        let s = summarize_community(&[], &e, None, &|_| 1.0, 5, 0.5).unwrap();
309        assert!(s.sentences.is_empty());
310        assert!(s.scores.is_empty());
311    }
312
313    #[test]
314    fn k_zero_returns_empty() {
315        let e = make_mock();
316        let xs = vec!["a".to_string(), "b".to_string()];
317        let s = summarize_community(&xs, &e, None, &|_| 1.0, 0, 0.5).unwrap();
318        assert!(s.sentences.is_empty());
319    }
320
321    #[test]
322    fn k_larger_than_n_is_clamped() {
323        let e = make_mock();
324        let xs = vec!["a".to_string(), "b".to_string()];
325        let s = summarize_community(&xs, &e, None, &|_| 1.0, 99, 0.5).unwrap();
326        assert_eq!(s.sentences.len(), 2);
327    }
328}