use crate::error::{Result, TextError};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum EventType {
Move,
Attack,
Meet,
Arrest,
Die,
Transfer,
Create,
Destroy,
Custom(String),
}
impl EventType {
pub fn label(&self) -> &str {
match self {
EventType::Move => "Move",
EventType::Attack => "Attack",
EventType::Meet => "Meet",
EventType::Arrest => "Arrest",
EventType::Die => "Die",
EventType::Transfer => "Transfer",
EventType::Create => "Create",
EventType::Destroy => "Destroy",
EventType::Custom(s) => s.as_str(),
}
}
}
impl std::fmt::Display for EventType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.label())
}
}
pub struct TriggerLexicon {
pub triggers: HashMap<String, EventType>,
}
impl Default for TriggerLexicon {
fn default() -> Self {
Self::default_english()
}
}
impl TriggerLexicon {
pub fn new() -> Self {
Self {
triggers: HashMap::new(),
}
}
pub fn insert(&mut self, word: impl Into<String>, event_type: EventType) {
self.triggers.insert(word.into().to_lowercase(), event_type);
}
pub fn lookup(&self, word: &str) -> Option<&EventType> {
self.triggers.get(&word.to_lowercase())
}
pub fn default_english() -> Self {
let mut lex = Self::new();
for w in &[
"moved", "moving", "move", "traveled", "travel", "travelled",
"fled", "flee", "departed", "depart", "arrived", "arrive",
"entered", "enter", "left", "evacuated", "evacuate", "migrated",
"migrate", "relocated", "relocate", "walked", "ran", "run",
] {
lex.insert(*w, EventType::Move);
}
for w in &[
"attacked", "attack", "assaulted", "assault", "bombed", "bomb",
"shot", "shoot", "fired", "fire", "struck", "strike", "hit",
"targeted", "target", "raided", "raid", "invaded", "invade",
"detonated", "detonate", "launched", "launch", "stabbed", "stab",
] {
lex.insert(*w, EventType::Attack);
}
for w in &[
"met", "meet", "meeting", "gathered", "gather", "assembled",
"assemble", "convened", "convene", "discussed", "discuss",
"negotiated", "negotiate", "talked", "talk", "conferenced",
"conferred", "confer", "visited", "visit",
] {
lex.insert(*w, EventType::Meet);
}
for w in &[
"arrested", "arrest", "detained", "detain", "apprehended",
"apprehend", "captured", "capture", "jailed", "jail",
"imprisoned", "imprison", "charged", "charge", "indicted",
"indict", "booked", "book", "handcuffed", "handcuff",
] {
lex.insert(*w, EventType::Arrest);
}
for w in &[
"died", "die", "killed", "kill", "murdered", "murder", "executed",
"execute", "slain", "slayed", "slay", "perished", "perish",
"deceased", "assassinated", "assassinate", "fatally",
] {
lex.insert(*w, EventType::Die);
}
for w in &[
"transferred", "transfer", "sold", "sell", "purchased", "purchase",
"bought", "buy", "donated", "donate", "paid", "pay", "sent",
"send", "received", "receive", "wired", "wire", "awarded",
"award", "granted", "grant",
] {
lex.insert(*w, EventType::Transfer);
}
for w in &[
"created", "create", "built", "build", "developed", "develop",
"founded", "found", "established", "establish", "launched",
"produced", "produce", "manufactured", "manufacture", "invented",
"invent", "designed", "design", "wrote", "write", "authored",
"author", "published", "publish", "formed", "form",
] {
lex.insert(*w, EventType::Create);
}
for w in &[
"destroyed", "destroy", "demolished", "demolish", "burned",
"burn", "razed", "raze", "collapsed", "collapse", "ruined",
"ruin", "dismantled", "dismantle", "obliterated", "obliterate",
"wrecked", "wreck", "shattered", "shatter",
] {
lex.insert(*w, EventType::Destroy);
}
lex
}
}
#[derive(Debug, Clone)]
pub struct Argument {
pub role: String,
pub text: String,
pub span: (usize, usize),
}
#[derive(Debug, Clone)]
pub struct Event {
pub trigger: String,
pub trigger_span: (usize, usize),
pub event_type: String,
pub arguments: Vec<Argument>,
}
fn is_word_char(c: char) -> bool {
c.is_alphanumeric() || c == '\'' || c == '-'
}
fn tokenise(text: &str) -> Vec<(usize, usize, String)> {
let mut tokens: Vec<(usize, usize, String)> = Vec::new();
let mut start: Option<usize> = None;
for (i, c) in text.char_indices() {
if is_word_char(c) {
if start.is_none() {
start = Some(i);
}
} else if let Some(s) = start.take() {
tokens.push((s, i, text[s..i].to_string()));
}
}
if let Some(s) = start {
tokens.push((s, text.len(), text[s..].to_string()));
}
tokens
}
fn sentences(text: &str) -> Vec<(usize, &str)> {
let mut result = Vec::new();
let mut start = 0usize;
let bytes = text.as_bytes();
let len = bytes.len();
while start < len {
let mut end = start;
while end < len {
let b = bytes[end];
if b == b'.' || b == b'!' || b == b'?' {
end += 1;
while end < len && (bytes[end] == b' ' || bytes[end] == b'\n') {
end += 1;
}
break;
}
end += 1;
}
let raw = text[start..end].trim();
if !raw.is_empty() {
result.push((start, raw));
}
start = end;
}
result
}
fn detect_np_spans(
tokens: &[(usize, usize, String)],
sent_start_abs: usize,
) -> Vec<(usize, usize, String)> {
let mut spans: Vec<(usize, usize, String)> = Vec::new();
let mut i = 0usize;
while i < tokens.len() {
let (tok_s, tok_e, word) = &tokens[i];
let abs_start = sent_start_abs + tok_s;
let abs_end = sent_start_abs + tok_e;
if word.starts_with(|c: char| c.is_uppercase()) && abs_start > sent_start_abs {
let mut j = i;
while j < tokens.len()
&& tokens[j]
.2
.starts_with(|c: char| c.is_uppercase())
{
j += 1;
}
if j > i {
let span_s = sent_start_abs + tokens[i].0;
let span_e = sent_start_abs + tokens[j - 1].1;
let surface: String = tokens[i..j]
.iter()
.map(|(_, _, w)| w.as_str())
.collect::<Vec<_>>()
.join(" ");
spans.push((span_s, span_e, surface));
i = j;
continue;
}
}
i += 1;
}
spans
}
fn detect_time_spans(tokens: &[(usize, usize, String)], sent_start_abs: usize)
-> Vec<(usize, usize, String)> {
const DAYS: &[&str] = &[
"monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday",
];
const MONTHS: &[&str] = &[
"january", "february", "march", "april", "may", "june",
"july", "august", "september", "october", "november", "december",
"jan", "feb", "mar", "apr", "jun", "jul", "aug", "sep", "oct",
"nov", "dec",
];
const ABSOLUTE_TEMPS: &[&str] = &["yesterday", "today", "tomorrow", "now", "recently"];
const REL_ANCHORS: &[&str] = &["last", "next", "this", "coming", "previous"];
const UNITS: &[&str] = &[
"second", "seconds", "minute", "minutes", "hour", "hours",
"day", "days", "week", "weeks", "month", "months", "year", "years",
];
let mut spans: Vec<(usize, usize, String)> = Vec::new();
let mut i = 0usize;
while i < tokens.len() {
let (tok_s, tok_e, word) = &tokens[i];
let abs_s = sent_start_abs + tok_s;
let abs_e = sent_start_abs + tok_e;
let lower = word.to_lowercase();
if ABSOLUTE_TEMPS.contains(&lower.as_str()) {
spans.push((abs_s, abs_e, word.clone()));
i += 1;
continue;
}
if REL_ANCHORS.contains(&lower.as_str()) && i + 1 < tokens.len() {
let next_lower = tokens[i + 1].2.to_lowercase();
if DAYS.contains(&next_lower.as_str())
|| MONTHS.contains(&next_lower.as_str())
|| UNITS.contains(&next_lower.as_str())
{
let span_e = sent_start_abs + tokens[i + 1].1;
let surface = format!("{} {}", word, tokens[i + 1].2);
spans.push((abs_s, span_e, surface));
i += 2;
continue;
}
}
if lower.chars().all(|c| c.is_ascii_digit()) && i + 1 < tokens.len() {
let unit_lower = tokens[i + 1].2.to_lowercase();
if UNITS.contains(&unit_lower.as_str()) {
let mut span_e = sent_start_abs + tokens[i + 1].1;
let mut surface = format!("{} {}", word, tokens[i + 1].2);
if i + 2 < tokens.len() && tokens[i + 2].2.to_lowercase() == "ago" {
span_e = sent_start_abs + tokens[i + 2].1;
surface = format!("{} ago", surface);
i += 3;
} else {
i += 2;
}
spans.push((abs_s, span_e, surface));
continue;
}
}
if DAYS.contains(&lower.as_str()) || MONTHS.contains(&lower.as_str()) {
spans.push((abs_s, abs_e, word.clone()));
i += 1;
continue;
}
if lower.len() == 4
&& lower.starts_with(|c: char| c == '1' || c == '2')
&& lower.chars().all(|c| c.is_ascii_digit())
{
spans.push((abs_s, abs_e, word.clone()));
i += 1;
continue;
}
i += 1;
}
spans
}
fn detect_location_spans(
tokens: &[(usize, usize, String)],
sent_start_abs: usize,
np_spans: &[(usize, usize, String)],
) -> Vec<(usize, usize, String)> {
const LOC_PREPS: &[&str] = &["in", "at", "from", "to", "near", "around", "through"];
let mut locs: Vec<(usize, usize, String)> = Vec::new();
for (i, (tok_s, _tok_e, word)) in tokens.iter().enumerate() {
let lower = word.to_lowercase();
if LOC_PREPS.contains(&lower.as_str()) {
if let Some(next) = tokens.get(i + 1) {
let next_abs_s = sent_start_abs + next.0;
let next_abs_e = sent_start_abs + next.1;
let found = np_spans
.iter()
.find(|(ns, _ne, _surf)| *ns == next_abs_s);
if let Some((ns, ne, surf)) = found {
locs.push((*ns, *ne, surf.clone()));
} else if next.2.starts_with(|c: char| c.is_uppercase()) {
locs.push((next_abs_s, next_abs_e, next.2.clone()));
}
}
}
let _ = tok_s;
}
locs
}
pub fn extract_events(text: &str, triggers: &TriggerLexicon) -> Vec<Event> {
let mut events: Vec<Event> = Vec::new();
for (sent_off, sent_text) in sentences(text) {
let tokens = tokenise(sent_text);
if tokens.is_empty() {
continue;
}
let np_spans = detect_np_spans(&tokens, sent_off);
let time_spans = detect_time_spans(&tokens, sent_off);
let loc_spans = detect_location_spans(&tokens, sent_off, &np_spans);
for (tok_idx, (tok_s, tok_e, word)) in tokens.iter().enumerate() {
let abs_trig_s = sent_off + tok_s;
let abs_trig_e = sent_off + tok_e;
let etype = match triggers.lookup(word) {
Some(et) => et,
None => continue,
};
let mut args: Vec<Argument> = Vec::new();
let agent = np_spans
.iter()
.filter(|(_, ne, _)| *ne <= abs_trig_s)
.max_by_key(|(ns, _, _)| *ns);
if let Some((ns, ne, surf)) = agent {
args.push(Argument {
role: "Agent".to_string(),
text: surf.clone(),
span: (*ns, *ne),
});
}
let patient = np_spans
.iter()
.filter(|(ns, _, _)| *ns >= abs_trig_e)
.min_by_key(|(ns, _, _)| *ns);
if let Some((ns, ne, surf)) = patient {
args.push(Argument {
role: "Patient".to_string(),
text: surf.clone(),
span: (*ns, *ne),
});
}
let window_start = tok_idx.saturating_sub(5);
let window_end = (tok_idx + 6).min(tokens.len());
let window_abs_s = sent_off + tokens[window_start].0;
let window_abs_e = sent_off + tokens[window_end - 1].1;
for (ls, le, lsurf) in &loc_spans {
if *ls >= window_abs_s && *le <= window_abs_e {
args.push(Argument {
role: "Location".to_string(),
text: lsurf.clone(),
span: (*ls, *le),
});
}
}
let twin_start = tok_idx.saturating_sub(6);
let twin_end = (tok_idx + 7).min(tokens.len());
let twin_abs_s = sent_off + tokens[twin_start].0;
let twin_abs_e = sent_off + tokens[twin_end - 1].1;
for (ts, te, tsurf) in &time_spans {
if *ts >= twin_abs_s && *te <= twin_abs_e {
args.push(Argument {
role: "Time".to_string(),
text: tsurf.clone(),
span: (*ts, *te),
});
}
}
events.push(Event {
trigger: word.clone(),
trigger_span: (abs_trig_s, abs_trig_e),
event_type: etype.label().to_string(),
arguments: args,
});
}
}
events
}
fn event_similarity(e1: &Event, e2: &Event) -> f64 {
let mut score = 0.0f64;
if e1.event_type == e2.event_type {
score += 0.5;
}
let texts1: std::collections::HashSet<String> = e1
.arguments
.iter()
.map(|a| a.text.to_lowercase())
.collect();
let texts2: std::collections::HashSet<String> = e2
.arguments
.iter()
.map(|a| a.text.to_lowercase())
.collect();
let shared = texts1.intersection(&texts2).count();
let total = texts1.len().max(texts2.len());
if total > 0 {
score += 0.4 * (shared as f64 / total as f64);
}
let t1 = e1.trigger.to_lowercase();
let t2 = e2.trigger.to_lowercase();
if t1 == t2 || levenshtein(t1.as_bytes(), t2.as_bytes()) <= 2 {
score += 0.1;
}
score
}
fn levenshtein(a: &[u8], b: &[u8]) -> usize {
let m = a.len();
let n = b.len();
if m == 0 {
return n;
}
if n == 0 {
return m;
}
let mut dp: Vec<usize> = (0..=n).collect();
for i in 1..=m {
let mut prev = dp[0];
dp[0] = i;
for j in 1..=n {
let tmp = dp[j];
dp[j] = if a[i - 1] == b[j - 1] {
prev
} else {
1 + prev.min(dp[j]).min(dp[j - 1])
};
prev = tmp;
}
}
dp[n]
}
pub fn event_coref(events: &[Event]) -> Vec<Vec<usize>> {
event_coref_with_threshold(events, 0.6)
}
pub fn event_coref_with_threshold(events: &[Event], threshold: f64) -> Vec<Vec<usize>> {
let n = events.len();
let mut parent: Vec<usize> = (0..n).collect();
fn find(parent: &mut Vec<usize>, x: usize) -> usize {
if parent[x] != x {
parent[x] = find(parent, parent[x]);
}
parent[x]
}
for i in 0..n {
for j in (i + 1)..n {
if event_similarity(&events[i], &events[j]) >= threshold {
let ri = find(&mut parent, i);
let rj = find(&mut parent, j);
if ri != rj {
parent[rj] = ri;
}
}
}
}
let mut chains: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n {
let root = find(&mut parent, i);
chains.entry(root).or_default().push(i);
}
let mut result: Vec<Vec<usize>> = chains
.into_values()
.filter(|v| v.len() >= 2)
.collect();
result.sort_by_key(|v| v[0]);
result
}
pub struct EventExtractor {
lexicon: TriggerLexicon,
coref_threshold: f64,
}
impl Default for EventExtractor {
fn default() -> Self {
Self::new()
}
}
impl EventExtractor {
pub fn new() -> Self {
Self {
lexicon: TriggerLexicon::default_english(),
coref_threshold: 0.6,
}
}
pub fn with_lexicon(mut self, lexicon: TriggerLexicon) -> Self {
self.lexicon = lexicon;
self
}
pub fn with_coref_threshold(mut self, threshold: f64) -> Self {
self.coref_threshold = threshold;
self
}
pub fn extract(&self, text: &str) -> Vec<Event> {
extract_events(text, &self.lexicon)
}
pub fn extract_with_coref(&self, text: &str) -> Result<(Vec<Event>, Vec<Vec<usize>>)> {
if text.is_empty() {
return Err(TextError::InvalidInput(
"Input text must not be empty".to_string(),
));
}
let events = self.extract(text);
let chains = event_coref_with_threshold(&events, self.coref_threshold);
Ok((events, chains))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trigger_lexicon_lookup() {
let lex = TriggerLexicon::default_english();
assert_eq!(lex.lookup("arrested"), Some(&EventType::Arrest));
assert_eq!(lex.lookup("ARRESTED"), Some(&EventType::Arrest));
assert_eq!(lex.lookup("died"), Some(&EventType::Die));
assert_eq!(lex.lookup("unknown_verb"), None);
}
#[test]
fn test_extract_events_arrest() {
let lex = TriggerLexicon::default_english();
let text = "Police arrested the suspect yesterday in New York.";
let events = extract_events(text, &lex);
assert!(!events.is_empty());
let e = events.iter().find(|e| e.event_type == "Arrest");
assert!(e.is_some(), "Expected an Arrest event");
let ev = e.expect("already checked");
assert!(ev.arguments.iter().any(|a| a.role == "Patient"));
}
#[test]
fn test_extract_events_die_with_agent() {
let lex = TriggerLexicon::default_english();
let text = "The soldier died in battle last week.";
let events = extract_events(text, &lex);
assert!(!events.is_empty());
let e = events.iter().find(|e| e.event_type == "Die");
assert!(e.is_some());
}
#[test]
fn test_extract_events_transfer() {
let lex = TriggerLexicon::default_english();
let text = "The company sold its assets to the buyer yesterday.";
let events = extract_events(text, &lex);
assert!(events.iter().any(|e| e.event_type == "Transfer"));
}
#[test]
fn test_extract_events_multiple_sentences() {
let lex = TriggerLexicon::default_english();
let text = "Alice attacked the base. Bob fled to safety.";
let events = extract_events(text, &lex);
assert!(events.len() >= 2);
let types: Vec<&str> = events.iter().map(|e| e.event_type.as_str()).collect();
assert!(types.contains(&"Attack"));
assert!(types.contains(&"Move"));
}
#[test]
fn test_event_coref_same_type_and_argument() {
let lex = TriggerLexicon::default_english();
let text = "Alice arrested Bob on Monday. Police arrested Bob again on Tuesday.";
let events = extract_events(text, &lex);
let chains = event_coref(&events);
if !chains.is_empty() {
assert!(chains.iter().any(|c| c.len() >= 2));
}
}
#[test]
fn test_event_coref_different_types() {
let lex = TriggerLexicon::default_english();
let text = "Alice attacked the fort. Bob fled to the hills.";
let events = extract_events(text, &lex);
assert!(events.len() >= 2);
let chains = event_coref(&events);
assert!(chains.is_empty() || !chains.iter().any(|c| c.len() >= 2));
}
#[test]
fn test_extractor_builder() {
let extractor = EventExtractor::new().with_coref_threshold(0.5);
let (events, _chains) = extractor
.extract_with_coref("Police arrested the suspect in London.")
.expect("should not fail");
assert!(!events.is_empty());
}
#[test]
fn test_extractor_empty_text_error() {
let extractor = EventExtractor::new();
let result = extractor.extract_with_coref("");
assert!(result.is_err());
}
#[test]
fn test_custom_lexicon() {
let mut lex = TriggerLexicon::new();
lex.insert("deployed", EventType::Move);
lex.insert("commissioned", EventType::Create);
let text = "The company deployed a new service and commissioned a report.";
let events = extract_events(text, &lex);
assert!(events.iter().any(|e| e.event_type == "Move"));
assert!(events.iter().any(|e| e.event_type == "Create"));
}
#[test]
fn test_event_type_label() {
assert_eq!(EventType::Move.label(), "Move");
assert_eq!(EventType::Arrest.label(), "Arrest");
assert_eq!(EventType::Custom("Foo".to_string()).label(), "Foo");
}
}