graphrag_core/inference.rs
1//! Implicit relationship inference system
2
3use crate::core::{Entity, EntityId, KnowledgeGraph, TextChunk};
4use std::collections::HashMap;
5
6/// Represents a relationship inferred between two entities
7///
8/// This structure contains information about a relationship discovered through
9/// co-occurrence analysis and contextual pattern matching.
10#[derive(Debug, Clone)]
11pub struct InferredRelation {
12 /// Source entity in the relationship
13 pub source: EntityId,
14 /// Target entity in the relationship
15 pub target: EntityId,
16 /// Type of relationship (e.g., "FRIENDS", "DISCUSSES", "LOCATED_IN")
17 pub relation_type: String,
18 /// Confidence score for this inference (0.0-1.0)
19 pub confidence: f32,
20 /// Number of text chunks providing evidence for this relationship
21 pub evidence_count: usize,
22}
23
24/// Configuration for the relationship inference engine
25///
26/// Controls the behavior and thresholds used when inferring implicit relationships
27/// between entities based on their co-occurrence in text.
28#[derive(Debug, Clone)]
29pub struct InferenceConfig {
30 /// Minimum confidence threshold for accepting an inferred relationship
31 pub min_confidence: f32,
32 /// Maximum number of candidate relationships to return per query
33 pub max_candidates: usize,
34 /// Threshold for determining if entities co-occur frequently enough
35 pub co_occurrence_threshold: f32,
36}
37
38impl Default for InferenceConfig {
39 fn default() -> Self {
40 Self {
41 min_confidence: 0.3,
42 max_candidates: 10,
43 co_occurrence_threshold: 0.4,
44 }
45 }
46}
47
48/// Engine for inferring implicit relationships between entities
49///
50/// The inference engine analyzes entity co-occurrence patterns, proximity,
51/// and contextual clues to discover relationships that may not be explicitly
52/// stated in the text.
53pub struct InferenceEngine {
54 /// Configuration controlling inference behavior
55 config: InferenceConfig,
56}
57
58impl InferenceEngine {
59 /// Create a new inference engine with the given configuration
60 ///
61 /// # Arguments
62 ///
63 /// * `config` - Configuration controlling inference thresholds and limits
64 pub fn new(config: InferenceConfig) -> Self {
65 Self { config }
66 }
67
68 /// Infer relationships for a target entity
69 ///
70 /// Analyzes the knowledge graph to find entities that frequently co-occur with
71 /// the target entity and have contextual evidence of a relationship.
72 ///
73 /// # Arguments
74 ///
75 /// * `target_entity` - The entity to find relationships for
76 /// * `relation_type` - The type of relationship to infer (e.g., "FRIENDS")
77 /// * `knowledge_graph` - The knowledge graph containing entities and chunks
78 ///
79 /// # Returns
80 ///
81 /// Returns a vector of inferred relationships, sorted by confidence score,
82 /// limited to `max_candidates` from the configuration.
83 pub fn infer_relationships(
84 &self,
85 target_entity: &EntityId,
86 relation_type: &str,
87 knowledge_graph: &KnowledgeGraph,
88 ) -> Vec<InferredRelation> {
89 let mut inferred_relations = Vec::new();
90
91 // Find target entity
92 let target_ent = knowledge_graph.entities().find(|e| &e.id == target_entity);
93
94 if target_ent.is_none() {
95 return inferred_relations;
96 }
97
98 // Get chunks containing target entity
99 let target_chunks: Vec<_> = knowledge_graph
100 .chunks()
101 .filter(|chunk| chunk.entities.contains(target_entity))
102 .collect();
103
104 // Find co-occurring entities
105 let mut entity_scores: HashMap<EntityId, f32> = HashMap::new();
106
107 for chunk in &target_chunks {
108 for entity_id in &chunk.entities {
109 if entity_id != target_entity {
110 let evidence_score =
111 self.calculate_evidence_score(chunk, target_entity, entity_id);
112 *entity_scores.entry(entity_id.clone()).or_insert(0.0) += evidence_score;
113 }
114 }
115 }
116
117 // Create inferred relations for high-scoring entities
118 for (entity_id, score) in entity_scores {
119 let normalized_score = (score / target_chunks.len() as f32).min(1.0);
120
121 if normalized_score >= self.config.min_confidence {
122 inferred_relations.push(InferredRelation {
123 source: target_entity.clone(),
124 target: entity_id,
125 relation_type: relation_type.to_string(),
126 confidence: normalized_score,
127 evidence_count: target_chunks.len(),
128 });
129 }
130 }
131
132 // Sort by confidence and limit results
133 inferred_relations.sort_by(|a, b| {
134 b.confidence
135 .partial_cmp(&a.confidence)
136 .unwrap_or(std::cmp::Ordering::Equal)
137 });
138 inferred_relations.truncate(self.config.max_candidates);
139
140 inferred_relations
141 }
142
143 /// Calculate evidence score for a potential relationship
144 ///
145 /// Analyzes a text chunk to determine how strongly it suggests a relationship
146 /// between two entities. Uses proximity analysis, pattern matching, and
147 /// contextual clues.
148 ///
149 /// # Arguments
150 ///
151 /// * `chunk` - The text chunk containing both entities
152 /// * `entity_a` - First entity ID
153 /// * `entity_b` - Second entity ID
154 ///
155 /// # Returns
156 ///
157 /// Returns a score between 0.0 and 1.0 indicating relationship strength.
158 /// Higher scores indicate stronger evidence of a relationship.
159 fn calculate_evidence_score(
160 &self,
161 chunk: &TextChunk,
162 entity_a: &EntityId,
163 entity_b: &EntityId,
164 ) -> f32 {
165 let content = &chunk.content.to_lowercase();
166 let mut score: f32 = 0.2; // Lower base co-occurrence score
167
168 // Get entity names for contextual analysis
169 let entity_a_name = self.extract_entity_name(entity_a);
170 let entity_b_name = self.extract_entity_name(entity_b);
171
172 // Calculate proximity score between entities in text
173 let proximity_bonus =
174 self.calculate_proximity_score(content, &entity_a_name, &entity_b_name);
175 score += proximity_bonus;
176
177 // Enhanced friendship indicators with contextual patterns
178 let friendship_patterns = [
179 // Direct friendship terms
180 ("best friend", 0.8),
181 ("close friend", 0.7),
182 ("good friend", 0.6),
183 ("friend", 0.4),
184 ("friends", 0.4),
185 ("friendship", 0.5),
186 // Activity-based friendship indicators
187 ("played together", 0.6),
188 ("went together", 0.5),
189 ("talked with", 0.4),
190 ("helped each other", 0.7),
191 ("shared", 0.3),
192 ("together", 0.3),
193 // Emotional bonding indicators
194 ("trusted", 0.6),
195 ("loyal", 0.5),
196 ("bond", 0.5),
197 ("close", 0.4),
198 ("cared for", 0.6),
199 ("looked after", 0.5),
200 ("protected", 0.6),
201 // Adventure/activity companionship
202 ("adventure", 0.4),
203 ("explore", 0.3),
204 ("journey", 0.3),
205 ("companion", 0.6),
206 ("partner", 0.5),
207 ("ally", 0.5),
208 ];
209
210 // Contextual pattern matching with weighted scores
211 for (pattern, weight) in &friendship_patterns {
212 if content.contains(pattern) {
213 // Additional context bonus if entities are mentioned near the pattern
214 let context_bonus =
215 if self.entities_near_pattern(content, &entity_a_name, &entity_b_name, pattern)
216 {
217 weight * 0.5
218 } else {
219 *weight * 0.3
220 };
221 score += context_bonus;
222 }
223 }
224
225 // Enhanced negative indicators with contextual analysis
226 let negative_patterns = [
227 ("enemy", -0.8),
228 ("enemies", -0.8),
229 ("rival", -0.6),
230 ("rivals", -0.6),
231 ("fought", -0.5),
232 ("fight", -0.4),
233 ("battle", -0.4),
234 ("conflict", -0.5),
235 ("angry at", -0.6),
236 ("hate", -0.7),
237 ("hated", -0.7),
238 ("despise", -0.6),
239 ("betrayed", -0.8),
240 ("betrayal", -0.7),
241 ("argued", -0.3),
242 ("quarrel", -0.4),
243 ("against", -0.2),
244 ("opposed", -0.4),
245 ("disagree", -0.2),
246 ];
247
248 for (pattern, weight) in &negative_patterns {
249 if content.contains(pattern) {
250 let context_penalty =
251 if self.entities_near_pattern(content, &entity_a_name, &entity_b_name, pattern)
252 {
253 weight * 1.2
254 } else {
255 weight * 0.8
256 };
257 score += context_penalty; // weight is already negative
258 }
259 }
260
261 // Family relationship indicators (neutral for friendship)
262 let family_patterns = ["brother", "sister", "cousin", "aunt", "uncle", "family"];
263 let mut has_family_relation = false;
264 for pattern in &family_patterns {
265 if content.contains(pattern) {
266 has_family_relation = true;
267 break;
268 }
269 }
270
271 // Family relations can still be friendships, but lower weight
272 if has_family_relation {
273 score *= 0.8;
274 }
275
276 score.clamp(0.0, 1.0)
277 }
278
279 /// Extract clean entity name from an entity ID
280 ///
281 /// Entity IDs typically have format "TYPE_normalized_name". This method
282 /// extracts just the name portion and formats it for matching.
283 ///
284 /// # Arguments
285 ///
286 /// * `entity_id` - The entity ID to extract the name from
287 ///
288 /// # Returns
289 ///
290 /// Returns the cleaned, lowercase entity name with underscores replaced by spaces
291 fn extract_entity_name(&self, entity_id: &EntityId) -> String {
292 // EntityId format is typically "TYPE_normalized_name"
293 let id_str = &entity_id.0;
294 if let Some(underscore_pos) = id_str.find('_') {
295 id_str[underscore_pos + 1..]
296 .replace('_', " ")
297 .to_lowercase()
298 } else {
299 id_str.to_lowercase()
300 }
301 }
302
303 /// Calculate proximity score between entities in text
304 ///
305 /// Determines how close two entities are mentioned in the text. Closer proximity
306 /// suggests a stronger relationship between the entities.
307 ///
308 /// # Arguments
309 ///
310 /// * `content` - The text content to analyze
311 /// * `entity_a` - Name of the first entity
312 /// * `entity_b` - Name of the second entity
313 ///
314 /// # Returns
315 ///
316 /// Returns a proximity score:
317 /// - 0.4 for very close (0-2 words apart)
318 /// - 0.3 for close (3-5 words apart)
319 /// - 0.2 for medium distance (6-10 words apart)
320 /// - 0.1 for far (11-20 words apart)
321 /// - 0.05 for very far (20+ words apart)
322 fn calculate_proximity_score(&self, content: &str, entity_a: &str, entity_b: &str) -> f32 {
323 let words: Vec<&str> = content.split_whitespace().collect();
324 let mut positions_a = Vec::new();
325 let mut positions_b = Vec::new();
326
327 // Find all positions of entity mentions
328 for (i, word) in words.iter().enumerate() {
329 if word.to_lowercase().contains(entity_a) {
330 positions_a.push(i);
331 }
332 if word.to_lowercase().contains(entity_b) {
333 positions_b.push(i);
334 }
335 }
336
337 if positions_a.is_empty() || positions_b.is_empty() {
338 return 0.0;
339 }
340
341 // Find minimum distance between any mentions
342 let mut min_distance = usize::MAX;
343 for &pos_a in &positions_a {
344 for &pos_b in &positions_b {
345 let distance = pos_a.abs_diff(pos_b);
346 min_distance = min_distance.min(distance);
347 }
348 }
349
350 // Convert distance to proximity score (closer = higher score)
351 match min_distance {
352 0..=2 => 0.4, // Very close (same sentence likely)
353 3..=5 => 0.3, // Close
354 6..=10 => 0.2, // Medium distance
355 11..=20 => 0.1, // Far
356 _ => 0.05, // Very far
357 }
358 }
359
360 /// Check if entities are mentioned near a relationship pattern
361 ///
362 /// Determines if both entities appear within a 200-character window
363 /// around a relationship keyword or pattern. This helps determine if
364 /// a relationship pattern actually applies to these specific entities.
365 ///
366 /// # Arguments
367 ///
368 /// * `content` - The text content to search
369 /// * `entity_a` - Name of the first entity
370 /// * `entity_b` - Name of the second entity
371 /// * `pattern` - The relationship pattern to search for (e.g., "friend", "enemy")
372 ///
373 /// # Returns
374 ///
375 /// Returns `true` if both entities are found within 100 characters before and
376 /// after the pattern, `false` otherwise.
377 fn entities_near_pattern(
378 &self,
379 content: &str,
380 entity_a: &str,
381 entity_b: &str,
382 pattern: &str,
383 ) -> bool {
384 if let Some(pattern_pos) = content.find(pattern) {
385 let start = pattern_pos.saturating_sub(100); // 100 chars before
386 let end = (pattern_pos + pattern.len() + 100).min(content.len()); // 100 chars after
387 let context = &content[start..end];
388
389 context.contains(entity_a) && context.contains(entity_b)
390 } else {
391 false
392 }
393 }
394
395 /// Find an entity in the knowledge graph by name
396 ///
397 /// Performs a case-insensitive substring search to find an entity whose
398 /// name contains the given search string.
399 ///
400 /// # Arguments
401 ///
402 /// * `knowledge_graph` - The knowledge graph to search
403 /// * `name` - The name (or partial name) to search for
404 ///
405 /// # Returns
406 ///
407 /// Returns `Some(&Entity)` if a matching entity is found, `None` otherwise.
408 pub fn find_entity_by_name<'a>(
409 &self,
410 knowledge_graph: &'a KnowledgeGraph,
411 name: &str,
412 ) -> Option<&'a Entity> {
413 knowledge_graph
414 .entities()
415 .find(|e| e.name.to_lowercase().contains(&name.to_lowercase()))
416 }
417}