1use mnem_embed_providers::Embedder;
24use tracing::trace;
25use unicode_segmentation::UnicodeSegmentation;
26
27use crate::traits::{Entity, ExtractionSource, Extractor, Relation};
28
29pub const DEFAULT_NGRAM_RANGE: (usize, usize) = (1, 3);
33pub const DEFAULT_TOP_K: usize = 10;
35pub const DEFAULT_MMR_DIVERSITY: f32 = 0.5;
38
39pub struct KeyBertExtractor<'a> {
45 pub embedder: &'a dyn Embedder,
49 pub top_k: usize,
51 pub ngram_range: (usize, usize),
53 pub mmr_diversity: f32,
55}
56
57impl std::fmt::Debug for KeyBertExtractor<'_> {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("KeyBertExtractor")
63 .field("embedder_model", &self.embedder.model())
64 .field("embedder_dim", &self.embedder.dim())
65 .field("top_k", &self.top_k)
66 .field("ngram_range", &self.ngram_range)
67 .field("mmr_diversity", &self.mmr_diversity)
68 .finish()
69 }
70}
71
72impl<'a> KeyBertExtractor<'a> {
73 #[must_use]
75 pub fn new(embedder: &'a dyn Embedder) -> Self {
76 Self {
77 embedder,
78 top_k: DEFAULT_TOP_K,
79 ngram_range: DEFAULT_NGRAM_RANGE,
80 mmr_diversity: DEFAULT_MMR_DIVERSITY,
81 }
82 }
83
84 #[must_use]
86 pub const fn with_top_k(mut self, k: usize) -> Self {
87 self.top_k = k;
88 self
89 }
90
91 #[must_use]
94 pub const fn with_ngram_range(mut self, min: usize, max: usize) -> Self {
95 let min = if min == 0 { 1 } else { min };
96 let max = if max < min { min } else { max };
97 self.ngram_range = (min, max);
98 self
99 }
100
101 #[must_use]
105 pub fn with_mmr_diversity(mut self, lambda: f32) -> Self {
106 self.mmr_diversity = lambda.clamp(0.0, 1.0);
107 self
108 }
109}
110
111impl Extractor for KeyBertExtractor<'_> {
112 fn extract_entities(&self, text: &str, chunk_embed: &[f32]) -> Vec<Entity> {
113 let words: Vec<(usize, &str)> = text.unicode_word_indices().collect();
115 if words.is_empty() || chunk_embed.is_empty() {
116 return Vec::new();
117 }
118
119 let (min_n, max_n) = self.ngram_range;
122 let mut candidates: Vec<Candidate> = Vec::new();
123 let mut seen_keys: std::collections::BTreeMap<String, usize> =
124 std::collections::BTreeMap::new();
125 for start_idx in 0..words.len() {
126 for n in min_n..=max_n {
127 if start_idx + n > words.len() {
128 break;
129 }
130 let (first_byte, first_tok) = words[start_idx];
131 let (last_byte, last_tok) = words[start_idx + n - 1];
132 let end_byte = last_byte + last_tok.len();
133 let surface = &text[first_byte..end_byte];
136 let normalised = normalise(surface);
137 if normalised.is_empty() {
138 continue;
139 }
140 if (start_idx..start_idx + n).all(|i| is_stopword(words[i].1)) {
143 continue;
144 }
145 if !normalised.chars().any(char::is_alphanumeric) {
148 continue;
149 }
150 if n == 1 && first_tok.chars().count() < 2 {
153 continue;
154 }
155 let key = normalised.clone();
156 if let std::collections::btree_map::Entry::Vacant(e) = seen_keys.entry(key.clone())
157 {
158 e.insert(candidates.len());
159 candidates.push(Candidate {
160 key,
161 surface: surface.to_string(),
162 span: (first_byte, end_byte),
163 });
164 }
165 }
166 }
167
168 if candidates.is_empty() {
169 return Vec::new();
170 }
171
172 candidates.sort_by(|a, b| a.key.cmp(&b.key));
176
177 let mut scored: Vec<Scored> = Vec::with_capacity(candidates.len());
184 let surfaces: Vec<&str> = candidates.iter().map(|c| c.surface.as_str()).collect();
185 match self.embedder.embed_batch(&surfaces) {
186 Ok(vecs) => {
187 for (c, vec) in candidates.iter().zip(vecs) {
188 if vec.len() != chunk_embed.len() {
189 trace!(
190 cand = %c.key,
191 expected = chunk_embed.len(),
192 got = vec.len(),
193 "dim mismatch, skipping candidate",
194 );
195 continue;
196 }
197 let sim = cosine(&vec, chunk_embed);
198 scored.push(Scored {
199 candidate: c.clone(),
200 embed: vec,
201 sim,
202 });
203 }
204 }
205 Err(batch_err) => {
206 trace!(
211 ?batch_err,
212 "embed_batch failed, falling back to per-candidate"
213 );
214 for c in &candidates {
215 match self.embedder.embed(&c.surface) {
216 Ok(vec) => {
217 if vec.len() != chunk_embed.len() {
218 trace!(
219 cand = %c.key,
220 expected = chunk_embed.len(),
221 got = vec.len(),
222 "dim mismatch, skipping candidate",
223 );
224 continue;
225 }
226 let sim = cosine(&vec, chunk_embed);
227 scored.push(Scored {
228 candidate: c.clone(),
229 embed: vec,
230 sim,
231 });
232 }
233 Err(err) => {
234 trace!(cand = %c.key, ?err, "embed failed, skipping candidate");
235 }
236 }
237 }
238 }
239 }
240 if scored.is_empty() {
241 return Vec::new();
242 }
243
244 let picks = mmr_select(&scored, self.top_k, self.mmr_diversity);
246 picks
247 .into_iter()
248 .map(|(s, mmr_score)| Entity {
249 mention: s.candidate.surface.clone(),
250 #[allow(clippy::cast_possible_truncation)]
251 score: (mmr_score as f32).clamp(-1.0, 1.0),
252 span: s.candidate.span,
253 })
254 .collect()
255 }
256
257 fn extract_relations(&self, text: &str, entities: &[Entity]) -> Vec<Relation> {
258 crate::cooccurrence::mine_relations(
259 text,
260 entities,
261 crate::cooccurrence::DEFAULT_PMI_THRESHOLD,
262 ExtractionSource::Statistical,
263 )
264 }
265}
266
267#[derive(Debug, Clone)]
270struct Candidate {
271 key: String,
275 surface: String,
276 span: (usize, usize),
277}
278
279#[derive(Debug, Clone)]
280struct Scored {
281 candidate: Candidate,
282 embed: Vec<f32>,
283 sim: f64,
284}
285
286fn normalise(s: &str) -> String {
289 let mut out = String::with_capacity(s.len());
290 let mut prev_ws = true;
291 for ch in s.chars() {
292 if ch.is_whitespace() {
293 if !prev_ws {
294 out.push(' ');
295 prev_ws = true;
296 }
297 } else {
298 for lc in ch.to_lowercase() {
299 out.push(lc);
300 }
301 prev_ws = false;
302 }
303 }
304 if out.ends_with(' ') {
305 out.pop();
306 }
307 out
308}
309
310#[rustfmt::skip]
315const STOPWORDS: &[&str] = &[
316 "a", "an", "and", "are", "as", "at", "be", "but", "by", "for",
317 "from", "has", "have", "he", "her", "hers", "him", "his", "i",
318 "if", "in", "into", "is", "it", "its", "me", "my", "no", "not",
319 "of", "on", "or", "our", "ours", "over", "she", "so", "that",
320 "the", "their", "theirs", "them", "then", "there", "they",
321 "this", "those", "to", "too", "us", "was", "we", "were", "what",
322 "when", "where", "which", "while", "who", "whom", "why", "will",
323 "with", "you", "your", "yours",
324];
325
326fn is_stopword(tok: &str) -> bool {
327 let lc: String = tok.chars().flat_map(char::to_lowercase).collect();
328 STOPWORDS.binary_search(&lc.as_str()).is_ok()
329}
330
331fn cosine(a: &[f32], b: &[f32]) -> f64 {
334 debug_assert_eq!(a.len(), b.len());
335 let mut dot = 0.0_f64;
336 let mut na = 0.0_f64;
337 let mut nb = 0.0_f64;
338 for (x, y) in a.iter().zip(b.iter()) {
339 let xf = f64::from(*x);
340 let yf = f64::from(*y);
341 dot += xf * yf;
342 na += xf * xf;
343 nb += yf * yf;
344 }
345 if na <= 0.0 || nb <= 0.0 {
346 return 0.0;
347 }
348 dot / (na.sqrt() * nb.sqrt())
349}
350
351fn mmr_select(scored: &[Scored], top_k: usize, lambda: f32) -> Vec<(Scored, f64)> {
357 let lambda = f64::from(lambda);
358 let k = top_k.min(scored.len());
359 let mut picks: Vec<(Scored, f64)> = Vec::with_capacity(k);
360 let mut remaining: Vec<usize> = (0..scored.len()).collect();
361
362 while picks.len() < k && !remaining.is_empty() {
363 let mut best_idx_in_remaining: Option<usize> = None;
364 let mut best_score: f64 = f64::NEG_INFINITY;
365 let mut best_key: Option<&str> = None;
366 for (pos, &i) in remaining.iter().enumerate() {
367 let c = &scored[i];
368 let redundancy = picks
369 .iter()
370 .map(|(p, _)| cosine(&c.embed, &p.embed))
371 .fold(f64::NEG_INFINITY, f64::max)
372 .max(0.0_f64);
373 let redundancy = if picks.is_empty() { 0.0 } else { redundancy };
374 let mmr = c.sim - lambda * redundancy;
375 let tiebreak = c.candidate.key.as_str();
376 let better = mmr > best_score
377 || (approx_eq(mmr, best_score) && best_key.is_none_or(|bk| tiebreak < bk));
378 if better {
379 best_score = mmr;
380 best_idx_in_remaining = Some(pos);
381 best_key = Some(tiebreak);
382 }
383 }
384 match best_idx_in_remaining {
385 Some(pos) => {
386 let i = remaining.swap_remove(pos);
387 picks.push((scored[i].clone(), best_score));
388 }
389 None => break,
390 }
391 }
392 picks
393}
394
395fn approx_eq(a: f64, b: f64) -> bool {
397 (a - b).abs() < 1e-9
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn normalise_collapses_whitespace_and_lowercases() {
406 assert_eq!(normalise(" Hello World "), "hello world");
407 assert_eq!(normalise("MixedCase"), "mixedcase");
408 }
409
410 #[test]
411 fn stopwords_are_sorted_for_binary_search() {
412 let mut sorted = STOPWORDS.to_vec();
413 sorted.sort_unstable();
414 assert_eq!(sorted.as_slice(), STOPWORDS);
415 }
416
417 #[test]
418 fn cosine_identity() {
419 let v = vec![1.0_f32, 2.0, 3.0];
420 let c = cosine(&v, &v);
421 assert!((c - 1.0).abs() < 1e-9, "cosine(v, v) = {c}");
422 }
423
424 #[test]
425 fn cosine_zero_magnitude_returns_zero() {
426 let a = vec![0.0_f32; 8];
427 let b = vec![1.0_f32; 8];
428 assert_eq!(cosine(&a, &b), 0.0);
429 }
430}