Skip to main content

mnem_extract/
keybert.rs

1//! KeyBERT-style statistical keyword / entity extractor.
2//!
3//! The algorithm mirrors the canonical KeyBERT pipeline (Grootendorst
4//! 2020) but runs entirely against mnem's existing synchronous
5//! [`Embedder`] trait - no Python, no sklearn, no ONNX-only path.
6//!
7//! 1. **Tokenise** the chunk text with `unicode-segmentation` into
8//! words and sentence boundaries.
9//! 2. **Enumerate n-gram candidates** of length `ngram_range.0 ..=
10//! ngram_range.1`, skipping candidates that are pure stop-word
11//! sequences.
12//! 3. **Deduplicate** candidates to their earliest span; sort the
13//! deduped list lexicographically for determinism.
14//! 4. **Embed** each candidate via `Embedder::embed`; compute cosine
15//! similarity against the caller-supplied `chunk_embed`.
16//! 5. **MMR-diversify**: iteratively pick the highest-scoring
17//! candidate after subtracting `mmr_diversity * max_sim_to_picked`.
18//! Stable lex tiebreaks on exact-tie scores.
19//!
20//! The implementation allocates once per call and keeps all scoring in
21//! `f64` to dodge `f32` summation drift on long inputs.
22
23use mnem_embed_providers::Embedder;
24use tracing::trace;
25use unicode_segmentation::UnicodeSegmentation;
26
27use crate::traits::{Entity, ExtractionSource, Extractor, Relation};
28
29/// Default KeyBERT extractor parameters. Picked to match the KeyBERT
30/// paper's out-of-the-box behaviour: 1–3-grams, top-10 keywords,
31/// MMR diversity 0.5.
32pub const DEFAULT_NGRAM_RANGE: (usize, usize) = (1, 3);
33/// Number of entities returned per call by default.
34pub const DEFAULT_TOP_K: usize = 10;
35/// Default MMR diversity coefficient (λ in the KeyBERT paper).
36/// 0.0 → pure cosine ranking; 1.0 → maximal redundancy penalty.
37pub const DEFAULT_MMR_DIVERSITY: f32 = 0.5;
38
39/// KeyBERT-style extractor.
40///
41/// Holds a borrowed [`Embedder`] reference; the caller owns the
42/// concrete provider (Ollama / OpenAI / ONNX / mock) and threads it in
43/// for the duration of an ingest run.
44pub struct KeyBertExtractor<'a> {
45    /// Embedder used to encode candidate n-grams. MUST be the same
46    /// provider + model that produced `chunk_embed`, otherwise cosine
47    /// similarity is meaningless.
48    pub embedder: &'a dyn Embedder,
49    /// Number of entities to return per call. See [`DEFAULT_TOP_K`].
50    pub top_k: usize,
51    /// Inclusive `(min_n, max_n)` n-gram length range.
52    pub ngram_range: (usize, usize),
53    /// MMR diversity coefficient in `[0.0, 1.0]`.
54    pub mmr_diversity: f32,
55}
56
57impl std::fmt::Debug for KeyBertExtractor<'_> {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        // `dyn Embedder` is not Debug; project the relevant fields
60        // instead so tracing / assertion output still identifies the
61        // configured model.
62        f.debug_struct("KeyBertExtractor")
63            .field("embedder_model", &self.embedder.model())
64            .field("embedder_dim", &self.embedder.dim())
65            .field("top_k", &self.top_k)
66            .field("ngram_range", &self.ngram_range)
67            .field("mmr_diversity", &self.mmr_diversity)
68            .finish()
69    }
70}
71
72impl<'a> KeyBertExtractor<'a> {
73    /// Construct a KeyBERT extractor with default parameters.
74    #[must_use]
75    pub fn new(embedder: &'a dyn Embedder) -> Self {
76        Self {
77            embedder,
78            top_k: DEFAULT_TOP_K,
79            ngram_range: DEFAULT_NGRAM_RANGE,
80            mmr_diversity: DEFAULT_MMR_DIVERSITY,
81        }
82    }
83
84    /// Override `top_k`. Returns `self` for chaining.
85    #[must_use]
86    pub const fn with_top_k(mut self, k: usize) -> Self {
87        self.top_k = k;
88        self
89    }
90
91    /// Override n-gram range. `min` must be >= 1; callers that pass 0
92    /// are clamped to 1. Returns `self` for chaining.
93    #[must_use]
94    pub const fn with_ngram_range(mut self, min: usize, max: usize) -> Self {
95        let min = if min == 0 { 1 } else { min };
96        let max = if max < min { min } else { max };
97        self.ngram_range = (min, max);
98        self
99    }
100
101    /// Override MMR diversity coefficient. Returns `self` for
102    /// chaining. Callers passing values outside `[0.0, 1.0]` get them
103    /// clamped.
104    #[must_use]
105    pub fn with_mmr_diversity(mut self, lambda: f32) -> Self {
106        self.mmr_diversity = lambda.clamp(0.0, 1.0);
107        self
108    }
109}
110
111impl Extractor for KeyBertExtractor<'_> {
112    fn extract_entities(&self, text: &str, chunk_embed: &[f32]) -> Vec<Entity> {
113        // 1. collect word spans.
114        let words: Vec<(usize, &str)> = text.unicode_word_indices().collect();
115        if words.is_empty() || chunk_embed.is_empty() {
116            return Vec::new();
117        }
118
119        // 2. enumerate n-gram candidates of length min..=max, deduped
120        // to their first occurrence span.
121        let (min_n, max_n) = self.ngram_range;
122        let mut candidates: Vec<Candidate> = Vec::new();
123        let mut seen_keys: std::collections::BTreeMap<String, usize> =
124            std::collections::BTreeMap::new();
125        for start_idx in 0..words.len() {
126            for n in min_n..=max_n {
127                if start_idx + n > words.len() {
128                    break;
129                }
130                let (first_byte, first_tok) = words[start_idx];
131                let (last_byte, last_tok) = words[start_idx + n - 1];
132                let end_byte = last_byte + last_tok.len();
133                // collect surface form over the exact byte span so
134                // punctuation between words is preserved when present.
135                let surface = &text[first_byte..end_byte];
136                let normalised = normalise(surface);
137                if normalised.is_empty() {
138                    continue;
139                }
140                // reject pure-stopword n-grams (single-token
141                // stopwords are also rejected).
142                if (start_idx..start_idx + n).all(|i| is_stopword(words[i].1)) {
143                    continue;
144                }
145                // skip candidates that don't have any alphanumeric
146                // content (e.g. all-punctuation windows).
147                if !normalised.chars().any(char::is_alphanumeric) {
148                    continue;
149                }
150                // For short single-word candidates, require length > 1
151                // to drop noise like "a", "I".
152                if n == 1 && first_tok.chars().count() < 2 {
153                    continue;
154                }
155                let key = normalised.clone();
156                if let std::collections::btree_map::Entry::Vacant(e) = seen_keys.entry(key.clone())
157                {
158                    e.insert(candidates.len());
159                    candidates.push(Candidate {
160                        key,
161                        surface: surface.to_string(),
162                        span: (first_byte, end_byte),
163                    });
164                }
165            }
166        }
167
168        if candidates.is_empty() {
169            return Vec::new();
170        }
171
172        // 3. sort candidates lexicographically before embedding - this
173        // is the determinism anchor even if the enumeration loop
174        // order changes.
175        candidates.sort_by(|a, b| a.key.cmp(&b.key));
176
177        // 4. embed every candidate via the provider's batch call when
178        // available (5-10x faster on ONNX/OpenAI vs the per-text
179        // loop; Ollama transparently falls back to sequential
180        // `embed` per its `embed_batch` default impl). Compute
181        // cosine vs `chunk_embed`. Result vectors line up with
182        // `candidates` index-for-index.
183        let mut scored: Vec<Scored> = Vec::with_capacity(candidates.len());
184        let surfaces: Vec<&str> = candidates.iter().map(|c| c.surface.as_str()).collect();
185        match self.embedder.embed_batch(&surfaces) {
186            Ok(vecs) => {
187                for (c, vec) in candidates.iter().zip(vecs) {
188                    if vec.len() != chunk_embed.len() {
189                        trace!(
190                        cand = %c.key,
191                        expected = chunk_embed.len(),
192                        got = vec.len(),
193                        "dim mismatch, skipping candidate",
194                        );
195                        continue;
196                    }
197                    let sim = cosine(&vec, chunk_embed);
198                    scored.push(Scored {
199                        candidate: c.clone(),
200                        embed: vec,
201                        sim,
202                    });
203                }
204            }
205            Err(batch_err) => {
206                // Per-candidate fallback: a single bad input shouldn't
207                // wipe a chunk's entire extraction. Keep the same
208                // skip-on-error / dim-mismatch contract as the
209                // pre-batch implementation.
210                trace!(
211                    ?batch_err,
212                    "embed_batch failed, falling back to per-candidate"
213                );
214                for c in &candidates {
215                    match self.embedder.embed(&c.surface) {
216                        Ok(vec) => {
217                            if vec.len() != chunk_embed.len() {
218                                trace!(
219                                cand = %c.key,
220                                expected = chunk_embed.len(),
221                                got = vec.len(),
222                                "dim mismatch, skipping candidate",
223                                );
224                                continue;
225                            }
226                            let sim = cosine(&vec, chunk_embed);
227                            scored.push(Scored {
228                                candidate: c.clone(),
229                                embed: vec,
230                                sim,
231                            });
232                        }
233                        Err(err) => {
234                            trace!(cand = %c.key, ?err, "embed failed, skipping candidate");
235                        }
236                    }
237                }
238            }
239        }
240        if scored.is_empty() {
241            return Vec::new();
242        }
243
244        // 5. MMR diversify.
245        let picks = mmr_select(&scored, self.top_k, self.mmr_diversity);
246        picks
247            .into_iter()
248            .map(|(s, mmr_score)| Entity {
249                mention: s.candidate.surface.clone(),
250                #[allow(clippy::cast_possible_truncation)]
251                score: (mmr_score as f32).clamp(-1.0, 1.0),
252                span: s.candidate.span,
253            })
254            .collect()
255    }
256
257    fn extract_relations(&self, text: &str, entities: &[Entity]) -> Vec<Relation> {
258        crate::cooccurrence::mine_relations(
259            text,
260            entities,
261            crate::cooccurrence::DEFAULT_PMI_THRESHOLD,
262            ExtractionSource::Statistical,
263        )
264    }
265}
266
267// -------- internals --------
268
269#[derive(Debug, Clone)]
270struct Candidate {
271    /// Lower-cased, whitespace-normalised lookup key used for dedup
272    /// and lex sort. The original surface form is preserved separately
273    /// so rendering stays faithful to the source.
274    key: String,
275    surface: String,
276    span: (usize, usize),
277}
278
279#[derive(Debug, Clone)]
280struct Scored {
281    candidate: Candidate,
282    embed: Vec<f32>,
283    sim: f64,
284}
285
286/// Lower-case + collapse whitespace. Pure-Rust, no regex so the crate
287/// stays tiny; the extractor runs per candidate so the loop is cheap.
288fn normalise(s: &str) -> String {
289    let mut out = String::with_capacity(s.len());
290    let mut prev_ws = true;
291    for ch in s.chars() {
292        if ch.is_whitespace() {
293            if !prev_ws {
294                out.push(' ');
295                prev_ws = true;
296            }
297        } else {
298            for lc in ch.to_lowercase() {
299                out.push(lc);
300            }
301            prev_ws = false;
302        }
303    }
304    if out.ends_with(' ') {
305        out.pop();
306    }
307    out
308}
309
310/// Minimal English stop-word list. Deliberately short - the goal is
311/// to reject pure-stopword n-grams like "the dog" from dominating the
312/// top-k, not to replicate NLTK's 180-word list. Callers that need
313/// broader coverage should post-filter the returned Entities.
314#[rustfmt::skip]
315const STOPWORDS: &[&str] = &[
316 "a", "an", "and", "are", "as", "at", "be", "but", "by", "for",
317 "from", "has", "have", "he", "her", "hers", "him", "his", "i",
318 "if", "in", "into", "is", "it", "its", "me", "my", "no", "not",
319 "of", "on", "or", "our", "ours", "over", "she", "so", "that",
320 "the", "their", "theirs", "them", "then", "there", "they",
321 "this", "those", "to", "too", "us", "was", "we", "were", "what",
322 "when", "where", "which", "while", "who", "whom", "why", "will",
323 "with", "you", "your", "yours",
324];
325
326fn is_stopword(tok: &str) -> bool {
327    let lc: String = tok.chars().flat_map(char::to_lowercase).collect();
328    STOPWORDS.binary_search(&lc.as_str()).is_ok()
329}
330
331/// Cosine similarity in `f64` to avoid `f32` accumulation drift.
332/// Returns 0.0 for zero-magnitude inputs rather than `NaN`.
333fn cosine(a: &[f32], b: &[f32]) -> f64 {
334    debug_assert_eq!(a.len(), b.len());
335    let mut dot = 0.0_f64;
336    let mut na = 0.0_f64;
337    let mut nb = 0.0_f64;
338    for (x, y) in a.iter().zip(b.iter()) {
339        let xf = f64::from(*x);
340        let yf = f64::from(*y);
341        dot += xf * yf;
342        na += xf * xf;
343        nb += yf * yf;
344    }
345    if na <= 0.0 || nb <= 0.0 {
346        return 0.0;
347    }
348    dot / (na.sqrt() * nb.sqrt())
349}
350
351/// Iteratively select up to `top_k` candidates by MMR.
352///
353/// Score function: `sim(cand, chunk) - lambda * max_i sim(cand, picked_i)`.
354/// Ties on score are broken by the candidate `key` (lex order) for
355/// determinism across runs / platforms.
356fn mmr_select(scored: &[Scored], top_k: usize, lambda: f32) -> Vec<(Scored, f64)> {
357    let lambda = f64::from(lambda);
358    let k = top_k.min(scored.len());
359    let mut picks: Vec<(Scored, f64)> = Vec::with_capacity(k);
360    let mut remaining: Vec<usize> = (0..scored.len()).collect();
361
362    while picks.len() < k && !remaining.is_empty() {
363        let mut best_idx_in_remaining: Option<usize> = None;
364        let mut best_score: f64 = f64::NEG_INFINITY;
365        let mut best_key: Option<&str> = None;
366        for (pos, &i) in remaining.iter().enumerate() {
367            let c = &scored[i];
368            let redundancy = picks
369                .iter()
370                .map(|(p, _)| cosine(&c.embed, &p.embed))
371                .fold(f64::NEG_INFINITY, f64::max)
372                .max(0.0_f64);
373            let redundancy = if picks.is_empty() { 0.0 } else { redundancy };
374            let mmr = c.sim - lambda * redundancy;
375            let tiebreak = c.candidate.key.as_str();
376            let better = mmr > best_score
377                || (approx_eq(mmr, best_score) && best_key.is_none_or(|bk| tiebreak < bk));
378            if better {
379                best_score = mmr;
380                best_idx_in_remaining = Some(pos);
381                best_key = Some(tiebreak);
382            }
383        }
384        match best_idx_in_remaining {
385            Some(pos) => {
386                let i = remaining.swap_remove(pos);
387                picks.push((scored[i].clone(), best_score));
388            }
389            None => break,
390        }
391    }
392    picks
393}
394
395/// `f64` equality up to 1e-9; enough for our cosine-derived tiebreaks.
396fn approx_eq(a: f64, b: f64) -> bool {
397    (a - b).abs() < 1e-9
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn normalise_collapses_whitespace_and_lowercases() {
406        assert_eq!(normalise(" Hello World "), "hello world");
407        assert_eq!(normalise("MixedCase"), "mixedcase");
408    }
409
410    #[test]
411    fn stopwords_are_sorted_for_binary_search() {
412        let mut sorted = STOPWORDS.to_vec();
413        sorted.sort_unstable();
414        assert_eq!(sorted.as_slice(), STOPWORDS);
415    }
416
417    #[test]
418    fn cosine_identity() {
419        let v = vec![1.0_f32, 2.0, 3.0];
420        let c = cosine(&v, &v);
421        assert!((c - 1.0).abs() < 1e-9, "cosine(v, v) = {c}");
422    }
423
424    #[test]
425    fn cosine_zero_magnitude_returns_zero() {
426        let a = vec![0.0_f32; 8];
427        let b = vec![1.0_f32; 8];
428        assert_eq!(cosine(&a, &b), 0.0);
429    }
430}