Skip to main content

mnem_ingest/
extract_keybert.rs

1//! Adapter that lets a [`mnem_extract::KeyBertExtractor`] drop into the
2//! [`crate::extract::Extractor`] slot on [`crate::pipeline::Ingester`].
3//!
4//! Two-phase contract:
5//!
6//! 1. [`Extractor::prepare`] is called once per file with every
7//! section the parser produced. The adapter collects unique
8//! section texts and runs them through `Embedder::embed_batch` in
9//! a single ORT session.run, caching the resulting vectors.
10//! 2. [`Extractor::extract_entities`] is then called per (section,
11//! chunk) pair by the pipeline; the adapter looks up the cached
12//! section embedding and runs KeyBERT candidate ranking + MMR
13//! against it. On a cache miss (e.g. a caller that did not invoke
14//! `prepare`) the adapter falls back to a single-section
15//! `Embedder::embed`, preserving the original drop-in contract.
16//!
17//! pre-batching the section pass turns the
18//! Bible-scale walltime bottleneck (~1 sequential ORT call per
19//! section, dominated by long chapters) into a single ORT batch per
20//! file. Same vectors land in `Node.embed`; only the wall-time
21//! changes.
22//!
23//! Relations returned by the statistical miner are mapped to the
24//! existing [`RelationSpan`] shape with predicate `"co_occurs_with"`.
25//!
26//! Gated behind the `keybert` cargo feature so callers who ship only
27//! the rule-based baseline pay zero compile / binary cost.
28
29use std::collections::HashMap;
30use std::sync::{Arc, Mutex};
31
32use mnem_embed_providers::Embedder;
33use mnem_extract::{Extractor as StatisticalExtractor, KeyBertExtractor};
34
35use crate::extract::{EntitySpan, Extractor, RelationSpan};
36use crate::types::Section;
37
38/// Predicate emitted for co-occurrence edges by the KeyBERT adapter.
39/// Mirrors the string the rule-based extractor uses so downstream
40/// graph consumers don't need to learn a new vocabulary.
41pub const KEYBERT_RELATION_LABEL: &str = "co_occurs_with";
42
43/// Confidence stamped onto every [`EntitySpan`] emitted by the
44/// adapter. Statistical extraction has a genuine score per candidate
45/// (cosine post-MMR), but the ingest pipeline's [`EntitySpan::confidence`]
46/// is constrained to `[0.0, 1.0]`; we preserve that by clamping.
47pub const KEYBERT_MIN_CONFIDENCE: f32 = 0.0;
48
49/// KeyBERT-backed [`Extractor`] adapter.
50///
51/// Construct with [`KeyBertAdapter::new`]; hand to
52/// [`crate::pipeline::Ingester::with_extractor`] in place of the
53/// default [`crate::extract::RuleExtractor`].
54pub struct KeyBertAdapter {
55    embedder: Arc<dyn Embedder>,
56    top_k: usize,
57    ngram_range: (usize, usize),
58    mmr_diversity: f32,
59    pmi_threshold: f32,
60    /// ntype label stamped on every entity this adapter emits.
61    /// Callers set this via [`KeyBertAdapter::with_label`]; there is no
62    /// built-in default — the label vocabulary is entirely up to the caller.
63    label: String,
64    /// Section-text → embedding cache. Populated by [`Extractor::prepare`]
65    /// in one batched `Embedder::embed_batch` call per file; queried
66    /// by [`Extractor::extract_entities`] on every (section, chunk)
67    /// pair the pipeline iterates over. Misses fall back to a single
68    /// `Embedder::embed`, so callers who skip `prepare` still get
69    /// correct behaviour. Keyed on the literal section text - same
70    /// section content across files therefore reuses one entry,
71    /// which is the dedup property `prepare` relies on internally.
72    section_cache: Mutex<HashMap<String, Vec<f32>>>,
73}
74
75impl std::fmt::Debug for KeyBertAdapter {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        let cached = self.section_cache.lock().map(|c| c.len()).unwrap_or(0);
78        f.debug_struct("KeyBertAdapter")
79            .field("embedder_model", &self.embedder.model())
80            .field("embedder_dim", &self.embedder.dim())
81            .field("top_k", &self.top_k)
82            .field("ngram_range", &self.ngram_range)
83            .field("mmr_diversity", &self.mmr_diversity)
84            .field("pmi_threshold", &self.pmi_threshold)
85            .field("label", &self.label)
86            .field("section_cache_len", &cached)
87            .finish()
88    }
89}
90
91impl KeyBertAdapter {
92    /// Build an adapter around the supplied embedder with KeyBERT defaults
93    /// (`top_k = 10`, `ngram_range = (1, 3)`, `mmr_diversity = 0.5`,
94    /// `pmi_threshold = 1.0`).
95    ///
96    /// `label` is the ntype string stamped on every entity this adapter emits.
97    /// The caller owns the vocabulary — pass whatever label fits your graph
98    /// (e.g. `"Keyword"`, `"Tag"`, `"Concept"`, or any domain-specific type).
99    #[must_use]
100    pub fn new(embedder: Arc<dyn Embedder>, label: impl Into<String>) -> Self {
101        Self {
102            embedder,
103            top_k: mnem_extract::keybert::DEFAULT_TOP_K,
104            ngram_range: mnem_extract::keybert::DEFAULT_NGRAM_RANGE,
105            mmr_diversity: mnem_extract::keybert::DEFAULT_MMR_DIVERSITY,
106            pmi_threshold: mnem_extract::cooccurrence::DEFAULT_PMI_THRESHOLD,
107            label: label.into(),
108            section_cache: Mutex::new(HashMap::new()),
109        }
110    }
111
112    /// Override the entity label. Returns `self` for chaining.
113    #[must_use]
114    pub fn with_label(mut self, label: impl Into<String>) -> Self {
115        self.label = label.into();
116        self
117    }
118
119    /// Override `top_k`. Returns `self` for chaining.
120    #[must_use]
121    pub const fn with_top_k(mut self, k: usize) -> Self {
122        self.top_k = k;
123        self
124    }
125
126    /// Override the PMI threshold used when mining co-occurrence
127    /// edges. Returns `self` for chaining.
128    #[must_use]
129    pub const fn with_pmi_threshold(mut self, t: f32) -> Self {
130        self.pmi_threshold = t;
131        self
132    }
133}
134
135impl Extractor for KeyBertAdapter {
136    fn prepare(&self, sections: &[Section]) -> Result<(), crate::error::Error> {
137        // Collect unique non-empty section texts. Pipelines that
138        // re-emit identical content across sections (e.g. boilerplate
139        // headers) pay only once.
140        let mut unique: Vec<&str> = Vec::with_capacity(sections.len());
141        let mut seen: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
142        for s in sections {
143            if s.text.is_empty() {
144                continue;
145            }
146            if seen.insert(s.text.as_str()) {
147                unique.push(s.text.as_str());
148            }
149        }
150        if unique.is_empty() {
151            return Ok(());
152        }
153
154        // Best-effort batch embed. Failure here downgrades silently
155        // to the per-section lazy path in `extract_entities` (which
156        // has its own error swallow), so a transient embedder hiccup
157        // never aborts the whole file ingest. This matches the
158        // legacy "skip section on embed failure" behaviour the
159        // adapter shipped with before pre-batching landed.
160        let vecs = match self.embedder.embed_batch(&unique) {
161            Ok(v) => v,
162            Err(_e) => return Ok(()),
163        };
164
165        if let Ok(mut cache) = self.section_cache.lock() {
166            // Store result indexed by the same text key
167            // `extract_entities` will look up. `embed_batch`'s
168            // contract preserves order, so unique[i] aligns with
169            // vecs[i].
170            for (text, vec) in unique.into_iter().zip(vecs) {
171                cache.entry(text.to_string()).or_insert(vec);
172            }
173        }
174        Ok(())
175    }
176
177    fn extract_entities(&self, section: &Section) -> Vec<EntitySpan> {
178        let text = &section.text;
179        if text.is_empty() {
180            return Vec::new();
181        }
182
183        // Cache hit path: `prepare` populated this entry in one
184        // batched ORT call; we just clone the f32 vector. Cache miss
185        // path: caller skipped `prepare` (or the batch failed at
186        // prepare time); embed the section in a single call so the
187        // adapter still works end-to-end. Either path produces the
188        // same vector for the same text on the same embedder.
189        let cached = self
190            .section_cache
191            .lock()
192            .ok()
193            .and_then(|cache| cache.get(text).cloned());
194        let section_embed = match cached {
195            Some(v) => v,
196            None => match self.embedder.embed(text) {
197                Ok(v) => v,
198                Err(_) => return Vec::new(),
199            },
200        };
201
202        let kb = KeyBertExtractor {
203            embedder: self.embedder.as_ref(),
204            top_k: self.top_k,
205            ngram_range: self.ngram_range,
206            mmr_diversity: self.mmr_diversity,
207        };
208        let entities = kb.extract_entities(text, &section_embed);
209        entities
210            .into_iter()
211            .map(|e| EntitySpan {
212                kind: self.label.clone(),
213                text: e.mention,
214                byte_range: e.span.0..e.span.1,
215                confidence: e.score.clamp(KEYBERT_MIN_CONFIDENCE, 1.0),
216            })
217            .collect()
218    }
219
220    fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
221        if entities.len() < 2 {
222            return Vec::new();
223        }
224        // Map EntitySpan → mnem_extract::Entity keeping the original
225        // index so we can refer back to it.
226        let bridged: Vec<mnem_extract::Entity> = entities
227            .iter()
228            .map(|e| mnem_extract::Entity {
229                mention: e.text.clone(),
230                score: e.confidence,
231                span: (e.byte_range.start, e.byte_range.end),
232            })
233            .collect();
234        let rels = mnem_extract::mine_relations(
235            &section.text,
236            &bridged,
237            self.pmi_threshold,
238            mnem_extract::ExtractionSource::Statistical,
239        );
240
241        // Reverse-lookup each Relation.src / .dst back to the
242        // EntitySpan index the pipeline expects.
243        let index_of =
244            |mention: &str| -> Option<usize> { entities.iter().position(|e| e.text == mention) };
245        let mut out = Vec::with_capacity(rels.len());
246        for r in rels {
247            let (Some(si), Some(oi)) = (index_of(&r.src), index_of(&r.dst)) else {
248                continue;
249            };
250            out.push(RelationSpan {
251                kind: KEYBERT_RELATION_LABEL.to_string(),
252                subject_span: si,
253                object_span: oi,
254                confidence: r.weight.clamp(0.0, 1.0),
255            });
256        }
257        out
258    }
259}