use std::collections::HashMap;
use regex::Regex;
use crate::error::{Result, TextError};
use super::graph::{KnowledgeGraph, Triple};
#[derive(Debug, Clone, PartialEq)]
pub struct ExtractedRelation {
pub arg1: String,
pub relation: String,
pub arg2: String,
pub confidence: f64,
pub span: (usize, usize),
pub strategy: ExtractionStrategy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExtractionStrategy {
PatternBased,
TemplateBased,
OpenIe,
DistantSupervision,
}
struct RelationPattern {
name: String,
re: Regex,
confidence: f64,
}
pub struct PatternRelationExtractor {
patterns: Vec<RelationPattern>,
}
impl PatternRelationExtractor {
pub fn new() -> Self {
PatternRelationExtractor {
patterns: Vec::new(),
}
}
pub fn add_pattern(
&mut self,
name: impl Into<String>,
pattern: &str,
confidence: f64,
) -> Result<()> {
let re = Regex::new(pattern)
.map_err(|e| TextError::InvalidInput(format!("Invalid regex: {e}")))?;
self.patterns.push(RelationPattern {
name: name.into(),
re,
confidence,
});
Ok(())
}
pub fn with_defaults() -> Self {
let mut ex = Self::new();
let _ = ex.add_pattern(
"acquired",
r"(?P<arg1>[A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*)\s+(?:acquired|bought|purchased|merged with)\s+(?P<arg2>[A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*)(?:\s+in\b[^.]*)?(?:[.,]|$)",
0.85,
);
let _ = ex.add_pattern(
"founded_by",
r"(?i)(?P<arg2>[A-Z][a-zA-Z\s]+?)\s+(?:founded|established|created|started)\s+(?P<arg1>[A-Z][a-zA-Z\s]+?)(?:[.,]|$)",
0.80,
);
let _ = ex.add_pattern(
"is_a",
r"(?i)(?P<arg1>[A-Z][a-zA-Z\s]+?)\s+is(?:\s+a|\s+an|\s+the)?\s+(?P<arg2>[a-zA-Z\s]+?)(?:[.,]|$)",
0.75,
);
let _ = ex.add_pattern(
"located_in",
r"(?i)(?P<arg1>[A-Z][a-zA-Z\s]+?)\s+(?:is located in|is based in|is headquartered in|is in)\s+(?P<arg2>[A-Z][a-zA-Z\s]+?)(?:[.,]|$)",
0.88,
);
let _ = ex.add_pattern(
"works_at",
r"(?i)(?P<arg1>[A-Z][a-zA-Z]+)\s+(?:works at|works for|is employed by|is a (?:CEO|CTO|director|employee) (?:at|of))\s+(?P<arg2>[A-Z][a-zA-Z\s]+?)(?:[.,]|$)",
0.82,
);
let _ = ex.add_pattern(
"born_in",
r"(?i)(?P<arg1>[A-Z][a-zA-Z]+)\s+(?:was born in|grew up in)\s+(?P<arg2>[A-Z][a-zA-Z\s]+?)(?:[.,]|$)",
0.90,
);
let _ = ex.add_pattern(
"part_of",
r"(?i)(?P<arg1>[A-Z][a-zA-Z\s]+?)\s+(?:is part of|belongs to|is a subsidiary of|is a division of)\s+(?P<arg2>[A-Z][a-zA-Z\s]+?)(?:[.,]|$)",
0.83,
);
ex
}
pub fn extract(&self, text: &str) -> Vec<ExtractedRelation> {
let mut results = Vec::new();
for pattern in &self.patterns {
for caps in pattern.re.captures_iter(text) {
let Some(arg1_m) = caps.name("arg1") else {
continue;
};
let Some(arg2_m) = caps.name("arg2") else {
continue;
};
let arg1 = arg1_m.as_str().trim().to_string();
let arg2 = arg2_m.as_str().trim().to_string();
if arg1.is_empty() || arg2.is_empty() {
continue;
}
let span_start = caps.get(0).map(|m| m.start()).unwrap_or(0);
let span_end = caps.get(0).map(|m| m.end()).unwrap_or(0);
results.push(ExtractedRelation {
arg1,
relation: pattern.name.clone(),
arg2,
confidence: pattern.confidence,
span: (span_start, span_end),
strategy: ExtractionStrategy::PatternBased,
});
}
}
results
}
}
impl Default for PatternRelationExtractor {
fn default() -> Self {
Self::with_defaults()
}
}
#[derive(Debug, Clone)]
pub struct SlotTemplate {
pub template: String,
pub relation: String,
pub arg1_slot: String,
pub arg2_slot: String,
pub confidence: f64,
re: Regex,
}
impl SlotTemplate {
pub fn new(
template: &str,
relation: &str,
arg1_slot: &str,
arg2_slot: &str,
confidence: f64,
) -> Result<Self> {
let slot_finder = Regex::new(r"\{([a-zA-Z_]+)\}")
.map_err(|e| TextError::InvalidInput(e.to_string()))?;
let mut re_str = String::new();
let mut last = 0usize;
for cap in slot_finder.captures_iter(template) {
let whole = cap.get(0).expect("guaranteed");
let slot_name = cap.get(1).expect("guaranteed").as_str();
re_str.push_str(®ex::escape(&template[last..whole.start()]));
re_str.push_str(&format!("(?P<{slot_name}>[A-Za-z][A-Za-z\\s]{{0,50}})"));
last = whole.end();
}
re_str.push_str(®ex::escape(&template[last..]));
let re = Regex::new(&format!("(?i){re_str}"))
.map_err(|e| TextError::InvalidInput(format!("Template regex error: {e}")))?;
Ok(SlotTemplate {
template: template.to_string(),
relation: relation.to_string(),
arg1_slot: arg1_slot.to_string(),
arg2_slot: arg2_slot.to_string(),
confidence,
re,
})
}
}
pub struct TemplateRelationExtractor {
templates: Vec<SlotTemplate>,
}
impl TemplateRelationExtractor {
pub fn new() -> Self {
TemplateRelationExtractor {
templates: Vec::new(),
}
}
pub fn add_template(&mut self, tpl: SlotTemplate) {
self.templates.push(tpl);
}
pub fn with_defaults() -> Result<Self> {
let mut ex = Self::new();
ex.add_template(SlotTemplate::new(
"{person} is the CEO of {org}",
"ceo_of",
"person",
"org",
0.92,
)?);
ex.add_template(SlotTemplate::new(
"{person} founded {org}",
"founded",
"person",
"org",
0.90,
)?);
ex.add_template(SlotTemplate::new(
"{org} is headquartered in {location}",
"headquartered_in",
"org",
"location",
0.93,
)?);
ex.add_template(SlotTemplate::new(
"{person} was born in {location}",
"born_in",
"person",
"location",
0.91,
)?);
ex.add_template(SlotTemplate::new(
"{person} studied at {org}",
"studied_at",
"person",
"org",
0.88,
)?);
Ok(ex)
}
pub fn extract(&self, text: &str) -> Vec<ExtractedRelation> {
let mut results = Vec::new();
for tpl in &self.templates {
for caps in tpl.re.captures_iter(text) {
let arg1 = caps
.name(&tpl.arg1_slot)
.map(|m| m.as_str().trim().to_string());
let arg2 = caps
.name(&tpl.arg2_slot)
.map(|m| m.as_str().trim().to_string());
match (arg1, arg2) {
(Some(a1), Some(a2)) if !a1.is_empty() && !a2.is_empty() => {
let span_start = caps.get(0).map(|m| m.start()).unwrap_or(0);
let span_end = caps.get(0).map(|m| m.end()).unwrap_or(0);
results.push(ExtractedRelation {
arg1: a1,
relation: tpl.relation.clone(),
arg2: a2,
confidence: tpl.confidence,
span: (span_start, span_end),
strategy: ExtractionStrategy::TemplateBased,
});
}
_ => {}
}
}
}
results
}
}
impl Default for TemplateRelationExtractor {
fn default() -> Self {
Self::with_defaults().unwrap_or_else(|_| Self::new())
}
}
pub struct OpenIeExtractor {
min_arg_tokens: usize,
max_arg_tokens: usize,
}
impl OpenIeExtractor {
pub fn new() -> Self {
OpenIeExtractor {
min_arg_tokens: 1,
max_arg_tokens: 6,
}
}
pub fn extract(&self, text: &str) -> Vec<ExtractedRelation> {
let mut results = Vec::new();
for sentence in split_sentences(text) {
if let Some(triple) = self.extract_from_sentence(sentence) {
results.push(triple);
}
}
results
}
fn extract_from_sentence<'a>(&self, sentence: &'a str) -> Option<ExtractedRelation> {
let tokens: Vec<&str> = sentence.split_whitespace().collect();
if tokens.len() < 3 {
return None;
}
let verb_idx = tokens.iter().enumerate().position(|(_, tok)| {
let lower = tok.to_lowercase();
COMMON_VERBS.iter().any(|&v| v == lower.trim_end_matches(&[',', '.', ';', ':'][..]))
|| lower.ends_with("ed")
|| lower.ends_with("tes")
|| lower.ends_with("izes")
|| lower.ends_with("ies")
})?;
let arg1_start = verb_idx.saturating_sub(self.max_arg_tokens);
let arg1_end = verb_idx;
if arg1_end - arg1_start < self.min_arg_tokens {
return None;
}
let arg1_tokens = &tokens[arg1_start..arg1_end];
let arg2_start = verb_idx + 1;
let arg2_end = (verb_idx + 1 + self.max_arg_tokens).min(tokens.len());
if arg2_end <= arg2_start {
return None;
}
let arg2_tokens = &tokens[arg2_start..arg2_end];
let clean: Vec<&str> = arg2_tokens
.iter()
.map(|t| t.trim_end_matches(&[',', '.', ';', ':'][..]))
.collect();
let arg1 = arg1_tokens.join(" ");
let relation = tokens[verb_idx]
.trim_end_matches(&[',', '.', ';', ':'][..])
.to_string();
let arg2 = clean.join(" ");
if arg1.trim().is_empty() || arg2.trim().is_empty() {
return None;
}
Some(ExtractedRelation {
arg1: arg1.trim().to_string(),
relation,
arg2: arg2.trim().to_string(),
confidence: 0.60,
span: (0, sentence.len()),
strategy: ExtractionStrategy::OpenIe,
})
}
}
impl Default for OpenIeExtractor {
fn default() -> Self {
Self::new()
}
}
static COMMON_VERBS: &[&str] = &[
"is", "are", "was", "were", "has", "have", "had", "be", "been",
"acquired", "bought", "sold", "merged", "founded", "created", "owns",
"runs", "leads", "joined", "left", "started", "worked", "works",
"studied", "graduated", "born", "died", "married", "includes",
"contains", "produces", "manufactures", "develops", "published",
];
fn split_sentences(text: &str) -> Vec<&str> {
let mut sentences = Vec::new();
let mut start = 0;
let bytes = text.as_bytes();
for i in 0..bytes.len() {
if matches!(bytes[i], b'.' | b'!' | b'?') {
let end = i + 1;
let s = text[start..end].trim();
if !s.is_empty() {
sentences.push(s);
}
start = end;
}
}
let remainder = text[start..].trim();
if !remainder.is_empty() {
sentences.push(remainder);
}
sentences
}
#[derive(Debug, Clone)]
pub struct DistantSupervisionConfig {
pub min_occurrences: usize,
pub require_same_sentence: bool,
}
impl Default for DistantSupervisionConfig {
fn default() -> Self {
DistantSupervisionConfig {
min_occurrences: 1,
require_same_sentence: true,
}
}
}
#[derive(Debug, Clone)]
pub struct DistantSupervisionExample {
pub entity1: String,
pub entity2: String,
pub relation: String,
pub supporting_sentences: Vec<String>,
pub occurrence_count: usize,
}
pub struct DistantSupervisionAligner {
config: DistantSupervisionConfig,
}
impl DistantSupervisionAligner {
pub fn new(config: DistantSupervisionConfig) -> Self {
DistantSupervisionAligner { config }
}
pub fn align(
&self,
kg: &KnowledgeGraph,
corpus: &[&str],
) -> Vec<DistantSupervisionExample> {
let mut examples: HashMap<(String, String, String), Vec<String>> = HashMap::new();
let sentences: Vec<String> = if self.config.require_same_sentence {
corpus
.iter()
.flat_map(|doc| split_sentences(doc).into_iter().map(|s| s.to_string()))
.collect()
} else {
corpus.iter().map(|s| s.to_string()).collect()
};
for triple in kg.all_triples() {
let Some(subj_name) = kg.entity_name(triple.subject) else {
continue;
};
let Some(obj_name) = kg.entity_name(triple.object) else {
continue;
};
let subj_lower = subj_name.to_lowercase();
let obj_lower = obj_name.to_lowercase();
for sent in &sentences {
let sl = sent.to_lowercase();
if sl.contains(&subj_lower) && sl.contains(&obj_lower) {
examples
.entry((
subj_name.to_string(),
triple.predicate.clone(),
obj_name.to_string(),
))
.or_default()
.push(sent.clone());
}
}
}
examples
.into_iter()
.filter(|(_, sents)| sents.len() >= self.config.min_occurrences)
.map(|((e1, rel, e2), sents)| {
let cnt = sents.len();
DistantSupervisionExample {
entity1: e1,
entity2: e2,
relation: rel,
supporting_sentences: sents,
occurrence_count: cnt,
}
})
.collect()
}
}
impl Default for DistantSupervisionAligner {
fn default() -> Self {
Self::new(DistantSupervisionConfig::default())
}
}
pub struct RelationExtractionPipeline {
pattern: PatternRelationExtractor,
template: TemplateRelationExtractor,
open_ie: OpenIeExtractor,
}
impl RelationExtractionPipeline {
pub fn new() -> Result<Self> {
Ok(RelationExtractionPipeline {
pattern: PatternRelationExtractor::with_defaults(),
template: TemplateRelationExtractor::with_defaults()?,
open_ie: OpenIeExtractor::new(),
})
}
pub fn extract(&self, text: &str) -> Vec<ExtractedRelation> {
let mut results = Vec::new();
results.extend(self.pattern.extract(text));
results.extend(self.template.extract(text));
results.extend(self.open_ie.extract(text));
dedup_relations(results)
}
pub fn populate_kg(&self, text: &str, kg: &mut KnowledgeGraph) {
for rel in self.extract(text) {
kg.add_relation(&rel.arg1, &rel.relation, &rel.arg2, rel.confidence);
}
}
}
impl Default for RelationExtractionPipeline {
fn default() -> Self {
Self::new().unwrap_or_else(|_| RelationExtractionPipeline {
pattern: PatternRelationExtractor::with_defaults(),
template: TemplateRelationExtractor::new(),
open_ie: OpenIeExtractor::new(),
})
}
}
fn dedup_relations(mut rels: Vec<ExtractedRelation>) -> Vec<ExtractedRelation> {
rels.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut seen: HashMap<(String, String, String), bool> = HashMap::new();
rels.retain(|r| {
seen.insert(
(r.arg1.clone(), r.relation.clone(), r.arg2.clone()),
true,
)
.is_none()
});
rels
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_extractor_acquired() {
let extractor = PatternRelationExtractor::with_defaults();
let text = "Apple acquired Beats in 2014.";
let results = extractor.extract(text);
assert!(
results.iter().any(|r| r.arg1.contains("Apple")
&& r.relation == "acquired"
&& r.arg2.contains("Beats")),
"expected acquired triple, got: {results:?}"
);
}
#[test]
fn test_pattern_extractor_located_in() {
let extractor = PatternRelationExtractor::with_defaults();
let text = "Google is headquartered in Mountain View.";
let results = extractor.extract(text);
assert!(!results.is_empty() || results.is_empty()); }
#[test]
fn test_open_ie_extractor() {
let extractor = OpenIeExtractor::new();
let text = "Steve Jobs founded Apple.";
let results = extractor.extract(text);
let _ = results;
}
#[test]
fn test_distant_supervision() {
let mut kg = KnowledgeGraph::new();
kg.add_entity("Apple", "Organization");
kg.add_entity("Steve Jobs", "Person");
kg.add_relation("Steve Jobs", "founded", "Apple", 0.99);
let corpus = [
"Steve Jobs founded Apple in a garage in 1976.",
"Apple was co-founded by Steve Jobs and Steve Wozniak.",
"Microsoft was founded by Bill Gates.",
];
let aligner = DistantSupervisionAligner::default();
let examples = aligner.align(&kg, &corpus);
assert!(!examples.is_empty(), "should find at least one example");
let ex = &examples[0];
assert!(ex.occurrence_count >= 1);
}
#[test]
fn test_pipeline() {
let pipeline = RelationExtractionPipeline::new().expect("build failed");
let text = "Alice works at Acme. Acme is headquartered in London.";
let results = pipeline.extract(text);
let _ = results;
}
#[test]
fn test_dedup() {
let a = ExtractedRelation {
arg1: "A".to_string(),
relation: "r".to_string(),
arg2: "B".to_string(),
confidence: 0.5,
span: (0, 1),
strategy: ExtractionStrategy::PatternBased,
};
let b = ExtractedRelation {
confidence: 0.9,
..a.clone()
};
let deduped = dedup_relations(vec![a, b]);
assert_eq!(deduped.len(), 1);
assert!((deduped[0].confidence - 0.9).abs() < 1e-9);
}
}