1use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5
6use super::candidate::{
7 Candidate, CandidateGenerator, CandidateSource, DictionaryCandidateGenerator,
8};
9use super::nil::{NilAction, NilDetector, NilReason};
10use anno_core::EntityType;
11
12#[derive(Debug, Clone)]
14pub struct Mention {
15 pub text: String,
17 pub start: usize,
19 pub end: usize,
21 pub entity_type: Option<EntityType>,
23}
24
25impl Mention {
26 pub fn new(text: &str, start: usize, end: usize) -> Self {
28 Self {
29 text: text.to_string(),
30 start,
31 end,
32 entity_type: None,
33 }
34 }
35
36 pub fn with_type(mut self, entity_type: EntityType) -> Self {
38 self.entity_type = Some(entity_type);
39 self
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct LinkedEntity {
46 pub mention_text: String,
48 pub start: usize,
50 pub end: usize,
52 pub kb_id: Option<String>,
54 pub source: CandidateSource,
56 pub label: Option<String>,
58 pub iri: Option<String>,
60 pub confidence: f64,
62 pub is_nil: bool,
64 pub nil_reason: Option<NilReason>,
66 pub nil_action: Option<NilAction>,
68 pub alternatives: Vec<CandidateSummary>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct CandidateSummary {
75 pub kb_id: String,
77 pub label: String,
79 pub score: f64,
81}
82
83impl From<&Candidate> for CandidateSummary {
84 fn from(c: &Candidate) -> Self {
85 Self {
86 kb_id: c.kb_id.clone(),
87 label: c.label.clone(),
88 score: c.score,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct LinkingResult {
96 pub entities: Vec<LinkedEntity>,
98 pub total_mentions: usize,
100 pub linked_count: usize,
102 pub nil_count: usize,
104 pub avg_confidence: f64,
106}
107
108impl LinkingResult {
109 pub fn linking_rate(&self) -> f64 {
111 if self.total_mentions == 0 {
112 0.0
113 } else {
114 self.linked_count as f64 / self.total_mentions as f64
115 }
116 }
117}
118
119pub struct EntityLinker {
121 generator: Arc<dyn CandidateGenerator>,
123 nil_detector: NilDetector,
125 max_candidates: usize,
127 include_alternatives: bool,
129}
130
131impl EntityLinker {
132 pub fn builder() -> EntityLinkerBuilder {
134 EntityLinkerBuilder::default()
135 }
136
137 pub fn link(&self, mentions: &[Mention], context: &str) -> LinkingResult {
139 let mut entities = Vec::with_capacity(mentions.len());
140 let mut linked_count = 0;
141 let mut nil_count = 0;
142 let mut total_confidence = 0.0;
143
144 for mention in mentions {
145 let entity_type_str = mention.entity_type.as_ref().map(|et| et.to_string());
146
147 let mut candidates = self.generator.generate(
149 &mention.text,
150 context,
151 entity_type_str.as_deref(),
152 self.max_candidates,
153 );
154
155 for c in &mut candidates {
157 c.compute_score();
158 }
159 candidates.sort_by(|a, b| {
160 b.score
161 .partial_cmp(&a.score)
162 .unwrap_or(std::cmp::Ordering::Equal)
163 });
164
165 let nil_analysis =
167 self.nil_detector
168 .analyze(&mention.text, &candidates, entity_type_str.as_deref());
169
170 let linked_entity = if nil_analysis.is_nil {
171 nil_count += 1;
172
173 LinkedEntity {
174 mention_text: mention.text.clone(),
175 start: mention.start,
176 end: mention.end,
177 kb_id: None,
178 source: CandidateSource::default(),
179 label: None,
180 iri: None,
181 confidence: nil_analysis.confidence,
182 is_nil: true,
183 nil_reason: nil_analysis.reason,
184 nil_action: Some(nil_analysis.action),
185 alternatives: if self.include_alternatives {
186 candidates
187 .iter()
188 .take(5)
189 .map(CandidateSummary::from)
190 .collect()
191 } else {
192 Vec::new()
193 },
194 }
195 } else {
196 linked_count += 1;
197 let top_candidate = &candidates[0];
198 total_confidence += top_candidate.score;
199
200 LinkedEntity {
201 mention_text: mention.text.clone(),
202 start: mention.start,
203 end: mention.end,
204 kb_id: Some(top_candidate.kb_id.clone()),
205 source: top_candidate.source.clone(),
206 label: Some(top_candidate.label.clone()),
207 iri: Some(top_candidate.to_iri()),
208 confidence: top_candidate.score,
209 is_nil: false,
210 nil_reason: None,
211 nil_action: None,
212 alternatives: if self.include_alternatives && candidates.len() > 1 {
213 candidates[1..]
214 .iter()
215 .take(4)
216 .map(CandidateSummary::from)
217 .collect()
218 } else {
219 Vec::new()
220 },
221 }
222 };
223
224 entities.push(linked_entity);
225 }
226
227 let avg_confidence = if linked_count > 0 {
228 total_confidence / linked_count as f64
229 } else {
230 0.0
231 };
232
233 LinkingResult {
234 entities,
235 total_mentions: mentions.len(),
236 linked_count,
237 nil_count,
238 avg_confidence,
239 }
240 }
241
242 pub fn link_one(
244 &self,
245 mention: &str,
246 context: &str,
247 entity_type: Option<EntityType>,
248 ) -> Option<LinkedEntity> {
249 let m = if let Some(et) = entity_type {
250 Mention::new(mention, 0, mention.len()).with_type(et)
251 } else {
252 Mention::new(mention, 0, mention.len())
253 };
254
255 let result = self.link(&[m], context);
256 result.entities.into_iter().next()
257 }
258}
259
260pub struct EntityLinkerBuilder {
262 generator: Option<Arc<dyn CandidateGenerator>>,
263 nil_threshold: f64,
264 max_candidates: usize,
265 include_alternatives: bool,
266}
267
268impl Default for EntityLinkerBuilder {
269 fn default() -> Self {
270 Self {
271 generator: None,
272 nil_threshold: 0.3,
273 max_candidates: 20,
274 include_alternatives: true,
275 }
276 }
277}
278
279impl EntityLinkerBuilder {
280 pub fn with_candidate_generator<G: CandidateGenerator + 'static>(mut self, gen: G) -> Self {
282 self.generator = Some(Arc::new(gen));
283 self
284 }
285
286 pub fn with_nil_threshold(mut self, threshold: f64) -> Self {
288 self.nil_threshold = threshold;
289 self
290 }
291
292 pub fn with_max_candidates(mut self, max: usize) -> Self {
294 self.max_candidates = max;
295 self
296 }
297
298 pub fn include_alternatives(mut self, include: bool) -> Self {
300 self.include_alternatives = include;
301 self
302 }
303
304 pub fn build(self) -> EntityLinker {
306 let generator = self
307 .generator
308 .unwrap_or_else(|| Arc::new(DictionaryCandidateGenerator::new().with_well_known()));
309
310 EntityLinker {
311 generator,
312 nil_detector: NilDetector::new().with_score_threshold(self.nil_threshold),
313 max_candidates: self.max_candidates,
314 include_alternatives: self.include_alternatives,
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_entity_linker_basic() {
325 let linker = EntityLinker::builder().build();
326
327 let mentions = vec![Mention::new("Einstein", 0, 8).with_type(EntityType::Person)];
328
329 let result = linker.link(&mentions, "Albert Einstein was a physicist.");
330
331 assert_eq!(result.total_mentions, 1);
332 }
334
335 #[test]
336 fn test_entity_linker_known_entity() {
337 let linker = EntityLinker::builder().with_nil_threshold(0.1).build();
338
339 let linked = linker.link_one(
340 "Albert Einstein",
341 "He was a physicist.",
342 Some(EntityType::Person),
343 );
344
345 if let Some(entity) = linked {
346 if !entity.is_nil {
347 assert!(entity.kb_id.is_some());
348 assert!(entity.iri.as_ref().unwrap().contains("wikidata"));
349 }
350 }
351 }
352
353 #[test]
354 fn test_entity_linker_nil() {
355 let linker = EntityLinker::builder().build();
356
357 let linked = linker.link_one("Xyzzy Qwerty Asdf", "Unknown person.", None);
358
359 if let Some(entity) = linked {
360 assert!(entity.is_nil || entity.confidence < 0.5);
362 }
363 }
364
365 #[test]
366 fn test_linking_result_stats() {
367 let result = LinkingResult {
368 entities: Vec::new(),
369 total_mentions: 10,
370 linked_count: 7,
371 nil_count: 3,
372 avg_confidence: 0.8,
373 };
374
375 assert!((result.linking_rate() - 0.7).abs() < 0.001);
376 }
377
378 #[test]
381 fn test_multilingual_entity_linking() {
382 let linker = EntityLinker::builder().with_nil_threshold(0.1).build();
384
385 let linked = linker.link_one("北京", "Visit Beijing, China.", None);
387 if let Some(entity) = &linked {
388 if !entity.is_nil {
390 assert!(entity.kb_id.is_some());
391 }
392 }
393
394 let linked = linker.link_one("東京", "Tokyo is in Japan.", None);
396 assert!(linked.is_some()); }
398
399 #[test]
400 fn test_entity_type_aware_linking() {
401 let linker = EntityLinker::builder().build();
402
403 let person = linker.link_one(
405 "Apple",
406 "Steve Jobs founded Apple.",
407 Some(EntityType::Person),
408 );
409 let org = linker.link_one(
410 "Apple",
411 "Apple is a tech company.",
412 Some(EntityType::Organization),
413 );
414
415 assert!(person.is_some());
417 assert!(org.is_some());
418 }
419
420 #[test]
421 fn test_batch_linking_multiple_mentions() {
422 let linker = EntityLinker::builder().build();
423
424 let mentions = vec![
425 Mention::new("Google", 0, 6).with_type(EntityType::Organization),
426 Mention::new("Microsoft", 15, 24).with_type(EntityType::Organization),
427 Mention::new("Apple", 30, 35).with_type(EntityType::Organization),
428 ];
429
430 let result = linker.link(&mentions, "Google and Microsoft and Apple are tech giants.");
431
432 assert_eq!(result.total_mentions, 3);
433 assert!(result.entities.len() <= 3);
434 }
435
436 #[test]
437 fn test_empty_mentions() {
438 let linker = EntityLinker::builder().build();
439
440 let result = linker.link(&[], "Some text without mentions.");
441
442 assert_eq!(result.total_mentions, 0);
443 assert_eq!(result.linked_count, 0);
444 assert_eq!(result.nil_count, 0);
445 }
446
447 #[test]
448 fn test_very_short_mention() {
449 let linker = EntityLinker::builder().build();
450
451 let linked = linker.link_one("X", "X marks the spot.", None);
453
454 if let Some(entity) = linked {
456 assert!(entity.is_nil || entity.confidence < 0.3);
458 }
459 }
460
461 #[test]
462 fn test_mention_builder_pattern() {
463 let mention = Mention::new("Test", 0, 4).with_type(EntityType::Person);
464
465 assert_eq!(mention.text, "Test");
466 assert_eq!(mention.start, 0);
467 assert_eq!(mention.end, 4);
468 assert_eq!(mention.entity_type, Some(EntityType::Person));
469 }
470
471 #[test]
472 fn test_linked_entity_serialization() {
473 let entity = LinkedEntity {
474 mention_text: "Einstein".to_string(),
475 start: 0,
476 end: 8,
477 kb_id: Some("Q937".to_string()),
478 source: CandidateSource::Wikidata,
479 label: Some("Albert Einstein".to_string()),
480 iri: Some("http://www.wikidata.org/entity/Q937".to_string()),
481 confidence: 0.95,
482 is_nil: false,
483 nil_reason: None,
484 nil_action: None,
485 alternatives: vec![],
486 };
487
488 let json = serde_json::to_string(&entity).unwrap();
490 let deserialized: LinkedEntity = serde_json::from_str(&json).unwrap();
491
492 assert_eq!(deserialized.kb_id, entity.kb_id);
493 assert_eq!(deserialized.mention_text, entity.mention_text);
494 }
495
496 #[test]
497 fn test_linker_with_custom_threshold() {
498 let strict_linker = EntityLinker::builder().with_nil_threshold(0.9).build();
500
501 let lenient_linker = EntityLinker::builder().with_nil_threshold(0.1).build();
503
504 let result_strict = strict_linker.link_one("some entity", "context", None);
505 let result_lenient = lenient_linker.link_one("some entity", "context", None);
506
507 let _ = (result_strict, result_lenient);
509 }
510}