Skip to main content

oxirs_graphrag/
entity_linking.rs

1//! Entity linking and disambiguation for knowledge graphs.
2//!
3//! Provides candidate generation via string matching and context-based
4//! disambiguation using TF-IDF cosine similarity.
5
6use std::collections::HashMap;
7
8// ──────────────────────────────────────────────────────────────────────────────
9// Types
10// ──────────────────────────────────────────────────────────────────────────────
11
12/// A span of text that may refer to a named entity.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct EntityMention {
15    /// Surface form of the mention.
16    pub text: String,
17    /// Start byte offset in the containing text.
18    pub start: usize,
19    /// End byte offset (exclusive) in the containing text.
20    pub end: usize,
21}
22
23impl EntityMention {
24    /// Create a mention from text and character positions.
25    pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
26        Self {
27            text: text.into(),
28            start,
29            end,
30        }
31    }
32}
33
34/// A candidate entity from the knowledge base.
35#[derive(Debug, Clone)]
36pub struct EntityCandidate {
37    /// Entity IRI.
38    pub iri: String,
39    /// Primary label.
40    pub label: String,
41    /// String-matching similarity score (0.0–1.0).
42    pub score: f64,
43    /// Alternative labels and aliases.
44    pub aliases: Vec<String>,
45}
46
47impl EntityCandidate {
48    fn new(iri: impl Into<String>, label: impl Into<String>, aliases: Vec<String>) -> Self {
49        Self {
50            iri: iri.into(),
51            label: label.into(),
52            score: 0.0,
53            aliases,
54        }
55    }
56}
57
58/// A successfully linked entity.
59#[derive(Debug, Clone)]
60pub struct LinkedEntity {
61    /// The text mention.
62    pub mention: EntityMention,
63    /// The best matching entity candidate.
64    pub entity: EntityCandidate,
65    /// Overall confidence (0.0–1.0) combining string + context similarity.
66    pub confidence: f64,
67}
68
69// ──────────────────────────────────────────────────────────────────────────────
70// TfIdfIndex
71// ──────────────────────────────────────────────────────────────────────────────
72
73/// A simple TF-IDF index for context-based disambiguation.
74pub struct TfIdfIndex {
75    /// Documents: (doc_id, term → tf).
76    docs: Vec<(String, HashMap<String, f64>)>,
77    /// Inverse document frequency: term → idf.
78    idf: HashMap<String, f64>,
79}
80
81impl TfIdfIndex {
82    /// Create an empty index.
83    pub fn new() -> Self {
84        Self {
85            docs: Vec::new(),
86            idf: HashMap::new(),
87        }
88    }
89
90    /// Add a document to the index.
91    pub fn add_document(&mut self, doc_id: impl Into<String>, text: &str) {
92        let tokens = tokenize(text);
93        let total = tokens.len() as f64;
94        if total == 0.0 {
95            return;
96        }
97        let mut tf: HashMap<String, f64> = HashMap::new();
98        for tok in &tokens {
99            *tf.entry(tok.clone()).or_insert(0.0) += 1.0 / total;
100        }
101        self.docs.push((doc_id.into(), tf));
102    }
103
104    /// Recompute IDF from all indexed documents.
105    pub fn build(&mut self) {
106        let n = self.docs.len() as f64;
107        let mut df: HashMap<String, usize> = HashMap::new();
108        for (_, tf) in &self.docs {
109            for term in tf.keys() {
110                *df.entry(term.clone()).or_insert(0) += 1;
111            }
112        }
113        self.idf.clear();
114        for (term, count) in df {
115            self.idf.insert(term, (n / count as f64).ln() + 1.0);
116        }
117    }
118
119    /// Compute TF-IDF cosine similarity between a query string and a document.
120    pub fn similarity(&self, query: &str, doc_id: &str) -> f64 {
121        let doc = match self.docs.iter().find(|(id, _)| id == doc_id) {
122            Some((_, tf)) => tf,
123            None => return 0.0,
124        };
125
126        let q_tokens = tokenize(query);
127        let q_total = q_tokens.len() as f64;
128        if q_total == 0.0 {
129            return 0.0;
130        }
131        let mut q_tf: HashMap<String, f64> = HashMap::new();
132        for tok in &q_tokens {
133            *q_tf.entry(tok.clone()).or_insert(0.0) += 1.0 / q_total;
134        }
135
136        // Cosine similarity of TF-IDF vectors
137        let mut dot = 0.0_f64;
138        let mut q_norm = 0.0_f64;
139        let mut d_norm = 0.0_f64;
140
141        let all_terms: std::collections::HashSet<&String> = q_tf.keys().chain(doc.keys()).collect();
142
143        for term in all_terms {
144            let idf = self.idf.get(term).copied().unwrap_or(1.0);
145            let q_val = q_tf.get(term).copied().unwrap_or(0.0) * idf;
146            let d_val = doc.get(term).copied().unwrap_or(0.0) * idf;
147            dot += q_val * d_val;
148            q_norm += q_val * q_val;
149            d_norm += d_val * d_val;
150        }
151
152        let denom = q_norm.sqrt() * d_norm.sqrt();
153        if denom < 1e-15 {
154            0.0
155        } else {
156            (dot / denom).clamp(0.0, 1.0)
157        }
158    }
159
160    /// Number of indexed documents.
161    pub fn doc_count(&self) -> usize {
162        self.docs.len()
163    }
164}
165
166impl Default for TfIdfIndex {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172// ──────────────────────────────────────────────────────────────────────────────
173// EntityLinker
174// ──────────────────────────────────────────────────────────────────────────────
175
176/// Links entity mentions in text to knowledge-base entities.
177pub struct EntityLinker {
178    /// Knowledge base entries: iri → (label, aliases, context_doc_id).
179    kb: HashMap<String, KbEntry>,
180    /// TF-IDF index over entity descriptions (for context disambiguation).
181    tfidf: TfIdfIndex,
182    /// Minimum confidence threshold below which an entity is treated as NIL.
183    pub nil_threshold: f64,
184}
185
186struct KbEntry {
187    label: String,
188    aliases: Vec<String>,
189}
190
191impl EntityLinker {
192    /// Create an entity linker with default NIL threshold 0.1.
193    pub fn new() -> Self {
194        Self {
195            kb: HashMap::new(),
196            tfidf: TfIdfIndex::new(),
197            nil_threshold: 0.1,
198        }
199    }
200
201    /// Add an entity to the knowledge base.
202    ///
203    /// `context` is an optional textual description used for TF-IDF
204    /// disambiguation.
205    pub fn add_entity(
206        &mut self,
207        iri: impl Into<String>,
208        label: impl Into<String>,
209        aliases: &[&str],
210    ) {
211        let iri = iri.into();
212        let label = label.into();
213        let aliases: Vec<String> = aliases.iter().map(|s| s.to_string()).collect();
214        let context = format!("{} {}", label, aliases.join(" "));
215        self.tfidf.add_document(iri.clone(), &context);
216        self.kb.insert(iri, KbEntry { label, aliases });
217    }
218
219    /// Finalise the TF-IDF index (call after all entities are added).
220    pub fn build_index(&mut self) {
221        self.tfidf.build();
222    }
223
224    /// Find and link all entity mentions in `text`.
225    pub fn link(&self, text: &str) -> Vec<LinkedEntity> {
226        let mentions = detect_mentions(text);
227        let mut linked = Vec::new();
228
229        for mention in mentions {
230            let candidates = self.candidate_generation(&mention.text);
231            if candidates.is_empty() {
232                continue;
233            }
234            let best = self.disambiguate(&mention, &candidates, text);
235            if let Some(entity) = best {
236                let confidence = entity.score;
237                if confidence >= self.nil_threshold {
238                    linked.push(LinkedEntity {
239                        mention,
240                        entity,
241                        confidence,
242                    });
243                }
244            }
245        }
246        linked
247    }
248
249    /// Generate entity candidates matching the mention by string similarity.
250    pub fn candidate_generation(&self, mention: &str) -> Vec<EntityCandidate> {
251        let mention_lower = mention.to_lowercase();
252        let mut candidates: Vec<EntityCandidate> = self
253            .kb
254            .iter()
255            .filter_map(|(iri, entry)| {
256                let label_score = jaro_winkler(&mention_lower, &entry.label.to_lowercase());
257                let alias_score = entry
258                    .aliases
259                    .iter()
260                    .map(|a| jaro_winkler(&mention_lower, &a.to_lowercase()))
261                    .fold(0.0_f64, f64::max);
262                let score = label_score.max(alias_score);
263                if score > 0.6 {
264                    let mut c = EntityCandidate::new(
265                        iri.clone(),
266                        entry.label.clone(),
267                        entry.aliases.clone(),
268                    );
269                    c.score = score;
270                    Some(c)
271                } else {
272                    None
273                }
274            })
275            .collect();
276
277        candidates.sort_by(|a, b| {
278            b.score
279                .partial_cmp(&a.score)
280                .unwrap_or(std::cmp::Ordering::Equal)
281        });
282        candidates
283    }
284
285    /// Disambiguate among candidates using context TF-IDF similarity.
286    pub fn disambiguate(
287        &self,
288        _mention: &EntityMention,
289        candidates: &[EntityCandidate],
290        context: &str,
291    ) -> Option<EntityCandidate> {
292        if candidates.is_empty() {
293            return None;
294        }
295
296        let mut best_score = f64::NEG_INFINITY;
297        let mut best: Option<EntityCandidate> = None;
298
299        for cand in candidates {
300            let ctx_score = self.tfidf.similarity(context, &cand.iri);
301            // Combined score: string similarity × 0.6 + context × 0.4
302            let combined = cand.score * 0.6 + ctx_score * 0.4;
303            if combined > best_score {
304                best_score = combined;
305                let mut winner = cand.clone();
306                winner.score = combined;
307                best = Some(winner);
308            }
309        }
310        best
311    }
312
313    /// Number of entities in the knowledge base.
314    pub fn entity_count(&self) -> usize {
315        self.kb.len()
316    }
317}
318
319impl Default for EntityLinker {
320    fn default() -> Self {
321        Self::new()
322    }
323}
324
325// ──────────────────────────────────────────────────────────────────────────────
326// Private helpers
327// ──────────────────────────────────────────────────────────────────────────────
328
329/// Detect potential entity mentions in text by looking for capitalised tokens
330/// or sequences.
331fn detect_mentions(text: &str) -> Vec<EntityMention> {
332    let mut mentions = Vec::new();
333    let mut chars = text.char_indices().peekable();
334    let bytes = text.as_bytes();
335    let len = bytes.len();
336
337    while let Some((start, ch)) = chars.next() {
338        if ch.is_uppercase() {
339            // Consume a capitalised word sequence (handles "Albert Einstein")
340            let mut end = start + ch.len_utf8();
341            while end < len {
342                let next_ch = text[end..].chars().next().unwrap_or('\0');
343                if next_ch.is_alphanumeric() || next_ch == ' ' {
344                    // Allow one space if followed by uppercase (multi-word entity)
345                    if next_ch == ' ' {
346                        let after_space = end + 1;
347                        if after_space < len {
348                            let nc2 = text[after_space..].chars().next().unwrap_or('\0');
349                            if nc2.is_uppercase() {
350                                end = after_space + nc2.len_utf8();
351                                // advance the chars iterator past the space and the uppercase char
352                                let _ = chars.next(); // space
353                                let _ = chars.next(); // uppercase
354                                continue;
355                            }
356                        }
357                        break;
358                    }
359                    end += next_ch.len_utf8();
360                    let _ = chars.next();
361                } else {
362                    break;
363                }
364            }
365            let mention_text = text[start..end].trim().to_string();
366            if mention_text.len() >= 2 {
367                mentions.push(EntityMention::new(mention_text, start, end));
368            }
369        }
370    }
371    mentions
372}
373
374/// Jaro-Winkler string similarity (0.0–1.0).
375fn jaro_winkler(s1: &str, s2: &str) -> f64 {
376    if s1 == s2 {
377        return 1.0;
378    }
379    let jaro = jaro(s1, s2);
380    let prefix_len = s1
381        .chars()
382        .zip(s2.chars())
383        .take(4)
384        .take_while(|(a, b)| a == b)
385        .count();
386    let p = 0.1_f64;
387    jaro + (prefix_len as f64 * p * (1.0 - jaro))
388}
389
390fn jaro(s1: &str, s2: &str) -> f64 {
391    let s1: Vec<char> = s1.chars().collect();
392    let s2: Vec<char> = s2.chars().collect();
393    let len1 = s1.len();
394    let len2 = s2.len();
395    if len1 == 0 && len2 == 0 {
396        return 1.0;
397    }
398    if len1 == 0 || len2 == 0 {
399        return 0.0;
400    }
401
402    let match_window = (len1.max(len2) / 2).saturating_sub(1);
403    let mut s1_matches = vec![false; len1];
404    let mut s2_matches = vec![false; len2];
405    let mut matches = 0usize;
406    let mut transpositions = 0usize;
407
408    for (i, &c1) in s1.iter().enumerate() {
409        let start = i.saturating_sub(match_window);
410        let end = (i + match_window + 1).min(len2);
411        for (j, &c2) in s2[start..end].iter().enumerate() {
412            let j_real = start + j;
413            if !s2_matches[j_real] && c1 == c2 {
414                s1_matches[i] = true;
415                s2_matches[j_real] = true;
416                matches += 1;
417                break;
418            }
419        }
420    }
421
422    if matches == 0 {
423        return 0.0;
424    }
425
426    let mut k = 0;
427    for (i, &s1m) in s1_matches.iter().enumerate() {
428        if s1m {
429            while !s2_matches[k] {
430                k += 1;
431            }
432            if s1[i] != s2[k] {
433                transpositions += 1;
434            }
435            k += 1;
436        }
437    }
438
439    let m = matches as f64;
440    (m / len1 as f64 + m / len2 as f64 + (m - transpositions as f64 / 2.0) / m) / 3.0
441}
442
443/// Tokenise text into lowercase alpha-numeric tokens.
444fn tokenize(text: &str) -> Vec<String> {
445    text.split(|c: char| !c.is_alphanumeric())
446        .filter(|s| !s.is_empty())
447        .map(|s| s.to_lowercase())
448        .collect()
449}
450
451// ──────────────────────────────────────────────────────────────────────────────
452// Tests
453// ──────────────────────────────────────────────────────────────────────────────
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    fn linker_with_persons() -> EntityLinker {
460        let mut linker = EntityLinker::new();
461        linker.add_entity(
462            "http://example.org/Albert_Einstein",
463            "Albert Einstein",
464            &["Einstein", "A. Einstein"],
465        );
466        linker.add_entity(
467            "http://example.org/Marie_Curie",
468            "Marie Curie",
469            &["Curie", "M. Curie"],
470        );
471        linker.add_entity(
472            "http://example.org/Isaac_Newton",
473            "Isaac Newton",
474            &["Newton"],
475        );
476        linker.build_index();
477        linker
478    }
479
480    // ── EntityMention ─────────────────────────────────────────────────────────
481
482    #[test]
483    fn test_mention_new() {
484        let m = EntityMention::new("Alice", 0, 5);
485        assert_eq!(m.text, "Alice");
486        assert_eq!(m.start, 0);
487        assert_eq!(m.end, 5);
488    }
489
490    #[test]
491    fn test_mention_equality() {
492        let m1 = EntityMention::new("Bob", 0, 3);
493        let m2 = EntityMention::new("Bob", 0, 3);
494        assert_eq!(m1, m2);
495    }
496
497    // ── TfIdfIndex ────────────────────────────────────────────────────────────
498
499    #[test]
500    fn test_tfidf_add_document() {
501        let mut idx = TfIdfIndex::new();
502        idx.add_document("doc1", "quantum physics relativity");
503        idx.build();
504        assert_eq!(idx.doc_count(), 1);
505    }
506
507    #[test]
508    fn test_tfidf_similarity_same_doc() {
509        let mut idx = TfIdfIndex::new();
510        idx.add_document("doc1", "quantum physics relativity");
511        idx.build();
512        let sim = idx.similarity("quantum physics", "doc1");
513        assert!(sim > 0.0, "similarity should be > 0, got {sim}");
514    }
515
516    #[test]
517    fn test_tfidf_similarity_different_content() {
518        let mut idx = TfIdfIndex::new();
519        idx.add_document("doc1", "quantum physics relativity");
520        idx.add_document("doc2", "cooking recipes baking bread");
521        idx.build();
522        let s1 = idx.similarity("quantum physics", "doc1");
523        let s2 = idx.similarity("quantum physics", "doc2");
524        assert!(s1 > s2, "physics query should match doc1 better");
525    }
526
527    #[test]
528    fn test_tfidf_unknown_doc() {
529        let idx = TfIdfIndex::new();
530        assert_eq!(idx.similarity("anything", "unknown"), 0.0);
531    }
532
533    #[test]
534    fn test_tfidf_empty_query() {
535        let mut idx = TfIdfIndex::new();
536        idx.add_document("d", "hello world");
537        idx.build();
538        assert_eq!(idx.similarity("", "d"), 0.0);
539    }
540
541    #[test]
542    fn test_tfidf_default() {
543        let idx = TfIdfIndex::default();
544        assert_eq!(idx.doc_count(), 0);
545    }
546
547    // ── EntityLinker ──────────────────────────────────────────────────────────
548
549    #[test]
550    fn test_linker_entity_count() {
551        let linker = linker_with_persons();
552        assert_eq!(linker.entity_count(), 3);
553    }
554
555    #[test]
556    fn test_linker_default() {
557        let linker = EntityLinker::default();
558        assert_eq!(linker.entity_count(), 0);
559    }
560
561    // ── candidate_generation ──────────────────────────────────────────────────
562
563    #[test]
564    fn test_candidate_generation_exact_label() {
565        let linker = linker_with_persons();
566        let cands = linker.candidate_generation("Einstein");
567        assert!(!cands.is_empty());
568        assert!(cands[0].iri.contains("Einstein"));
569    }
570
571    #[test]
572    fn test_candidate_generation_partial() {
573        let linker = linker_with_persons();
574        let cands = linker.candidate_generation("Newton");
575        assert!(!cands.is_empty());
576        assert!(cands.iter().any(|c| c.iri.contains("Newton")));
577    }
578
579    #[test]
580    fn test_candidate_generation_no_match() {
581        let linker = linker_with_persons();
582        let cands = linker.candidate_generation("Zorkblat");
583        assert!(cands.is_empty());
584    }
585
586    #[test]
587    fn test_candidate_generation_sorted_by_score() {
588        let linker = linker_with_persons();
589        let cands = linker.candidate_generation("Curie");
590        for i in 1..cands.len() {
591            assert!(cands[i - 1].score >= cands[i].score);
592        }
593    }
594
595    #[test]
596    fn test_candidate_generation_alias_match() {
597        let linker = linker_with_persons();
598        // "Curie" is an alias for Marie Curie
599        let cands = linker.candidate_generation("Curie");
600        assert!(cands.iter().any(|c| c.iri.contains("Curie")));
601    }
602
603    // ── disambiguate ─────────────────────────────────────────────────────────
604
605    #[test]
606    fn test_disambiguate_returns_best() {
607        let linker = linker_with_persons();
608        let cands = linker.candidate_generation("Einstein");
609        let mention = EntityMention::new("Einstein", 0, 8);
610        let best = linker.disambiguate(&mention, &cands, "Einstein worked on relativity");
611        assert!(best.is_some());
612        assert!(best.expect("should succeed").iri.contains("Einstein"));
613    }
614
615    #[test]
616    fn test_disambiguate_empty_candidates() {
617        let linker = linker_with_persons();
618        let mention = EntityMention::new("X", 0, 1);
619        let result = linker.disambiguate(&mention, &[], "context");
620        assert!(result.is_none());
621    }
622
623    #[test]
624    fn test_disambiguate_score_in_range() {
625        let linker = linker_with_persons();
626        let cands = linker.candidate_generation("Newton");
627        let mention = EntityMention::new("Newton", 0, 6);
628        if let Some(best) = linker.disambiguate(&mention, &cands, "gravity laws Newton") {
629            assert!((0.0..=1.0).contains(&best.score));
630        }
631    }
632
633    // ── link ──────────────────────────────────────────────────────────────────
634
635    #[test]
636    fn test_link_finds_entity() {
637        let linker = linker_with_persons();
638        let linked = linker.link("Einstein developed relativity theory.");
639        assert!(!linked.is_empty());
640        assert!(linked[0].entity.iri.contains("Einstein"));
641    }
642
643    #[test]
644    fn test_link_confidence_above_threshold() {
645        let linker = linker_with_persons();
646        let linked = linker.link("Newton formulated laws of motion.");
647        for le in &linked {
648            assert!(le.confidence >= linker.nil_threshold);
649        }
650    }
651
652    #[test]
653    fn test_link_no_entities_in_empty_text() {
654        let linker = linker_with_persons();
655        let linked = linker.link("");
656        assert!(linked.is_empty());
657    }
658
659    #[test]
660    fn test_link_result_fields() {
661        let linker = linker_with_persons();
662        let linked = linker.link("Einstein and Curie were scientists.");
663        for le in &linked {
664            assert!(!le.mention.text.is_empty());
665            assert!(!le.entity.iri.is_empty());
666            assert!((0.0..=1.0).contains(&le.confidence));
667        }
668    }
669
670    // ── Jaro-Winkler ──────────────────────────────────────────────────────────
671
672    #[test]
673    fn test_jaro_winkler_identical() {
674        assert!((jaro_winkler("hello", "hello") - 1.0).abs() < 1e-9);
675    }
676
677    #[test]
678    fn test_jaro_winkler_completely_different() {
679        let score = jaro_winkler("abc", "xyz");
680        assert!(score < 0.5, "score = {score}");
681    }
682
683    #[test]
684    fn test_jaro_winkler_prefix_boost() {
685        let jw = jaro_winkler("einstein", "einstien");
686        assert!(jw > 0.8, "score = {jw}");
687    }
688
689    #[test]
690    fn test_jaro_winkler_empty_strings() {
691        assert!((jaro("", "") - 1.0).abs() < 1e-9);
692        assert!((jaro("abc", "") - 0.0).abs() < 1e-9);
693    }
694
695    // ── detect_mentions ───────────────────────────────────────────────────────
696
697    #[test]
698    fn test_detect_mentions_finds_capitalized() {
699        let mentions = detect_mentions("Alice and Bob went to Paris.");
700        let texts: Vec<&str> = mentions.iter().map(|m| m.text.as_str()).collect();
701        // At least Alice, Bob, Paris should be detected
702        assert!(texts
703            .iter()
704            .any(|t| *t == "Alice" || t.starts_with("Alice")));
705    }
706
707    #[test]
708    fn test_detect_mentions_empty() {
709        assert!(detect_mentions("").is_empty());
710    }
711
712    #[test]
713    fn test_detect_mentions_lowercase_only() {
714        let mentions = detect_mentions("all lowercase words here");
715        assert!(mentions.is_empty());
716    }
717
718    // ── tokenize ─────────────────────────────────────────────────────────────
719
720    #[test]
721    fn test_tokenize_basic() {
722        let tokens = tokenize("Hello World");
723        assert_eq!(tokens, vec!["hello", "world"]);
724    }
725
726    #[test]
727    fn test_tokenize_empty() {
728        assert!(tokenize("").is_empty());
729    }
730
731    #[test]
732    fn test_tokenize_punctuation_split() {
733        let tokens = tokenize("foo, bar; baz.");
734        assert_eq!(tokens, vec!["foo", "bar", "baz"]);
735    }
736
737    // ── Full pipeline ─────────────────────────────────────────────────────────
738
739    #[test]
740    fn test_full_pipeline() {
741        let mut linker = EntityLinker::new();
742        linker.add_entity("http://ex.org/Paris", "Paris", &["City of Light"]);
743        linker.add_entity("http://ex.org/London", "London", &["British capital"]);
744        linker.build_index();
745
746        let linked = linker.link("Paris is a famous city in France.");
747        if !linked.is_empty() {
748            assert!(linked[0].entity.iri.contains("Paris"));
749        }
750        // No assertion on count since detection depends on heuristic
751    }
752
753    #[test]
754    fn test_nil_threshold_filters_low_confidence() {
755        let mut linker = EntityLinker::new();
756        linker.add_entity("http://ex.org/X", "Xyzzy", &[]);
757        linker.build_index();
758        linker.nil_threshold = 0.99; // Very high threshold
759
760        let linked = linker.link("Xyzzy something");
761        // Most links should be filtered out by high threshold
762        for le in &linked {
763            assert!(le.confidence >= 0.99);
764        }
765    }
766}