Skip to main content

graphrag_core/entity/
semantic_merging.rs

1use crate::{
2    core::{Entity, Result},
3    ollama::OllamaClient,
4};
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8/// Decision about merging entities
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct EntityMergeDecision {
11    /// Whether the entities should be merged
12    pub should_merge: bool,
13    /// Confidence in the merge decision (0.0-1.0)
14    pub confidence: f64,
15    /// Reasoning for the decision
16    pub reasoning: String,
17    /// Merged entity description if merging
18    pub merged_description: Option<String>,
19    /// Merged entity name if merging
20    pub merged_name: Option<String>,
21}
22
23/// Entity merger using semantic similarity and optional LLM
24#[derive(Clone)]
25pub struct SemanticEntityMerger {
26    llm_client: Option<OllamaClient>,
27    similarity_threshold: f64,
28    max_description_tokens: usize,
29    use_llm_merging: bool,
30}
31
32impl SemanticEntityMerger {
33    /// Create a new semantic entity merger
34    pub fn new(similarity_threshold: f64) -> Self {
35        Self {
36            llm_client: None,
37            similarity_threshold,
38            max_description_tokens: 512,
39            use_llm_merging: false,
40        }
41    }
42
43    /// Add an LLM client for intelligent merging
44    pub fn with_llm_client(mut self, client: OllamaClient) -> Self {
45        self.llm_client = Some(client);
46        self.use_llm_merging = true;
47        self
48    }
49
50    /// Set maximum tokens for entity descriptions
51    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
52        self.max_description_tokens = max_tokens;
53        self
54    }
55
56    /// Group entities by semantic similarity for potential merging
57    pub async fn group_similar_entities(&self, entities: &[Entity]) -> Result<Vec<Vec<Entity>>> {
58        let mut similarity_groups = Vec::new();
59        let mut processed = HashSet::new();
60
61        for (i, entity1) in entities.iter().enumerate() {
62            if processed.contains(&i) {
63                continue;
64            }
65
66            let mut group = vec![entity1.clone()];
67            processed.insert(i);
68
69            // Find similar entities
70            for (j, entity2) in entities.iter().enumerate() {
71                if i == j || processed.contains(&j) {
72                    continue;
73                }
74
75                let similarity = self.calculate_semantic_similarity(entity1, entity2).await?;
76                if similarity > self.similarity_threshold {
77                    group.push(entity2.clone());
78                    processed.insert(j);
79                }
80            }
81
82            if group.len() > 1 {
83                similarity_groups.push(group);
84            }
85        }
86
87        Ok(similarity_groups)
88    }
89
90    /// Use LLM to decide if entities should be merged and how
91    pub async fn decide_merge(&self, entity_group: &[Entity]) -> Result<EntityMergeDecision> {
92        if !self.use_llm_merging {
93            // Fallback to simple heuristic-based merging
94            return Ok(self.heuristic_merge_decision(entity_group));
95        }
96
97        if let Some(llm_client) = &self.llm_client {
98            let prompt = self.build_merge_decision_prompt(entity_group);
99
100            // Try to get structured response from LLM
101            match self.try_llm_merge_decision(llm_client, &prompt).await {
102                Ok(decision) => Ok(decision),
103                Err(_) => {
104                    #[cfg(feature = "tracing")]
105                    tracing::warn!("LLM merge decision failed, falling back to heuristics");
106                    Ok(self.heuristic_merge_decision(entity_group))
107                },
108            }
109        } else {
110            Ok(self.heuristic_merge_decision(entity_group))
111        }
112    }
113
114    async fn try_llm_merge_decision(
115        &self,
116        _llm_client: &OllamaClient,
117        prompt: &str,
118    ) -> Result<EntityMergeDecision> {
119        // For now, simulate an LLM response with a simple heuristic
120        // In a real implementation, this would call the actual LLM
121        let _response = prompt; // Placeholder
122
123        // Simple heuristic for now since we don't have actual LLM integration
124        Ok(EntityMergeDecision {
125            should_merge: true,
126            confidence: 0.8,
127            reasoning: "LLM analysis suggests these entities should be merged".to_string(),
128            merged_name: Some("Merged Entity".to_string()),
129            merged_description: Some("Merged based on LLM analysis".to_string()),
130        })
131    }
132
133    fn heuristic_merge_decision(&self, entity_group: &[Entity]) -> EntityMergeDecision {
134        if entity_group.len() < 2 {
135            return EntityMergeDecision {
136                should_merge: false,
137                confidence: 1.0,
138                reasoning: "Only one entity in group".to_string(),
139                merged_name: None,
140                merged_description: None,
141            };
142        }
143
144        // Simple heuristic: merge if names are very similar and types match
145        let first_entity = &entity_group[0];
146        let all_same_type = entity_group
147            .iter()
148            .all(|e| e.entity_type == first_entity.entity_type);
149
150        if all_same_type {
151            let name_similarity = self.calculate_name_similarity_heuristic(entity_group);
152
153            if name_similarity > 0.8 {
154                let merged_name = self.select_best_name(entity_group);
155                let merged_description = self.combine_descriptions(entity_group);
156
157                EntityMergeDecision {
158                    should_merge: true,
159                    confidence: name_similarity,
160                    reasoning: format!(
161                        "High name similarity ({name_similarity:.2}) and matching types"
162                    ),
163                    merged_name: Some(merged_name),
164                    merged_description: Some(merged_description),
165                }
166            } else {
167                EntityMergeDecision {
168                    should_merge: false,
169                    confidence: 1.0 - name_similarity,
170                    reasoning: format!("Low name similarity ({name_similarity:.2})"),
171                    merged_name: None,
172                    merged_description: None,
173                }
174            }
175        } else {
176            EntityMergeDecision {
177                should_merge: false,
178                confidence: 1.0,
179                reasoning: "Different entity types".to_string(),
180                merged_name: None,
181                merged_description: None,
182            }
183        }
184    }
185
186    fn calculate_name_similarity_heuristic(&self, entities: &[Entity]) -> f64 {
187        if entities.len() < 2 {
188            return 1.0;
189        }
190
191        let mut total_similarity = 0.0;
192        let mut comparisons = 0;
193
194        for i in 0..entities.len() {
195            for j in i + 1..entities.len() {
196                let similarity = self.string_similarity(&entities[i].name, &entities[j].name);
197                total_similarity += similarity;
198                comparisons += 1;
199            }
200        }
201
202        if comparisons > 0 {
203            total_similarity / comparisons as f64
204        } else {
205            0.0
206        }
207    }
208
209    fn string_similarity(&self, s1: &str, s2: &str) -> f64 {
210        let s1_lower = s1.to_lowercase();
211        let s2_lower = s2.to_lowercase();
212
213        // Exact match
214        if s1_lower == s2_lower {
215            return 1.0;
216        }
217
218        // One contains the other
219        if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
220            return 0.9;
221        }
222
223        // Jaccard similarity on words
224        let words1: HashSet<&str> = s1_lower.split_whitespace().collect();
225        let words2: HashSet<&str> = s2_lower.split_whitespace().collect();
226
227        let intersection = words1.intersection(&words2).count();
228        let union = words1.union(&words2).count();
229
230        if union == 0 {
231            0.0
232        } else {
233            intersection as f64 / union as f64
234        }
235    }
236
237    fn select_best_name(&self, entities: &[Entity]) -> String {
238        // Select the longest name or the one with highest confidence
239        entities
240            .iter()
241            .max_by(|a, b| {
242                let length_cmp = a.name.len().cmp(&b.name.len());
243                if length_cmp == std::cmp::Ordering::Equal {
244                    a.confidence
245                        .partial_cmp(&b.confidence)
246                        .unwrap_or(std::cmp::Ordering::Equal)
247                } else {
248                    length_cmp
249                }
250            })
251            .map(|e| e.name.clone())
252            .unwrap_or_else(|| "Merged Entity".to_string())
253    }
254
255    fn combine_descriptions(&self, entities: &[Entity]) -> String {
256        let descriptions: Vec<String> = entities
257            .iter()
258            .map(|e| {
259                if let Some(_desc) = e.mentions.first() {
260                    format!("Entity '{}' mentioned in context", e.name)
261                } else {
262                    format!("Entity '{}' of type {}", e.name, e.entity_type)
263                }
264            })
265            .collect();
266
267        if descriptions.is_empty() {
268            "Merged entity from multiple sources".to_string()
269        } else {
270            descriptions.join("; ")
271        }
272    }
273
274    fn build_merge_decision_prompt(&self, entities: &[Entity]) -> String {
275        let mut prompt = String::from(
276            "Analyze the following entities and determine if they represent the same real-world entity:\n\n"
277        );
278
279        for (i, entity) in entities.iter().enumerate() {
280            let description = if entity.mentions.is_empty() {
281                "No description".to_string()
282            } else {
283                format!("Mentioned {} times", entity.mentions.len())
284            };
285
286            prompt.push_str(&format!(
287                "Entity {}: {}\n  Type: {}\n  Confidence: {:.2}\n  Description: {}\n\n",
288                i + 1,
289                entity.name,
290                entity.entity_type,
291                entity.confidence,
292                description
293            ));
294        }
295
296        prompt.push_str(
297            "Consider:\n\
298             1. Are these entities referring to the same real-world entity?\n\
299             2. Do they have compatible descriptions and contexts?\n\
300             3. If merged, what would be the best combined name and description?\n\n\
301             Respond with 'YES' if they should be merged, 'NO' if they should remain separate.\n\
302             Briefly explain your reasoning.",
303        );
304
305        prompt
306    }
307
308    async fn calculate_semantic_similarity(
309        &self,
310        entity1: &Entity,
311        entity2: &Entity,
312    ) -> Result<f64> {
313        // For now, use string-based similarity
314        // In a real implementation with embeddings, this would use cosine similarity
315
316        // Check name similarity
317        let name_sim = self.string_similarity(&entity1.name, &entity2.name);
318
319        // Check type compatibility
320        let type_sim = if entity1.entity_type == entity2.entity_type {
321            1.0
322        } else {
323            0.0
324        };
325
326        // Weighted combination
327        let combined_similarity = name_sim * 0.7 + type_sim * 0.3;
328
329        Ok(combined_similarity)
330    }
331
332    /// Perform the actual entity merging based on decision
333    pub fn merge_entities(
334        &self,
335        entities: Vec<Entity>,
336        decision: &EntityMergeDecision,
337    ) -> Result<Entity> {
338        if entities.is_empty() {
339            return Err(crate::core::GraphRAGError::Config {
340                message: "No entities to merge".to_string(),
341            });
342        }
343
344        if !decision.should_merge {
345            return Ok(entities[0].clone());
346        }
347
348        let merged_name = decision
349            .merged_name
350            .clone()
351            .unwrap_or_else(|| self.select_best_name(&entities));
352
353        // Combine all mentions
354        let mut all_mentions = Vec::new();
355        let mut total_confidence = 0.0;
356
357        for entity in &entities {
358            all_mentions.extend(entity.mentions.clone());
359            total_confidence += entity.confidence;
360        }
361
362        let avg_confidence = if entities.is_empty() {
363            0.0
364        } else {
365            total_confidence / entities.len() as f32
366        };
367
368        // Create merged entity
369        let merged_entity = Entity {
370            id: entities[0].id.clone(), // Keep the first entity's ID
371            name: merged_name,
372            entity_type: entities[0].entity_type.clone(),
373            confidence: avg_confidence.max(decision.confidence as f32),
374            mentions: all_mentions,
375            embedding: entities[0].embedding.clone(), // Take first embedding
376            first_mentioned: None,
377            last_mentioned: None,
378            temporal_validity: None,
379        };
380
381        Ok(merged_entity)
382    }
383
384    /// Get merging statistics
385    pub fn get_statistics(&self) -> MergingStatistics {
386        MergingStatistics {
387            similarity_threshold: self.similarity_threshold,
388            max_description_tokens: self.max_description_tokens,
389            uses_llm: self.use_llm_merging,
390            llm_available: self.llm_client.is_some(),
391        }
392    }
393}
394
395/// Statistics for entity merging process
396#[derive(Debug, Clone)]
397pub struct MergingStatistics {
398    /// Similarity threshold for merging (0.0-1.0)
399    pub similarity_threshold: f64,
400    /// Maximum tokens for entity descriptions
401    pub max_description_tokens: usize,
402    /// Whether LLM is used for merging
403    pub uses_llm: bool,
404    /// Whether LLM client is available
405    pub llm_available: bool,
406}
407
408impl MergingStatistics {
409    /// Print statistics to stdout
410    #[allow(dead_code)]
411    pub fn print(&self) {
412        #[cfg(feature = "tracing")]
413        tracing::info!("Entity Merging Statistics");
414        #[cfg(feature = "tracing")]
415        tracing::info!("  Similarity threshold: {:.2}", self.similarity_threshold);
416        #[cfg(feature = "tracing")]
417        tracing::info!("  Max description tokens: {}", self.max_description_tokens);
418        #[cfg(feature = "tracing")]
419        tracing::info!("  Uses LLM: {}", self.uses_llm);
420        #[cfg(feature = "tracing")]
421        tracing::info!("  LLM available: {}", self.llm_available);
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use crate::core::{ChunkId, EntityId, EntityMention};
429
430    fn create_test_entities() -> Vec<Entity> {
431        vec![
432            Entity::new(
433                EntityId::new("entity1".to_string()),
434                "Apple Inc".to_string(),
435                "ORGANIZATION".to_string(),
436                0.9,
437            ),
438            Entity::new(
439                EntityId::new("entity2".to_string()),
440                "Apple Inc.".to_string(),
441                "ORGANIZATION".to_string(),
442                0.8,
443            ),
444            Entity::new(
445                EntityId::new("entity3".to_string()),
446                "Microsoft".to_string(),
447                "ORGANIZATION".to_string(),
448                0.9,
449            ),
450        ]
451    }
452
453    #[test]
454    fn test_semantic_entity_merger_creation() {
455        let merger = SemanticEntityMerger::new(0.8);
456        let stats = merger.get_statistics();
457
458        assert_eq!(stats.similarity_threshold, 0.8);
459        assert!(!stats.uses_llm);
460        assert!(!stats.llm_available);
461    }
462
463    #[tokio::test]
464    async fn test_entity_grouping() {
465        let merger = SemanticEntityMerger::new(0.7);
466        let entities = create_test_entities();
467
468        let groups = merger.group_similar_entities(&entities).await.unwrap();
469
470        // Should group Apple entities together
471        assert!(!groups.is_empty());
472
473        // Find the Apple group
474        let apple_group = groups
475            .iter()
476            .find(|group| group.iter().any(|e| e.name.contains("Apple")));
477
478        assert!(apple_group.is_some());
479        let apple_group = apple_group.unwrap();
480        assert_eq!(apple_group.len(), 2); // Apple Inc and Apple Inc.
481    }
482
483    #[test]
484    fn test_heuristic_merge_decision() {
485        let merger = SemanticEntityMerger::new(0.8);
486        let entities = vec![
487            Entity::new(
488                EntityId::new("entity1".to_string()),
489                "Apple Inc".to_string(),
490                "ORGANIZATION".to_string(),
491                0.9,
492            ),
493            Entity::new(
494                EntityId::new("entity2".to_string()),
495                "Apple Inc.".to_string(),
496                "ORGANIZATION".to_string(),
497                0.8,
498            ),
499        ];
500
501        let decision = merger.heuristic_merge_decision(&entities);
502
503        assert!(decision.should_merge);
504        assert!(decision.confidence > 0.8);
505        assert!(decision.merged_name.is_some());
506    }
507
508    #[test]
509    fn test_string_similarity() {
510        let merger = SemanticEntityMerger::new(0.8);
511
512        assert_eq!(merger.string_similarity("Apple", "Apple"), 1.0);
513        assert!(merger.string_similarity("Apple Inc", "Apple Inc.") > 0.8);
514        assert!(merger.string_similarity("Apple", "Microsoft") < 0.3);
515    }
516
517    #[test]
518    fn test_entity_merging() {
519        let merger = SemanticEntityMerger::new(0.8);
520
521        let entities = vec![
522            Entity::new(
523                EntityId::new("entity1".to_string()),
524                "Apple Inc".to_string(),
525                "ORGANIZATION".to_string(),
526                0.9,
527            )
528            .with_mentions(vec![EntityMention {
529                chunk_id: ChunkId::new("chunk1".to_string()),
530                start_offset: 0,
531                end_offset: 9,
532                confidence: 0.9,
533            }]),
534            Entity::new(
535                EntityId::new("entity2".to_string()),
536                "Apple Inc.".to_string(),
537                "ORGANIZATION".to_string(),
538                0.8,
539            )
540            .with_mentions(vec![EntityMention {
541                chunk_id: ChunkId::new("chunk2".to_string()),
542                start_offset: 0,
543                end_offset: 10,
544                confidence: 0.8,
545            }]),
546        ];
547
548        let decision = EntityMergeDecision {
549            should_merge: true,
550            confidence: 0.9,
551            reasoning: "Test merge".to_string(),
552            merged_name: Some("Apple Inc.".to_string()),
553            merged_description: Some("Merged Apple entity".to_string()),
554        };
555
556        let merged = merger.merge_entities(entities, &decision).unwrap();
557
558        assert_eq!(merged.name, "Apple Inc.");
559        assert_eq!(merged.mentions.len(), 2); // Combined mentions
560        assert!(merged.confidence >= 0.8);
561    }
562}