Skip to main content

oxirs_graphrag/
entity_linker.rs

1//! String-to-RDF entity linking: mention detection and candidate ranking.
2//!
3//! Provides exact and fuzzy matching of text mentions to entities in a
4//! knowledge base, with configurable score thresholds.
5
6use std::collections::HashMap;
7
8// ---------------------------------------------------------------------------
9// Domain types
10// ---------------------------------------------------------------------------
11
12/// An entity in the knowledge base.
13#[derive(Debug, Clone)]
14pub struct Entity {
15    /// Canonical IRI identifier.
16    pub iri: String,
17    /// Primary label.
18    pub label: String,
19    /// Alternative surface forms / aliases.
20    pub aliases: Vec<String>,
21    /// Entity type string (e.g. "Person", "Organization", "Place").
22    pub entity_type: String,
23    /// Optional short description.
24    pub description: Option<String>,
25    /// Popularity score in \[0.0, 1.0\] used as a prior.
26    pub popularity: f64,
27}
28
29/// A mention of an entity surface form found in text.
30#[derive(Debug, Clone)]
31pub struct EntityMention {
32    /// The surface form detected in the input text.
33    pub text: String,
34    /// Byte offset of the first character.
35    pub start_char: usize,
36    /// Byte offset one past the last character.
37    pub end_char: usize,
38    /// Ranked candidate links for this mention.
39    pub candidates: Vec<LinkCandidate>,
40}
41
42/// A single candidate link returned for a mention.
43#[derive(Debug, Clone)]
44pub struct LinkCandidate {
45    /// The candidate entity.
46    pub entity: Entity,
47    /// Combined ranking score.
48    pub score: f64,
49    /// Normalised string similarity to the mention text.
50    pub string_similarity: f64,
51    /// Entity popularity used as prior.
52    pub prior_probability: f64,
53}
54
55/// A resolved link between a text mention and its best-matching entity.
56#[derive(Debug, Clone)]
57pub struct LinkedEntity {
58    /// The text mention.
59    pub mention: EntityMention,
60    /// Highest-scoring candidate (if any exceeds the linker threshold).
61    pub best_candidate: Option<LinkCandidate>,
62    /// Confidence score of the best candidate (0.0 if none).
63    pub confidence: f64,
64}
65
66// ---------------------------------------------------------------------------
67// LinkerError
68// ---------------------------------------------------------------------------
69
70/// Errors returned by `EntityLinker`.
71#[derive(Debug)]
72pub enum LinkerError {
73    /// The supplied text is empty.
74    EmptyText,
75    /// The knowledge base has no entities.
76    EmptyKnowledgeBase,
77    /// An entity in the knowledge base is invalid.
78    InvalidEntity(String),
79}
80
81impl std::fmt::Display for LinkerError {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            LinkerError::EmptyText => write!(f, "Input text is empty"),
85            LinkerError::EmptyKnowledgeBase => write!(f, "Knowledge base contains no entities"),
86            LinkerError::InvalidEntity(msg) => write!(f, "Invalid entity: {}", msg),
87        }
88    }
89}
90
91impl std::error::Error for LinkerError {}
92
93// ---------------------------------------------------------------------------
94// EntityLinker
95// ---------------------------------------------------------------------------
96
97/// String-to-entity linker backed by an in-memory knowledge base.
98pub struct EntityLinker {
99    entities: Vec<Entity>,
100    /// label (lowercased) → list of entity indices
101    label_index: HashMap<String, Vec<usize>>,
102    /// Minimum character length of a text span to be considered as a mention.
103    pub min_mention_len: usize,
104    /// Minimum combined score for a candidate to be returned.
105    pub threshold: f64,
106}
107
108impl EntityLinker {
109    // -----------------------------------------------------------------------
110    // Construction
111    // -----------------------------------------------------------------------
112
113    /// Create an empty linker with the given minimum combined score threshold.
114    pub fn new(threshold: f64) -> Self {
115        Self {
116            entities: Vec::new(),
117            label_index: HashMap::new(),
118            min_mention_len: 2,
119            threshold,
120        }
121    }
122
123    /// Add a single entity to the knowledge base and update the index.
124    pub fn add_entity(&mut self, entity: Entity) {
125        let idx = self.entities.len();
126        // Index the label and all aliases
127        let mut keys: Vec<String> = entity.aliases.iter().map(|a| a.to_lowercase()).collect();
128        keys.push(entity.label.to_lowercase());
129
130        for key in keys {
131            self.label_index.entry(key).or_default().push(idx);
132        }
133        self.entities.push(entity);
134    }
135
136    /// Add multiple entities in bulk and rebuild the index.
137    pub fn add_entities(&mut self, entities: Vec<Entity>) {
138        for entity in entities {
139            self.add_entity(entity);
140        }
141    }
142
143    /// Total number of entities in the knowledge base.
144    pub fn entity_count(&self) -> usize {
145        self.entities.len()
146    }
147
148    // -----------------------------------------------------------------------
149    // Mention detection
150    // -----------------------------------------------------------------------
151
152    /// Scan `text` for all known entity labels / aliases and return mentions.
153    ///
154    /// Both exact (case-insensitive) and fuzzy (similarity ≥ 0.7) matches are
155    /// considered.  Overlapping spans from the same starting position are
156    /// de-duplicated in favour of the longer match.
157    pub fn detect_mentions(&self, text: &str) -> Vec<EntityMention> {
158        let text_lower = text.to_lowercase();
159        let mut mentions: Vec<EntityMention> = Vec::new();
160
161        for (surface, indices) in &self.label_index {
162            if surface.len() < self.min_mention_len {
163                continue;
164            }
165            // Find all (non-overlapping) occurrences of `surface` in text_lower
166            let mut search_start = 0;
167            while let Some(pos) = text_lower[search_start..].find(surface.as_str()) {
168                let abs_start = search_start + pos;
169                let abs_end = abs_start + surface.len();
170
171                // Collect candidates for this surface form
172                let candidates = self.candidates_for_surface(surface, indices);
173                if !candidates.is_empty() {
174                    // Use the original casing from the source text
175                    let original_text = &text[abs_start..abs_end];
176                    mentions.push(EntityMention {
177                        text: original_text.to_string(),
178                        start_char: abs_start,
179                        end_char: abs_end,
180                        candidates,
181                    });
182                }
183                search_start = abs_start + 1;
184            }
185        }
186
187        // Sort by position, then deduplicate overlapping spans (keep longest)
188        mentions.sort_by_key(|m| (m.start_char, usize::MAX - (m.end_char - m.start_char)));
189        let mut deduped: Vec<EntityMention> = Vec::new();
190        for mention in mentions {
191            // Skip if fully contained in the last kept mention
192            if let Some(last) = deduped.last() {
193                if mention.start_char >= last.start_char && mention.end_char <= last.end_char {
194                    continue;
195                }
196            }
197            deduped.push(mention);
198        }
199        deduped
200    }
201
202    // -----------------------------------------------------------------------
203    // Linking
204    // -----------------------------------------------------------------------
205
206    /// Link all mentions found in `text` to their best entity candidates.
207    ///
208    /// # Errors
209    /// - `LinkerError::EmptyText` if `text` is empty after trimming.
210    /// - `LinkerError::EmptyKnowledgeBase` if no entities have been added.
211    pub fn link(&self, text: &str) -> Result<Vec<LinkedEntity>, LinkerError> {
212        if text.trim().is_empty() {
213            return Err(LinkerError::EmptyText);
214        }
215        if self.entities.is_empty() {
216            return Err(LinkerError::EmptyKnowledgeBase);
217        }
218
219        let mentions = self.detect_mentions(text);
220        let linked = mentions
221            .into_iter()
222            .map(|mention| {
223                let best = mention
224                    .candidates
225                    .iter()
226                    .find(|c| c.score >= self.threshold)
227                    .cloned();
228                let confidence = best.as_ref().map(|c| c.score).unwrap_or(0.0);
229                LinkedEntity {
230                    mention,
231                    best_candidate: best,
232                    confidence,
233                }
234            })
235            .collect();
236        Ok(linked)
237    }
238
239    /// Find all entity candidates for an arbitrary mention string.
240    ///
241    /// Searches all entity labels and aliases using both exact (lowercase) and
242    /// fuzzy matching.  Results are sorted by `score` descending.
243    pub fn link_mention(&self, mention: &str) -> Vec<LinkCandidate> {
244        let mention_lower = mention.to_lowercase();
245        let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
246        let mut candidates: Vec<LinkCandidate> = Vec::new();
247
248        for (idx, entity) in self.entities.iter().enumerate() {
249            // Collect all surface forms of this entity
250            let mut surfaces: Vec<String> =
251                entity.aliases.iter().map(|a| a.to_lowercase()).collect();
252            surfaces.push(entity.label.to_lowercase());
253
254            let best_sim = surfaces
255                .iter()
256                .map(|s| Self::string_similarity(&mention_lower, s))
257                .fold(0.0_f64, f64::max);
258
259            if best_sim > 0.0 && !seen.contains(&idx) {
260                let score = Self::combined_score(best_sim, entity.popularity);
261                candidates.push(LinkCandidate {
262                    entity: entity.clone(),
263                    score,
264                    string_similarity: best_sim,
265                    prior_probability: entity.popularity,
266                });
267                seen.insert(idx);
268            }
269        }
270
271        // Sort by score descending
272        candidates.sort_by(|a, b| {
273            b.score
274                .partial_cmp(&a.score)
275                .unwrap_or(std::cmp::Ordering::Equal)
276        });
277        candidates
278    }
279
280    /// Look up an entity by its exact IRI.
281    pub fn find_by_iri(&self, iri: &str) -> Option<&Entity> {
282        self.entities.iter().find(|e| e.iri == iri)
283    }
284
285    // -----------------------------------------------------------------------
286    // String similarity
287    // -----------------------------------------------------------------------
288
289    /// Normalised edit-distance similarity: 1 − edit_distance(a, b) / max(|a|, |b|).
290    ///
291    /// Returns 1.0 for identical strings and 0.0 when the edit distance equals
292    /// the length of the longer string.
293    pub fn string_similarity(a: &str, b: &str) -> f64 {
294        if a == b {
295            return 1.0;
296        }
297        let len_a = a.chars().count();
298        let len_b = b.chars().count();
299        let max_len = len_a.max(len_b);
300        if max_len == 0 {
301            return 1.0; // both empty
302        }
303        let dist = Self::edit_distance(a, b);
304        1.0 - (dist as f64 / max_len as f64)
305    }
306
307    /// Levenshtein edit distance between two strings.
308    fn edit_distance(a: &str, b: &str) -> usize {
309        let a_chars: Vec<char> = a.chars().collect();
310        let b_chars: Vec<char> = b.chars().collect();
311        let la = a_chars.len();
312        let lb = b_chars.len();
313
314        if la == 0 {
315            return lb;
316        }
317        if lb == 0 {
318            return la;
319        }
320
321        let mut dp = vec![vec![0usize; lb + 1]; la + 1];
322        for (i, row) in dp.iter_mut().enumerate() {
323            row[0] = i;
324        }
325        for (j, cell) in dp[0].iter_mut().enumerate() {
326            *cell = j;
327        }
328        for i in 1..=la {
329            for j in 1..=lb {
330                let cost = if a_chars[i - 1] == b_chars[j - 1] {
331                    0
332                } else {
333                    1
334                };
335                dp[i][j] = (dp[i - 1][j] + 1)
336                    .min(dp[i][j - 1] + 1)
337                    .min(dp[i - 1][j - 1] + cost);
338            }
339        }
340        dp[la][lb]
341    }
342
343    // -----------------------------------------------------------------------
344    // Internal helpers
345    // -----------------------------------------------------------------------
346
347    /// Rebuild the label index from scratch.
348    #[allow(dead_code)]
349    fn rebuild_index(&mut self) {
350        self.label_index.clear();
351        for (idx, entity) in self.entities.iter().enumerate() {
352            let mut keys: Vec<String> = entity.aliases.iter().map(|a| a.to_lowercase()).collect();
353            keys.push(entity.label.to_lowercase());
354            for key in keys {
355                self.label_index.entry(key).or_default().push(idx);
356            }
357        }
358    }
359
360    /// Build candidates for a known surface form and entity index list.
361    fn candidates_for_surface(&self, surface: &str, indices: &[usize]) -> Vec<LinkCandidate> {
362        let mut candidates: Vec<LinkCandidate> = indices
363            .iter()
364            .filter_map(|&idx| {
365                let entity = self.entities.get(idx)?;
366                let sim = Self::string_similarity(surface, &entity.label.to_lowercase());
367                let score = Self::combined_score(sim, entity.popularity);
368                Some(LinkCandidate {
369                    entity: entity.clone(),
370                    score,
371                    string_similarity: sim,
372                    prior_probability: entity.popularity,
373                })
374            })
375            .collect();
376        candidates.sort_by(|a, b| {
377            b.score
378                .partial_cmp(&a.score)
379                .unwrap_or(std::cmp::Ordering::Equal)
380        });
381        candidates
382    }
383
384    /// Combined ranking score: 0.7 × string_similarity + 0.3 × popularity.
385    fn combined_score(sim: f64, popularity: f64) -> f64 {
386        0.7 * sim + 0.3 * popularity
387    }
388}
389
390// ---------------------------------------------------------------------------
391// Tests
392// ---------------------------------------------------------------------------
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    fn make_entity(iri: &str, label: &str, aliases: &[&str], pop: f64) -> Entity {
399        Entity {
400            iri: iri.to_string(),
401            label: label.to_string(),
402            aliases: aliases.iter().map(|s| s.to_string()).collect(),
403            entity_type: "Thing".to_string(),
404            description: None,
405            popularity: pop,
406        }
407    }
408
409    fn sample_linker() -> EntityLinker {
410        let mut linker = EntityLinker::new(0.3);
411        linker.add_entity(make_entity(
412            "http://example.org/einstein",
413            "Albert Einstein",
414            &["Einstein", "A. Einstein"],
415            0.95,
416        ));
417        linker.add_entity(make_entity(
418            "http://example.org/curie",
419            "Marie Curie",
420            &["Curie", "M. Curie"],
421            0.90,
422        ));
423        linker.add_entity(make_entity(
424            "http://example.org/berlin",
425            "Berlin",
426            &["Berlin City"],
427            0.80,
428        ));
429        linker
430    }
431
432    // --- entity_count -------------------------------------------------------
433
434    #[test]
435    fn test_entity_count_empty() {
436        let linker = EntityLinker::new(0.5);
437        assert_eq!(linker.entity_count(), 0);
438    }
439
440    #[test]
441    fn test_entity_count_after_add() {
442        let mut linker = EntityLinker::new(0.5);
443        linker.add_entity(make_entity("http://x.org/a", "Alpha", &[], 0.5));
444        assert_eq!(linker.entity_count(), 1);
445        linker.add_entity(make_entity("http://x.org/b", "Beta", &[], 0.5));
446        assert_eq!(linker.entity_count(), 2);
447    }
448
449    #[test]
450    fn test_add_entities_bulk() {
451        let mut linker = EntityLinker::new(0.5);
452        linker.add_entities(vec![
453            make_entity("http://x.org/a", "Alpha", &[], 0.5),
454            make_entity("http://x.org/b", "Beta", &[], 0.6),
455            make_entity("http://x.org/c", "Gamma", &[], 0.7),
456        ]);
457        assert_eq!(linker.entity_count(), 3);
458    }
459
460    // --- find_by_iri --------------------------------------------------------
461
462    #[test]
463    fn test_find_by_iri_exists() {
464        let linker = sample_linker();
465        let entity = linker.find_by_iri("http://example.org/einstein");
466        assert!(entity.is_some());
467        assert_eq!(entity.expect("some").label, "Albert Einstein");
468    }
469
470    #[test]
471    fn test_find_by_iri_not_found() {
472        let linker = sample_linker();
473        assert!(linker.find_by_iri("http://example.org/nobody").is_none());
474    }
475
476    // --- string_similarity --------------------------------------------------
477
478    #[test]
479    fn test_string_similarity_exact() {
480        assert!((EntityLinker::string_similarity("hello", "hello") - 1.0).abs() < 1e-9);
481    }
482
483    #[test]
484    fn test_string_similarity_completely_different() {
485        let sim = EntityLinker::string_similarity("abc", "xyz");
486        assert!(sim < 1.0);
487    }
488
489    #[test]
490    fn test_string_similarity_both_empty() {
491        assert!((EntityLinker::string_similarity("", "") - 1.0).abs() < 1e-9);
492    }
493
494    #[test]
495    fn test_string_similarity_one_empty() {
496        let sim = EntityLinker::string_similarity("", "hello");
497        assert!((sim - 0.0).abs() < 1e-9);
498    }
499
500    #[test]
501    fn test_string_similarity_near_match() {
502        let sim = EntityLinker::string_similarity("Einstein", "Einsten");
503        assert!(sim > 0.8, "sim = {}", sim);
504    }
505
506    #[test]
507    fn test_string_similarity_range() {
508        let sim = EntityLinker::string_similarity("kitten", "sitting");
509        assert!((0.0..=1.0).contains(&sim));
510    }
511
512    // --- edit_distance -------------------------------------------------------
513
514    #[test]
515    fn test_edit_distance_identical() {
516        assert_eq!(EntityLinker::edit_distance("abc", "abc"), 0);
517    }
518
519    #[test]
520    fn test_edit_distance_one_empty() {
521        assert_eq!(EntityLinker::edit_distance("", "abc"), 3);
522        assert_eq!(EntityLinker::edit_distance("abc", ""), 3);
523    }
524
525    #[test]
526    fn test_edit_distance_kitten_sitting() {
527        assert_eq!(EntityLinker::edit_distance("kitten", "sitting"), 3);
528    }
529
530    #[test]
531    fn test_edit_distance_sunday_saturday() {
532        assert_eq!(EntityLinker::edit_distance("sunday", "saturday"), 3);
533    }
534
535    #[test]
536    fn test_edit_distance_single_char() {
537        assert_eq!(EntityLinker::edit_distance("a", "b"), 1);
538        assert_eq!(EntityLinker::edit_distance("a", "a"), 0);
539    }
540
541    // --- detect_mentions ----------------------------------------------------
542
543    #[test]
544    fn test_detect_mentions_exact_label() {
545        let linker = sample_linker();
546        let mentions = linker.detect_mentions("Albert Einstein was a physicist.");
547        assert!(!mentions.is_empty(), "expected at least one mention");
548        let texts: Vec<&str> = mentions.iter().map(|m| m.text.as_str()).collect();
549        assert!(
550            texts.iter().any(|t| t.to_lowercase().contains("einstein")),
551            "mentions = {:?}",
552            texts
553        );
554    }
555
556    #[test]
557    fn test_detect_mentions_alias() {
558        let linker = sample_linker();
559        let mentions = linker.detect_mentions("Curie discovered radium.");
560        let texts: Vec<&str> = mentions.iter().map(|m| m.text.as_str()).collect();
561        assert!(
562            texts.iter().any(|t| t.to_lowercase() == "curie"),
563            "mentions = {:?}",
564            texts
565        );
566    }
567
568    #[test]
569    fn test_detect_mentions_case_insensitive() {
570        let linker = sample_linker();
571        let mentions = linker.detect_mentions("berlin is a great city.");
572        assert!(!mentions.is_empty(), "expected mention of Berlin");
573    }
574
575    #[test]
576    fn test_detect_mentions_multiple() {
577        let linker = sample_linker();
578        let mentions = linker.detect_mentions("Einstein visited Berlin.");
579        // Should find at least two distinct mentions
580        assert!(mentions.len() >= 2, "mentions = {:?}", mentions.len());
581    }
582
583    #[test]
584    fn test_detect_mentions_no_match() {
585        let linker = sample_linker();
586        let mentions = linker.detect_mentions("The quick brown fox jumps.");
587        assert!(
588            mentions.is_empty(),
589            "expected no mentions, got {:?}",
590            mentions
591        );
592    }
593
594    // --- link ---------------------------------------------------------------
595
596    #[test]
597    fn test_link_basic() {
598        let linker = sample_linker();
599        let linked = linker
600            .link("Albert Einstein won the Nobel Prize.")
601            .expect("ok");
602        assert!(!linked.is_empty());
603    }
604
605    #[test]
606    fn test_link_empty_text_error() {
607        let linker = sample_linker();
608        assert!(linker.link("").is_err());
609        assert!(linker.link("   ").is_err());
610    }
611
612    #[test]
613    fn test_link_empty_kb_error() {
614        let linker = EntityLinker::new(0.5);
615        assert!(matches!(
616            linker.link("some text"),
617            Err(LinkerError::EmptyKnowledgeBase)
618        ));
619    }
620
621    #[test]
622    fn test_link_confidence_populated() {
623        let linker = sample_linker();
624        let linked = linker.link("Curie was born in Poland.").expect("ok");
625        for le in &linked {
626            if le.best_candidate.is_some() {
627                assert!(le.confidence > 0.0);
628            }
629        }
630    }
631
632    // --- link_mention -------------------------------------------------------
633
634    #[test]
635    fn test_link_mention_exact() {
636        let linker = sample_linker();
637        let candidates = linker.link_mention("Einstein");
638        assert!(!candidates.is_empty());
639        // Top candidate should be Einstein
640        assert_eq!(candidates[0].entity.iri, "http://example.org/einstein");
641    }
642
643    #[test]
644    fn test_link_mention_sorted_descending() {
645        let linker = sample_linker();
646        let candidates = linker.link_mention("Berlin");
647        for window in candidates.windows(2) {
648            assert!(
649                window[0].score >= window[1].score,
650                "candidates not sorted: {} < {}",
651                window[0].score,
652                window[1].score
653            );
654        }
655    }
656
657    #[test]
658    fn test_link_mention_returns_all_above_zero() {
659        let linker = sample_linker();
660        let candidates = linker.link_mention("Curie");
661        // All returned candidates should have positive score
662        for c in &candidates {
663            assert!(c.score > 0.0);
664        }
665    }
666
667    // --- threshold ----------------------------------------------------------
668
669    #[test]
670    fn test_threshold_filters_low_confidence() {
671        let mut linker = EntityLinker::new(0.99); // very high threshold
672        linker.add_entity(make_entity(
673            "http://x.org/z",
674            "Zephyr",
675            &[],
676            0.1, // low popularity
677        ));
678        let linked = linker.link("There is a zephyr wind.").expect("ok");
679        // Either no mentions or no best_candidate passes threshold
680        for le in &linked {
681            assert!(
682                le.best_candidate.is_none()
683                    || le.best_candidate.as_ref().expect("some").score >= 0.99
684            );
685        }
686    }
687
688    #[test]
689    fn test_min_mention_len_filter() {
690        let mut linker = EntityLinker::new(0.0);
691        linker.add_entity(make_entity("http://x.org/a", "AI", &[], 0.5));
692        linker.min_mention_len = 5;
693        // "AI" is 2 chars — below threshold
694        let mentions = linker.detect_mentions("AI is transforming the world.");
695        assert!(
696            mentions.is_empty() || mentions.iter().all(|m| m.text.len() >= 5),
697            "unexpected short mention found"
698        );
699    }
700
701    // --- Error display -------------------------------------------------------
702
703    #[test]
704    fn test_linker_error_display() {
705        assert!(LinkerError::EmptyText.to_string().contains("empty"));
706        assert!(LinkerError::EmptyKnowledgeBase
707            .to_string()
708            .contains("no entities"));
709        assert!(LinkerError::InvalidEntity("bad".to_string())
710            .to_string()
711            .contains("bad"));
712    }
713
714    // --- Combined score ------------------------------------------------------
715
716    #[test]
717    fn test_combined_score_perfect() {
718        // 0.7 * 1.0 + 0.3 * 1.0 = 1.0
719        let linker = sample_linker();
720        let candidates = linker.link_mention("Albert Einstein");
721        // Best candidate should have high score
722        if let Some(c) = candidates.first() {
723            assert!(c.score > 0.5, "score = {}", c.score);
724        }
725    }
726
727    // --- Alias detection ----------------------------------------------------
728
729    #[test]
730    fn test_alias_detection_full_alias() {
731        let linker = sample_linker();
732        let mentions = linker.detect_mentions("A. Einstein changed physics.");
733        let texts: Vec<String> = mentions.iter().map(|m| m.text.to_lowercase()).collect();
734        assert!(
735            texts.iter().any(|t| t.contains("einstein")),
736            "texts = {:?}",
737            texts
738        );
739    }
740}