graphrag_core/inference.rs
1//! Implicit relationship inference system
2
3use crate::core::{Entity, EntityId, KnowledgeGraph, TextChunk};
4use std::collections::HashMap;
5
6/// Represents a relationship inferred between two entities
7///
8/// This structure contains information about a relationship discovered through
9/// co-occurrence analysis and contextual pattern matching.
10#[derive(Debug, Clone)]
11pub struct InferredRelation {
12 /// Source entity in the relationship
13 pub source: EntityId,
14 /// Target entity in the relationship
15 pub target: EntityId,
16 /// Type of relationship (e.g., "FRIENDS", "DISCUSSES", "LOCATED_IN")
17 pub relation_type: String,
18 /// Confidence score for this inference (0.0-1.0)
19 pub confidence: f32,
20 /// Number of text chunks providing evidence for this relationship
21 pub evidence_count: usize,
22}
23
24/// Configuration for the relationship inference engine
25///
26/// Controls the behavior and thresholds used when inferring implicit relationships
27/// between entities based on their co-occurrence in text.
28#[derive(Debug, Clone)]
29pub struct InferenceConfig {
30 /// Minimum confidence threshold for accepting an inferred relationship
31 pub min_confidence: f32,
32 /// Maximum number of candidate relationships to return per query
33 pub max_candidates: usize,
34 /// Threshold for determining if entities co-occur frequently enough
35 pub co_occurrence_threshold: f32,
36}
37
38impl Default for InferenceConfig {
39 fn default() -> Self {
40 Self {
41 min_confidence: 0.3,
42 max_candidates: 10,
43 co_occurrence_threshold: 0.4,
44 }
45 }
46}
47
48/// Engine for inferring implicit relationships between entities
49///
50/// The inference engine analyzes entity co-occurrence patterns, proximity,
51/// and contextual clues to discover relationships that may not be explicitly
52/// stated in the text.
53pub struct InferenceEngine {
54 /// Configuration controlling inference behavior
55 config: InferenceConfig,
56}
57
58impl InferenceEngine {
59 /// Create a new inference engine with the given configuration
60 ///
61 /// # Arguments
62 ///
63 /// * `config` - Configuration controlling inference thresholds and limits
64 pub fn new(config: InferenceConfig) -> Self {
65 Self { config }
66 }
67
68 /// Infer relationships for a target entity
69 ///
70 /// Analyzes the knowledge graph to find entities that frequently co-occur with
71 /// the target entity and have contextual evidence of a relationship.
72 ///
73 /// # Arguments
74 ///
75 /// * `target_entity` - The entity to find relationships for
76 /// * `relation_type` - The type of relationship to infer (e.g., "FRIENDS")
77 /// * `knowledge_graph` - The knowledge graph containing entities and chunks
78 ///
79 /// # Returns
80 ///
81 /// Returns a vector of inferred relationships, sorted by confidence score,
82 /// limited to `max_candidates` from the configuration.
83 pub fn infer_relationships(
84 &self,
85 target_entity: &EntityId,
86 relation_type: &str,
87 knowledge_graph: &KnowledgeGraph,
88 ) -> Vec<InferredRelation> {
89 let mut inferred_relations = Vec::new();
90
91 // Find target entity
92 let target_ent = knowledge_graph.entities().find(|e| &e.id == target_entity);
93
94 if target_ent.is_none() {
95 return inferred_relations;
96 }
97
98 // Get chunks containing target entity
99 let target_chunks: Vec<_> = knowledge_graph
100 .chunks()
101 .filter(|chunk| chunk.entities.contains(target_entity))
102 .collect();
103
104 // Find co-occurring entities
105 let mut entity_scores: HashMap<EntityId, f32> = HashMap::new();
106
107 for chunk in &target_chunks {
108 for entity_id in &chunk.entities {
109 if entity_id != target_entity {
110 let evidence_score =
111 self.calculate_evidence_score(chunk, target_entity, entity_id);
112 *entity_scores.entry(entity_id.clone()).or_insert(0.0) += evidence_score;
113 }
114 }
115 }
116
117 // Create inferred relations for high-scoring entities
118 for (entity_id, score) in entity_scores {
119 let normalized_score = (score / target_chunks.len() as f32).min(1.0);
120
121 if normalized_score >= self.config.min_confidence {
122 inferred_relations.push(InferredRelation {
123 source: target_entity.clone(),
124 target: entity_id,
125 relation_type: relation_type.to_string(),
126 confidence: normalized_score,
127 evidence_count: target_chunks.len(),
128 });
129 }
130 }
131
132 // Sort by confidence and limit results
133 inferred_relations.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
134 inferred_relations.truncate(self.config.max_candidates);
135
136 inferred_relations
137 }
138
139 /// Calculate evidence score for a potential relationship
140 ///
141 /// Analyzes a text chunk to determine how strongly it suggests a relationship
142 /// between two entities. Uses proximity analysis, pattern matching, and
143 /// contextual clues.
144 ///
145 /// # Arguments
146 ///
147 /// * `chunk` - The text chunk containing both entities
148 /// * `entity_a` - First entity ID
149 /// * `entity_b` - Second entity ID
150 ///
151 /// # Returns
152 ///
153 /// Returns a score between 0.0 and 1.0 indicating relationship strength.
154 /// Higher scores indicate stronger evidence of a relationship.
155 fn calculate_evidence_score(
156 &self,
157 chunk: &TextChunk,
158 entity_a: &EntityId,
159 entity_b: &EntityId,
160 ) -> f32 {
161 let content = &chunk.content.to_lowercase();
162 let mut score: f32 = 0.2; // Lower base co-occurrence score
163
164 // Get entity names for contextual analysis
165 let entity_a_name = self.extract_entity_name(entity_a);
166 let entity_b_name = self.extract_entity_name(entity_b);
167
168 // Calculate proximity score between entities in text
169 let proximity_bonus =
170 self.calculate_proximity_score(content, &entity_a_name, &entity_b_name);
171 score += proximity_bonus;
172
173 // Enhanced friendship indicators with contextual patterns
174 let friendship_patterns = [
175 // Direct friendship terms
176 ("best friend", 0.8),
177 ("close friend", 0.7),
178 ("good friend", 0.6),
179 ("friend", 0.4),
180 ("friends", 0.4),
181 ("friendship", 0.5),
182 // Activity-based friendship indicators
183 ("played together", 0.6),
184 ("went together", 0.5),
185 ("talked with", 0.4),
186 ("helped each other", 0.7),
187 ("shared", 0.3),
188 ("together", 0.3),
189 // Emotional bonding indicators
190 ("trusted", 0.6),
191 ("loyal", 0.5),
192 ("bond", 0.5),
193 ("close", 0.4),
194 ("cared for", 0.6),
195 ("looked after", 0.5),
196 ("protected", 0.6),
197 // Adventure/activity companionship
198 ("adventure", 0.4),
199 ("explore", 0.3),
200 ("journey", 0.3),
201 ("companion", 0.6),
202 ("partner", 0.5),
203 ("ally", 0.5),
204 ];
205
206 // Contextual pattern matching with weighted scores
207 for (pattern, weight) in &friendship_patterns {
208 if content.contains(pattern) {
209 // Additional context bonus if entities are mentioned near the pattern
210 let context_bonus =
211 if self.entities_near_pattern(content, &entity_a_name, &entity_b_name, pattern)
212 {
213 weight * 0.5
214 } else {
215 *weight * 0.3
216 };
217 score += context_bonus;
218 }
219 }
220
221 // Enhanced negative indicators with contextual analysis
222 let negative_patterns = [
223 ("enemy", -0.8),
224 ("enemies", -0.8),
225 ("rival", -0.6),
226 ("rivals", -0.6),
227 ("fought", -0.5),
228 ("fight", -0.4),
229 ("battle", -0.4),
230 ("conflict", -0.5),
231 ("angry at", -0.6),
232 ("hate", -0.7),
233 ("hated", -0.7),
234 ("despise", -0.6),
235 ("betrayed", -0.8),
236 ("betrayal", -0.7),
237 ("argued", -0.3),
238 ("quarrel", -0.4),
239 ("against", -0.2),
240 ("opposed", -0.4),
241 ("disagree", -0.2),
242 ];
243
244 for (pattern, weight) in &negative_patterns {
245 if content.contains(pattern) {
246 let context_penalty =
247 if self.entities_near_pattern(content, &entity_a_name, &entity_b_name, pattern)
248 {
249 weight * 1.2
250 } else {
251 weight * 0.8
252 };
253 score += context_penalty; // weight is already negative
254 }
255 }
256
257 // Family relationship indicators (neutral for friendship)
258 let family_patterns = ["brother", "sister", "cousin", "aunt", "uncle", "family"];
259 let mut has_family_relation = false;
260 for pattern in &family_patterns {
261 if content.contains(pattern) {
262 has_family_relation = true;
263 break;
264 }
265 }
266
267 // Family relations can still be friendships, but lower weight
268 if has_family_relation {
269 score *= 0.8;
270 }
271
272 score.clamp(0.0, 1.0)
273 }
274
275 /// Extract clean entity name from an entity ID
276 ///
277 /// Entity IDs typically have format "TYPE_normalized_name". This method
278 /// extracts just the name portion and formats it for matching.
279 ///
280 /// # Arguments
281 ///
282 /// * `entity_id` - The entity ID to extract the name from
283 ///
284 /// # Returns
285 ///
286 /// Returns the cleaned, lowercase entity name with underscores replaced by spaces
287 fn extract_entity_name(&self, entity_id: &EntityId) -> String {
288 // EntityId format is typically "TYPE_normalized_name"
289 let id_str = &entity_id.0;
290 if let Some(underscore_pos) = id_str.find('_') {
291 id_str[underscore_pos + 1..]
292 .replace('_', " ")
293 .to_lowercase()
294 } else {
295 id_str.to_lowercase()
296 }
297 }
298
299 /// Calculate proximity score between entities in text
300 ///
301 /// Determines how close two entities are mentioned in the text. Closer proximity
302 /// suggests a stronger relationship between the entities.
303 ///
304 /// # Arguments
305 ///
306 /// * `content` - The text content to analyze
307 /// * `entity_a` - Name of the first entity
308 /// * `entity_b` - Name of the second entity
309 ///
310 /// # Returns
311 ///
312 /// Returns a proximity score:
313 /// - 0.4 for very close (0-2 words apart)
314 /// - 0.3 for close (3-5 words apart)
315 /// - 0.2 for medium distance (6-10 words apart)
316 /// - 0.1 for far (11-20 words apart)
317 /// - 0.05 for very far (20+ words apart)
318 fn calculate_proximity_score(&self, content: &str, entity_a: &str, entity_b: &str) -> f32 {
319 let words: Vec<&str> = content.split_whitespace().collect();
320 let mut positions_a = Vec::new();
321 let mut positions_b = Vec::new();
322
323 // Find all positions of entity mentions
324 for (i, word) in words.iter().enumerate() {
325 if word.to_lowercase().contains(entity_a) {
326 positions_a.push(i);
327 }
328 if word.to_lowercase().contains(entity_b) {
329 positions_b.push(i);
330 }
331 }
332
333 if positions_a.is_empty() || positions_b.is_empty() {
334 return 0.0;
335 }
336
337 // Find minimum distance between any mentions
338 let mut min_distance = usize::MAX;
339 for &pos_a in &positions_a {
340 for &pos_b in &positions_b {
341 let distance = pos_a.abs_diff(pos_b);
342 min_distance = min_distance.min(distance);
343 }
344 }
345
346 // Convert distance to proximity score (closer = higher score)
347 match min_distance {
348 0..=2 => 0.4, // Very close (same sentence likely)
349 3..=5 => 0.3, // Close
350 6..=10 => 0.2, // Medium distance
351 11..=20 => 0.1, // Far
352 _ => 0.05, // Very far
353 }
354 }
355
356 /// Check if entities are mentioned near a relationship pattern
357 ///
358 /// Determines if both entities appear within a 200-character window
359 /// around a relationship keyword or pattern. This helps determine if
360 /// a relationship pattern actually applies to these specific entities.
361 ///
362 /// # Arguments
363 ///
364 /// * `content` - The text content to search
365 /// * `entity_a` - Name of the first entity
366 /// * `entity_b` - Name of the second entity
367 /// * `pattern` - The relationship pattern to search for (e.g., "friend", "enemy")
368 ///
369 /// # Returns
370 ///
371 /// Returns `true` if both entities are found within 100 characters before and
372 /// after the pattern, `false` otherwise.
373 fn entities_near_pattern(
374 &self,
375 content: &str,
376 entity_a: &str,
377 entity_b: &str,
378 pattern: &str,
379 ) -> bool {
380 if let Some(pattern_pos) = content.find(pattern) {
381 let start = pattern_pos.saturating_sub(100); // 100 chars before
382 let end = (pattern_pos + pattern.len() + 100).min(content.len()); // 100 chars after
383 let context = &content[start..end];
384
385 context.contains(entity_a) && context.contains(entity_b)
386 } else {
387 false
388 }
389 }
390
391 /// Find an entity in the knowledge graph by name
392 ///
393 /// Performs a case-insensitive substring search to find an entity whose
394 /// name contains the given search string.
395 ///
396 /// # Arguments
397 ///
398 /// * `knowledge_graph` - The knowledge graph to search
399 /// * `name` - The name (or partial name) to search for
400 ///
401 /// # Returns
402 ///
403 /// Returns `Some(&Entity)` if a matching entity is found, `None` otherwise.
404 pub fn find_entity_by_name<'a>(
405 &self,
406 knowledge_graph: &'a KnowledgeGraph,
407 name: &str,
408 ) -> Option<&'a Entity> {
409 knowledge_graph
410 .entities()
411 .find(|e| e.name.to_lowercase().contains(&name.to_lowercase()))
412 }
413}