Skip to main content

scirs2_text/
event_extraction.rs

1//! Event extraction from natural language text.
2//!
3//! This module provides lexicon-based event detection, argument extraction,
4//! and event coreference resolution.  All processing is purely rule-based
5//! and does not require trained weights.
6//!
7//! # Overview
8//!
9//! - [`EventType`] — enumeration of recognised event categories
10//! - [`TriggerLexicon`] — mapping from action words to [`EventType`]
11//! - [`Argument`] — an event participant with its semantic role
12//! - [`Event`] — a single extracted event instance
13//! - [`extract_events`] — top-level extraction entry point
14//! - [`event_coref`] — group events that refer to the same occurrence
15//!
16//! # Example
17//!
18//! ```rust
19//! use scirs2_text::event_extraction::{TriggerLexicon, extract_events};
20//!
21//! let lex = TriggerLexicon::default_english();
22//! let text = "Police arrested the suspect yesterday in New York.";
23//! let events = extract_events(text, &lex);
24//! assert!(!events.is_empty());
25//! assert_eq!(events[0].event_type, "Arrest");
26//! ```
27
28use crate::error::{Result, TextError};
29use std::collections::HashMap;
30
31// ---------------------------------------------------------------------------
32// EventType
33// ---------------------------------------------------------------------------
34
35/// Coarse-grained event category.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum EventType {
38    /// Movement of a person, vehicle, or object.
39    Move,
40    /// Violent attack, assault, or strike.
41    Attack,
42    /// Meeting, gathering, or conference.
43    Meet,
44    /// Arrest, detain, or apprehend.
45    Arrest,
46    /// Death, killing, or fatality.
47    Die,
48    /// Transfer of money, ownership, or responsibility.
49    Transfer,
50    /// Creation, production, or manufacturing.
51    Create,
52    /// Destruction, demolition, or elimination.
53    Destroy,
54    /// User-defined event category.
55    Custom(String),
56}
57
58impl EventType {
59    /// Human-readable label.
60    pub fn label(&self) -> &str {
61        match self {
62            EventType::Move => "Move",
63            EventType::Attack => "Attack",
64            EventType::Meet => "Meet",
65            EventType::Arrest => "Arrest",
66            EventType::Die => "Die",
67            EventType::Transfer => "Transfer",
68            EventType::Create => "Create",
69            EventType::Destroy => "Destroy",
70            EventType::Custom(s) => s.as_str(),
71        }
72    }
73}
74
75impl std::fmt::Display for EventType {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        write!(f, "{}", self.label())
78    }
79}
80
81// ---------------------------------------------------------------------------
82// TriggerLexicon
83// ---------------------------------------------------------------------------
84
85/// Maps action words (triggers) to their event categories.
86///
87/// Trigger lookup is case-insensitive.  Each trigger word must appear as a
88/// standalone token boundary in the text.
89pub struct TriggerLexicon {
90    /// Lower-cased trigger word → event type.
91    pub triggers: HashMap<String, EventType>,
92}
93
94impl Default for TriggerLexicon {
95    fn default() -> Self {
96        Self::default_english()
97    }
98}
99
100impl TriggerLexicon {
101    /// Create an empty lexicon.
102    pub fn new() -> Self {
103        Self {
104            triggers: HashMap::new(),
105        }
106    }
107
108    /// Register a single trigger.
109    pub fn insert(&mut self, word: impl Into<String>, event_type: EventType) {
110        self.triggers.insert(word.into().to_lowercase(), event_type);
111    }
112
113    /// Look up an event type for a word (case-insensitive).
114    pub fn lookup(&self, word: &str) -> Option<&EventType> {
115        self.triggers.get(&word.to_lowercase())
116    }
117
118    /// Build a lexicon populated with common English triggers.
119    pub fn default_english() -> Self {
120        let mut lex = Self::new();
121
122        // Move
123        for w in &[
124            "moved",
125            "moving",
126            "move",
127            "traveled",
128            "travel",
129            "travelled",
130            "fled",
131            "flee",
132            "departed",
133            "depart",
134            "arrived",
135            "arrive",
136            "entered",
137            "enter",
138            "left",
139            "evacuated",
140            "evacuate",
141            "migrated",
142            "migrate",
143            "relocated",
144            "relocate",
145            "walked",
146            "ran",
147            "run",
148        ] {
149            lex.insert(*w, EventType::Move);
150        }
151
152        // Attack
153        for w in &[
154            "attacked",
155            "attack",
156            "assaulted",
157            "assault",
158            "bombed",
159            "bomb",
160            "shot",
161            "shoot",
162            "fired",
163            "fire",
164            "struck",
165            "strike",
166            "hit",
167            "targeted",
168            "target",
169            "raided",
170            "raid",
171            "invaded",
172            "invade",
173            "detonated",
174            "detonate",
175            "launched",
176            "launch",
177            "stabbed",
178            "stab",
179        ] {
180            lex.insert(*w, EventType::Attack);
181        }
182
183        // Meet
184        for w in &[
185            "met",
186            "meet",
187            "meeting",
188            "gathered",
189            "gather",
190            "assembled",
191            "assemble",
192            "convened",
193            "convene",
194            "discussed",
195            "discuss",
196            "negotiated",
197            "negotiate",
198            "talked",
199            "talk",
200            "conferenced",
201            "conferred",
202            "confer",
203            "visited",
204            "visit",
205        ] {
206            lex.insert(*w, EventType::Meet);
207        }
208
209        // Arrest
210        for w in &[
211            "arrested",
212            "arrest",
213            "detained",
214            "detain",
215            "apprehended",
216            "apprehend",
217            "captured",
218            "capture",
219            "jailed",
220            "jail",
221            "imprisoned",
222            "imprison",
223            "charged",
224            "charge",
225            "indicted",
226            "indict",
227            "booked",
228            "book",
229            "handcuffed",
230            "handcuff",
231        ] {
232            lex.insert(*w, EventType::Arrest);
233        }
234
235        // Die
236        for w in &[
237            "died",
238            "die",
239            "killed",
240            "kill",
241            "murdered",
242            "murder",
243            "executed",
244            "execute",
245            "slain",
246            "slayed",
247            "slay",
248            "perished",
249            "perish",
250            "deceased",
251            "assassinated",
252            "assassinate",
253            "fatally",
254        ] {
255            lex.insert(*w, EventType::Die);
256        }
257
258        // Transfer
259        for w in &[
260            "transferred",
261            "transfer",
262            "sold",
263            "sell",
264            "purchased",
265            "purchase",
266            "bought",
267            "buy",
268            "donated",
269            "donate",
270            "paid",
271            "pay",
272            "sent",
273            "send",
274            "received",
275            "receive",
276            "wired",
277            "wire",
278            "awarded",
279            "award",
280            "granted",
281            "grant",
282        ] {
283            lex.insert(*w, EventType::Transfer);
284        }
285
286        // Create
287        for w in &[
288            "created",
289            "create",
290            "built",
291            "build",
292            "developed",
293            "develop",
294            "founded",
295            "found",
296            "established",
297            "establish",
298            "launched",
299            "produced",
300            "produce",
301            "manufactured",
302            "manufacture",
303            "invented",
304            "invent",
305            "designed",
306            "design",
307            "wrote",
308            "write",
309            "authored",
310            "author",
311            "published",
312            "publish",
313            "formed",
314            "form",
315        ] {
316            lex.insert(*w, EventType::Create);
317        }
318
319        // Destroy
320        for w in &[
321            "destroyed",
322            "destroy",
323            "demolished",
324            "demolish",
325            "burned",
326            "burn",
327            "razed",
328            "raze",
329            "collapsed",
330            "collapse",
331            "ruined",
332            "ruin",
333            "dismantled",
334            "dismantle",
335            "obliterated",
336            "obliterate",
337            "wrecked",
338            "wreck",
339            "shattered",
340            "shatter",
341        ] {
342            lex.insert(*w, EventType::Destroy);
343        }
344
345        lex
346    }
347}
348
349// ---------------------------------------------------------------------------
350// Argument
351// ---------------------------------------------------------------------------
352
353/// A participant or modifier in an event.
354#[derive(Debug, Clone)]
355pub struct Argument {
356    /// Semantic role label (e.g. `"Agent"`, `"Patient"`, `"Location"`, `"Time"`).
357    pub role: String,
358    /// Surface text of the argument.
359    pub text: String,
360    /// Character span `(start, end)` in the source document.
361    pub span: (usize, usize),
362}
363
364// ---------------------------------------------------------------------------
365// Event
366// ---------------------------------------------------------------------------
367
368/// A single event instance extracted from text.
369#[derive(Debug, Clone)]
370pub struct Event {
371    /// The trigger word that anchored this event.
372    pub trigger: String,
373    /// Character span of the trigger `(start, end)`.
374    pub trigger_span: (usize, usize),
375    /// Human-readable event category label (e.g. `"Arrest"`).
376    pub event_type: String,
377    /// Extracted arguments (participants, location, time, etc.).
378    pub arguments: Vec<Argument>,
379}
380
381// ---------------------------------------------------------------------------
382// Tokenisation helpers
383// ---------------------------------------------------------------------------
384
385/// Character classes used during lightweight tokenisation.
386fn is_word_char(c: char) -> bool {
387    c.is_alphanumeric() || c == '\'' || c == '-'
388}
389
390/// Tokenise `text` into `(start, end, surface)` triples (byte offsets).
391fn tokenise(text: &str) -> Vec<(usize, usize, String)> {
392    let mut tokens: Vec<(usize, usize, String)> = Vec::new();
393    let mut start: Option<usize> = None;
394    for (i, c) in text.char_indices() {
395        if is_word_char(c) {
396            if start.is_none() {
397                start = Some(i);
398            }
399        } else if let Some(s) = start.take() {
400            tokens.push((s, i, text[s..i].to_string()));
401        }
402    }
403    if let Some(s) = start {
404        tokens.push((s, text.len(), text[s..].to_string()));
405    }
406    tokens
407}
408
409// ---------------------------------------------------------------------------
410// Sentence splitter
411// ---------------------------------------------------------------------------
412
413/// Split `text` into sentence strings with byte start offsets.
414fn sentences(text: &str) -> Vec<(usize, &str)> {
415    let mut result = Vec::new();
416    let mut start = 0usize;
417    let bytes = text.as_bytes();
418    let len = bytes.len();
419    while start < len {
420        let mut end = start;
421        while end < len {
422            let b = bytes[end];
423            if b == b'.' || b == b'!' || b == b'?' {
424                end += 1;
425                while end < len && (bytes[end] == b' ' || bytes[end] == b'\n') {
426                    end += 1;
427                }
428                break;
429            }
430            end += 1;
431        }
432        let raw = text[start..end].trim();
433        if !raw.is_empty() {
434            result.push((start, raw));
435        }
436        start = end;
437    }
438    result
439}
440
441// ---------------------------------------------------------------------------
442// Named-entity detection heuristics (no external model)
443// ---------------------------------------------------------------------------
444
445/// Detect spans that look like person / organisation names (consecutive
446/// capitalised tokens not at sentence start).
447fn detect_np_spans(
448    tokens: &[(usize, usize, String)],
449    sent_start_abs: usize,
450) -> Vec<(usize, usize, String)> {
451    let mut spans: Vec<(usize, usize, String)> = Vec::new();
452    let mut i = 0usize;
453    while i < tokens.len() {
454        let (tok_s, tok_e, word) = &tokens[i];
455        let abs_start = sent_start_abs + tok_s;
456        let abs_end = sent_start_abs + tok_e;
457
458        // Capitalised token (not the very first in sentence at offset 0)
459        if word.starts_with(|c: char| c.is_uppercase()) && abs_start > sent_start_abs {
460            let mut j = i;
461            while j < tokens.len() && tokens[j].2.starts_with(|c: char| c.is_uppercase()) {
462                j += 1;
463            }
464            if j > i {
465                let span_s = sent_start_abs + tokens[i].0;
466                let span_e = sent_start_abs + tokens[j - 1].1;
467                // Build surface by re-slicing — we don't have the original
468                // sentence slice here, so join tokens.
469                let surface: String = tokens[i..j]
470                    .iter()
471                    .map(|(_, _, w)| w.as_str())
472                    .collect::<Vec<_>>()
473                    .join(" ");
474                spans.push((span_s, span_e, surface));
475                i = j;
476                continue;
477            }
478        }
479        i += 1;
480    }
481    spans
482}
483
484/// Detect temporal tokens: patterns like "yesterday", "today", "Monday",
485/// month names, years, "N days ago", etc.
486fn detect_time_spans(
487    tokens: &[(usize, usize, String)],
488    sent_start_abs: usize,
489) -> Vec<(usize, usize, String)> {
490    const DAYS: &[&str] = &[
491        "monday",
492        "tuesday",
493        "wednesday",
494        "thursday",
495        "friday",
496        "saturday",
497        "sunday",
498    ];
499    const MONTHS: &[&str] = &[
500        "january",
501        "february",
502        "march",
503        "april",
504        "may",
505        "june",
506        "july",
507        "august",
508        "september",
509        "october",
510        "november",
511        "december",
512        "jan",
513        "feb",
514        "mar",
515        "apr",
516        "jun",
517        "jul",
518        "aug",
519        "sep",
520        "oct",
521        "nov",
522        "dec",
523    ];
524    const ABSOLUTE_TEMPS: &[&str] = &["yesterday", "today", "tomorrow", "now", "recently"];
525    const REL_ANCHORS: &[&str] = &["last", "next", "this", "coming", "previous"];
526    const UNITS: &[&str] = &[
527        "second", "seconds", "minute", "minutes", "hour", "hours", "day", "days", "week", "weeks",
528        "month", "months", "year", "years",
529    ];
530
531    let mut spans: Vec<(usize, usize, String)> = Vec::new();
532    let mut i = 0usize;
533    while i < tokens.len() {
534        let (tok_s, tok_e, word) = &tokens[i];
535        let abs_s = sent_start_abs + tok_s;
536        let abs_e = sent_start_abs + tok_e;
537        let lower = word.to_lowercase();
538
539        // Absolute temporal adverbs
540        if ABSOLUTE_TEMPS.contains(&lower.as_str()) {
541            spans.push((abs_s, abs_e, word.clone()));
542            i += 1;
543            continue;
544        }
545
546        // "last/next/this <day|month|period>"
547        if REL_ANCHORS.contains(&lower.as_str()) && i + 1 < tokens.len() {
548            let next_lower = tokens[i + 1].2.to_lowercase();
549            if DAYS.contains(&next_lower.as_str())
550                || MONTHS.contains(&next_lower.as_str())
551                || UNITS.contains(&next_lower.as_str())
552            {
553                let span_e = sent_start_abs + tokens[i + 1].1;
554                let surface = format!("{} {}", word, tokens[i + 1].2);
555                spans.push((abs_s, span_e, surface));
556                i += 2;
557                continue;
558            }
559        }
560
561        // "<N> <unit> ago"  or  "<N> <unit>"
562        if lower.chars().all(|c| c.is_ascii_digit()) && i + 1 < tokens.len() {
563            let unit_lower = tokens[i + 1].2.to_lowercase();
564            if UNITS.contains(&unit_lower.as_str()) {
565                let mut span_e = sent_start_abs + tokens[i + 1].1;
566                let mut surface = format!("{} {}", word, tokens[i + 1].2);
567                // optional trailing "ago"
568                if i + 2 < tokens.len() && tokens[i + 2].2.to_lowercase() == "ago" {
569                    span_e = sent_start_abs + tokens[i + 2].1;
570                    surface = format!("{} ago", surface);
571                    i += 3;
572                } else {
573                    i += 2;
574                }
575                spans.push((abs_s, span_e, surface));
576                continue;
577            }
578        }
579
580        // Day / month standalone
581        if DAYS.contains(&lower.as_str()) || MONTHS.contains(&lower.as_str()) {
582            spans.push((abs_s, abs_e, word.clone()));
583            i += 1;
584            continue;
585        }
586
587        // 4-digit year (1000–2099)
588        if lower.len() == 4
589            && lower.starts_with(['1', '2'])
590            && lower.chars().all(|c| c.is_ascii_digit())
591        {
592            spans.push((abs_s, abs_e, word.clone()));
593            i += 1;
594            continue;
595        }
596
597        i += 1;
598    }
599    spans
600}
601
602/// Detect location heuristics: preposition "in/at/from/to" + capitalised NP.
603fn detect_location_spans(
604    tokens: &[(usize, usize, String)],
605    sent_start_abs: usize,
606    np_spans: &[(usize, usize, String)],
607) -> Vec<(usize, usize, String)> {
608    const LOC_PREPS: &[&str] = &["in", "at", "from", "to", "near", "around", "through"];
609    let mut locs: Vec<(usize, usize, String)> = Vec::new();
610
611    for (i, (tok_s, _tok_e, word)) in tokens.iter().enumerate() {
612        let lower = word.to_lowercase();
613        if LOC_PREPS.contains(&lower.as_str()) {
614            if let Some(next) = tokens.get(i + 1) {
615                let next_abs_s = sent_start_abs + next.0;
616                let next_abs_e = sent_start_abs + next.1;
617                // Look for an NP that starts right here
618                let found = np_spans.iter().find(|(ns, _ne, _surf)| *ns == next_abs_s);
619                if let Some((ns, ne, surf)) = found {
620                    locs.push((*ns, *ne, surf.clone()));
621                } else if next.2.starts_with(|c: char| c.is_uppercase()) {
622                    locs.push((next_abs_s, next_abs_e, next.2.clone()));
623                }
624            }
625        }
626        let _ = tok_s;
627    }
628    locs
629}
630
631// ---------------------------------------------------------------------------
632// Core extraction logic
633// ---------------------------------------------------------------------------
634
635/// Extract events from `text` using the supplied trigger lexicon.
636///
637/// For each sentence we identify trigger tokens and then apply heuristic
638/// dependency patterns to fill `Agent`, `Patient`, `Location`, and `Time`
639/// argument roles:
640///
641/// - **Agent**: the first NP before the trigger (subject position)
642/// - **Patient**: the first NP after the trigger (object position)
643/// - **Location**: any `in/at/from/to + NP` within ±3 tokens of the trigger
644/// - **Time**: any temporal expression within ±5 tokens of the trigger
645pub fn extract_events(text: &str, triggers: &TriggerLexicon) -> Vec<Event> {
646    let mut events: Vec<Event> = Vec::new();
647
648    for (sent_off, sent_text) in sentences(text) {
649        let tokens = tokenise(sent_text);
650        if tokens.is_empty() {
651            continue;
652        }
653
654        let np_spans = detect_np_spans(&tokens, sent_off);
655        let time_spans = detect_time_spans(&tokens, sent_off);
656        let loc_spans = detect_location_spans(&tokens, sent_off, &np_spans);
657
658        for (tok_idx, (tok_s, tok_e, word)) in tokens.iter().enumerate() {
659            let abs_trig_s = sent_off + tok_s;
660            let abs_trig_e = sent_off + tok_e;
661
662            let etype = match triggers.lookup(word) {
663                Some(et) => et,
664                None => continue,
665            };
666
667            let mut args: Vec<Argument> = Vec::new();
668
669            // Agent: closest NP whose end ≤ trigger start
670            let agent = np_spans
671                .iter()
672                .filter(|(_, ne, _)| *ne <= abs_trig_s)
673                .max_by_key(|(ns, _, _)| *ns);
674            if let Some((ns, ne, surf)) = agent {
675                args.push(Argument {
676                    role: "Agent".to_string(),
677                    text: surf.clone(),
678                    span: (*ns, *ne),
679                });
680            }
681
682            // Patient: closest NP whose start ≥ trigger end
683            let patient = np_spans
684                .iter()
685                .filter(|(ns, _, _)| *ns >= abs_trig_e)
686                .min_by_key(|(ns, _, _)| *ns);
687            if let Some((ns, ne, surf)) = patient {
688                args.push(Argument {
689                    role: "Patient".to_string(),
690                    text: surf.clone(),
691                    span: (*ns, *ne),
692                });
693            }
694
695            // Location: within ±5 token window
696            let window_start = tok_idx.saturating_sub(5);
697            let window_end = (tok_idx + 6).min(tokens.len());
698            let window_abs_s = sent_off + tokens[window_start].0;
699            let window_abs_e = sent_off + tokens[window_end - 1].1;
700
701            for (ls, le, lsurf) in &loc_spans {
702                if *ls >= window_abs_s && *le <= window_abs_e {
703                    args.push(Argument {
704                        role: "Location".to_string(),
705                        text: lsurf.clone(),
706                        span: (*ls, *le),
707                    });
708                }
709            }
710
711            // Time: within ±6 token window (slightly wider)
712            let twin_start = tok_idx.saturating_sub(6);
713            let twin_end = (tok_idx + 7).min(tokens.len());
714            let twin_abs_s = sent_off + tokens[twin_start].0;
715            let twin_abs_e = sent_off + tokens[twin_end - 1].1;
716
717            for (ts, te, tsurf) in &time_spans {
718                if *ts >= twin_abs_s && *te <= twin_abs_e {
719                    args.push(Argument {
720                        role: "Time".to_string(),
721                        text: tsurf.clone(),
722                        span: (*ts, *te),
723                    });
724                }
725            }
726
727            events.push(Event {
728                trigger: word.clone(),
729                trigger_span: (abs_trig_s, abs_trig_e),
730                event_type: etype.label().to_string(),
731                arguments: args,
732            });
733        }
734    }
735
736    events
737}
738
739// ---------------------------------------------------------------------------
740// Event coreference
741// ---------------------------------------------------------------------------
742
743/// Compute a simple similarity score between two events for coreference
744/// clustering purposes.
745///
746/// Factors:
747/// - Same event type (0.5)
748/// - Overlapping argument text (up to 0.4 weighted by number of shared args)
749/// - Trigger edit distance ≤ 2 (0.1)
750fn event_similarity(e1: &Event, e2: &Event) -> f64 {
751    let mut score = 0.0f64;
752
753    // Same event type
754    if e1.event_type == e2.event_type {
755        score += 0.5;
756    }
757
758    // Shared argument texts
759    let texts1: std::collections::HashSet<String> =
760        e1.arguments.iter().map(|a| a.text.to_lowercase()).collect();
761    let texts2: std::collections::HashSet<String> =
762        e2.arguments.iter().map(|a| a.text.to_lowercase()).collect();
763
764    let shared = texts1.intersection(&texts2).count();
765    let total = texts1.len().max(texts2.len());
766    if total > 0 {
767        score += 0.4 * (shared as f64 / total as f64);
768    }
769
770    // Trigger similarity (simple character-level)
771    let t1 = e1.trigger.to_lowercase();
772    let t2 = e2.trigger.to_lowercase();
773    if t1 == t2 || levenshtein(t1.as_bytes(), t2.as_bytes()) <= 2 {
774        score += 0.1;
775    }
776
777    score
778}
779
780/// Compute Levenshtein distance between two byte slices.
781fn levenshtein(a: &[u8], b: &[u8]) -> usize {
782    let m = a.len();
783    let n = b.len();
784    if m == 0 {
785        return n;
786    }
787    if n == 0 {
788        return m;
789    }
790    let mut dp: Vec<usize> = (0..=n).collect();
791    for i in 1..=m {
792        let mut prev = dp[0];
793        dp[0] = i;
794        for j in 1..=n {
795            let tmp = dp[j];
796            dp[j] = if a[i - 1] == b[j - 1] {
797                prev
798            } else {
799                1 + prev.min(dp[j]).min(dp[j - 1])
800            };
801            prev = tmp;
802        }
803    }
804    dp[n]
805}
806
807/// Group events into coreference chains using single-linkage clustering.
808///
809/// Two events are placed in the same chain when their similarity score
810/// (see `event_similarity`) exceeds `threshold` (default 0.6).
811///
812/// Returns a `Vec<Vec<usize>>` where each inner vector holds the indices of
813/// coreferent events (index into the input `events` slice).
814pub fn event_coref(events: &[Event]) -> Vec<Vec<usize>> {
815    event_coref_with_threshold(events, 0.6)
816}
817
818/// Like [`event_coref`] but with a configurable similarity threshold.
819pub fn event_coref_with_threshold(events: &[Event], threshold: f64) -> Vec<Vec<usize>> {
820    let n = events.len();
821    // Union-Find
822    let mut parent: Vec<usize> = (0..n).collect();
823
824    fn find(parent: &mut Vec<usize>, x: usize) -> usize {
825        if parent[x] != x {
826            parent[x] = find(parent, parent[x]);
827        }
828        parent[x]
829    }
830
831    for i in 0..n {
832        for j in (i + 1)..n {
833            if event_similarity(&events[i], &events[j]) >= threshold {
834                let ri = find(&mut parent, i);
835                let rj = find(&mut parent, j);
836                if ri != rj {
837                    parent[rj] = ri;
838                }
839            }
840        }
841    }
842
843    // Collect chains
844    let mut chains: HashMap<usize, Vec<usize>> = HashMap::new();
845    for i in 0..n {
846        let root = find(&mut parent, i);
847        chains.entry(root).or_default().push(i);
848    }
849
850    // Keep only chains with > 1 member (actual coreference)
851    let mut result: Vec<Vec<usize>> = chains.into_values().filter(|v| v.len() >= 2).collect();
852    result.sort_by_key(|v| v[0]);
853    result
854}
855
856// ---------------------------------------------------------------------------
857// Builder / convenience
858// ---------------------------------------------------------------------------
859
860/// High-level interface for configurable event extraction.
861pub struct EventExtractor {
862    lexicon: TriggerLexicon,
863    coref_threshold: f64,
864}
865
866impl Default for EventExtractor {
867    fn default() -> Self {
868        Self::new()
869    }
870}
871
872impl EventExtractor {
873    /// Create an extractor with the default English trigger lexicon.
874    pub fn new() -> Self {
875        Self {
876            lexicon: TriggerLexicon::default_english(),
877            coref_threshold: 0.6,
878        }
879    }
880
881    /// Replace the trigger lexicon.
882    pub fn with_lexicon(mut self, lexicon: TriggerLexicon) -> Self {
883        self.lexicon = lexicon;
884        self
885    }
886
887    /// Set the similarity threshold used for event coreference clustering.
888    pub fn with_coref_threshold(mut self, threshold: f64) -> Self {
889        self.coref_threshold = threshold;
890        self
891    }
892
893    /// Extract events from `text`.
894    pub fn extract(&self, text: &str) -> Vec<Event> {
895        extract_events(text, &self.lexicon)
896    }
897
898    /// Extract events and return coreference chains alongside.
899    pub fn extract_with_coref(&self, text: &str) -> Result<(Vec<Event>, Vec<Vec<usize>>)> {
900        if text.is_empty() {
901            return Err(TextError::InvalidInput(
902                "Input text must not be empty".to_string(),
903            ));
904        }
905        let events = self.extract(text);
906        let chains = event_coref_with_threshold(&events, self.coref_threshold);
907        Ok((events, chains))
908    }
909}
910
911// ---------------------------------------------------------------------------
912// Tests
913// ---------------------------------------------------------------------------
914
915#[cfg(test)]
916mod tests {
917    use super::*;
918
919    #[test]
920    fn test_trigger_lexicon_lookup() {
921        let lex = TriggerLexicon::default_english();
922        assert_eq!(lex.lookup("arrested"), Some(&EventType::Arrest));
923        assert_eq!(lex.lookup("ARRESTED"), Some(&EventType::Arrest));
924        assert_eq!(lex.lookup("died"), Some(&EventType::Die));
925        assert_eq!(lex.lookup("unknown_verb"), None);
926    }
927
928    #[test]
929    fn test_extract_events_arrest() {
930        let lex = TriggerLexicon::default_english();
931        let text = "Police arrested the suspect yesterday in New York.";
932        let events = extract_events(text, &lex);
933        assert!(!events.is_empty());
934        let e = events.iter().find(|e| e.event_type == "Arrest");
935        assert!(e.is_some(), "Expected an Arrest event");
936        let ev = e.expect("already checked");
937        // Patient argument expected (suspect)
938        assert!(ev.arguments.iter().any(|a| a.role == "Patient"));
939    }
940
941    #[test]
942    fn test_extract_events_die_with_agent() {
943        let lex = TriggerLexicon::default_english();
944        let text = "The soldier died in battle last week.";
945        let events = extract_events(text, &lex);
946        assert!(!events.is_empty());
947        let e = events.iter().find(|e| e.event_type == "Die");
948        assert!(e.is_some());
949    }
950
951    #[test]
952    fn test_extract_events_transfer() {
953        let lex = TriggerLexicon::default_english();
954        let text = "The company sold its assets to the buyer yesterday.";
955        let events = extract_events(text, &lex);
956        assert!(events.iter().any(|e| e.event_type == "Transfer"));
957    }
958
959    #[test]
960    fn test_extract_events_multiple_sentences() {
961        let lex = TriggerLexicon::default_english();
962        let text = "Alice attacked the base. Bob fled to safety.";
963        let events = extract_events(text, &lex);
964        assert!(events.len() >= 2);
965        let types: Vec<&str> = events.iter().map(|e| e.event_type.as_str()).collect();
966        assert!(types.contains(&"Attack"));
967        assert!(types.contains(&"Move"));
968    }
969
970    #[test]
971    fn test_event_coref_same_type_and_argument() {
972        let lex = TriggerLexicon::default_english();
973        let text = "Alice arrested Bob on Monday. Police arrested Bob again on Tuesday.";
974        let events = extract_events(text, &lex);
975        let chains = event_coref(&events);
976        // At least one chain should group the two arrest events
977        if !chains.is_empty() {
978            assert!(chains.iter().any(|c| c.len() >= 2));
979        }
980    }
981
982    #[test]
983    fn test_event_coref_different_types() {
984        let lex = TriggerLexicon::default_english();
985        let text = "Alice attacked the fort. Bob fled to the hills.";
986        let events = extract_events(text, &lex);
987        assert!(events.len() >= 2);
988        // Attack and Move should not be coreferent
989        let chains = event_coref(&events);
990        assert!(chains.is_empty() || !chains.iter().any(|c| c.len() >= 2));
991    }
992
993    #[test]
994    fn test_extractor_builder() {
995        let extractor = EventExtractor::new().with_coref_threshold(0.5);
996        let (events, _chains) = extractor
997            .extract_with_coref("Police arrested the suspect in London.")
998            .expect("should not fail");
999        assert!(!events.is_empty());
1000    }
1001
1002    #[test]
1003    fn test_extractor_empty_text_error() {
1004        let extractor = EventExtractor::new();
1005        let result = extractor.extract_with_coref("");
1006        assert!(result.is_err());
1007    }
1008
1009    #[test]
1010    fn test_custom_lexicon() {
1011        let mut lex = TriggerLexicon::new();
1012        lex.insert("deployed", EventType::Move);
1013        lex.insert("commissioned", EventType::Create);
1014        let text = "The company deployed a new service and commissioned a report.";
1015        let events = extract_events(text, &lex);
1016        assert!(events.iter().any(|e| e.event_type == "Move"));
1017        assert!(events.iter().any(|e| e.event_type == "Create"));
1018    }
1019
1020    #[test]
1021    fn test_event_type_label() {
1022        assert_eq!(EventType::Move.label(), "Move");
1023        assert_eq!(EventType::Arrest.label(), "Arrest");
1024        assert_eq!(EventType::Custom("Foo".to_string()).label(), "Foo");
1025    }
1026}