mnem_graphrag/summarize.rs
1//! Centroid + MMR extractive community summarizer.
2//!
3//! Given a set of sentences (e.g. the text spans belonging to a
4//! community of nodes), score each sentence by its distance to the
5//! community centroid, optionally to a query vector, and by a
6//! graph-centrality fallback (degree today; PPR once E2 lands).
7//! Then greedily pick `k` sentences with MMR diversity so the output
8//! is not dominated by near-duplicates.
9//!
10//! # Determinism
11//!
12//! The function is input-order-insensitive: callers may pass
13//! sentences in any order and the resulting [`Summary`] is byte-for-byte
14//! identical. This is achieved by sorting the working set by the
15//! SHA-256-style content hash of each sentence (we use BLAKE3 for
16//! speed and because it is already a workspace dep; the guarantee the
17//! caller cares about is stability, not a specific hash family).
18//! MMR tie-breaks fall back to lexicographic order.
19//!
20//! # Weights
21//!
22//! `alpha = 0.5`, `beta = 0.3`, `gamma = 0.2` (see spec §E4).
23//! If `query_embed` is `None`, `beta` is redistributed to `alpha`
24//! (effective `alpha = 0.8, gamma = 0.2`).
25
26use mnem_embed_providers::Embedder;
27
28/// A single sentence picked by [`summarize_community`], with the
29/// final MMR-adjusted score at the moment it was selected.
30#[derive(Debug, Clone, PartialEq)]
31pub struct SummaryItem {
32 /// The original sentence text.
33 pub sentence: String,
34 /// The score at selection time (post-MMR penalty).
35 pub score: f32,
36}
37
38/// The output of [`summarize_community`]: picked sentences in MMR
39/// order (first-picked = highest effective score) plus a
40/// parallel-indexed score vector for callers that want the numbers
41/// without the strings.
42#[derive(Debug, Clone, PartialEq)]
43pub struct Summary {
44 /// Picked sentences in MMR selection order.
45 pub sentences: Vec<String>,
46 /// Scores aligned with [`Summary::sentences`].
47 pub scores: Vec<f32>,
48}
49
50impl Summary {
51 /// Convenience: zip the parallel vectors into [`SummaryItem`]s.
52 #[must_use]
53 pub fn items(&self) -> Vec<SummaryItem> {
54 self.sentences
55 .iter()
56 .zip(self.scores.iter())
57 .map(|(s, &score)| SummaryItem {
58 sentence: s.clone(),
59 score,
60 })
61 .collect()
62 }
63}
64
65/// Summarize a community of sentences using Centroid + MMR.
66///
67/// # Arguments
68///
69/// - `sentences`: all sentences in the community. Order-insensitive.
70/// - `embedder`: any [`Embedder`] (typically the MiniLM MCP default,
71/// or a mock in tests). Reused from `mnem-embed-providers`.
72/// - `query_embed`: optional query vector for query-focused
73/// summarization. Must match `embedder.dim()` when provided.
74/// - `centrality`: closure returning a non-negative centrality weight
75/// for each sentence index. Today this is degree-centrality from
76/// the caller; when E2 lands a PPR vector can slot in unchanged.
77/// - `k`: maximum number of sentences to return. `min(k, sentences.len())`
78/// are actually picked.
79/// - `mmr_lambda`: diversity knob in `[0.0, 1.0]`. `0.0` = pure
80/// relevance, `1.0` = pure diversity. Values outside the range
81/// are clamped. Default from spec: `0.5`.
82///
83/// # Panics
84///
85/// Does not panic on empty input; returns an empty [`Summary`].
86///
87/// # Errors
88///
89/// Propagates any [`mnem_embed_providers::EmbedError`] from the
90/// underlying embedder.
91#[allow(clippy::too_many_arguments)]
92pub fn summarize_community(
93 sentences: &[String],
94 embedder: &dyn Embedder,
95 query_embed: Option<&[f32]>,
96 centrality: &dyn Fn(usize) -> f32,
97 k: usize,
98 mmr_lambda: f32,
99) -> Result<Summary, mnem_embed_providers::EmbedError> {
100 if sentences.is_empty() || k == 0 {
101 return Ok(Summary {
102 sentences: Vec::new(),
103 scores: Vec::new(),
104 });
105 }
106
107 // ------------------------------------------------------------
108 // Step 1: stable ordering via content hash.
109 //
110 // We build an index permutation `perm` so that
111 // `sentences[perm[i]]` is the i-th sentence in canonical order.
112 // The caller's `centrality` closure is still called with the
113 // ORIGINAL index so that a caller-provided degree/PPR vector
114 // does not have to be re-permuted.
115 // ------------------------------------------------------------
116 let mut perm: Vec<usize> = (0..sentences.len()).collect();
117 perm.sort_by(|&a, &b| {
118 let ha = blake3::hash(sentences[a].as_bytes());
119 let hb = blake3::hash(sentences[b].as_bytes());
120 ha.as_bytes()
121 .cmp(hb.as_bytes())
122 .then_with(|| sentences[a].cmp(&sentences[b]))
123 });
124
125 // ------------------------------------------------------------
126 // Step 2: embed every sentence in canonical order.
127 // ------------------------------------------------------------
128 let texts: Vec<&str> = perm.iter().map(|&i| sentences[i].as_str()).collect();
129 let embeds = embedder.embed_batch(&texts)?;
130
131 // ------------------------------------------------------------
132 // Step 3: centroid = mean of sentence embeddings.
133 // ------------------------------------------------------------
134 let dim = embedder.dim() as usize;
135 let mut centroid = vec![0.0_f32; dim];
136 for v in &embeds {
137 for (c, x) in centroid.iter_mut().zip(v.iter()) {
138 *c += *x;
139 }
140 }
141 let n_f = embeds.len() as f32;
142 for c in &mut centroid {
143 *c /= n_f;
144 }
145
146 // Validate query_embed dimension if supplied.
147 if let Some(q) = query_embed
148 && q.len() != dim
149 {
150 return Err(mnem_embed_providers::EmbedError::DimMismatch {
151 expected: embedder.dim(),
152 got: u32::try_from(q.len()).unwrap_or(u32::MAX),
153 });
154 }
155
156 // ------------------------------------------------------------
157 // Step 4: per-sentence base score.
158 //
159 // Score(s_i) = alpha * cos(s_i, centroid)
160 // + beta * cos(s_i, query) (if query)
161 // + gamma * centrality(orig_i)/max_centrality
162 //
163 // If no query, redistribute beta to alpha (alpha=0.8, gamma=0.2).
164 // ------------------------------------------------------------
165 let (alpha, beta, gamma) = if query_embed.is_some() {
166 (0.5_f32, 0.3_f32, 0.2_f32)
167 } else {
168 (0.8_f32, 0.0_f32, 0.2_f32)
169 };
170
171 // Materialise centralities in canonical order AND find their max.
172 let mut centralities_canon: Vec<f32> = Vec::with_capacity(perm.len());
173 for &orig_i in &perm {
174 let c = centrality(orig_i);
175 centralities_canon.push(c.max(0.0));
176 }
177 let max_centrality = centralities_canon
178 .iter()
179 .copied()
180 .fold(0.0_f32, f32::max)
181 .max(f32::EPSILON); // avoid /0
182
183 let base_scores: Vec<f32> = embeds
184 .iter()
185 .enumerate()
186 .map(|(i, v)| {
187 let s_cent = cosine(v, ¢roid);
188 let s_query = query_embed.map_or(0.0, |q| cosine(v, q));
189 let s_centrality = centralities_canon[i] / max_centrality;
190 alpha * s_cent + beta * s_query + gamma * s_centrality
191 })
192 .collect();
193
194 // ------------------------------------------------------------
195 // Step 5: MMR greedy selection.
196 //
197 // effective(i) = (1 - lambda) * base_scores[i]
198 // - lambda * max_{j in picked} cos(v_i, v_j)
199 //
200 // Note: the spec text in the worktree task says the penalty is
201 // `lambda * max(cos(..., picked))`. The standard MMR formulation
202 // balances relevance and diversity as
203 // MMR = lambda * rel - (1 - lambda) * max_sim
204 // We follow the standard interpretation with `mmr_lambda` being
205 // the *diversity* weight (high lambda -> strong penalty), which
206 // matches the spec's "diversity tradeoff" language and the
207 // lambda=0.5 default.
208 // ------------------------------------------------------------
209 let lambda = mmr_lambda.clamp(0.0, 1.0);
210 let k_cap = k.min(embeds.len());
211 let mut picked: Vec<usize> = Vec::with_capacity(k_cap);
212 let mut picked_set = vec![false; embeds.len()];
213 let mut out_sentences: Vec<String> = Vec::with_capacity(k_cap);
214 let mut out_scores: Vec<f32> = Vec::with_capacity(k_cap);
215
216 while picked.len() < k_cap {
217 let mut best_idx: Option<usize> = None;
218 let mut best_score = f32::NEG_INFINITY;
219
220 for i in 0..embeds.len() {
221 if picked_set[i] {
222 continue;
223 }
224 // MMR penalty: max cosine similarity to any already-picked
225 // sentence. Clamped to [0.0, 1.0] so that anti-correlated
226 // vectors (which the MockEmbedder can produce) cannot turn
227 // the penalty into a spurious BONUS. With normalized
228 // MiniLM embeddings cosines are in [0,1] already, but the
229 // clamp keeps the invariant "effective_score is
230 // non-increasing across greedy picks" under every embedder.
231 let penalty = if picked.is_empty() {
232 0.0
233 } else {
234 picked
235 .iter()
236 .map(|&j| cosine(&embeds[i], &embeds[j]).clamp(0.0, 1.0))
237 .fold(0.0_f32, f32::max)
238 };
239 let eff = (1.0 - lambda) * base_scores[i] - lambda * penalty;
240
241 // Tie-break: lexicographic on the sentence text.
242 let is_better = match best_idx {
243 None => true,
244 Some(bi) => {
245 if eff > best_score {
246 true
247 } else if (eff - best_score).abs() < f32::EPSILON {
248 texts[i] < texts[bi]
249 } else {
250 false
251 }
252 }
253 };
254 if is_better {
255 best_idx = Some(i);
256 best_score = eff;
257 }
258 }
259
260 if let Some(bi) = best_idx {
261 picked.push(bi);
262 picked_set[bi] = true;
263 out_sentences.push(texts[bi].to_owned());
264 out_scores.push(best_score);
265 } else {
266 break;
267 }
268 }
269
270 Ok(Summary {
271 sentences: out_sentences,
272 scores: out_scores,
273 })
274}
275
276/// Cosine similarity. Returns 0.0 when either vector is zero-norm
277/// (no panics, no NaN).
278fn cosine(a: &[f32], b: &[f32]) -> f32 {
279 debug_assert_eq!(a.len(), b.len(), "cosine: dim mismatch");
280 let mut dot = 0.0_f32;
281 let mut na = 0.0_f32;
282 let mut nb = 0.0_f32;
283 for (x, y) in a.iter().zip(b.iter()) {
284 dot += x * y;
285 na += x * x;
286 nb += y * y;
287 }
288 let denom = na.sqrt() * nb.sqrt();
289 if denom <= f32::EPSILON {
290 0.0
291 } else {
292 dot / denom
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use mnem_embed_providers::MockEmbedder;
300
301 fn make_mock() -> MockEmbedder {
302 MockEmbedder::new("test:mock", 32)
303 }
304
305 #[test]
306 fn empty_input_returns_empty_summary() {
307 let e = make_mock();
308 let s = summarize_community(&[], &e, None, &|_| 1.0, 5, 0.5).unwrap();
309 assert!(s.sentences.is_empty());
310 assert!(s.scores.is_empty());
311 }
312
313 #[test]
314 fn k_zero_returns_empty() {
315 let e = make_mock();
316 let xs = vec!["a".to_string(), "b".to_string()];
317 let s = summarize_community(&xs, &e, None, &|_| 1.0, 0, 0.5).unwrap();
318 assert!(s.sentences.is_empty());
319 }
320
321 #[test]
322 fn k_larger_than_n_is_clamped() {
323 let e = make_mock();
324 let xs = vec!["a".to_string(), "b".to_string()];
325 let s = summarize_community(&xs, &e, None, &|_| 1.0, 99, 0.5).unwrap();
326 assert_eq!(s.sentences.len(), 2);
327 }
328}