1use std::ops::Range;
18use std::sync::Arc;
19
20use mnem_ner_providers::NerProvider;
21use regex::Regex;
22use serde::{Deserialize, Serialize};
23
24use crate::types::{ExtractorConfig, Section};
25
26#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct EntitySpan {
39 pub kind: String,
41 pub text: String,
43 pub byte_range: Range<usize>,
45 pub confidence: f32,
47}
48
49#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
56pub struct RelationSpan {
57 pub kind: String,
59 pub subject_span: usize,
61 pub object_span: usize,
63 pub confidence: f32,
65}
66
67pub trait Extractor: Send + Sync {
75 fn extract_entities(&self, section: &Section) -> Vec<EntitySpan>;
77
78 fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan>;
80
81 fn prepare(&self, _sections: &[Section]) -> Result<(), crate::error::Error> {
92 Ok(())
93 }
94}
95
96pub struct RuleExtractor {
104 cfg: ExtractorConfig,
105 verb_window: Regex,
106 ner: Arc<dyn NerProvider>,
107}
108
109impl std::fmt::Debug for RuleExtractor {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 f.debug_struct("RuleExtractor")
112 .field("cfg", &self.cfg)
113 .field("ner", &self.ner.provider_id())
114 .finish()
115 }
116}
117
118impl RuleExtractor {
119 #[allow(clippy::missing_panics_doc)]
121 #[must_use]
122 pub fn new(cfg: ExtractorConfig, ner: Arc<dyn NerProvider>) -> Self {
123 let verb_window = Regex::new(
124 r"(?i)\b(?:joined|founded|acquired|owns|hired|created|launched|bought|leads|runs)\b",
125 )
126 .expect("verb regex compiles");
127 Self {
128 cfg,
129 verb_window,
130 ner,
131 }
132 }
133
134 #[must_use]
136 pub fn with_default_ner(cfg: ExtractorConfig) -> Self {
137 Self::new(cfg, Arc::new(mnem_ner_providers::RuleNer))
138 }
139}
140
141impl Default for RuleExtractor {
142 fn default() -> Self {
143 Self::with_default_ner(ExtractorConfig::default())
144 }
145}
146
147impl Extractor for RuleExtractor {
148 fn extract_entities(&self, section: &Section) -> Vec<EntitySpan> {
149 if !self.cfg.extract_ner {
150 return Vec::new();
151 }
152 let text = section.text.as_str();
153 let mut out: Vec<EntitySpan> = self
154 .ner
155 .extract(text)
156 .into_iter()
157 .filter_map(|ne| {
158 if ne.label.trim().is_empty() {
159 return None;
160 }
161 let slice = text.get(ne.byte_start..ne.byte_end)?.to_string();
162 if slice.is_empty() {
163 return None;
164 }
165 Some(EntitySpan {
166 kind: ne.label,
167 text: slice,
168 byte_range: ne.byte_start..ne.byte_end,
169 confidence: ne.confidence,
170 })
171 })
172 .collect();
173
174 out.sort_by(|a, b| {
175 a.byte_range
176 .start
177 .cmp(&b.byte_range.start)
178 .then_with(|| a.kind.as_str().cmp(b.kind.as_str()))
179 });
180 out.dedup_by(|a, b| a.byte_range == b.byte_range && a.kind == b.kind);
181 out
182 }
183
184 fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
185 if entities.len() < 2 {
186 return Vec::new();
187 }
188 let text = section.text.as_str();
189 let window = self.cfg.relation_window_tokens;
190 let mut out = Vec::new();
191
192 for i in 0..entities.len() {
193 for j in (i + 1)..entities.len() {
194 let a = &entities[i];
195 let b = &entities[j];
196 if a.byte_range.end > b.byte_range.start {
197 continue;
198 }
199 let between = &text[a.byte_range.end..b.byte_range.start];
200 let tokens_between = between.split_whitespace().count();
201 if tokens_between > window {
202 continue;
203 }
204 let (kind, conf) = if self.verb_window.is_match(between) {
205 ("acts_on".to_string(), 0.50_f32)
206 } else {
207 ("co_occurs_with".to_string(), 0.40_f32)
208 };
209 out.push(RelationSpan {
210 kind,
211 subject_span: i,
212 object_span: j,
213 confidence: conf,
214 });
215 }
216 }
217 out
218 }
219}
220
221#[must_use]
225pub fn extract_entities(section: &Section) -> Vec<EntitySpan> {
226 RuleExtractor::default().extract_entities(section)
227}
228
229#[must_use]
231pub fn extract_relations(entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
232 RuleExtractor::default().extract_relations(entities, section)
233}
234
235#[cfg(test)]
238mod tests {
239 use super::*;
240
241 fn section(text: &str) -> Section {
242 Section {
243 heading: None,
244 depth: 0,
245 text: text.to_string(),
246 byte_range: 0..text.len(),
247 }
248 }
249
250 #[test]
251 fn ner_detects_person() {
252 let s = section("Alice Johnson met Bob Lee at the lobby.");
253 let ents = extract_entities(&s);
254 assert!(
255 ents.iter().any(|e| e.text == "Alice Johnson"),
256 "got: {ents:?}"
257 );
258 assert!(ents.iter().any(|e| e.text == "Bob Lee"), "got: {ents:?}");
259 }
260
261 #[test]
262 fn ner_detects_org() {
263 let s = section("Acme Corp and Foo Inc signed the deal.");
264 let ents = extract_entities(&s);
265 assert!(ents.iter().any(|e| e.text == "Acme Corp"), "got: {ents:?}");
266 }
267
268 #[test]
269 fn ner_single_token_not_detected() {
270 let s = section("Alice then left.");
271 let ents = extract_entities(&s);
272 assert!(ents.is_empty(), "single-token should not match: {ents:?}");
273 }
274
275 #[test]
276 fn relations_proximity_co_occurs() {
277 let s = section("Alice Johnson met Bob Lee today.");
278 let ents = extract_entities(&s);
279 let rels = extract_relations(&ents, &s);
280 assert!(
281 rels.iter().any(|r| r.kind == "co_occurs_with"),
282 "got rels: {rels:?}"
283 );
284 }
285
286 #[test]
287 fn relations_verb_between_becomes_acts_on() {
288 let s = section("Alice Johnson founded Acme Corp in 2022.");
289 let ents = extract_entities(&s);
290 let rels = extract_relations(&ents, &s);
291 assert!(
292 rels.iter().any(|r| r.kind == "acts_on"),
293 "got rels: {rels:?}, ents: {ents:?}"
294 );
295 }
296
297 #[test]
298 fn confidence_in_unit_range() {
299 let s = section("Alice Johnson and Bob Lee work at Acme Corp.");
300 let ents = extract_entities(&s);
301 assert!(!ents.is_empty(), "expected at least one entity from NER");
302 for e in &ents {
303 assert!(
304 (0.0..=1.0).contains(&e.confidence),
305 "confidence {} out of [0,1] for {:?}",
306 e.confidence,
307 e
308 );
309 }
310 }
311
312 #[test]
313 fn null_ner_produces_no_entities() {
314 use mnem_ner_providers::NullNer;
315 let ext = RuleExtractor::new(ExtractorConfig::default(), Arc::new(NullNer));
316 let s = section("Alice Johnson founded Acme Corp.");
317 assert!(
318 ext.extract_entities(&s).is_empty(),
319 "NullNer must produce nothing"
320 );
321 }
322
323 #[test]
324 fn extract_ner_false_produces_no_entities() {
325 let cfg = ExtractorConfig {
326 extract_ner: false,
327 ..ExtractorConfig::default()
328 };
329 let ext = RuleExtractor::with_default_ner(cfg);
330 let s = section("Alice Johnson founded Acme Corp.");
331 assert!(ext.extract_entities(&s).is_empty());
332 }
333}