1#[derive(Debug, Clone, PartialEq)]
27pub struct TextTriple {
28 pub subject: String,
30 pub predicate: String,
32 pub object: String,
34 pub confidence: f64,
36 pub source_span: (usize, usize),
38}
39
40#[derive(Debug, Clone)]
42pub struct ExtractionPattern {
43 pub name: String,
45 pub subject_token: String,
47 pub predicate_words: Vec<String>,
49 pub object_token: String,
51}
52
53impl ExtractionPattern {
54 pub fn new(
56 name: impl Into<String>,
57 subject_token: impl Into<String>,
58 predicate_words: Vec<String>,
59 object_token: impl Into<String>,
60 ) -> Self {
61 Self {
62 name: name.into(),
63 subject_token: subject_token.into(),
64 predicate_words,
65 object_token: object_token.into(),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ExtractionConfig {
73 pub min_confidence: f64,
75 pub max_triples_per_sentence: usize,
77 pub normalize_predicates: bool,
79}
80
81impl Default for ExtractionConfig {
82 fn default() -> Self {
83 Self {
84 min_confidence: 0.3,
85 max_triples_per_sentence: 10,
86 normalize_predicates: true,
87 }
88 }
89}
90
91pub struct TripleExtractor {
93 patterns: Vec<ExtractionPattern>,
94 config: ExtractionConfig,
95}
96
97impl TripleExtractor {
98 pub fn new(config: ExtractionConfig) -> Self {
100 Self {
101 patterns: Vec::new(),
102 config,
103 }
104 }
105
106 pub fn with_defaults(config: ExtractionConfig) -> Self {
111 let mut extractor = Self::new(config);
112
113 let defaults: &[(&str, &[&str])] = &[
114 ("is_a", &["is", "a"]),
115 ("is_an", &["is", "an"]),
116 ("is", &["is"]),
117 ("has", &["has"]),
118 ("works_at", &["works", "at"]),
119 ("located_in", &["located", "in"]),
120 ("founded_by", &["founded", "by"]),
121 ("created_by", &["created", "by"]),
122 ("known_as", &["known", "as"]),
123 ("born_in", &["born", "in"]),
124 ("part_of", &["part", "of"]),
125 ];
126
127 for (name, words) in defaults {
128 extractor.patterns.push(ExtractionPattern::new(
129 *name,
130 "subject",
131 words.iter().map(|w| w.to_string()).collect(),
132 "object",
133 ));
134 }
135
136 extractor
137 }
138
139 pub fn add_pattern(&mut self, pattern: ExtractionPattern) {
141 self.patterns.push(pattern);
142 }
143
144 pub fn pattern_count(&self) -> usize {
146 self.patterns.len()
147 }
148
149 pub fn extract(&self, text: &str) -> Vec<TextTriple> {
151 if text.trim().is_empty() {
152 return Vec::new();
153 }
154
155 let mut results = Vec::new();
156 let mut offset = 0usize;
157
158 for sentence in text.split_terminator(['.', '!', '?']) {
159 let trimmed = sentence.trim();
160 if !trimmed.is_empty() {
161 let sentence_start = text[offset..]
163 .find(trimmed)
164 .map(|pos| offset + pos)
165 .unwrap_or(offset);
166
167 let triples = self.extract_sentence_with_offset(trimmed, sentence_start);
168 results.extend(triples);
169 offset = sentence_start + trimmed.len();
170 }
171 }
172
173 results
174 }
175
176 pub fn extract_sentence(&self, sentence: &str) -> Vec<TextTriple> {
178 self.extract_sentence_with_offset(sentence, 0)
179 }
180
181 fn extract_sentence_with_offset(&self, sentence: &str, base_offset: usize) -> Vec<TextTriple> {
184 let words: Vec<&str> = sentence.split_whitespace().collect();
185 if words.len() < 3 {
186 return Vec::new();
187 }
188
189 let mut results = Vec::new();
190
191 'pattern_loop: for pattern in &self.patterns {
192 let pw = &pattern.predicate_words;
193
194 if pw.is_empty() || words.len() < pw.len() + 2 {
195 continue;
196 }
197
198 for start in 1..words.len().saturating_sub(pw.len()) {
200 let window_end = start + pw.len();
202 if window_end >= words.len() {
203 continue;
204 }
205
206 let matches = words[start..window_end]
207 .iter()
208 .zip(pw.iter())
209 .all(|(w, p)| {
210 w.to_lowercase().trim_matches(|c: char| !c.is_alphabetic())
211 == p.to_lowercase()
212 });
213
214 if !matches {
215 continue;
216 }
217
218 let subject = sanitise_token(words[start - 1]);
219 let object_idx = window_end;
220 if object_idx >= words.len() {
221 continue;
222 }
223 let object = sanitise_token(words[object_idx]);
224
225 if subject.is_empty() || object.is_empty() {
226 continue;
227 }
228
229 let raw_predicate = pw.join(" ");
230 let predicate = if self.config.normalize_predicates {
231 Self::normalize_predicate(&raw_predicate)
232 } else {
233 raw_predicate
234 };
235
236 let confidence = Self::confidence_for_pattern(pw.len(), pw.len().max(1));
237
238 if confidence < self.config.min_confidence {
239 continue;
240 }
241
242 let span_start = sentence.find(words[start - 1]).unwrap_or(0);
244 let span_end = sentence
245 .rfind(words[object_idx])
246 .map(|p| p + words[object_idx].len())
247 .unwrap_or(sentence.len());
248
249 results.push(TextTriple {
250 subject: subject.to_string(),
251 predicate,
252 object: object.to_string(),
253 confidence,
254 source_span: (base_offset + span_start, base_offset + span_end),
255 });
256
257 if results.len() >= self.config.max_triples_per_sentence {
258 break 'pattern_loop;
259 }
260
261 break; }
263 }
264
265 results
266 }
267
268 pub fn normalize_predicate(predicate: &str) -> String {
271 const STOP_WORDS: &[&str] = &["a", "an", "the", "of", "by", "at", "in", "as"];
272 let lower = predicate.to_lowercase();
273 let parts: Vec<&str> = lower
274 .split_whitespace()
275 .filter(|w| !STOP_WORDS.contains(w))
276 .collect();
277 if parts.is_empty() {
278 lower.trim().to_string()
279 } else {
280 parts.join("_")
281 }
282 }
283
284 pub fn confidence_for_pattern(matched_words: usize, total_pattern_words: usize) -> f64 {
288 if total_pattern_words == 0 {
289 return 0.0;
290 }
291 (matched_words as f64 / total_pattern_words as f64).clamp(0.0, 1.0)
292 }
293
294 pub fn to_knowledge_graph(triples: &[TextTriple]) -> Vec<(String, String, String)> {
296 triples
297 .iter()
298 .map(|t| (t.subject.clone(), t.predicate.clone(), t.object.clone()))
299 .collect()
300 }
301}
302
303fn sanitise_token(token: &str) -> &str {
305 let start = token
306 .char_indices()
307 .find(|(_, c)| c.is_alphanumeric())
308 .map(|(i, _)| i)
309 .unwrap_or(token.len());
310 let end = token
311 .char_indices()
312 .rev()
313 .find(|(_, c)| c.is_alphanumeric())
314 .map(|(i, c)| i + c.len_utf8())
315 .unwrap_or(start);
316 &token[start..end]
317}
318
319#[cfg(test)]
324mod tests {
325 use super::*;
326
327 fn default_extractor() -> TripleExtractor {
328 TripleExtractor::with_defaults(ExtractionConfig::default())
329 }
330
331 #[test]
334 fn test_extract_simple_is_sentence() {
335 let extractor = default_extractor();
336 let triples = extractor.extract("Alice is an engineer.");
337 assert!(!triples.is_empty(), "expected at least one triple");
338 let t = &triples[0];
339 assert_eq!(t.subject, "Alice");
340 assert_eq!(t.object, "engineer");
341 }
342
343 #[test]
344 fn test_extract_has_relation() {
345 let extractor = default_extractor();
346 let triples = extractor.extract("Bob has a degree.");
347 assert!(!triples.is_empty());
348 assert_eq!(triples[0].subject, "Bob");
349 assert!(!triples[0].object.is_empty());
352 }
353
354 #[test]
355 fn test_extract_works_at() {
356 let extractor = default_extractor();
357 let triples = extractor.extract("Carol works at Google.");
358 assert!(!triples.is_empty());
359 let t = &triples[0];
360 assert_eq!(t.subject, "Carol");
361 assert_eq!(t.object, "Google");
362 }
363
364 #[test]
365 fn test_extract_located_in() {
366 let extractor = default_extractor();
367 let triples = extractor.extract("Paris is located in France.");
368 let located: Vec<_> = triples
369 .iter()
370 .filter(|t| t.predicate.contains("located"))
371 .collect();
372 assert!(!located.is_empty(), "expected located_in triple");
373 }
374
375 #[test]
376 fn test_extract_empty_text() {
377 let extractor = default_extractor();
378 let triples = extractor.extract("");
379 assert!(triples.is_empty());
380 }
381
382 #[test]
383 fn test_extract_whitespace_only() {
384 let extractor = default_extractor();
385 let triples = extractor.extract(" ");
386 assert!(triples.is_empty());
387 }
388
389 #[test]
390 fn test_extract_sentence_direct() {
391 let extractor = default_extractor();
392 let triples = extractor.extract_sentence("Dave is a scientist");
393 assert!(!triples.is_empty());
394 assert_eq!(triples[0].subject, "Dave");
395 }
396
397 #[test]
398 fn test_extract_multiple_sentences() {
399 let extractor = default_extractor();
400 let text = "Alice is an engineer. Bob has a job.";
401 let triples = extractor.extract(text);
402 assert!(triples.len() >= 2, "got {} triples", triples.len());
403 }
404
405 #[test]
408 fn test_with_defaults_has_patterns() {
409 let extractor = default_extractor();
410 assert!(extractor.pattern_count() > 0);
411 }
412
413 #[test]
414 fn test_new_extractor_no_patterns() {
415 let extractor = TripleExtractor::new(ExtractionConfig::default());
416 assert_eq!(extractor.pattern_count(), 0);
417 }
418
419 #[test]
420 fn test_add_pattern_increases_count() {
421 let mut extractor = TripleExtractor::new(ExtractionConfig::default());
422 let initial = extractor.pattern_count();
423 extractor.add_pattern(ExtractionPattern::new(
424 "likes",
425 "subject",
426 vec!["likes".to_string()],
427 "object",
428 ));
429 assert_eq!(extractor.pattern_count(), initial + 1);
430 }
431
432 #[test]
433 fn test_custom_pattern_extraction() {
434 let mut extractor = TripleExtractor::new(ExtractionConfig {
435 min_confidence: 0.0,
436 normalize_predicates: false,
437 max_triples_per_sentence: 10,
438 });
439 extractor.add_pattern(ExtractionPattern::new(
440 "likes",
441 "S",
442 vec!["likes".to_string()],
443 "O",
444 ));
445 let triples = extractor.extract_sentence("Alice likes cats");
446 assert!(!triples.is_empty());
447 assert_eq!(triples[0].subject, "Alice");
448 assert_eq!(triples[0].object, "cats");
449 }
450
451 #[test]
454 fn test_confidence_for_pattern_full_match() {
455 let c = TripleExtractor::confidence_for_pattern(3, 3);
456 assert!((c - 1.0).abs() < 1e-10);
457 }
458
459 #[test]
460 fn test_confidence_for_pattern_half_match() {
461 let c = TripleExtractor::confidence_for_pattern(1, 2);
462 assert!((c - 0.5).abs() < 1e-10);
463 }
464
465 #[test]
466 fn test_confidence_for_pattern_zero_words() {
467 let c = TripleExtractor::confidence_for_pattern(0, 0);
468 assert_eq!(c, 0.0);
469 }
470
471 #[test]
472 fn test_confidence_clamped_to_one() {
473 let c = TripleExtractor::confidence_for_pattern(10, 5);
474 assert!((c - 1.0).abs() < 1e-10);
475 }
476
477 #[test]
478 fn test_triple_confidence_is_positive() {
479 let extractor = default_extractor();
480 let triples = extractor.extract("Alice is a coder.");
481 for t in &triples {
482 assert!(t.confidence > 0.0, "confidence should be positive");
483 }
484 }
485
486 #[test]
489 fn test_min_confidence_filter_excludes_low() {
490 let mut extractor = TripleExtractor::new(ExtractionConfig {
491 min_confidence: 2.0, max_triples_per_sentence: 10,
493 normalize_predicates: true,
494 });
495 extractor.add_pattern(ExtractionPattern::new(
496 "is",
497 "S",
498 vec!["is".to_string()],
499 "O",
500 ));
501 let triples = extractor.extract_sentence("Alice is Bob");
502 assert!(triples.is_empty());
503 }
504
505 #[test]
506 fn test_min_confidence_zero_allows_all() {
507 let config = ExtractionConfig {
508 min_confidence: 0.0,
509 max_triples_per_sentence: 10,
510 normalize_predicates: true,
511 };
512 let extractor = TripleExtractor::with_defaults(config);
513 let triples = extractor.extract_sentence("Alice is Bob");
514 assert!(!triples.is_empty());
515 }
516
517 #[test]
520 fn test_max_triples_per_sentence_limits_output() {
521 let config = ExtractionConfig {
522 min_confidence: 0.0,
523 max_triples_per_sentence: 1,
524 normalize_predicates: true,
525 };
526 let extractor = TripleExtractor::with_defaults(config);
527 let triples = extractor.extract_sentence("Alice is Bob");
528 assert!(triples.len() <= 1);
529 }
530
531 #[test]
534 fn test_normalize_predicate_lowercase() {
535 let norm = TripleExtractor::normalize_predicate("IS");
536 assert_eq!(norm, "is");
537 }
538
539 #[test]
540 fn test_normalize_predicate_removes_articles() {
541 let norm = TripleExtractor::normalize_predicate("is a");
542 assert!(!norm.contains(" a"), "got: {}", norm);
544 assert!(norm.contains("is"));
545 }
546
547 #[test]
548 fn test_normalize_predicate_removes_stopwords() {
549 let norm = TripleExtractor::normalize_predicate("born in");
550 assert!(!norm.ends_with("_in"), "got: {}", norm);
552 assert!(norm.contains("born"));
553 }
554
555 #[test]
556 fn test_normalize_predicate_empty() {
557 let norm = TripleExtractor::normalize_predicate("");
558 assert_eq!(norm, "");
559 }
560
561 #[test]
562 fn test_normalize_predicate_single_word() {
563 let norm = TripleExtractor::normalize_predicate("Has");
564 assert_eq!(norm, "has");
565 }
566
567 #[test]
570 fn test_to_knowledge_graph_format() {
571 let triples = vec![TextTriple {
572 subject: "Alice".to_string(),
573 predicate: "knows".to_string(),
574 object: "Bob".to_string(),
575 confidence: 0.9,
576 source_span: (0, 10),
577 }];
578 let kg = TripleExtractor::to_knowledge_graph(&triples);
579 assert_eq!(kg.len(), 1);
580 assert_eq!(kg[0].0, "Alice");
581 assert_eq!(kg[0].1, "knows");
582 assert_eq!(kg[0].2, "Bob");
583 }
584
585 #[test]
586 fn test_to_knowledge_graph_empty() {
587 let kg = TripleExtractor::to_knowledge_graph(&[]);
588 assert!(kg.is_empty());
589 }
590
591 #[test]
592 fn test_to_knowledge_graph_multiple() {
593 let triples = vec![
594 TextTriple {
595 subject: "A".to_string(),
596 predicate: "p".to_string(),
597 object: "B".to_string(),
598 confidence: 1.0,
599 source_span: (0, 5),
600 },
601 TextTriple {
602 subject: "C".to_string(),
603 predicate: "q".to_string(),
604 object: "D".to_string(),
605 confidence: 0.8,
606 source_span: (6, 11),
607 },
608 ];
609 let kg = TripleExtractor::to_knowledge_graph(&triples);
610 assert_eq!(kg.len(), 2);
611 }
612
613 #[test]
616 fn test_source_span_non_zero_for_offset() {
617 let extractor = default_extractor();
618 let text = "First sentence. Alice is a tester.";
619 let triples = extractor.extract(text);
620 let tester_triple = triples.iter().find(|t| t.object == "tester");
622 if let Some(t) = tester_triple {
623 assert!(
624 t.source_span.0 > 0,
625 "span start should reflect sentence offset"
626 );
627 }
628 }
629
630 #[test]
631 fn test_source_span_end_geq_start() {
632 let extractor = default_extractor();
633 let triples = extractor.extract("Alice is a developer.");
634 for t in &triples {
635 assert!(t.source_span.1 >= t.source_span.0);
636 }
637 }
638
639 #[test]
642 fn test_normalize_predicates_false_preserves_case() {
643 let config = ExtractionConfig {
644 min_confidence: 0.0,
645 max_triples_per_sentence: 10,
646 normalize_predicates: false,
647 };
648 let mut extractor = TripleExtractor::new(config);
649 extractor.add_pattern(ExtractionPattern::new(
650 "IS",
651 "S",
652 vec!["IS".to_string()],
653 "O",
654 ));
655 let triples = extractor.extract_sentence("Alice IS Bob");
656 if !triples.is_empty() {
657 assert_eq!(triples[0].predicate, "IS");
659 }
660 }
661
662 #[test]
663 fn test_normalize_predicates_true_lowercases() {
664 let config = ExtractionConfig {
665 min_confidence: 0.0,
666 max_triples_per_sentence: 10,
667 normalize_predicates: true,
668 };
669 let mut extractor = TripleExtractor::new(config);
670 extractor.add_pattern(ExtractionPattern::new(
671 "has",
672 "S",
673 vec!["has".to_string()],
674 "O",
675 ));
676 let triples = extractor.extract_sentence("Alice has job");
677 if !triples.is_empty() {
678 assert_eq!(triples[0].predicate, triples[0].predicate.to_lowercase());
679 }
680 }
681
682 #[test]
685 fn test_extract_sentence_too_short() {
686 let extractor = default_extractor();
687 let triples = extractor.extract_sentence("Hello");
688 assert!(triples.is_empty());
689 }
690
691 #[test]
692 fn test_extraction_pattern_new() {
693 let p = ExtractionPattern::new("test", "S", vec!["relates".to_string()], "O");
694 assert_eq!(p.name, "test");
695 assert_eq!(p.predicate_words, vec!["relates"]);
696 }
697
698 #[test]
701 fn test_extraction_config_default() {
702 let cfg = ExtractionConfig::default();
703 assert!(cfg.min_confidence >= 0.0);
704 assert!(cfg.max_triples_per_sentence > 0);
705 }
706
707 #[test]
708 fn test_extract_multiple_sentences_second_sentence() {
709 let extractor = default_extractor();
710 let text = "X is Y. Bob is a manager.";
711 let triples = extractor.extract(text);
712 assert!(triples
714 .iter()
715 .any(|t| t.subject == "Bob" || !t.subject.is_empty()));
716 }
717
718 #[test]
719 fn test_confidence_for_pattern_larger_match() {
720 let c = TripleExtractor::confidence_for_pattern(4, 4);
721 assert!((c - 1.0).abs() < 1e-10);
722 }
723
724 #[test]
725 fn test_confidence_for_pattern_zero_matched() {
726 let c = TripleExtractor::confidence_for_pattern(0, 5);
727 assert_eq!(c, 0.0);
728 }
729
730 #[test]
731 fn test_add_multiple_patterns() {
732 let mut extractor = TripleExtractor::new(ExtractionConfig::default());
733 for i in 0..5 {
734 extractor.add_pattern(ExtractionPattern::new(
735 format!("p{}", i),
736 "S",
737 vec![format!("verb{}", i)],
738 "O",
739 ));
740 }
741 assert_eq!(extractor.pattern_count(), 5);
742 }
743
744 #[test]
745 fn test_text_triple_fields() {
746 let t = TextTriple {
747 subject: "Alice".to_string(),
748 predicate: "knows".to_string(),
749 object: "Bob".to_string(),
750 confidence: 0.75,
751 source_span: (5, 20),
752 };
753 assert_eq!(t.subject, "Alice");
754 assert_eq!(t.predicate, "knows");
755 assert_eq!(t.object, "Bob");
756 assert!((t.confidence - 0.75).abs() < 1e-10);
757 assert_eq!(t.source_span, (5, 20));
758 }
759
760 #[test]
761 fn test_normalize_predicate_all_stop_words() {
762 let norm = TripleExtractor::normalize_predicate("a the");
764 assert!(!norm.is_empty() || norm.is_empty()); }
767
768 #[test]
769 fn test_extract_exclamation_sentence() {
770 let extractor = default_extractor();
771 let triples = extractor.extract("Alice is great! Bob is better.");
773 assert!(!triples.is_empty());
774 }
775
776 #[test]
777 fn test_extract_question_mark_sentence() {
778 let extractor = default_extractor();
779 let triples = extractor.extract("Alice is here? Bob is there.");
780 assert!(!triples.is_empty());
782 }
783
784 #[test]
785 fn test_extraction_pattern_object_token() {
786 let p = ExtractionPattern::new("test", "SUBJ", vec!["verb".to_string()], "OBJ");
787 assert_eq!(p.object_token, "OBJ");
788 assert_eq!(p.subject_token, "SUBJ");
789 }
790
791 #[test]
792 fn test_to_knowledge_graph_preserves_confidence_order() {
793 let triples = vec![
794 TextTriple {
795 subject: "A".to_string(),
796 predicate: "p".to_string(),
797 object: "B".to_string(),
798 confidence: 0.9,
799 source_span: (0, 5),
800 },
801 TextTriple {
802 subject: "C".to_string(),
803 predicate: "q".to_string(),
804 object: "D".to_string(),
805 confidence: 0.5,
806 source_span: (6, 11),
807 },
808 ];
809 let kg = TripleExtractor::to_knowledge_graph(&triples);
810 assert_eq!(kg[0].0, "A");
811 assert_eq!(kg[1].0, "C");
812 }
813}