1use std::sync::Arc;
7
8use crate::backends::box_embeddings::BoxEmbedding;
9use crate::linking::candidate::CandidateGenerator;
10use crate::linking::linker::{EntityLinker, Mention};
11use anno_core::EntityType;
12
13use super::types::{
14 AntecedentValue, CorefScoreProvider, JointMention, LinkScoreProvider, NerScoreProvider,
15};
16
17pub struct EntityLinkerProvider {
39 linker: Arc<EntityLinker>,
40 max_candidates: usize,
41}
42
43impl EntityLinkerProvider {
44 pub fn new(linker: Arc<EntityLinker>) -> Self {
46 Self {
47 linker,
48 max_candidates: 20,
49 }
50 }
51
52 pub fn with_max_candidates(mut self, max: usize) -> Self {
54 self.max_candidates = max;
55 self
56 }
57}
58
59impl LinkScoreProvider for EntityLinkerProvider {
60 fn link_candidates(&self, mention: &JointMention, text: &str) -> Vec<(String, f64)> {
61 let linking_mention = Mention::new(&mention.text, mention.start, mention.end);
63 let linking_mention = if let Some(ref entity) = mention.entity {
64 linking_mention.with_type(entity.entity_type.clone())
65 } else {
66 linking_mention
67 };
68
69 let result = self.linker.link(&[linking_mention], text);
71
72 if result.entities.is_empty() {
73 return vec![("NIL".to_string(), 0.0)];
74 }
75
76 let linked = &result.entities[0];
77
78 let mut candidates: Vec<(String, f64)> = linked
80 .alternatives
81 .iter()
82 .take(self.max_candidates - 1)
83 .map(|alt| (alt.kb_id.clone(), alt.score.ln().max(-100.0)))
84 .collect();
85
86 if let Some(ref kb_id) = linked.kb_id {
88 candidates.insert(0, (kb_id.clone(), linked.confidence.ln().max(-100.0)));
89 }
90
91 candidates.push(("NIL".to_string(), (-2.0_f64).ln())); candidates
95 }
96}
97
98#[allow(dead_code)] pub struct BoxCorefProvider {
110 pub radius: f32,
112}
113
114impl Default for BoxCorefProvider {
115 fn default() -> Self {
116 Self { radius: 0.1 }
117 }
118}
119
120impl BoxCorefProvider {
121 #[allow(dead_code)] fn mention_to_box(&self, mention: &JointMention) -> BoxEmbedding {
126 use std::hash::{Hash, Hasher};
127 let mut hasher = std::collections::hash_map::DefaultHasher::new();
128 mention.text.hash(&mut hasher);
129 mention.start.hash(&mut hasher);
130 let h = hasher.finish();
131 let v1 = ((h & 0xFFFF) as f32) / 65535.0;
132 let v2 = (((h >> 16) & 0xFFFF) as f32) / 65535.0;
133 let radius = self.radius.max(1e-3);
134 BoxEmbedding::new(
135 vec![v1 - radius, v2 - radius],
136 vec![v1 + radius, v2 + radius],
137 )
138 }
139}
140
141impl CorefScoreProvider for BoxCorefProvider {
142 fn antecedent_scores(
143 &self,
144 mention: &JointMention,
145 candidates: &[&JointMention],
146 _text: &str,
147 ) -> Vec<(AntecedentValue, f64)> {
148 let m_box = self.mention_to_box(mention);
149
150 let mut scores: Vec<(AntecedentValue, f64)> = candidates
152 .iter()
153 .map(|cand| {
154 let c_box = self.mention_to_box(cand);
155 let s = m_box.coreference_score(&c_box).max(1e-6);
156 (AntecedentValue::Mention(cand.idx), s.ln() as f64)
157 })
158 .collect();
159
160 scores.push((AntecedentValue::NewCluster, (-1.0_f64).ln()));
162 scores
163 }
164}
165
166pub struct DictionaryLinkProvider {
174 generator: Arc<dyn CandidateGenerator>,
175 max_candidates: usize,
176}
177
178impl DictionaryLinkProvider {
179 pub fn new(generator: Arc<dyn CandidateGenerator>) -> Self {
181 Self {
182 generator,
183 max_candidates: 20,
184 }
185 }
186
187 pub fn with_max_candidates(mut self, max: usize) -> Self {
189 self.max_candidates = max;
190 self
191 }
192}
193
194impl LinkScoreProvider for DictionaryLinkProvider {
195 fn link_candidates(&self, mention: &JointMention, text: &str) -> Vec<(String, f64)> {
196 let entity_type_str = mention.entity.as_ref().map(|e| e.entity_type.to_string());
197
198 let mut candidates = self.generator.generate(
199 &mention.text,
200 text,
201 entity_type_str.as_deref(),
202 self.max_candidates,
203 );
204
205 let results: Vec<(String, f64)> = candidates
207 .iter_mut()
208 .map(|c| {
209 c.compute_score();
210 (c.kb_id.clone(), c.score.ln().max(-100.0))
211 })
212 .collect();
213
214 if results.is_empty() {
215 vec![("NIL".to_string(), 0.0)]
216 } else {
217 let mut results = results;
218 results.push(("NIL".to_string(), (-2.0_f64).ln()));
219 results
220 }
221 }
222}
223
224pub struct ModelNerProvider {
233 model: Arc<dyn crate::Model>,
234 entity_types: Vec<EntityType>,
236}
237
238impl ModelNerProvider {
239 pub fn new(model: Arc<dyn crate::Model>) -> Self {
241 let entity_types = model.supported_types();
242 Self {
243 model,
244 entity_types,
245 }
246 }
247
248 pub fn with_entity_types(mut self, types: Vec<EntityType>) -> Self {
250 self.entity_types = types;
251 self
252 }
253}
254
255impl NerScoreProvider for ModelNerProvider {
256 fn type_scores(&self, mention: &JointMention, text: &str) -> Vec<(EntityType, f64)> {
257 if let Some(ref entity) = mention.entity {
259 let prior_type = entity.entity_type.clone();
260 let confidence = entity.confidence;
261
262 return self
263 .entity_types
264 .iter()
265 .map(|et| {
266 if et == &prior_type {
267 (et.clone(), confidence.ln().max(-100.0))
268 } else {
269 (et.clone(), (1.0 - confidence).ln().max(-100.0) - 2.0)
270 }
271 })
272 .collect();
273 }
274
275 let context_start = mention.start.saturating_sub(50);
278 let context_end = (mention.end + 50).min(text.chars().count());
279 let context: String = text
280 .chars()
281 .skip(context_start)
282 .take(context_end - context_start)
283 .collect();
284
285 match self.model.extract_entities(&context, None) {
286 Ok(entities) => {
287 let mention_in_context_start = mention.start - context_start;
289 let mention_in_context_end = mention.end - context_start;
290
291 let matching_entity = entities.iter().find(|e| {
292 e.start <= mention_in_context_end && e.end >= mention_in_context_start
293 });
294
295 match matching_entity {
296 Some(e) => self
297 .entity_types
298 .iter()
299 .map(|et| {
300 if et == &e.entity_type {
301 (et.clone(), e.confidence.ln().max(-100.0))
302 } else {
303 (et.clone(), (1.0 - e.confidence).ln().max(-100.0) - 1.0)
304 }
305 })
306 .collect(),
307 None => {
308 let uniform = (-(self.entity_types.len() as f64)).ln();
310 self.entity_types
311 .iter()
312 .map(|et| (et.clone(), uniform))
313 .collect()
314 }
315 }
316 }
317 Err(_) => {
318 let uniform = (-(self.entity_types.len() as f64)).ln();
320 self.entity_types
321 .iter()
322 .map(|et| (et.clone(), uniform))
323 .collect()
324 }
325 }
326 }
327}
328
329pub struct HeuristicCorefProvider {
337 exact_match_weight: f64,
339 substring_weight: f64,
341 head_match_weight: f64,
343 distance_penalty: f64,
345}
346
347impl Default for HeuristicCorefProvider {
348 fn default() -> Self {
349 Self {
350 exact_match_weight: 5.0,
351 substring_weight: 2.0,
352 head_match_weight: 3.0,
353 distance_penalty: 0.1,
354 }
355 }
356}
357
358impl HeuristicCorefProvider {
359 pub fn new() -> Self {
361 Self::default()
362 }
363
364 pub fn with_exact_match_weight(mut self, weight: f64) -> Self {
366 self.exact_match_weight = weight;
367 self
368 }
369
370 pub fn with_distance_penalty(mut self, penalty: f64) -> Self {
372 self.distance_penalty = penalty;
373 self
374 }
375}
376
377impl CorefScoreProvider for HeuristicCorefProvider {
378 fn antecedent_scores(
379 &self,
380 mention: &JointMention,
381 candidates: &[&JointMention],
382 _text: &str,
383 ) -> Vec<(AntecedentValue, f64)> {
384 let mention_text_lower = mention.text.to_lowercase();
385 let mention_head_lower = mention.head.to_lowercase();
386
387 let mut scores: Vec<(AntecedentValue, f64)> = candidates
388 .iter()
389 .enumerate()
390 .map(|(i, cand)| {
391 let cand_text_lower = cand.text.to_lowercase();
392 let cand_head_lower = cand.head.to_lowercase();
393
394 let mut score = 0.0;
395
396 if mention_text_lower == cand_text_lower {
398 score += self.exact_match_weight;
399 }
400
401 if mention_text_lower.contains(&cand_text_lower)
403 || cand_text_lower.contains(&mention_text_lower)
404 {
405 score += self.substring_weight;
406 }
407
408 if mention_head_lower == cand_head_lower {
410 score += self.head_match_weight;
411 }
412
413 let distance = candidates.len() - i; score -= self.distance_penalty * distance as f64;
416
417 (AntecedentValue::Mention(cand.idx), score)
418 })
419 .collect();
420
421 let new_cluster_score = if mention.mention_kind.is_proper_name() {
424 1.0 } else {
426 -1.0 };
428 scores.push((AntecedentValue::NewCluster, new_cluster_score));
429
430 scores
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_heuristic_coref_provider() {
440 let provider = HeuristicCorefProvider::default();
441
442 let mention = JointMention {
443 idx: 2,
444 text: "he".to_string(),
445 head: "he".to_string(),
446 start: 20,
447 end: 22,
448 mention_kind: super::super::MentionKind::Pronominal,
449 entity: None,
450 entity_type: None,
451 };
452
453 let cand1 = JointMention {
454 idx: 0,
455 text: "John Smith".to_string(),
456 head: "Smith".to_string(),
457 start: 0,
458 end: 10,
459 mention_kind: super::super::MentionKind::Proper,
460 entity: None,
461 entity_type: None,
462 };
463
464 let cand2 = JointMention {
465 idx: 1,
466 text: "the CEO".to_string(),
467 head: "CEO".to_string(),
468 start: 12,
469 end: 19,
470 mention_kind: super::super::MentionKind::Nominal,
471 entity: None,
472 entity_type: None,
473 };
474
475 let candidates: Vec<&JointMention> = vec![&cand1, &cand2];
476 let scores = provider.antecedent_scores(&mention, &candidates, "");
477
478 assert_eq!(scores.len(), 3);
480
481 let new_cluster_score = scores
483 .iter()
484 .find(|(v, _)| matches!(v, AntecedentValue::NewCluster))
485 .map(|(_, s)| *s)
486 .unwrap();
487 assert!(new_cluster_score < 0.0);
488 }
489
490 #[test]
491 fn test_heuristic_coref_exact_match() {
492 let provider = HeuristicCorefProvider::default();
493
494 let mention = JointMention {
495 idx: 1,
496 text: "John Smith".to_string(),
497 head: "Smith".to_string(),
498 start: 50,
499 end: 60,
500 mention_kind: super::super::MentionKind::Proper,
501 entity: None,
502 entity_type: None,
503 };
504
505 let cand = JointMention {
506 idx: 0,
507 text: "John Smith".to_string(),
508 head: "Smith".to_string(),
509 start: 0,
510 end: 10,
511 mention_kind: super::super::MentionKind::Proper,
512 entity: None,
513 entity_type: None,
514 };
515
516 let scores = provider.antecedent_scores(&mention, &[&cand], "");
517
518 let mention_score = scores
520 .iter()
521 .find(|(v, _)| matches!(v, AntecedentValue::Mention(0)))
522 .map(|(_, s)| *s)
523 .unwrap();
524
525 assert!(mention_score > 7.0); }
528}