use mnem_embed_providers::Embedder;
#[derive(Debug, Clone, PartialEq)]
pub struct SummaryItem {
pub sentence: String,
pub score: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Summary {
pub sentences: Vec<String>,
pub scores: Vec<f32>,
}
impl Summary {
#[must_use]
pub fn items(&self) -> Vec<SummaryItem> {
self.sentences
.iter()
.zip(self.scores.iter())
.map(|(s, &score)| SummaryItem {
sentence: s.clone(),
score,
})
.collect()
}
}
#[allow(clippy::too_many_arguments)]
pub fn summarize_community(
sentences: &[String],
embedder: &dyn Embedder,
query_embed: Option<&[f32]>,
centrality: &dyn Fn(usize) -> f32,
k: usize,
mmr_lambda: f32,
) -> Result<Summary, mnem_embed_providers::EmbedError> {
if sentences.is_empty() || k == 0 {
return Ok(Summary {
sentences: Vec::new(),
scores: Vec::new(),
});
}
let mut perm: Vec<usize> = (0..sentences.len()).collect();
perm.sort_by(|&a, &b| {
let ha = blake3::hash(sentences[a].as_bytes());
let hb = blake3::hash(sentences[b].as_bytes());
ha.as_bytes()
.cmp(hb.as_bytes())
.then_with(|| sentences[a].cmp(&sentences[b]))
});
let texts: Vec<&str> = perm.iter().map(|&i| sentences[i].as_str()).collect();
let embeds = embedder.embed_batch(&texts)?;
let dim = embedder.dim() as usize;
let mut centroid = vec![0.0_f32; dim];
for v in &embeds {
for (c, x) in centroid.iter_mut().zip(v.iter()) {
*c += *x;
}
}
let n_f = embeds.len() as f32;
for c in &mut centroid {
*c /= n_f;
}
if let Some(q) = query_embed
&& q.len() != dim
{
return Err(mnem_embed_providers::EmbedError::DimMismatch {
expected: embedder.dim(),
got: u32::try_from(q.len()).unwrap_or(u32::MAX),
});
}
let (alpha, beta, gamma) = if query_embed.is_some() {
(0.5_f32, 0.3_f32, 0.2_f32)
} else {
(0.8_f32, 0.0_f32, 0.2_f32)
};
let mut centralities_canon: Vec<f32> = Vec::with_capacity(perm.len());
for &orig_i in &perm {
let c = centrality(orig_i);
centralities_canon.push(c.max(0.0));
}
let max_centrality = centralities_canon
.iter()
.copied()
.fold(0.0_f32, f32::max)
.max(f32::EPSILON);
let base_scores: Vec<f32> = embeds
.iter()
.enumerate()
.map(|(i, v)| {
let s_cent = cosine(v, ¢roid);
let s_query = query_embed.map_or(0.0, |q| cosine(v, q));
let s_centrality = centralities_canon[i] / max_centrality;
alpha * s_cent + beta * s_query + gamma * s_centrality
})
.collect();
let lambda = mmr_lambda.clamp(0.0, 1.0);
let k_cap = k.min(embeds.len());
let mut picked: Vec<usize> = Vec::with_capacity(k_cap);
let mut picked_set = vec![false; embeds.len()];
let mut out_sentences: Vec<String> = Vec::with_capacity(k_cap);
let mut out_scores: Vec<f32> = Vec::with_capacity(k_cap);
while picked.len() < k_cap {
let mut best_idx: Option<usize> = None;
let mut best_score = f32::NEG_INFINITY;
for i in 0..embeds.len() {
if picked_set[i] {
continue;
}
let penalty = if picked.is_empty() {
0.0
} else {
picked
.iter()
.map(|&j| cosine(&embeds[i], &embeds[j]).clamp(0.0, 1.0))
.fold(0.0_f32, f32::max)
};
let eff = (1.0 - lambda) * base_scores[i] - lambda * penalty;
let is_better = match best_idx {
None => true,
Some(bi) => {
if eff > best_score {
true
} else if (eff - best_score).abs() < f32::EPSILON {
texts[i] < texts[bi]
} else {
false
}
}
};
if is_better {
best_idx = Some(i);
best_score = eff;
}
}
if let Some(bi) = best_idx {
picked.push(bi);
picked_set[bi] = true;
out_sentences.push(texts[bi].to_owned());
out_scores.push(best_score);
} else {
break;
}
}
Ok(Summary {
sentences: out_sentences,
scores: out_scores,
})
}
fn cosine(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "cosine: dim mismatch");
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
let denom = na.sqrt() * nb.sqrt();
if denom <= f32::EPSILON {
0.0
} else {
dot / denom
}
}
#[cfg(test)]
mod tests {
use super::*;
use mnem_embed_providers::MockEmbedder;
fn make_mock() -> MockEmbedder {
MockEmbedder::new("test:mock", 32)
}
#[test]
fn empty_input_returns_empty_summary() {
let e = make_mock();
let s = summarize_community(&[], &e, None, &|_| 1.0, 5, 0.5).unwrap();
assert!(s.sentences.is_empty());
assert!(s.scores.is_empty());
}
#[test]
fn k_zero_returns_empty() {
let e = make_mock();
let xs = vec!["a".to_string(), "b".to_string()];
let s = summarize_community(&xs, &e, None, &|_| 1.0, 0, 0.5).unwrap();
assert!(s.sentences.is_empty());
}
#[test]
fn k_larger_than_n_is_clamped() {
let e = make_mock();
let xs = vec!["a".to_string(), "b".to_string()];
let s = summarize_community(&xs, &e, None, &|_| 1.0, 99, 0.5).unwrap();
assert_eq!(s.sentences.len(), 2);
}
}