Skip to main content

graphrag_core/entity/
semantic_merging.rs

1use crate::{
2    core::{Entity, Result},
3    ollama::OllamaClient,
4};
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8/// Decision about merging entities
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct EntityMergeDecision {
11    /// Whether the entities should be merged
12    pub should_merge: bool,
13    /// Confidence in the merge decision (0.0-1.0)
14    pub confidence: f64,
15    /// Reasoning for the decision
16    pub reasoning: String,
17    /// Merged entity description if merging
18    pub merged_description: Option<String>,
19    /// Merged entity name if merging
20    pub merged_name: Option<String>,
21}
22
23/// Entity merger using semantic similarity and optional LLM
24#[derive(Clone)]
25pub struct SemanticEntityMerger {
26    llm_client: Option<OllamaClient>,
27    similarity_threshold: f64,
28    max_description_tokens: usize,
29    use_llm_merging: bool,
30}
31
32impl SemanticEntityMerger {
33    /// Create a new semantic entity merger
34    pub fn new(similarity_threshold: f64) -> Self {
35        Self {
36            llm_client: None,
37            similarity_threshold,
38            max_description_tokens: 512,
39            use_llm_merging: false,
40        }
41    }
42
43    /// Add an LLM client for intelligent merging
44    pub fn with_llm_client(mut self, client: OllamaClient) -> Self {
45        self.llm_client = Some(client);
46        self.use_llm_merging = true;
47        self
48    }
49
50    /// Set maximum tokens for entity descriptions
51    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
52        self.max_description_tokens = max_tokens;
53        self
54    }
55
56    /// Group entities by semantic similarity for potential merging
57    pub async fn group_similar_entities(&self, entities: &[Entity]) -> Result<Vec<Vec<Entity>>> {
58        let mut similarity_groups = Vec::new();
59        let mut processed = HashSet::new();
60
61        for (i, entity1) in entities.iter().enumerate() {
62            if processed.contains(&i) {
63                continue;
64            }
65
66            let mut group = vec![entity1.clone()];
67            processed.insert(i);
68
69            // Find similar entities
70            for (j, entity2) in entities.iter().enumerate() {
71                if i == j || processed.contains(&j) {
72                    continue;
73                }
74
75                let similarity = self.calculate_semantic_similarity(entity1, entity2).await?;
76                if similarity > self.similarity_threshold {
77                    group.push(entity2.clone());
78                    processed.insert(j);
79                }
80            }
81
82            if group.len() > 1 {
83                similarity_groups.push(group);
84            }
85        }
86
87        Ok(similarity_groups)
88    }
89
90    /// Use LLM to decide if entities should be merged and how
91    pub async fn decide_merge(&self, entity_group: &[Entity]) -> Result<EntityMergeDecision> {
92        if !self.use_llm_merging {
93            // Fallback to simple heuristic-based merging
94            return Ok(self.heuristic_merge_decision(entity_group));
95        }
96
97        if let Some(llm_client) = &self.llm_client {
98            let prompt = self.build_merge_decision_prompt(entity_group);
99
100            // Try to get structured response from LLM
101            match self.try_llm_merge_decision(llm_client, &prompt).await {
102                Ok(decision) => Ok(decision),
103                Err(_) => {
104                    tracing::warn!("LLM merge decision failed, falling back to heuristics");
105                    Ok(self.heuristic_merge_decision(entity_group))
106                }
107            }
108        } else {
109            Ok(self.heuristic_merge_decision(entity_group))
110        }
111    }
112
113    async fn try_llm_merge_decision(
114        &self,
115        _llm_client: &OllamaClient,
116        prompt: &str,
117    ) -> Result<EntityMergeDecision> {
118        // For now, simulate an LLM response with a simple heuristic
119        // In a real implementation, this would call the actual LLM
120        let _response = prompt; // Placeholder
121
122        // Simple heuristic for now since we don't have actual LLM integration
123        Ok(EntityMergeDecision {
124            should_merge: true,
125            confidence: 0.8,
126            reasoning: "LLM analysis suggests these entities should be merged".to_string(),
127            merged_name: Some("Merged Entity".to_string()),
128            merged_description: Some("Merged based on LLM analysis".to_string()),
129        })
130    }
131
132    fn heuristic_merge_decision(&self, entity_group: &[Entity]) -> EntityMergeDecision {
133        if entity_group.len() < 2 {
134            return EntityMergeDecision {
135                should_merge: false,
136                confidence: 1.0,
137                reasoning: "Only one entity in group".to_string(),
138                merged_name: None,
139                merged_description: None,
140            };
141        }
142
143        // Simple heuristic: merge if names are very similar and types match
144        let first_entity = &entity_group[0];
145        let all_same_type = entity_group
146            .iter()
147            .all(|e| e.entity_type == first_entity.entity_type);
148
149        if all_same_type {
150            let name_similarity = self.calculate_name_similarity_heuristic(entity_group);
151
152            if name_similarity > 0.8 {
153                let merged_name = self.select_best_name(entity_group);
154                let merged_description = self.combine_descriptions(entity_group);
155
156                EntityMergeDecision {
157                    should_merge: true,
158                    confidence: name_similarity,
159                    reasoning: format!(
160                        "High name similarity ({name_similarity:.2}) and matching types"
161                    ),
162                    merged_name: Some(merged_name),
163                    merged_description: Some(merged_description),
164                }
165            } else {
166                EntityMergeDecision {
167                    should_merge: false,
168                    confidence: 1.0 - name_similarity,
169                    reasoning: format!("Low name similarity ({name_similarity:.2})"),
170                    merged_name: None,
171                    merged_description: None,
172                }
173            }
174        } else {
175            EntityMergeDecision {
176                should_merge: false,
177                confidence: 1.0,
178                reasoning: "Different entity types".to_string(),
179                merged_name: None,
180                merged_description: None,
181            }
182        }
183    }
184
185    fn calculate_name_similarity_heuristic(&self, entities: &[Entity]) -> f64 {
186        if entities.len() < 2 {
187            return 1.0;
188        }
189
190        let mut total_similarity = 0.0;
191        let mut comparisons = 0;
192
193        for i in 0..entities.len() {
194            for j in i + 1..entities.len() {
195                let similarity = self.string_similarity(&entities[i].name, &entities[j].name);
196                total_similarity += similarity;
197                comparisons += 1;
198            }
199        }
200
201        if comparisons > 0 {
202            total_similarity / comparisons as f64
203        } else {
204            0.0
205        }
206    }
207
208    fn string_similarity(&self, s1: &str, s2: &str) -> f64 {
209        let s1_lower = s1.to_lowercase();
210        let s2_lower = s2.to_lowercase();
211
212        // Exact match
213        if s1_lower == s2_lower {
214            return 1.0;
215        }
216
217        // One contains the other
218        if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
219            return 0.9;
220        }
221
222        // Jaccard similarity on words
223        let words1: HashSet<&str> = s1_lower.split_whitespace().collect();
224        let words2: HashSet<&str> = s2_lower.split_whitespace().collect();
225
226        let intersection = words1.intersection(&words2).count();
227        let union = words1.union(&words2).count();
228
229        if union == 0 {
230            0.0
231        } else {
232            intersection as f64 / union as f64
233        }
234    }
235
236    fn select_best_name(&self, entities: &[Entity]) -> String {
237        // Select the longest name or the one with highest confidence
238        entities
239            .iter()
240            .max_by(|a, b| {
241                let length_cmp = a.name.len().cmp(&b.name.len());
242                if length_cmp == std::cmp::Ordering::Equal {
243                    a.confidence
244                        .partial_cmp(&b.confidence)
245                        .unwrap_or(std::cmp::Ordering::Equal)
246                } else {
247                    length_cmp
248                }
249            })
250            .map(|e| e.name.clone())
251            .unwrap_or_else(|| "Merged Entity".to_string())
252    }
253
254    fn combine_descriptions(&self, entities: &[Entity]) -> String {
255        let descriptions: Vec<String> = entities
256            .iter()
257            .map(|e| {
258                if let Some(_desc) = e.mentions.first() {
259                    format!("Entity '{}' mentioned in context", e.name)
260                } else {
261                    format!("Entity '{}' of type {}", e.name, e.entity_type)
262                }
263            })
264            .collect();
265
266        if descriptions.is_empty() {
267            "Merged entity from multiple sources".to_string()
268        } else {
269            descriptions.join("; ")
270        }
271    }
272
273    fn build_merge_decision_prompt(&self, entities: &[Entity]) -> String {
274        let mut prompt = String::from(
275            "Analyze the following entities and determine if they represent the same real-world entity:\n\n"
276        );
277
278        for (i, entity) in entities.iter().enumerate() {
279            let description = if entity.mentions.is_empty() {
280                "No description".to_string()
281            } else {
282                format!("Mentioned {} times", entity.mentions.len())
283            };
284
285            prompt.push_str(&format!(
286                "Entity {}: {}\n  Type: {}\n  Confidence: {:.2}\n  Description: {}\n\n",
287                i + 1,
288                entity.name,
289                entity.entity_type,
290                entity.confidence,
291                description
292            ));
293        }
294
295        prompt.push_str(
296            "Consider:\n\
297             1. Are these entities referring to the same real-world entity?\n\
298             2. Do they have compatible descriptions and contexts?\n\
299             3. If merged, what would be the best combined name and description?\n\n\
300             Respond with 'YES' if they should be merged, 'NO' if they should remain separate.\n\
301             Briefly explain your reasoning.",
302        );
303
304        prompt
305    }
306
307    async fn calculate_semantic_similarity(
308        &self,
309        entity1: &Entity,
310        entity2: &Entity,
311    ) -> Result<f64> {
312        // For now, use string-based similarity
313        // In a real implementation with embeddings, this would use cosine similarity
314
315        // Check name similarity
316        let name_sim = self.string_similarity(&entity1.name, &entity2.name);
317
318        // Check type compatibility
319        let type_sim = if entity1.entity_type == entity2.entity_type {
320            1.0
321        } else {
322            0.0
323        };
324
325        // Weighted combination
326        let combined_similarity = name_sim * 0.7 + type_sim * 0.3;
327
328        Ok(combined_similarity)
329    }
330
331    /// Perform the actual entity merging based on decision
332    pub fn merge_entities(
333        &self,
334        entities: Vec<Entity>,
335        decision: &EntityMergeDecision,
336    ) -> Result<Entity> {
337        if entities.is_empty() {
338            return Err(crate::core::GraphRAGError::Config {
339                message: "No entities to merge".to_string(),
340            });
341        }
342
343        if !decision.should_merge {
344            return Ok(entities[0].clone());
345        }
346
347        let merged_name = decision
348            .merged_name
349            .clone()
350            .unwrap_or_else(|| self.select_best_name(&entities));
351
352        // Combine all mentions
353        let mut all_mentions = Vec::new();
354        let mut total_confidence = 0.0;
355
356        for entity in &entities {
357            all_mentions.extend(entity.mentions.clone());
358            total_confidence += entity.confidence;
359        }
360
361        let avg_confidence = if entities.is_empty() {
362            0.0
363        } else {
364            total_confidence / entities.len() as f32
365        };
366
367        // Create merged entity
368        let merged_entity = Entity {
369            id: entities[0].id.clone(), // Keep the first entity's ID
370            name: merged_name,
371            entity_type: entities[0].entity_type.clone(),
372            confidence: avg_confidence.max(decision.confidence as f32),
373            mentions: all_mentions,
374            embedding: entities[0].embedding.clone(), // Take first embedding
375        };
376
377        Ok(merged_entity)
378    }
379
380    /// Get merging statistics
381    pub fn get_statistics(&self) -> MergingStatistics {
382        MergingStatistics {
383            similarity_threshold: self.similarity_threshold,
384            max_description_tokens: self.max_description_tokens,
385            uses_llm: self.use_llm_merging,
386            llm_available: self.llm_client.is_some(),
387        }
388    }
389}
390
391/// Statistics for entity merging process
392#[derive(Debug, Clone)]
393pub struct MergingStatistics {
394    /// Similarity threshold for merging (0.0-1.0)
395    pub similarity_threshold: f64,
396    /// Maximum tokens for entity descriptions
397    pub max_description_tokens: usize,
398    /// Whether LLM is used for merging
399    pub uses_llm: bool,
400    /// Whether LLM client is available
401    pub llm_available: bool,
402}
403
404impl MergingStatistics {
405    /// Print statistics to stdout
406    #[allow(dead_code)]
407    pub fn print(&self) {
408        tracing::info!("Entity Merging Statistics");
409        tracing::info!("  Similarity threshold: {:.2}", self.similarity_threshold);
410        tracing::info!("  Max description tokens: {}", self.max_description_tokens);
411        tracing::info!("  Uses LLM: {}", self.uses_llm);
412        tracing::info!("  LLM available: {}", self.llm_available);
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::core::{ChunkId, EntityId, EntityMention};
420
421    fn create_test_entities() -> Vec<Entity> {
422        vec![
423            Entity::new(
424                EntityId::new("entity1".to_string()),
425                "Apple Inc".to_string(),
426                "ORGANIZATION".to_string(),
427                0.9,
428            ),
429            Entity::new(
430                EntityId::new("entity2".to_string()),
431                "Apple Inc.".to_string(),
432                "ORGANIZATION".to_string(),
433                0.8,
434            ),
435            Entity::new(
436                EntityId::new("entity3".to_string()),
437                "Microsoft".to_string(),
438                "ORGANIZATION".to_string(),
439                0.9,
440            ),
441        ]
442    }
443
444    #[test]
445    fn test_semantic_entity_merger_creation() {
446        let merger = SemanticEntityMerger::new(0.8);
447        let stats = merger.get_statistics();
448
449        assert_eq!(stats.similarity_threshold, 0.8);
450        assert!(!stats.uses_llm);
451        assert!(!stats.llm_available);
452    }
453
454    #[tokio::test]
455    async fn test_entity_grouping() {
456        let merger = SemanticEntityMerger::new(0.7);
457        let entities = create_test_entities();
458
459        let groups = merger.group_similar_entities(&entities).await.unwrap();
460
461        // Should group Apple entities together
462        assert!(!groups.is_empty());
463
464        // Find the Apple group
465        let apple_group = groups
466            .iter()
467            .find(|group| group.iter().any(|e| e.name.contains("Apple")));
468
469        assert!(apple_group.is_some());
470        let apple_group = apple_group.unwrap();
471        assert_eq!(apple_group.len(), 2); // Apple Inc and Apple Inc.
472    }
473
474    #[test]
475    fn test_heuristic_merge_decision() {
476        let merger = SemanticEntityMerger::new(0.8);
477        let entities = vec![
478            Entity::new(
479                EntityId::new("entity1".to_string()),
480                "Apple Inc".to_string(),
481                "ORGANIZATION".to_string(),
482                0.9,
483            ),
484            Entity::new(
485                EntityId::new("entity2".to_string()),
486                "Apple Inc.".to_string(),
487                "ORGANIZATION".to_string(),
488                0.8,
489            ),
490        ];
491
492        let decision = merger.heuristic_merge_decision(&entities);
493
494        assert!(decision.should_merge);
495        assert!(decision.confidence > 0.8);
496        assert!(decision.merged_name.is_some());
497    }
498
499    #[test]
500    fn test_string_similarity() {
501        let merger = SemanticEntityMerger::new(0.8);
502
503        assert_eq!(merger.string_similarity("Apple", "Apple"), 1.0);
504        assert!(merger.string_similarity("Apple Inc", "Apple Inc.") > 0.8);
505        assert!(merger.string_similarity("Apple", "Microsoft") < 0.3);
506    }
507
508    #[test]
509    fn test_entity_merging() {
510        let merger = SemanticEntityMerger::new(0.8);
511
512        let entities = vec![
513            Entity::new(
514                EntityId::new("entity1".to_string()),
515                "Apple Inc".to_string(),
516                "ORGANIZATION".to_string(),
517                0.9,
518            )
519            .with_mentions(vec![EntityMention {
520                chunk_id: ChunkId::new("chunk1".to_string()),
521                start_offset: 0,
522                end_offset: 9,
523                confidence: 0.9,
524            }]),
525            Entity::new(
526                EntityId::new("entity2".to_string()),
527                "Apple Inc.".to_string(),
528                "ORGANIZATION".to_string(),
529                0.8,
530            )
531            .with_mentions(vec![EntityMention {
532                chunk_id: ChunkId::new("chunk2".to_string()),
533                start_offset: 0,
534                end_offset: 10,
535                confidence: 0.8,
536            }]),
537        ];
538
539        let decision = EntityMergeDecision {
540            should_merge: true,
541            confidence: 0.9,
542            reasoning: "Test merge".to_string(),
543            merged_name: Some("Apple Inc.".to_string()),
544            merged_description: Some("Merged Apple entity".to_string()),
545        };
546
547        let merged = merger.merge_entities(entities, &decision).unwrap();
548
549        assert_eq!(merged.name, "Apple Inc.");
550        assert_eq!(merged.mentions.len(), 2); // Combined mentions
551        assert!(merged.confidence >= 0.8);
552    }
553}