Skip to main content

graphrag_core/entity/
string_similarity_linker.rs

1//! String Similarity-based Entity Linking
2//!
3//! Deterministic entity linking using string similarity metrics without ML.
4//! Implements multiple algorithms:
5//! - Levenshtein edit distance
6//! - Jaro-Winkler similarity
7//! - Jaccard similarity (token-based)
8//! - Exact match with normalization
9//! - Phonetic matching (Soundex, Metaphone)
10
11use crate::core::{Entity, EntityId, KnowledgeGraph};
12use crate::Result;
13use std::collections::{HashMap, HashSet};
14
15/// Configuration for string similarity-based entity linking
16#[derive(Debug, Clone)]
17pub struct EntityLinkingConfig {
18    /// Minimum similarity threshold (0.0-1.0)
19    pub min_similarity: f32,
20
21    /// Use case normalization
22    pub case_insensitive: bool,
23
24    /// Remove punctuation before comparison
25    pub remove_punctuation: bool,
26
27    /// Enable phonetic matching (Soundex)
28    pub use_phonetic: bool,
29
30    /// Minimum token overlap for Jaccard (0.0-1.0)
31    pub min_jaccard_overlap: f32,
32
33    /// Maximum edit distance for Levenshtein
34    pub max_edit_distance: usize,
35
36    /// Enable fuzzy matching with typo tolerance
37    pub fuzzy_matching: bool,
38}
39
40impl Default for EntityLinkingConfig {
41    fn default() -> Self {
42        Self {
43            min_similarity: 0.85,
44            case_insensitive: true,
45            remove_punctuation: true,
46            use_phonetic: false,
47            min_jaccard_overlap: 0.6,
48            max_edit_distance: 2,
49            fuzzy_matching: true,
50        }
51    }
52}
53
54/// Entity linker using string similarity metrics
55pub struct StringSimilarityLinker {
56    config: EntityLinkingConfig,
57}
58
59impl StringSimilarityLinker {
60    /// Create a new entity linker with configuration
61    pub fn new(config: EntityLinkingConfig) -> Self {
62        Self { config }
63    }
64
65    /// Link entities in a knowledge graph based on string similarity
66    ///
67    /// Returns a mapping from entity IDs to their canonical entity ID
68    pub fn link_entities(&self, graph: &KnowledgeGraph) -> Result<HashMap<EntityId, EntityId>> {
69        let mut links: HashMap<EntityId, EntityId> = HashMap::new();
70        let entities: Vec<Entity> = graph.entities().cloned().collect();
71
72        // Build clusters of similar entities
73        let mut clusters: Vec<Vec<usize>> = Vec::new();
74        let mut clustered: HashSet<usize> = HashSet::new();
75
76        for i in 0..entities.len() {
77            if clustered.contains(&i) {
78                continue;
79            }
80
81            let mut cluster = vec![i];
82            clustered.insert(i);
83
84            for j in (i + 1)..entities.len() {
85                if clustered.contains(&j) {
86                    continue;
87                }
88
89                let similarity = self.compute_similarity(&entities[i], &entities[j]);
90
91                if similarity >= self.config.min_similarity {
92                    cluster.push(j);
93                    clustered.insert(j);
94                }
95            }
96
97            if cluster.len() > 1 {
98                clusters.push(cluster);
99            }
100        }
101
102        // For each cluster, select canonical entity (highest confidence)
103        for cluster in clusters {
104            let canonical_idx = cluster
105                .iter()
106                .max_by(|&&a, &&b| {
107                    entities[a]
108                        .confidence
109                        .partial_cmp(&entities[b].confidence)
110                        .unwrap_or(std::cmp::Ordering::Equal)
111                })
112                .unwrap();
113
114            let canonical_id = &entities[*canonical_idx].id;
115
116            for &entity_idx in &cluster {
117                if entity_idx != *canonical_idx {
118                    links.insert(entities[entity_idx].id.clone(), canonical_id.clone());
119                }
120            }
121        }
122
123        Ok(links)
124    }
125
126    /// Compute overall similarity between two entities
127    fn compute_similarity(&self, e1: &Entity, e2: &Entity) -> f32 {
128        // Different entity types should not be linked
129        if e1.entity_type != e2.entity_type {
130            return 0.0;
131        }
132
133        let name1 = self.normalize_string(&e1.name);
134        let name2 = self.normalize_string(&e2.name);
135
136        // Exact match after normalization
137        if name1 == name2 {
138            return 1.0;
139        }
140
141        let mut scores = Vec::new();
142
143        // 1. Levenshtein-based similarity
144        if self.config.fuzzy_matching {
145            let lev_sim = self.levenshtein_similarity(&name1, &name2);
146            scores.push(lev_sim);
147        }
148
149        // 2. Jaro-Winkler similarity
150        let jaro_sim = self.jaro_winkler_similarity(&name1, &name2);
151        scores.push(jaro_sim);
152
153        // 3. Token-based Jaccard similarity
154        let jaccard_sim = self.jaccard_similarity(&name1, &name2);
155        scores.push(jaccard_sim);
156
157        // 4. Phonetic matching (if enabled)
158        if self.config.use_phonetic {
159            let phonetic_sim = self.phonetic_similarity(&name1, &name2);
160            scores.push(phonetic_sim);
161        }
162
163        // Return maximum similarity across all methods
164        scores.into_iter().fold(0.0, f32::max)
165    }
166
167    /// Normalize string for comparison
168    fn normalize_string(&self, s: &str) -> String {
169        let mut normalized = s.to_string();
170
171        if self.config.case_insensitive {
172            normalized = normalized.to_lowercase();
173        }
174
175        if self.config.remove_punctuation {
176            normalized = normalized
177                .chars()
178                .filter(|c| c.is_alphanumeric() || c.is_whitespace())
179                .collect();
180        }
181
182        // Normalize whitespace
183        normalized.split_whitespace().collect::<Vec<_>>().join(" ")
184    }
185
186    /// Compute Levenshtein edit distance-based similarity
187    fn levenshtein_similarity(&self, s1: &str, s2: &str) -> f32 {
188        let distance = self.levenshtein_distance(s1, s2);
189
190        if distance > self.config.max_edit_distance {
191            return 0.0;
192        }
193
194        let max_len = s1.len().max(s2.len());
195        if max_len == 0 {
196            return 1.0;
197        }
198
199        1.0 - (distance as f32 / max_len as f32)
200    }
201
202    /// Compute Levenshtein edit distance
203    fn levenshtein_distance(&self, s1: &str, s2: &str) -> usize {
204        let len1 = s1.chars().count();
205        let len2 = s2.chars().count();
206
207        if len1 == 0 {
208            return len2;
209        }
210        if len2 == 0 {
211            return len1;
212        }
213
214        let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
215
216        // Initialize first row and column
217        #[allow(clippy::needless_range_loop)]
218        for i in 0..=len1 {
219            matrix[i][0] = i;
220        }
221        #[allow(clippy::needless_range_loop)]
222        for j in 0..=len2 {
223            matrix[0][j] = j;
224        }
225
226        let s1_chars: Vec<char> = s1.chars().collect();
227        let s2_chars: Vec<char> = s2.chars().collect();
228
229        // Fill matrix
230        for i in 1..=len1 {
231            for j in 1..=len2 {
232                let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
233                    0
234                } else {
235                    1
236                };
237
238                matrix[i][j] = (matrix[i - 1][j] + 1) // deletion
239                    .min(matrix[i][j - 1] + 1) // insertion
240                    .min(matrix[i - 1][j - 1] + cost); // substitution
241            }
242        }
243
244        matrix[len1][len2]
245    }
246
247    /// Compute Jaro-Winkler similarity
248    fn jaro_winkler_similarity(&self, s1: &str, s2: &str) -> f32 {
249        let jaro = self.jaro_similarity(s1, s2);
250
251        // Apply Winkler prefix bonus
252        let prefix_len = s1
253            .chars()
254            .zip(s2.chars())
255            .take(4)
256            .take_while(|(c1, c2)| c1 == c2)
257            .count();
258
259        jaro + (prefix_len as f32 * 0.1 * (1.0 - jaro))
260    }
261
262    /// Compute Jaro similarity
263    fn jaro_similarity(&self, s1: &str, s2: &str) -> f32 {
264        let s1_chars: Vec<char> = s1.chars().collect();
265        let s2_chars: Vec<char> = s2.chars().collect();
266
267        let len1 = s1_chars.len();
268        let len2 = s2_chars.len();
269
270        if len1 == 0 && len2 == 0 {
271            return 1.0;
272        }
273        if len1 == 0 || len2 == 0 {
274            return 0.0;
275        }
276
277        let match_distance = (len1.max(len2) / 2).saturating_sub(1);
278
279        let mut s1_matches = vec![false; len1];
280        let mut s2_matches = vec![false; len2];
281
282        let mut matches = 0;
283        let mut transpositions = 0;
284
285        // Find matches
286        for i in 0..len1 {
287            let start = i.saturating_sub(match_distance);
288            let end = (i + match_distance + 1).min(len2);
289
290            for j in start..end {
291                if s2_matches[j] || s1_chars[i] != s2_chars[j] {
292                    continue;
293                }
294                s1_matches[i] = true;
295                s2_matches[j] = true;
296                matches += 1;
297                break;
298            }
299        }
300
301        if matches == 0 {
302            return 0.0;
303        }
304
305        // Count transpositions
306        let mut k = 0;
307        for i in 0..len1 {
308            if !s1_matches[i] {
309                continue;
310            }
311            while !s2_matches[k] {
312                k += 1;
313            }
314            if s1_chars[i] != s2_chars[k] {
315                transpositions += 1;
316            }
317            k += 1;
318        }
319
320        let m = matches as f32;
321        (m / len1 as f32 + m / len2 as f32 + (m - transpositions as f32 / 2.0) / m) / 3.0
322    }
323
324    /// Compute token-based Jaccard similarity
325    fn jaccard_similarity(&self, s1: &str, s2: &str) -> f32 {
326        let tokens1: HashSet<&str> = s1.split_whitespace().collect();
327        let tokens2: HashSet<&str> = s2.split_whitespace().collect();
328
329        if tokens1.is_empty() && tokens2.is_empty() {
330            return 1.0;
331        }
332
333        let intersection = tokens1.intersection(&tokens2).count();
334        let union = tokens1.union(&tokens2).count();
335
336        if union == 0 {
337            return 0.0;
338        }
339
340        intersection as f32 / union as f32
341    }
342
343    /// Compute phonetic similarity using simplified Soundex
344    fn phonetic_similarity(&self, s1: &str, s2: &str) -> f32 {
345        let soundex1 = self.soundex(s1);
346        let soundex2 = self.soundex(s2);
347
348        if soundex1 == soundex2 {
349            0.9 // High but not perfect score for phonetic match
350        } else {
351            0.0
352        }
353    }
354
355    /// Simple Soundex implementation
356    fn soundex(&self, s: &str) -> String {
357        if s.is_empty() {
358            return String::new();
359        }
360
361        let chars: Vec<char> = s.to_uppercase().chars().collect();
362        let mut result = String::new();
363
364        // Keep first letter
365        if let Some(&first) = chars.first() {
366            if first.is_alphabetic() {
367                result.push(first);
368            }
369        }
370
371        let mut prev_code = self.soundex_code(chars[0]);
372
373        for &c in chars.iter().skip(1) {
374            let code = self.soundex_code(c);
375
376            if code != '0' && code != prev_code {
377                result.push(code);
378                prev_code = code;
379            }
380
381            if result.len() >= 4 {
382                break;
383            }
384        }
385
386        // Pad with zeros
387        while result.len() < 4 {
388            result.push('0');
389        }
390
391        result
392    }
393
394    /// Get Soundex code for a character
395    fn soundex_code(&self, c: char) -> char {
396        match c.to_ascii_uppercase() {
397            'B' | 'F' | 'P' | 'V' => '1',
398            'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
399            'D' | 'T' => '3',
400            'L' => '4',
401            'M' | 'N' => '5',
402            'R' => '6',
403            _ => '0',
404        }
405    }
406
407    /// Find candidate entity for linking a new mention
408    pub fn find_canonical_entity(
409        &self,
410        mention: &str,
411        entity_type: &str,
412        candidates: &[Entity],
413    ) -> Option<EntityId> {
414        let normalized_mention = self.normalize_string(mention);
415
416        let mut best_match: Option<(EntityId, f32)> = None;
417
418        for candidate in candidates {
419            if candidate.entity_type != entity_type {
420                continue;
421            }
422
423            let normalized_candidate = self.normalize_string(&candidate.name);
424
425            // Quick exact match check
426            if normalized_mention == normalized_candidate {
427                return Some(candidate.id.clone());
428            }
429
430            // Compute similarity
431            let mut scores = Vec::new();
432
433            if self.config.fuzzy_matching {
434                let lev_sim =
435                    self.levenshtein_similarity(&normalized_mention, &normalized_candidate);
436                scores.push(lev_sim);
437            }
438
439            let jaro_sim = self.jaro_winkler_similarity(&normalized_mention, &normalized_candidate);
440            scores.push(jaro_sim);
441
442            let jaccard_sim = self.jaccard_similarity(&normalized_mention, &normalized_candidate);
443            scores.push(jaccard_sim);
444
445            if self.config.use_phonetic {
446                let phonetic_sim =
447                    self.phonetic_similarity(&normalized_mention, &normalized_candidate);
448                scores.push(phonetic_sim);
449            }
450
451            let max_similarity = scores.into_iter().fold(0.0, f32::max);
452
453            if max_similarity >= self.config.min_similarity {
454                if let Some((_, current_best_score)) = &best_match {
455                    if max_similarity > *current_best_score {
456                        best_match = Some((candidate.id.clone(), max_similarity));
457                    }
458                } else {
459                    best_match = Some((candidate.id.clone(), max_similarity));
460                }
461            }
462        }
463
464        best_match.map(|(id, _)| id)
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use crate::core::{ChunkId, EntityMention};
472
473    #[test]
474    fn test_levenshtein_distance() {
475        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
476
477        assert_eq!(linker.levenshtein_distance("kitten", "sitting"), 3);
478        assert_eq!(linker.levenshtein_distance("saturday", "sunday"), 3);
479        assert_eq!(linker.levenshtein_distance("", ""), 0);
480        assert_eq!(linker.levenshtein_distance("abc", "abc"), 0);
481    }
482
483    #[test]
484    fn test_jaro_winkler_similarity() {
485        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
486
487        let sim = linker.jaro_winkler_similarity("martha", "marhta");
488        assert!(sim > 0.9, "Expected high similarity for transposition");
489
490        let sim2 = linker.jaro_winkler_similarity("dwayne", "duane");
491        assert!(sim2 > 0.8, "Expected decent similarity");
492
493        let sim3 = linker.jaro_winkler_similarity("abc", "xyz");
494        assert!(sim3 < 0.3, "Expected low similarity");
495    }
496
497    #[test]
498    fn test_jaccard_similarity() {
499        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
500
501        let sim = linker.jaccard_similarity("the quick brown fox", "the lazy brown dog");
502        assert!(sim > 0.3 && sim < 0.5, "Expected moderate similarity");
503
504        let sim2 = linker.jaccard_similarity("apple orange banana", "apple orange banana");
505        assert!((sim2 - 1.0).abs() < 0.001, "Expected perfect match");
506    }
507
508    #[test]
509    fn test_soundex() {
510        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
511
512        assert_eq!(linker.soundex("Robert"), "R163");
513        assert_eq!(linker.soundex("Rupert"), "R163");
514        assert_eq!(linker.soundex("Rubin"), "R150");
515        assert_eq!(linker.soundex("Smith"), "S530");
516        assert_eq!(linker.soundex("Smyth"), "S530");
517    }
518
519    #[test]
520    fn test_entity_normalization() {
521        let linker = StringSimilarityLinker::new(EntityLinkingConfig::default());
522
523        assert_eq!(linker.normalize_string("John  Smith!"), "john smith");
524        assert_eq!(linker.normalize_string("ACME Corp."), "acme corp");
525    }
526
527    #[test]
528    fn test_find_canonical_entity() {
529        let config = EntityLinkingConfig {
530            min_similarity: 0.8,
531            ..Default::default()
532        };
533        let linker = StringSimilarityLinker::new(config);
534
535        let candidates = vec![
536            Entity::new(
537                EntityId::new("e1".to_string()),
538                "John Smith".to_string(),
539                "PERSON".to_string(),
540                0.9,
541            ),
542            Entity::new(
543                EntityId::new("e2".to_string()),
544                "Acme Corp".to_string(),
545                "ORG".to_string(),
546                0.85,
547            ),
548        ];
549
550        // Should match John Smith
551        let result = linker.find_canonical_entity("Jon Smith", "PERSON", &candidates);
552        assert!(result.is_some());
553        assert_eq!(result.unwrap(), EntityId::new("e1".to_string()));
554
555        // Should not match wrong type
556        let result = linker.find_canonical_entity("John Smith", "ORG", &candidates);
557        assert!(result.is_none());
558
559        // Should match with typo
560        let result = linker.find_canonical_entity("Jhon Smith", "PERSON", &candidates);
561        assert!(result.is_some());
562    }
563
564    #[test]
565    fn test_link_similar_entities() {
566        let config = EntityLinkingConfig {
567            min_similarity: 0.85,
568            ..Default::default()
569        };
570        let linker = StringSimilarityLinker::new(config);
571
572        let mut graph = KnowledgeGraph::new();
573
574        // Add similar entities
575        let _ = graph.add_entity(Entity {
576            id: EntityId::new("e1".to_string()),
577            name: "New York".to_string(),
578            entity_type: "LOCATION".to_string(),
579            confidence: 0.9,
580            mentions: vec![EntityMention {
581                chunk_id: ChunkId::new("chunk1".to_string()),
582                start_offset: 0,
583                end_offset: 8,
584                confidence: 0.9,
585            }],
586            embedding: None,
587            first_mentioned: None,
588            last_mentioned: None,
589            temporal_validity: None,
590        });
591
592        let _ = graph.add_entity(Entity {
593            id: EntityId::new("e2".to_string()),
594            name: "New York City".to_string(),
595            entity_type: "LOCATION".to_string(),
596            confidence: 0.85,
597            mentions: vec![EntityMention {
598                chunk_id: ChunkId::new("chunk2".to_string()),
599                start_offset: 0,
600                end_offset: 13,
601                confidence: 0.85,
602            }],
603            embedding: None,
604            first_mentioned: None,
605            last_mentioned: None,
606            temporal_validity: None,
607        });
608
609        let links = linker.link_entities(&graph).unwrap();
610
611        // Should link similar location names
612        assert!(links.len() > 0, "Expected some entities to be linked");
613    }
614}