Skip to main content

graphrag_core/
inference.rs

1//! Implicit relationship inference system
2
3use crate::core::{Entity, EntityId, KnowledgeGraph, TextChunk};
4use std::collections::HashMap;
5
6/// Represents a relationship inferred between two entities
7///
8/// This structure contains information about a relationship discovered through
9/// co-occurrence analysis and contextual pattern matching.
10#[derive(Debug, Clone)]
11pub struct InferredRelation {
12    /// Source entity in the relationship
13    pub source: EntityId,
14    /// Target entity in the relationship
15    pub target: EntityId,
16    /// Type of relationship (e.g., "FRIENDS", "DISCUSSES", "LOCATED_IN")
17    pub relation_type: String,
18    /// Confidence score for this inference (0.0-1.0)
19    pub confidence: f32,
20    /// Number of text chunks providing evidence for this relationship
21    pub evidence_count: usize,
22}
23
24/// Configuration for the relationship inference engine
25///
26/// Controls the behavior and thresholds used when inferring implicit relationships
27/// between entities based on their co-occurrence in text.
28#[derive(Debug, Clone)]
29pub struct InferenceConfig {
30    /// Minimum confidence threshold for accepting an inferred relationship
31    pub min_confidence: f32,
32    /// Maximum number of candidate relationships to return per query
33    pub max_candidates: usize,
34    /// Threshold for determining if entities co-occur frequently enough
35    pub co_occurrence_threshold: f32,
36}
37
38impl Default for InferenceConfig {
39    fn default() -> Self {
40        Self {
41            min_confidence: 0.3,
42            max_candidates: 10,
43            co_occurrence_threshold: 0.4,
44        }
45    }
46}
47
48/// Engine for inferring implicit relationships between entities
49///
50/// The inference engine analyzes entity co-occurrence patterns, proximity,
51/// and contextual clues to discover relationships that may not be explicitly
52/// stated in the text.
53pub struct InferenceEngine {
54    /// Configuration controlling inference behavior
55    config: InferenceConfig,
56}
57
58impl InferenceEngine {
59    /// Create a new inference engine with the given configuration
60    ///
61    /// # Arguments
62    ///
63    /// * `config` - Configuration controlling inference thresholds and limits
64    pub fn new(config: InferenceConfig) -> Self {
65        Self { config }
66    }
67
68    /// Infer relationships for a target entity
69    ///
70    /// Analyzes the knowledge graph to find entities that frequently co-occur with
71    /// the target entity and have contextual evidence of a relationship.
72    ///
73    /// # Arguments
74    ///
75    /// * `target_entity` - The entity to find relationships for
76    /// * `relation_type` - The type of relationship to infer (e.g., "FRIENDS")
77    /// * `knowledge_graph` - The knowledge graph containing entities and chunks
78    ///
79    /// # Returns
80    ///
81    /// Returns a vector of inferred relationships, sorted by confidence score,
82    /// limited to `max_candidates` from the configuration.
83    pub fn infer_relationships(
84        &self,
85        target_entity: &EntityId,
86        relation_type: &str,
87        knowledge_graph: &KnowledgeGraph,
88    ) -> Vec<InferredRelation> {
89        let mut inferred_relations = Vec::new();
90
91        // Find target entity
92        let target_ent = knowledge_graph.entities().find(|e| &e.id == target_entity);
93
94        if target_ent.is_none() {
95            return inferred_relations;
96        }
97
98        // Get chunks containing target entity
99        let target_chunks: Vec<_> = knowledge_graph
100            .chunks()
101            .filter(|chunk| chunk.entities.contains(target_entity))
102            .collect();
103
104        // Find co-occurring entities
105        let mut entity_scores: HashMap<EntityId, f32> = HashMap::new();
106
107        for chunk in &target_chunks {
108            for entity_id in &chunk.entities {
109                if entity_id != target_entity {
110                    let evidence_score =
111                        self.calculate_evidence_score(chunk, target_entity, entity_id);
112                    *entity_scores.entry(entity_id.clone()).or_insert(0.0) += evidence_score;
113                }
114            }
115        }
116
117        // Create inferred relations for high-scoring entities
118        for (entity_id, score) in entity_scores {
119            let normalized_score = (score / target_chunks.len() as f32).min(1.0);
120
121            if normalized_score >= self.config.min_confidence {
122                inferred_relations.push(InferredRelation {
123                    source: target_entity.clone(),
124                    target: entity_id,
125                    relation_type: relation_type.to_string(),
126                    confidence: normalized_score,
127                    evidence_count: target_chunks.len(),
128                });
129            }
130        }
131
132        // Sort by confidence and limit results
133        inferred_relations.sort_by(|a, b| {
134            b.confidence
135                .partial_cmp(&a.confidence)
136                .unwrap_or(std::cmp::Ordering::Equal)
137        });
138        inferred_relations.truncate(self.config.max_candidates);
139
140        inferred_relations
141    }
142
143    /// Calculate evidence score for a potential relationship
144    ///
145    /// Analyzes a text chunk to determine how strongly it suggests a relationship
146    /// between two entities. Uses proximity analysis, pattern matching, and
147    /// contextual clues.
148    ///
149    /// # Arguments
150    ///
151    /// * `chunk` - The text chunk containing both entities
152    /// * `entity_a` - First entity ID
153    /// * `entity_b` - Second entity ID
154    ///
155    /// # Returns
156    ///
157    /// Returns a score between 0.0 and 1.0 indicating relationship strength.
158    /// Higher scores indicate stronger evidence of a relationship.
159    fn calculate_evidence_score(
160        &self,
161        chunk: &TextChunk,
162        entity_a: &EntityId,
163        entity_b: &EntityId,
164    ) -> f32 {
165        let content = &chunk.content.to_lowercase();
166        let mut score: f32 = 0.2; // Lower base co-occurrence score
167
168        // Get entity names for contextual analysis
169        let entity_a_name = self.extract_entity_name(entity_a);
170        let entity_b_name = self.extract_entity_name(entity_b);
171
172        // Calculate proximity score between entities in text
173        let proximity_bonus =
174            self.calculate_proximity_score(content, &entity_a_name, &entity_b_name);
175        score += proximity_bonus;
176
177        // Enhanced friendship indicators with contextual patterns
178        let friendship_patterns = [
179            // Direct friendship terms
180            ("best friend", 0.8),
181            ("close friend", 0.7),
182            ("good friend", 0.6),
183            ("friend", 0.4),
184            ("friends", 0.4),
185            ("friendship", 0.5),
186            // Activity-based friendship indicators
187            ("played together", 0.6),
188            ("went together", 0.5),
189            ("talked with", 0.4),
190            ("helped each other", 0.7),
191            ("shared", 0.3),
192            ("together", 0.3),
193            // Emotional bonding indicators
194            ("trusted", 0.6),
195            ("loyal", 0.5),
196            ("bond", 0.5),
197            ("close", 0.4),
198            ("cared for", 0.6),
199            ("looked after", 0.5),
200            ("protected", 0.6),
201            // Adventure/activity companionship
202            ("adventure", 0.4),
203            ("explore", 0.3),
204            ("journey", 0.3),
205            ("companion", 0.6),
206            ("partner", 0.5),
207            ("ally", 0.5),
208        ];
209
210        // Contextual pattern matching with weighted scores
211        for (pattern, weight) in &friendship_patterns {
212            if content.contains(pattern) {
213                // Additional context bonus if entities are mentioned near the pattern
214                let context_bonus =
215                    if self.entities_near_pattern(content, &entity_a_name, &entity_b_name, pattern)
216                    {
217                        weight * 0.5
218                    } else {
219                        *weight * 0.3
220                    };
221                score += context_bonus;
222            }
223        }
224
225        // Enhanced negative indicators with contextual analysis
226        let negative_patterns = [
227            ("enemy", -0.8),
228            ("enemies", -0.8),
229            ("rival", -0.6),
230            ("rivals", -0.6),
231            ("fought", -0.5),
232            ("fight", -0.4),
233            ("battle", -0.4),
234            ("conflict", -0.5),
235            ("angry at", -0.6),
236            ("hate", -0.7),
237            ("hated", -0.7),
238            ("despise", -0.6),
239            ("betrayed", -0.8),
240            ("betrayal", -0.7),
241            ("argued", -0.3),
242            ("quarrel", -0.4),
243            ("against", -0.2),
244            ("opposed", -0.4),
245            ("disagree", -0.2),
246        ];
247
248        for (pattern, weight) in &negative_patterns {
249            if content.contains(pattern) {
250                let context_penalty =
251                    if self.entities_near_pattern(content, &entity_a_name, &entity_b_name, pattern)
252                    {
253                        weight * 1.2
254                    } else {
255                        weight * 0.8
256                    };
257                score += context_penalty; // weight is already negative
258            }
259        }
260
261        // Family relationship indicators (neutral for friendship)
262        let family_patterns = ["brother", "sister", "cousin", "aunt", "uncle", "family"];
263        let mut has_family_relation = false;
264        for pattern in &family_patterns {
265            if content.contains(pattern) {
266                has_family_relation = true;
267                break;
268            }
269        }
270
271        // Family relations can still be friendships, but lower weight
272        if has_family_relation {
273            score *= 0.8;
274        }
275
276        score.clamp(0.0, 1.0)
277    }
278
279    /// Extract clean entity name from an entity ID
280    ///
281    /// Entity IDs typically have format "TYPE_normalized_name". This method
282    /// extracts just the name portion and formats it for matching.
283    ///
284    /// # Arguments
285    ///
286    /// * `entity_id` - The entity ID to extract the name from
287    ///
288    /// # Returns
289    ///
290    /// Returns the cleaned, lowercase entity name with underscores replaced by spaces
291    fn extract_entity_name(&self, entity_id: &EntityId) -> String {
292        // EntityId format is typically "TYPE_normalized_name"
293        let id_str = &entity_id.0;
294        if let Some(underscore_pos) = id_str.find('_') {
295            id_str[underscore_pos + 1..]
296                .replace('_', " ")
297                .to_lowercase()
298        } else {
299            id_str.to_lowercase()
300        }
301    }
302
303    /// Calculate proximity score between entities in text
304    ///
305    /// Determines how close two entities are mentioned in the text. Closer proximity
306    /// suggests a stronger relationship between the entities.
307    ///
308    /// # Arguments
309    ///
310    /// * `content` - The text content to analyze
311    /// * `entity_a` - Name of the first entity
312    /// * `entity_b` - Name of the second entity
313    ///
314    /// # Returns
315    ///
316    /// Returns a proximity score:
317    /// - 0.4 for very close (0-2 words apart)
318    /// - 0.3 for close (3-5 words apart)
319    /// - 0.2 for medium distance (6-10 words apart)
320    /// - 0.1 for far (11-20 words apart)
321    /// - 0.05 for very far (20+ words apart)
322    fn calculate_proximity_score(&self, content: &str, entity_a: &str, entity_b: &str) -> f32 {
323        let words: Vec<&str> = content.split_whitespace().collect();
324        let mut positions_a = Vec::new();
325        let mut positions_b = Vec::new();
326
327        // Find all positions of entity mentions
328        for (i, word) in words.iter().enumerate() {
329            if word.to_lowercase().contains(entity_a) {
330                positions_a.push(i);
331            }
332            if word.to_lowercase().contains(entity_b) {
333                positions_b.push(i);
334            }
335        }
336
337        if positions_a.is_empty() || positions_b.is_empty() {
338            return 0.0;
339        }
340
341        // Find minimum distance between any mentions
342        let mut min_distance = usize::MAX;
343        for &pos_a in &positions_a {
344            for &pos_b in &positions_b {
345                let distance = pos_a.abs_diff(pos_b);
346                min_distance = min_distance.min(distance);
347            }
348        }
349
350        // Convert distance to proximity score (closer = higher score)
351        match min_distance {
352            0..=2 => 0.4,   // Very close (same sentence likely)
353            3..=5 => 0.3,   // Close
354            6..=10 => 0.2,  // Medium distance
355            11..=20 => 0.1, // Far
356            _ => 0.05,      // Very far
357        }
358    }
359
360    /// Check if entities are mentioned near a relationship pattern
361    ///
362    /// Determines if both entities appear within a 200-character window
363    /// around a relationship keyword or pattern. This helps determine if
364    /// a relationship pattern actually applies to these specific entities.
365    ///
366    /// # Arguments
367    ///
368    /// * `content` - The text content to search
369    /// * `entity_a` - Name of the first entity
370    /// * `entity_b` - Name of the second entity
371    /// * `pattern` - The relationship pattern to search for (e.g., "friend", "enemy")
372    ///
373    /// # Returns
374    ///
375    /// Returns `true` if both entities are found within 100 characters before and
376    /// after the pattern, `false` otherwise.
377    fn entities_near_pattern(
378        &self,
379        content: &str,
380        entity_a: &str,
381        entity_b: &str,
382        pattern: &str,
383    ) -> bool {
384        if let Some(pattern_pos) = content.find(pattern) {
385            let start = pattern_pos.saturating_sub(100); // 100 chars before
386            let end = (pattern_pos + pattern.len() + 100).min(content.len()); // 100 chars after
387            let context = &content[start..end];
388
389            context.contains(entity_a) && context.contains(entity_b)
390        } else {
391            false
392        }
393    }
394
395    /// Find an entity in the knowledge graph by name
396    ///
397    /// Performs a case-insensitive substring search to find an entity whose
398    /// name contains the given search string.
399    ///
400    /// # Arguments
401    ///
402    /// * `knowledge_graph` - The knowledge graph to search
403    /// * `name` - The name (or partial name) to search for
404    ///
405    /// # Returns
406    ///
407    /// Returns `Some(&Entity)` if a matching entity is found, `None` otherwise.
408    pub fn find_entity_by_name<'a>(
409        &self,
410        knowledge_graph: &'a KnowledgeGraph,
411        name: &str,
412    ) -> Option<&'a Entity> {
413        knowledge_graph
414            .entities()
415            .find(|e| e.name.to_lowercase().contains(&name.to_lowercase()))
416    }
417}