Skip to main content

graphrag_core/entity/
semantic_merging.rs

1use crate::{
2    core::{Entity, Result},
3    ollama::OllamaClient,
4};
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8/// Decision about merging entities
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct EntityMergeDecision {
11    /// Whether the entities should be merged
12    pub should_merge: bool,
13    /// Confidence in the merge decision (0.0-1.0)
14    pub confidence: f64,
15    /// Reasoning for the decision
16    pub reasoning: String,
17    /// Merged entity description if merging
18    pub merged_description: Option<String>,
19    /// Merged entity name if merging
20    pub merged_name: Option<String>,
21}
22
23/// Entity merger using semantic similarity and optional LLM
24#[derive(Clone)]
25pub struct SemanticEntityMerger {
26    llm_client: Option<OllamaClient>,
27    similarity_threshold: f64,
28    max_description_tokens: usize,
29    use_llm_merging: bool,
30}
31
32impl SemanticEntityMerger {
33    /// Create a new semantic entity merger
34    pub fn new(similarity_threshold: f64) -> Self {
35        Self {
36            llm_client: None,
37            similarity_threshold,
38            max_description_tokens: 512,
39            use_llm_merging: false,
40        }
41    }
42
43    /// Add an LLM client for intelligent merging
44    pub fn with_llm_client(mut self, client: OllamaClient) -> Self {
45        self.llm_client = Some(client);
46        self.use_llm_merging = true;
47        self
48    }
49
50    /// Set maximum tokens for entity descriptions
51    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
52        self.max_description_tokens = max_tokens;
53        self
54    }
55
56    /// Group entities by semantic similarity for potential merging
57    pub async fn group_similar_entities(&self, entities: &[Entity]) -> Result<Vec<Vec<Entity>>> {
58        let mut similarity_groups = Vec::new();
59        let mut processed = HashSet::new();
60
61        for (i, entity1) in entities.iter().enumerate() {
62            if processed.contains(&i) {
63                continue;
64            }
65
66            let mut group = vec![entity1.clone()];
67            processed.insert(i);
68
69            // Find similar entities
70            for (j, entity2) in entities.iter().enumerate() {
71                if i == j || processed.contains(&j) {
72                    continue;
73                }
74
75                let similarity = self.calculate_semantic_similarity(entity1, entity2).await?;
76                if similarity > self.similarity_threshold {
77                    group.push(entity2.clone());
78                    processed.insert(j);
79                }
80            }
81
82            if group.len() > 1 {
83                similarity_groups.push(group);
84            }
85        }
86
87        Ok(similarity_groups)
88    }
89
90    /// Use LLM to decide if entities should be merged and how
91    pub async fn decide_merge(&self, entity_group: &[Entity]) -> Result<EntityMergeDecision> {
92        if !self.use_llm_merging {
93            // Fallback to simple heuristic-based merging
94            return Ok(self.heuristic_merge_decision(entity_group));
95        }
96
97        if let Some(llm_client) = &self.llm_client {
98            let prompt = self.build_merge_decision_prompt(entity_group);
99
100            // Try to get structured response from LLM
101            match self.try_llm_merge_decision(llm_client, &prompt).await {
102                Ok(decision) => Ok(decision),
103                Err(_) => {
104                    tracing::warn!("LLM merge decision failed, falling back to heuristics");
105                    Ok(self.heuristic_merge_decision(entity_group))
106                },
107            }
108        } else {
109            Ok(self.heuristic_merge_decision(entity_group))
110        }
111    }
112
113    async fn try_llm_merge_decision(
114        &self,
115        _llm_client: &OllamaClient,
116        prompt: &str,
117    ) -> Result<EntityMergeDecision> {
118        // For now, simulate an LLM response with a simple heuristic
119        // In a real implementation, this would call the actual LLM
120        let _response = prompt; // Placeholder
121
122        // Simple heuristic for now since we don't have actual LLM integration
123        Ok(EntityMergeDecision {
124            should_merge: true,
125            confidence: 0.8,
126            reasoning: "LLM analysis suggests these entities should be merged".to_string(),
127            merged_name: Some("Merged Entity".to_string()),
128            merged_description: Some("Merged based on LLM analysis".to_string()),
129        })
130    }
131
132    fn heuristic_merge_decision(&self, entity_group: &[Entity]) -> EntityMergeDecision {
133        if entity_group.len() < 2 {
134            return EntityMergeDecision {
135                should_merge: false,
136                confidence: 1.0,
137                reasoning: "Only one entity in group".to_string(),
138                merged_name: None,
139                merged_description: None,
140            };
141        }
142
143        // Simple heuristic: merge if names are very similar and types match
144        let first_entity = &entity_group[0];
145        let all_same_type = entity_group
146            .iter()
147            .all(|e| e.entity_type == first_entity.entity_type);
148
149        if all_same_type {
150            let name_similarity = self.calculate_name_similarity_heuristic(entity_group);
151
152            if name_similarity > 0.8 {
153                let merged_name = self.select_best_name(entity_group);
154                let merged_description = self.combine_descriptions(entity_group);
155
156                EntityMergeDecision {
157                    should_merge: true,
158                    confidence: name_similarity,
159                    reasoning: format!(
160                        "High name similarity ({name_similarity:.2}) and matching types"
161                    ),
162                    merged_name: Some(merged_name),
163                    merged_description: Some(merged_description),
164                }
165            } else {
166                EntityMergeDecision {
167                    should_merge: false,
168                    confidence: 1.0 - name_similarity,
169                    reasoning: format!("Low name similarity ({name_similarity:.2})"),
170                    merged_name: None,
171                    merged_description: None,
172                }
173            }
174        } else {
175            EntityMergeDecision {
176                should_merge: false,
177                confidence: 1.0,
178                reasoning: "Different entity types".to_string(),
179                merged_name: None,
180                merged_description: None,
181            }
182        }
183    }
184
185    fn calculate_name_similarity_heuristic(&self, entities: &[Entity]) -> f64 {
186        if entities.len() < 2 {
187            return 1.0;
188        }
189
190        let mut total_similarity = 0.0;
191        let mut comparisons = 0;
192
193        for i in 0..entities.len() {
194            for j in i + 1..entities.len() {
195                let similarity = self.string_similarity(&entities[i].name, &entities[j].name);
196                total_similarity += similarity;
197                comparisons += 1;
198            }
199        }
200
201        if comparisons > 0 {
202            total_similarity / comparisons as f64
203        } else {
204            0.0
205        }
206    }
207
208    fn string_similarity(&self, s1: &str, s2: &str) -> f64 {
209        let s1_lower = s1.to_lowercase();
210        let s2_lower = s2.to_lowercase();
211
212        // Exact match
213        if s1_lower == s2_lower {
214            return 1.0;
215        }
216
217        // One contains the other
218        if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
219            return 0.9;
220        }
221
222        // Jaccard similarity on words
223        let words1: HashSet<&str> = s1_lower.split_whitespace().collect();
224        let words2: HashSet<&str> = s2_lower.split_whitespace().collect();
225
226        let intersection = words1.intersection(&words2).count();
227        let union = words1.union(&words2).count();
228
229        if union == 0 {
230            0.0
231        } else {
232            intersection as f64 / union as f64
233        }
234    }
235
236    fn select_best_name(&self, entities: &[Entity]) -> String {
237        // Select the longest name or the one with highest confidence
238        entities
239            .iter()
240            .max_by(|a, b| {
241                let length_cmp = a.name.len().cmp(&b.name.len());
242                if length_cmp == std::cmp::Ordering::Equal {
243                    a.confidence
244                        .partial_cmp(&b.confidence)
245                        .unwrap_or(std::cmp::Ordering::Equal)
246                } else {
247                    length_cmp
248                }
249            })
250            .map(|e| e.name.clone())
251            .unwrap_or_else(|| "Merged Entity".to_string())
252    }
253
254    fn combine_descriptions(&self, entities: &[Entity]) -> String {
255        let descriptions: Vec<String> = entities
256            .iter()
257            .map(|e| {
258                if let Some(_desc) = e.mentions.first() {
259                    format!("Entity '{}' mentioned in context", e.name)
260                } else {
261                    format!("Entity '{}' of type {}", e.name, e.entity_type)
262                }
263            })
264            .collect();
265
266        if descriptions.is_empty() {
267            "Merged entity from multiple sources".to_string()
268        } else {
269            descriptions.join("; ")
270        }
271    }
272
273    fn build_merge_decision_prompt(&self, entities: &[Entity]) -> String {
274        let mut prompt = String::from(
275            "Analyze the following entities and determine if they represent the same real-world entity:\n\n"
276        );
277
278        for (i, entity) in entities.iter().enumerate() {
279            let description = if entity.mentions.is_empty() {
280                "No description".to_string()
281            } else {
282                format!("Mentioned {} times", entity.mentions.len())
283            };
284
285            prompt.push_str(&format!(
286                "Entity {}: {}\n  Type: {}\n  Confidence: {:.2}\n  Description: {}\n\n",
287                i + 1,
288                entity.name,
289                entity.entity_type,
290                entity.confidence,
291                description
292            ));
293        }
294
295        prompt.push_str(
296            "Consider:\n\
297             1. Are these entities referring to the same real-world entity?\n\
298             2. Do they have compatible descriptions and contexts?\n\
299             3. If merged, what would be the best combined name and description?\n\n\
300             Respond with 'YES' if they should be merged, 'NO' if they should remain separate.\n\
301             Briefly explain your reasoning.",
302        );
303
304        prompt
305    }
306
307    async fn calculate_semantic_similarity(
308        &self,
309        entity1: &Entity,
310        entity2: &Entity,
311    ) -> Result<f64> {
312        // For now, use string-based similarity
313        // In a real implementation with embeddings, this would use cosine similarity
314
315        // Check name similarity
316        let name_sim = self.string_similarity(&entity1.name, &entity2.name);
317
318        // Check type compatibility
319        let type_sim = if entity1.entity_type == entity2.entity_type {
320            1.0
321        } else {
322            0.0
323        };
324
325        // Weighted combination
326        let combined_similarity = name_sim * 0.7 + type_sim * 0.3;
327
328        Ok(combined_similarity)
329    }
330
331    /// Perform the actual entity merging based on decision
332    pub fn merge_entities(
333        &self,
334        entities: Vec<Entity>,
335        decision: &EntityMergeDecision,
336    ) -> Result<Entity> {
337        if entities.is_empty() {
338            return Err(crate::core::GraphRAGError::Config {
339                message: "No entities to merge".to_string(),
340            });
341        }
342
343        if !decision.should_merge {
344            return Ok(entities[0].clone());
345        }
346
347        let merged_name = decision
348            .merged_name
349            .clone()
350            .unwrap_or_else(|| self.select_best_name(&entities));
351
352        // Combine all mentions
353        let mut all_mentions = Vec::new();
354        let mut total_confidence = 0.0;
355
356        for entity in &entities {
357            all_mentions.extend(entity.mentions.clone());
358            total_confidence += entity.confidence;
359        }
360
361        let avg_confidence = if entities.is_empty() {
362            0.0
363        } else {
364            total_confidence / entities.len() as f32
365        };
366
367        // Create merged entity
368        let merged_entity = Entity {
369            id: entities[0].id.clone(), // Keep the first entity's ID
370            name: merged_name,
371            entity_type: entities[0].entity_type.clone(),
372            confidence: avg_confidence.max(decision.confidence as f32),
373            mentions: all_mentions,
374            embedding: entities[0].embedding.clone(), // Take first embedding
375            first_mentioned: None,
376            last_mentioned: None,
377            temporal_validity: None,
378        };
379
380        Ok(merged_entity)
381    }
382
383    /// Get merging statistics
384    pub fn get_statistics(&self) -> MergingStatistics {
385        MergingStatistics {
386            similarity_threshold: self.similarity_threshold,
387            max_description_tokens: self.max_description_tokens,
388            uses_llm: self.use_llm_merging,
389            llm_available: self.llm_client.is_some(),
390        }
391    }
392}
393
394/// Statistics for entity merging process
395#[derive(Debug, Clone)]
396pub struct MergingStatistics {
397    /// Similarity threshold for merging (0.0-1.0)
398    pub similarity_threshold: f64,
399    /// Maximum tokens for entity descriptions
400    pub max_description_tokens: usize,
401    /// Whether LLM is used for merging
402    pub uses_llm: bool,
403    /// Whether LLM client is available
404    pub llm_available: bool,
405}
406
407impl MergingStatistics {
408    /// Print statistics to stdout
409    #[allow(dead_code)]
410    pub fn print(&self) {
411        tracing::info!("Entity Merging Statistics");
412        tracing::info!("  Similarity threshold: {:.2}", self.similarity_threshold);
413        tracing::info!("  Max description tokens: {}", self.max_description_tokens);
414        tracing::info!("  Uses LLM: {}", self.uses_llm);
415        tracing::info!("  LLM available: {}", self.llm_available);
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::core::{ChunkId, EntityId, EntityMention};
423
424    fn create_test_entities() -> Vec<Entity> {
425        vec![
426            Entity::new(
427                EntityId::new("entity1".to_string()),
428                "Apple Inc".to_string(),
429                "ORGANIZATION".to_string(),
430                0.9,
431            ),
432            Entity::new(
433                EntityId::new("entity2".to_string()),
434                "Apple Inc.".to_string(),
435                "ORGANIZATION".to_string(),
436                0.8,
437            ),
438            Entity::new(
439                EntityId::new("entity3".to_string()),
440                "Microsoft".to_string(),
441                "ORGANIZATION".to_string(),
442                0.9,
443            ),
444        ]
445    }
446
447    #[test]
448    fn test_semantic_entity_merger_creation() {
449        let merger = SemanticEntityMerger::new(0.8);
450        let stats = merger.get_statistics();
451
452        assert_eq!(stats.similarity_threshold, 0.8);
453        assert!(!stats.uses_llm);
454        assert!(!stats.llm_available);
455    }
456
457    #[tokio::test]
458    async fn test_entity_grouping() {
459        let merger = SemanticEntityMerger::new(0.7);
460        let entities = create_test_entities();
461
462        let groups = merger.group_similar_entities(&entities).await.unwrap();
463
464        // Should group Apple entities together
465        assert!(!groups.is_empty());
466
467        // Find the Apple group
468        let apple_group = groups
469            .iter()
470            .find(|group| group.iter().any(|e| e.name.contains("Apple")));
471
472        assert!(apple_group.is_some());
473        let apple_group = apple_group.unwrap();
474        assert_eq!(apple_group.len(), 2); // Apple Inc and Apple Inc.
475    }
476
477    #[test]
478    fn test_heuristic_merge_decision() {
479        let merger = SemanticEntityMerger::new(0.8);
480        let entities = vec![
481            Entity::new(
482                EntityId::new("entity1".to_string()),
483                "Apple Inc".to_string(),
484                "ORGANIZATION".to_string(),
485                0.9,
486            ),
487            Entity::new(
488                EntityId::new("entity2".to_string()),
489                "Apple Inc.".to_string(),
490                "ORGANIZATION".to_string(),
491                0.8,
492            ),
493        ];
494
495        let decision = merger.heuristic_merge_decision(&entities);
496
497        assert!(decision.should_merge);
498        assert!(decision.confidence > 0.8);
499        assert!(decision.merged_name.is_some());
500    }
501
502    #[test]
503    fn test_string_similarity() {
504        let merger = SemanticEntityMerger::new(0.8);
505
506        assert_eq!(merger.string_similarity("Apple", "Apple"), 1.0);
507        assert!(merger.string_similarity("Apple Inc", "Apple Inc.") > 0.8);
508        assert!(merger.string_similarity("Apple", "Microsoft") < 0.3);
509    }
510
511    #[test]
512    fn test_entity_merging() {
513        let merger = SemanticEntityMerger::new(0.8);
514
515        let entities = vec![
516            Entity::new(
517                EntityId::new("entity1".to_string()),
518                "Apple Inc".to_string(),
519                "ORGANIZATION".to_string(),
520                0.9,
521            )
522            .with_mentions(vec![EntityMention {
523                chunk_id: ChunkId::new("chunk1".to_string()),
524                start_offset: 0,
525                end_offset: 9,
526                confidence: 0.9,
527            }]),
528            Entity::new(
529                EntityId::new("entity2".to_string()),
530                "Apple Inc.".to_string(),
531                "ORGANIZATION".to_string(),
532                0.8,
533            )
534            .with_mentions(vec![EntityMention {
535                chunk_id: ChunkId::new("chunk2".to_string()),
536                start_offset: 0,
537                end_offset: 10,
538                confidence: 0.8,
539            }]),
540        ];
541
542        let decision = EntityMergeDecision {
543            should_merge: true,
544            confidence: 0.9,
545            reasoning: "Test merge".to_string(),
546            merged_name: Some("Apple Inc.".to_string()),
547            merged_description: Some("Merged Apple entity".to_string()),
548        };
549
550        let merged = merger.merge_entities(entities, &decision).unwrap();
551
552        assert_eq!(merged.name, "Apple Inc.");
553        assert_eq!(merged.mentions.len(), 2); // Combined mentions
554        assert!(merged.confidence >= 0.8);
555    }
556}