1use super::{EdgeType, GraphEdge, GraphError, GraphNode, NodeType};
6use crate::RragResult;
7use async_trait::async_trait;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Entity {
15 pub text: String,
17
18 pub entity_type: EntityType,
20
21 pub start_pos: usize,
23
24 pub end_pos: usize,
26
27 pub confidence: f32,
29
30 pub normalized_form: Option<String>,
32
33 pub attributes: HashMap<String, serde_json::Value>,
35
36 pub source_id: String,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct Relationship {
43 pub source_entity: String,
45
46 pub target_entity: String,
48
49 pub relation_type: RelationType,
51
52 pub context: String,
54
55 pub start_pos: usize,
57
58 pub end_pos: usize,
60
61 pub confidence: f32,
63
64 pub attributes: HashMap<String, serde_json::Value>,
66
67 pub source_id: String,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
73pub enum EntityType {
74 Person,
76
77 Organization,
79
80 Location,
82
83 DateTime,
85
86 Money,
88
89 Percentage,
91
92 Technical,
94
95 Concept,
97
98 Product,
100
101 Event,
103
104 Custom(String),
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
110pub enum RelationType {
111 IsA,
113
114 PartOf,
116
117 LocatedIn,
119
120 WorksFor,
122
123 Owns,
125
126 Causes,
128
129 SimilarTo,
131
132 HappenedOn,
134
135 MentionedWith,
137
138 Custom(String),
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct EntityExtractionConfig {
145 pub min_confidence: f32,
147
148 pub max_entity_length: usize,
150
151 pub extract_technical_terms: bool,
153
154 pub extract_concepts: bool,
156
157 #[serde(skip)]
159 pub custom_patterns: HashMap<String, Regex>,
160
161 pub stop_words: HashSet<String>,
163
164 pub entity_priorities: HashMap<EntityType, f32>,
166}
167
168impl Default for EntityExtractionConfig {
169 fn default() -> Self {
170 let mut stop_words = HashSet::new();
171 stop_words.extend(
172 vec![
173 "the",
174 "a",
175 "an",
176 "and",
177 "or",
178 "but",
179 "in",
180 "on",
181 "at",
182 "to",
183 "for",
184 "of",
185 "with",
186 "by",
187 "from",
188 "up",
189 "about",
190 "into",
191 "through",
192 "during",
193 "before",
194 "after",
195 "above",
196 "below",
197 "between",
198 "among",
199 "this",
200 "that",
201 "these",
202 "those",
203 "i",
204 "me",
205 "my",
206 "myself",
207 "we",
208 "our",
209 "ours",
210 "ourselves",
211 "you",
212 "your",
213 "yours",
214 "yourself",
215 "yourselves",
216 "he",
217 "him",
218 "his",
219 "himself",
220 "she",
221 "her",
222 "hers",
223 "herself",
224 "it",
225 "its",
226 "itself",
227 "they",
228 "them",
229 "their",
230 "theirs",
231 "themselves",
232 "what",
233 "which",
234 "who",
235 "whom",
236 "whose",
237 "this",
238 "that",
239 "these",
240 "those",
241 "am",
242 "is",
243 "are",
244 "was",
245 "were",
246 "be",
247 "been",
248 "being",
249 "have",
250 "has",
251 "had",
252 "having",
253 "do",
254 "does",
255 "did",
256 "doing",
257 "would",
258 "should",
259 "could",
260 "can",
261 "may",
262 "might",
263 "must",
264 "shall",
265 "will",
266 "would",
267 ]
268 .into_iter()
269 .map(|s| s.to_string()),
270 );
271
272 let mut entity_priorities = HashMap::new();
273 entity_priorities.insert(EntityType::Person, 0.9);
274 entity_priorities.insert(EntityType::Organization, 0.8);
275 entity_priorities.insert(EntityType::Location, 0.8);
276 entity_priorities.insert(EntityType::DateTime, 0.7);
277 entity_priorities.insert(EntityType::Technical, 0.6);
278 entity_priorities.insert(EntityType::Concept, 0.5);
279
280 Self {
281 min_confidence: 0.5,
282 max_entity_length: 100,
283 extract_technical_terms: true,
284 extract_concepts: true,
285 custom_patterns: HashMap::new(),
286 stop_words,
287 entity_priorities,
288 }
289 }
290}
291
292#[async_trait]
294pub trait EntityExtractor: Send + Sync {
295 async fn extract_entities(&self, text: &str, source_id: &str) -> RragResult<Vec<Entity>>;
297
298 async fn extract_relationships(
300 &self,
301 text: &str,
302 entities: &[Entity],
303 source_id: &str,
304 ) -> RragResult<Vec<Relationship>>;
305
306 async fn extract_all(
308 &self,
309 text: &str,
310 source_id: &str,
311 ) -> RragResult<(Vec<Entity>, Vec<Relationship>)> {
312 let entities = self.extract_entities(text, source_id).await?;
313 let relationships = self
314 .extract_relationships(text, &entities, source_id)
315 .await?;
316 Ok((entities, relationships))
317 }
318}
319
320pub struct RuleBasedEntityExtractor {
322 config: EntityExtractionConfig,
324
325 patterns: HashMap<EntityType, Vec<Regex>>,
327
328 relationship_patterns: HashMap<RelationType, Vec<Regex>>,
330}
331
332impl RuleBasedEntityExtractor {
333 pub fn new(config: EntityExtractionConfig) -> RragResult<Self> {
335 let patterns = Self::compile_entity_patterns(&config)?;
336 let relationship_patterns = Self::compile_relationship_patterns()?;
337
338 Ok(Self {
339 config,
340 patterns,
341 relationship_patterns,
342 })
343 }
344
345 fn compile_entity_patterns(
347 config: &EntityExtractionConfig,
348 ) -> RragResult<HashMap<EntityType, Vec<Regex>>> {
349 let mut patterns = HashMap::new();
350
351 let person_patterns = vec![
353 Regex::new(r"\b[A-Z][a-z]+\s+[A-Z][a-z]+\b").map_err(|e| {
354 GraphError::EntityExtraction {
355 message: format!("Failed to compile person pattern: {}", e),
356 }
357 })?,
358 Regex::new(r"\b(?:Mr|Mrs|Dr|Prof|Ms)\.?\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b").map_err(
359 |e| GraphError::EntityExtraction {
360 message: format!("Failed to compile person title pattern: {}", e),
361 },
362 )?,
363 ];
364 patterns.insert(EntityType::Person, person_patterns);
365
366 let org_patterns = vec![
368 Regex::new(r"\b[A-Z][a-zA-Z]*\s+(?:Inc|Corp|Company|Ltd|LLC|Organization|Institute|University|College|School)\b").map_err(|e| {
369 GraphError::EntityExtraction {
370 message: format!("Failed to compile organization pattern: {}", e)
371 }
372 })?,
373 Regex::new(r"\b(?:the\s+)?[A-Z][a-zA-Z\s]+(?:Corporation|Foundation|Association|Agency|Department)\b").map_err(|e| {
374 GraphError::EntityExtraction {
375 message: format!("Failed to compile organization pattern 2: {}", e)
376 }
377 })?,
378 ];
379 patterns.insert(EntityType::Organization, org_patterns);
380
381 let location_patterns = vec![
383 Regex::new(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*,\s*[A-Z]{2}\b").map_err(|e| {
384 GraphError::EntityExtraction {
385 message: format!("Failed to compile location pattern: {}", e)
386 }
387 })?,
388 Regex::new(r"\b(?:New York|Los Angeles|Chicago|Houston|Phoenix|Philadelphia|San Antonio|San Diego|Dallas|San Jose|Austin|Jacksonville|Fort Worth|Columbus|Charlotte|San Francisco|Indianapolis|Seattle|Denver|Washington|Boston|El Paso|Detroit|Nashville|Portland|Memphis|Oklahoma City|Las Vegas|Louisville|Baltimore|Milwaukee|Albuquerque|Tucson|Fresno|Sacramento|Mesa|Kansas City|Atlanta|Long Beach|Colorado Springs|Raleigh|Miami|Virginia Beach|Omaha|Oakland|Minneapolis|Tulsa|Arlington|Tampa|New Orleans|Wichita|Cleveland|Bakersfield|Aurora|Anaheim|Honolulu|Santa Ana|Riverside|Corpus Christi|Lexington|Stockton|Henderson|Saint Paul|St. Paul|Cincinnati|St. Louis|Pittsburgh|Greensboro|Lincoln|Plano|Anchorage|Durham|Jersey City|Chula Vista|Orlando|Chandler|Henderson|Laredo|Buffalo|North Las Vegas|Madison|Lubbock|Reno|Akron|Hialeah|Garland|Rochester|Modesto|Montgomery|Yonkers|Spokane|Tacoma|Shreveport|Des Moines|Fremont|Baton Rouge|Richmond|Birmingham|Chesapeake|Glendale|Irving|Scottsdale|North Hempstead|Fayetteville|Grand Rapids|Santa Clarita|Salt Lake City|Huntsville)\b").map_err(|e| {
389 GraphError::EntityExtraction {
390 message: format!("Failed to compile major cities pattern: {}", e)
391 }
392 })?,
393 ];
394 patterns.insert(EntityType::Location, location_patterns);
395
396 let datetime_patterns = vec![
398 Regex::new(r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b").map_err(|e| {
399 GraphError::EntityExtraction {
400 message: format!("Failed to compile date pattern: {}", e)
401 }
402 })?,
403 Regex::new(r"\b\d{1,2}/\d{1,2}/\d{4}\b").map_err(|e| {
404 GraphError::EntityExtraction {
405 message: format!("Failed to compile date pattern 2: {}", e)
406 }
407 })?,
408 Regex::new(r"\b\d{4}-\d{2}-\d{2}\b").map_err(|e| {
409 GraphError::EntityExtraction {
410 message: format!("Failed to compile ISO date pattern: {}", e)
411 }
412 })?,
413 ];
414 patterns.insert(EntityType::DateTime, datetime_patterns);
415
416 let money_patterns = vec![
418 Regex::new(r"\$\d+(?:,\d{3})*(?:\.\d{2})?\b").map_err(|e| {
419 GraphError::EntityExtraction {
420 message: format!("Failed to compile money pattern: {}", e),
421 }
422 })?,
423 Regex::new(r"\b\d+(?:,\d{3})*(?:\.\d{2})?\s*(?:dollars?|USD|cents?)\b").map_err(
424 |e| GraphError::EntityExtraction {
425 message: format!("Failed to compile money pattern 2: {}", e),
426 },
427 )?,
428 ];
429 patterns.insert(EntityType::Money, money_patterns);
430
431 let percentage_patterns = vec![
433 Regex::new(r"\b\d+(?:\.\d+)?%\b").map_err(|e| GraphError::EntityExtraction {
434 message: format!("Failed to compile percentage pattern: {}", e),
435 })?,
436 Regex::new(r"\b\d+(?:\.\d+)?\s*percent\b").map_err(|e| {
437 GraphError::EntityExtraction {
438 message: format!("Failed to compile percentage pattern 2: {}", e),
439 }
440 })?,
441 ];
442 patterns.insert(EntityType::Percentage, percentage_patterns);
443
444 if config.extract_technical_terms {
446 let technical_patterns = vec![
447 Regex::new(r"\b[A-Z]{2,}(?:\s+[A-Z]{2,})*\b").map_err(|e| {
448 GraphError::EntityExtraction {
449 message: format!("Failed to compile technical acronym pattern: {}", e)
450 }
451 })?,
452 Regex::new(r"\b(?:API|SDK|HTTP|HTTPS|JSON|XML|SQL|NoSQL|REST|GraphQL|OAuth|JWT|SSL|TLS|CSS|HTML|JavaScript|TypeScript|Python|Java|Rust|Go|C\+\+|PHP|Ruby|Swift|Kotlin|React|Vue|Angular|Docker|Kubernetes|AWS|GCP|Azure|MongoDB|PostgreSQL|MySQL|Redis|Elasticsearch|TensorFlow|PyTorch|OpenAI|GPT|BERT|Transformer|Neural Network|Machine Learning|Deep Learning|AI|ML|DL|NLP|Computer Vision|Data Science|Big Data|Cloud Computing|DevOps|CI/CD|Git|GitHub|GitLab|Bitbucket|Jenkins|Travis CI|CircleCI|Terraform|Ansible|Chef|Puppet)\b").map_err(|e| {
453 GraphError::EntityExtraction {
454 message: format!("Failed to compile technical terms pattern: {}", e)
455 }
456 })?,
457 ];
458 patterns.insert(EntityType::Technical, technical_patterns);
459 }
460
461 for (pattern_name, regex) in &config.custom_patterns {
463 let entity_type = EntityType::Custom(pattern_name.clone());
464 patterns
465 .entry(entity_type)
466 .or_insert_with(Vec::new)
467 .push(regex.clone());
468 }
469
470 Ok(patterns)
471 }
472
473 fn compile_relationship_patterns() -> RragResult<HashMap<RelationType, Vec<Regex>>> {
475 let mut patterns = HashMap::new();
476
477 let is_a_patterns = vec![
479 Regex::new(r"(.+?)\s+is\s+a\s+(.+?)").map_err(|e| GraphError::EntityExtraction {
480 message: format!("Failed to compile is-a pattern: {}", e),
481 })?,
482 Regex::new(r"(.+?)\s+(?:are|is)\s+(?:an?|the)\s+(.+?)").map_err(|e| {
483 GraphError::EntityExtraction {
484 message: format!("Failed to compile is-a pattern 2: {}", e),
485 }
486 })?,
487 ];
488 patterns.insert(RelationType::IsA, is_a_patterns);
489
490 let part_of_patterns = vec![
492 Regex::new(r"(.+?)\s+(?:is|are)\s+part\s+of\s+(.+?)").map_err(|e| {
493 GraphError::EntityExtraction {
494 message: format!("Failed to compile part-of pattern: {}", e),
495 }
496 })?,
497 Regex::new(r"(.+?)\s+belongs\s+to\s+(.+?)").map_err(|e| {
498 GraphError::EntityExtraction {
499 message: format!("Failed to compile belongs-to pattern: {}", e),
500 }
501 })?,
502 ];
503 patterns.insert(RelationType::PartOf, part_of_patterns);
504
505 let located_in_patterns = vec![
507 Regex::new(r"(.+?)\s+(?:is|are)\s+located\s+in\s+(.+?)").map_err(|e| {
508 GraphError::EntityExtraction {
509 message: format!("Failed to compile located-in pattern: {}", e),
510 }
511 })?,
512 Regex::new(r"(.+?)\s+in\s+(.+?)").map_err(|e| GraphError::EntityExtraction {
513 message: format!("Failed to compile in pattern: {}", e),
514 })?,
515 ];
516 patterns.insert(RelationType::LocatedIn, located_in_patterns);
517
518 let works_for_patterns = vec![
520 Regex::new(r"(.+?)\s+works\s+(?:for|at)\s+(.+?)").map_err(|e| {
521 GraphError::EntityExtraction {
522 message: format!("Failed to compile works-for pattern: {}", e),
523 }
524 })?,
525 Regex::new(r"(.+?)\s+(?:is|was)\s+employed\s+(?:by|at)\s+(.+?)").map_err(|e| {
526 GraphError::EntityExtraction {
527 message: format!("Failed to compile employed-by pattern: {}", e),
528 }
529 })?,
530 ];
531 patterns.insert(RelationType::WorksFor, works_for_patterns);
532
533 let owns_patterns = vec![
535 Regex::new(r"(.+?)\s+owns\s+(.+?)").map_err(|e| GraphError::EntityExtraction {
536 message: format!("Failed to compile owns pattern: {}", e),
537 })?,
538 Regex::new(r"(.+?)\s+(?:has|possesses)\s+(.+?)").map_err(|e| {
539 GraphError::EntityExtraction {
540 message: format!("Failed to compile has pattern: {}", e),
541 }
542 })?,
543 ];
544 patterns.insert(RelationType::Owns, owns_patterns);
545
546 Ok(patterns)
547 }
548
549 fn extract_entities_with_patterns(&self, text: &str, source_id: &str) -> Vec<Entity> {
551 let mut entities = Vec::new();
552 let mut seen_positions = HashSet::new();
553
554 for (entity_type, patterns) in &self.patterns {
555 let priority = self
556 .config
557 .entity_priorities
558 .get(entity_type)
559 .copied()
560 .unwrap_or(0.5);
561
562 for pattern in patterns {
563 for mat in pattern.find_iter(text) {
564 let start_pos = mat.start();
565 let end_pos = mat.end();
566 let entity_text = mat.as_str().trim();
567
568 if seen_positions.contains(&(start_pos, end_pos)) {
570 continue;
571 }
572
573 if entity_text.len() > self.config.max_entity_length
575 || self.is_stop_word_only(entity_text)
576 {
577 continue;
578 }
579
580 let base_confidence = match entity_type {
582 EntityType::DateTime | EntityType::Money | EntityType::Percentage => 0.9,
583 EntityType::Technical => 0.8,
584 _ => 0.7,
585 };
586 let confidence = (base_confidence * priority).min(1.0);
587
588 if confidence >= self.config.min_confidence {
589 let entity = Entity {
590 text: entity_text.to_string(),
591 entity_type: entity_type.clone(),
592 start_pos,
593 end_pos,
594 confidence,
595 normalized_form: Some(self.normalize_entity(entity_text)),
596 attributes: HashMap::new(),
597 source_id: source_id.to_string(),
598 };
599
600 entities.push(entity);
601 seen_positions.insert((start_pos, end_pos));
602 }
603 }
604 }
605 }
606
607 entities.sort_by_key(|e| e.start_pos);
609 entities
610 }
611
612 fn is_stop_word_only(&self, text: &str) -> bool {
614 let words: Vec<&str> = text.split_whitespace().collect();
615 if words.is_empty() {
616 return true;
617 }
618
619 words
620 .iter()
621 .all(|word| self.config.stop_words.contains(&word.to_lowercase()))
622 }
623
624 fn normalize_entity(&self, text: &str) -> String {
626 text.trim()
627 .chars()
628 .map(|c| if c.is_whitespace() { ' ' } else { c })
629 .collect::<String>()
630 .split_whitespace()
631 .collect::<Vec<_>>()
632 .join(" ")
633 }
634
635 fn extract_relationships_with_patterns(
637 &self,
638 text: &str,
639 entities: &[Entity],
640 source_id: &str,
641 ) -> Vec<Relationship> {
642 let mut relationships = Vec::new();
643
644 let mut entity_spans: Vec<(usize, usize, &Entity)> = entities
646 .iter()
647 .map(|e| (e.start_pos, e.end_pos, e))
648 .collect();
649 entity_spans.sort_by_key(|&(start, _, _)| start);
650
651 for (relation_type, patterns) in &self.relationship_patterns {
652 for pattern in patterns {
653 for mat in pattern.find_iter(text) {
654 if let Some(captures) = pattern.captures(mat.as_str()) {
655 if captures.len() >= 3 {
656 let source_text = captures.get(1).unwrap().as_str().trim();
657 let target_text = captures.get(2).unwrap().as_str().trim();
658
659 if let (Some(source_entity), Some(target_entity)) = (
661 self.find_matching_entity(source_text, &entity_spans),
662 self.find_matching_entity(target_text, &entity_spans),
663 ) {
664 let relationship = Relationship {
665 source_entity: source_entity
666 .normalized_form
667 .as_ref()
668 .unwrap_or(&source_entity.text)
669 .clone(),
670 target_entity: target_entity
671 .normalized_form
672 .as_ref()
673 .unwrap_or(&target_entity.text)
674 .clone(),
675 relation_type: relation_type.clone(),
676 context: mat.as_str().to_string(),
677 start_pos: mat.start(),
678 end_pos: mat.end(),
679 confidence: 0.7, attributes: HashMap::new(),
681 source_id: source_id.to_string(),
682 };
683
684 relationships.push(relationship);
685 }
686 }
687 }
688 }
689 }
690 }
691
692 self.extract_co_occurrence_relationships(text, entities, source_id, &mut relationships);
694
695 relationships
696 }
697
698 fn find_matching_entity<'a>(
700 &self,
701 text: &str,
702 entity_spans: &'a [(usize, usize, &'a Entity)],
703 ) -> Option<&'a Entity> {
704 entity_spans
705 .iter()
706 .find(|(_, _, entity)| {
707 entity.text.eq_ignore_ascii_case(text)
708 || entity
709 .normalized_form
710 .as_ref()
711 .map_or(false, |norm| norm.eq_ignore_ascii_case(text))
712 })
713 .map(|(_, _, entity)| *entity)
714 }
715
716 fn extract_co_occurrence_relationships(
718 &self,
719 _text: &str,
720 entities: &[Entity],
721 source_id: &str,
722 relationships: &mut Vec<Relationship>,
723 ) {
724 let max_distance = 100; for i in 0..entities.len() {
727 for j in (i + 1)..entities.len() {
728 let entity1 = &entities[i];
729 let entity2 = &entities[j];
730
731 let distance = if entity1.end_pos < entity2.start_pos {
733 entity2.start_pos - entity1.end_pos
734 } else if entity2.end_pos < entity1.start_pos {
735 entity1.start_pos - entity2.end_pos
736 } else {
737 0 };
739
740 if distance <= max_distance {
741 let base_confidence = 0.3;
743 let distance_factor = 1.0 - (distance as f32 / max_distance as f32);
744 let confidence = base_confidence * distance_factor;
745
746 if confidence >= self.config.min_confidence {
747 let relationship = Relationship {
748 source_entity: entity1
749 .normalized_form
750 .as_ref()
751 .unwrap_or(&entity1.text)
752 .clone(),
753 target_entity: entity2
754 .normalized_form
755 .as_ref()
756 .unwrap_or(&entity2.text)
757 .clone(),
758 relation_type: RelationType::MentionedWith,
759 context: format!("Co-occurrence within {} characters", distance),
760 start_pos: entity1.start_pos.min(entity2.start_pos),
761 end_pos: entity1.end_pos.max(entity2.end_pos),
762 confidence,
763 attributes: {
764 let mut attrs = HashMap::new();
765 attrs.insert(
766 "distance".to_string(),
767 serde_json::Value::Number(distance.into()),
768 );
769 attrs.insert(
770 "type".to_string(),
771 serde_json::Value::String("co_occurrence".to_string()),
772 );
773 attrs
774 },
775 source_id: source_id.to_string(),
776 };
777
778 relationships.push(relationship);
779 }
780 }
781 }
782 }
783 }
784}
785
786#[async_trait]
787impl EntityExtractor for RuleBasedEntityExtractor {
788 async fn extract_entities(&self, text: &str, source_id: &str) -> RragResult<Vec<Entity>> {
789 Ok(self.extract_entities_with_patterns(text, source_id))
790 }
791
792 async fn extract_relationships(
793 &self,
794 text: &str,
795 entities: &[Entity],
796 source_id: &str,
797 ) -> RragResult<Vec<Relationship>> {
798 Ok(self.extract_relationships_with_patterns(text, entities, source_id))
799 }
800}
801
802pub fn entities_to_nodes(entities: &[Entity]) -> Vec<GraphNode> {
804 entities
805 .iter()
806 .map(|entity| {
807 let node_type = match &entity.entity_type {
808 EntityType::Person => NodeType::Entity("Person".to_string()),
809 EntityType::Organization => NodeType::Entity("Organization".to_string()),
810 EntityType::Location => NodeType::Entity("Location".to_string()),
811 EntityType::DateTime => NodeType::Entity("DateTime".to_string()),
812 EntityType::Money => NodeType::Entity("Money".to_string()),
813 EntityType::Percentage => NodeType::Entity("Percentage".to_string()),
814 EntityType::Technical => NodeType::Entity("Technical".to_string()),
815 EntityType::Concept => NodeType::Concept,
816 EntityType::Product => NodeType::Entity("Product".to_string()),
817 EntityType::Event => NodeType::Entity("Event".to_string()),
818 EntityType::Custom(custom_type) => NodeType::Custom(custom_type.clone()),
819 };
820
821 let mut node = GraphNode::new(
822 entity.normalized_form.as_ref().unwrap_or(&entity.text),
823 node_type,
824 )
825 .with_confidence(entity.confidence)
826 .with_source_document(entity.source_id.clone());
827
828 for (key, value) in &entity.attributes {
830 node = node.with_attribute(key, value.clone());
831 }
832
833 node = node.with_attribute(
834 "original_text",
835 serde_json::Value::String(entity.text.clone()),
836 );
837 node = node.with_attribute(
838 "start_pos",
839 serde_json::Value::Number(entity.start_pos.into()),
840 );
841 node = node.with_attribute("end_pos", serde_json::Value::Number(entity.end_pos.into()));
842
843 node
844 })
845 .collect()
846}
847
848pub fn relationships_to_edges(
850 relationships: &[Relationship],
851 entity_node_map: &HashMap<String, String>,
852) -> Vec<GraphEdge> {
853 relationships
854 .iter()
855 .filter_map(|relationship| {
856 let source_node_id = entity_node_map.get(&relationship.source_entity)?;
858 let target_node_id = entity_node_map.get(&relationship.target_entity)?;
859
860 let edge_type = match &relationship.relation_type {
861 RelationType::IsA => EdgeType::Semantic("is_a".to_string()),
862 RelationType::PartOf => EdgeType::Semantic("part_of".to_string()),
863 RelationType::LocatedIn => EdgeType::Semantic("located_in".to_string()),
864 RelationType::WorksFor => EdgeType::Semantic("works_for".to_string()),
865 RelationType::Owns => EdgeType::Semantic("owns".to_string()),
866 RelationType::Causes => EdgeType::Semantic("causes".to_string()),
867 RelationType::SimilarTo => EdgeType::Similar,
868 RelationType::HappenedOn => EdgeType::Semantic("happened_on".to_string()),
869 RelationType::MentionedWith => EdgeType::CoOccurs,
870 RelationType::Custom(custom_type) => EdgeType::Custom(custom_type.clone()),
871 };
872
873 let mut edge = GraphEdge::new(
874 source_node_id,
875 target_node_id,
876 &relationship.relation_type.to_string(),
877 edge_type,
878 )
879 .with_confidence(relationship.confidence)
880 .with_weight(relationship.confidence)
881 .with_source_document(relationship.source_id.clone());
882
883 for (key, value) in &relationship.attributes {
885 edge = edge.with_attribute(key, value.clone());
886 }
887
888 edge = edge.with_attribute(
889 "context",
890 serde_json::Value::String(relationship.context.clone()),
891 );
892 edge = edge.with_attribute(
893 "start_pos",
894 serde_json::Value::Number(relationship.start_pos.into()),
895 );
896 edge = edge.with_attribute(
897 "end_pos",
898 serde_json::Value::Number(relationship.end_pos.into()),
899 );
900
901 Some(edge)
902 })
903 .collect()
904}
905
906impl std::fmt::Display for RelationType {
907 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
908 match self {
909 RelationType::IsA => write!(f, "is_a"),
910 RelationType::PartOf => write!(f, "part_of"),
911 RelationType::LocatedIn => write!(f, "located_in"),
912 RelationType::WorksFor => write!(f, "works_for"),
913 RelationType::Owns => write!(f, "owns"),
914 RelationType::Causes => write!(f, "causes"),
915 RelationType::SimilarTo => write!(f, "similar_to"),
916 RelationType::HappenedOn => write!(f, "happened_on"),
917 RelationType::MentionedWith => write!(f, "mentioned_with"),
918 RelationType::Custom(custom) => write!(f, "{}", custom),
919 }
920 }
921}
922
923#[cfg(test)]
924mod tests {
925 use super::*;
926
927 #[tokio::test]
928 async fn test_rule_based_entity_extraction() {
929 let config = EntityExtractionConfig::default();
930 let extractor = RuleBasedEntityExtractor::new(config).unwrap();
931
932 let text = "John Smith works at Microsoft Corporation in Seattle. The company was founded in 1975.";
933 let entities = extractor.extract_entities(text, "test_doc").await.unwrap();
934
935 assert!(!entities.is_empty());
936
937 let person_entities: Vec<_> = entities
939 .iter()
940 .filter(|e| matches!(e.entity_type, EntityType::Person))
941 .collect();
942 assert!(!person_entities.is_empty());
943
944 let org_entities: Vec<_> = entities
945 .iter()
946 .filter(|e| matches!(e.entity_type, EntityType::Organization))
947 .collect();
948 assert!(!org_entities.is_empty());
949 }
950
951 #[tokio::test]
952 async fn test_relationship_extraction() {
953 let config = EntityExtractionConfig::default();
954 let extractor = RuleBasedEntityExtractor::new(config).unwrap();
955
956 let text = "Alice is a software engineer. She works for Google.";
957 let (entities, relationships) = extractor.extract_all(text, "test_doc").await.unwrap();
958
959 assert!(!entities.is_empty());
960 assert!(!relationships.is_empty());
961
962 let work_relations: Vec<_> = relationships
964 .iter()
965 .filter(|r| matches!(r.relation_type, RelationType::WorksFor))
966 .collect();
967 assert!(!work_relations.is_empty());
968 }
969
970 #[test]
971 fn test_entity_to_node_conversion() {
972 let entity = Entity {
973 text: "John Smith".to_string(),
974 entity_type: EntityType::Person,
975 start_pos: 0,
976 end_pos: 10,
977 confidence: 0.9,
978 normalized_form: Some("John Smith".to_string()),
979 attributes: HashMap::new(),
980 source_id: "test_doc".to_string(),
981 };
982
983 let nodes = entities_to_nodes(&[entity]);
984 assert_eq!(nodes.len(), 1);
985 assert!(matches!(nodes[0].node_type, NodeType::Entity(_)));
986 assert_eq!(nodes[0].confidence, 0.9);
987 }
988}