Skip to main content

anno/joint/
providers.rs

1//! Score provider implementations for pluggable unary factors.
2//!
3//! These implementations allow the joint model to use existing anno
4//! backends for NER, coreference, and entity linking.
5
6use std::sync::Arc;
7
8use crate::backends::box_embeddings::BoxEmbedding;
9use crate::linking::candidate::CandidateGenerator;
10use crate::linking::linker::{EntityLinker, Mention};
11use anno_core::EntityType;
12
13use super::types::{
14    AntecedentValue, CorefScoreProvider, JointMention, LinkScoreProvider, NerScoreProvider,
15};
16
17// =============================================================================
18// EntityLinker-based Link Score Provider
19// =============================================================================
20
21/// A `LinkScoreProvider` that uses the existing `EntityLinker` infrastructure.
22///
23/// This adapter bridges the joint model with anno's entity linking system.
24///
25/// # Example
26///
27/// ```rust,ignore
28/// use anno::joint::{JointModelBuilder, EntityLinkerProvider};
29/// use anno::linking::EntityLinker;
30/// use std::sync::Arc;
31///
32/// let linker = EntityLinker::builder()
33///     .with_max_candidates(20)
34///     .build()?;
35///
36/// let provider = EntityLinkerProvider::new(Arc::new(linker));
37/// ```
38pub struct EntityLinkerProvider {
39    linker: Arc<EntityLinker>,
40    max_candidates: usize,
41}
42
43impl EntityLinkerProvider {
44    /// Create a new provider wrapping an entity linker.
45    pub fn new(linker: Arc<EntityLinker>) -> Self {
46        Self {
47            linker,
48            max_candidates: 20,
49        }
50    }
51
52    /// Set maximum candidates to return.
53    pub fn with_max_candidates(mut self, max: usize) -> Self {
54        self.max_candidates = max;
55        self
56    }
57}
58
59impl LinkScoreProvider for EntityLinkerProvider {
60    fn link_candidates(&self, mention: &JointMention, text: &str) -> Vec<(String, f64)> {
61        // Convert JointMention to linking::Mention
62        let linking_mention = Mention::new(&mention.text, mention.start, mention.end);
63        let linking_mention = if let Some(ref entity) = mention.entity {
64            linking_mention.with_type(entity.entity_type.clone())
65        } else {
66            linking_mention
67        };
68
69        // Use EntityLinker to get candidates
70        let result = self.linker.link(&[linking_mention], text);
71
72        if result.entities.is_empty() {
73            return vec![("NIL".to_string(), 0.0)];
74        }
75
76        let linked = &result.entities[0];
77
78        // Collect candidates with scores
79        let mut candidates: Vec<(String, f64)> = linked
80            .alternatives
81            .iter()
82            .take(self.max_candidates - 1)
83            .map(|alt| (alt.kb_id.clone(), alt.score.ln().max(-100.0)))
84            .collect();
85
86        // Add the top candidate
87        if let Some(ref kb_id) = linked.kb_id {
88            candidates.insert(0, (kb_id.clone(), linked.confidence.ln().max(-100.0)));
89        }
90
91        // Always include NIL option
92        candidates.push(("NIL".to_string(), (-2.0_f64).ln())); // ~0.13 prior for NIL
93
94        candidates
95    }
96}
97
98// =============================================================================
99// Box embedding based Coref Score Provider
100// =============================================================================
101
102/// A lightweight `CorefScoreProvider` that scores antecedents using box
103/// embeddings' mutual overlap (`coreference_score`). Boxes are derived
104/// deterministically from mention text to avoid needing a trained encoder.
105///
106/// This is a stopgap adapter to let the joint model consume box-based
107/// coreference cues without wiring a full box-training pipeline.
108#[allow(dead_code)] // Future: wire up box-based coref in joint model
109pub struct BoxCorefProvider {
110    /// Half-width of the constructed boxes in each dimension.
111    pub radius: f32,
112}
113
114impl Default for BoxCorefProvider {
115    fn default() -> Self {
116        Self { radius: 0.1 }
117    }
118}
119
120impl BoxCorefProvider {
121    /// Convert a mention into a deterministic 2D box embedding.
122    ///
123    /// The hash is mapped into [0,1]² and expanded by `radius` in each dim.
124    #[allow(dead_code)] // struct is currently not wired into main joint path
125    fn mention_to_box(&self, mention: &JointMention) -> BoxEmbedding {
126        use std::hash::{Hash, Hasher};
127        let mut hasher = std::collections::hash_map::DefaultHasher::new();
128        mention.text.hash(&mut hasher);
129        mention.start.hash(&mut hasher);
130        let h = hasher.finish();
131        let v1 = ((h & 0xFFFF) as f32) / 65535.0;
132        let v2 = (((h >> 16) & 0xFFFF) as f32) / 65535.0;
133        let radius = self.radius.max(1e-3);
134        BoxEmbedding::new(
135            vec![v1 - radius, v2 - radius],
136            vec![v1 + radius, v2 + radius],
137        )
138    }
139}
140
141impl CorefScoreProvider for BoxCorefProvider {
142    fn antecedent_scores(
143        &self,
144        mention: &JointMention,
145        candidates: &[&JointMention],
146        _text: &str,
147    ) -> Vec<(AntecedentValue, f64)> {
148        let m_box = self.mention_to_box(mention);
149
150        // Score each candidate via box overlap and convert to log-score
151        let mut scores: Vec<(AntecedentValue, f64)> = candidates
152            .iter()
153            .map(|cand| {
154                let c_box = self.mention_to_box(cand);
155                let s = m_box.coreference_score(&c_box).max(1e-6);
156                (AntecedentValue::Mention(cand.idx), s.ln() as f64)
157            })
158            .collect();
159
160        // NEW cluster prior (mild)
161        scores.push((AntecedentValue::NewCluster, (-1.0_f64).ln()));
162        scores
163    }
164}
165
166// =============================================================================
167// Dictionary-based Link Score Provider
168// =============================================================================
169
170/// A simpler `LinkScoreProvider` that uses a dictionary for candidate generation.
171///
172/// This is faster than the full EntityLinker but less accurate.
173pub struct DictionaryLinkProvider {
174    generator: Arc<dyn CandidateGenerator>,
175    max_candidates: usize,
176}
177
178impl DictionaryLinkProvider {
179    /// Create a new dictionary-based provider.
180    pub fn new(generator: Arc<dyn CandidateGenerator>) -> Self {
181        Self {
182            generator,
183            max_candidates: 20,
184        }
185    }
186
187    /// Set maximum candidates to return.
188    pub fn with_max_candidates(mut self, max: usize) -> Self {
189        self.max_candidates = max;
190        self
191    }
192}
193
194impl LinkScoreProvider for DictionaryLinkProvider {
195    fn link_candidates(&self, mention: &JointMention, text: &str) -> Vec<(String, f64)> {
196        let entity_type_str = mention.entity.as_ref().map(|e| e.entity_type.to_string());
197
198        let mut candidates = self.generator.generate(
199            &mention.text,
200            text,
201            entity_type_str.as_deref(),
202            self.max_candidates,
203        );
204
205        // Compute scores and convert to log-space
206        let results: Vec<(String, f64)> = candidates
207            .iter_mut()
208            .map(|c| {
209                c.compute_score();
210                (c.kb_id.clone(), c.score.ln().max(-100.0))
211            })
212            .collect();
213
214        if results.is_empty() {
215            vec![("NIL".to_string(), 0.0)]
216        } else {
217            let mut results = results;
218            results.push(("NIL".to_string(), (-2.0_f64).ln()));
219            results
220        }
221    }
222}
223
224// =============================================================================
225// Model-based NER Score Provider
226// =============================================================================
227
228/// A `NerScoreProvider` that uses any `Model` implementation.
229///
230/// This allows using GLiNER, NuNER, or any other NER backend
231/// to provide type scores.
232pub struct ModelNerProvider {
233    model: Arc<dyn crate::Model>,
234    /// Supported entity types
235    entity_types: Vec<EntityType>,
236}
237
238impl ModelNerProvider {
239    /// Create a new provider wrapping an NER model.
240    pub fn new(model: Arc<dyn crate::Model>) -> Self {
241        let entity_types = model.supported_types();
242        Self {
243            model,
244            entity_types,
245        }
246    }
247
248    /// Override entity types to consider.
249    pub fn with_entity_types(mut self, types: Vec<EntityType>) -> Self {
250        self.entity_types = types;
251        self
252    }
253}
254
255impl NerScoreProvider for ModelNerProvider {
256    fn type_scores(&self, mention: &JointMention, text: &str) -> Vec<(EntityType, f64)> {
257        // If the mention already has an entity with a type, use that as a strong prior
258        if let Some(ref entity) = mention.entity {
259            let prior_type = entity.entity_type.clone();
260            let confidence = entity.confidence;
261
262            return self
263                .entity_types
264                .iter()
265                .map(|et| {
266                    if et == &prior_type {
267                        (et.clone(), confidence.ln().max(-100.0))
268                    } else {
269                        (et.clone(), (1.0 - confidence).ln().max(-100.0) - 2.0)
270                    }
271                })
272                .collect();
273        }
274
275        // Otherwise, run NER on the mention span
276        // Extract a context window around the mention
277        let context_start = mention.start.saturating_sub(50);
278        let context_end = (mention.end + 50).min(text.chars().count());
279        let context: String = text
280            .chars()
281            .skip(context_start)
282            .take(context_end - context_start)
283            .collect();
284
285        match self.model.extract_entities(&context, None) {
286            Ok(entities) => {
287                // Find entity overlapping with the mention
288                let mention_in_context_start = mention.start - context_start;
289                let mention_in_context_end = mention.end - context_start;
290
291                let matching_entity = entities.iter().find(|e| {
292                    e.start <= mention_in_context_end && e.end >= mention_in_context_start
293                });
294
295                match matching_entity {
296                    Some(e) => self
297                        .entity_types
298                        .iter()
299                        .map(|et| {
300                            if et == &e.entity_type {
301                                (et.clone(), e.confidence.ln().max(-100.0))
302                            } else {
303                                (et.clone(), (1.0 - e.confidence).ln().max(-100.0) - 1.0)
304                            }
305                        })
306                        .collect(),
307                    None => {
308                        // Uniform distribution as fallback
309                        let uniform = (-(self.entity_types.len() as f64)).ln();
310                        self.entity_types
311                            .iter()
312                            .map(|et| (et.clone(), uniform))
313                            .collect()
314                    }
315                }
316            }
317            Err(_) => {
318                // Fallback to uniform distribution
319                let uniform = (-(self.entity_types.len() as f64)).ln();
320                self.entity_types
321                    .iter()
322                    .map(|et| (et.clone(), uniform))
323                    .collect()
324            }
325        }
326    }
327}
328
329// =============================================================================
330// Heuristic Coref Score Provider
331// =============================================================================
332
333/// A simple heuristic `CorefScoreProvider` based on string matching.
334///
335/// This is a lightweight alternative to neural mention-ranking models.
336pub struct HeuristicCorefProvider {
337    /// Weight for exact match
338    exact_match_weight: f64,
339    /// Weight for substring match
340    substring_weight: f64,
341    /// Weight for same head word
342    head_match_weight: f64,
343    /// Distance penalty per mention
344    distance_penalty: f64,
345}
346
347impl Default for HeuristicCorefProvider {
348    fn default() -> Self {
349        Self {
350            exact_match_weight: 5.0,
351            substring_weight: 2.0,
352            head_match_weight: 3.0,
353            distance_penalty: 0.1,
354        }
355    }
356}
357
358impl HeuristicCorefProvider {
359    /// Create a new heuristic provider with default weights.
360    pub fn new() -> Self {
361        Self::default()
362    }
363
364    /// Set exact match weight.
365    pub fn with_exact_match_weight(mut self, weight: f64) -> Self {
366        self.exact_match_weight = weight;
367        self
368    }
369
370    /// Set distance penalty.
371    pub fn with_distance_penalty(mut self, penalty: f64) -> Self {
372        self.distance_penalty = penalty;
373        self
374    }
375}
376
377impl CorefScoreProvider for HeuristicCorefProvider {
378    fn antecedent_scores(
379        &self,
380        mention: &JointMention,
381        candidates: &[&JointMention],
382        _text: &str,
383    ) -> Vec<(AntecedentValue, f64)> {
384        let mention_text_lower = mention.text.to_lowercase();
385        let mention_head_lower = mention.head.to_lowercase();
386
387        let mut scores: Vec<(AntecedentValue, f64)> = candidates
388            .iter()
389            .enumerate()
390            .map(|(i, cand)| {
391                let cand_text_lower = cand.text.to_lowercase();
392                let cand_head_lower = cand.head.to_lowercase();
393
394                let mut score = 0.0;
395
396                // Exact match bonus
397                if mention_text_lower == cand_text_lower {
398                    score += self.exact_match_weight;
399                }
400
401                // Substring match
402                if mention_text_lower.contains(&cand_text_lower)
403                    || cand_text_lower.contains(&mention_text_lower)
404                {
405                    score += self.substring_weight;
406                }
407
408                // Head word match
409                if mention_head_lower == cand_head_lower {
410                    score += self.head_match_weight;
411                }
412
413                // Distance penalty
414                let distance = candidates.len() - i; // More recent = higher score
415                score -= self.distance_penalty * distance as f64;
416
417                (AntecedentValue::Mention(cand.idx), score)
418            })
419            .collect();
420
421        // Add NEW_CLUSTER option
422        // New cluster is preferred for proper nouns that don't match anything
423        let new_cluster_score = if mention.mention_kind.is_proper_name() {
424            1.0 // Proper nouns more likely to start new clusters
425        } else {
426            -1.0 // Pronouns/nominals more likely to be anaphoric
427        };
428        scores.push((AntecedentValue::NewCluster, new_cluster_score));
429
430        scores
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_heuristic_coref_provider() {
440        let provider = HeuristicCorefProvider::default();
441
442        let mention = JointMention {
443            idx: 2,
444            text: "he".to_string(),
445            head: "he".to_string(),
446            start: 20,
447            end: 22,
448            mention_kind: super::super::MentionKind::Pronominal,
449            entity: None,
450            entity_type: None,
451        };
452
453        let cand1 = JointMention {
454            idx: 0,
455            text: "John Smith".to_string(),
456            head: "Smith".to_string(),
457            start: 0,
458            end: 10,
459            mention_kind: super::super::MentionKind::Proper,
460            entity: None,
461            entity_type: None,
462        };
463
464        let cand2 = JointMention {
465            idx: 1,
466            text: "the CEO".to_string(),
467            head: "CEO".to_string(),
468            start: 12,
469            end: 19,
470            mention_kind: super::super::MentionKind::Nominal,
471            entity: None,
472            entity_type: None,
473        };
474
475        let candidates: Vec<&JointMention> = vec![&cand1, &cand2];
476        let scores = provider.antecedent_scores(&mention, &candidates, "");
477
478        // Should have 3 options: 2 candidates + NEW_CLUSTER
479        assert_eq!(scores.len(), 3);
480
481        // NEW_CLUSTER should have negative score for pronouns
482        let new_cluster_score = scores
483            .iter()
484            .find(|(v, _)| matches!(v, AntecedentValue::NewCluster))
485            .map(|(_, s)| *s)
486            .unwrap();
487        assert!(new_cluster_score < 0.0);
488    }
489
490    #[test]
491    fn test_heuristic_coref_exact_match() {
492        let provider = HeuristicCorefProvider::default();
493
494        let mention = JointMention {
495            idx: 1,
496            text: "John Smith".to_string(),
497            head: "Smith".to_string(),
498            start: 50,
499            end: 60,
500            mention_kind: super::super::MentionKind::Proper,
501            entity: None,
502            entity_type: None,
503        };
504
505        let cand = JointMention {
506            idx: 0,
507            text: "John Smith".to_string(),
508            head: "Smith".to_string(),
509            start: 0,
510            end: 10,
511            mention_kind: super::super::MentionKind::Proper,
512            entity: None,
513            entity_type: None,
514        };
515
516        let scores = provider.antecedent_scores(&mention, &[&cand], "");
517
518        // Exact match should have high score
519        let mention_score = scores
520            .iter()
521            .find(|(v, _)| matches!(v, AntecedentValue::Mention(0)))
522            .map(|(_, s)| *s)
523            .unwrap();
524
525        // Should include exact match + head match bonuses
526        assert!(mention_score > 7.0); // 5.0 + 3.0 - small distance penalty
527    }
528}