1use crate::core::{Entity, EntityId, KnowledgeGraph};
12use crate::Result;
13use std::collections::{HashMap, HashSet};
14
15#[derive(Debug, Clone)]
17pub struct EntityLinkingConfig {
18 pub min_similarity: f32,
20
21 pub case_insensitive: bool,
23
24 pub remove_punctuation: bool,
26
27 pub use_phonetic: bool,
29
30 pub min_jaccard_overlap: f32,
32
33 pub max_edit_distance: usize,
35
36 pub fuzzy_matching: bool,
38}
39
40impl Default for EntityLinkingConfig {
41 fn default() -> Self {
42 Self {
43 min_similarity: 0.85,
44 case_insensitive: true,
45 remove_punctuation: true,
46 use_phonetic: false,
47 min_jaccard_overlap: 0.6,
48 max_edit_distance: 2,
49 fuzzy_matching: true,
50 }
51 }
52}
53
54pub struct StringSimilarityLinker {
56 config: EntityLinkingConfig,
57}
58
59impl StringSimilarityLinker {
60 pub fn new(config: EntityLinkingConfig) -> Self {
62 Self { config }
63 }
64
65 pub fn link_entities(
69 &self,
70 graph: &KnowledgeGraph,
71 ) -> Result<HashMap<EntityId, EntityId>> {
72 let mut links: HashMap<EntityId, EntityId> = HashMap::new();
73 let entities: Vec<Entity> = graph.entities().cloned().collect();
74
75 let mut clusters: Vec<Vec<usize>> = Vec::new();
77 let mut clustered: HashSet<usize> = HashSet::new();
78
79 for i in 0..entities.len() {
80 if clustered.contains(&i) {
81 continue;
82 }
83
84 let mut cluster = vec![i];
85 clustered.insert(i);
86
87 for j in (i + 1)..entities.len() {
88 if clustered.contains(&j) {
89 continue;
90 }
91
92 let similarity = self.compute_similarity(&entities[i], &entities[j]);
93
94 if similarity >= self.config.min_similarity {
95 cluster.push(j);
96 clustered.insert(j);
97 }
98 }
99
100 if cluster.len() > 1 {
101 clusters.push(cluster);
102 }
103 }
104
105 for cluster in clusters {
107 let canonical_idx = cluster
108 .iter()
109 .max_by(|&&a, &&b| {
110 entities[a]
111 .confidence
112 .partial_cmp(&entities[b].confidence)
113 .unwrap_or(std::cmp::Ordering::Equal)
114 })
115 .unwrap();
116
117 let canonical_id = &entities[*canonical_idx].id;
118
119 for &entity_idx in &cluster {
120 if entity_idx != *canonical_idx {
121 links.insert(entities[entity_idx].id.clone(), canonical_id.clone());
122 }
123 }
124 }
125
126 Ok(links)
127 }
128
129 fn compute_similarity(&self, e1: &Entity, e2: &Entity) -> f32 {
131 if e1.entity_type != e2.entity_type {
133 return 0.0;
134 }
135
136 let name1 = self.normalize_string(&e1.name);
137 let name2 = self.normalize_string(&e2.name);
138
139 if name1 == name2 {
141 return 1.0;
142 }
143
144 let mut scores = Vec::new();
145
146 if self.config.fuzzy_matching {
148 let lev_sim = self.levenshtein_similarity(&name1, &name2);
149 scores.push(lev_sim);
150 }
151
152 let jaro_sim = self.jaro_winkler_similarity(&name1, &name2);
154 scores.push(jaro_sim);
155
156 let jaccard_sim = self.jaccard_similarity(&name1, &name2);
158 scores.push(jaccard_sim);
159
160 if self.config.use_phonetic {
162 let phonetic_sim = self.phonetic_similarity(&name1, &name2);
163 scores.push(phonetic_sim);
164 }
165
166 scores.into_iter().fold(0.0, f32::max)
168 }
169
170 fn normalize_string(&self, s: &str) -> String {
172 let mut normalized = s.to_string();
173
174 if self.config.case_insensitive {
175 normalized = normalized.to_lowercase();
176 }
177
178 if self.config.remove_punctuation {
179 normalized = normalized
180 .chars()
181 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
182 .collect();
183 }
184
185 normalized
187 .split_whitespace()
188 .collect::<Vec<_>>()
189 .join(" ")
190 }
191
192 fn levenshtein_similarity(&self, s1: &str, s2: &str) -> f32 {
194 let distance = self.levenshtein_distance(s1, s2);
195
196 if distance > self.config.max_edit_distance {
197 return 0.0;
198 }
199
200 let max_len = s1.len().max(s2.len());
201 if max_len == 0 {
202 return 1.0;
203 }
204
205 1.0 - (distance as f32 / max_len as f32)
206 }
207
208 fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
210 let len1 = s1.chars().count();
211 let len2 = s2.chars().count();
212
213 if len1 == 0 {
214 return len2;
215 }
216 if len2 == 0 {
217 return len1;
218 }
219
220 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
221
222 for i in 0..=len1 {
224 matrix[i][0] = i;
225 }
226 for j in 0..=len2 {
227 matrix[0][j] = j;
228 }
229
230 let s1_chars: Vec<char> = s1.chars().collect();
231 let s2_chars: Vec<char> = s2.chars().collect();
232
233 for i in 1..=len1 {
235 for j in 1..=len2 {
236 let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
237 0
238 } else {
239 1
240 };
241
242 matrix[i][j] = (matrix[i - 1][j] + 1) .min(matrix[i][j - 1] + 1) .min(matrix[i - 1][j - 1] + cost); }
246 }
247
248 matrix[len1][len2]
249 }
250
251 fn jaro_winkler_similarity(&self, s1: &str, s2: &str) -> f32 {
253 let jaro = self.jaro_similarity(s1, s2);
254
255 let prefix_len = s1
257 .chars()
258 .zip(s2.chars())
259 .take(4)
260 .take_while(|(c1, c2)| c1 == c2)
261 .count();
262
263 jaro + (prefix_len as f32 * 0.1 * (1.0 - jaro))
264 }
265
266 fn jaro_similarity(&self, s1: &str, s2: &str) -> f32 {
268 let s1_chars: Vec<char> = s1.chars().collect();
269 let s2_chars: Vec<char> = s2.chars().collect();
270
271 let len1 = s1_chars.len();
272 let len2 = s2_chars.len();
273
274 if len1 == 0 && len2 == 0 {
275 return 1.0;
276 }
277 if len1 == 0 || len2 == 0 {
278 return 0.0;
279 }
280
281 let match_distance = (len1.max(len2) / 2).saturating_sub(1);
282
283 let mut s1_matches = vec![false; len1];
284 let mut s2_matches = vec![false; len2];
285
286 let mut matches = 0;
287 let mut transpositions = 0;
288
289 for i in 0..len1 {
291 let start = i.saturating_sub(match_distance);
292 let end = (i + match_distance + 1).min(len2);
293
294 for j in start..end {
295 if s2_matches[j] || s1_chars[i] != s2_chars[j] {
296 continue;
297 }
298 s1_matches[i] = true;
299 s2_matches[j] = true;
300 matches += 1;
301 break;
302 }
303 }
304
305 if matches == 0 {
306 return 0.0;
307 }
308
309 let mut k = 0;
311 for i in 0..len1 {
312 if !s1_matches[i] {
313 continue;
314 }
315 while !s2_matches[k] {
316 k += 1;
317 }
318 if s1_chars[i] != s2_chars[k] {
319 transpositions += 1;
320 }
321 k += 1;
322 }
323
324 let m = matches as f32;
325 (m / len1 as f32 + m / len2 as f32 + (m - transpositions as f32 / 2.0) / m) / 3.0
326 }
327
328 fn jaccard_similarity(&self, s1: &str, s2: &str) -> f32 {
330 let tokens1: HashSet<&str> = s1.split_whitespace().collect();
331 let tokens2: HashSet<&str> = s2.split_whitespace().collect();
332
333 if tokens1.is_empty() && tokens2.is_empty() {
334 return 1.0;
335 }
336
337 let intersection = tokens1.intersection(&tokens2).count();
338 let union = tokens1.union(&tokens2).count();
339
340 if union == 0 {
341 return 0.0;
342 }
343
344 intersection as f32 / union as f32
345 }
346
347 fn phonetic_similarity(&self, s1: &str, s2: &str) -> f32 {
349 let soundex1 = self.soundex(s1);
350 let soundex2 = self.soundex(s2);
351
352 if soundex1 == soundex2 {
353 0.9 } else {
355 0.0
356 }
357 }
358
359 fn soundex(&self, s: &str) -> String {
361 if s.is_empty() {
362 return String::new();
363 }
364
365 let chars: Vec<char> = s.to_uppercase().chars().collect();
366 let mut result = String::new();
367
368 if let Some(&first) = chars.first() {
370 if first.is_alphabetic() {
371 result.push(first);
372 }
373 }
374
375 let mut prev_code = self.soundex_code(chars[0]);
376
377 for &c in chars.iter().skip(1) {
378 let code = self.soundex_code(c);
379
380 if code != '0' && code != prev_code {
381 result.push(code);
382 prev_code = code;
383 }
384
385 if result.len() >= 4 {
386 break;
387 }
388 }
389
390 while result.len() < 4 {
392 result.push('0');
393 }
394
395 result
396 }
397
398 fn soundex_code(&self, c: char) -> char {
400 match c.to_ascii_uppercase() {
401 'B' | 'F' | 'P' | 'V' => '1',
402 'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
403 'D' | 'T' => '3',
404 'L' => '4',
405 'M' | 'N' => '5',
406 'R' => '6',
407 _ => '0',
408 }
409 }
410
411 pub fn find_canonical_entity(
413 &self,
414 mention: &str,
415 entity_type: &str,
416 candidates: &[Entity],
417 ) -> Option<EntityId> {
418 let normalized_mention = self.normalize_string(mention);
419
420 let mut best_match: Option<(EntityId, f32)> = None;
421
422 for candidate in candidates {
423 if candidate.entity_type != entity_type {
424 continue;
425 }
426
427 let normalized_candidate = self.normalize_string(&candidate.name);
428
429 if normalized_mention == normalized_candidate {
431 return Some(candidate.id.clone());
432 }
433
434 let mut scores = Vec::new();
436
437 if self.config.fuzzy_matching {
438 let lev_sim = self.levenshtein_similarity(&normalized_mention, &normalized_candidate);
439 scores.push(lev_sim);
440 }
441
442 let jaro_sim = self.jaro_winkler_similarity(&normalized_mention, &normalized_candidate);
443 scores.push(jaro_sim);
444
445 let jaccard_sim = self.jaccard_similarity(&normalized_mention, &normalized_candidate);
446 scores.push(jaccard_sim);
447
448 if self.config.use_phonetic {
449 let phonetic_sim =
450 self.phonetic_similarity(&normalized_mention, &normalized_candidate);
451 scores.push(phonetic_sim);
452 }
453
454 let max_similarity = scores.into_iter().fold(0.0, f32::max);
455
456 if max_similarity >= self.config.min_similarity {
457 if let Some((_, current_best_score)) = &best_match {
458 if max_similarity > *current_best_score {
459 best_match = Some((candidate.id.clone(), max_similarity));
460 }
461 } else {
462 best_match = Some((candidate.id.clone(), max_similarity));
463 }
464 }
465 }
466
467 best_match.map(|(id, _)| id)
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::core::{ChunkId, DocumentId, EntityMention};
475
476 #[test]
477 fn test_levenshtein_distance() {
478 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
479
480 assert_eq!(linker.levenshtein_distance("kitten", "sitting"), 3);
481 assert_eq!(linker.levenshtein_distance("saturday", "sunday"), 3);
482 assert_eq!(linker.levenshtein_distance("", ""), 0);
483 assert_eq!(linker.levenshtein_distance("abc", "abc"), 0);
484 }
485
486 #[test]
487 fn test_jaro_winkler_similarity() {
488 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
489
490 let sim = linker.jaro_winkler_similarity("martha", "marhta");
491 assert!(sim > 0.9, "Expected high similarity for transposition");
492
493 let sim2 = linker.jaro_winkler_similarity("dwayne", "duane");
494 assert!(sim2 > 0.8, "Expected decent similarity");
495
496 let sim3 = linker.jaro_winkler_similarity("abc", "xyz");
497 assert!(sim3 < 0.3, "Expected low similarity");
498 }
499
500 #[test]
501 fn test_jaccard_similarity() {
502 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
503
504 let sim = linker.jaccard_similarity("the quick brown fox", "the lazy brown dog");
505 assert!(sim > 0.3 && sim < 0.5, "Expected moderate similarity");
506
507 let sim2 = linker.jaccard_similarity("apple orange banana", "apple orange banana");
508 assert!((sim2 - 1.0).abs() < 0.001, "Expected perfect match");
509 }
510
511 #[test]
512 fn test_soundex() {
513 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
514
515 assert_eq!(linker.soundex("Robert"), "R163");
516 assert_eq!(linker.soundex("Rupert"), "R163");
517 assert_eq!(linker.soundex("Rubin"), "R150");
518 assert_eq!(linker.soundex("Smith"), "S530");
519 assert_eq!(linker.soundex("Smyth"), "S530");
520 }
521
522 #[test]
523 fn test_entity_normalization() {
524 let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
525
526 assert_eq!(
527 linker.normalize_string("John Smith!"),
528 "john smith"
529 );
530 assert_eq!(
531 linker.normalize_string("ACME Corp."),
532 "acme corp"
533 );
534 }
535
536 #[test]
537 fn test_find_canonical_entity() {
538 let config = EntityLinkingConfig {
539 min_similarity: 0.8,
540 ..Default::default()
541 };
542 let linker = StringSimilarityLinker::new(config);
543
544 let candidates = vec![
545 Entity {
546 id: EntityId::new("e1".to_string()),
547 name: "John Smith".to_string(),
548 entity_type: "PERSON".to_string(),
549 confidence: 0.9,
550 mentions: vec![],
551 embedding: None,
552 },
553 Entity {
554 id: EntityId::new("e2".to_string()),
555 name: "Acme Corp".to_string(),
556 entity_type: "ORG".to_string(),
557 confidence: 0.85,
558 mentions: vec![],
559 embedding: None,
560 },
561 ];
562
563 let result = linker.find_canonical_entity("Jon Smith", "PERSON", &candidates);
565 assert!(result.is_some());
566 assert_eq!(result.unwrap(), EntityId::new("e1".to_string()));
567
568 let result = linker.find_canonical_entity("John Smith", "ORG", &candidates);
570 assert!(result.is_none());
571
572 let result = linker.find_canonical_entity("Jhon Smith", "PERSON", &candidates);
574 assert!(result.is_some());
575 }
576
577 #[test]
578 fn test_link_similar_entities() {
579 let config = EntityLinkingConfig {
580 min_similarity: 0.85,
581 ..Default::default()
582 };
583 let linker = StringSimilarityLinker::new(config);
584
585 let mut graph = KnowledgeGraph::new();
586
587 let _ = graph.add_entity(Entity {
589 id: EntityId::new("e1".to_string()),
590 name: "New York".to_string(),
591 entity_type: "LOCATION".to_string(),
592 confidence: 0.9,
593 mentions: vec![EntityMention {
594 chunk_id: ChunkId::new("chunk1".to_string()),
595 start_offset: 0,
596 end_offset: 8,
597 confidence: 0.9,
598 }],
599 embedding: None,
600 });
601
602 let _ = graph.add_entity(Entity {
603 id: EntityId::new("e2".to_string()),
604 name: "New York City".to_string(),
605 entity_type: "LOCATION".to_string(),
606 confidence: 0.85,
607 mentions: vec![EntityMention {
608 chunk_id: ChunkId::new("chunk2".to_string()),
609 start_offset: 0,
610 end_offset: 13,
611 confidence: 0.85,
612 }],
613 embedding: None,
614 });
615
616 let links = linker.link_entities(&graph).unwrap();
617
618 assert!(links.len() > 0, "Expected some entities to be linked");
620 }
621}