mnem_ingest/extract_keybert.rs
1//! Adapter that lets a [`mnem_extract::KeyBertExtractor`] drop into the
2//! [`crate::extract::Extractor`] slot on [`crate::pipeline::Ingester`].
3//!
4//! Two-phase contract:
5//!
6//! 1. [`Extractor::prepare`] is called once per file with every
7//! section the parser produced. The adapter collects unique
8//! section texts and runs them through `Embedder::embed_batch` in
9//! a single ORT session.run, caching the resulting vectors.
10//! 2. [`Extractor::extract_entities`] is then called per (section,
11//! chunk) pair by the pipeline; the adapter looks up the cached
12//! section embedding and runs KeyBERT candidate ranking + MMR
13//! against it. On a cache miss (e.g. a caller that did not invoke
14//! `prepare`) the adapter falls back to a single-section
15//! `Embedder::embed`, preserving the original drop-in contract.
16//!
17//! pre-batching the section pass turns the
18//! Bible-scale walltime bottleneck (~1 sequential ORT call per
19//! section, dominated by long chapters) into a single ORT batch per
20//! file. Same vectors land in `Node.embed`; only the wall-time
21//! changes.
22//!
23//! Relations returned by the statistical miner are mapped to the
24//! existing [`RelationSpan`] shape with predicate `"co_occurs_with"`.
25//!
26//! Gated behind the `keybert` cargo feature so callers who ship only
27//! the rule-based baseline pay zero compile / binary cost.
28
29use std::collections::HashMap;
30use std::sync::{Arc, Mutex};
31
32use mnem_embed_providers::Embedder;
33use mnem_extract::{Extractor as StatisticalExtractor, KeyBertExtractor};
34
35use crate::extract::{EntitySpan, Extractor, RelationSpan};
36use crate::types::Section;
37
38/// Predicate emitted for co-occurrence edges by the KeyBERT adapter.
39/// Mirrors the string the rule-based extractor uses so downstream
40/// graph consumers don't need to learn a new vocabulary.
41pub const KEYBERT_RELATION_LABEL: &str = "co_occurs_with";
42
43/// Confidence stamped onto every [`EntitySpan`] emitted by the
44/// adapter. Statistical extraction has a genuine score per candidate
45/// (cosine post-MMR), but the ingest pipeline's [`EntitySpan::confidence`]
46/// is constrained to `[0.0, 1.0]`; we preserve that by clamping.
47pub const KEYBERT_MIN_CONFIDENCE: f32 = 0.0;
48
49/// KeyBERT-backed [`Extractor`] adapter.
50///
51/// Construct with [`KeyBertAdapter::new`]; hand to
52/// [`crate::pipeline::Ingester::with_extractor`] in place of the
53/// default [`crate::extract::RuleExtractor`].
54pub struct KeyBertAdapter {
55 embedder: Arc<dyn Embedder>,
56 top_k: usize,
57 ngram_range: (usize, usize),
58 mmr_diversity: f32,
59 pmi_threshold: f32,
60 /// ntype label stamped on every entity this adapter emits.
61 /// Callers set this via [`KeyBertAdapter::with_label`]; there is no
62 /// built-in default — the label vocabulary is entirely up to the caller.
63 label: String,
64 /// Section-text → embedding cache. Populated by [`Extractor::prepare`]
65 /// in one batched `Embedder::embed_batch` call per file; queried
66 /// by [`Extractor::extract_entities`] on every (section, chunk)
67 /// pair the pipeline iterates over. Misses fall back to a single
68 /// `Embedder::embed`, so callers who skip `prepare` still get
69 /// correct behaviour. Keyed on the literal section text - same
70 /// section content across files therefore reuses one entry,
71 /// which is the dedup property `prepare` relies on internally.
72 section_cache: Mutex<HashMap<String, Vec<f32>>>,
73}
74
75impl std::fmt::Debug for KeyBertAdapter {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 let cached = self.section_cache.lock().map(|c| c.len()).unwrap_or(0);
78 f.debug_struct("KeyBertAdapter")
79 .field("embedder_model", &self.embedder.model())
80 .field("embedder_dim", &self.embedder.dim())
81 .field("top_k", &self.top_k)
82 .field("ngram_range", &self.ngram_range)
83 .field("mmr_diversity", &self.mmr_diversity)
84 .field("pmi_threshold", &self.pmi_threshold)
85 .field("label", &self.label)
86 .field("section_cache_len", &cached)
87 .finish()
88 }
89}
90
91impl KeyBertAdapter {
92 /// Build an adapter around the supplied embedder with KeyBERT defaults
93 /// (`top_k = 10`, `ngram_range = (1, 3)`, `mmr_diversity = 0.5`,
94 /// `pmi_threshold = 1.0`).
95 ///
96 /// `label` is the ntype string stamped on every entity this adapter emits.
97 /// The caller owns the vocabulary — pass whatever label fits your graph
98 /// (e.g. `"Keyword"`, `"Tag"`, `"Concept"`, or any domain-specific type).
99 #[must_use]
100 pub fn new(embedder: Arc<dyn Embedder>, label: impl Into<String>) -> Self {
101 Self {
102 embedder,
103 top_k: mnem_extract::keybert::DEFAULT_TOP_K,
104 ngram_range: mnem_extract::keybert::DEFAULT_NGRAM_RANGE,
105 mmr_diversity: mnem_extract::keybert::DEFAULT_MMR_DIVERSITY,
106 pmi_threshold: mnem_extract::cooccurrence::DEFAULT_PMI_THRESHOLD,
107 label: label.into(),
108 section_cache: Mutex::new(HashMap::new()),
109 }
110 }
111
112 /// Override the entity label. Returns `self` for chaining.
113 #[must_use]
114 pub fn with_label(mut self, label: impl Into<String>) -> Self {
115 self.label = label.into();
116 self
117 }
118
119 /// Override `top_k`. Returns `self` for chaining.
120 #[must_use]
121 pub const fn with_top_k(mut self, k: usize) -> Self {
122 self.top_k = k;
123 self
124 }
125
126 /// Override the PMI threshold used when mining co-occurrence
127 /// edges. Returns `self` for chaining.
128 #[must_use]
129 pub const fn with_pmi_threshold(mut self, t: f32) -> Self {
130 self.pmi_threshold = t;
131 self
132 }
133}
134
135impl Extractor for KeyBertAdapter {
136 fn prepare(&self, sections: &[Section]) -> Result<(), crate::error::Error> {
137 // Collect unique non-empty section texts. Pipelines that
138 // re-emit identical content across sections (e.g. boilerplate
139 // headers) pay only once.
140 let mut unique: Vec<&str> = Vec::with_capacity(sections.len());
141 let mut seen: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
142 for s in sections {
143 if s.text.is_empty() {
144 continue;
145 }
146 if seen.insert(s.text.as_str()) {
147 unique.push(s.text.as_str());
148 }
149 }
150 if unique.is_empty() {
151 return Ok(());
152 }
153
154 // Best-effort batch embed. Failure here downgrades silently
155 // to the per-section lazy path in `extract_entities` (which
156 // has its own error swallow), so a transient embedder hiccup
157 // never aborts the whole file ingest. This matches the
158 // legacy "skip section on embed failure" behaviour the
159 // adapter shipped with before pre-batching landed.
160 let vecs = match self.embedder.embed_batch(&unique) {
161 Ok(v) => v,
162 Err(_e) => return Ok(()),
163 };
164
165 if let Ok(mut cache) = self.section_cache.lock() {
166 // Store result indexed by the same text key
167 // `extract_entities` will look up. `embed_batch`'s
168 // contract preserves order, so unique[i] aligns with
169 // vecs[i].
170 for (text, vec) in unique.into_iter().zip(vecs) {
171 cache.entry(text.to_string()).or_insert(vec);
172 }
173 }
174 Ok(())
175 }
176
177 fn extract_entities(&self, section: &Section) -> Vec<EntitySpan> {
178 let text = §ion.text;
179 if text.is_empty() {
180 return Vec::new();
181 }
182
183 // Cache hit path: `prepare` populated this entry in one
184 // batched ORT call; we just clone the f32 vector. Cache miss
185 // path: caller skipped `prepare` (or the batch failed at
186 // prepare time); embed the section in a single call so the
187 // adapter still works end-to-end. Either path produces the
188 // same vector for the same text on the same embedder.
189 let cached = self
190 .section_cache
191 .lock()
192 .ok()
193 .and_then(|cache| cache.get(text).cloned());
194 let section_embed = match cached {
195 Some(v) => v,
196 None => match self.embedder.embed(text) {
197 Ok(v) => v,
198 Err(_) => return Vec::new(),
199 },
200 };
201
202 let kb = KeyBertExtractor {
203 embedder: self.embedder.as_ref(),
204 top_k: self.top_k,
205 ngram_range: self.ngram_range,
206 mmr_diversity: self.mmr_diversity,
207 };
208 let entities = kb.extract_entities(text, §ion_embed);
209 entities
210 .into_iter()
211 .map(|e| EntitySpan {
212 kind: self.label.clone(),
213 text: e.mention,
214 byte_range: e.span.0..e.span.1,
215 confidence: e.score.clamp(KEYBERT_MIN_CONFIDENCE, 1.0),
216 })
217 .collect()
218 }
219
220 fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
221 if entities.len() < 2 {
222 return Vec::new();
223 }
224 // Map EntitySpan → mnem_extract::Entity keeping the original
225 // index so we can refer back to it.
226 let bridged: Vec<mnem_extract::Entity> = entities
227 .iter()
228 .map(|e| mnem_extract::Entity {
229 mention: e.text.clone(),
230 score: e.confidence,
231 span: (e.byte_range.start, e.byte_range.end),
232 })
233 .collect();
234 let rels = mnem_extract::mine_relations(
235 §ion.text,
236 &bridged,
237 self.pmi_threshold,
238 mnem_extract::ExtractionSource::Statistical,
239 );
240
241 // Reverse-lookup each Relation.src / .dst back to the
242 // EntitySpan index the pipeline expects.
243 let index_of =
244 |mention: &str| -> Option<usize> { entities.iter().position(|e| e.text == mention) };
245 let mut out = Vec::with_capacity(rels.len());
246 for r in rels {
247 let (Some(si), Some(oi)) = (index_of(&r.src), index_of(&r.dst)) else {
248 continue;
249 };
250 out.push(RelationSpan {
251 kind: KEYBERT_RELATION_LABEL.to_string(),
252 subject_span: si,
253 object_span: oi,
254 confidence: r.weight.clamp(0.0, 1.0),
255 });
256 }
257 out
258 }
259}