1use crate::error::{Result, TextError};
29use std::collections::HashMap;
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum EventType {
38 Move,
40 Attack,
42 Meet,
44 Arrest,
46 Die,
48 Transfer,
50 Create,
52 Destroy,
54 Custom(String),
56}
57
58impl EventType {
59 pub fn label(&self) -> &str {
61 match self {
62 EventType::Move => "Move",
63 EventType::Attack => "Attack",
64 EventType::Meet => "Meet",
65 EventType::Arrest => "Arrest",
66 EventType::Die => "Die",
67 EventType::Transfer => "Transfer",
68 EventType::Create => "Create",
69 EventType::Destroy => "Destroy",
70 EventType::Custom(s) => s.as_str(),
71 }
72 }
73}
74
75impl std::fmt::Display for EventType {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(f, "{}", self.label())
78 }
79}
80
81pub struct TriggerLexicon {
90 pub triggers: HashMap<String, EventType>,
92}
93
94impl Default for TriggerLexicon {
95 fn default() -> Self {
96 Self::default_english()
97 }
98}
99
100impl TriggerLexicon {
101 pub fn new() -> Self {
103 Self {
104 triggers: HashMap::new(),
105 }
106 }
107
108 pub fn insert(&mut self, word: impl Into<String>, event_type: EventType) {
110 self.triggers.insert(word.into().to_lowercase(), event_type);
111 }
112
113 pub fn lookup(&self, word: &str) -> Option<&EventType> {
115 self.triggers.get(&word.to_lowercase())
116 }
117
118 pub fn default_english() -> Self {
120 let mut lex = Self::new();
121
122 for w in &[
124 "moved",
125 "moving",
126 "move",
127 "traveled",
128 "travel",
129 "travelled",
130 "fled",
131 "flee",
132 "departed",
133 "depart",
134 "arrived",
135 "arrive",
136 "entered",
137 "enter",
138 "left",
139 "evacuated",
140 "evacuate",
141 "migrated",
142 "migrate",
143 "relocated",
144 "relocate",
145 "walked",
146 "ran",
147 "run",
148 ] {
149 lex.insert(*w, EventType::Move);
150 }
151
152 for w in &[
154 "attacked",
155 "attack",
156 "assaulted",
157 "assault",
158 "bombed",
159 "bomb",
160 "shot",
161 "shoot",
162 "fired",
163 "fire",
164 "struck",
165 "strike",
166 "hit",
167 "targeted",
168 "target",
169 "raided",
170 "raid",
171 "invaded",
172 "invade",
173 "detonated",
174 "detonate",
175 "launched",
176 "launch",
177 "stabbed",
178 "stab",
179 ] {
180 lex.insert(*w, EventType::Attack);
181 }
182
183 for w in &[
185 "met",
186 "meet",
187 "meeting",
188 "gathered",
189 "gather",
190 "assembled",
191 "assemble",
192 "convened",
193 "convene",
194 "discussed",
195 "discuss",
196 "negotiated",
197 "negotiate",
198 "talked",
199 "talk",
200 "conferenced",
201 "conferred",
202 "confer",
203 "visited",
204 "visit",
205 ] {
206 lex.insert(*w, EventType::Meet);
207 }
208
209 for w in &[
211 "arrested",
212 "arrest",
213 "detained",
214 "detain",
215 "apprehended",
216 "apprehend",
217 "captured",
218 "capture",
219 "jailed",
220 "jail",
221 "imprisoned",
222 "imprison",
223 "charged",
224 "charge",
225 "indicted",
226 "indict",
227 "booked",
228 "book",
229 "handcuffed",
230 "handcuff",
231 ] {
232 lex.insert(*w, EventType::Arrest);
233 }
234
235 for w in &[
237 "died",
238 "die",
239 "killed",
240 "kill",
241 "murdered",
242 "murder",
243 "executed",
244 "execute",
245 "slain",
246 "slayed",
247 "slay",
248 "perished",
249 "perish",
250 "deceased",
251 "assassinated",
252 "assassinate",
253 "fatally",
254 ] {
255 lex.insert(*w, EventType::Die);
256 }
257
258 for w in &[
260 "transferred",
261 "transfer",
262 "sold",
263 "sell",
264 "purchased",
265 "purchase",
266 "bought",
267 "buy",
268 "donated",
269 "donate",
270 "paid",
271 "pay",
272 "sent",
273 "send",
274 "received",
275 "receive",
276 "wired",
277 "wire",
278 "awarded",
279 "award",
280 "granted",
281 "grant",
282 ] {
283 lex.insert(*w, EventType::Transfer);
284 }
285
286 for w in &[
288 "created",
289 "create",
290 "built",
291 "build",
292 "developed",
293 "develop",
294 "founded",
295 "found",
296 "established",
297 "establish",
298 "launched",
299 "produced",
300 "produce",
301 "manufactured",
302 "manufacture",
303 "invented",
304 "invent",
305 "designed",
306 "design",
307 "wrote",
308 "write",
309 "authored",
310 "author",
311 "published",
312 "publish",
313 "formed",
314 "form",
315 ] {
316 lex.insert(*w, EventType::Create);
317 }
318
319 for w in &[
321 "destroyed",
322 "destroy",
323 "demolished",
324 "demolish",
325 "burned",
326 "burn",
327 "razed",
328 "raze",
329 "collapsed",
330 "collapse",
331 "ruined",
332 "ruin",
333 "dismantled",
334 "dismantle",
335 "obliterated",
336 "obliterate",
337 "wrecked",
338 "wreck",
339 "shattered",
340 "shatter",
341 ] {
342 lex.insert(*w, EventType::Destroy);
343 }
344
345 lex
346 }
347}
348
349#[derive(Debug, Clone)]
355pub struct Argument {
356 pub role: String,
358 pub text: String,
360 pub span: (usize, usize),
362}
363
364#[derive(Debug, Clone)]
370pub struct Event {
371 pub trigger: String,
373 pub trigger_span: (usize, usize),
375 pub event_type: String,
377 pub arguments: Vec<Argument>,
379}
380
381fn is_word_char(c: char) -> bool {
387 c.is_alphanumeric() || c == '\'' || c == '-'
388}
389
390fn tokenise(text: &str) -> Vec<(usize, usize, String)> {
392 let mut tokens: Vec<(usize, usize, String)> = Vec::new();
393 let mut start: Option<usize> = None;
394 for (i, c) in text.char_indices() {
395 if is_word_char(c) {
396 if start.is_none() {
397 start = Some(i);
398 }
399 } else if let Some(s) = start.take() {
400 tokens.push((s, i, text[s..i].to_string()));
401 }
402 }
403 if let Some(s) = start {
404 tokens.push((s, text.len(), text[s..].to_string()));
405 }
406 tokens
407}
408
409fn sentences(text: &str) -> Vec<(usize, &str)> {
415 let mut result = Vec::new();
416 let mut start = 0usize;
417 let bytes = text.as_bytes();
418 let len = bytes.len();
419 while start < len {
420 let mut end = start;
421 while end < len {
422 let b = bytes[end];
423 if b == b'.' || b == b'!' || b == b'?' {
424 end += 1;
425 while end < len && (bytes[end] == b' ' || bytes[end] == b'\n') {
426 end += 1;
427 }
428 break;
429 }
430 end += 1;
431 }
432 let raw = text[start..end].trim();
433 if !raw.is_empty() {
434 result.push((start, raw));
435 }
436 start = end;
437 }
438 result
439}
440
441fn detect_np_spans(
448 tokens: &[(usize, usize, String)],
449 sent_start_abs: usize,
450) -> Vec<(usize, usize, String)> {
451 let mut spans: Vec<(usize, usize, String)> = Vec::new();
452 let mut i = 0usize;
453 while i < tokens.len() {
454 let (tok_s, tok_e, word) = &tokens[i];
455 let abs_start = sent_start_abs + tok_s;
456 let abs_end = sent_start_abs + tok_e;
457
458 if word.starts_with(|c: char| c.is_uppercase()) && abs_start > sent_start_abs {
460 let mut j = i;
461 while j < tokens.len() && tokens[j].2.starts_with(|c: char| c.is_uppercase()) {
462 j += 1;
463 }
464 if j > i {
465 let span_s = sent_start_abs + tokens[i].0;
466 let span_e = sent_start_abs + tokens[j - 1].1;
467 let surface: String = tokens[i..j]
470 .iter()
471 .map(|(_, _, w)| w.as_str())
472 .collect::<Vec<_>>()
473 .join(" ");
474 spans.push((span_s, span_e, surface));
475 i = j;
476 continue;
477 }
478 }
479 i += 1;
480 }
481 spans
482}
483
484fn detect_time_spans(
487 tokens: &[(usize, usize, String)],
488 sent_start_abs: usize,
489) -> Vec<(usize, usize, String)> {
490 const DAYS: &[&str] = &[
491 "monday",
492 "tuesday",
493 "wednesday",
494 "thursday",
495 "friday",
496 "saturday",
497 "sunday",
498 ];
499 const MONTHS: &[&str] = &[
500 "january",
501 "february",
502 "march",
503 "april",
504 "may",
505 "june",
506 "july",
507 "august",
508 "september",
509 "october",
510 "november",
511 "december",
512 "jan",
513 "feb",
514 "mar",
515 "apr",
516 "jun",
517 "jul",
518 "aug",
519 "sep",
520 "oct",
521 "nov",
522 "dec",
523 ];
524 const ABSOLUTE_TEMPS: &[&str] = &["yesterday", "today", "tomorrow", "now", "recently"];
525 const REL_ANCHORS: &[&str] = &["last", "next", "this", "coming", "previous"];
526 const UNITS: &[&str] = &[
527 "second", "seconds", "minute", "minutes", "hour", "hours", "day", "days", "week", "weeks",
528 "month", "months", "year", "years",
529 ];
530
531 let mut spans: Vec<(usize, usize, String)> = Vec::new();
532 let mut i = 0usize;
533 while i < tokens.len() {
534 let (tok_s, tok_e, word) = &tokens[i];
535 let abs_s = sent_start_abs + tok_s;
536 let abs_e = sent_start_abs + tok_e;
537 let lower = word.to_lowercase();
538
539 if ABSOLUTE_TEMPS.contains(&lower.as_str()) {
541 spans.push((abs_s, abs_e, word.clone()));
542 i += 1;
543 continue;
544 }
545
546 if REL_ANCHORS.contains(&lower.as_str()) && i + 1 < tokens.len() {
548 let next_lower = tokens[i + 1].2.to_lowercase();
549 if DAYS.contains(&next_lower.as_str())
550 || MONTHS.contains(&next_lower.as_str())
551 || UNITS.contains(&next_lower.as_str())
552 {
553 let span_e = sent_start_abs + tokens[i + 1].1;
554 let surface = format!("{} {}", word, tokens[i + 1].2);
555 spans.push((abs_s, span_e, surface));
556 i += 2;
557 continue;
558 }
559 }
560
561 if lower.chars().all(|c| c.is_ascii_digit()) && i + 1 < tokens.len() {
563 let unit_lower = tokens[i + 1].2.to_lowercase();
564 if UNITS.contains(&unit_lower.as_str()) {
565 let mut span_e = sent_start_abs + tokens[i + 1].1;
566 let mut surface = format!("{} {}", word, tokens[i + 1].2);
567 if i + 2 < tokens.len() && tokens[i + 2].2.to_lowercase() == "ago" {
569 span_e = sent_start_abs + tokens[i + 2].1;
570 surface = format!("{} ago", surface);
571 i += 3;
572 } else {
573 i += 2;
574 }
575 spans.push((abs_s, span_e, surface));
576 continue;
577 }
578 }
579
580 if DAYS.contains(&lower.as_str()) || MONTHS.contains(&lower.as_str()) {
582 spans.push((abs_s, abs_e, word.clone()));
583 i += 1;
584 continue;
585 }
586
587 if lower.len() == 4
589 && lower.starts_with(['1', '2'])
590 && lower.chars().all(|c| c.is_ascii_digit())
591 {
592 spans.push((abs_s, abs_e, word.clone()));
593 i += 1;
594 continue;
595 }
596
597 i += 1;
598 }
599 spans
600}
601
602fn detect_location_spans(
604 tokens: &[(usize, usize, String)],
605 sent_start_abs: usize,
606 np_spans: &[(usize, usize, String)],
607) -> Vec<(usize, usize, String)> {
608 const LOC_PREPS: &[&str] = &["in", "at", "from", "to", "near", "around", "through"];
609 let mut locs: Vec<(usize, usize, String)> = Vec::new();
610
611 for (i, (tok_s, _tok_e, word)) in tokens.iter().enumerate() {
612 let lower = word.to_lowercase();
613 if LOC_PREPS.contains(&lower.as_str()) {
614 if let Some(next) = tokens.get(i + 1) {
615 let next_abs_s = sent_start_abs + next.0;
616 let next_abs_e = sent_start_abs + next.1;
617 let found = np_spans.iter().find(|(ns, _ne, _surf)| *ns == next_abs_s);
619 if let Some((ns, ne, surf)) = found {
620 locs.push((*ns, *ne, surf.clone()));
621 } else if next.2.starts_with(|c: char| c.is_uppercase()) {
622 locs.push((next_abs_s, next_abs_e, next.2.clone()));
623 }
624 }
625 }
626 let _ = tok_s;
627 }
628 locs
629}
630
631pub fn extract_events(text: &str, triggers: &TriggerLexicon) -> Vec<Event> {
646 let mut events: Vec<Event> = Vec::new();
647
648 for (sent_off, sent_text) in sentences(text) {
649 let tokens = tokenise(sent_text);
650 if tokens.is_empty() {
651 continue;
652 }
653
654 let np_spans = detect_np_spans(&tokens, sent_off);
655 let time_spans = detect_time_spans(&tokens, sent_off);
656 let loc_spans = detect_location_spans(&tokens, sent_off, &np_spans);
657
658 for (tok_idx, (tok_s, tok_e, word)) in tokens.iter().enumerate() {
659 let abs_trig_s = sent_off + tok_s;
660 let abs_trig_e = sent_off + tok_e;
661
662 let etype = match triggers.lookup(word) {
663 Some(et) => et,
664 None => continue,
665 };
666
667 let mut args: Vec<Argument> = Vec::new();
668
669 let agent = np_spans
671 .iter()
672 .filter(|(_, ne, _)| *ne <= abs_trig_s)
673 .max_by_key(|(ns, _, _)| *ns);
674 if let Some((ns, ne, surf)) = agent {
675 args.push(Argument {
676 role: "Agent".to_string(),
677 text: surf.clone(),
678 span: (*ns, *ne),
679 });
680 }
681
682 let patient = np_spans
684 .iter()
685 .filter(|(ns, _, _)| *ns >= abs_trig_e)
686 .min_by_key(|(ns, _, _)| *ns);
687 if let Some((ns, ne, surf)) = patient {
688 args.push(Argument {
689 role: "Patient".to_string(),
690 text: surf.clone(),
691 span: (*ns, *ne),
692 });
693 }
694
695 let window_start = tok_idx.saturating_sub(5);
697 let window_end = (tok_idx + 6).min(tokens.len());
698 let window_abs_s = sent_off + tokens[window_start].0;
699 let window_abs_e = sent_off + tokens[window_end - 1].1;
700
701 for (ls, le, lsurf) in &loc_spans {
702 if *ls >= window_abs_s && *le <= window_abs_e {
703 args.push(Argument {
704 role: "Location".to_string(),
705 text: lsurf.clone(),
706 span: (*ls, *le),
707 });
708 }
709 }
710
711 let twin_start = tok_idx.saturating_sub(6);
713 let twin_end = (tok_idx + 7).min(tokens.len());
714 let twin_abs_s = sent_off + tokens[twin_start].0;
715 let twin_abs_e = sent_off + tokens[twin_end - 1].1;
716
717 for (ts, te, tsurf) in &time_spans {
718 if *ts >= twin_abs_s && *te <= twin_abs_e {
719 args.push(Argument {
720 role: "Time".to_string(),
721 text: tsurf.clone(),
722 span: (*ts, *te),
723 });
724 }
725 }
726
727 events.push(Event {
728 trigger: word.clone(),
729 trigger_span: (abs_trig_s, abs_trig_e),
730 event_type: etype.label().to_string(),
731 arguments: args,
732 });
733 }
734 }
735
736 events
737}
738
739fn event_similarity(e1: &Event, e2: &Event) -> f64 {
751 let mut score = 0.0f64;
752
753 if e1.event_type == e2.event_type {
755 score += 0.5;
756 }
757
758 let texts1: std::collections::HashSet<String> =
760 e1.arguments.iter().map(|a| a.text.to_lowercase()).collect();
761 let texts2: std::collections::HashSet<String> =
762 e2.arguments.iter().map(|a| a.text.to_lowercase()).collect();
763
764 let shared = texts1.intersection(&texts2).count();
765 let total = texts1.len().max(texts2.len());
766 if total > 0 {
767 score += 0.4 * (shared as f64 / total as f64);
768 }
769
770 let t1 = e1.trigger.to_lowercase();
772 let t2 = e2.trigger.to_lowercase();
773 if t1 == t2 || levenshtein(t1.as_bytes(), t2.as_bytes()) <= 2 {
774 score += 0.1;
775 }
776
777 score
778}
779
780fn levenshtein(a: &[u8], b: &[u8]) -> usize {
782 let m = a.len();
783 let n = b.len();
784 if m == 0 {
785 return n;
786 }
787 if n == 0 {
788 return m;
789 }
790 let mut dp: Vec<usize> = (0..=n).collect();
791 for i in 1..=m {
792 let mut prev = dp[0];
793 dp[0] = i;
794 for j in 1..=n {
795 let tmp = dp[j];
796 dp[j] = if a[i - 1] == b[j - 1] {
797 prev
798 } else {
799 1 + prev.min(dp[j]).min(dp[j - 1])
800 };
801 prev = tmp;
802 }
803 }
804 dp[n]
805}
806
807pub fn event_coref(events: &[Event]) -> Vec<Vec<usize>> {
815 event_coref_with_threshold(events, 0.6)
816}
817
818pub fn event_coref_with_threshold(events: &[Event], threshold: f64) -> Vec<Vec<usize>> {
820 let n = events.len();
821 let mut parent: Vec<usize> = (0..n).collect();
823
824 fn find(parent: &mut Vec<usize>, x: usize) -> usize {
825 if parent[x] != x {
826 parent[x] = find(parent, parent[x]);
827 }
828 parent[x]
829 }
830
831 for i in 0..n {
832 for j in (i + 1)..n {
833 if event_similarity(&events[i], &events[j]) >= threshold {
834 let ri = find(&mut parent, i);
835 let rj = find(&mut parent, j);
836 if ri != rj {
837 parent[rj] = ri;
838 }
839 }
840 }
841 }
842
843 let mut chains: HashMap<usize, Vec<usize>> = HashMap::new();
845 for i in 0..n {
846 let root = find(&mut parent, i);
847 chains.entry(root).or_default().push(i);
848 }
849
850 let mut result: Vec<Vec<usize>> = chains.into_values().filter(|v| v.len() >= 2).collect();
852 result.sort_by_key(|v| v[0]);
853 result
854}
855
856pub struct EventExtractor {
862 lexicon: TriggerLexicon,
863 coref_threshold: f64,
864}
865
866impl Default for EventExtractor {
867 fn default() -> Self {
868 Self::new()
869 }
870}
871
872impl EventExtractor {
873 pub fn new() -> Self {
875 Self {
876 lexicon: TriggerLexicon::default_english(),
877 coref_threshold: 0.6,
878 }
879 }
880
881 pub fn with_lexicon(mut self, lexicon: TriggerLexicon) -> Self {
883 self.lexicon = lexicon;
884 self
885 }
886
887 pub fn with_coref_threshold(mut self, threshold: f64) -> Self {
889 self.coref_threshold = threshold;
890 self
891 }
892
893 pub fn extract(&self, text: &str) -> Vec<Event> {
895 extract_events(text, &self.lexicon)
896 }
897
898 pub fn extract_with_coref(&self, text: &str) -> Result<(Vec<Event>, Vec<Vec<usize>>)> {
900 if text.is_empty() {
901 return Err(TextError::InvalidInput(
902 "Input text must not be empty".to_string(),
903 ));
904 }
905 let events = self.extract(text);
906 let chains = event_coref_with_threshold(&events, self.coref_threshold);
907 Ok((events, chains))
908 }
909}
910
911#[cfg(test)]
916mod tests {
917 use super::*;
918
919 #[test]
920 fn test_trigger_lexicon_lookup() {
921 let lex = TriggerLexicon::default_english();
922 assert_eq!(lex.lookup("arrested"), Some(&EventType::Arrest));
923 assert_eq!(lex.lookup("ARRESTED"), Some(&EventType::Arrest));
924 assert_eq!(lex.lookup("died"), Some(&EventType::Die));
925 assert_eq!(lex.lookup("unknown_verb"), None);
926 }
927
928 #[test]
929 fn test_extract_events_arrest() {
930 let lex = TriggerLexicon::default_english();
931 let text = "Police arrested the suspect yesterday in New York.";
932 let events = extract_events(text, &lex);
933 assert!(!events.is_empty());
934 let e = events.iter().find(|e| e.event_type == "Arrest");
935 assert!(e.is_some(), "Expected an Arrest event");
936 let ev = e.expect("already checked");
937 assert!(ev.arguments.iter().any(|a| a.role == "Patient"));
939 }
940
941 #[test]
942 fn test_extract_events_die_with_agent() {
943 let lex = TriggerLexicon::default_english();
944 let text = "The soldier died in battle last week.";
945 let events = extract_events(text, &lex);
946 assert!(!events.is_empty());
947 let e = events.iter().find(|e| e.event_type == "Die");
948 assert!(e.is_some());
949 }
950
951 #[test]
952 fn test_extract_events_transfer() {
953 let lex = TriggerLexicon::default_english();
954 let text = "The company sold its assets to the buyer yesterday.";
955 let events = extract_events(text, &lex);
956 assert!(events.iter().any(|e| e.event_type == "Transfer"));
957 }
958
959 #[test]
960 fn test_extract_events_multiple_sentences() {
961 let lex = TriggerLexicon::default_english();
962 let text = "Alice attacked the base. Bob fled to safety.";
963 let events = extract_events(text, &lex);
964 assert!(events.len() >= 2);
965 let types: Vec<&str> = events.iter().map(|e| e.event_type.as_str()).collect();
966 assert!(types.contains(&"Attack"));
967 assert!(types.contains(&"Move"));
968 }
969
970 #[test]
971 fn test_event_coref_same_type_and_argument() {
972 let lex = TriggerLexicon::default_english();
973 let text = "Alice arrested Bob on Monday. Police arrested Bob again on Tuesday.";
974 let events = extract_events(text, &lex);
975 let chains = event_coref(&events);
976 if !chains.is_empty() {
978 assert!(chains.iter().any(|c| c.len() >= 2));
979 }
980 }
981
982 #[test]
983 fn test_event_coref_different_types() {
984 let lex = TriggerLexicon::default_english();
985 let text = "Alice attacked the fort. Bob fled to the hills.";
986 let events = extract_events(text, &lex);
987 assert!(events.len() >= 2);
988 let chains = event_coref(&events);
990 assert!(chains.is_empty() || !chains.iter().any(|c| c.len() >= 2));
991 }
992
993 #[test]
994 fn test_extractor_builder() {
995 let extractor = EventExtractor::new().with_coref_threshold(0.5);
996 let (events, _chains) = extractor
997 .extract_with_coref("Police arrested the suspect in London.")
998 .expect("should not fail");
999 assert!(!events.is_empty());
1000 }
1001
1002 #[test]
1003 fn test_extractor_empty_text_error() {
1004 let extractor = EventExtractor::new();
1005 let result = extractor.extract_with_coref("");
1006 assert!(result.is_err());
1007 }
1008
1009 #[test]
1010 fn test_custom_lexicon() {
1011 let mut lex = TriggerLexicon::new();
1012 lex.insert("deployed", EventType::Move);
1013 lex.insert("commissioned", EventType::Create);
1014 let text = "The company deployed a new service and commissioned a report.";
1015 let events = extract_events(text, &lex);
1016 assert!(events.iter().any(|e| e.event_type == "Move"));
1017 assert!(events.iter().any(|e| e.event_type == "Create"));
1018 }
1019
1020 #[test]
1021 fn test_event_type_label() {
1022 assert_eq!(EventType::Move.label(), "Move");
1023 assert_eq!(EventType::Arrest.label(), "Arrest");
1024 assert_eq!(EventType::Custom("Foo".to_string()).label(), "Foo");
1025 }
1026}