Skip to main content

anno/joint/
types.rs

1//! Core types for joint entity analysis.
2//!
3//! This module implements the factor graph approach to joint NER, coreference,
4//! and entity linking from Durrett & Klein (2014). The key insight is that
5//! these three tasks are **interdependent**:
6//!
7//! - **NER informs coreference**: "President Obama" and "the CEO" likely don't
8//!   corefer (PERSON vs likely PERSON/ORG mismatch)
9//! - **Coreference informs linking**: If "Microsoft" and "the company" corefer,
10//!   they should link to the same Wikipedia entity
11//! - **Linking informs NER**: If a mention links to a Wikipedia person page,
12//!   it's probably a PERSON mention
13//!
14//! # The Joint Model
15//!
16//! For each mention m_i, we have three random variables:
17//!
18//! | Variable | Domain | Meaning |
19//! |----------|--------|---------|
20//! | a_i | {1..i-1, NEW} | Antecedent (or start new entity) |
21//! | t_i | EntityTypes | Semantic type (PER, ORG, LOC, ...) |
22//! | e_i | WikiTitles ∪ {NIL} | Entity link (or no KB entry) |
23//!
24//! These are connected by **factors** that encode soft constraints:
25//!
26//! ```text
27//!   ┌─────────┐     ┌─────────┐     ┌─────────┐
28//!   │  NER    │─────│ Coref   │─────│  Link   │
29//!   │  (t_i)  │     │  (a_i)  │     │  (e_i)  │
30//!   └────┬────┘     └────┬────┘     └────┬────┘
31//!        │               │               │
32//!        └───────────────┴───────────────┘
33//!               Pairwise Factors
34//! ```
35//!
36//! # Inference
37//!
38//! We use loopy belief propagation to find marginal distributions over each
39//! variable, then decode via MAP or marginal inference.
40//!
41//! # References
42//!
43//! - Durrett & Klein (2014): "A Joint Model for Entity Analysis: Coreference,
44//!   Typing, and Linking" (TACL)
45//! - Zhao et al. (2025): RECB for cross-document event coreference (future)
46
47use crate::linking::candidate::CandidateSource;
48use crate::linking::linker::LinkedEntity;
49use crate::{Entity, EntityType, Result};
50use anno_core::{CorefChain, Mention as CorefMention};
51use serde::{Deserialize, Serialize};
52use std::collections::HashMap;
53
54use super::factors::{
55    CorefLinkFactor, CorefLinkWeights, CorefNerFactor, CorefNerWeights, Factor, LinkNerFactor,
56    LinkNerWeights, UnaryCorefFactor, UnaryLinkFactor, UnaryNerFactor, WikipediaKnowledgeStore,
57};
58use super::inference::{BeliefPropagation, InferenceConfig, Marginals};
59use std::sync::Arc;
60
61// =============================================================================
62// Variable Types
63// =============================================================================
64
65/// Unique identifier for a variable in the factor graph.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67pub struct VariableId {
68    /// Mention index this variable belongs to
69    pub mention_idx: usize,
70    /// Which variable type for this mention
71    pub var_type: VariableType,
72}
73
74/// Types of variables in the joint model.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum VariableType {
77    /// Antecedent selection: a_i ∈ {1,...,i-1,NEW}
78    Antecedent,
79    /// Semantic type: t_i ∈ EntityTypes
80    SemanticType,
81    /// Entity link: e_i ∈ WikiTitles ∪ {NIL}
82    EntityLink,
83}
84
85/// A variable in the joint model.
86#[derive(Debug, Clone)]
87pub enum JointVariable {
88    /// Antecedent for mention i
89    Antecedent {
90        /// Mention index
91        mention_idx: usize,
92        /// Possible antecedents (pruned)
93        candidates: Vec<usize>,
94    },
95    /// Semantic type for mention i
96    SemanticType {
97        /// Mention index
98        mention_idx: usize,
99        /// Possible types
100        types: Vec<EntityType>,
101    },
102    /// Entity link for mention i
103    EntityLink {
104        /// Mention index
105        mention_idx: usize,
106        /// Candidate KB IDs (e.g., Wikidata Q-numbers)
107        candidates: Vec<String>,
108    },
109}
110
111impl JointVariable {
112    /// Get the variable ID.
113    pub fn id(&self) -> VariableId {
114        match self {
115            JointVariable::Antecedent { mention_idx, .. } => VariableId {
116                mention_idx: *mention_idx,
117                var_type: VariableType::Antecedent,
118            },
119            JointVariable::SemanticType { mention_idx, .. } => VariableId {
120                mention_idx: *mention_idx,
121                var_type: VariableType::SemanticType,
122            },
123            JointVariable::EntityLink { mention_idx, .. } => VariableId {
124                mention_idx: *mention_idx,
125                var_type: VariableType::EntityLink,
126            },
127        }
128    }
129
130    /// Get domain size.
131    pub fn domain_size(&self) -> usize {
132        match self {
133            JointVariable::Antecedent { candidates, .. } => candidates.len() + 1, // +1 for NEW
134            JointVariable::SemanticType { types, .. } => types.len(),
135            JointVariable::EntityLink { candidates, .. } => candidates.len() + 1, // +1 for NIL
136        }
137    }
138}
139
140/// Domain of a variable (possible values).
141#[derive(Debug, Clone)]
142pub enum VariableDomain {
143    /// Antecedent domain: indices into mention list, plus NEW_CLUSTER
144    Antecedent(Vec<AntecedentValue>),
145    /// Type domain: entity types
146    SemanticType(Vec<EntityType>),
147    /// Link domain: KB IDs plus NIL
148    EntityLink(Vec<LinkValue>),
149}
150
151/// Value for an antecedent variable.
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
153pub enum AntecedentValue {
154    /// Links to mention at index
155    Mention(usize),
156    /// Starts a new cluster
157    NewCluster,
158}
159
160/// Value for an entity link variable.
161#[derive(Debug, Clone, PartialEq, Eq, Hash)]
162pub enum LinkValue {
163    /// Links to KB entry
164    KbId(String),
165    /// Not in knowledge base
166    Nil,
167}
168
169// =============================================================================
170// Assignment
171// =============================================================================
172
173/// An assignment of values to variables.
174#[derive(Debug, Clone, Default)]
175pub struct Assignment {
176    /// Antecedent assignments: mention_idx → antecedent
177    pub antecedents: HashMap<usize, AntecedentValue>,
178    /// Type assignments: mention_idx → entity type
179    pub types: HashMap<usize, EntityType>,
180    /// Link assignments: mention_idx → link value
181    pub links: HashMap<usize, LinkValue>,
182}
183
184impl Assignment {
185    /// Get antecedent for mention.
186    pub fn get_antecedent(&self, mention_idx: usize) -> Option<AntecedentValue> {
187        self.antecedents.get(&mention_idx).copied()
188    }
189
190    /// Get type for mention.
191    pub fn get_type(&self, mention_idx: usize) -> Option<EntityType> {
192        self.types.get(&mention_idx).cloned()
193    }
194
195    /// Get link for mention.
196    pub fn get_link(&self, mention_idx: usize) -> Option<&LinkValue> {
197        self.links.get(&mention_idx)
198    }
199
200    /// Set antecedent.
201    pub fn set_antecedent(&mut self, mention_idx: usize, value: AntecedentValue) {
202        self.antecedents.insert(mention_idx, value);
203    }
204
205    /// Set type.
206    pub fn set_type(&mut self, mention_idx: usize, value: EntityType) {
207        self.types.insert(mention_idx, value);
208    }
209
210    /// Set link.
211    pub fn set_link(&mut self, mention_idx: usize, value: LinkValue) {
212        self.links.insert(mention_idx, value);
213    }
214}
215
216// =============================================================================
217// Mention Representation
218// =============================================================================
219
220/// Kind of mention.
221#[derive(Debug, Clone, Copy, PartialEq, Eq)]
222pub enum MentionKind {
223    /// Proper name (e.g., "Barack Obama")
224    Proper,
225    /// Common noun (e.g., "the president")
226    Nominal,
227    /// Pronoun (e.g., "he", "she", "it")
228    Pronominal,
229}
230
231impl MentionKind {
232    /// Infer mention kind from text.
233    pub fn from_text(text: &str) -> Self {
234        let lower = text.to_lowercase();
235        let pronouns = [
236            "he",
237            "she",
238            "it",
239            "they",
240            "him",
241            "her",
242            "them",
243            "his",
244            "hers",
245            "its",
246            "their",
247            "himself",
248            "herself",
249            "itself",
250            "themselves",
251            "who",
252            "whom",
253            "which",
254            "that",
255        ];
256
257        if pronouns.contains(&lower.as_str()) {
258            MentionKind::Pronominal
259        } else if text.chars().next().is_some_and(|c| c.is_uppercase()) {
260            MentionKind::Proper
261        } else {
262            MentionKind::Nominal
263        }
264    }
265
266    /// Check if this is a proper noun mention.
267    pub fn is_proper_name(&self) -> bool {
268        matches!(self, MentionKind::Proper)
269    }
270
271    /// Check if this is a pronoun mention.
272    pub fn is_pronoun(&self) -> bool {
273        matches!(self, MentionKind::Pronominal)
274    }
275
276    /// Check if this is a nominal mention.
277    pub fn is_nominal(&self) -> bool {
278        matches!(self, MentionKind::Nominal)
279    }
280}
281
282// =============================================================================
283// Cross-Document Event Coreference (RECB)
284// =============================================================================
285
286/// Event coreference relation types from RECB (Zhao et al., 2025).
287///
288/// RECB extends binary coref with fine-grained near-identity relations,
289/// enabling richer annotation and evaluation of event coreference.
290#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
291pub enum EventCorefRelation {
292    /// Full identity: events are the same event instance
293    Identity,
294    /// Concept-instance: one is a general event, one is a specific instance
295    /// E.g., "protests" vs "the March 15th protest"
296    ConceptInstance,
297    /// Whole-subevent: one event contains the other
298    /// E.g., "the war" vs "the Battle of Gettysburg"
299    WholeSubevent,
300    /// Set-member: one event is part of a set described by the other
301    /// E.g., "the attacks" vs "the September 11 attack"
302    SetMember,
303    /// Topically related but not coreferent
304    TopicallyRelated,
305    /// Not related
306    NotRelated,
307    /// Cannot decide (annotation guideline escape)
308    CannotDecide,
309}
310
311impl EventCorefRelation {
312    /// Is this a positive coreference relation?
313    ///
314    /// Returns true for Identity, ConceptInstance, WholeSubevent, SetMember.
315    pub fn is_positive(&self) -> bool {
316        matches!(
317            self,
318            EventCorefRelation::Identity
319                | EventCorefRelation::ConceptInstance
320                | EventCorefRelation::WholeSubevent
321                | EventCorefRelation::SetMember
322        )
323    }
324
325    /// Convert to standard binary coreference label.
326    ///
327    /// Only Identity maps to true; all others to false.
328    pub fn to_binary(&self) -> bool {
329        matches!(self, EventCorefRelation::Identity)
330    }
331
332    /// Convert to strict binary with near-identity.
333    ///
334    /// Maps Identity, ConceptInstance, WholeSubevent, SetMember to true.
335    pub fn to_strict_binary(&self) -> bool {
336        self.is_positive()
337    }
338}
339
340/// A decontextualized event mention.
341///
342/// Decontextualization (Choi et al., 2021) transforms event mentions into
343/// self-contained sentences that don't require document context to interpret.
344///
345/// # Example
346///
347/// Original: "The company announced it yesterday."
348/// Decontextualized: "Apple Inc. announced the new iPhone on March 15, 2024."
349#[derive(Debug, Clone)]
350pub struct DecontextualizedMention {
351    /// Original mention text
352    pub original_text: String,
353    /// Decontextualized (self-contained) version
354    pub decontextualized: String,
355    /// Source document ID
356    pub doc_id: String,
357    /// Original character start offset in source document.
358    pub original_start: usize,
359    /// Original character end offset in source document.
360    pub original_end: usize,
361    /// Entities resolved during decontextualization
362    pub resolved_entities: Vec<(String, String)>, // (pronoun/reference, resolved value)
363}
364
365impl DecontextualizedMention {
366    /// Create a new decontextualized mention.
367    pub fn new(
368        original_text: impl Into<String>,
369        decontextualized: impl Into<String>,
370        doc_id: impl Into<String>,
371        original_start: usize,
372        original_end: usize,
373    ) -> Self {
374        Self {
375            original_text: original_text.into(),
376            decontextualized: decontextualized.into(),
377            doc_id: doc_id.into(),
378            original_start,
379            original_end,
380            resolved_entities: Vec::new(),
381        }
382    }
383
384    /// Add a resolved entity reference.
385    pub fn with_resolved(
386        mut self,
387        reference: impl Into<String>,
388        resolved: impl Into<String>,
389    ) -> Self {
390        self.resolved_entities
391            .push((reference.into(), resolved.into()));
392        self
393    }
394}
395
396/// An event mention for cross-document coreference.
397#[derive(Debug, Clone)]
398pub struct EventMention {
399    /// Unique mention ID
400    pub id: String,
401    /// Event trigger text
402    pub trigger: String,
403    /// Event type (if known)
404    pub event_type: Option<String>,
405    /// Source document ID
406    pub doc_id: String,
407    /// Character start offset in source document.
408    pub start: usize,
409    /// Character end offset in source document.
410    pub end: usize,
411    /// Decontextualized form (for improved annotation/modeling)
412    pub decontextualized: Option<DecontextualizedMention>,
413}
414
415// =============================================================================
416// Mention Representation
417// =============================================================================
418
419/// A mention in the joint model with all relevant context.
420#[derive(Debug, Clone)]
421pub struct JointMention {
422    /// Mention index in document
423    pub idx: usize,
424    /// Surface text
425    pub text: String,
426    /// Head word
427    pub head: String,
428    /// Start character offset
429    pub start: usize,
430    /// End character offset
431    pub end: usize,
432    /// Mention kind (proper/nominal/pronominal)
433    pub mention_kind: MentionKind,
434    /// Entity type (if known from NER)
435    pub entity_type: Option<EntityType>,
436    /// Original entity (if available)
437    pub entity: Option<Entity>,
438}
439
440impl JointMention {
441    /// Create from an Entity.
442    pub fn from_entity(idx: usize, entity: &Entity, text: &str) -> Self {
443        let mention_text = text
444            .chars()
445            .skip(entity.start)
446            .take(entity.end - entity.start)
447            .collect::<String>();
448
449        let head = mention_text
450            .split_whitespace()
451            .last()
452            .unwrap_or(&mention_text)
453            .to_string();
454
455        Self {
456            idx,
457            text: mention_text.clone(),
458            head,
459            start: entity.start,
460            end: entity.end,
461            mention_kind: MentionKind::from_text(&mention_text),
462            entity_type: Some(entity.entity_type.clone()),
463            entity: Some(entity.clone()),
464        }
465    }
466}
467
468// =============================================================================
469// Configuration
470// =============================================================================
471
472/// Configuration for joint model.
473#[derive(Debug, Clone)]
474pub struct JointConfig {
475    /// Enable Link+NER factors
476    pub enable_link_ner: bool,
477    /// Enable Coref+NER factors
478    pub enable_coref_ner: bool,
479    /// Enable Coref+Link factors
480    pub enable_coref_link: bool,
481
482    /// Maximum iterations for belief propagation
483    pub max_iterations: usize,
484    /// Convergence threshold for message changes
485    pub convergence_threshold: f64,
486
487    /// Pruning threshold for antecedent candidates (log space)
488    pub pruning_threshold: f64,
489    /// Maximum antecedent candidates to keep after pruning
490    pub max_antecedent_candidates: usize,
491
492    /// Maximum link candidates per mention
493    pub max_link_candidates: usize,
494
495    /// Entity types to consider
496    pub entity_types: Vec<EntityType>,
497}
498
499impl Default for JointConfig {
500    fn default() -> Self {
501        Self {
502            enable_link_ner: true,
503            enable_coref_ner: true,
504            enable_coref_link: true,
505
506            max_iterations: 5,
507            convergence_threshold: 1e-4,
508
509            pruning_threshold: 5.0, // Paper uses k=5
510            max_antecedent_candidates: 50,
511
512            max_link_candidates: 20,
513
514            // Include all common NER types so we can preserve original type
515            entity_types: vec![
516                EntityType::Person,
517                EntityType::Organization,
518                EntityType::Location,
519                EntityType::Date,
520                EntityType::Time,
521                EntityType::Money,
522                EntityType::Percent,
523                EntityType::Other("MISC".to_string()),
524            ],
525        }
526    }
527}
528
529// =============================================================================
530// Results
531// =============================================================================
532
533/// Result of joint entity analysis.
534#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct JointResult {
536    /// Typed entity mentions
537    pub entities: Vec<Entity>,
538    /// Coreference chains
539    pub chains: Vec<CorefChain>,
540    /// Entity links
541    pub links: Vec<LinkedEntity>,
542    /// Confidence scores per mention (averaged over variables)
543    pub confidences: Vec<f64>,
544}
545
546// =============================================================================
547// Coarse Pruning (§5 of Durrett & Klein 2014)
548// =============================================================================
549
550/// Coarse pruner for antecedent candidates.
551///
552/// From the paper (§5):
553/// "We prune the domains of the coreference variables using a coarse model
554/// consisting of the coreference factors trained in isolation."
555pub struct CoarsePruner {
556    /// Pruning threshold in log space (paper uses k=5)
557    pub threshold: f64,
558    /// Maximum candidates to keep regardless of threshold
559    pub max_candidates: usize,
560    /// Weight for string match features
561    pub string_match_weight: f64,
562    /// Weight for distance penalty
563    pub distance_weight: f64,
564}
565
566impl Default for CoarsePruner {
567    fn default() -> Self {
568        Self {
569            threshold: 5.0, // Paper: k=5
570            max_candidates: 50,
571            string_match_weight: 2.0,
572            distance_weight: 0.1,
573        }
574    }
575}
576
577impl CoarsePruner {
578    /// Prune antecedent candidates for a mention.
579    pub fn prune_candidates(&self, mention_idx: usize, mentions: &[JointMention]) -> Vec<usize> {
580        if mention_idx == 0 {
581            return vec![];
582        }
583
584        let mention = &mentions[mention_idx];
585
586        // Score all preceding mentions
587        let mut scored: Vec<(usize, f64)> = (0..mention_idx)
588            .map(|ante_idx| {
589                let score = self.score_pair(mention, &mentions[ante_idx], mention_idx - ante_idx);
590                (ante_idx, score)
591            })
592            .collect();
593
594        // Sort by score (best first)
595        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
596
597        if scored.is_empty() {
598            return vec![];
599        }
600
601        // Find best score
602        let best_score = scored[0].1;
603
604        // Prune: keep candidates within threshold of best
605        scored
606            .into_iter()
607            .take_while(|(_, score)| best_score - *score <= self.threshold)
608            .take(self.max_candidates)
609            .map(|(idx, _)| idx)
610            .collect()
611    }
612
613    /// Score a mention-antecedent pair.
614    fn score_pair(
615        &self,
616        mention: &JointMention,
617        antecedent: &JointMention,
618        distance: usize,
619    ) -> f64 {
620        let mut score = 0.0;
621
622        // String match features
623        let m_lower = mention.text.to_lowercase();
624        let a_lower = antecedent.text.to_lowercase();
625        let m_head = mention.head.to_lowercase();
626        let a_head = antecedent.head.to_lowercase();
627
628        // Exact match (strongest signal)
629        if m_lower == a_lower {
630            score += self.string_match_weight * 1.0;
631        }
632        // Head match
633        else if m_head == a_head {
634            score += self.string_match_weight * 0.6;
635        }
636        // Substring match
637        else if m_lower.contains(&a_lower) || a_lower.contains(&m_lower) {
638            score += self.string_match_weight * 0.3;
639        }
640
641        // Type compatibility
642        match (mention.mention_kind, antecedent.mention_kind) {
643            // Pronouns resolve to proper nouns well
644            (MentionKind::Pronominal, MentionKind::Proper) => score += 0.5,
645            // Same type is good
646            (a, b) if a == b => score += 0.3,
647            _ => {}
648        }
649
650        // Distance penalty (log distance)
651        score -= self.distance_weight * (distance as f64 + 1.0).ln();
652
653        score
654    }
655}
656
657// =============================================================================
658// Model
659// =============================================================================
660
661/// Joint model for entity analysis.
662///
663/// Combines NER, coreference, and entity linking in a single factor graph
664/// following Durrett & Klein (2014).
665pub struct JointModel {
666    config: JointConfig,
667    /// Coarse pruner for antecedent candidates
668    pruner: CoarsePruner,
669    /// Wikipedia knowledge store for semantics lookups
670    knowledge_store: Option<Arc<WikipediaKnowledgeStore>>,
671    /// Optional custom NER score provider
672    ner_provider: Option<Arc<dyn NerScoreProvider>>,
673    /// Optional custom coref score provider
674    coref_provider: Option<Arc<dyn CorefScoreProvider>>,
675    /// Optional custom link score provider
676    link_provider: Option<Arc<dyn LinkScoreProvider>>,
677}
678
679impl Default for JointModel {
680    fn default() -> Self {
681        Self::new(JointConfig::default()).expect("default config should always succeed")
682    }
683}
684
685impl JointModel {
686    /// Create a new joint model.
687    pub fn new(config: JointConfig) -> Result<Self> {
688        let pruner = CoarsePruner {
689            threshold: config.pruning_threshold,
690            max_candidates: config.max_antecedent_candidates,
691            ..Default::default()
692        };
693
694        Ok(Self {
695            config,
696            pruner,
697            knowledge_store: None,
698            ner_provider: None,
699            coref_provider: None,
700            link_provider: None,
701        })
702    }
703
704    /// Add Wikipedia knowledge store for semantics lookups (Link+NER factors).
705    pub fn with_knowledge(mut self, store: Arc<WikipediaKnowledgeStore>) -> Self {
706        self.knowledge_store = Some(store);
707        self
708    }
709
710    /// Attach a custom NER score provider.
711    pub fn with_ner_provider(mut self, provider: Arc<dyn NerScoreProvider>) -> Self {
712        self.ner_provider = Some(provider);
713        self
714    }
715
716    /// Attach a custom coreference score provider.
717    pub fn with_coref_provider(mut self, provider: Arc<dyn CorefScoreProvider>) -> Self {
718        self.coref_provider = Some(provider);
719        self
720    }
721
722    /// Attach a custom link score provider.
723    pub fn with_link_provider(mut self, provider: Arc<dyn LinkScoreProvider>) -> Self {
724        self.link_provider = Some(provider);
725        self
726    }
727
728    /// Analyze text jointly for entities, coreference, and links.
729    pub fn analyze(&self, text: &str, entities: &[Entity]) -> Result<JointResult> {
730        // 1. Create joint mentions from entities
731        let mentions: Vec<JointMention> = entities
732            .iter()
733            .enumerate()
734            .map(|(i, e)| JointMention::from_entity(i, e, text))
735            .collect();
736
737        if mentions.is_empty() {
738            return Ok(JointResult {
739                entities: vec![],
740                chains: vec![],
741                links: vec![],
742                confidences: vec![],
743            });
744        }
745
746        // 2. Build variables
747        let variables = self.build_variables(&mentions);
748
749        // 3. Build factors
750        let factors = self.build_factors(&mentions, &variables);
751
752        // 4. Run belief propagation
753        let inference_config = InferenceConfig {
754            max_iterations: self.config.max_iterations,
755            convergence_threshold: self.config.convergence_threshold,
756            ..Default::default()
757        };
758        let mut bp = BeliefPropagation::new(factors, variables.clone(), inference_config);
759        let marginals = bp.run();
760
761        // 5. Decode using MBR
762        let (entities_out, chains, links, confidences) =
763            self.decode(&mentions, &variables, &marginals);
764
765        Ok(JointResult {
766            entities: entities_out,
767            chains,
768            links,
769            confidences,
770        })
771    }
772
773    /// Build variables for all mentions.
774    fn build_variables(&self, mentions: &[JointMention]) -> Vec<JointVariable> {
775        let mut variables = Vec::new();
776
777        for (i, _mention) in mentions.iter().enumerate() {
778            // Antecedent variable (for all except first)
779            if i > 0 {
780                let pruned = self.pruner.prune_candidates(i, mentions);
781                variables.push(JointVariable::Antecedent {
782                    mention_idx: i,
783                    candidates: pruned,
784                });
785            }
786
787            // Semantic type variable
788            variables.push(JointVariable::SemanticType {
789                mention_idx: i,
790                types: self.config.entity_types.clone(),
791            });
792
793            // Entity link variable
794            // In production: query an external linker for candidates (only for Proper mentions).
795            let link_candidates: Vec<String> = vec![];
796            variables.push(JointVariable::EntityLink {
797                mention_idx: i,
798                candidates: link_candidates,
799            });
800        }
801
802        variables
803    }
804
805    /// Build factors for the model.
806    fn build_factors(
807        &self,
808        mentions: &[JointMention],
809        _variables: &[JointVariable],
810    ) -> Vec<Box<dyn Factor>> {
811        let mut factors: Vec<Box<dyn Factor>> = Vec::new();
812
813        for mention in mentions {
814            let i = mention.idx;
815
816            // Unary NER factor
817            let type_scores: Vec<(EntityType, f64)> = if let Some(ref provider) = self.ner_provider
818            {
819                provider.type_scores(mention, mention.text.as_str())
820            } else {
821                let original_type = mention.entity.as_ref().map(|e| &e.entity_type);
822                self.config
823                    .entity_types
824                    .iter()
825                    .map(|t| {
826                        let score = if original_type == Some(t) {
827                            10.0 // Strong prior from NER
828                        } else {
829                            -5.0 // Penalize non-matching types
830                        };
831                        (t.clone(), score)
832                    })
833                    .collect()
834            };
835            factors.push(Box::new(UnaryNerFactor::new(i, type_scores)));
836
837            // Unary coref factor (for mentions after first)
838            if i > 0 {
839                let candidates: Vec<usize> =
840                    (0..i).take(self.config.max_antecedent_candidates).collect();
841                let coref_scores: Vec<(AntecedentValue, f64)> =
842                    if let Some(ref provider) = self.coref_provider {
843                        // Build candidate refs
844                        let cand_refs: Vec<&JointMention> =
845                            candidates.iter().map(|&idx| &mentions[idx]).collect();
846                        provider.antecedent_scores(mention, &cand_refs, mention.text.as_str())
847                    } else {
848                        let mut scores: Vec<(AntecedentValue, f64)> = candidates
849                            .iter()
850                            .map(|&ante| {
851                                let ante_mention = &mentions[ante];
852                                let head_match = if mention.head.to_lowercase()
853                                    == ante_mention.head.to_lowercase()
854                                {
855                                    2.0
856                                } else {
857                                    0.0
858                                };
859                                let distance_penalty = -0.1 * (i - ante) as f64;
860                                (
861                                    AntecedentValue::Mention(ante),
862                                    head_match + distance_penalty,
863                                )
864                            })
865                            .collect();
866                        scores.push((AntecedentValue::NewCluster, 0.0));
867                        scores
868                    };
869                factors.push(Box::new(UnaryCorefFactor::new(i, coref_scores)));
870            }
871
872            // Unary link factor
873            let link_candidates_raw = if let Some(ref provider) = self.link_provider {
874                provider.link_candidates(mention, mention.text.as_str())
875            } else {
876                vec![]
877            };
878            let link_candidates: Vec<(LinkValue, f64)> = link_candidates_raw
879                .into_iter()
880                .map(|(id, score)| {
881                    let lv = if id == "NIL" {
882                        LinkValue::Nil
883                    } else {
884                        LinkValue::KbId(id)
885                    };
886                    (lv, score)
887                })
888                .collect();
889            factors.push(Box::new(UnaryLinkFactor::new(i, link_candidates)));
890        }
891
892        // Cross-task factors
893        for mention in mentions {
894            let i = mention.idx;
895
896            if i > 0 {
897                let candidates: Vec<usize> =
898                    (0..i).take(self.config.max_antecedent_candidates).collect();
899
900                for &ante in &candidates {
901                    // Coref+NER factor
902                    if self.config.enable_coref_ner {
903                        factors.push(Box::new(CorefNerFactor::new(
904                            i,
905                            ante,
906                            CorefNerWeights::default(),
907                        )));
908                    }
909
910                    // Coref+Link factor
911                    if self.config.enable_coref_link {
912                        let mut factor = CorefLinkFactor::new(i, ante, CorefLinkWeights::default());
913                        if let Some(ref store) = self.knowledge_store {
914                            factor = factor.with_knowledge(store.clone());
915                        }
916                        factors.push(Box::new(factor));
917                    }
918                }
919            }
920
921            // Link+NER factor
922            if self.config.enable_link_ner {
923                let mut factor = LinkNerFactor::new(i, LinkNerWeights::default());
924                if let Some(ref store) = self.knowledge_store {
925                    factor = factor.with_knowledge(store.clone());
926                }
927                factors.push(Box::new(factor));
928            }
929        }
930
931        factors
932    }
933
934    /// Decode assignments from marginals using MBR.
935    fn decode(
936        &self,
937        mentions: &[JointMention],
938        variables: &[JointVariable],
939        marginals: &Marginals,
940    ) -> (Vec<Entity>, Vec<CorefChain>, Vec<LinkedEntity>, Vec<f64>) {
941        let mut entities = Vec::new();
942        let mut links = Vec::new();
943        let mut confidences = Vec::new();
944        let mut antecedents: HashMap<usize, AntecedentValue> = HashMap::new();
945
946        for var in variables {
947            let var_id = var.id();
948            if let Some(best_idx) = marginals.argmax(&var_id) {
949                let prob = marginals.prob(&var_id, best_idx).unwrap_or(0.0);
950
951                match var {
952                    JointVariable::Antecedent {
953                        mention_idx,
954                        candidates,
955                    } => {
956                        let value = if best_idx < candidates.len() {
957                            AntecedentValue::Mention(candidates[best_idx])
958                        } else {
959                            AntecedentValue::NewCluster
960                        };
961                        antecedents.insert(*mention_idx, value);
962                    }
963                    JointVariable::SemanticType {
964                        mention_idx, types, ..
965                    } => {
966                        let m = &mentions[*mention_idx];
967                        // Use inferred type if confident, otherwise fall back to original NER
968                        let (entity_type, conf) = if let Some(inferred_type) = types.get(best_idx) {
969                            // If prob is high, use inferred type
970                            if prob > 0.3 {
971                                (inferred_type.clone(), prob)
972                            } else if let Some(original) = &m.entity {
973                                // Fall back to original NER type
974                                (original.entity_type.clone(), original.confidence)
975                            } else {
976                                (inferred_type.clone(), prob)
977                            }
978                        } else if let Some(original) = &m.entity {
979                            // No inference available, use original
980                            (original.entity_type.clone(), original.confidence)
981                        } else {
982                            // Should not happen
983                            continue;
984                        };
985                        entities.push(Entity::new(&m.text, entity_type, m.start, m.end, conf));
986                        confidences.push(conf);
987                    }
988                    JointVariable::EntityLink {
989                        mention_idx,
990                        candidates,
991                    } => {
992                        let link_value = if best_idx < candidates.len() {
993                            LinkValue::KbId(candidates[best_idx].clone())
994                        } else {
995                            LinkValue::Nil
996                        };
997                        if let LinkValue::KbId(kb_id) = link_value {
998                            let m = &mentions[*mention_idx];
999                            links.push(LinkedEntity {
1000                                mention_text: m.text.clone(),
1001                                start: m.start,
1002                                end: m.end,
1003                                kb_id: Some(kb_id),
1004                                source: CandidateSource::Wikidata,
1005                                label: None,
1006                                iri: None,
1007                                confidence: prob,
1008                                is_nil: false,
1009                                nil_reason: None,
1010                                nil_action: None,
1011                                alternatives: Vec::new(),
1012                            });
1013                        }
1014                    }
1015                }
1016            }
1017        }
1018
1019        // Build coreference chains from antecedent assignments
1020        let chains = self.build_chains(&antecedents, mentions);
1021
1022        (entities, chains, links, confidences)
1023    }
1024
1025    /// Build coreference chains from antecedent assignments.
1026    fn build_chains(
1027        &self,
1028        antecedents: &HashMap<usize, AntecedentValue>,
1029        mentions: &[JointMention],
1030    ) -> Vec<CorefChain> {
1031        let n_mentions = mentions.len();
1032        // Union-find to group mentions
1033        let mut parent: Vec<usize> = (0..n_mentions).collect();
1034
1035        fn find(parent: &mut [usize], i: usize) -> usize {
1036            if parent[i] != i {
1037                parent[i] = find(parent, parent[i]);
1038            }
1039            parent[i]
1040        }
1041
1042        fn union(parent: &mut [usize], i: usize, j: usize) {
1043            let pi = find(parent, i);
1044            let pj = find(parent, j);
1045            if pi != pj {
1046                parent[pi] = pj;
1047            }
1048        }
1049
1050        // Process antecedent assignments
1051        for (&mention_idx, &ante_value) in antecedents {
1052            if let AntecedentValue::Mention(ante_idx) = ante_value {
1053                union(&mut parent, mention_idx, ante_idx);
1054            }
1055        }
1056
1057        // Group by root
1058        let mut clusters: HashMap<usize, Vec<usize>> = HashMap::new();
1059        for i in 0..n_mentions {
1060            let root = find(&mut parent, i);
1061            clusters.entry(root).or_default().push(i);
1062        }
1063
1064        // Convert to CorefChain
1065        clusters
1066            .into_iter()
1067            .filter(|(_, members)| members.len() > 1) // Only non-singleton
1068            .enumerate()
1069            .map(|(chain_id, (_, mut members))| {
1070                members.sort();
1071                let coref_mentions: Vec<CorefMention> = members
1072                    .iter()
1073                    .map(|&idx| {
1074                        let m = &mentions[idx];
1075                        CorefMention {
1076                            text: m.text.clone(),
1077                            start: m.start,
1078                            end: m.end,
1079                            head_start: None,
1080                            head_end: None,
1081                            entity_type: m.entity.as_ref().map(|e| format!("{:?}", e.entity_type)),
1082                            mention_type: None,
1083                        }
1084                    })
1085                    .collect();
1086                CorefChain {
1087                    cluster_id: Some(anno_core::CanonicalId::new(chain_id as u64)),
1088                    mentions: coref_mentions,
1089                    entity_type: None,
1090                }
1091            })
1092            .collect()
1093    }
1094
1095    /// Get configuration.
1096    pub fn config(&self) -> &JointConfig {
1097        &self.config
1098    }
1099
1100    /// Extract entities from raw text (requires external NER first).
1101    ///
1102    /// This is a convenience method for pipelines that want to use
1103    /// JointModel as the final step after mention detection.
1104    pub fn extract_entities_from_mentions(
1105        &self,
1106        text: &str,
1107        mentions: &[JointMention],
1108    ) -> Result<Vec<Entity>> {
1109        let entities: Vec<Entity> = mentions.iter().filter_map(|m| m.entity.clone()).collect();
1110
1111        let result = self.analyze(text, &entities)?;
1112        Ok(result.entities)
1113    }
1114}
1115
1116// =============================================================================
1117// Trait Implementations
1118// =============================================================================
1119
1120/// Implement the `Model` trait for JointModel to allow it to be used as an NER backend.
1121///
1122/// Note: JointModel requires pre-extracted entities as input, so `extract_entities`
1123/// uses an internal regex-based mention detector as a fallback.
1124impl crate::Model for JointModel {
1125    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
1126        // For raw text input, we need mention detection first.
1127        // Use a simple regex-based approach for common entity patterns.
1128        let initial_entities = detect_mentions_heuristic(text);
1129        let result = self.analyze(text, &initial_entities)?;
1130        Ok(result.entities)
1131    }
1132
1133    fn supported_types(&self) -> Vec<EntityType> {
1134        self.config.entity_types.clone()
1135    }
1136
1137    fn is_available(&self) -> bool {
1138        true
1139    }
1140
1141    fn name(&self) -> &'static str {
1142        "joint-model"
1143    }
1144
1145    fn description(&self) -> &'static str {
1146        "Joint Entity Analysis: NER + Coreference + Entity Linking (Durrett & Klein 2014)"
1147    }
1148}
1149
1150/// Implement the `CoreferenceResolver` trait for JointModel.
1151impl anno_core::CoreferenceResolver for JointModel {
1152    fn resolve(&self, entities: &[Entity]) -> Vec<Entity> {
1153        if entities.is_empty() {
1154            return vec![];
1155        }
1156
1157        // Create a dummy text for position-based analysis
1158        // In practice, you should use `analyze_with_text` if you have the text
1159        let max_end = entities.iter().map(|e| e.end).max().unwrap_or(0);
1160        let text = " ".repeat(max_end + 1);
1161
1162        match self.analyze(&text, entities) {
1163            Ok(result) => {
1164                // Assign canonical IDs based on coreference chains
1165                let mut resolved = entities.to_vec();
1166
1167                for chain in &result.chains {
1168                    let cluster_id = chain.cluster_id.unwrap_or(anno_core::CanonicalId::ZERO);
1169                    for mention in &chain.mentions {
1170                        // Find matching entity by position
1171                        for entity in &mut resolved {
1172                            if entity.start == mention.start && entity.end == mention.end {
1173                                entity.canonical_id = Some(cluster_id);
1174                            }
1175                        }
1176                    }
1177                }
1178
1179                // Assign unique IDs to singletons
1180                let mut next_id = anno_core::CanonicalId::new(result.chains.len() as u64);
1181                for entity in &mut resolved {
1182                    if entity.canonical_id.is_none() {
1183                        entity.canonical_id = Some(next_id);
1184                        next_id += 1;
1185                    }
1186                }
1187
1188                resolved
1189            }
1190            Err(_) => entities.to_vec(),
1191        }
1192    }
1193
1194    fn name(&self) -> &'static str {
1195        "joint-model-coref"
1196    }
1197}
1198
1199// =============================================================================
1200// Builder Pattern
1201// =============================================================================
1202
1203/// Builder for `JointModel` with fluent configuration.
1204///
1205/// # Example
1206///
1207/// ```rust,ignore
1208/// use anno::joint::{JointModelBuilder, WikipediaKnowledgeStore};
1209///
1210/// let model = JointModelBuilder::new()
1211///     .with_max_iterations(10)
1212///     .with_convergence_threshold(1e-5)
1213///     .enable_link_ner(true)
1214///     .enable_coref_ner(true)
1215///     .enable_coref_link(true)
1216///     .with_knowledge(knowledge_store)
1217///     .build()?;
1218/// ```
1219#[derive(Clone, Default)]
1220pub struct JointModelBuilder {
1221    config: JointConfig,
1222    knowledge_store: Option<Arc<WikipediaKnowledgeStore>>,
1223    ner_provider: Option<Arc<dyn NerScoreProvider>>,
1224    coref_provider: Option<Arc<dyn CorefScoreProvider>>,
1225    link_provider: Option<Arc<dyn LinkScoreProvider>>,
1226}
1227
1228impl JointModelBuilder {
1229    /// Create a new builder with default configuration.
1230    pub fn new() -> Self {
1231        Self::default()
1232    }
1233
1234    /// Set maximum iterations for belief propagation.
1235    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
1236        self.config.max_iterations = max_iterations;
1237        self
1238    }
1239
1240    /// Set convergence threshold for belief propagation.
1241    pub fn with_convergence_threshold(mut self, threshold: f64) -> Self {
1242        self.config.convergence_threshold = threshold;
1243        self
1244    }
1245
1246    /// Set pruning threshold for antecedent candidates.
1247    pub fn with_pruning_threshold(mut self, threshold: f64) -> Self {
1248        self.config.pruning_threshold = threshold;
1249        self
1250    }
1251
1252    /// Set maximum antecedent candidates to keep after pruning.
1253    pub fn with_max_antecedent_candidates(mut self, max: usize) -> Self {
1254        self.config.max_antecedent_candidates = max;
1255        self
1256    }
1257
1258    /// Set maximum link candidates per mention.
1259    pub fn with_max_link_candidates(mut self, max: usize) -> Self {
1260        self.config.max_link_candidates = max;
1261        self
1262    }
1263
1264    /// Enable or disable Link+NER factors.
1265    pub fn enable_link_ner(mut self, enable: bool) -> Self {
1266        self.config.enable_link_ner = enable;
1267        self
1268    }
1269
1270    /// Enable or disable Coref+NER factors.
1271    pub fn enable_coref_ner(mut self, enable: bool) -> Self {
1272        self.config.enable_coref_ner = enable;
1273        self
1274    }
1275
1276    /// Enable or disable Coref+Link factors.
1277    pub fn enable_coref_link(mut self, enable: bool) -> Self {
1278        self.config.enable_coref_link = enable;
1279        self
1280    }
1281
1282    /// Set entity types to consider.
1283    pub fn with_entity_types(mut self, types: Vec<EntityType>) -> Self {
1284        self.config.entity_types = types;
1285        self
1286    }
1287
1288    /// Add Wikipedia knowledge store for semantics lookups.
1289    pub fn with_knowledge(mut self, store: Arc<WikipediaKnowledgeStore>) -> Self {
1290        self.knowledge_store = Some(store);
1291        self
1292    }
1293
1294    /// Plug in a custom NER score provider (unary NER factors).
1295    pub fn with_ner_provider(mut self, provider: Arc<dyn NerScoreProvider>) -> Self {
1296        self.ner_provider = Some(provider);
1297        self
1298    }
1299
1300    /// Plug in a custom coreference score provider (unary coref factors).
1301    pub fn with_coref_provider(mut self, provider: Arc<dyn CorefScoreProvider>) -> Self {
1302        self.coref_provider = Some(provider);
1303        self
1304    }
1305
1306    /// Plug in a custom link score provider (unary link factors).
1307    pub fn with_link_provider(mut self, provider: Arc<dyn LinkScoreProvider>) -> Self {
1308        self.link_provider = Some(provider);
1309        self
1310    }
1311
1312    /// Build the JointModel.
1313    pub fn build(self) -> Result<JointModel> {
1314        let mut model = JointModel::new(self.config)?;
1315        if let Some(store) = self.knowledge_store {
1316            model = model.with_knowledge(store);
1317        }
1318        if let Some(ner) = self.ner_provider {
1319            model = model.with_ner_provider(ner);
1320        }
1321        if let Some(coref) = self.coref_provider {
1322            model = model.with_coref_provider(coref);
1323        }
1324        if let Some(link) = self.link_provider {
1325            model = model.with_link_provider(link);
1326        }
1327        Ok(model)
1328    }
1329}
1330
1331// =============================================================================
1332// Score Provider Traits
1333// =============================================================================
1334
1335/// Trait for providing NER scores for mentions.
1336///
1337/// Allows plugging in different NER backends to provide unary type scores.
1338pub trait NerScoreProvider: Send + Sync {
1339    /// Get type scores for a mention.
1340    ///
1341    /// Returns a vector of (EntityType, log_score) pairs.
1342    fn type_scores(&self, mention: &JointMention, text: &str) -> Vec<(EntityType, f64)>;
1343}
1344
1345/// Trait for providing coreference scores for mention pairs.
1346///
1347/// Allows plugging in different mention-ranking models.
1348pub trait CorefScoreProvider: Send + Sync {
1349    /// Get antecedent scores for a mention.
1350    ///
1351    /// Returns scores for each candidate antecedent plus NEW_CLUSTER.
1352    fn antecedent_scores(
1353        &self,
1354        mention: &JointMention,
1355        candidates: &[&JointMention],
1356        text: &str,
1357    ) -> Vec<(AntecedentValue, f64)>;
1358}
1359
1360/// Trait for providing entity linking scores.
1361///
1362/// Allows plugging in different candidate generators and rankers.
1363pub trait LinkScoreProvider: Send + Sync {
1364    /// Get link candidates for a mention.
1365    ///
1366    /// Returns KB IDs and their log scores.
1367    fn link_candidates(&self, mention: &JointMention, text: &str) -> Vec<(String, f64)>;
1368}
1369
1370// =============================================================================
1371// Heuristic Mention Detection (fallback for Model trait)
1372// =============================================================================
1373
1374/// Simple heuristic mention detection for when no external NER is available.
1375///
1376/// This is a basic fallback - for best results, use a proper NER backend first.
1377/// Uses CHARACTER offsets (not byte offsets) as required by Entity.
1378fn detect_mentions_heuristic(text: &str) -> Vec<Entity> {
1379    let mut entities = Vec::new();
1380
1381    // Simple capitalized word sequence detection
1382    // Track character position explicitly
1383    let mut in_name = false;
1384    let mut name_start_char = 0;
1385    let mut char_pos = 0;
1386
1387    let chars: Vec<char> = text.chars().collect();
1388
1389    for c in &chars {
1390        if c.is_whitespace() || c.is_ascii_punctuation() {
1391            if in_name {
1392                // End of name - extract text using character positions
1393                let name_text: String = chars[name_start_char..char_pos].iter().collect();
1394
1395                if name_text.chars().count() > 1 {
1396                    entities.push(Entity::new(
1397                        &name_text,
1398                        EntityType::Other("MENTION".to_string()),
1399                        name_start_char,
1400                        char_pos,
1401                        0.5,
1402                    ));
1403                }
1404                in_name = false;
1405            }
1406        } else if c.is_uppercase() && !in_name {
1407            // Start of potential name
1408            in_name = true;
1409            name_start_char = char_pos;
1410        }
1411
1412        char_pos += 1;
1413    }
1414
1415    // Handle trailing name
1416    if in_name {
1417        let name_text: String = chars[name_start_char..char_pos].iter().collect();
1418
1419        if name_text.chars().count() > 1 {
1420            entities.push(Entity::new(
1421                &name_text,
1422                EntityType::Other("MENTION".to_string()),
1423                name_start_char,
1424                char_pos,
1425                0.5,
1426            ));
1427        }
1428    }
1429
1430    entities
1431}
1432
1433#[cfg(test)]
1434mod tests {
1435    use super::*;
1436
1437    #[test]
1438    fn test_variable_id() {
1439        let id = VariableId {
1440            mention_idx: 0,
1441            var_type: VariableType::Antecedent,
1442        };
1443        assert_eq!(id.mention_idx, 0);
1444    }
1445
1446    #[test]
1447    fn test_assignment() {
1448        let mut assignment = Assignment::default();
1449        assignment.set_antecedent(1, AntecedentValue::Mention(0));
1450        assignment.set_type(0, EntityType::Person);
1451        assignment.set_link(0, LinkValue::KbId("Q42".to_string()));
1452
1453        assert_eq!(
1454            assignment.get_antecedent(1),
1455            Some(AntecedentValue::Mention(0))
1456        );
1457        assert_eq!(assignment.get_type(0), Some(EntityType::Person));
1458        assert_eq!(
1459            assignment.get_link(0),
1460            Some(&LinkValue::KbId("Q42".to_string()))
1461        );
1462    }
1463
1464    #[test]
1465    fn test_joint_config_default() {
1466        let config = JointConfig::default();
1467        assert!(config.enable_link_ner);
1468        assert!(config.enable_coref_ner);
1469        assert!(config.enable_coref_link);
1470        assert_eq!(config.max_iterations, 5);
1471    }
1472
1473    #[test]
1474    fn test_joint_model_creation() {
1475        let model = JointModel::new(JointConfig::default());
1476        assert!(model.is_ok());
1477    }
1478
1479    #[test]
1480    fn test_mention_kind_detection() {
1481        assert_eq!(MentionKind::from_text("he"), MentionKind::Pronominal);
1482        assert_eq!(MentionKind::from_text("She"), MentionKind::Pronominal);
1483        assert_eq!(MentionKind::from_text("Barack Obama"), MentionKind::Proper);
1484        assert_eq!(
1485            MentionKind::from_text("the president"),
1486            MentionKind::Nominal
1487        );
1488    }
1489
1490    #[test]
1491    fn test_joint_model_analyze_empty() {
1492        let model = JointModel::new(JointConfig::default()).unwrap();
1493        let result = model.analyze("Hello world", &[]).unwrap();
1494
1495        assert!(result.entities.is_empty());
1496        assert!(result.chains.is_empty());
1497    }
1498
1499    #[test]
1500    fn test_joint_model_analyze_single_entity() {
1501        let model = JointModel::new(JointConfig::default()).unwrap();
1502        let entities = vec![Entity::new("Obama", EntityType::Person, 0, 5, 0.9)];
1503
1504        let result = model.analyze("Obama was here.", &entities).unwrap();
1505        assert!(!result.entities.is_empty());
1506    }
1507
1508    #[test]
1509    fn test_coarse_pruner() {
1510        let pruner = CoarsePruner::default();
1511
1512        let mentions = vec![
1513            JointMention {
1514                idx: 0,
1515                text: "Barack Obama".to_string(),
1516                head: "Obama".to_string(),
1517                start: 0,
1518                end: 12,
1519                mention_kind: MentionKind::Proper,
1520                entity_type: Some(EntityType::Person),
1521                entity: None,
1522            },
1523            JointMention {
1524                idx: 1,
1525                text: "France".to_string(),
1526                head: "France".to_string(),
1527                start: 21,
1528                end: 27,
1529                mention_kind: MentionKind::Proper,
1530                entity_type: Some(EntityType::Location),
1531                entity: None,
1532            },
1533            JointMention {
1534                idx: 2,
1535                text: "Obama".to_string(),
1536                head: "Obama".to_string(),
1537                start: 40,
1538                end: 45,
1539                mention_kind: MentionKind::Proper,
1540                entity_type: Some(EntityType::Person),
1541                entity: None,
1542            },
1543        ];
1544
1545        let candidates = pruner.prune_candidates(2, &mentions);
1546        // Should include mention 0 (head match "Obama") but maybe not mention 1
1547        assert!(!candidates.is_empty());
1548        // Mention 0 should be the best candidate due to head match
1549        assert!(candidates.contains(&0));
1550    }
1551
1552    // =========================================================================
1553    // Cross-Document Event Coreference Tests
1554    // =========================================================================
1555
1556    #[test]
1557    fn test_event_coref_relation_is_positive() {
1558        // Positive relations (should be clustered together)
1559        assert!(EventCorefRelation::Identity.is_positive());
1560        assert!(EventCorefRelation::ConceptInstance.is_positive());
1561        assert!(EventCorefRelation::WholeSubevent.is_positive());
1562        assert!(EventCorefRelation::SetMember.is_positive());
1563
1564        // Negative relations (not coreferent)
1565        assert!(!EventCorefRelation::TopicallyRelated.is_positive());
1566        assert!(!EventCorefRelation::NotRelated.is_positive());
1567        assert!(!EventCorefRelation::CannotDecide.is_positive());
1568    }
1569
1570    #[test]
1571    fn test_event_coref_relation_to_binary() {
1572        // Standard binary: only Identity is positive
1573        assert!(EventCorefRelation::Identity.to_binary());
1574        assert!(!EventCorefRelation::ConceptInstance.to_binary());
1575        assert!(!EventCorefRelation::WholeSubevent.to_binary());
1576        assert!(!EventCorefRelation::SetMember.to_binary());
1577        assert!(!EventCorefRelation::NotRelated.to_binary());
1578    }
1579
1580    #[test]
1581    fn test_event_coref_relation_to_strict_binary() {
1582        // Strict binary: all positive near-identity relations count
1583        assert!(EventCorefRelation::Identity.to_strict_binary());
1584        assert!(EventCorefRelation::ConceptInstance.to_strict_binary());
1585        assert!(EventCorefRelation::WholeSubevent.to_strict_binary());
1586        assert!(EventCorefRelation::SetMember.to_strict_binary());
1587        assert!(!EventCorefRelation::NotRelated.to_strict_binary());
1588        assert!(!EventCorefRelation::TopicallyRelated.to_strict_binary());
1589    }
1590
1591    #[test]
1592    fn test_decontextualized_mention() {
1593        let mention = DecontextualizedMention::new("it", "Apple Inc.", "doc001", 10, 12)
1594            .with_resolved("it", "Apple Inc.");
1595
1596        assert_eq!(mention.original_text, "it");
1597        assert_eq!(mention.decontextualized, "Apple Inc.");
1598        assert_eq!(mention.doc_id, "doc001");
1599        assert_eq!(mention.resolved_entities.len(), 1);
1600        assert_eq!(
1601            mention.resolved_entities[0],
1602            ("it".to_string(), "Apple Inc.".to_string())
1603        );
1604    }
1605
1606    #[test]
1607    fn test_event_mention() {
1608        let event = EventMention {
1609            id: "e001".to_string(),
1610            trigger: "announced".to_string(),
1611            event_type: Some("Communication".to_string()),
1612            doc_id: "doc001".to_string(),
1613            start: 15,
1614            end: 24,
1615            decontextualized: Some(DecontextualizedMention::new(
1616                "The company announced it yesterday",
1617                "Apple Inc. announced the new iPhone on March 15, 2024",
1618                "doc001",
1619                0,
1620                34,
1621            )),
1622        };
1623
1624        assert_eq!(event.id, "e001");
1625        assert_eq!(event.trigger, "announced");
1626        assert!(event.decontextualized.is_some());
1627        let decon = event.decontextualized.unwrap();
1628        assert!(decon.decontextualized.contains("Apple Inc."));
1629    }
1630
1631    // ==========================================================================
1632    // Trait Implementation Tests
1633    // ==========================================================================
1634
1635    #[test]
1636    fn test_model_trait_implementation() {
1637        use crate::Model;
1638
1639        let model = JointModel::default();
1640
1641        // Test Model trait methods
1642        assert_eq!(model.name(), "joint-model");
1643        assert!(model.description().contains("Durrett"));
1644        assert!(model.is_available());
1645
1646        let types = model.supported_types();
1647        assert!(!types.is_empty());
1648    }
1649
1650    #[test]
1651    fn test_model_extract_entities_simple() {
1652        use crate::Model;
1653
1654        let model = JointModel::default();
1655
1656        // Test with simple text containing capitalized words
1657        let text = "John Smith visited New York";
1658        let entities = model.extract_entities(text, None).unwrap();
1659
1660        // Heuristic detection may legitimately return empty output; this test only asserts no error.
1661        let _ = entities;
1662    }
1663
1664    #[test]
1665    fn test_coref_resolver_trait_implementation() {
1666        use anno_core::CoreferenceResolver;
1667
1668        let model = JointModel::default();
1669
1670        // Test CoreferenceResolver trait
1671        assert_eq!(model.name(), "joint-model-coref");
1672
1673        // Test with empty input
1674        let empty_result = model.resolve(&[]);
1675        assert!(empty_result.is_empty());
1676    }
1677
1678    #[test]
1679    fn test_coref_resolver_assigns_canonical_ids() {
1680        use anno_core::CoreferenceResolver;
1681
1682        let model = JointModel::default();
1683
1684        let entities = vec![
1685            Entity::new("John", EntityType::Person, 0, 4, 0.9),
1686            Entity::new("he", EntityType::Person, 10, 12, 0.8),
1687            Entity::new("Microsoft", EntityType::Organization, 20, 29, 0.95),
1688        ];
1689
1690        let resolved = model.resolve(&entities);
1691
1692        // All entities should have canonical IDs assigned
1693        assert_eq!(resolved.len(), 3);
1694        for entity in &resolved {
1695            assert!(entity.canonical_id.is_some());
1696        }
1697    }
1698
1699    #[test]
1700    fn test_builder_default() {
1701        let model = JointModelBuilder::new().build().unwrap();
1702
1703        // Default configuration matches JointConfig::default()
1704        let config = model.config();
1705        assert_eq!(config.max_iterations, 5); // Default is 5
1706        assert!(config.enable_link_ner);
1707        assert!(config.enable_coref_ner);
1708        assert!(config.enable_coref_link);
1709    }
1710
1711    #[test]
1712    fn test_builder_fluent_api() {
1713        let model = JointModelBuilder::new()
1714            .with_max_iterations(50)
1715            .with_convergence_threshold(1e-6)
1716            .with_pruning_threshold(0.5)
1717            .with_max_antecedent_candidates(100)
1718            .with_max_link_candidates(20)
1719            .enable_link_ner(false)
1720            .enable_coref_ner(true)
1721            .enable_coref_link(false)
1722            .build()
1723            .unwrap();
1724
1725        let config = model.config();
1726        assert_eq!(config.max_iterations, 50);
1727        assert!((config.convergence_threshold - 1e-6).abs() < 1e-10);
1728        assert!((config.pruning_threshold - 0.5).abs() < 1e-10);
1729        assert_eq!(config.max_antecedent_candidates, 100);
1730        assert_eq!(config.max_link_candidates, 20);
1731        assert!(!config.enable_link_ner);
1732        assert!(config.enable_coref_ner);
1733        assert!(!config.enable_coref_link);
1734    }
1735
1736    #[test]
1737    fn test_builder_with_entity_types() {
1738        let custom_types = vec![EntityType::Person, EntityType::Organization];
1739
1740        let model = JointModelBuilder::new()
1741            .with_entity_types(custom_types.clone())
1742            .build()
1743            .unwrap();
1744
1745        assert_eq!(model.config().entity_types, custom_types);
1746    }
1747
1748    #[test]
1749    fn test_heuristic_mention_detection() {
1750        // Test the heuristic detection directly
1751        let text = "Barack Obama met Angela Merkel in Berlin";
1752        let entities = detect_mentions_heuristic(text);
1753
1754        // Should detect capitalized sequences
1755        // Note: May vary based on implementation
1756        assert!(!entities.is_empty());
1757
1758        // All detected entities should have valid spans
1759        for entity in &entities {
1760            assert!(entity.start < entity.end);
1761            assert!(entity.end <= text.chars().count());
1762        }
1763    }
1764
1765    #[test]
1766    fn test_heuristic_mention_detection_unicode() {
1767        // Test with Unicode characters
1768        let text = "François Müller visited München";
1769        let entities = detect_mentions_heuristic(text);
1770
1771        // Should handle Unicode correctly
1772        for entity in &entities {
1773            assert!(entity.start <= entity.end);
1774            let char_count = text.chars().count();
1775            assert!(entity.end <= char_count);
1776        }
1777    }
1778
1779    #[test]
1780    fn test_extract_entities_from_mentions() {
1781        let model = JointModel::default();
1782
1783        let text = "John Smith visited New York. He liked the city.";
1784        let mentions = vec![
1785            JointMention::from_entity(
1786                0,
1787                &Entity::new("John Smith", EntityType::Person, 0, 10, 0.9),
1788                text,
1789            ),
1790            JointMention::from_entity(
1791                1,
1792                &Entity::new("New York", EntityType::Location, 19, 27, 0.85),
1793                text,
1794            ),
1795            JointMention::from_entity(2, &Entity::new("He", EntityType::Person, 29, 31, 0.7), text),
1796        ];
1797
1798        let result = model.extract_entities_from_mentions(text, &mentions);
1799        assert!(result.is_ok());
1800
1801        let entities = result.unwrap();
1802        // Should return entities with updated types/links based on joint inference
1803        assert!(!entities.is_empty());
1804    }
1805}