1use crate::{Entity, Result};
44
45#[derive(Debug, Clone)]
47pub struct Event {
48 pub trigger: String,
50 pub trigger_start: usize,
52 pub trigger_end: usize,
54 pub event_type: String,
56 pub arguments: Vec<(String, Entity)>,
59 pub confidence: f64,
61}
62
63impl Event {
64 #[must_use]
66 pub fn new(
67 trigger: impl Into<String>,
68 trigger_start: usize,
69 trigger_end: usize,
70 event_type: impl Into<String>,
71 ) -> Self {
72 Self {
73 trigger: trigger.into(),
74 trigger_start,
75 trigger_end,
76 event_type: event_type.into(),
77 arguments: Vec::new(),
78 confidence: 1.0,
79 }
80 }
81
82 #[must_use]
84 pub fn with_argument(mut self, role: impl Into<String>, entity: Entity) -> Self {
85 self.arguments.push((role.into(), entity));
86 self
87 }
88
89 #[must_use]
91 pub fn with_confidence(mut self, confidence: f64) -> Self {
92 self.confidence = confidence.clamp(0.0, 1.0);
93 self
94 }
95}
96
97pub trait EventExtractor: Send + Sync {
103 fn extract_events(&self, text: &str, language: Option<&str>) -> Result<Vec<Event>>;
114
115 fn name(&self) -> &'static str;
117
118 fn description(&self) -> &'static str {
120 "Event extractor"
121 }
122}
123
124pub struct RuleBasedEventExtractor {
129 threshold: f64,
131}
132
133impl RuleBasedEventExtractor {
134 #[must_use]
136 pub fn new() -> Self {
137 Self { threshold: 0.5 }
138 }
139
140 #[must_use]
142 pub fn with_threshold(threshold: f64) -> Self {
143 Self {
144 threshold: threshold.clamp(0.0, 1.0),
145 }
146 }
147}
148
149impl Default for RuleBasedEventExtractor {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl EventExtractor for RuleBasedEventExtractor {
156 fn extract_events(&self, text: &str, language: Option<&str>) -> Result<Vec<Event>> {
157 let mut events = Vec::new();
158
159 let lang_code = language.map(|l| l.split('-').next().unwrap_or(l).to_lowercase());
162
163 let trigger_patterns: Vec<(&str, &str)> = match lang_code.as_deref() {
165 Some("es") => vec![
166 ("invadió", "conflict:attack"),
168 ("atacó", "conflict:attack"),
169 ("bombardeó", "conflict:attack"),
170 ("guerra", "conflict:attack"),
171 ("viajó", "movement:transport"),
173 ("movió", "movement:transport"),
174 ("desplegó", "movement:transport"),
175 ("compró", "transaction:transfer-ownership"),
177 ("vendió", "transaction:transfer-ownership"),
178 ("adquirió", "transaction:transfer-ownership"),
179 ("fundó", "business:start-org"),
181 ("inició", "business:start-org"),
182 ("fusionó", "business:merge-org"),
183 ("anunció", "communication:announce"),
185 ("declaró", "communication:announce"),
186 ("informó", "communication:announce"),
187 ("dijo", "communication:announce"),
188 ("nació", "life:be-born"),
190 ("murió", "life:die"),
191 ("se casó", "life:marry"),
192 ("divorció", "life:divorce"),
193 ("arrestó", "justice:arrest-jail"),
195 ("acusó", "justice:charge-indict"),
196 ("condenó", "justice:convict"),
197 ("sentenció", "justice:sentence"),
198 ],
199 Some("fr") => vec![
200 ("envahi", "conflict:attack"),
202 ("attaqué", "conflict:attack"),
203 ("bombardé", "conflict:attack"),
204 ("guerre", "conflict:attack"),
205 ("voyagé", "movement:transport"),
207 ("déplacé", "movement:transport"),
208 ("déployé", "movement:transport"),
209 ("acheté", "transaction:transfer-ownership"),
211 ("vendu", "transaction:transfer-ownership"),
212 ("acquis", "transaction:transfer-ownership"),
213 ("fondé", "business:start-org"),
215 ("créé", "business:start-org"),
216 ("fusionné", "business:merge-org"),
217 ("annoncé", "communication:announce"),
219 ("déclaré", "communication:announce"),
220 ("rapporté", "communication:announce"),
221 ("dit", "communication:announce"),
222 ("né", "life:be-born"),
224 ("mort", "life:die"),
225 ("marié", "life:marry"),
226 ("divorcé", "life:divorce"),
227 ("arrêté", "justice:arrest-jail"),
229 ("accusé", "justice:charge-indict"),
230 ("condamné", "justice:convict"),
231 ("condamné", "justice:sentence"),
232 ],
233 Some("de") => vec![
234 ("invadiert", "conflict:attack"),
236 ("angegriffen", "conflict:attack"),
237 ("bombardiert", "conflict:attack"),
238 ("krieg", "conflict:attack"),
239 ("gereist", "movement:transport"),
241 ("bewegt", "movement:transport"),
242 ("verlegt", "movement:transport"),
243 ("gekauft", "transaction:transfer-ownership"),
245 ("verkauft", "transaction:transfer-ownership"),
246 ("erworben", "transaction:transfer-ownership"),
247 ("gegründet", "business:start-org"),
249 ("gestartet", "business:start-org"),
250 ("fusioniert", "business:merge-org"),
251 ("angekündigt", "communication:announce"),
253 ("erklärt", "communication:announce"),
254 ("berichtet", "communication:announce"),
255 ("sagte", "communication:announce"),
256 ("geboren", "life:be-born"),
258 ("gestorben", "life:die"),
259 ("geheiratet", "life:marry"),
260 ("geschieden", "life:divorce"),
261 ("verhaftet", "justice:arrest-jail"),
263 ("angeklagt", "justice:charge-indict"),
264 ("verurteilt", "justice:convict"),
265 ("verurteilt", "justice:sentence"),
266 ],
267 Some("zh") | Some("ja") | Some("ko") | Some("ar") | Some("ru") => {
268 vec![]
271 }
272 _ => vec![
273 ("invaded", "conflict:attack"),
275 ("attacked", "conflict:attack"),
276 ("bombed", "conflict:attack"),
277 ("fired", "conflict:attack"),
278 ("war", "conflict:attack"),
279 ("traveled", "movement:transport"),
280 ("moved", "movement:transport"),
281 ("transported", "movement:transport"),
282 ("deployed", "movement:transport"),
283 ("bought", "transaction:transfer-ownership"),
284 ("sold", "transaction:transfer-ownership"),
285 ("purchased", "transaction:transfer-ownership"),
286 ("acquired", "transaction:transfer-ownership"),
287 ("founded", "business:start-org"),
288 ("started", "business:start-org"),
289 ("merged", "business:merge-org"),
290 ("bankruptcy", "business:declare-bankruptcy"),
291 ("announced", "communication:announce"),
292 ("stated", "communication:announce"),
293 ("reported", "communication:announce"),
294 ("said", "communication:announce"),
295 ("born", "life:be-born"),
296 ("died", "life:die"),
297 ("married", "life:marry"),
298 ("divorced", "life:divorce"),
299 ("arrested", "justice:arrest-jail"),
300 ("charged", "justice:charge-indict"),
301 ("convicted", "justice:convict"),
302 ("sentenced", "justice:sentence"),
303 ],
304 };
305
306 let text_lower = text.to_lowercase();
309 for (trigger_word, event_type) in trigger_patterns {
310 let search_text = if lang_code
312 .as_deref()
313 .is_some_and(|l| matches!(l, "zh" | "ja" | "ko" | "ar" | "ru"))
314 {
315 text } else {
317 &text_lower
318 };
319
320 if let Some(pos) = search_text.find(trigger_word) {
321 let char_start = text
323 .char_indices()
324 .nth(pos)
325 .map(|(i, _)| i)
326 .unwrap_or(text.len());
327 let char_end = text
328 .char_indices()
329 .nth(pos + trigger_word.chars().count())
330 .map(|(i, _)| i)
331 .unwrap_or(text.len());
332
333 let trigger_text: String = text
335 .chars()
336 .skip(pos)
337 .take(trigger_word.chars().count())
338 .collect();
339
340 let event = Event::new(trigger_text, char_start, char_end, event_type.to_string())
341 .with_confidence(0.7); if event.confidence >= self.threshold {
344 events.push(event);
345 }
346 }
347 }
348
349 Ok(events)
350 }
351
352 fn name(&self) -> &'static str {
353 "rule-based-event"
354 }
355
356 fn description(&self) -> &'static str {
357 "Rule-based event extraction using trigger word patterns"
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_rule_based_event_extraction() {
367 let extractor = RuleBasedEventExtractor::new();
368 let events = extractor
369 .extract_events("Russia invaded Ukraine in 2022.", None)
370 .unwrap();
371
372 assert!(!events.is_empty());
373 assert!(events
374 .iter()
375 .any(|e| e.trigger.to_lowercase() == "invaded" && e.event_type == "conflict:attack"));
376 }
377
378 #[test]
379 fn test_event_with_arguments() {
380 let event = Event::new("invaded", 7, 14, "conflict:attack");
381 assert_eq!(event.arguments.len(), 0);
383 }
384
385 #[test]
386 fn test_event_unicode_offsets() {
387 let extractor = RuleBasedEventExtractor::new();
388 let text = "ロシアがウクライナを侵攻した。"; let events = extractor.extract_events(text, Some("ja")).unwrap();
390
391 for event in &events {
393 assert!(event.trigger_start <= event.trigger_end);
394 assert!(event.trigger_end <= text.chars().count());
395 }
396 }
397
398 #[test]
399 fn test_multilingual_event_extraction() {
400 let extractor = RuleBasedEventExtractor::new();
401
402 let events_es = extractor
404 .extract_events("Rusia invadió Ucrania en 2022.", Some("es"))
405 .unwrap();
406 assert!(!events_es.is_empty(), "Should extract Spanish events");
407
408 let events_fr = extractor
410 .extract_events("La Russie a envahi l'Ukraine en 2022.", Some("fr"))
411 .unwrap();
412 assert!(!events_fr.is_empty(), "Should extract French events");
413
414 let events_de = extractor
416 .extract_events("Russland hat die Ukraine 2022 angegriffen.", Some("de"))
417 .unwrap();
418 assert!(!events_de.is_empty(), "Should extract German events");
419 }
420}