rexis_rag/graph_retrieval/
entity.rs

1//! # Entity and Relationship Extraction
2//!
3//! Advanced entity recognition and relationship extraction for knowledge graph construction.
4
5use super::{EdgeType, GraphEdge, GraphError, GraphNode, NodeType};
6use crate::RragResult;
7use async_trait::async_trait;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11
12/// Entity extracted from text
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Entity {
15    /// Entity text/mention
16    pub text: String,
17
18    /// Entity type
19    pub entity_type: EntityType,
20
21    /// Start position in source text
22    pub start_pos: usize,
23
24    /// End position in source text
25    pub end_pos: usize,
26
27    /// Confidence score (0.0 to 1.0)
28    pub confidence: f32,
29
30    /// Normalized form of the entity
31    pub normalized_form: Option<String>,
32
33    /// Additional attributes
34    pub attributes: HashMap<String, serde_json::Value>,
35
36    /// Source document/chunk ID
37    pub source_id: String,
38}
39
40/// Relationship between entities
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct Relationship {
43    /// Source entity
44    pub source_entity: String,
45
46    /// Target entity
47    pub target_entity: String,
48
49    /// Relationship type
50    pub relation_type: RelationType,
51
52    /// Relationship text/context
53    pub context: String,
54
55    /// Start position in source text
56    pub start_pos: usize,
57
58    /// End position in source text
59    pub end_pos: usize,
60
61    /// Confidence score (0.0 to 1.0)
62    pub confidence: f32,
63
64    /// Additional attributes
65    pub attributes: HashMap<String, serde_json::Value>,
66
67    /// Source document/chunk ID
68    pub source_id: String,
69}
70
71/// Entity types for classification
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
73pub enum EntityType {
74    /// Person names
75    Person,
76
77    /// Organization names
78    Organization,
79
80    /// Locations (cities, countries, etc.)
81    Location,
82
83    /// Dates and times
84    DateTime,
85
86    /// Monetary values
87    Money,
88
89    /// Percentages
90    Percentage,
91
92    /// Technical terms
93    Technical,
94
95    /// Concepts
96    Concept,
97
98    /// Products or services
99    Product,
100
101    /// Events
102    Event,
103
104    /// Custom entity type
105    Custom(String),
106}
107
108/// Relationship types
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
110pub enum RelationType {
111    /// "is a" relationship
112    IsA,
113
114    /// "part of" relationship
115    PartOf,
116
117    /// "located in" relationship
118    LocatedIn,
119
120    /// "works for" relationship
121    WorksFor,
122
123    /// "owns" relationship
124    Owns,
125
126    /// "causes" relationship
127    Causes,
128
129    /// "similar to" relationship
130    SimilarTo,
131
132    /// "happened on" relationship
133    HappenedOn,
134
135    /// "mentioned with" relationship
136    MentionedWith,
137
138    /// Custom relationship type
139    Custom(String),
140}
141
142/// Entity extraction configuration
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct EntityExtractionConfig {
145    /// Minimum confidence threshold
146    pub min_confidence: f32,
147
148    /// Maximum entity length in characters
149    pub max_entity_length: usize,
150
151    /// Whether to extract technical terms
152    pub extract_technical_terms: bool,
153
154    /// Whether to extract concepts
155    pub extract_concepts: bool,
156
157    /// Custom entity patterns
158    #[serde(skip)]
159    pub custom_patterns: HashMap<String, Regex>,
160
161    /// Stop words to ignore
162    pub stop_words: HashSet<String>,
163
164    /// Entity type priorities (higher = more important)
165    pub entity_priorities: HashMap<EntityType, f32>,
166}
167
168impl Default for EntityExtractionConfig {
169    fn default() -> Self {
170        let mut stop_words = HashSet::new();
171        stop_words.extend(
172            vec![
173                "the",
174                "a",
175                "an",
176                "and",
177                "or",
178                "but",
179                "in",
180                "on",
181                "at",
182                "to",
183                "for",
184                "of",
185                "with",
186                "by",
187                "from",
188                "up",
189                "about",
190                "into",
191                "through",
192                "during",
193                "before",
194                "after",
195                "above",
196                "below",
197                "between",
198                "among",
199                "this",
200                "that",
201                "these",
202                "those",
203                "i",
204                "me",
205                "my",
206                "myself",
207                "we",
208                "our",
209                "ours",
210                "ourselves",
211                "you",
212                "your",
213                "yours",
214                "yourself",
215                "yourselves",
216                "he",
217                "him",
218                "his",
219                "himself",
220                "she",
221                "her",
222                "hers",
223                "herself",
224                "it",
225                "its",
226                "itself",
227                "they",
228                "them",
229                "their",
230                "theirs",
231                "themselves",
232                "what",
233                "which",
234                "who",
235                "whom",
236                "whose",
237                "this",
238                "that",
239                "these",
240                "those",
241                "am",
242                "is",
243                "are",
244                "was",
245                "were",
246                "be",
247                "been",
248                "being",
249                "have",
250                "has",
251                "had",
252                "having",
253                "do",
254                "does",
255                "did",
256                "doing",
257                "would",
258                "should",
259                "could",
260                "can",
261                "may",
262                "might",
263                "must",
264                "shall",
265                "will",
266                "would",
267            ]
268            .into_iter()
269            .map(|s| s.to_string()),
270        );
271
272        let mut entity_priorities = HashMap::new();
273        entity_priorities.insert(EntityType::Person, 0.9);
274        entity_priorities.insert(EntityType::Organization, 0.8);
275        entity_priorities.insert(EntityType::Location, 0.8);
276        entity_priorities.insert(EntityType::DateTime, 0.7);
277        entity_priorities.insert(EntityType::Technical, 0.6);
278        entity_priorities.insert(EntityType::Concept, 0.5);
279
280        Self {
281            min_confidence: 0.5,
282            max_entity_length: 100,
283            extract_technical_terms: true,
284            extract_concepts: true,
285            custom_patterns: HashMap::new(),
286            stop_words,
287            entity_priorities,
288        }
289    }
290}
291
292/// Entity extractor trait
293#[async_trait]
294pub trait EntityExtractor: Send + Sync {
295    /// Extract entities from text
296    async fn extract_entities(&self, text: &str, source_id: &str) -> RragResult<Vec<Entity>>;
297
298    /// Extract relationships from text and entities
299    async fn extract_relationships(
300        &self,
301        text: &str,
302        entities: &[Entity],
303        source_id: &str,
304    ) -> RragResult<Vec<Relationship>>;
305
306    /// Extract both entities and relationships
307    async fn extract_all(
308        &self,
309        text: &str,
310        source_id: &str,
311    ) -> RragResult<(Vec<Entity>, Vec<Relationship>)> {
312        let entities = self.extract_entities(text, source_id).await?;
313        let relationships = self
314            .extract_relationships(text, &entities, source_id)
315            .await?;
316        Ok((entities, relationships))
317    }
318}
319
320/// Rule-based entity extractor
321pub struct RuleBasedEntityExtractor {
322    /// Configuration
323    config: EntityExtractionConfig,
324
325    /// Compiled regex patterns
326    patterns: HashMap<EntityType, Vec<Regex>>,
327
328    /// Relationship patterns
329    relationship_patterns: HashMap<RelationType, Vec<Regex>>,
330}
331
332impl RuleBasedEntityExtractor {
333    /// Create a new rule-based entity extractor
334    pub fn new(config: EntityExtractionConfig) -> RragResult<Self> {
335        let patterns = Self::compile_entity_patterns(&config)?;
336        let relationship_patterns = Self::compile_relationship_patterns()?;
337
338        Ok(Self {
339            config,
340            patterns,
341            relationship_patterns,
342        })
343    }
344
345    /// Compile entity recognition patterns
346    fn compile_entity_patterns(
347        config: &EntityExtractionConfig,
348    ) -> RragResult<HashMap<EntityType, Vec<Regex>>> {
349        let mut patterns = HashMap::new();
350
351        // Person patterns
352        let person_patterns = vec![
353            Regex::new(r"\b[A-Z][a-z]+\s+[A-Z][a-z]+\b").map_err(|e| {
354                GraphError::EntityExtraction {
355                    message: format!("Failed to compile person pattern: {}", e),
356                }
357            })?,
358            Regex::new(r"\b(?:Mr|Mrs|Dr|Prof|Ms)\.?\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b").map_err(
359                |e| GraphError::EntityExtraction {
360                    message: format!("Failed to compile person title pattern: {}", e),
361                },
362            )?,
363        ];
364        patterns.insert(EntityType::Person, person_patterns);
365
366        // Organization patterns
367        let org_patterns = vec![
368            Regex::new(r"\b[A-Z][a-zA-Z]*\s+(?:Inc|Corp|Company|Ltd|LLC|Organization|Institute|University|College|School)\b").map_err(|e| {
369                GraphError::EntityExtraction {
370                    message: format!("Failed to compile organization pattern: {}", e)
371                }
372            })?,
373            Regex::new(r"\b(?:the\s+)?[A-Z][a-zA-Z\s]+(?:Corporation|Foundation|Association|Agency|Department)\b").map_err(|e| {
374                GraphError::EntityExtraction {
375                    message: format!("Failed to compile organization pattern 2: {}", e)
376                }
377            })?,
378        ];
379        patterns.insert(EntityType::Organization, org_patterns);
380
381        // Location patterns
382        let location_patterns = vec![
383            Regex::new(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*,\s*[A-Z]{2}\b").map_err(|e| {
384                GraphError::EntityExtraction {
385                    message: format!("Failed to compile location pattern: {}", e)
386                }
387            })?,
388            Regex::new(r"\b(?:New York|Los Angeles|Chicago|Houston|Phoenix|Philadelphia|San Antonio|San Diego|Dallas|San Jose|Austin|Jacksonville|Fort Worth|Columbus|Charlotte|San Francisco|Indianapolis|Seattle|Denver|Washington|Boston|El Paso|Detroit|Nashville|Portland|Memphis|Oklahoma City|Las Vegas|Louisville|Baltimore|Milwaukee|Albuquerque|Tucson|Fresno|Sacramento|Mesa|Kansas City|Atlanta|Long Beach|Colorado Springs|Raleigh|Miami|Virginia Beach|Omaha|Oakland|Minneapolis|Tulsa|Arlington|Tampa|New Orleans|Wichita|Cleveland|Bakersfield|Aurora|Anaheim|Honolulu|Santa Ana|Riverside|Corpus Christi|Lexington|Stockton|Henderson|Saint Paul|St. Paul|Cincinnati|St. Louis|Pittsburgh|Greensboro|Lincoln|Plano|Anchorage|Durham|Jersey City|Chula Vista|Orlando|Chandler|Henderson|Laredo|Buffalo|North Las Vegas|Madison|Lubbock|Reno|Akron|Hialeah|Garland|Rochester|Modesto|Montgomery|Yonkers|Spokane|Tacoma|Shreveport|Des Moines|Fremont|Baton Rouge|Richmond|Birmingham|Chesapeake|Glendale|Irving|Scottsdale|North Hempstead|Fayetteville|Grand Rapids|Santa Clarita|Salt Lake City|Huntsville)\b").map_err(|e| {
389                GraphError::EntityExtraction {
390                    message: format!("Failed to compile major cities pattern: {}", e)
391                }
392            })?,
393        ];
394        patterns.insert(EntityType::Location, location_patterns);
395
396        // DateTime patterns
397        let datetime_patterns = vec![
398            Regex::new(r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b").map_err(|e| {
399                GraphError::EntityExtraction {
400                    message: format!("Failed to compile date pattern: {}", e)
401                }
402            })?,
403            Regex::new(r"\b\d{1,2}/\d{1,2}/\d{4}\b").map_err(|e| {
404                GraphError::EntityExtraction {
405                    message: format!("Failed to compile date pattern 2: {}", e)
406                }
407            })?,
408            Regex::new(r"\b\d{4}-\d{2}-\d{2}\b").map_err(|e| {
409                GraphError::EntityExtraction {
410                    message: format!("Failed to compile ISO date pattern: {}", e)
411                }
412            })?,
413        ];
414        patterns.insert(EntityType::DateTime, datetime_patterns);
415
416        // Money patterns
417        let money_patterns = vec![
418            Regex::new(r"\$\d+(?:,\d{3})*(?:\.\d{2})?\b").map_err(|e| {
419                GraphError::EntityExtraction {
420                    message: format!("Failed to compile money pattern: {}", e),
421                }
422            })?,
423            Regex::new(r"\b\d+(?:,\d{3})*(?:\.\d{2})?\s*(?:dollars?|USD|cents?)\b").map_err(
424                |e| GraphError::EntityExtraction {
425                    message: format!("Failed to compile money pattern 2: {}", e),
426                },
427            )?,
428        ];
429        patterns.insert(EntityType::Money, money_patterns);
430
431        // Percentage patterns
432        let percentage_patterns = vec![
433            Regex::new(r"\b\d+(?:\.\d+)?%\b").map_err(|e| GraphError::EntityExtraction {
434                message: format!("Failed to compile percentage pattern: {}", e),
435            })?,
436            Regex::new(r"\b\d+(?:\.\d+)?\s*percent\b").map_err(|e| {
437                GraphError::EntityExtraction {
438                    message: format!("Failed to compile percentage pattern 2: {}", e),
439                }
440            })?,
441        ];
442        patterns.insert(EntityType::Percentage, percentage_patterns);
443
444        // Technical terms (if enabled)
445        if config.extract_technical_terms {
446            let technical_patterns = vec![
447                Regex::new(r"\b[A-Z]{2,}(?:\s+[A-Z]{2,})*\b").map_err(|e| {
448                    GraphError::EntityExtraction {
449                        message: format!("Failed to compile technical acronym pattern: {}", e)
450                    }
451                })?,
452                Regex::new(r"\b(?:API|SDK|HTTP|HTTPS|JSON|XML|SQL|NoSQL|REST|GraphQL|OAuth|JWT|SSL|TLS|CSS|HTML|JavaScript|TypeScript|Python|Java|Rust|Go|C\+\+|PHP|Ruby|Swift|Kotlin|React|Vue|Angular|Docker|Kubernetes|AWS|GCP|Azure|MongoDB|PostgreSQL|MySQL|Redis|Elasticsearch|TensorFlow|PyTorch|OpenAI|GPT|BERT|Transformer|Neural Network|Machine Learning|Deep Learning|AI|ML|DL|NLP|Computer Vision|Data Science|Big Data|Cloud Computing|DevOps|CI/CD|Git|GitHub|GitLab|Bitbucket|Jenkins|Travis CI|CircleCI|Terraform|Ansible|Chef|Puppet)\b").map_err(|e| {
453                    GraphError::EntityExtraction {
454                        message: format!("Failed to compile technical terms pattern: {}", e)
455                    }
456                })?,
457            ];
458            patterns.insert(EntityType::Technical, technical_patterns);
459        }
460
461        // Add custom patterns
462        for (pattern_name, regex) in &config.custom_patterns {
463            let entity_type = EntityType::Custom(pattern_name.clone());
464            patterns
465                .entry(entity_type)
466                .or_insert_with(Vec::new)
467                .push(regex.clone());
468        }
469
470        Ok(patterns)
471    }
472
473    /// Compile relationship patterns
474    fn compile_relationship_patterns() -> RragResult<HashMap<RelationType, Vec<Regex>>> {
475        let mut patterns = HashMap::new();
476
477        // "is a" relationships
478        let is_a_patterns = vec![
479            Regex::new(r"(.+?)\s+is\s+a\s+(.+?)").map_err(|e| GraphError::EntityExtraction {
480                message: format!("Failed to compile is-a pattern: {}", e),
481            })?,
482            Regex::new(r"(.+?)\s+(?:are|is)\s+(?:an?|the)\s+(.+?)").map_err(|e| {
483                GraphError::EntityExtraction {
484                    message: format!("Failed to compile is-a pattern 2: {}", e),
485                }
486            })?,
487        ];
488        patterns.insert(RelationType::IsA, is_a_patterns);
489
490        // "part of" relationships
491        let part_of_patterns = vec![
492            Regex::new(r"(.+?)\s+(?:is|are)\s+part\s+of\s+(.+?)").map_err(|e| {
493                GraphError::EntityExtraction {
494                    message: format!("Failed to compile part-of pattern: {}", e),
495                }
496            })?,
497            Regex::new(r"(.+?)\s+belongs\s+to\s+(.+?)").map_err(|e| {
498                GraphError::EntityExtraction {
499                    message: format!("Failed to compile belongs-to pattern: {}", e),
500                }
501            })?,
502        ];
503        patterns.insert(RelationType::PartOf, part_of_patterns);
504
505        // "located in" relationships
506        let located_in_patterns = vec![
507            Regex::new(r"(.+?)\s+(?:is|are)\s+located\s+in\s+(.+?)").map_err(|e| {
508                GraphError::EntityExtraction {
509                    message: format!("Failed to compile located-in pattern: {}", e),
510                }
511            })?,
512            Regex::new(r"(.+?)\s+in\s+(.+?)").map_err(|e| GraphError::EntityExtraction {
513                message: format!("Failed to compile in pattern: {}", e),
514            })?,
515        ];
516        patterns.insert(RelationType::LocatedIn, located_in_patterns);
517
518        // "works for" relationships
519        let works_for_patterns = vec![
520            Regex::new(r"(.+?)\s+works\s+(?:for|at)\s+(.+?)").map_err(|e| {
521                GraphError::EntityExtraction {
522                    message: format!("Failed to compile works-for pattern: {}", e),
523                }
524            })?,
525            Regex::new(r"(.+?)\s+(?:is|was)\s+employed\s+(?:by|at)\s+(.+?)").map_err(|e| {
526                GraphError::EntityExtraction {
527                    message: format!("Failed to compile employed-by pattern: {}", e),
528                }
529            })?,
530        ];
531        patterns.insert(RelationType::WorksFor, works_for_patterns);
532
533        // "owns" relationships
534        let owns_patterns = vec![
535            Regex::new(r"(.+?)\s+owns\s+(.+?)").map_err(|e| GraphError::EntityExtraction {
536                message: format!("Failed to compile owns pattern: {}", e),
537            })?,
538            Regex::new(r"(.+?)\s+(?:has|possesses)\s+(.+?)").map_err(|e| {
539                GraphError::EntityExtraction {
540                    message: format!("Failed to compile has pattern: {}", e),
541                }
542            })?,
543        ];
544        patterns.insert(RelationType::Owns, owns_patterns);
545
546        Ok(patterns)
547    }
548
549    /// Extract entities using pattern matching
550    fn extract_entities_with_patterns(&self, text: &str, source_id: &str) -> Vec<Entity> {
551        let mut entities = Vec::new();
552        let mut seen_positions = HashSet::new();
553
554        for (entity_type, patterns) in &self.patterns {
555            let priority = self
556                .config
557                .entity_priorities
558                .get(entity_type)
559                .copied()
560                .unwrap_or(0.5);
561
562            for pattern in patterns {
563                for mat in pattern.find_iter(text) {
564                    let start_pos = mat.start();
565                    let end_pos = mat.end();
566                    let entity_text = mat.as_str().trim();
567
568                    // Skip if we've already found an entity at this position
569                    if seen_positions.contains(&(start_pos, end_pos)) {
570                        continue;
571                    }
572
573                    // Skip if it's too long or contains only stop words
574                    if entity_text.len() > self.config.max_entity_length
575                        || self.is_stop_word_only(entity_text)
576                    {
577                        continue;
578                    }
579
580                    // Calculate confidence based on pattern match and entity type priority
581                    let base_confidence = match entity_type {
582                        EntityType::DateTime | EntityType::Money | EntityType::Percentage => 0.9,
583                        EntityType::Technical => 0.8,
584                        _ => 0.7,
585                    };
586                    let confidence = (base_confidence * priority).min(1.0);
587
588                    if confidence >= self.config.min_confidence {
589                        let entity = Entity {
590                            text: entity_text.to_string(),
591                            entity_type: entity_type.clone(),
592                            start_pos,
593                            end_pos,
594                            confidence,
595                            normalized_form: Some(self.normalize_entity(entity_text)),
596                            attributes: HashMap::new(),
597                            source_id: source_id.to_string(),
598                        };
599
600                        entities.push(entity);
601                        seen_positions.insert((start_pos, end_pos));
602                    }
603                }
604            }
605        }
606
607        // Sort by position for consistent ordering
608        entities.sort_by_key(|e| e.start_pos);
609        entities
610    }
611
612    /// Check if text contains only stop words
613    fn is_stop_word_only(&self, text: &str) -> bool {
614        let words: Vec<&str> = text.split_whitespace().collect();
615        if words.is_empty() {
616            return true;
617        }
618
619        words
620            .iter()
621            .all(|word| self.config.stop_words.contains(&word.to_lowercase()))
622    }
623
624    /// Normalize entity text
625    fn normalize_entity(&self, text: &str) -> String {
626        text.trim()
627            .chars()
628            .map(|c| if c.is_whitespace() { ' ' } else { c })
629            .collect::<String>()
630            .split_whitespace()
631            .collect::<Vec<_>>()
632            .join(" ")
633    }
634
635    /// Extract relationships using pattern matching
636    fn extract_relationships_with_patterns(
637        &self,
638        text: &str,
639        entities: &[Entity],
640        source_id: &str,
641    ) -> Vec<Relationship> {
642        let mut relationships = Vec::new();
643
644        // Create entity lookup by position
645        let mut entity_spans: Vec<(usize, usize, &Entity)> = entities
646            .iter()
647            .map(|e| (e.start_pos, e.end_pos, e))
648            .collect();
649        entity_spans.sort_by_key(|&(start, _, _)| start);
650
651        for (relation_type, patterns) in &self.relationship_patterns {
652            for pattern in patterns {
653                for mat in pattern.find_iter(text) {
654                    if let Some(captures) = pattern.captures(mat.as_str()) {
655                        if captures.len() >= 3 {
656                            let source_text = captures.get(1).unwrap().as_str().trim();
657                            let target_text = captures.get(2).unwrap().as_str().trim();
658
659                            // Find entities that match the captured groups
660                            if let (Some(source_entity), Some(target_entity)) = (
661                                self.find_matching_entity(source_text, &entity_spans),
662                                self.find_matching_entity(target_text, &entity_spans),
663                            ) {
664                                let relationship = Relationship {
665                                    source_entity: source_entity
666                                        .normalized_form
667                                        .as_ref()
668                                        .unwrap_or(&source_entity.text)
669                                        .clone(),
670                                    target_entity: target_entity
671                                        .normalized_form
672                                        .as_ref()
673                                        .unwrap_or(&target_entity.text)
674                                        .clone(),
675                                    relation_type: relation_type.clone(),
676                                    context: mat.as_str().to_string(),
677                                    start_pos: mat.start(),
678                                    end_pos: mat.end(),
679                                    confidence: 0.7, // Base confidence for pattern-matched relationships
680                                    attributes: HashMap::new(),
681                                    source_id: source_id.to_string(),
682                                };
683
684                                relationships.push(relationship);
685                            }
686                        }
687                    }
688                }
689            }
690        }
691
692        // Add co-occurrence relationships for entities that appear close together
693        self.extract_co_occurrence_relationships(text, entities, source_id, &mut relationships);
694
695        relationships
696    }
697
698    /// Find entity that matches the given text
699    fn find_matching_entity<'a>(
700        &self,
701        text: &str,
702        entity_spans: &'a [(usize, usize, &'a Entity)],
703    ) -> Option<&'a Entity> {
704        entity_spans
705            .iter()
706            .find(|(_, _, entity)| {
707                entity.text.eq_ignore_ascii_case(text)
708                    || entity
709                        .normalized_form
710                        .as_ref()
711                        .map_or(false, |norm| norm.eq_ignore_ascii_case(text))
712            })
713            .map(|(_, _, entity)| *entity)
714    }
715
716    /// Extract co-occurrence relationships
717    fn extract_co_occurrence_relationships(
718        &self,
719        _text: &str,
720        entities: &[Entity],
721        source_id: &str,
722        relationships: &mut Vec<Relationship>,
723    ) {
724        let max_distance = 100; // Maximum character distance for co-occurrence
725
726        for i in 0..entities.len() {
727            for j in (i + 1)..entities.len() {
728                let entity1 = &entities[i];
729                let entity2 = &entities[j];
730
731                // Check if entities are close enough
732                let distance = if entity1.end_pos < entity2.start_pos {
733                    entity2.start_pos - entity1.end_pos
734                } else if entity2.end_pos < entity1.start_pos {
735                    entity1.start_pos - entity2.end_pos
736                } else {
737                    0 // Overlapping
738                };
739
740                if distance <= max_distance {
741                    // Calculate confidence based on distance and entity types
742                    let base_confidence = 0.3;
743                    let distance_factor = 1.0 - (distance as f32 / max_distance as f32);
744                    let confidence = base_confidence * distance_factor;
745
746                    if confidence >= self.config.min_confidence {
747                        let relationship = Relationship {
748                            source_entity: entity1
749                                .normalized_form
750                                .as_ref()
751                                .unwrap_or(&entity1.text)
752                                .clone(),
753                            target_entity: entity2
754                                .normalized_form
755                                .as_ref()
756                                .unwrap_or(&entity2.text)
757                                .clone(),
758                            relation_type: RelationType::MentionedWith,
759                            context: format!("Co-occurrence within {} characters", distance),
760                            start_pos: entity1.start_pos.min(entity2.start_pos),
761                            end_pos: entity1.end_pos.max(entity2.end_pos),
762                            confidence,
763                            attributes: {
764                                let mut attrs = HashMap::new();
765                                attrs.insert(
766                                    "distance".to_string(),
767                                    serde_json::Value::Number(distance.into()),
768                                );
769                                attrs.insert(
770                                    "type".to_string(),
771                                    serde_json::Value::String("co_occurrence".to_string()),
772                                );
773                                attrs
774                            },
775                            source_id: source_id.to_string(),
776                        };
777
778                        relationships.push(relationship);
779                    }
780                }
781            }
782        }
783    }
784}
785
786#[async_trait]
787impl EntityExtractor for RuleBasedEntityExtractor {
788    async fn extract_entities(&self, text: &str, source_id: &str) -> RragResult<Vec<Entity>> {
789        Ok(self.extract_entities_with_patterns(text, source_id))
790    }
791
792    async fn extract_relationships(
793        &self,
794        text: &str,
795        entities: &[Entity],
796        source_id: &str,
797    ) -> RragResult<Vec<Relationship>> {
798        Ok(self.extract_relationships_with_patterns(text, entities, source_id))
799    }
800}
801
802/// Convert entities to graph nodes
803pub fn entities_to_nodes(entities: &[Entity]) -> Vec<GraphNode> {
804    entities
805        .iter()
806        .map(|entity| {
807            let node_type = match &entity.entity_type {
808                EntityType::Person => NodeType::Entity("Person".to_string()),
809                EntityType::Organization => NodeType::Entity("Organization".to_string()),
810                EntityType::Location => NodeType::Entity("Location".to_string()),
811                EntityType::DateTime => NodeType::Entity("DateTime".to_string()),
812                EntityType::Money => NodeType::Entity("Money".to_string()),
813                EntityType::Percentage => NodeType::Entity("Percentage".to_string()),
814                EntityType::Technical => NodeType::Entity("Technical".to_string()),
815                EntityType::Concept => NodeType::Concept,
816                EntityType::Product => NodeType::Entity("Product".to_string()),
817                EntityType::Event => NodeType::Entity("Event".to_string()),
818                EntityType::Custom(custom_type) => NodeType::Custom(custom_type.clone()),
819            };
820
821            let mut node = GraphNode::new(
822                entity.normalized_form.as_ref().unwrap_or(&entity.text),
823                node_type,
824            )
825            .with_confidence(entity.confidence)
826            .with_source_document(entity.source_id.clone());
827
828            // Add entity attributes
829            for (key, value) in &entity.attributes {
830                node = node.with_attribute(key, value.clone());
831            }
832
833            node = node.with_attribute(
834                "original_text",
835                serde_json::Value::String(entity.text.clone()),
836            );
837            node = node.with_attribute(
838                "start_pos",
839                serde_json::Value::Number(entity.start_pos.into()),
840            );
841            node = node.with_attribute("end_pos", serde_json::Value::Number(entity.end_pos.into()));
842
843            node
844        })
845        .collect()
846}
847
848/// Convert relationships to graph edges
849pub fn relationships_to_edges(
850    relationships: &[Relationship],
851    entity_node_map: &HashMap<String, String>,
852) -> Vec<GraphEdge> {
853    relationships
854        .iter()
855        .filter_map(|relationship| {
856            // Find node IDs for source and target entities
857            let source_node_id = entity_node_map.get(&relationship.source_entity)?;
858            let target_node_id = entity_node_map.get(&relationship.target_entity)?;
859
860            let edge_type = match &relationship.relation_type {
861                RelationType::IsA => EdgeType::Semantic("is_a".to_string()),
862                RelationType::PartOf => EdgeType::Semantic("part_of".to_string()),
863                RelationType::LocatedIn => EdgeType::Semantic("located_in".to_string()),
864                RelationType::WorksFor => EdgeType::Semantic("works_for".to_string()),
865                RelationType::Owns => EdgeType::Semantic("owns".to_string()),
866                RelationType::Causes => EdgeType::Semantic("causes".to_string()),
867                RelationType::SimilarTo => EdgeType::Similar,
868                RelationType::HappenedOn => EdgeType::Semantic("happened_on".to_string()),
869                RelationType::MentionedWith => EdgeType::CoOccurs,
870                RelationType::Custom(custom_type) => EdgeType::Custom(custom_type.clone()),
871            };
872
873            let mut edge = GraphEdge::new(
874                source_node_id,
875                target_node_id,
876                &relationship.relation_type.to_string(),
877                edge_type,
878            )
879            .with_confidence(relationship.confidence)
880            .with_weight(relationship.confidence)
881            .with_source_document(relationship.source_id.clone());
882
883            // Add relationship attributes
884            for (key, value) in &relationship.attributes {
885                edge = edge.with_attribute(key, value.clone());
886            }
887
888            edge = edge.with_attribute(
889                "context",
890                serde_json::Value::String(relationship.context.clone()),
891            );
892            edge = edge.with_attribute(
893                "start_pos",
894                serde_json::Value::Number(relationship.start_pos.into()),
895            );
896            edge = edge.with_attribute(
897                "end_pos",
898                serde_json::Value::Number(relationship.end_pos.into()),
899            );
900
901            Some(edge)
902        })
903        .collect()
904}
905
906impl std::fmt::Display for RelationType {
907    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
908        match self {
909            RelationType::IsA => write!(f, "is_a"),
910            RelationType::PartOf => write!(f, "part_of"),
911            RelationType::LocatedIn => write!(f, "located_in"),
912            RelationType::WorksFor => write!(f, "works_for"),
913            RelationType::Owns => write!(f, "owns"),
914            RelationType::Causes => write!(f, "causes"),
915            RelationType::SimilarTo => write!(f, "similar_to"),
916            RelationType::HappenedOn => write!(f, "happened_on"),
917            RelationType::MentionedWith => write!(f, "mentioned_with"),
918            RelationType::Custom(custom) => write!(f, "{}", custom),
919        }
920    }
921}
922
923#[cfg(test)]
924mod tests {
925    use super::*;
926
927    #[tokio::test]
928    async fn test_rule_based_entity_extraction() {
929        let config = EntityExtractionConfig::default();
930        let extractor = RuleBasedEntityExtractor::new(config).unwrap();
931
932        let text = "John Smith works at Microsoft Corporation in Seattle. The company was founded in 1975.";
933        let entities = extractor.extract_entities(text, "test_doc").await.unwrap();
934
935        assert!(!entities.is_empty());
936
937        // Should find at least person, organization, and location
938        let person_entities: Vec<_> = entities
939            .iter()
940            .filter(|e| matches!(e.entity_type, EntityType::Person))
941            .collect();
942        assert!(!person_entities.is_empty());
943
944        let org_entities: Vec<_> = entities
945            .iter()
946            .filter(|e| matches!(e.entity_type, EntityType::Organization))
947            .collect();
948        assert!(!org_entities.is_empty());
949    }
950
951    #[tokio::test]
952    async fn test_relationship_extraction() {
953        let config = EntityExtractionConfig::default();
954        let extractor = RuleBasedEntityExtractor::new(config).unwrap();
955
956        let text = "Alice is a software engineer. She works for Google.";
957        let (entities, relationships) = extractor.extract_all(text, "test_doc").await.unwrap();
958
959        assert!(!entities.is_empty());
960        assert!(!relationships.is_empty());
961
962        // Should find work relationship
963        let work_relations: Vec<_> = relationships
964            .iter()
965            .filter(|r| matches!(r.relation_type, RelationType::WorksFor))
966            .collect();
967        assert!(!work_relations.is_empty());
968    }
969
970    #[test]
971    fn test_entity_to_node_conversion() {
972        let entity = Entity {
973            text: "John Smith".to_string(),
974            entity_type: EntityType::Person,
975            start_pos: 0,
976            end_pos: 10,
977            confidence: 0.9,
978            normalized_form: Some("John Smith".to_string()),
979            attributes: HashMap::new(),
980            source_id: "test_doc".to_string(),
981        };
982
983        let nodes = entities_to_nodes(&[entity]);
984        assert_eq!(nodes.len(), 1);
985        assert!(matches!(nodes[0].node_type, NodeType::Entity(_)));
986        assert_eq!(nodes[0].confidence, 0.9);
987    }
988}