Skip to main content

graphrag_core/
inference.rs

1//! Implicit relationship inference system
2
3use crate::core::{Entity, EntityId, KnowledgeGraph, TextChunk};
4use std::collections::HashMap;
5
6/// Represents a relationship inferred between two entities
7///
8/// This structure contains information about a relationship discovered through
9/// co-occurrence analysis and contextual pattern matching.
10#[derive(Debug, Clone)]
11pub struct InferredRelation {
12    /// Source entity in the relationship
13    pub source: EntityId,
14    /// Target entity in the relationship
15    pub target: EntityId,
16    /// Type of relationship (e.g., "FRIENDS", "DISCUSSES", "LOCATED_IN")
17    pub relation_type: String,
18    /// Confidence score for this inference (0.0-1.0)
19    pub confidence: f32,
20    /// Number of text chunks providing evidence for this relationship
21    pub evidence_count: usize,
22}
23
24/// Configuration for the relationship inference engine
25///
26/// Controls the behavior and thresholds used when inferring implicit relationships
27/// between entities based on their co-occurrence in text.
28#[derive(Debug, Clone)]
29pub struct InferenceConfig {
30    /// Minimum confidence threshold for accepting an inferred relationship
31    pub min_confidence: f32,
32    /// Maximum number of candidate relationships to return per query
33    pub max_candidates: usize,
34    /// Threshold for determining if entities co-occur frequently enough
35    pub co_occurrence_threshold: f32,
36}
37
38impl Default for InferenceConfig {
39    fn default() -> Self {
40        Self {
41            min_confidence: 0.3,
42            max_candidates: 10,
43            co_occurrence_threshold: 0.4,
44        }
45    }
46}
47
48/// Engine for inferring implicit relationships between entities
49///
50/// The inference engine analyzes entity co-occurrence patterns, proximity,
51/// and contextual clues to discover relationships that may not be explicitly
52/// stated in the text.
53pub struct InferenceEngine {
54    /// Configuration controlling inference behavior
55    config: InferenceConfig,
56}
57
58impl InferenceEngine {
59    /// Create a new inference engine with the given configuration
60    ///
61    /// # Arguments
62    ///
63    /// * `config` - Configuration controlling inference thresholds and limits
64    pub fn new(config: InferenceConfig) -> Self {
65        Self { config }
66    }
67
68    /// Infer relationships for a target entity
69    ///
70    /// Analyzes the knowledge graph to find entities that frequently co-occur with
71    /// the target entity and have contextual evidence of a relationship.
72    ///
73    /// # Arguments
74    ///
75    /// * `target_entity` - The entity to find relationships for
76    /// * `relation_type` - The type of relationship to infer (e.g., "FRIENDS")
77    /// * `knowledge_graph` - The knowledge graph containing entities and chunks
78    ///
79    /// # Returns
80    ///
81    /// Returns a vector of inferred relationships, sorted by confidence score,
82    /// limited to `max_candidates` from the configuration.
83    pub fn infer_relationships(
84        &self,
85        target_entity: &EntityId,
86        relation_type: &str,
87        knowledge_graph: &KnowledgeGraph,
88    ) -> Vec<InferredRelation> {
89        let mut inferred_relations = Vec::new();
90
91        // Find target entity
92        let target_ent = knowledge_graph.entities().find(|e| &e.id == target_entity);
93
94        if target_ent.is_none() {
95            return inferred_relations;
96        }
97
98        // Get chunks containing target entity
99        let target_chunks: Vec<_> = knowledge_graph
100            .chunks()
101            .filter(|chunk| chunk.entities.contains(target_entity))
102            .collect();
103
104        // Find co-occurring entities
105        let mut entity_scores: HashMap<EntityId, f32> = HashMap::new();
106
107        for chunk in &target_chunks {
108            for entity_id in &chunk.entities {
109                if entity_id != target_entity {
110                    let evidence_score =
111                        self.calculate_evidence_score(chunk, target_entity, entity_id);
112                    *entity_scores.entry(entity_id.clone()).or_insert(0.0) += evidence_score;
113                }
114            }
115        }
116
117        // Create inferred relations for high-scoring entities
118        for (entity_id, score) in entity_scores {
119            let normalized_score = (score / target_chunks.len() as f32).min(1.0);
120
121            if normalized_score >= self.config.min_confidence {
122                inferred_relations.push(InferredRelation {
123                    source: target_entity.clone(),
124                    target: entity_id,
125                    relation_type: relation_type.to_string(),
126                    confidence: normalized_score,
127                    evidence_count: target_chunks.len(),
128                });
129            }
130        }
131
132        // Sort by confidence and limit results
133        inferred_relations.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
134        inferred_relations.truncate(self.config.max_candidates);
135
136        inferred_relations
137    }
138
139    /// Calculate evidence score for a potential relationship
140    ///
141    /// Analyzes a text chunk to determine how strongly it suggests a relationship
142    /// between two entities. Uses proximity analysis, pattern matching, and
143    /// contextual clues.
144    ///
145    /// # Arguments
146    ///
147    /// * `chunk` - The text chunk containing both entities
148    /// * `entity_a` - First entity ID
149    /// * `entity_b` - Second entity ID
150    ///
151    /// # Returns
152    ///
153    /// Returns a score between 0.0 and 1.0 indicating relationship strength.
154    /// Higher scores indicate stronger evidence of a relationship.
155    fn calculate_evidence_score(
156        &self,
157        chunk: &TextChunk,
158        entity_a: &EntityId,
159        entity_b: &EntityId,
160    ) -> f32 {
161        let content = &chunk.content.to_lowercase();
162        let mut score: f32 = 0.2; // Lower base co-occurrence score
163
164        // Get entity names for contextual analysis
165        let entity_a_name = self.extract_entity_name(entity_a);
166        let entity_b_name = self.extract_entity_name(entity_b);
167
168        // Calculate proximity score between entities in text
169        let proximity_bonus =
170            self.calculate_proximity_score(content, &entity_a_name, &entity_b_name);
171        score += proximity_bonus;
172
173        // Enhanced friendship indicators with contextual patterns
174        let friendship_patterns = [
175            // Direct friendship terms
176            ("best friend", 0.8),
177            ("close friend", 0.7),
178            ("good friend", 0.6),
179            ("friend", 0.4),
180            ("friends", 0.4),
181            ("friendship", 0.5),
182            // Activity-based friendship indicators
183            ("played together", 0.6),
184            ("went together", 0.5),
185            ("talked with", 0.4),
186            ("helped each other", 0.7),
187            ("shared", 0.3),
188            ("together", 0.3),
189            // Emotional bonding indicators
190            ("trusted", 0.6),
191            ("loyal", 0.5),
192            ("bond", 0.5),
193            ("close", 0.4),
194            ("cared for", 0.6),
195            ("looked after", 0.5),
196            ("protected", 0.6),
197            // Adventure/activity companionship
198            ("adventure", 0.4),
199            ("explore", 0.3),
200            ("journey", 0.3),
201            ("companion", 0.6),
202            ("partner", 0.5),
203            ("ally", 0.5),
204        ];
205
206        // Contextual pattern matching with weighted scores
207        for (pattern, weight) in &friendship_patterns {
208            if content.contains(pattern) {
209                // Additional context bonus if entities are mentioned near the pattern
210                let context_bonus =
211                    if self.entities_near_pattern(content, &entity_a_name, &entity_b_name, pattern)
212                    {
213                        weight * 0.5
214                    } else {
215                        *weight * 0.3
216                    };
217                score += context_bonus;
218            }
219        }
220
221        // Enhanced negative indicators with contextual analysis
222        let negative_patterns = [
223            ("enemy", -0.8),
224            ("enemies", -0.8),
225            ("rival", -0.6),
226            ("rivals", -0.6),
227            ("fought", -0.5),
228            ("fight", -0.4),
229            ("battle", -0.4),
230            ("conflict", -0.5),
231            ("angry at", -0.6),
232            ("hate", -0.7),
233            ("hated", -0.7),
234            ("despise", -0.6),
235            ("betrayed", -0.8),
236            ("betrayal", -0.7),
237            ("argued", -0.3),
238            ("quarrel", -0.4),
239            ("against", -0.2),
240            ("opposed", -0.4),
241            ("disagree", -0.2),
242        ];
243
244        for (pattern, weight) in &negative_patterns {
245            if content.contains(pattern) {
246                let context_penalty =
247                    if self.entities_near_pattern(content, &entity_a_name, &entity_b_name, pattern)
248                    {
249                        weight * 1.2
250                    } else {
251                        weight * 0.8
252                    };
253                score += context_penalty; // weight is already negative
254            }
255        }
256
257        // Family relationship indicators (neutral for friendship)
258        let family_patterns = ["brother", "sister", "cousin", "aunt", "uncle", "family"];
259        let mut has_family_relation = false;
260        for pattern in &family_patterns {
261            if content.contains(pattern) {
262                has_family_relation = true;
263                break;
264            }
265        }
266
267        // Family relations can still be friendships, but lower weight
268        if has_family_relation {
269            score *= 0.8;
270        }
271
272        score.clamp(0.0, 1.0)
273    }
274
275    /// Extract clean entity name from an entity ID
276    ///
277    /// Entity IDs typically have format "TYPE_normalized_name". This method
278    /// extracts just the name portion and formats it for matching.
279    ///
280    /// # Arguments
281    ///
282    /// * `entity_id` - The entity ID to extract the name from
283    ///
284    /// # Returns
285    ///
286    /// Returns the cleaned, lowercase entity name with underscores replaced by spaces
287    fn extract_entity_name(&self, entity_id: &EntityId) -> String {
288        // EntityId format is typically "TYPE_normalized_name"
289        let id_str = &entity_id.0;
290        if let Some(underscore_pos) = id_str.find('_') {
291            id_str[underscore_pos + 1..]
292                .replace('_', " ")
293                .to_lowercase()
294        } else {
295            id_str.to_lowercase()
296        }
297    }
298
299    /// Calculate proximity score between entities in text
300    ///
301    /// Determines how close two entities are mentioned in the text. Closer proximity
302    /// suggests a stronger relationship between the entities.
303    ///
304    /// # Arguments
305    ///
306    /// * `content` - The text content to analyze
307    /// * `entity_a` - Name of the first entity
308    /// * `entity_b` - Name of the second entity
309    ///
310    /// # Returns
311    ///
312    /// Returns a proximity score:
313    /// - 0.4 for very close (0-2 words apart)
314    /// - 0.3 for close (3-5 words apart)
315    /// - 0.2 for medium distance (6-10 words apart)
316    /// - 0.1 for far (11-20 words apart)
317    /// - 0.05 for very far (20+ words apart)
318    fn calculate_proximity_score(&self, content: &str, entity_a: &str, entity_b: &str) -> f32 {
319        let words: Vec<&str> = content.split_whitespace().collect();
320        let mut positions_a = Vec::new();
321        let mut positions_b = Vec::new();
322
323        // Find all positions of entity mentions
324        for (i, word) in words.iter().enumerate() {
325            if word.to_lowercase().contains(entity_a) {
326                positions_a.push(i);
327            }
328            if word.to_lowercase().contains(entity_b) {
329                positions_b.push(i);
330            }
331        }
332
333        if positions_a.is_empty() || positions_b.is_empty() {
334            return 0.0;
335        }
336
337        // Find minimum distance between any mentions
338        let mut min_distance = usize::MAX;
339        for &pos_a in &positions_a {
340            for &pos_b in &positions_b {
341                let distance = pos_a.abs_diff(pos_b);
342                min_distance = min_distance.min(distance);
343            }
344        }
345
346        // Convert distance to proximity score (closer = higher score)
347        match min_distance {
348            0..=2 => 0.4,   // Very close (same sentence likely)
349            3..=5 => 0.3,   // Close
350            6..=10 => 0.2,  // Medium distance
351            11..=20 => 0.1, // Far
352            _ => 0.05,      // Very far
353        }
354    }
355
356    /// Check if entities are mentioned near a relationship pattern
357    ///
358    /// Determines if both entities appear within a 200-character window
359    /// around a relationship keyword or pattern. This helps determine if
360    /// a relationship pattern actually applies to these specific entities.
361    ///
362    /// # Arguments
363    ///
364    /// * `content` - The text content to search
365    /// * `entity_a` - Name of the first entity
366    /// * `entity_b` - Name of the second entity
367    /// * `pattern` - The relationship pattern to search for (e.g., "friend", "enemy")
368    ///
369    /// # Returns
370    ///
371    /// Returns `true` if both entities are found within 100 characters before and
372    /// after the pattern, `false` otherwise.
373    fn entities_near_pattern(
374        &self,
375        content: &str,
376        entity_a: &str,
377        entity_b: &str,
378        pattern: &str,
379    ) -> bool {
380        if let Some(pattern_pos) = content.find(pattern) {
381            let start = pattern_pos.saturating_sub(100); // 100 chars before
382            let end = (pattern_pos + pattern.len() + 100).min(content.len()); // 100 chars after
383            let context = &content[start..end];
384
385            context.contains(entity_a) && context.contains(entity_b)
386        } else {
387            false
388        }
389    }
390
391    /// Find an entity in the knowledge graph by name
392    ///
393    /// Performs a case-insensitive substring search to find an entity whose
394    /// name contains the given search string.
395    ///
396    /// # Arguments
397    ///
398    /// * `knowledge_graph` - The knowledge graph to search
399    /// * `name` - The name (or partial name) to search for
400    ///
401    /// # Returns
402    ///
403    /// Returns `Some(&Entity)` if a matching entity is found, `None` otherwise.
404    pub fn find_entity_by_name<'a>(
405        &self,
406        knowledge_graph: &'a KnowledgeGraph,
407        name: &str,
408    ) -> Option<&'a Entity> {
409        knowledge_graph
410            .entities()
411            .find(|e| e.name.to_lowercase().contains(&name.to_lowercase()))
412    }
413}