1use crate::{
2 core::{Entity, Result},
3 ollama::OllamaClient,
4};
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct EntityMergeDecision {
11 pub should_merge: bool,
13 pub confidence: f64,
15 pub reasoning: String,
17 pub merged_description: Option<String>,
19 pub merged_name: Option<String>,
21}
22
23#[derive(Clone)]
25pub struct SemanticEntityMerger {
26 llm_client: Option<OllamaClient>,
27 similarity_threshold: f64,
28 max_description_tokens: usize,
29 use_llm_merging: bool,
30}
31
32impl SemanticEntityMerger {
33 pub fn new(similarity_threshold: f64) -> Self {
35 Self {
36 llm_client: None,
37 similarity_threshold,
38 max_description_tokens: 512,
39 use_llm_merging: false,
40 }
41 }
42
43 pub fn with_llm_client(mut self, client: OllamaClient) -> Self {
45 self.llm_client = Some(client);
46 self.use_llm_merging = true;
47 self
48 }
49
50 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
52 self.max_description_tokens = max_tokens;
53 self
54 }
55
56 pub async fn group_similar_entities(&self, entities: &[Entity]) -> Result<Vec<Vec<Entity>>> {
58 let mut similarity_groups = Vec::new();
59 let mut processed = HashSet::new();
60
61 for (i, entity1) in entities.iter().enumerate() {
62 if processed.contains(&i) {
63 continue;
64 }
65
66 let mut group = vec![entity1.clone()];
67 processed.insert(i);
68
69 for (j, entity2) in entities.iter().enumerate() {
71 if i == j || processed.contains(&j) {
72 continue;
73 }
74
75 let similarity = self.calculate_semantic_similarity(entity1, entity2).await?;
76 if similarity > self.similarity_threshold {
77 group.push(entity2.clone());
78 processed.insert(j);
79 }
80 }
81
82 if group.len() > 1 {
83 similarity_groups.push(group);
84 }
85 }
86
87 Ok(similarity_groups)
88 }
89
90 pub async fn decide_merge(&self, entity_group: &[Entity]) -> Result<EntityMergeDecision> {
92 if !self.use_llm_merging {
93 return Ok(self.heuristic_merge_decision(entity_group));
95 }
96
97 if let Some(llm_client) = &self.llm_client {
98 let prompt = self.build_merge_decision_prompt(entity_group);
99
100 match self.try_llm_merge_decision(llm_client, &prompt).await {
102 Ok(decision) => Ok(decision),
103 Err(_) => {
104 #[cfg(feature = "tracing")]
105 tracing::warn!("LLM merge decision failed, falling back to heuristics");
106 Ok(self.heuristic_merge_decision(entity_group))
107 },
108 }
109 } else {
110 Ok(self.heuristic_merge_decision(entity_group))
111 }
112 }
113
114 async fn try_llm_merge_decision(
115 &self,
116 _llm_client: &OllamaClient,
117 prompt: &str,
118 ) -> Result<EntityMergeDecision> {
119 let _response = prompt; Ok(EntityMergeDecision {
125 should_merge: true,
126 confidence: 0.8,
127 reasoning: "LLM analysis suggests these entities should be merged".to_string(),
128 merged_name: Some("Merged Entity".to_string()),
129 merged_description: Some("Merged based on LLM analysis".to_string()),
130 })
131 }
132
133 fn heuristic_merge_decision(&self, entity_group: &[Entity]) -> EntityMergeDecision {
134 if entity_group.len() < 2 {
135 return EntityMergeDecision {
136 should_merge: false,
137 confidence: 1.0,
138 reasoning: "Only one entity in group".to_string(),
139 merged_name: None,
140 merged_description: None,
141 };
142 }
143
144 let first_entity = &entity_group[0];
146 let all_same_type = entity_group
147 .iter()
148 .all(|e| e.entity_type == first_entity.entity_type);
149
150 if all_same_type {
151 let name_similarity = self.calculate_name_similarity_heuristic(entity_group);
152
153 if name_similarity > 0.8 {
154 let merged_name = self.select_best_name(entity_group);
155 let merged_description = self.combine_descriptions(entity_group);
156
157 EntityMergeDecision {
158 should_merge: true,
159 confidence: name_similarity,
160 reasoning: format!(
161 "High name similarity ({name_similarity:.2}) and matching types"
162 ),
163 merged_name: Some(merged_name),
164 merged_description: Some(merged_description),
165 }
166 } else {
167 EntityMergeDecision {
168 should_merge: false,
169 confidence: 1.0 - name_similarity,
170 reasoning: format!("Low name similarity ({name_similarity:.2})"),
171 merged_name: None,
172 merged_description: None,
173 }
174 }
175 } else {
176 EntityMergeDecision {
177 should_merge: false,
178 confidence: 1.0,
179 reasoning: "Different entity types".to_string(),
180 merged_name: None,
181 merged_description: None,
182 }
183 }
184 }
185
186 fn calculate_name_similarity_heuristic(&self, entities: &[Entity]) -> f64 {
187 if entities.len() < 2 {
188 return 1.0;
189 }
190
191 let mut total_similarity = 0.0;
192 let mut comparisons = 0;
193
194 for i in 0..entities.len() {
195 for j in i + 1..entities.len() {
196 let similarity = self.string_similarity(&entities[i].name, &entities[j].name);
197 total_similarity += similarity;
198 comparisons += 1;
199 }
200 }
201
202 if comparisons > 0 {
203 total_similarity / comparisons as f64
204 } else {
205 0.0
206 }
207 }
208
209 fn string_similarity(&self, s1: &str, s2: &str) -> f64 {
210 let s1_lower = s1.to_lowercase();
211 let s2_lower = s2.to_lowercase();
212
213 if s1_lower == s2_lower {
215 return 1.0;
216 }
217
218 if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
220 return 0.9;
221 }
222
223 let words1: HashSet<&str> = s1_lower.split_whitespace().collect();
225 let words2: HashSet<&str> = s2_lower.split_whitespace().collect();
226
227 let intersection = words1.intersection(&words2).count();
228 let union = words1.union(&words2).count();
229
230 if union == 0 {
231 0.0
232 } else {
233 intersection as f64 / union as f64
234 }
235 }
236
237 fn select_best_name(&self, entities: &[Entity]) -> String {
238 entities
240 .iter()
241 .max_by(|a, b| {
242 let length_cmp = a.name.len().cmp(&b.name.len());
243 if length_cmp == std::cmp::Ordering::Equal {
244 a.confidence
245 .partial_cmp(&b.confidence)
246 .unwrap_or(std::cmp::Ordering::Equal)
247 } else {
248 length_cmp
249 }
250 })
251 .map(|e| e.name.clone())
252 .unwrap_or_else(|| "Merged Entity".to_string())
253 }
254
255 fn combine_descriptions(&self, entities: &[Entity]) -> String {
256 let descriptions: Vec<String> = entities
257 .iter()
258 .map(|e| {
259 if let Some(_desc) = e.mentions.first() {
260 format!("Entity '{}' mentioned in context", e.name)
261 } else {
262 format!("Entity '{}' of type {}", e.name, e.entity_type)
263 }
264 })
265 .collect();
266
267 if descriptions.is_empty() {
268 "Merged entity from multiple sources".to_string()
269 } else {
270 descriptions.join("; ")
271 }
272 }
273
274 fn build_merge_decision_prompt(&self, entities: &[Entity]) -> String {
275 let mut prompt = String::from(
276 "Analyze the following entities and determine if they represent the same real-world entity:\n\n"
277 );
278
279 for (i, entity) in entities.iter().enumerate() {
280 let description = if entity.mentions.is_empty() {
281 "No description".to_string()
282 } else {
283 format!("Mentioned {} times", entity.mentions.len())
284 };
285
286 prompt.push_str(&format!(
287 "Entity {}: {}\n Type: {}\n Confidence: {:.2}\n Description: {}\n\n",
288 i + 1,
289 entity.name,
290 entity.entity_type,
291 entity.confidence,
292 description
293 ));
294 }
295
296 prompt.push_str(
297 "Consider:\n\
298 1. Are these entities referring to the same real-world entity?\n\
299 2. Do they have compatible descriptions and contexts?\n\
300 3. If merged, what would be the best combined name and description?\n\n\
301 Respond with 'YES' if they should be merged, 'NO' if they should remain separate.\n\
302 Briefly explain your reasoning.",
303 );
304
305 prompt
306 }
307
308 async fn calculate_semantic_similarity(
309 &self,
310 entity1: &Entity,
311 entity2: &Entity,
312 ) -> Result<f64> {
313 let name_sim = self.string_similarity(&entity1.name, &entity2.name);
318
319 let type_sim = if entity1.entity_type == entity2.entity_type {
321 1.0
322 } else {
323 0.0
324 };
325
326 let combined_similarity = name_sim * 0.7 + type_sim * 0.3;
328
329 Ok(combined_similarity)
330 }
331
332 pub fn merge_entities(
334 &self,
335 entities: Vec<Entity>,
336 decision: &EntityMergeDecision,
337 ) -> Result<Entity> {
338 if entities.is_empty() {
339 return Err(crate::core::GraphRAGError::Config {
340 message: "No entities to merge".to_string(),
341 });
342 }
343
344 if !decision.should_merge {
345 return Ok(entities[0].clone());
346 }
347
348 let merged_name = decision
349 .merged_name
350 .clone()
351 .unwrap_or_else(|| self.select_best_name(&entities));
352
353 let mut all_mentions = Vec::new();
355 let mut total_confidence = 0.0;
356
357 for entity in &entities {
358 all_mentions.extend(entity.mentions.clone());
359 total_confidence += entity.confidence;
360 }
361
362 let avg_confidence = if entities.is_empty() {
363 0.0
364 } else {
365 total_confidence / entities.len() as f32
366 };
367
368 let merged_entity = Entity {
370 id: entities[0].id.clone(), name: merged_name,
372 entity_type: entities[0].entity_type.clone(),
373 confidence: avg_confidence.max(decision.confidence as f32),
374 mentions: all_mentions,
375 embedding: entities[0].embedding.clone(), first_mentioned: None,
377 last_mentioned: None,
378 temporal_validity: None,
379 };
380
381 Ok(merged_entity)
382 }
383
384 pub fn get_statistics(&self) -> MergingStatistics {
386 MergingStatistics {
387 similarity_threshold: self.similarity_threshold,
388 max_description_tokens: self.max_description_tokens,
389 uses_llm: self.use_llm_merging,
390 llm_available: self.llm_client.is_some(),
391 }
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct MergingStatistics {
398 pub similarity_threshold: f64,
400 pub max_description_tokens: usize,
402 pub uses_llm: bool,
404 pub llm_available: bool,
406}
407
408impl MergingStatistics {
409 #[allow(dead_code)]
411 pub fn print(&self) {
412 #[cfg(feature = "tracing")]
413 tracing::info!("Entity Merging Statistics");
414 #[cfg(feature = "tracing")]
415 tracing::info!(" Similarity threshold: {:.2}", self.similarity_threshold);
416 #[cfg(feature = "tracing")]
417 tracing::info!(" Max description tokens: {}", self.max_description_tokens);
418 #[cfg(feature = "tracing")]
419 tracing::info!(" Uses LLM: {}", self.uses_llm);
420 #[cfg(feature = "tracing")]
421 tracing::info!(" LLM available: {}", self.llm_available);
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use crate::core::{ChunkId, EntityId, EntityMention};
429
430 fn create_test_entities() -> Vec<Entity> {
431 vec![
432 Entity::new(
433 EntityId::new("entity1".to_string()),
434 "Apple Inc".to_string(),
435 "ORGANIZATION".to_string(),
436 0.9,
437 ),
438 Entity::new(
439 EntityId::new("entity2".to_string()),
440 "Apple Inc.".to_string(),
441 "ORGANIZATION".to_string(),
442 0.8,
443 ),
444 Entity::new(
445 EntityId::new("entity3".to_string()),
446 "Microsoft".to_string(),
447 "ORGANIZATION".to_string(),
448 0.9,
449 ),
450 ]
451 }
452
453 #[test]
454 fn test_semantic_entity_merger_creation() {
455 let merger = SemanticEntityMerger::new(0.8);
456 let stats = merger.get_statistics();
457
458 assert_eq!(stats.similarity_threshold, 0.8);
459 assert!(!stats.uses_llm);
460 assert!(!stats.llm_available);
461 }
462
463 #[tokio::test]
464 async fn test_entity_grouping() {
465 let merger = SemanticEntityMerger::new(0.7);
466 let entities = create_test_entities();
467
468 let groups = merger.group_similar_entities(&entities).await.unwrap();
469
470 assert!(!groups.is_empty());
472
473 let apple_group = groups
475 .iter()
476 .find(|group| group.iter().any(|e| e.name.contains("Apple")));
477
478 assert!(apple_group.is_some());
479 let apple_group = apple_group.unwrap();
480 assert_eq!(apple_group.len(), 2); }
482
483 #[test]
484 fn test_heuristic_merge_decision() {
485 let merger = SemanticEntityMerger::new(0.8);
486 let entities = vec![
487 Entity::new(
488 EntityId::new("entity1".to_string()),
489 "Apple Inc".to_string(),
490 "ORGANIZATION".to_string(),
491 0.9,
492 ),
493 Entity::new(
494 EntityId::new("entity2".to_string()),
495 "Apple Inc.".to_string(),
496 "ORGANIZATION".to_string(),
497 0.8,
498 ),
499 ];
500
501 let decision = merger.heuristic_merge_decision(&entities);
502
503 assert!(decision.should_merge);
504 assert!(decision.confidence > 0.8);
505 assert!(decision.merged_name.is_some());
506 }
507
508 #[test]
509 fn test_string_similarity() {
510 let merger = SemanticEntityMerger::new(0.8);
511
512 assert_eq!(merger.string_similarity("Apple", "Apple"), 1.0);
513 assert!(merger.string_similarity("Apple Inc", "Apple Inc.") > 0.8);
514 assert!(merger.string_similarity("Apple", "Microsoft") < 0.3);
515 }
516
517 #[test]
518 fn test_entity_merging() {
519 let merger = SemanticEntityMerger::new(0.8);
520
521 let entities = vec![
522 Entity::new(
523 EntityId::new("entity1".to_string()),
524 "Apple Inc".to_string(),
525 "ORGANIZATION".to_string(),
526 0.9,
527 )
528 .with_mentions(vec![EntityMention {
529 chunk_id: ChunkId::new("chunk1".to_string()),
530 start_offset: 0,
531 end_offset: 9,
532 confidence: 0.9,
533 }]),
534 Entity::new(
535 EntityId::new("entity2".to_string()),
536 "Apple Inc.".to_string(),
537 "ORGANIZATION".to_string(),
538 0.8,
539 )
540 .with_mentions(vec![EntityMention {
541 chunk_id: ChunkId::new("chunk2".to_string()),
542 start_offset: 0,
543 end_offset: 10,
544 confidence: 0.8,
545 }]),
546 ];
547
548 let decision = EntityMergeDecision {
549 should_merge: true,
550 confidence: 0.9,
551 reasoning: "Test merge".to_string(),
552 merged_name: Some("Apple Inc.".to_string()),
553 merged_description: Some("Merged Apple entity".to_string()),
554 };
555
556 let merged = merger.merge_entities(entities, &decision).unwrap();
557
558 assert_eq!(merged.name, "Apple Inc.");
559 assert_eq!(merged.mentions.len(), 2); assert!(merged.confidence >= 0.8);
561 }
562}