1use crate::{
2 core::{Entity, Result},
3 ollama::OllamaClient,
4};
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct EntityMergeDecision {
11 pub should_merge: bool,
13 pub confidence: f64,
15 pub reasoning: String,
17 pub merged_description: Option<String>,
19 pub merged_name: Option<String>,
21}
22
23#[derive(Clone)]
25pub struct SemanticEntityMerger {
26 llm_client: Option<OllamaClient>,
27 similarity_threshold: f64,
28 max_description_tokens: usize,
29 use_llm_merging: bool,
30}
31
32impl SemanticEntityMerger {
33 pub fn new(similarity_threshold: f64) -> Self {
35 Self {
36 llm_client: None,
37 similarity_threshold,
38 max_description_tokens: 512,
39 use_llm_merging: false,
40 }
41 }
42
43 pub fn with_llm_client(mut self, client: OllamaClient) -> Self {
45 self.llm_client = Some(client);
46 self.use_llm_merging = true;
47 self
48 }
49
50 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
52 self.max_description_tokens = max_tokens;
53 self
54 }
55
56 pub async fn group_similar_entities(&self, entities: &[Entity]) -> Result<Vec<Vec<Entity>>> {
58 let mut similarity_groups = Vec::new();
59 let mut processed = HashSet::new();
60
61 for (i, entity1) in entities.iter().enumerate() {
62 if processed.contains(&i) {
63 continue;
64 }
65
66 let mut group = vec![entity1.clone()];
67 processed.insert(i);
68
69 for (j, entity2) in entities.iter().enumerate() {
71 if i == j || processed.contains(&j) {
72 continue;
73 }
74
75 let similarity = self.calculate_semantic_similarity(entity1, entity2).await?;
76 if similarity > self.similarity_threshold {
77 group.push(entity2.clone());
78 processed.insert(j);
79 }
80 }
81
82 if group.len() > 1 {
83 similarity_groups.push(group);
84 }
85 }
86
87 Ok(similarity_groups)
88 }
89
90 pub async fn decide_merge(&self, entity_group: &[Entity]) -> Result<EntityMergeDecision> {
92 if !self.use_llm_merging {
93 return Ok(self.heuristic_merge_decision(entity_group));
95 }
96
97 if let Some(llm_client) = &self.llm_client {
98 let prompt = self.build_merge_decision_prompt(entity_group);
99
100 match self.try_llm_merge_decision(llm_client, &prompt).await {
102 Ok(decision) => Ok(decision),
103 Err(_) => {
104 tracing::warn!("LLM merge decision failed, falling back to heuristics");
105 Ok(self.heuristic_merge_decision(entity_group))
106 }
107 }
108 } else {
109 Ok(self.heuristic_merge_decision(entity_group))
110 }
111 }
112
113 async fn try_llm_merge_decision(
114 &self,
115 _llm_client: &OllamaClient,
116 prompt: &str,
117 ) -> Result<EntityMergeDecision> {
118 let _response = prompt; Ok(EntityMergeDecision {
124 should_merge: true,
125 confidence: 0.8,
126 reasoning: "LLM analysis suggests these entities should be merged".to_string(),
127 merged_name: Some("Merged Entity".to_string()),
128 merged_description: Some("Merged based on LLM analysis".to_string()),
129 })
130 }
131
132 fn heuristic_merge_decision(&self, entity_group: &[Entity]) -> EntityMergeDecision {
133 if entity_group.len() < 2 {
134 return EntityMergeDecision {
135 should_merge: false,
136 confidence: 1.0,
137 reasoning: "Only one entity in group".to_string(),
138 merged_name: None,
139 merged_description: None,
140 };
141 }
142
143 let first_entity = &entity_group[0];
145 let all_same_type = entity_group
146 .iter()
147 .all(|e| e.entity_type == first_entity.entity_type);
148
149 if all_same_type {
150 let name_similarity = self.calculate_name_similarity_heuristic(entity_group);
151
152 if name_similarity > 0.8 {
153 let merged_name = self.select_best_name(entity_group);
154 let merged_description = self.combine_descriptions(entity_group);
155
156 EntityMergeDecision {
157 should_merge: true,
158 confidence: name_similarity,
159 reasoning: format!(
160 "High name similarity ({name_similarity:.2}) and matching types"
161 ),
162 merged_name: Some(merged_name),
163 merged_description: Some(merged_description),
164 }
165 } else {
166 EntityMergeDecision {
167 should_merge: false,
168 confidence: 1.0 - name_similarity,
169 reasoning: format!("Low name similarity ({name_similarity:.2})"),
170 merged_name: None,
171 merged_description: None,
172 }
173 }
174 } else {
175 EntityMergeDecision {
176 should_merge: false,
177 confidence: 1.0,
178 reasoning: "Different entity types".to_string(),
179 merged_name: None,
180 merged_description: None,
181 }
182 }
183 }
184
185 fn calculate_name_similarity_heuristic(&self, entities: &[Entity]) -> f64 {
186 if entities.len() < 2 {
187 return 1.0;
188 }
189
190 let mut total_similarity = 0.0;
191 let mut comparisons = 0;
192
193 for i in 0..entities.len() {
194 for j in i + 1..entities.len() {
195 let similarity = self.string_similarity(&entities[i].name, &entities[j].name);
196 total_similarity += similarity;
197 comparisons += 1;
198 }
199 }
200
201 if comparisons > 0 {
202 total_similarity / comparisons as f64
203 } else {
204 0.0
205 }
206 }
207
208 fn string_similarity(&self, s1: &str, s2: &str) -> f64 {
209 let s1_lower = s1.to_lowercase();
210 let s2_lower = s2.to_lowercase();
211
212 if s1_lower == s2_lower {
214 return 1.0;
215 }
216
217 if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
219 return 0.9;
220 }
221
222 let words1: HashSet<&str> = s1_lower.split_whitespace().collect();
224 let words2: HashSet<&str> = s2_lower.split_whitespace().collect();
225
226 let intersection = words1.intersection(&words2).count();
227 let union = words1.union(&words2).count();
228
229 if union == 0 {
230 0.0
231 } else {
232 intersection as f64 / union as f64
233 }
234 }
235
236 fn select_best_name(&self, entities: &[Entity]) -> String {
237 entities
239 .iter()
240 .max_by(|a, b| {
241 let length_cmp = a.name.len().cmp(&b.name.len());
242 if length_cmp == std::cmp::Ordering::Equal {
243 a.confidence
244 .partial_cmp(&b.confidence)
245 .unwrap_or(std::cmp::Ordering::Equal)
246 } else {
247 length_cmp
248 }
249 })
250 .map(|e| e.name.clone())
251 .unwrap_or_else(|| "Merged Entity".to_string())
252 }
253
254 fn combine_descriptions(&self, entities: &[Entity]) -> String {
255 let descriptions: Vec<String> = entities
256 .iter()
257 .map(|e| {
258 if let Some(_desc) = e.mentions.first() {
259 format!("Entity '{}' mentioned in context", e.name)
260 } else {
261 format!("Entity '{}' of type {}", e.name, e.entity_type)
262 }
263 })
264 .collect();
265
266 if descriptions.is_empty() {
267 "Merged entity from multiple sources".to_string()
268 } else {
269 descriptions.join("; ")
270 }
271 }
272
273 fn build_merge_decision_prompt(&self, entities: &[Entity]) -> String {
274 let mut prompt = String::from(
275 "Analyze the following entities and determine if they represent the same real-world entity:\n\n"
276 );
277
278 for (i, entity) in entities.iter().enumerate() {
279 let description = if entity.mentions.is_empty() {
280 "No description".to_string()
281 } else {
282 format!("Mentioned {} times", entity.mentions.len())
283 };
284
285 prompt.push_str(&format!(
286 "Entity {}: {}\n Type: {}\n Confidence: {:.2}\n Description: {}\n\n",
287 i + 1,
288 entity.name,
289 entity.entity_type,
290 entity.confidence,
291 description
292 ));
293 }
294
295 prompt.push_str(
296 "Consider:\n\
297 1. Are these entities referring to the same real-world entity?\n\
298 2. Do they have compatible descriptions and contexts?\n\
299 3. If merged, what would be the best combined name and description?\n\n\
300 Respond with 'YES' if they should be merged, 'NO' if they should remain separate.\n\
301 Briefly explain your reasoning.",
302 );
303
304 prompt
305 }
306
307 async fn calculate_semantic_similarity(
308 &self,
309 entity1: &Entity,
310 entity2: &Entity,
311 ) -> Result<f64> {
312 let name_sim = self.string_similarity(&entity1.name, &entity2.name);
317
318 let type_sim = if entity1.entity_type == entity2.entity_type {
320 1.0
321 } else {
322 0.0
323 };
324
325 let combined_similarity = name_sim * 0.7 + type_sim * 0.3;
327
328 Ok(combined_similarity)
329 }
330
331 pub fn merge_entities(
333 &self,
334 entities: Vec<Entity>,
335 decision: &EntityMergeDecision,
336 ) -> Result<Entity> {
337 if entities.is_empty() {
338 return Err(crate::core::GraphRAGError::Config {
339 message: "No entities to merge".to_string(),
340 });
341 }
342
343 if !decision.should_merge {
344 return Ok(entities[0].clone());
345 }
346
347 let merged_name = decision
348 .merged_name
349 .clone()
350 .unwrap_or_else(|| self.select_best_name(&entities));
351
352 let mut all_mentions = Vec::new();
354 let mut total_confidence = 0.0;
355
356 for entity in &entities {
357 all_mentions.extend(entity.mentions.clone());
358 total_confidence += entity.confidence;
359 }
360
361 let avg_confidence = if entities.is_empty() {
362 0.0
363 } else {
364 total_confidence / entities.len() as f32
365 };
366
367 let merged_entity = Entity {
369 id: entities[0].id.clone(), name: merged_name,
371 entity_type: entities[0].entity_type.clone(),
372 confidence: avg_confidence.max(decision.confidence as f32),
373 mentions: all_mentions,
374 embedding: entities[0].embedding.clone(), };
376
377 Ok(merged_entity)
378 }
379
380 pub fn get_statistics(&self) -> MergingStatistics {
382 MergingStatistics {
383 similarity_threshold: self.similarity_threshold,
384 max_description_tokens: self.max_description_tokens,
385 uses_llm: self.use_llm_merging,
386 llm_available: self.llm_client.is_some(),
387 }
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct MergingStatistics {
394 pub similarity_threshold: f64,
396 pub max_description_tokens: usize,
398 pub uses_llm: bool,
400 pub llm_available: bool,
402}
403
404impl MergingStatistics {
405 #[allow(dead_code)]
407 pub fn print(&self) {
408 tracing::info!("Entity Merging Statistics");
409 tracing::info!(" Similarity threshold: {:.2}", self.similarity_threshold);
410 tracing::info!(" Max description tokens: {}", self.max_description_tokens);
411 tracing::info!(" Uses LLM: {}", self.uses_llm);
412 tracing::info!(" LLM available: {}", self.llm_available);
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use crate::core::{ChunkId, EntityId, EntityMention};
420
421 fn create_test_entities() -> Vec<Entity> {
422 vec![
423 Entity::new(
424 EntityId::new("entity1".to_string()),
425 "Apple Inc".to_string(),
426 "ORGANIZATION".to_string(),
427 0.9,
428 ),
429 Entity::new(
430 EntityId::new("entity2".to_string()),
431 "Apple Inc.".to_string(),
432 "ORGANIZATION".to_string(),
433 0.8,
434 ),
435 Entity::new(
436 EntityId::new("entity3".to_string()),
437 "Microsoft".to_string(),
438 "ORGANIZATION".to_string(),
439 0.9,
440 ),
441 ]
442 }
443
444 #[test]
445 fn test_semantic_entity_merger_creation() {
446 let merger = SemanticEntityMerger::new(0.8);
447 let stats = merger.get_statistics();
448
449 assert_eq!(stats.similarity_threshold, 0.8);
450 assert!(!stats.uses_llm);
451 assert!(!stats.llm_available);
452 }
453
454 #[tokio::test]
455 async fn test_entity_grouping() {
456 let merger = SemanticEntityMerger::new(0.7);
457 let entities = create_test_entities();
458
459 let groups = merger.group_similar_entities(&entities).await.unwrap();
460
461 assert!(!groups.is_empty());
463
464 let apple_group = groups
466 .iter()
467 .find(|group| group.iter().any(|e| e.name.contains("Apple")));
468
469 assert!(apple_group.is_some());
470 let apple_group = apple_group.unwrap();
471 assert_eq!(apple_group.len(), 2); }
473
474 #[test]
475 fn test_heuristic_merge_decision() {
476 let merger = SemanticEntityMerger::new(0.8);
477 let entities = vec![
478 Entity::new(
479 EntityId::new("entity1".to_string()),
480 "Apple Inc".to_string(),
481 "ORGANIZATION".to_string(),
482 0.9,
483 ),
484 Entity::new(
485 EntityId::new("entity2".to_string()),
486 "Apple Inc.".to_string(),
487 "ORGANIZATION".to_string(),
488 0.8,
489 ),
490 ];
491
492 let decision = merger.heuristic_merge_decision(&entities);
493
494 assert!(decision.should_merge);
495 assert!(decision.confidence > 0.8);
496 assert!(decision.merged_name.is_some());
497 }
498
499 #[test]
500 fn test_string_similarity() {
501 let merger = SemanticEntityMerger::new(0.8);
502
503 assert_eq!(merger.string_similarity("Apple", "Apple"), 1.0);
504 assert!(merger.string_similarity("Apple Inc", "Apple Inc.") > 0.8);
505 assert!(merger.string_similarity("Apple", "Microsoft") < 0.3);
506 }
507
508 #[test]
509 fn test_entity_merging() {
510 let merger = SemanticEntityMerger::new(0.8);
511
512 let entities = vec![
513 Entity::new(
514 EntityId::new("entity1".to_string()),
515 "Apple Inc".to_string(),
516 "ORGANIZATION".to_string(),
517 0.9,
518 )
519 .with_mentions(vec![EntityMention {
520 chunk_id: ChunkId::new("chunk1".to_string()),
521 start_offset: 0,
522 end_offset: 9,
523 confidence: 0.9,
524 }]),
525 Entity::new(
526 EntityId::new("entity2".to_string()),
527 "Apple Inc.".to_string(),
528 "ORGANIZATION".to_string(),
529 0.8,
530 )
531 .with_mentions(vec![EntityMention {
532 chunk_id: ChunkId::new("chunk2".to_string()),
533 start_offset: 0,
534 end_offset: 10,
535 confidence: 0.8,
536 }]),
537 ];
538
539 let decision = EntityMergeDecision {
540 should_merge: true,
541 confidence: 0.9,
542 reasoning: "Test merge".to_string(),
543 merged_name: Some("Apple Inc.".to_string()),
544 merged_description: Some("Merged Apple entity".to_string()),
545 };
546
547 let merged = merger.merge_entities(entities, &decision).unwrap();
548
549 assert_eq!(merged.name, "Apple Inc.");
550 assert_eq!(merged.mentions.len(), 2); assert!(merged.confidence >= 0.8);
552 }
553}