1use crate::core::{Entity, EntityId, KnowledgeGraph};
12use crate::Result;
13use std::collections::{HashMap, HashSet};
14
15#[derive(Debug, Clone)]
17pub struct EntityLinkingConfig {
18 pub min_similarity: f32,
20
21 pub case_insensitive: bool,
23
24 pub remove_punctuation: bool,
26
27 pub use_phonetic: bool,
29
30 pub min_jaccard_overlap: f32,
32
33 pub max_edit_distance: usize,
35
36 pub fuzzy_matching: bool,
38}
39
40impl Default for EntityLinkingConfig {
41 fn default() -> Self {
42 Self {
43 min_similarity: 0.85,
44 case_insensitive: true,
45 remove_punctuation: true,
46 use_phonetic: false,
47 min_jaccard_overlap: 0.6,
48 max_edit_distance: 2,
49 fuzzy_matching: true,
50 }
51 }
52}
53
54pub struct StringSimilarityLinker {
56 config: EntityLinkingConfig,
57}
58
59impl StringSimilarityLinker {
60 pub fn new(config: EntityLinkingConfig) -> Self {
62 Self { config }
63 }
64
65 pub fn link_entities(&self, graph: &KnowledgeGraph) -> Result<HashMap<EntityId, EntityId>> {
69 let mut links: HashMap<EntityId, EntityId> = HashMap::new();
70 let entities: Vec<Entity> = graph.entities().cloned().collect();
71
72 let mut clusters: Vec<Vec<usize>> = Vec::new();
74 let mut clustered: HashSet<usize> = HashSet::new();
75
76 for i in 0..entities.len() {
77 if clustered.contains(&i) {
78 continue;
79 }
80
81 let mut cluster = vec![i];
82 clustered.insert(i);
83
84 for j in (i + 1)..entities.len() {
85 if clustered.contains(&j) {
86 continue;
87 }
88
89 let similarity = self.compute_similarity(&entities[i], &entities[j]);
90
91 if similarity >= self.config.min_similarity {
92 cluster.push(j);
93 clustered.insert(j);
94 }
95 }
96
97 if cluster.len() > 1 {
98 clusters.push(cluster);
99 }
100 }
101
102 for cluster in clusters {
104 let canonical_idx = cluster
105 .iter()
106 .max_by(|&&a, &&b| {
107 entities[a]
108 .confidence
109 .partial_cmp(&entities[b].confidence)
110 .unwrap_or(std::cmp::Ordering::Equal)
111 })
112 .unwrap();
113
114 let canonical_id = &entities[*canonical_idx].id;
115
116 for &entity_idx in &cluster {
117 if entity_idx != *canonical_idx {
118 links.insert(entities[entity_idx].id.clone(), canonical_id.clone());
119 }
120 }
121 }
122
123 Ok(links)
124 }
125
126 fn compute_similarity(&self, e1: &Entity, e2: &Entity) -> f32 {
128 if e1.entity_type != e2.entity_type {
130 return 0.0;
131 }
132
133 let name1 = self.normalize_string(&e1.name);
134 let name2 = self.normalize_string(&e2.name);
135
136 if name1 == name2 {
138 return 1.0;
139 }
140
141 let mut scores = Vec::new();
142
143 if self.config.fuzzy_matching {
145 let lev_sim = self.levenshtein_similarity(&name1, &name2);
146 scores.push(lev_sim);
147 }
148
149 let jaro_sim = self.jaro_winkler_similarity(&name1, &name2);
151 scores.push(jaro_sim);
152
153 let jaccard_sim = self.jaccard_similarity(&name1, &name2);
155 scores.push(jaccard_sim);
156
157 if self.config.use_phonetic {
159 let phonetic_sim = self.phonetic_similarity(&name1, &name2);
160 scores.push(phonetic_sim);
161 }
162
163 scores.into_iter().fold(0.0, f32::max)
165 }
166
167 fn normalize_string(&self, s: &str) -> String {
169 let mut normalized = s.to_string();
170
171 if self.config.case_insensitive {
172 normalized = normalized.to_lowercase();
173 }
174
175 if self.config.remove_punctuation {
176 normalized = normalized
177 .chars()
178 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
179 .collect();
180 }
181
182 normalized.split_whitespace().collect::<Vec<_>>().join(" ")
184 }
185
186 fn levenshtein_similarity(&self, s1: &str, s2: &str) -> f32 {
188 let distance = self.levenshtein_distance(s1, s2);
189
190 if distance > self.config.max_edit_distance {
191 return 0.0;
192 }
193
194 let max_len = s1.len().max(s2.len());
195 if max_len == 0 {
196 return 1.0;
197 }
198
199 1.0 - (distance as f32 / max_len as f32)
200 }
201
202 fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
204 let len1 = s1.chars().count();
205 let len2 = s2.chars().count();
206
207 if len1 == 0 {
208 return len2;
209 }
210 if len2 == 0 {
211 return len1;
212 }
213
214 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
215
216 #[allow(clippy::needless_range_loop)]
218 for i in 0..=len1 {
219 matrix[i][0] = i;
220 }
221 #[allow(clippy::needless_range_loop)]
222 for j in 0..=len2 {
223 matrix[0][j] = j;
224 }
225
226 let s1_chars: Vec<char> = s1.chars().collect();
227 let s2_chars: Vec<char> = s2.chars().collect();
228
229 for i in 1..=len1 {
231 for j in 1..=len2 {
232 let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
233 0
234 } else {
235 1
236 };
237
238 matrix[i][j] = (matrix[i - 1][j] + 1) .min(matrix[i][j - 1] + 1) .min(matrix[i - 1][j - 1] + cost); }
242 }
243
244 matrix[len1][len2]
245 }
246
247 fn jaro_winkler_similarity(&self, s1: &str, s2: &str) -> f32 {
249 let jaro = self.jaro_similarity(s1, s2);
250
251 let prefix_len = s1
253 .chars()
254 .zip(s2.chars())
255 .take(4)
256 .take_while(|(c1, c2)| c1 == c2)
257 .count();
258
259 jaro + (prefix_len as f32 * 0.1 * (1.0 - jaro))
260 }
261
262 fn jaro_similarity(&self, s1: &str, s2: &str) -> f32 {
264 let s1_chars: Vec<char> = s1.chars().collect();
265 let s2_chars: Vec<char> = s2.chars().collect();
266
267 let len1 = s1_chars.len();
268 let len2 = s2_chars.len();
269
270 if len1 == 0 && len2 == 0 {
271 return 1.0;
272 }
273 if len1 == 0 || len2 == 0 {
274 return 0.0;
275 }
276
277 let match_distance = (len1.max(len2) / 2).saturating_sub(1);
278
279 let mut s1_matches = vec![false; len1];
280 let mut s2_matches = vec![false; len2];
281
282 let mut matches = 0;
283 let mut transpositions = 0;
284
285 for i in 0..len1 {
287 let start = i.saturating_sub(match_distance);
288 let end = (i + match_distance + 1).min(len2);
289
290 for j in start..end {
291 if s2_matches[j] || s1_chars[i] != s2_chars[j] {
292 continue;
293 }
294 s1_matches[i] = true;
295 s2_matches[j] = true;
296 matches += 1;
297 break;
298 }
299 }
300
301 if matches == 0 {
302 return 0.0;
303 }
304
305 let mut k = 0;
307 for i in 0..len1 {
308 if !s1_matches[i] {
309 continue;
310 }
311 while !s2_matches[k] {
312 k += 1;
313 }
314 if s1_chars[i] != s2_chars[k] {
315 transpositions += 1;
316 }
317 k += 1;
318 }
319
320 let m = matches as f32;
321 (m / len1 as f32 + m / len2 as f32 + (m - transpositions as f32 / 2.0) / m) / 3.0
322 }
323
324 fn jaccard_similarity(&self, s1: &str, s2: &str) -> f32 {
326 let tokens1: HashSet<&str> = s1.split_whitespace().collect();
327 let tokens2: HashSet<&str> = s2.split_whitespace().collect();
328
329 if tokens1.is_empty() && tokens2.is_empty() {
330 return 1.0;
331 }
332
333 let intersection = tokens1.intersection(&tokens2).count();
334 let union = tokens1.union(&tokens2).count();
335
336 if union == 0 {
337 return 0.0;
338 }
339
340 intersection as f32 / union as f32
341 }
342
343 fn phonetic_similarity(&self, s1: &str, s2: &str) -> f32 {
345 let soundex1 = self.soundex(s1);
346 let soundex2 = self.soundex(s2);
347
348 if soundex1 == soundex2 {
349 0.9 } else {
351 0.0
352 }
353 }
354
355 fn soundex(&self, s: &str) -> String {
357 if s.is_empty() {
358 return String::new();
359 }
360
361 let chars: Vec<char> = s.to_uppercase().chars().collect();
362 let mut result = String::new();
363
364 if let Some(&first) = chars.first() {
366 if first.is_alphabetic() {
367 result.push(first);
368 }
369 }
370
371 let mut prev_code = self.soundex_code(chars[0]);
372
373 for &c in chars.iter().skip(1) {
374 let code = self.soundex_code(c);
375
376 if code != '0' && code != prev_code {
377 result.push(code);
378 prev_code = code;
379 }
380
381 if result.len() >= 4 {
382 break;
383 }
384 }
385
386 while result.len() < 4 {
388 result.push('0');
389 }
390
391 result
392 }
393
394 fn soundex_code(&self, c: char) -> char {
396 match c.to_ascii_uppercase() {
397 'B' | 'F' | 'P' | 'V' => '1',
398 'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
399 'D' | 'T' => '3',
400 'L' => '4',
401 'M' | 'N' => '5',
402 'R' => '6',
403 _ => '0',
404 }
405 }
406
407 pub fn find_canonical_entity(
409 &self,
410 mention: &str,
411 entity_type: &str,
412 candidates: &[Entity],
413 ) -> Option<EntityId> {
414 let normalized_mention = self.normalize_string(mention);
415
416 let mut best_match: Option<(EntityId, f32)> = None;
417
418 for candidate in candidates {
419 if candidate.entity_type != entity_type {
420 continue;
421 }
422
423 let normalized_candidate = self.normalize_string(&candidate.name);
424
425 if normalized_mention == normalized_candidate {
427 return Some(candidate.id.clone());
428 }
429
430 let mut scores = Vec::new();
432
433 if self.config.fuzzy_matching {
434 let lev_sim =
435 self.levenshtein_similarity(&normalized_mention, &normalized_candidate);
436 scores.push(lev_sim);
437 }
438
439 let jaro_sim = self.jaro_winkler_similarity(&normalized_mention, &normalized_candidate);
440 scores.push(jaro_sim);
441
442 let jaccard_sim = self.jaccard_similarity(&normalized_mention, &normalized_candidate);
443 scores.push(jaccard_sim);
444
445 if self.config.use_phonetic {
446 let phonetic_sim =
447 self.phonetic_similarity(&normalized_mention, &normalized_candidate);
448 scores.push(phonetic_sim);
449 }
450
451 let max_similarity = scores.into_iter().fold(0.0, f32::max);
452
453 if max_similarity >= self.config.min_similarity {
454 if let Some((_, current_best_score)) = &best_match {
455 if max_similarity > *current_best_score {
456 best_match = Some((candidate.id.clone(), max_similarity));
457 }
458 } else {
459 best_match = Some((candidate.id.clone(), max_similarity));
460 }
461 }
462 }
463
464 best_match.map(|(id, _)| id)
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use crate::core::{ChunkId, EntityMention};
472
473 #[test]
474 fn test_levenshtein_distance() {
475 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
476
477 assert_eq!(linker.levenshtein_distance("kitten", "sitting"), 3);
478 assert_eq!(linker.levenshtein_distance("saturday", "sunday"), 3);
479 assert_eq!(linker.levenshtein_distance("", ""), 0);
480 assert_eq!(linker.levenshtein_distance("abc", "abc"), 0);
481 }
482
483 #[test]
484 fn test_jaro_winkler_similarity() {
485 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
486
487 let sim = linker.jaro_winkler_similarity("martha", "marhta");
488 assert!(sim > 0.9, "Expected high similarity for transposition");
489
490 let sim2 = linker.jaro_winkler_similarity("dwayne", "duane");
491 assert!(sim2 > 0.8, "Expected decent similarity");
492
493 let sim3 = linker.jaro_winkler_similarity("abc", "xyz");
494 assert!(sim3 < 0.3, "Expected low similarity");
495 }
496
497 #[test]
498 fn test_jaccard_similarity() {
499 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
500
501 let sim = linker.jaccard_similarity("the quick brown fox", "the lazy brown dog");
502 assert!(sim > 0.3 && sim < 0.5, "Expected moderate similarity");
503
504 let sim2 = linker.jaccard_similarity("apple orange banana", "apple orange banana");
505 assert!((sim2 - 1.0).abs() < 0.001, "Expected perfect match");
506 }
507
508 #[test]
509 fn test_soundex() {
510 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
511
512 assert_eq!(linker.soundex("Robert"), "R163");
513 assert_eq!(linker.soundex("Rupert"), "R163");
514 assert_eq!(linker.soundex("Rubin"), "R150");
515 assert_eq!(linker.soundex("Smith"), "S530");
516 assert_eq!(linker.soundex("Smyth"), "S530");
517 }
518
519 #[test]
520 fn test_entity_normalization() {
521 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
522
523 assert_eq!(linker.normalize_string("John Smith!"), "john smith");
524 assert_eq!(linker.normalize_string("ACME Corp."), "acme corp");
525 }
526
527 #[test]
528 fn test_find_canonical_entity() {
529 let config = EntityLinkingConfig {
530 min_similarity: 0.8,
531 ..Default::default()
532 };
533 let linker = StringSimilarityLinker::new(config);
534
535 let candidates = vec![
536 Entity::new(
537 EntityId::new("e1".to_string()),
538 "John Smith".to_string(),
539 "PERSON".to_string(),
540 0.9,
541 ),
542 Entity::new(
543 EntityId::new("e2".to_string()),
544 "Acme Corp".to_string(),
545 "ORG".to_string(),
546 0.85,
547 ),
548 ];
549
550 let result = linker.find_canonical_entity("Jon Smith", "PERSON", &candidates);
552 assert!(result.is_some());
553 assert_eq!(result.unwrap(), EntityId::new("e1".to_string()));
554
555 let result = linker.find_canonical_entity("John Smith", "ORG", &candidates);
557 assert!(result.is_none());
558
559 let result = linker.find_canonical_entity("Jhon Smith", "PERSON", &candidates);
561 assert!(result.is_some());
562 }
563
564 #[test]
565 fn test_link_similar_entities() {
566 let config = EntityLinkingConfig {
567 min_similarity: 0.85,
568 ..Default::default()
569 };
570 let linker = StringSimilarityLinker::new(config);
571
572 let mut graph = KnowledgeGraph::new();
573
574 let _ = graph.add_entity(Entity {
576 id: EntityId::new("e1".to_string()),
577 name: "New York".to_string(),
578 entity_type: "LOCATION".to_string(),
579 confidence: 0.9,
580 mentions: vec![EntityMention {
581 chunk_id: ChunkId::new("chunk1".to_string()),
582 start_offset: 0,
583 end_offset: 8,
584 confidence: 0.9,
585 }],
586 embedding: None,
587 first_mentioned: None,
588 last_mentioned: None,
589 temporal_validity: None,
590 });
591
592 let _ = graph.add_entity(Entity {
593 id: EntityId::new("e2".to_string()),
594 name: "New York City".to_string(),
595 entity_type: "LOCATION".to_string(),
596 confidence: 0.85,
597 mentions: vec![EntityMention {
598 chunk_id: ChunkId::new("chunk2".to_string()),
599 start_offset: 0,
600 end_offset: 13,
601 confidence: 0.85,
602 }],
603 embedding: None,
604 first_mentioned: None,
605 last_mentioned: None,
606 temporal_validity: None,
607 });
608
609 let links = linker.link_entities(&graph).unwrap();
610
611 assert!(links.len() > 0, "Expected some entities to be linked");
613 }
614}