Skip to main content

anno/linking/
linker.rs

1//! Main entity linker combining candidate generation, ranking, and NIL detection.
2
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5
6use super::candidate::{
7    Candidate, CandidateGenerator, CandidateSource, DictionaryCandidateGenerator,
8};
9use super::nil::{NilAction, NilDetector, NilReason};
10use anno_core::EntityType;
11
12/// A mention to be linked.
13#[derive(Debug, Clone)]
14pub struct Mention {
15    /// Mention text
16    pub text: String,
17    /// Start offset in document
18    pub start: usize,
19    /// End offset in document
20    pub end: usize,
21    /// Entity type from NER (optional)
22    pub entity_type: Option<EntityType>,
23}
24
25impl Mention {
26    /// Create a new mention.
27    pub fn new(text: &str, start: usize, end: usize) -> Self {
28        Self {
29            text: text.to_string(),
30            start,
31            end,
32            entity_type: None,
33        }
34    }
35
36    /// Set entity type.
37    pub fn with_type(mut self, entity_type: EntityType) -> Self {
38        self.entity_type = Some(entity_type);
39        self
40    }
41}
42
43/// A linked entity result.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct LinkedEntity {
46    /// Original mention text
47    pub mention_text: String,
48    /// Start offset
49    pub start: usize,
50    /// End offset
51    pub end: usize,
52    /// Linked KB ID (None if NIL)
53    pub kb_id: Option<String>,
54    /// KB source
55    pub source: CandidateSource,
56    /// Canonical label from KB
57    pub label: Option<String>,
58    /// Full IRI/URI
59    pub iri: Option<String>,
60    /// Linking confidence
61    pub confidence: f64,
62    /// Is this a NIL entity?
63    pub is_nil: bool,
64    /// NIL reason if applicable
65    pub nil_reason: Option<NilReason>,
66    /// NIL action if applicable
67    pub nil_action: Option<NilAction>,
68    /// Alternative candidates (for debugging/review)
69    pub alternatives: Vec<CandidateSummary>,
70}
71
72/// Summary of a candidate (for alternatives list).
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct CandidateSummary {
75    /// KB ID
76    pub kb_id: String,
77    /// Label
78    pub label: String,
79    /// Score
80    pub score: f64,
81}
82
83impl From<&Candidate> for CandidateSummary {
84    fn from(c: &Candidate) -> Self {
85        Self {
86            kb_id: c.kb_id.clone(),
87            label: c.label.clone(),
88            score: c.score,
89        }
90    }
91}
92
93/// Overall linking result for a document.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct LinkingResult {
96    /// Linked entities
97    pub entities: Vec<LinkedEntity>,
98    /// Total mentions processed
99    pub total_mentions: usize,
100    /// Successfully linked
101    pub linked_count: usize,
102    /// NIL count
103    pub nil_count: usize,
104    /// Average confidence
105    pub avg_confidence: f64,
106}
107
108impl LinkingResult {
109    /// Get linking rate.
110    pub fn linking_rate(&self) -> f64 {
111        if self.total_mentions == 0 {
112            0.0
113        } else {
114            self.linked_count as f64 / self.total_mentions as f64
115        }
116    }
117}
118
119/// Entity linker combining all components.
120pub struct EntityLinker {
121    /// Candidate generator
122    generator: Arc<dyn CandidateGenerator>,
123    /// NIL detector
124    nil_detector: NilDetector,
125    /// Maximum candidates to retrieve
126    max_candidates: usize,
127    /// Include alternatives in output
128    include_alternatives: bool,
129}
130
131impl EntityLinker {
132    /// Create a builder.
133    pub fn builder() -> EntityLinkerBuilder {
134        EntityLinkerBuilder::default()
135    }
136
137    /// Link mentions in a document.
138    pub fn link(&self, mentions: &[Mention], context: &str) -> LinkingResult {
139        let mut entities = Vec::with_capacity(mentions.len());
140        let mut linked_count = 0;
141        let mut nil_count = 0;
142        let mut total_confidence = 0.0;
143
144        for mention in mentions {
145            let entity_type_str = mention.entity_type.as_ref().map(|et| et.to_string());
146
147            // Generate candidates
148            let mut candidates = self.generator.generate(
149                &mention.text,
150                context,
151                entity_type_str.as_deref(),
152                self.max_candidates,
153            );
154
155            // Score candidates
156            for c in &mut candidates {
157                c.compute_score();
158            }
159            candidates.sort_by(|a, b| {
160                b.score
161                    .partial_cmp(&a.score)
162                    .unwrap_or(std::cmp::Ordering::Equal)
163            });
164
165            // NIL analysis
166            let nil_analysis =
167                self.nil_detector
168                    .analyze(&mention.text, &candidates, entity_type_str.as_deref());
169
170            let linked_entity = if nil_analysis.is_nil {
171                nil_count += 1;
172
173                LinkedEntity {
174                    mention_text: mention.text.clone(),
175                    start: mention.start,
176                    end: mention.end,
177                    kb_id: None,
178                    source: CandidateSource::default(),
179                    label: None,
180                    iri: None,
181                    confidence: nil_analysis.confidence,
182                    is_nil: true,
183                    nil_reason: nil_analysis.reason,
184                    nil_action: Some(nil_analysis.action),
185                    alternatives: if self.include_alternatives {
186                        candidates
187                            .iter()
188                            .take(5)
189                            .map(CandidateSummary::from)
190                            .collect()
191                    } else {
192                        Vec::new()
193                    },
194                }
195            } else {
196                linked_count += 1;
197                let top_candidate = &candidates[0];
198                total_confidence += top_candidate.score;
199
200                LinkedEntity {
201                    mention_text: mention.text.clone(),
202                    start: mention.start,
203                    end: mention.end,
204                    kb_id: Some(top_candidate.kb_id.clone()),
205                    source: top_candidate.source.clone(),
206                    label: Some(top_candidate.label.clone()),
207                    iri: Some(top_candidate.to_iri()),
208                    confidence: top_candidate.score,
209                    is_nil: false,
210                    nil_reason: None,
211                    nil_action: None,
212                    alternatives: if self.include_alternatives && candidates.len() > 1 {
213                        candidates[1..]
214                            .iter()
215                            .take(4)
216                            .map(CandidateSummary::from)
217                            .collect()
218                    } else {
219                        Vec::new()
220                    },
221                }
222            };
223
224            entities.push(linked_entity);
225        }
226
227        let avg_confidence = if linked_count > 0 {
228            total_confidence / linked_count as f64
229        } else {
230            0.0
231        };
232
233        LinkingResult {
234            entities,
235            total_mentions: mentions.len(),
236            linked_count,
237            nil_count,
238            avg_confidence,
239        }
240    }
241
242    /// Link a single mention (convenience method).
243    pub fn link_one(
244        &self,
245        mention: &str,
246        context: &str,
247        entity_type: Option<EntityType>,
248    ) -> Option<LinkedEntity> {
249        let m = if let Some(et) = entity_type {
250            Mention::new(mention, 0, mention.len()).with_type(et)
251        } else {
252            Mention::new(mention, 0, mention.len())
253        };
254
255        let result = self.link(&[m], context);
256        result.entities.into_iter().next()
257    }
258}
259
260/// Builder for EntityLinker.
261pub struct EntityLinkerBuilder {
262    generator: Option<Arc<dyn CandidateGenerator>>,
263    nil_threshold: f64,
264    max_candidates: usize,
265    include_alternatives: bool,
266}
267
268impl Default for EntityLinkerBuilder {
269    fn default() -> Self {
270        Self {
271            generator: None,
272            nil_threshold: 0.3,
273            max_candidates: 20,
274            include_alternatives: true,
275        }
276    }
277}
278
279impl EntityLinkerBuilder {
280    /// Set the candidate generator.
281    pub fn with_candidate_generator<G: CandidateGenerator + 'static>(mut self, gen: G) -> Self {
282        self.generator = Some(Arc::new(gen));
283        self
284    }
285
286    /// Set NIL threshold.
287    pub fn with_nil_threshold(mut self, threshold: f64) -> Self {
288        self.nil_threshold = threshold;
289        self
290    }
291
292    /// Set max candidates.
293    pub fn with_max_candidates(mut self, max: usize) -> Self {
294        self.max_candidates = max;
295        self
296    }
297
298    /// Set whether to include alternatives.
299    pub fn include_alternatives(mut self, include: bool) -> Self {
300        self.include_alternatives = include;
301        self
302    }
303
304    /// Build the linker.
305    pub fn build(self) -> EntityLinker {
306        let generator = self
307            .generator
308            .unwrap_or_else(|| Arc::new(DictionaryCandidateGenerator::new().with_well_known()));
309
310        EntityLinker {
311            generator,
312            nil_detector: NilDetector::new().with_score_threshold(self.nil_threshold),
313            max_candidates: self.max_candidates,
314            include_alternatives: self.include_alternatives,
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_entity_linker_basic() {
325        let linker = EntityLinker::builder().build();
326
327        let mentions = vec![Mention::new("Einstein", 0, 8).with_type(EntityType::Person)];
328
329        let result = linker.link(&mentions, "Albert Einstein was a physicist.");
330
331        assert_eq!(result.total_mentions, 1);
332        // May or may not link depending on fuzzy matching
333    }
334
335    #[test]
336    fn test_entity_linker_known_entity() {
337        let linker = EntityLinker::builder().with_nil_threshold(0.1).build();
338
339        let linked = linker.link_one(
340            "Albert Einstein",
341            "He was a physicist.",
342            Some(EntityType::Person),
343        );
344
345        if let Some(entity) = linked {
346            if !entity.is_nil {
347                assert!(entity.kb_id.is_some());
348                assert!(entity.iri.as_ref().unwrap().contains("wikidata"));
349            }
350        }
351    }
352
353    #[test]
354    fn test_entity_linker_nil() {
355        let linker = EntityLinker::builder().build();
356
357        let linked = linker.link_one("Xyzzy Qwerty Asdf", "Unknown person.", None);
358
359        if let Some(entity) = linked {
360            // Should be NIL
361            assert!(entity.is_nil || entity.confidence < 0.5);
362        }
363    }
364
365    #[test]
366    fn test_linking_result_stats() {
367        let result = LinkingResult {
368            entities: Vec::new(),
369            total_mentions: 10,
370            linked_count: 7,
371            nil_count: 3,
372            avg_confidence: 0.8,
373        };
374
375        assert!((result.linking_rate() - 0.7).abs() < 0.001);
376    }
377
378    // === Additional tests for coverage ===
379
380    #[test]
381    fn test_multilingual_entity_linking() {
382        // Test CJK entities (per multicultural guidelines)
383        let linker = EntityLinker::builder().with_nil_threshold(0.1).build();
384
385        // Chinese name
386        let linked = linker.link_one("北京", "Visit Beijing, China.", None);
387        if let Some(entity) = &linked {
388            // Beijing should be in our KB
389            if !entity.is_nil {
390                assert!(entity.kb_id.is_some());
391            }
392        }
393
394        // Japanese name
395        let linked = linker.link_one("東京", "Tokyo is in Japan.", None);
396        assert!(linked.is_some()); // Should at least return a result
397    }
398
399    #[test]
400    fn test_entity_type_aware_linking() {
401        let linker = EntityLinker::builder().build();
402
403        // Same text, different entity types should still work
404        let person = linker.link_one(
405            "Apple",
406            "Steve Jobs founded Apple.",
407            Some(EntityType::Person),
408        );
409        let org = linker.link_one(
410            "Apple",
411            "Apple is a tech company.",
412            Some(EntityType::Organization),
413        );
414
415        // Both should return results (possibly different)
416        assert!(person.is_some());
417        assert!(org.is_some());
418    }
419
420    #[test]
421    fn test_batch_linking_multiple_mentions() {
422        let linker = EntityLinker::builder().build();
423
424        let mentions = vec![
425            Mention::new("Google", 0, 6).with_type(EntityType::Organization),
426            Mention::new("Microsoft", 15, 24).with_type(EntityType::Organization),
427            Mention::new("Apple", 30, 35).with_type(EntityType::Organization),
428        ];
429
430        let result = linker.link(&mentions, "Google and Microsoft and Apple are tech giants.");
431
432        assert_eq!(result.total_mentions, 3);
433        assert!(result.entities.len() <= 3);
434    }
435
436    #[test]
437    fn test_empty_mentions() {
438        let linker = EntityLinker::builder().build();
439
440        let result = linker.link(&[], "Some text without mentions.");
441
442        assert_eq!(result.total_mentions, 0);
443        assert_eq!(result.linked_count, 0);
444        assert_eq!(result.nil_count, 0);
445    }
446
447    #[test]
448    fn test_very_short_mention() {
449        let linker = EntityLinker::builder().build();
450
451        // Single character mentions are likely noise
452        let linked = linker.link_one("X", "X marks the spot.", None);
453
454        // Should handle gracefully (probably NIL or low confidence)
455        if let Some(entity) = linked {
456            // Short mentions typically get flagged as noisy
457            assert!(entity.is_nil || entity.confidence < 0.3);
458        }
459    }
460
461    #[test]
462    fn test_mention_builder_pattern() {
463        let mention = Mention::new("Test", 0, 4).with_type(EntityType::Person);
464
465        assert_eq!(mention.text, "Test");
466        assert_eq!(mention.start, 0);
467        assert_eq!(mention.end, 4);
468        assert_eq!(mention.entity_type, Some(EntityType::Person));
469    }
470
471    #[test]
472    fn test_linked_entity_serialization() {
473        let entity = LinkedEntity {
474            mention_text: "Einstein".to_string(),
475            start: 0,
476            end: 8,
477            kb_id: Some("Q937".to_string()),
478            source: CandidateSource::Wikidata,
479            label: Some("Albert Einstein".to_string()),
480            iri: Some("http://www.wikidata.org/entity/Q937".to_string()),
481            confidence: 0.95,
482            is_nil: false,
483            nil_reason: None,
484            nil_action: None,
485            alternatives: vec![],
486        };
487
488        // Test serialization round-trip
489        let json = serde_json::to_string(&entity).unwrap();
490        let deserialized: LinkedEntity = serde_json::from_str(&json).unwrap();
491
492        assert_eq!(deserialized.kb_id, entity.kb_id);
493        assert_eq!(deserialized.mention_text, entity.mention_text);
494    }
495
496    #[test]
497    fn test_linker_with_custom_threshold() {
498        // High threshold should increase NIL rate
499        let strict_linker = EntityLinker::builder().with_nil_threshold(0.9).build();
500
501        // Low threshold should decrease NIL rate
502        let lenient_linker = EntityLinker::builder().with_nil_threshold(0.1).build();
503
504        let result_strict = strict_linker.link_one("some entity", "context", None);
505        let result_lenient = lenient_linker.link_one("some entity", "context", None);
506
507        // Both should work without panicking
508        let _ = (result_strict, result_lenient);
509    }
510}