use std::ops::Range;
use std::sync::Arc;
use mnem_ner_providers::NerProvider;
use regex::Regex;
use serde::{Deserialize, Serialize};
use crate::types::{ExtractorConfig, Section};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EntitySpan {
pub kind: String,
pub text: String,
pub byte_range: Range<usize>,
pub confidence: f32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RelationSpan {
pub kind: String,
pub subject_span: usize,
pub object_span: usize,
pub confidence: f32,
}
pub trait Extractor: Send + Sync {
fn extract_entities(&self, section: &Section) -> Vec<EntitySpan>;
fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan>;
fn prepare(&self, _sections: &[Section]) -> Result<(), crate::error::Error> {
Ok(())
}
}
pub struct RuleExtractor {
cfg: ExtractorConfig,
verb_window: Regex,
ner: Arc<dyn NerProvider>,
}
impl std::fmt::Debug for RuleExtractor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RuleExtractor")
.field("cfg", &self.cfg)
.field("ner", &self.ner.provider_id())
.finish()
}
}
impl RuleExtractor {
#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn new(cfg: ExtractorConfig, ner: Arc<dyn NerProvider>) -> Self {
let verb_window = Regex::new(
r"(?i)\b(?:joined|founded|acquired|owns|hired|created|launched|bought|leads|runs)\b",
)
.expect("verb regex compiles");
Self {
cfg,
verb_window,
ner,
}
}
#[must_use]
pub fn with_default_ner(cfg: ExtractorConfig) -> Self {
Self::new(cfg, Arc::new(mnem_ner_providers::RuleNer))
}
}
impl Default for RuleExtractor {
fn default() -> Self {
Self::with_default_ner(ExtractorConfig::default())
}
}
impl Extractor for RuleExtractor {
fn extract_entities(&self, section: &Section) -> Vec<EntitySpan> {
if !self.cfg.extract_ner {
return Vec::new();
}
let text = section.text.as_str();
let mut out: Vec<EntitySpan> = self
.ner
.extract(text)
.into_iter()
.filter_map(|ne| {
if ne.label.trim().is_empty() {
return None;
}
let slice = text.get(ne.byte_start..ne.byte_end)?.to_string();
if slice.is_empty() {
return None;
}
Some(EntitySpan {
kind: ne.label,
text: slice,
byte_range: ne.byte_start..ne.byte_end,
confidence: ne.confidence,
})
})
.collect();
out.sort_by(|a, b| {
a.byte_range
.start
.cmp(&b.byte_range.start)
.then_with(|| a.kind.as_str().cmp(b.kind.as_str()))
});
out.dedup_by(|a, b| a.byte_range == b.byte_range && a.kind == b.kind);
out
}
fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
if entities.len() < 2 {
return Vec::new();
}
let text = section.text.as_str();
let window = self.cfg.relation_window_tokens;
let mut out = Vec::new();
for i in 0..entities.len() {
for j in (i + 1)..entities.len() {
let a = &entities[i];
let b = &entities[j];
if a.byte_range.end > b.byte_range.start {
continue;
}
let between = &text[a.byte_range.end..b.byte_range.start];
let tokens_between = between.split_whitespace().count();
if tokens_between > window {
continue;
}
let (kind, conf) = if self.verb_window.is_match(between) {
("acts_on".to_string(), 0.50_f32)
} else {
("co_occurs_with".to_string(), 0.40_f32)
};
out.push(RelationSpan {
kind,
subject_span: i,
object_span: j,
confidence: conf,
});
}
}
out
}
}
#[must_use]
pub fn extract_entities(section: &Section) -> Vec<EntitySpan> {
RuleExtractor::default().extract_entities(section)
}
#[must_use]
pub fn extract_relations(entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
RuleExtractor::default().extract_relations(entities, section)
}
#[cfg(test)]
mod tests {
use super::*;
fn section(text: &str) -> Section {
Section {
heading: None,
depth: 0,
text: text.to_string(),
byte_range: 0..text.len(),
}
}
#[test]
fn ner_detects_person() {
let s = section("Alice Johnson met Bob Lee at the lobby.");
let ents = extract_entities(&s);
assert!(
ents.iter().any(|e| e.text == "Alice Johnson"),
"got: {ents:?}"
);
assert!(ents.iter().any(|e| e.text == "Bob Lee"), "got: {ents:?}");
}
#[test]
fn ner_detects_org() {
let s = section("Acme Corp and Foo Inc signed the deal.");
let ents = extract_entities(&s);
assert!(ents.iter().any(|e| e.text == "Acme Corp"), "got: {ents:?}");
}
#[test]
fn ner_single_token_not_detected() {
let s = section("Alice then left.");
let ents = extract_entities(&s);
assert!(ents.is_empty(), "single-token should not match: {ents:?}");
}
#[test]
fn relations_proximity_co_occurs() {
let s = section("Alice Johnson met Bob Lee today.");
let ents = extract_entities(&s);
let rels = extract_relations(&ents, &s);
assert!(
rels.iter().any(|r| r.kind == "co_occurs_with"),
"got rels: {rels:?}"
);
}
#[test]
fn relations_verb_between_becomes_acts_on() {
let s = section("Alice Johnson founded Acme Corp in 2022.");
let ents = extract_entities(&s);
let rels = extract_relations(&ents, &s);
assert!(
rels.iter().any(|r| r.kind == "acts_on"),
"got rels: {rels:?}, ents: {ents:?}"
);
}
#[test]
fn confidence_in_unit_range() {
let s = section("Alice Johnson and Bob Lee work at Acme Corp.");
let ents = extract_entities(&s);
assert!(!ents.is_empty(), "expected at least one entity from NER");
for e in &ents {
assert!(
(0.0..=1.0).contains(&e.confidence),
"confidence {} out of [0,1] for {:?}",
e.confidence,
e
);
}
}
#[test]
fn null_ner_produces_no_entities() {
use mnem_ner_providers::NullNer;
let ext = RuleExtractor::new(ExtractorConfig::default(), Arc::new(NullNer));
let s = section("Alice Johnson founded Acme Corp.");
assert!(
ext.extract_entities(&s).is_empty(),
"NullNer must produce nothing"
);
}
#[test]
fn extract_ner_false_produces_no_entities() {
let cfg = ExtractorConfig {
extract_ner: false,
..ExtractorConfig::default()
};
let ext = RuleExtractor::with_default_ner(cfg);
let s = section("Alice Johnson founded Acme Corp.");
assert!(ext.extract_entities(&s).is_empty());
}
}