1use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct EntityMention {
15 pub text: String,
17 pub start: usize,
19 pub end: usize,
21}
22
23impl EntityMention {
24 pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
26 Self {
27 text: text.into(),
28 start,
29 end,
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct EntityCandidate {
37 pub iri: String,
39 pub label: String,
41 pub score: f64,
43 pub aliases: Vec<String>,
45}
46
47impl EntityCandidate {
48 fn new(iri: impl Into<String>, label: impl Into<String>, aliases: Vec<String>) -> Self {
49 Self {
50 iri: iri.into(),
51 label: label.into(),
52 score: 0.0,
53 aliases,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct LinkedEntity {
61 pub mention: EntityMention,
63 pub entity: EntityCandidate,
65 pub confidence: f64,
67}
68
69pub struct TfIdfIndex {
75 docs: Vec<(String, HashMap<String, f64>)>,
77 idf: HashMap<String, f64>,
79}
80
81impl TfIdfIndex {
82 pub fn new() -> Self {
84 Self {
85 docs: Vec::new(),
86 idf: HashMap::new(),
87 }
88 }
89
90 pub fn add_document(&mut self, doc_id: impl Into<String>, text: &str) {
92 let tokens = tokenize(text);
93 let total = tokens.len() as f64;
94 if total == 0.0 {
95 return;
96 }
97 let mut tf: HashMap<String, f64> = HashMap::new();
98 for tok in &tokens {
99 *tf.entry(tok.clone()).or_insert(0.0) += 1.0 / total;
100 }
101 self.docs.push((doc_id.into(), tf));
102 }
103
104 pub fn build(&mut self) {
106 let n = self.docs.len() as f64;
107 let mut df: HashMap<String, usize> = HashMap::new();
108 for (_, tf) in &self.docs {
109 for term in tf.keys() {
110 *df.entry(term.clone()).or_insert(0) += 1;
111 }
112 }
113 self.idf.clear();
114 for (term, count) in df {
115 self.idf.insert(term, (n / count as f64).ln() + 1.0);
116 }
117 }
118
119 pub fn similarity(&self, query: &str, doc_id: &str) -> f64 {
121 let doc = match self.docs.iter().find(|(id, _)| id == doc_id) {
122 Some((_, tf)) => tf,
123 None => return 0.0,
124 };
125
126 let q_tokens = tokenize(query);
127 let q_total = q_tokens.len() as f64;
128 if q_total == 0.0 {
129 return 0.0;
130 }
131 let mut q_tf: HashMap<String, f64> = HashMap::new();
132 for tok in &q_tokens {
133 *q_tf.entry(tok.clone()).or_insert(0.0) += 1.0 / q_total;
134 }
135
136 let mut dot = 0.0_f64;
138 let mut q_norm = 0.0_f64;
139 let mut d_norm = 0.0_f64;
140
141 let all_terms: std::collections::HashSet<&String> = q_tf.keys().chain(doc.keys()).collect();
142
143 for term in all_terms {
144 let idf = self.idf.get(term).copied().unwrap_or(1.0);
145 let q_val = q_tf.get(term).copied().unwrap_or(0.0) * idf;
146 let d_val = doc.get(term).copied().unwrap_or(0.0) * idf;
147 dot += q_val * d_val;
148 q_norm += q_val * q_val;
149 d_norm += d_val * d_val;
150 }
151
152 let denom = q_norm.sqrt() * d_norm.sqrt();
153 if denom < 1e-15 {
154 0.0
155 } else {
156 (dot / denom).clamp(0.0, 1.0)
157 }
158 }
159
160 pub fn doc_count(&self) -> usize {
162 self.docs.len()
163 }
164}
165
166impl Default for TfIdfIndex {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172pub struct EntityLinker {
178 kb: HashMap<String, KbEntry>,
180 tfidf: TfIdfIndex,
182 pub nil_threshold: f64,
184}
185
186struct KbEntry {
187 label: String,
188 aliases: Vec<String>,
189}
190
191impl EntityLinker {
192 pub fn new() -> Self {
194 Self {
195 kb: HashMap::new(),
196 tfidf: TfIdfIndex::new(),
197 nil_threshold: 0.1,
198 }
199 }
200
201 pub fn add_entity(
206 &mut self,
207 iri: impl Into<String>,
208 label: impl Into<String>,
209 aliases: &[&str],
210 ) {
211 let iri = iri.into();
212 let label = label.into();
213 let aliases: Vec<String> = aliases.iter().map(|s| s.to_string()).collect();
214 let context = format!("{} {}", label, aliases.join(" "));
215 self.tfidf.add_document(iri.clone(), &context);
216 self.kb.insert(iri, KbEntry { label, aliases });
217 }
218
219 pub fn build_index(&mut self) {
221 self.tfidf.build();
222 }
223
224 pub fn link(&self, text: &str) -> Vec<LinkedEntity> {
226 let mentions = detect_mentions(text);
227 let mut linked = Vec::new();
228
229 for mention in mentions {
230 let candidates = self.candidate_generation(&mention.text);
231 if candidates.is_empty() {
232 continue;
233 }
234 let best = self.disambiguate(&mention, &candidates, text);
235 if let Some(entity) = best {
236 let confidence = entity.score;
237 if confidence >= self.nil_threshold {
238 linked.push(LinkedEntity {
239 mention,
240 entity,
241 confidence,
242 });
243 }
244 }
245 }
246 linked
247 }
248
249 pub fn candidate_generation(&self, mention: &str) -> Vec<EntityCandidate> {
251 let mention_lower = mention.to_lowercase();
252 let mut candidates: Vec<EntityCandidate> = self
253 .kb
254 .iter()
255 .filter_map(|(iri, entry)| {
256 let label_score = jaro_winkler(&mention_lower, &entry.label.to_lowercase());
257 let alias_score = entry
258 .aliases
259 .iter()
260 .map(|a| jaro_winkler(&mention_lower, &a.to_lowercase()))
261 .fold(0.0_f64, f64::max);
262 let score = label_score.max(alias_score);
263 if score > 0.6 {
264 let mut c = EntityCandidate::new(
265 iri.clone(),
266 entry.label.clone(),
267 entry.aliases.clone(),
268 );
269 c.score = score;
270 Some(c)
271 } else {
272 None
273 }
274 })
275 .collect();
276
277 candidates.sort_by(|a, b| {
278 b.score
279 .partial_cmp(&a.score)
280 .unwrap_or(std::cmp::Ordering::Equal)
281 });
282 candidates
283 }
284
285 pub fn disambiguate(
287 &self,
288 _mention: &EntityMention,
289 candidates: &[EntityCandidate],
290 context: &str,
291 ) -> Option<EntityCandidate> {
292 if candidates.is_empty() {
293 return None;
294 }
295
296 let mut best_score = f64::NEG_INFINITY;
297 let mut best: Option<EntityCandidate> = None;
298
299 for cand in candidates {
300 let ctx_score = self.tfidf.similarity(context, &cand.iri);
301 let combined = cand.score * 0.6 + ctx_score * 0.4;
303 if combined > best_score {
304 best_score = combined;
305 let mut winner = cand.clone();
306 winner.score = combined;
307 best = Some(winner);
308 }
309 }
310 best
311 }
312
313 pub fn entity_count(&self) -> usize {
315 self.kb.len()
316 }
317}
318
319impl Default for EntityLinker {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325fn detect_mentions(text: &str) -> Vec<EntityMention> {
332 let mut mentions = Vec::new();
333 let mut chars = text.char_indices().peekable();
334 let bytes = text.as_bytes();
335 let len = bytes.len();
336
337 while let Some((start, ch)) = chars.next() {
338 if ch.is_uppercase() {
339 let mut end = start + ch.len_utf8();
341 while end < len {
342 let next_ch = text[end..].chars().next().unwrap_or('\0');
343 if next_ch.is_alphanumeric() || next_ch == ' ' {
344 if next_ch == ' ' {
346 let after_space = end + 1;
347 if after_space < len {
348 let nc2 = text[after_space..].chars().next().unwrap_or('\0');
349 if nc2.is_uppercase() {
350 end = after_space + nc2.len_utf8();
351 let _ = chars.next(); let _ = chars.next(); continue;
355 }
356 }
357 break;
358 }
359 end += next_ch.len_utf8();
360 let _ = chars.next();
361 } else {
362 break;
363 }
364 }
365 let mention_text = text[start..end].trim().to_string();
366 if mention_text.len() >= 2 {
367 mentions.push(EntityMention::new(mention_text, start, end));
368 }
369 }
370 }
371 mentions
372}
373
374fn jaro_winkler(s1: &str, s2: &str) -> f64 {
376 if s1 == s2 {
377 return 1.0;
378 }
379 let jaro = jaro(s1, s2);
380 let prefix_len = s1
381 .chars()
382 .zip(s2.chars())
383 .take(4)
384 .take_while(|(a, b)| a == b)
385 .count();
386 let p = 0.1_f64;
387 jaro + (prefix_len as f64 * p * (1.0 - jaro))
388}
389
390fn jaro(s1: &str, s2: &str) -> f64 {
391 let s1: Vec<char> = s1.chars().collect();
392 let s2: Vec<char> = s2.chars().collect();
393 let len1 = s1.len();
394 let len2 = s2.len();
395 if len1 == 0 && len2 == 0 {
396 return 1.0;
397 }
398 if len1 == 0 || len2 == 0 {
399 return 0.0;
400 }
401
402 let match_window = (len1.max(len2) / 2).saturating_sub(1);
403 let mut s1_matches = vec![false; len1];
404 let mut s2_matches = vec![false; len2];
405 let mut matches = 0usize;
406 let mut transpositions = 0usize;
407
408 for (i, &c1) in s1.iter().enumerate() {
409 let start = i.saturating_sub(match_window);
410 let end = (i + match_window + 1).min(len2);
411 for (j, &c2) in s2[start..end].iter().enumerate() {
412 let j_real = start + j;
413 if !s2_matches[j_real] && c1 == c2 {
414 s1_matches[i] = true;
415 s2_matches[j_real] = true;
416 matches += 1;
417 break;
418 }
419 }
420 }
421
422 if matches == 0 {
423 return 0.0;
424 }
425
426 let mut k = 0;
427 for (i, &s1m) in s1_matches.iter().enumerate() {
428 if s1m {
429 while !s2_matches[k] {
430 k += 1;
431 }
432 if s1[i] != s2[k] {
433 transpositions += 1;
434 }
435 k += 1;
436 }
437 }
438
439 let m = matches as f64;
440 (m / len1 as f64 + m / len2 as f64 + (m - transpositions as f64 / 2.0) / m) / 3.0
441}
442
443fn tokenize(text: &str) -> Vec<String> {
445 text.split(|c: char| !c.is_alphanumeric())
446 .filter(|s| !s.is_empty())
447 .map(|s| s.to_lowercase())
448 .collect()
449}
450
451#[cfg(test)]
456mod tests {
457 use super::*;
458
459 fn linker_with_persons() -> EntityLinker {
460 let mut linker = EntityLinker::new();
461 linker.add_entity(
462 "http://example.org/Albert_Einstein",
463 "Albert Einstein",
464 &["Einstein", "A. Einstein"],
465 );
466 linker.add_entity(
467 "http://example.org/Marie_Curie",
468 "Marie Curie",
469 &["Curie", "M. Curie"],
470 );
471 linker.add_entity(
472 "http://example.org/Isaac_Newton",
473 "Isaac Newton",
474 &["Newton"],
475 );
476 linker.build_index();
477 linker
478 }
479
480 #[test]
483 fn test_mention_new() {
484 let m = EntityMention::new("Alice", 0, 5);
485 assert_eq!(m.text, "Alice");
486 assert_eq!(m.start, 0);
487 assert_eq!(m.end, 5);
488 }
489
490 #[test]
491 fn test_mention_equality() {
492 let m1 = EntityMention::new("Bob", 0, 3);
493 let m2 = EntityMention::new("Bob", 0, 3);
494 assert_eq!(m1, m2);
495 }
496
497 #[test]
500 fn test_tfidf_add_document() {
501 let mut idx = TfIdfIndex::new();
502 idx.add_document("doc1", "quantum physics relativity");
503 idx.build();
504 assert_eq!(idx.doc_count(), 1);
505 }
506
507 #[test]
508 fn test_tfidf_similarity_same_doc() {
509 let mut idx = TfIdfIndex::new();
510 idx.add_document("doc1", "quantum physics relativity");
511 idx.build();
512 let sim = idx.similarity("quantum physics", "doc1");
513 assert!(sim > 0.0, "similarity should be > 0, got {sim}");
514 }
515
516 #[test]
517 fn test_tfidf_similarity_different_content() {
518 let mut idx = TfIdfIndex::new();
519 idx.add_document("doc1", "quantum physics relativity");
520 idx.add_document("doc2", "cooking recipes baking bread");
521 idx.build();
522 let s1 = idx.similarity("quantum physics", "doc1");
523 let s2 = idx.similarity("quantum physics", "doc2");
524 assert!(s1 > s2, "physics query should match doc1 better");
525 }
526
527 #[test]
528 fn test_tfidf_unknown_doc() {
529 let idx = TfIdfIndex::new();
530 assert_eq!(idx.similarity("anything", "unknown"), 0.0);
531 }
532
533 #[test]
534 fn test_tfidf_empty_query() {
535 let mut idx = TfIdfIndex::new();
536 idx.add_document("d", "hello world");
537 idx.build();
538 assert_eq!(idx.similarity("", "d"), 0.0);
539 }
540
541 #[test]
542 fn test_tfidf_default() {
543 let idx = TfIdfIndex::default();
544 assert_eq!(idx.doc_count(), 0);
545 }
546
547 #[test]
550 fn test_linker_entity_count() {
551 let linker = linker_with_persons();
552 assert_eq!(linker.entity_count(), 3);
553 }
554
555 #[test]
556 fn test_linker_default() {
557 let linker = EntityLinker::default();
558 assert_eq!(linker.entity_count(), 0);
559 }
560
561 #[test]
564 fn test_candidate_generation_exact_label() {
565 let linker = linker_with_persons();
566 let cands = linker.candidate_generation("Einstein");
567 assert!(!cands.is_empty());
568 assert!(cands[0].iri.contains("Einstein"));
569 }
570
571 #[test]
572 fn test_candidate_generation_partial() {
573 let linker = linker_with_persons();
574 let cands = linker.candidate_generation("Newton");
575 assert!(!cands.is_empty());
576 assert!(cands.iter().any(|c| c.iri.contains("Newton")));
577 }
578
579 #[test]
580 fn test_candidate_generation_no_match() {
581 let linker = linker_with_persons();
582 let cands = linker.candidate_generation("Zorkblat");
583 assert!(cands.is_empty());
584 }
585
586 #[test]
587 fn test_candidate_generation_sorted_by_score() {
588 let linker = linker_with_persons();
589 let cands = linker.candidate_generation("Curie");
590 for i in 1..cands.len() {
591 assert!(cands[i - 1].score >= cands[i].score);
592 }
593 }
594
595 #[test]
596 fn test_candidate_generation_alias_match() {
597 let linker = linker_with_persons();
598 let cands = linker.candidate_generation("Curie");
600 assert!(cands.iter().any(|c| c.iri.contains("Curie")));
601 }
602
603 #[test]
606 fn test_disambiguate_returns_best() {
607 let linker = linker_with_persons();
608 let cands = linker.candidate_generation("Einstein");
609 let mention = EntityMention::new("Einstein", 0, 8);
610 let best = linker.disambiguate(&mention, &cands, "Einstein worked on relativity");
611 assert!(best.is_some());
612 assert!(best.expect("should succeed").iri.contains("Einstein"));
613 }
614
615 #[test]
616 fn test_disambiguate_empty_candidates() {
617 let linker = linker_with_persons();
618 let mention = EntityMention::new("X", 0, 1);
619 let result = linker.disambiguate(&mention, &[], "context");
620 assert!(result.is_none());
621 }
622
623 #[test]
624 fn test_disambiguate_score_in_range() {
625 let linker = linker_with_persons();
626 let cands = linker.candidate_generation("Newton");
627 let mention = EntityMention::new("Newton", 0, 6);
628 if let Some(best) = linker.disambiguate(&mention, &cands, "gravity laws Newton") {
629 assert!((0.0..=1.0).contains(&best.score));
630 }
631 }
632
633 #[test]
636 fn test_link_finds_entity() {
637 let linker = linker_with_persons();
638 let linked = linker.link("Einstein developed relativity theory.");
639 assert!(!linked.is_empty());
640 assert!(linked[0].entity.iri.contains("Einstein"));
641 }
642
643 #[test]
644 fn test_link_confidence_above_threshold() {
645 let linker = linker_with_persons();
646 let linked = linker.link("Newton formulated laws of motion.");
647 for le in &linked {
648 assert!(le.confidence >= linker.nil_threshold);
649 }
650 }
651
652 #[test]
653 fn test_link_no_entities_in_empty_text() {
654 let linker = linker_with_persons();
655 let linked = linker.link("");
656 assert!(linked.is_empty());
657 }
658
659 #[test]
660 fn test_link_result_fields() {
661 let linker = linker_with_persons();
662 let linked = linker.link("Einstein and Curie were scientists.");
663 for le in &linked {
664 assert!(!le.mention.text.is_empty());
665 assert!(!le.entity.iri.is_empty());
666 assert!((0.0..=1.0).contains(&le.confidence));
667 }
668 }
669
670 #[test]
673 fn test_jaro_winkler_identical() {
674 assert!((jaro_winkler("hello", "hello") - 1.0).abs() < 1e-9);
675 }
676
677 #[test]
678 fn test_jaro_winkler_completely_different() {
679 let score = jaro_winkler("abc", "xyz");
680 assert!(score < 0.5, "score = {score}");
681 }
682
683 #[test]
684 fn test_jaro_winkler_prefix_boost() {
685 let jw = jaro_winkler("einstein", "einstien");
686 assert!(jw > 0.8, "score = {jw}");
687 }
688
689 #[test]
690 fn test_jaro_winkler_empty_strings() {
691 assert!((jaro("", "") - 1.0).abs() < 1e-9);
692 assert!((jaro("abc", "") - 0.0).abs() < 1e-9);
693 }
694
695 #[test]
698 fn test_detect_mentions_finds_capitalized() {
699 let mentions = detect_mentions("Alice and Bob went to Paris.");
700 let texts: Vec<&str> = mentions.iter().map(|m| m.text.as_str()).collect();
701 assert!(texts
703 .iter()
704 .any(|t| *t == "Alice" || t.starts_with("Alice")));
705 }
706
707 #[test]
708 fn test_detect_mentions_empty() {
709 assert!(detect_mentions("").is_empty());
710 }
711
712 #[test]
713 fn test_detect_mentions_lowercase_only() {
714 let mentions = detect_mentions("all lowercase words here");
715 assert!(mentions.is_empty());
716 }
717
718 #[test]
721 fn test_tokenize_basic() {
722 let tokens = tokenize("Hello World");
723 assert_eq!(tokens, vec!["hello", "world"]);
724 }
725
726 #[test]
727 fn test_tokenize_empty() {
728 assert!(tokenize("").is_empty());
729 }
730
731 #[test]
732 fn test_tokenize_punctuation_split() {
733 let tokens = tokenize("foo, bar; baz.");
734 assert_eq!(tokens, vec!["foo", "bar", "baz"]);
735 }
736
737 #[test]
740 fn test_full_pipeline() {
741 let mut linker = EntityLinker::new();
742 linker.add_entity("http://ex.org/Paris", "Paris", &["City of Light"]);
743 linker.add_entity("http://ex.org/London", "London", &["British capital"]);
744 linker.build_index();
745
746 let linked = linker.link("Paris is a famous city in France.");
747 if !linked.is_empty() {
748 assert!(linked[0].entity.iri.contains("Paris"));
749 }
750 }
752
753 #[test]
754 fn test_nil_threshold_filters_low_confidence() {
755 let mut linker = EntityLinker::new();
756 linker.add_entity("http://ex.org/X", "Xyzzy", &[]);
757 linker.build_index();
758 linker.nil_threshold = 0.99; let linked = linker.link("Xyzzy something");
761 for le in &linked {
763 assert!(le.confidence >= 0.99);
764 }
765 }
766}