use std::collections::HashMap;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct DependencyRelation {
pub head: String,
pub relation: String,
pub dependent: String,
}
#[derive(Debug, Clone)]
pub struct RelationPattern {
pub subject_pos: Option<String>,
pub predicate: Vec<String>,
pub object_pos: Option<String>,
pub label: String,
}
pub struct DependencyRelationExtractor {
patterns: Vec<RelationPattern>,
}
impl Default for DependencyRelationExtractor {
fn default() -> Self {
Self::new()
}
}
impl DependencyRelationExtractor {
pub fn new() -> DependencyRelationExtractor {
DependencyRelationExtractor {
patterns: Vec::new(),
}
}
pub fn add_pattern(&mut self, pattern: RelationPattern) {
self.patterns.push(pattern);
}
pub fn with_svo_defaults() -> DependencyRelationExtractor {
let mut ext = DependencyRelationExtractor::new();
ext.add_pattern(RelationPattern {
subject_pos: None,
predicate: vec![], object_pos: None,
label: "SVO".to_string(),
});
ext
}
pub fn extract(
&self,
_text: &str,
dependency_tree: &[DependencyRelation],
) -> Result<Vec<(String, String, String)>> {
let mut head_map: HashMap<&str, Vec<(&str, &str)>> = HashMap::new();
for arc in dependency_tree {
head_map
.entry(arc.head.as_str())
.or_default()
.push((arc.relation.as_str(), arc.dependent.as_str()));
}
let mut triples = Vec::new();
let mut seen_heads = std::collections::HashSet::new();
let head_words: Vec<String> = dependency_tree
.iter()
.filter_map(|arc| {
if seen_heads.insert(arc.head.clone()) {
Some(arc.head.clone())
} else {
None
}
})
.collect();
for head in &head_words {
let Some(deps) = head_map.get(head.as_str()) else {
continue;
};
let subjects: Vec<&str> = deps
.iter()
.filter(|(rel, _)| *rel == "nsubj" || *rel == "nsubjpass")
.map(|(_, dep)| *dep)
.collect();
let objects: Vec<&str> = deps
.iter()
.filter(|(rel, _)| {
*rel == "obj"
|| *rel == "dobj"
|| *rel == "iobj"
|| *rel == "obl"
|| *rel == "xobj"
})
.map(|(_, dep)| *dep)
.collect();
if subjects.is_empty() || objects.is_empty() {
continue;
}
for subj in &subjects {
for obj in &objects {
for pattern in &self.patterns {
if !pattern.predicate.is_empty() {
let head_lower = head.to_lowercase();
if !pattern
.predicate
.iter()
.any(|p| p.to_lowercase() == head_lower)
{
continue;
}
}
triples.push((subj.to_string(), pattern.label.clone(), obj.to_string()));
}
}
}
}
Ok(triples)
}
}
pub struct CorefResolver {
history: Vec<(String, bool, PronounGender)>,
window: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PronounGender {
Masculine,
Feminine,
Neutral,
Plural,
Unknown,
}
impl CorefResolver {
pub fn new(window: usize) -> CorefResolver {
CorefResolver {
history: Vec::new(),
window,
}
}
pub fn register(
&mut self,
noun_phrase: impl Into<String>,
is_plural: bool,
gender: PronounGender,
) {
if self.history.len() >= self.window {
self.history.remove(0);
}
self.history.push((noun_phrase.into(), is_plural, gender));
}
pub fn resolve(&self, pronoun: &str) -> Option<&str> {
let (target_plural, target_gender) = pronoun_attributes(pronoun)?;
for (np, is_plural, gender) in self.history.iter().rev() {
if *is_plural != target_plural {
continue;
}
if target_gender != PronounGender::Unknown
&& *gender != PronounGender::Unknown
&& *gender != target_gender
{
continue;
}
return Some(np.as_str());
}
None
}
}
fn pronoun_attributes(pronoun: &str) -> Option<(bool, PronounGender)> {
match pronoun.to_lowercase().as_str() {
"he" | "him" | "his" | "himself" => Some((false, PronounGender::Masculine)),
"she" | "her" | "hers" | "herself" => Some((false, PronounGender::Feminine)),
"it" | "its" | "itself" => Some((false, PronounGender::Neutral)),
"they" | "them" | "their" | "theirs" | "themselves" => Some((true, PronounGender::Plural)),
"we" | "us" | "our" | "ours" | "ourselves" => Some((true, PronounGender::Plural)),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_svo_tree() -> Vec<DependencyRelation> {
vec![
DependencyRelation {
head: "loves".to_string(),
relation: "nsubj".to_string(),
dependent: "John".to_string(),
},
DependencyRelation {
head: "loves".to_string(),
relation: "obj".to_string(),
dependent: "Mary".to_string(),
},
]
}
#[test]
fn test_svo_extraction() {
let extractor = DependencyRelationExtractor::with_svo_defaults();
let triples = extractor
.extract("John loves Mary", &make_svo_tree())
.expect("extract failed");
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].0, "John");
assert_eq!(triples[0].2, "Mary");
}
#[test]
fn test_no_triples_without_subject() {
let extractor = DependencyRelationExtractor::with_svo_defaults();
let tree = vec![DependencyRelation {
head: "runs".to_string(),
relation: "obj".to_string(),
dependent: "race".to_string(),
}];
let triples = extractor
.extract("runs race", &tree)
.expect("extract failed");
assert!(triples.is_empty());
}
#[test]
fn test_predicate_filter() {
let mut extractor = DependencyRelationExtractor::new();
extractor.add_pattern(RelationPattern {
subject_pos: None,
predicate: vec!["loves".to_string()],
object_pos: None,
label: "LOVE".to_string(),
});
let triples = extractor
.extract("John loves Mary", &make_svo_tree())
.expect("extract failed");
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].1, "LOVE");
let tree2 = vec![
DependencyRelation {
head: "hates".to_string(),
relation: "nsubj".to_string(),
dependent: "John".to_string(),
},
DependencyRelation {
head: "hates".to_string(),
relation: "obj".to_string(),
dependent: "Mary".to_string(),
},
];
let triples2 = extractor
.extract("John hates Mary", &tree2)
.expect("extract failed");
assert!(triples2.is_empty());
}
#[test]
fn test_coref_resolver_basic() {
let mut resolver = CorefResolver::new(5);
resolver.register("John Smith", false, PronounGender::Masculine);
let antecedent = resolver.resolve("he");
assert_eq!(antecedent, Some("John Smith"));
}
#[test]
fn test_coref_resolver_gender_mismatch() {
let mut resolver = CorefResolver::new(5);
resolver.register("Alice", false, PronounGender::Feminine);
let antecedent = resolver.resolve("he");
assert!(antecedent.is_none());
}
#[test]
fn test_coref_resolver_recency() {
let mut resolver = CorefResolver::new(5);
resolver.register("Bob", false, PronounGender::Masculine);
resolver.register("Alice", false, PronounGender::Feminine);
let antecedent = resolver.resolve("he");
assert_eq!(antecedent, Some("Bob"));
}
#[test]
fn test_coref_resolver_window_eviction() {
let mut resolver = CorefResolver::new(2);
resolver.register("Old Guy", false, PronounGender::Masculine);
resolver.register("Middle Person", false, PronounGender::Unknown);
resolver.register("New Person", false, PronounGender::Unknown);
let names: Vec<&str> = resolver
.history
.iter()
.map(|(n, _, _)| n.as_str())
.collect();
assert!(!names.contains(&"Old Guy"));
}
}