Skip to main content

mnem_ingest/
extract.rs

1//! Entity + relation extraction over parsed [`Section`]s.
2//!
3//! Entity extraction is delegated entirely to the configured
4//! [`mnem_ner_providers::NerProvider`]. The default is
5//! [`mnem_ner_providers::RuleNer`] (capitalized-phrase heuristic).
6//! Swap for [`mnem_ner_providers::NullNer`] or any future provider via
7//! [`IngestConfig::ner`]. Provider labels pass through unconditionally —
8//! there is no fixed vocabulary.
9//!
10//! Relations are proximity-based: two entity spans whose start positions
11//! are within `window_tokens` of each other in the same [`Section`] get a
12//! candidate `"co_occurs_with"` edge (confidence `0.40`). A lightweight
13//! verb-between check promotes that to `"acts_on"` (confidence `0.50`)
14//! when a token like `"joined"`, `"founded"`, `"acquired"`, `"owns"`, or
15//! `"hired"` sits between the two spans.
16
17use std::ops::Range;
18use std::sync::Arc;
19
20use mnem_ner_providers::NerProvider;
21use regex::Regex;
22use serde::{Deserialize, Serialize};
23
24use crate::types::{ExtractorConfig, Section};
25
26// ---------------- Types ----------------
27
28/// A single entity mention inside a [`Section`].
29///
30/// `byte_range` refers to offsets within the section's `text` field
31/// (not the original source). Downstream commit code combines it with
32/// `Section::byte_range` when provenance-accurate source offsets are
33/// needed.
34///
35/// `kind` is the namespaced ntype string (e.g. `"Entity:Person"`).
36/// Using a `String` keeps the type open for any NER provider label vocabulary.
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct EntitySpan {
39    /// Namespaced ntype label string (e.g. `"Entity:Person"`).
40    pub kind: String,
41    /// Verbatim surface string as it appears in the section text.
42    pub text: String,
43    /// Byte range within the section's `text`.
44    pub byte_range: Range<usize>,
45    /// Heuristic confidence in `[0.0, 1.0]`.
46    pub confidence: f32,
47}
48
49/// A candidate relation between two entities in the same section.
50///
51/// `subject_span` and `object_span` are indices into the entity vector
52/// returned by the same extract call. Relation identifiers are plain
53/// strings to keep the shape open; callers emit `"co_occurs_with"` or
54/// `"acts_on"` today.
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
56pub struct RelationSpan {
57    /// Predicate label (e.g. `"co_occurs_with"`, `"acts_on"`).
58    pub kind: String,
59    /// Index of the subject entity within the accompanying `Vec<EntitySpan>`.
60    pub subject_span: usize,
61    /// Index of the object entity within the accompanying `Vec<EntitySpan>`.
62    pub object_span: usize,
63    /// Heuristic confidence in `[0.0, 1.0]`.
64    pub confidence: f32,
65}
66
67// ---------------- Extractor trait ----------------
68
69/// Pluggable entity + relation extractor.
70///
71/// Implementations must be `Send + Sync` so the [`crate::Ingester`]
72/// façade can hand them across thread boundaries in batch ingest paths
73/// scheduled by CLI/HTTP wrappers in later waves.
74pub trait Extractor: Send + Sync {
75    /// Extract entity mentions from a single section.
76    fn extract_entities(&self, section: &Section) -> Vec<EntitySpan>;
77
78    /// Extract candidate relations between already-extracted entities.
79    fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan>;
80
81    /// Optional pre-extraction hook. Called once per file by
82    /// [`crate::pipeline::Ingester::ingest`] BEFORE any
83    /// `extract_entities` / `extract_relations` call, with the full
84    /// list of sections the file produced. The default implementation
85    /// is a no-op, so existing extractors keep their behaviour.
86    ///
87    /// # Errors
88    ///
89    /// Returns whatever the implementation chooses; the pipeline
90    /// passes the error through.
91    fn prepare(&self, _sections: &[Section]) -> Result<(), crate::error::Error> {
92        Ok(())
93    }
94}
95
96// ---------------- Default rule extractor ----------------
97
98/// [`Extractor`] implementation that delegates entity detection to the
99/// configured [`NerProvider`] and proximity-based relation detection to an
100/// internal verb-window regex.
101///
102/// Construct via [`RuleExtractor::new`] or [`RuleExtractor::with_default_ner`].
103pub struct RuleExtractor {
104    cfg: ExtractorConfig,
105    verb_window: Regex,
106    ner: Arc<dyn NerProvider>,
107}
108
109impl std::fmt::Debug for RuleExtractor {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        f.debug_struct("RuleExtractor")
112            .field("cfg", &self.cfg)
113            .field("ner", &self.ner.provider_id())
114            .finish()
115    }
116}
117
118impl RuleExtractor {
119    /// Build a new extractor from configuration and a NER provider.
120    #[allow(clippy::missing_panics_doc)]
121    #[must_use]
122    pub fn new(cfg: ExtractorConfig, ner: Arc<dyn NerProvider>) -> Self {
123        let verb_window = Regex::new(
124            r"(?i)\b(?:joined|founded|acquired|owns|hired|created|launched|bought|leads|runs)\b",
125        )
126        .expect("verb regex compiles");
127        Self {
128            cfg,
129            verb_window,
130            ner,
131        }
132    }
133
134    /// Build with the default [`mnem_ner_providers::RuleNer`] provider.
135    #[must_use]
136    pub fn with_default_ner(cfg: ExtractorConfig) -> Self {
137        Self::new(cfg, Arc::new(mnem_ner_providers::RuleNer))
138    }
139}
140
141impl Default for RuleExtractor {
142    fn default() -> Self {
143        Self::with_default_ner(ExtractorConfig::default())
144    }
145}
146
147impl Extractor for RuleExtractor {
148    fn extract_entities(&self, section: &Section) -> Vec<EntitySpan> {
149        if !self.cfg.extract_ner {
150            return Vec::new();
151        }
152        let text = section.text.as_str();
153        let mut out: Vec<EntitySpan> = self
154            .ner
155            .extract(text)
156            .into_iter()
157            .filter_map(|ne| {
158                if ne.label.trim().is_empty() {
159                    return None;
160                }
161                let slice = text.get(ne.byte_start..ne.byte_end)?.to_string();
162                if slice.is_empty() {
163                    return None;
164                }
165                Some(EntitySpan {
166                    kind: ne.label,
167                    text: slice,
168                    byte_range: ne.byte_start..ne.byte_end,
169                    confidence: ne.confidence,
170                })
171            })
172            .collect();
173
174        out.sort_by(|a, b| {
175            a.byte_range
176                .start
177                .cmp(&b.byte_range.start)
178                .then_with(|| a.kind.as_str().cmp(b.kind.as_str()))
179        });
180        out.dedup_by(|a, b| a.byte_range == b.byte_range && a.kind == b.kind);
181        out
182    }
183
184    fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
185        if entities.len() < 2 {
186            return Vec::new();
187        }
188        let text = section.text.as_str();
189        let window = self.cfg.relation_window_tokens;
190        let mut out = Vec::new();
191
192        for i in 0..entities.len() {
193            for j in (i + 1)..entities.len() {
194                let a = &entities[i];
195                let b = &entities[j];
196                if a.byte_range.end > b.byte_range.start {
197                    continue;
198                }
199                let between = &text[a.byte_range.end..b.byte_range.start];
200                let tokens_between = between.split_whitespace().count();
201                if tokens_between > window {
202                    continue;
203                }
204                let (kind, conf) = if self.verb_window.is_match(between) {
205                    ("acts_on".to_string(), 0.50_f32)
206                } else {
207                    ("co_occurs_with".to_string(), 0.40_f32)
208                };
209                out.push(RelationSpan {
210                    kind,
211                    subject_span: i,
212                    object_span: j,
213                    confidence: conf,
214                });
215            }
216        }
217        out
218    }
219}
220
221// ---------------- Free helpers ----------------
222
223/// Run [`RuleExtractor::default`] once against a section.
224#[must_use]
225pub fn extract_entities(section: &Section) -> Vec<EntitySpan> {
226    RuleExtractor::default().extract_entities(section)
227}
228
229/// Run [`RuleExtractor::default`] once to derive relations.
230#[must_use]
231pub fn extract_relations(entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
232    RuleExtractor::default().extract_relations(entities, section)
233}
234
235// ---------------- Tests ----------------
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    fn section(text: &str) -> Section {
242        Section {
243            heading: None,
244            depth: 0,
245            text: text.to_string(),
246            byte_range: 0..text.len(),
247        }
248    }
249
250    #[test]
251    fn ner_detects_person() {
252        let s = section("Alice Johnson met Bob Lee at the lobby.");
253        let ents = extract_entities(&s);
254        assert!(
255            ents.iter().any(|e| e.text == "Alice Johnson"),
256            "got: {ents:?}"
257        );
258        assert!(ents.iter().any(|e| e.text == "Bob Lee"), "got: {ents:?}");
259    }
260
261    #[test]
262    fn ner_detects_org() {
263        let s = section("Acme Corp and Foo Inc signed the deal.");
264        let ents = extract_entities(&s);
265        assert!(ents.iter().any(|e| e.text == "Acme Corp"), "got: {ents:?}");
266    }
267
268    #[test]
269    fn ner_single_token_not_detected() {
270        let s = section("Alice then left.");
271        let ents = extract_entities(&s);
272        assert!(ents.is_empty(), "single-token should not match: {ents:?}");
273    }
274
275    #[test]
276    fn relations_proximity_co_occurs() {
277        let s = section("Alice Johnson met Bob Lee today.");
278        let ents = extract_entities(&s);
279        let rels = extract_relations(&ents, &s);
280        assert!(
281            rels.iter().any(|r| r.kind == "co_occurs_with"),
282            "got rels: {rels:?}"
283        );
284    }
285
286    #[test]
287    fn relations_verb_between_becomes_acts_on() {
288        let s = section("Alice Johnson founded Acme Corp in 2022.");
289        let ents = extract_entities(&s);
290        let rels = extract_relations(&ents, &s);
291        assert!(
292            rels.iter().any(|r| r.kind == "acts_on"),
293            "got rels: {rels:?}, ents: {ents:?}"
294        );
295    }
296
297    #[test]
298    fn confidence_in_unit_range() {
299        let s = section("Alice Johnson and Bob Lee work at Acme Corp.");
300        let ents = extract_entities(&s);
301        assert!(!ents.is_empty(), "expected at least one entity from NER");
302        for e in &ents {
303            assert!(
304                (0.0..=1.0).contains(&e.confidence),
305                "confidence {} out of [0,1] for {:?}",
306                e.confidence,
307                e
308            );
309        }
310    }
311
312    #[test]
313    fn null_ner_produces_no_entities() {
314        use mnem_ner_providers::NullNer;
315        let ext = RuleExtractor::new(ExtractorConfig::default(), Arc::new(NullNer));
316        let s = section("Alice Johnson founded Acme Corp.");
317        assert!(
318            ext.extract_entities(&s).is_empty(),
319            "NullNer must produce nothing"
320        );
321    }
322
323    #[test]
324    fn extract_ner_false_produces_no_entities() {
325        let cfg = ExtractorConfig {
326            extract_ner: false,
327            ..ExtractorConfig::default()
328        };
329        let ext = RuleExtractor::with_default_ner(cfg);
330        let s = section("Alice Johnson founded Acme Corp.");
331        assert!(ext.extract_entities(&s).is_empty());
332    }
333}