1use crate::{
2 core::{Entity, Result},
3 ollama::OllamaClient,
4};
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct EntityMergeDecision {
11 pub should_merge: bool,
13 pub confidence: f64,
15 pub reasoning: String,
17 pub merged_description: Option<String>,
19 pub merged_name: Option<String>,
21}
22
23#[derive(Clone)]
25pub struct SemanticEntityMerger {
26 llm_client: Option<OllamaClient>,
27 similarity_threshold: f64,
28 max_description_tokens: usize,
29 use_llm_merging: bool,
30}
31
32impl SemanticEntityMerger {
33 pub fn new(similarity_threshold: f64) -> Self {
35 Self {
36 llm_client: None,
37 similarity_threshold,
38 max_description_tokens: 512,
39 use_llm_merging: false,
40 }
41 }
42
43 pub fn with_llm_client(mut self, client: OllamaClient) -> Self {
45 self.llm_client = Some(client);
46 self.use_llm_merging = true;
47 self
48 }
49
50 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
52 self.max_description_tokens = max_tokens;
53 self
54 }
55
56 pub async fn group_similar_entities(&self, entities: &[Entity]) -> Result<Vec<Vec<Entity>>> {
58 let mut similarity_groups = Vec::new();
59 let mut processed = HashSet::new();
60
61 for (i, entity1) in entities.iter().enumerate() {
62 if processed.contains(&i) {
63 continue;
64 }
65
66 let mut group = vec![entity1.clone()];
67 processed.insert(i);
68
69 for (j, entity2) in entities.iter().enumerate() {
71 if i == j || processed.contains(&j) {
72 continue;
73 }
74
75 let similarity = self.calculate_semantic_similarity(entity1, entity2).await?;
76 if similarity > self.similarity_threshold {
77 group.push(entity2.clone());
78 processed.insert(j);
79 }
80 }
81
82 if group.len() > 1 {
83 similarity_groups.push(group);
84 }
85 }
86
87 Ok(similarity_groups)
88 }
89
90 pub async fn decide_merge(&self, entity_group: &[Entity]) -> Result<EntityMergeDecision> {
92 if !self.use_llm_merging {
93 return Ok(self.heuristic_merge_decision(entity_group));
95 }
96
97 if let Some(llm_client) = &self.llm_client {
98 let prompt = self.build_merge_decision_prompt(entity_group);
99
100 match self.try_llm_merge_decision(llm_client, &prompt).await {
102 Ok(decision) => Ok(decision),
103 Err(_) => {
104 tracing::warn!("LLM merge decision failed, falling back to heuristics");
105 Ok(self.heuristic_merge_decision(entity_group))
106 },
107 }
108 } else {
109 Ok(self.heuristic_merge_decision(entity_group))
110 }
111 }
112
113 async fn try_llm_merge_decision(
114 &self,
115 _llm_client: &OllamaClient,
116 prompt: &str,
117 ) -> Result<EntityMergeDecision> {
118 let _response = prompt; Ok(EntityMergeDecision {
124 should_merge: true,
125 confidence: 0.8,
126 reasoning: "LLM analysis suggests these entities should be merged".to_string(),
127 merged_name: Some("Merged Entity".to_string()),
128 merged_description: Some("Merged based on LLM analysis".to_string()),
129 })
130 }
131
132 fn heuristic_merge_decision(&self, entity_group: &[Entity]) -> EntityMergeDecision {
133 if entity_group.len() < 2 {
134 return EntityMergeDecision {
135 should_merge: false,
136 confidence: 1.0,
137 reasoning: "Only one entity in group".to_string(),
138 merged_name: None,
139 merged_description: None,
140 };
141 }
142
143 let first_entity = &entity_group[0];
145 let all_same_type = entity_group
146 .iter()
147 .all(|e| e.entity_type == first_entity.entity_type);
148
149 if all_same_type {
150 let name_similarity = self.calculate_name_similarity_heuristic(entity_group);
151
152 if name_similarity > 0.8 {
153 let merged_name = self.select_best_name(entity_group);
154 let merged_description = self.combine_descriptions(entity_group);
155
156 EntityMergeDecision {
157 should_merge: true,
158 confidence: name_similarity,
159 reasoning: format!(
160 "High name similarity ({name_similarity:.2}) and matching types"
161 ),
162 merged_name: Some(merged_name),
163 merged_description: Some(merged_description),
164 }
165 } else {
166 EntityMergeDecision {
167 should_merge: false,
168 confidence: 1.0 - name_similarity,
169 reasoning: format!("Low name similarity ({name_similarity:.2})"),
170 merged_name: None,
171 merged_description: None,
172 }
173 }
174 } else {
175 EntityMergeDecision {
176 should_merge: false,
177 confidence: 1.0,
178 reasoning: "Different entity types".to_string(),
179 merged_name: None,
180 merged_description: None,
181 }
182 }
183 }
184
185 fn calculate_name_similarity_heuristic(&self, entities: &[Entity]) -> f64 {
186 if entities.len() < 2 {
187 return 1.0;
188 }
189
190 let mut total_similarity = 0.0;
191 let mut comparisons = 0;
192
193 for i in 0..entities.len() {
194 for j in i + 1..entities.len() {
195 let similarity = self.string_similarity(&entities[i].name, &entities[j].name);
196 total_similarity += similarity;
197 comparisons += 1;
198 }
199 }
200
201 if comparisons > 0 {
202 total_similarity / comparisons as f64
203 } else {
204 0.0
205 }
206 }
207
208 fn string_similarity(&self, s1: &str, s2: &str) -> f64 {
209 let s1_lower = s1.to_lowercase();
210 let s2_lower = s2.to_lowercase();
211
212 if s1_lower == s2_lower {
214 return 1.0;
215 }
216
217 if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
219 return 0.9;
220 }
221
222 let words1: HashSet<&str> = s1_lower.split_whitespace().collect();
224 let words2: HashSet<&str> = s2_lower.split_whitespace().collect();
225
226 let intersection = words1.intersection(&words2).count();
227 let union = words1.union(&words2).count();
228
229 if union == 0 {
230 0.0
231 } else {
232 intersection as f64 / union as f64
233 }
234 }
235
236 fn select_best_name(&self, entities: &[Entity]) -> String {
237 entities
239 .iter()
240 .max_by(|a, b| {
241 let length_cmp = a.name.len().cmp(&b.name.len());
242 if length_cmp == std::cmp::Ordering::Equal {
243 a.confidence
244 .partial_cmp(&b.confidence)
245 .unwrap_or(std::cmp::Ordering::Equal)
246 } else {
247 length_cmp
248 }
249 })
250 .map(|e| e.name.clone())
251 .unwrap_or_else(|| "Merged Entity".to_string())
252 }
253
254 fn combine_descriptions(&self, entities: &[Entity]) -> String {
255 let descriptions: Vec<String> = entities
256 .iter()
257 .map(|e| {
258 if let Some(_desc) = e.mentions.first() {
259 format!("Entity '{}' mentioned in context", e.name)
260 } else {
261 format!("Entity '{}' of type {}", e.name, e.entity_type)
262 }
263 })
264 .collect();
265
266 if descriptions.is_empty() {
267 "Merged entity from multiple sources".to_string()
268 } else {
269 descriptions.join("; ")
270 }
271 }
272
273 fn build_merge_decision_prompt(&self, entities: &[Entity]) -> String {
274 let mut prompt = String::from(
275 "Analyze the following entities and determine if they represent the same real-world entity:\n\n"
276 );
277
278 for (i, entity) in entities.iter().enumerate() {
279 let description = if entity.mentions.is_empty() {
280 "No description".to_string()
281 } else {
282 format!("Mentioned {} times", entity.mentions.len())
283 };
284
285 prompt.push_str(&format!(
286 "Entity {}: {}\n Type: {}\n Confidence: {:.2}\n Description: {}\n\n",
287 i + 1,
288 entity.name,
289 entity.entity_type,
290 entity.confidence,
291 description
292 ));
293 }
294
295 prompt.push_str(
296 "Consider:\n\
297 1. Are these entities referring to the same real-world entity?\n\
298 2. Do they have compatible descriptions and contexts?\n\
299 3. If merged, what would be the best combined name and description?\n\n\
300 Respond with 'YES' if they should be merged, 'NO' if they should remain separate.\n\
301 Briefly explain your reasoning.",
302 );
303
304 prompt
305 }
306
307 async fn calculate_semantic_similarity(
308 &self,
309 entity1: &Entity,
310 entity2: &Entity,
311 ) -> Result<f64> {
312 let name_sim = self.string_similarity(&entity1.name, &entity2.name);
317
318 let type_sim = if entity1.entity_type == entity2.entity_type {
320 1.0
321 } else {
322 0.0
323 };
324
325 let combined_similarity = name_sim * 0.7 + type_sim * 0.3;
327
328 Ok(combined_similarity)
329 }
330
331 pub fn merge_entities(
333 &self,
334 entities: Vec<Entity>,
335 decision: &EntityMergeDecision,
336 ) -> Result<Entity> {
337 if entities.is_empty() {
338 return Err(crate::core::GraphRAGError::Config {
339 message: "No entities to merge".to_string(),
340 });
341 }
342
343 if !decision.should_merge {
344 return Ok(entities[0].clone());
345 }
346
347 let merged_name = decision
348 .merged_name
349 .clone()
350 .unwrap_or_else(|| self.select_best_name(&entities));
351
352 let mut all_mentions = Vec::new();
354 let mut total_confidence = 0.0;
355
356 for entity in &entities {
357 all_mentions.extend(entity.mentions.clone());
358 total_confidence += entity.confidence;
359 }
360
361 let avg_confidence = if entities.is_empty() {
362 0.0
363 } else {
364 total_confidence / entities.len() as f32
365 };
366
367 let merged_entity = Entity {
369 id: entities[0].id.clone(), name: merged_name,
371 entity_type: entities[0].entity_type.clone(),
372 confidence: avg_confidence.max(decision.confidence as f32),
373 mentions: all_mentions,
374 embedding: entities[0].embedding.clone(), first_mentioned: None,
376 last_mentioned: None,
377 temporal_validity: None,
378 };
379
380 Ok(merged_entity)
381 }
382
383 pub fn get_statistics(&self) -> MergingStatistics {
385 MergingStatistics {
386 similarity_threshold: self.similarity_threshold,
387 max_description_tokens: self.max_description_tokens,
388 uses_llm: self.use_llm_merging,
389 llm_available: self.llm_client.is_some(),
390 }
391 }
392}
393
394#[derive(Debug, Clone)]
396pub struct MergingStatistics {
397 pub similarity_threshold: f64,
399 pub max_description_tokens: usize,
401 pub uses_llm: bool,
403 pub llm_available: bool,
405}
406
407impl MergingStatistics {
408 #[allow(dead_code)]
410 pub fn print(&self) {
411 tracing::info!("Entity Merging Statistics");
412 tracing::info!(" Similarity threshold: {:.2}", self.similarity_threshold);
413 tracing::info!(" Max description tokens: {}", self.max_description_tokens);
414 tracing::info!(" Uses LLM: {}", self.uses_llm);
415 tracing::info!(" LLM available: {}", self.llm_available);
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::core::{ChunkId, EntityId, EntityMention};
423
424 fn create_test_entities() -> Vec<Entity> {
425 vec![
426 Entity::new(
427 EntityId::new("entity1".to_string()),
428 "Apple Inc".to_string(),
429 "ORGANIZATION".to_string(),
430 0.9,
431 ),
432 Entity::new(
433 EntityId::new("entity2".to_string()),
434 "Apple Inc.".to_string(),
435 "ORGANIZATION".to_string(),
436 0.8,
437 ),
438 Entity::new(
439 EntityId::new("entity3".to_string()),
440 "Microsoft".to_string(),
441 "ORGANIZATION".to_string(),
442 0.9,
443 ),
444 ]
445 }
446
447 #[test]
448 fn test_semantic_entity_merger_creation() {
449 let merger = SemanticEntityMerger::new(0.8);
450 let stats = merger.get_statistics();
451
452 assert_eq!(stats.similarity_threshold, 0.8);
453 assert!(!stats.uses_llm);
454 assert!(!stats.llm_available);
455 }
456
457 #[tokio::test]
458 async fn test_entity_grouping() {
459 let merger = SemanticEntityMerger::new(0.7);
460 let entities = create_test_entities();
461
462 let groups = merger.group_similar_entities(&entities).await.unwrap();
463
464 assert!(!groups.is_empty());
466
467 let apple_group = groups
469 .iter()
470 .find(|group| group.iter().any(|e| e.name.contains("Apple")));
471
472 assert!(apple_group.is_some());
473 let apple_group = apple_group.unwrap();
474 assert_eq!(apple_group.len(), 2); }
476
477 #[test]
478 fn test_heuristic_merge_decision() {
479 let merger = SemanticEntityMerger::new(0.8);
480 let entities = vec![
481 Entity::new(
482 EntityId::new("entity1".to_string()),
483 "Apple Inc".to_string(),
484 "ORGANIZATION".to_string(),
485 0.9,
486 ),
487 Entity::new(
488 EntityId::new("entity2".to_string()),
489 "Apple Inc.".to_string(),
490 "ORGANIZATION".to_string(),
491 0.8,
492 ),
493 ];
494
495 let decision = merger.heuristic_merge_decision(&entities);
496
497 assert!(decision.should_merge);
498 assert!(decision.confidence > 0.8);
499 assert!(decision.merged_name.is_some());
500 }
501
502 #[test]
503 fn test_string_similarity() {
504 let merger = SemanticEntityMerger::new(0.8);
505
506 assert_eq!(merger.string_similarity("Apple", "Apple"), 1.0);
507 assert!(merger.string_similarity("Apple Inc", "Apple Inc.") > 0.8);
508 assert!(merger.string_similarity("Apple", "Microsoft") < 0.3);
509 }
510
511 #[test]
512 fn test_entity_merging() {
513 let merger = SemanticEntityMerger::new(0.8);
514
515 let entities = vec![
516 Entity::new(
517 EntityId::new("entity1".to_string()),
518 "Apple Inc".to_string(),
519 "ORGANIZATION".to_string(),
520 0.9,
521 )
522 .with_mentions(vec![EntityMention {
523 chunk_id: ChunkId::new("chunk1".to_string()),
524 start_offset: 0,
525 end_offset: 9,
526 confidence: 0.9,
527 }]),
528 Entity::new(
529 EntityId::new("entity2".to_string()),
530 "Apple Inc.".to_string(),
531 "ORGANIZATION".to_string(),
532 0.8,
533 )
534 .with_mentions(vec![EntityMention {
535 chunk_id: ChunkId::new("chunk2".to_string()),
536 start_offset: 0,
537 end_offset: 10,
538 confidence: 0.8,
539 }]),
540 ];
541
542 let decision = EntityMergeDecision {
543 should_merge: true,
544 confidence: 0.9,
545 reasoning: "Test merge".to_string(),
546 merged_name: Some("Apple Inc.".to_string()),
547 merged_description: Some("Merged Apple entity".to_string()),
548 };
549
550 let merged = merger.merge_entities(entities, &decision).unwrap();
551
552 assert_eq!(merged.name, "Apple Inc.");
553 assert_eq!(merged.mentions.len(), 2); assert!(merged.confidence >= 0.8);
555 }
556}