Skip to main content

graphrag_core/entity/
string_similarity_linker.rs

1//! String Similarity-based Entity Linking
2//!
3//! Deterministic entity linking using string similarity metrics without ML.
4//! Implements multiple algorithms:
5//! - Levenshtein edit distance
6//! - Jaro-Winkler similarity
7//! - Jaccard similarity (token-based)
8//! - Exact match with normalization
9//! - Phonetic matching (Soundex, Metaphone)
10
11use crate::core::{Entity, EntityId, KnowledgeGraph};
12use crate::Result;
13use std::collections::{HashMap, HashSet};
14
15/// Configuration for string similarity-based entity linking
16#[derive(Debug, Clone)]
17pub struct EntityLinkingConfig {
18    /// Minimum similarity threshold (0.0-1.0)
19    pub min_similarity: f32,
20
21    /// Use case normalization
22    pub case_insensitive: bool,
23
24    /// Remove punctuation before comparison
25    pub remove_punctuation: bool,
26
27    /// Enable phonetic matching (Soundex)
28    pub use_phonetic: bool,
29
30    /// Minimum token overlap for Jaccard (0.0-1.0)
31    pub min_jaccard_overlap: f32,
32
33    /// Maximum edit distance for Levenshtein
34    pub max_edit_distance: usize,
35
36    /// Enable fuzzy matching with typo tolerance
37    pub fuzzy_matching: bool,
38}
39
40impl Default for EntityLinkingConfig {
41    fn default() -> Self {
42        Self {
43            min_similarity: 0.85,
44            case_insensitive: true,
45            remove_punctuation: true,
46            use_phonetic: false,
47            min_jaccard_overlap: 0.6,
48            max_edit_distance: 2,
49            fuzzy_matching: true,
50        }
51    }
52}
53
54/// Entity linker using string similarity metrics
55pub struct StringSimilarityLinker {
56    config: EntityLinkingConfig,
57}
58
59impl StringSimilarityLinker {
60    /// Create a new entity linker with configuration
61    pub fn new(config: EntityLinkingConfig) -> Self {
62        Self { config }
63    }
64
65    /// Link entities in a knowledge graph based on string similarity
66    ///
67    /// Returns a mapping from entity IDs to their canonical entity ID
68    pub fn link_entities(
69        &self,
70        graph: &KnowledgeGraph,
71    ) -> Result<HashMap<EntityId, EntityId>> {
72        let mut links: HashMap<EntityId, EntityId> = HashMap::new();
73        let entities: Vec<Entity> = graph.entities().cloned().collect();
74
75        // Build clusters of similar entities
76        let mut clusters: Vec<Vec<usize>> = Vec::new();
77        let mut clustered: HashSet<usize> = HashSet::new();
78
79        for i in 0..entities.len() {
80            if clustered.contains(&i) {
81                continue;
82            }
83
84            let mut cluster = vec![i];
85            clustered.insert(i);
86
87            for j in (i + 1)..entities.len() {
88                if clustered.contains(&j) {
89                    continue;
90                }
91
92                let similarity = self.compute_similarity(&entities[i], &entities[j]);
93
94                if similarity >= self.config.min_similarity {
95                    cluster.push(j);
96                    clustered.insert(j);
97                }
98            }
99
100            if cluster.len() > 1 {
101                clusters.push(cluster);
102            }
103        }
104
105        // For each cluster, select canonical entity (highest confidence)
106        for cluster in clusters {
107            let canonical_idx = cluster
108                .iter()
109                .max_by(|&&a, &&b| {
110                    entities[a]
111                        .confidence
112                        .partial_cmp(&entities[b].confidence)
113                        .unwrap_or(std::cmp::Ordering::Equal)
114                })
115                .unwrap();
116
117            let canonical_id = &entities[*canonical_idx].id;
118
119            for &entity_idx in &cluster {
120                if entity_idx != *canonical_idx {
121                    links.insert(entities[entity_idx].id.clone(), canonical_id.clone());
122                }
123            }
124        }
125
126        Ok(links)
127    }
128
129    /// Compute overall similarity between two entities
130    fn compute_similarity(&self, e1: &Entity, e2: &Entity) -> f32 {
131        // Different entity types should not be linked
132        if e1.entity_type != e2.entity_type {
133            return 0.0;
134        }
135
136        let name1 = self.normalize_string(&e1.name);
137        let name2 = self.normalize_string(&e2.name);
138
139        // Exact match after normalization
140        if name1 == name2 {
141            return 1.0;
142        }
143
144        let mut scores = Vec::new();
145
146        // 1. Levenshtein-based similarity
147        if self.config.fuzzy_matching {
148            let lev_sim = self.levenshtein_similarity(&name1, &name2);
149            scores.push(lev_sim);
150        }
151
152        // 2. Jaro-Winkler similarity
153        let jaro_sim = self.jaro_winkler_similarity(&name1, &name2);
154        scores.push(jaro_sim);
155
156        // 3. Token-based Jaccard similarity
157        let jaccard_sim = self.jaccard_similarity(&name1, &name2);
158        scores.push(jaccard_sim);
159
160        // 4. Phonetic matching (if enabled)
161        if self.config.use_phonetic {
162            let phonetic_sim = self.phonetic_similarity(&name1, &name2);
163            scores.push(phonetic_sim);
164        }
165
166        // Return maximum similarity across all methods
167        scores.into_iter().fold(0.0, f32::max)
168    }
169
170    /// Normalize string for comparison
171    fn normalize_string(&self, s: &str) -> String {
172        let mut normalized = s.to_string();
173
174        if self.config.case_insensitive {
175            normalized = normalized.to_lowercase();
176        }
177
178        if self.config.remove_punctuation {
179            normalized = normalized
180                .chars()
181                .filter(|c| c.is_alphanumeric() || c.is_whitespace())
182                .collect();
183        }
184
185        // Normalize whitespace
186        normalized
187            .split_whitespace()
188            .collect::<Vec<_>>()
189            .join(" ")
190    }
191
192    /// Compute Levenshtein edit distance-based similarity
193    fn levenshtein_similarity(&self, s1: &str, s2: &str) -> f32 {
194        let distance = self.levenshtein_distance(s1, s2);
195
196        if distance > self.config.max_edit_distance {
197            return 0.0;
198        }
199
200        let max_len = s1.len().max(s2.len());
201        if max_len == 0 {
202            return 1.0;
203        }
204
205        1.0 - (distance as f32 / max_len as f32)
206    }
207
208    /// Compute Levenshtein edit distance
209    fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
210        let len1 = s1.chars().count();
211        let len2 = s2.chars().count();
212
213        if len1 == 0 {
214            return len2;
215        }
216        if len2 == 0 {
217            return len1;
218        }
219
220        let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
221
222        // Initialize first row and column
223        for i in 0..=len1 {
224            matrix[i][0] = i;
225        }
226        for j in 0..=len2 {
227            matrix[0][j] = j;
228        }
229
230        let s1_chars: Vec<char> = s1.chars().collect();
231        let s2_chars: Vec<char> = s2.chars().collect();
232
233        // Fill matrix
234        for i in 1..=len1 {
235            for j in 1..=len2 {
236                let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
237                    0
238                } else {
239                    1
240                };
241
242                matrix[i][j] = (matrix[i - 1][j] + 1) // deletion
243                    .min(matrix[i][j - 1] + 1) // insertion
244                    .min(matrix[i - 1][j - 1] + cost); // substitution
245            }
246        }
247
248        matrix[len1][len2]
249    }
250
251    /// Compute Jaro-Winkler similarity
252    fn jaro_winkler_similarity(&self, s1: &str, s2: &str) -> f32 {
253        let jaro = self.jaro_similarity(s1, s2);
254
255        // Apply Winkler prefix bonus
256        let prefix_len = s1
257            .chars()
258            .zip(s2.chars())
259            .take(4)
260            .take_while(|(c1, c2)| c1 == c2)
261            .count();
262
263        jaro + (prefix_len as f32 * 0.1 * (1.0 - jaro))
264    }
265
266    /// Compute Jaro similarity
267    fn jaro_similarity(&self, s1: &str, s2: &str) -> f32 {
268        let s1_chars: Vec<char> = s1.chars().collect();
269        let s2_chars: Vec<char> = s2.chars().collect();
270
271        let len1 = s1_chars.len();
272        let len2 = s2_chars.len();
273
274        if len1 == 0 && len2 == 0 {
275            return 1.0;
276        }
277        if len1 == 0 || len2 == 0 {
278            return 0.0;
279        }
280
281        let match_distance = (len1.max(len2) / 2).saturating_sub(1);
282
283        let mut s1_matches = vec![false; len1];
284        let mut s2_matches = vec![false; len2];
285
286        let mut matches = 0;
287        let mut transpositions = 0;
288
289        // Find matches
290        for i in 0..len1 {
291            let start = i.saturating_sub(match_distance);
292            let end = (i + match_distance + 1).min(len2);
293
294            for j in start..end {
295                if s2_matches[j] || s1_chars[i] != s2_chars[j] {
296                    continue;
297                }
298                s1_matches[i] = true;
299                s2_matches[j] = true;
300                matches += 1;
301                break;
302            }
303        }
304
305        if matches == 0 {
306            return 0.0;
307        }
308
309        // Count transpositions
310        let mut k = 0;
311        for i in 0..len1 {
312            if !s1_matches[i] {
313                continue;
314            }
315            while !s2_matches[k] {
316                k += 1;
317            }
318            if s1_chars[i] != s2_chars[k] {
319                transpositions += 1;
320            }
321            k += 1;
322        }
323
324        let m = matches as f32;
325        (m / len1 as f32 + m / len2 as f32 + (m - transpositions as f32 / 2.0) / m) / 3.0
326    }
327
328    /// Compute token-based Jaccard similarity
329    fn jaccard_similarity(&self, s1: &str, s2: &str) -> f32 {
330        let tokens1: HashSet<&str> = s1.split_whitespace().collect();
331        let tokens2: HashSet<&str> = s2.split_whitespace().collect();
332
333        if tokens1.is_empty() && tokens2.is_empty() {
334            return 1.0;
335        }
336
337        let intersection = tokens1.intersection(&tokens2).count();
338        let union = tokens1.union(&tokens2).count();
339
340        if union == 0 {
341            return 0.0;
342        }
343
344        intersection as f32 / union as f32
345    }
346
347    /// Compute phonetic similarity using simplified Soundex
348    fn phonetic_similarity(&self, s1: &str, s2: &str) -> f32 {
349        let soundex1 = self.soundex(s1);
350        let soundex2 = self.soundex(s2);
351
352        if soundex1 == soundex2 {
353            0.9 // High but not perfect score for phonetic match
354        } else {
355            0.0
356        }
357    }
358
359    /// Simple Soundex implementation
360    fn soundex(&self, s: &str) -> String {
361        if s.is_empty() {
362            return String::new();
363        }
364
365        let chars: Vec<char> = s.to_uppercase().chars().collect();
366        let mut result = String::new();
367
368        // Keep first letter
369        if let Some(&first) = chars.first() {
370            if first.is_alphabetic() {
371                result.push(first);
372            }
373        }
374
375        let mut prev_code = self.soundex_code(chars[0]);
376
377        for &c in chars.iter().skip(1) {
378            let code = self.soundex_code(c);
379
380            if code != '0' && code != prev_code {
381                result.push(code);
382                prev_code = code;
383            }
384
385            if result.len() >= 4 {
386                break;
387            }
388        }
389
390        // Pad with zeros
391        while result.len() < 4 {
392            result.push('0');
393        }
394
395        result
396    }
397
398    /// Get Soundex code for a character
399    fn soundex_code(&self, c: char) -> char {
400        match c.to_ascii_uppercase() {
401            'B' | 'F' | 'P' | 'V' => '1',
402            'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
403            'D' | 'T' => '3',
404            'L' => '4',
405            'M' | 'N' => '5',
406            'R' => '6',
407            _ => '0',
408        }
409    }
410
411    /// Find candidate entity for linking a new mention
412    pub fn find_canonical_entity(
413        &self,
414        mention: &str,
415        entity_type: &str,
416        candidates: &[Entity],
417    ) -> Option<EntityId> {
418        let normalized_mention = self.normalize_string(mention);
419
420        let mut best_match: Option<(EntityId, f32)> = None;
421
422        for candidate in candidates {
423            if candidate.entity_type != entity_type {
424                continue;
425            }
426
427            let normalized_candidate = self.normalize_string(&candidate.name);
428
429            // Quick exact match check
430            if normalized_mention == normalized_candidate {
431                return Some(candidate.id.clone());
432            }
433
434            // Compute similarity
435            let mut scores = Vec::new();
436
437            if self.config.fuzzy_matching {
438                let lev_sim = self.levenshtein_similarity(&normalized_mention, &normalized_candidate);
439                scores.push(lev_sim);
440            }
441
442            let jaro_sim = self.jaro_winkler_similarity(&normalized_mention, &normalized_candidate);
443            scores.push(jaro_sim);
444
445            let jaccard_sim = self.jaccard_similarity(&normalized_mention, &normalized_candidate);
446            scores.push(jaccard_sim);
447
448            if self.config.use_phonetic {
449                let phonetic_sim =
450                    self.phonetic_similarity(&normalized_mention, &normalized_candidate);
451                scores.push(phonetic_sim);
452            }
453
454            let max_similarity = scores.into_iter().fold(0.0, f32::max);
455
456            if max_similarity >= self.config.min_similarity {
457                if let Some((_, current_best_score)) = &best_match {
458                    if max_similarity > *current_best_score {
459                        best_match = Some((candidate.id.clone(), max_similarity));
460                    }
461                } else {
462                    best_match = Some((candidate.id.clone(), max_similarity));
463                }
464            }
465        }
466
467        best_match.map(|(id, _)| id)
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::core::{ChunkId, DocumentId, EntityMention};
475
476    #[test]
477    fn test_levenshtein_distance() {
478        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
479
480        assert_eq!(linker.levenshtein_distance("kitten", "sitting"), 3);
481        assert_eq!(linker.levenshtein_distance("saturday", "sunday"), 3);
482        assert_eq!(linker.levenshtein_distance("", ""), 0);
483        assert_eq!(linker.levenshtein_distance("abc", "abc"), 0);
484    }
485
486    #[test]
487    fn test_jaro_winkler_similarity() {
488        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
489
490        let sim = linker.jaro_winkler_similarity("martha", "marhta");
491        assert!(sim > 0.9, "Expected high similarity for transposition");
492
493        let sim2 = linker.jaro_winkler_similarity("dwayne", "duane");
494        assert!(sim2 > 0.8, "Expected decent similarity");
495
496        let sim3 = linker.jaro_winkler_similarity("abc", "xyz");
497        assert!(sim3 < 0.3, "Expected low similarity");
498    }
499
500    #[test]
501    fn test_jaccard_similarity() {
502        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
503
504        let sim = linker.jaccard_similarity("the quick brown fox", "the lazy brown dog");
505        assert!(sim > 0.3 && sim < 0.5, "Expected moderate similarity");
506
507        let sim2 = linker.jaccard_similarity("apple orange banana", "apple orange banana");
508        assert!((sim2 - 1.0).abs() < 0.001, "Expected perfect match");
509    }
510
511    #[test]
512    fn test_soundex() {
513        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
514
515        assert_eq!(linker.soundex("Robert"), "R163");
516        assert_eq!(linker.soundex("Rupert"), "R163");
517        assert_eq!(linker.soundex("Rubin"), "R150");
518        assert_eq!(linker.soundex("Smith"), "S530");
519        assert_eq!(linker.soundex("Smyth"), "S530");
520    }
521
522    #[test]
523    fn test_entity_normalization() {
524        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
525
526        assert_eq!(
527            linker.normalize_string("John  Smith!"),
528            "john smith"
529        );
530        assert_eq!(
531            linker.normalize_string("ACME Corp."),
532            "acme corp"
533        );
534    }
535
536    #[test]
537    fn test_find_canonical_entity() {
538        let config = EntityLinkingConfig {
539            min_similarity: 0.8,
540            ..Default::default()
541        };
542        let linker = StringSimilarityLinker::new(config);
543
544        let candidates = vec![
545            Entity {
546                id: EntityId::new("e1".to_string()),
547                name: "John Smith".to_string(),
548                entity_type: "PERSON".to_string(),
549                confidence: 0.9,
550                mentions: vec![],
551                embedding: None,
552            },
553            Entity {
554                id: EntityId::new("e2".to_string()),
555                name: "Acme Corp".to_string(),
556                entity_type: "ORG".to_string(),
557                confidence: 0.85,
558                mentions: vec![],
559                embedding: None,
560            },
561        ];
562
563        // Should match John Smith
564        let result = linker.find_canonical_entity("Jon Smith", "PERSON", &candidates);
565        assert!(result.is_some());
566        assert_eq!(result.unwrap(), EntityId::new("e1".to_string()));
567
568        // Should not match wrong type
569        let result = linker.find_canonical_entity("John Smith", "ORG", &candidates);
570        assert!(result.is_none());
571
572        // Should match with typo
573        let result = linker.find_canonical_entity("Jhon Smith", "PERSON", &candidates);
574        assert!(result.is_some());
575    }
576
577    #[test]
578    fn test_link_similar_entities() {
579        let config = EntityLinkingConfig {
580            min_similarity: 0.85,
581            ..Default::default()
582        };
583        let linker = StringSimilarityLinker::new(config);
584
585        let mut graph = KnowledgeGraph::new();
586
587        // Add similar entities
588        let _ = graph.add_entity(Entity {
589            id: EntityId::new("e1".to_string()),
590            name: "New York".to_string(),
591            entity_type: "LOCATION".to_string(),
592            confidence: 0.9,
593            mentions: vec![EntityMention {
594                chunk_id: ChunkId::new("chunk1".to_string()),
595                start_offset: 0,
596                end_offset: 8,
597                confidence: 0.9,
598            }],
599            embedding: None,
600        });
601
602        let _ = graph.add_entity(Entity {
603            id: EntityId::new("e2".to_string()),
604            name: "New York City".to_string(),
605            entity_type: "LOCATION".to_string(),
606            confidence: 0.85,
607            mentions: vec![EntityMention {
608                chunk_id: ChunkId::new("chunk2".to_string()),
609                start_offset: 0,
610                end_offset: 13,
611                confidence: 0.85,
612            }],
613            embedding: None,
614        });
615
616        let links = linker.link_entities(&graph).unwrap();
617
618        // Should link similar location names
619        assert!(links.len() > 0, "Expected some entities to be linked");
620    }
621}