Skip to main content

anno/joint/
factors.rs

1//! Factor definitions for the joint entity analysis model.
2//!
3//! In factor graphs, **factors** (also called "potential functions") encode
4//! dependencies between variables. Each factor is a function ψ(X_S) over a
5//! subset S of variables that returns a non-negative score indicating how
6//! "good" that configuration is.
7//!
8//! The joint distribution is then:
9//!
10//! ```text
11//! P(X) ∝ ∏_f ψ_f(X_{S_f})
12//! ```
13//!
14//! # Factor Types in Joint Entity Analysis
15//!
16//! ## Unary Factors (single variable)
17//!
18//! These capture task-specific features for individual mentions:
19//!
20//! | Factor | Variable | What It Encodes |
21//! |--------|----------|-----------------|
22//! | `UnaryCorefFactor` | a_i | Mention-ranking features (string match, distance, etc.) |
23//! | `UnaryNerFactor` | t_i | NER classifier features (word shape, context, etc.) |
24//! | `UnaryLinkFactor` | e_i | Entity linking features (prior probability, name match, etc.) |
25//!
26//! ## Pairwise Cross-Task Factors
27//!
28//! These encode dependencies **between** different tasks:
29//!
30//! | Factor | Variables | What It Encodes |
31//! |--------|-----------|-----------------|
32//! | `LinkNerFactor` | (e_i, t_i) | Wikipedia type should match NER type |
33//! | `CorefNerFactor` | (a_i, t_i) | Coreferent mentions should have same type |
34//! | `CorefLinkFactor` | (a_i, e_i) | Coreferent mentions should link to same/related entities |
35//!
36//! # Example: LinkNer Factor
37//!
38//! If mention m links to Wikipedia page "Barack Obama" (a Person), and the NER
39//! system types m as PERSON, the `LinkNerFactor` assigns high score. If NER
40//! types m as ORG, the factor assigns low score (penalty).
41//!
42//! ```text
43//! LinkNerFactor(e="Barack Obama", t=PERSON) → +3.0 (type match)
44//! LinkNerFactor(e="Barack Obama", t=ORG)    → -2.0 (type mismatch)
45//! ```
46//!
47//! # Weight Learning
48//!
49//! Factor weights are learned from annotated data using structured perceptron
50//! or max-margin methods. See [`super::learning`] for training.
51//!
52//! # References
53//!
54//! - Durrett & Klein (2014): "A Joint Model for Entity Analysis" (TACL)
55//! - Kschischang et al. (2001): Factor Graphs and the Sum-Product Algorithm
56
57use super::types::{AntecedentValue, Assignment, LinkValue, VariableId, VariableType};
58use crate::linking::wikidata::{WikidataDictionary, WikidataNERType};
59use crate::EntityType;
60use serde::{Deserialize, Serialize};
61use std::collections::{HashMap, HashSet};
62use std::sync::Arc;
63
64// =============================================================================
65// Factor Trait
66// =============================================================================
67
68/// A factor (potential function) in the structured CRF.
69///
70/// Factors define the functions ψ(·) that encode soft constraints and
71/// dependencies between variables. The joint distribution over all
72/// variables is proportional to the product of all factor potentials:
73///
74/// ```text
75/// P(a,t,e|x) ∝ ∏_f ψ_f(vars_f, x)
76/// ```
77///
78/// where each factor ψ_f depends on a subset of variables (its **scope**).
79///
80/// # Implementation Notes
81///
82/// - Factors work in **log space** for numerical stability
83/// - `log_potential` should return `-f64::INFINITY` for impossible configurations
84/// - Factors must be `Send + Sync` for parallel inference
85///
86/// # Example
87///
88/// A simple type-consistency factor that prefers matching types:
89///
90/// ```ignore
91/// impl Factor for TypeConsistencyFactor {
92///     fn log_potential(&self, assignment: &Assignment) -> f64 {
93///         let type_i = assignment.get(&self.var_type_i);
94///         let type_j = assignment.get(&self.var_type_j);
95///         if type_i == type_j {
96///             self.match_weight // positive bonus
97///         } else {
98///             self.mismatch_penalty // negative penalty
99///         }
100///     }
101/// }
102/// ```
103pub trait Factor: Send + Sync {
104    /// The variables this factor touches (its "scope").
105    ///
106    /// A unary factor has scope of size 1.
107    /// A pairwise factor has scope of size 2.
108    /// Higher-order factors have larger scopes (but increase inference cost).
109    fn scope(&self) -> &[VariableId];
110
111    /// Log potential (unnormalized log probability) for an assignment.
112    ///
113    /// Returns the score θᵀφ(assignment) where θ are learned weights and
114    /// φ are feature functions. Higher scores indicate more likely configurations.
115    ///
116    /// Should return `-f64::INFINITY` for logically impossible configurations.
117    fn log_potential(&self, assignment: &Assignment) -> f64;
118
119    /// Human-readable factor name for debugging and visualization.
120    fn name(&self) -> &'static str;
121}
122
123// =============================================================================
124// Unary Factors
125// =============================================================================
126
127/// Unary factor for coreference (antecedent selection).
128///
129/// Features from mention-ranking model: string match, distance, type compatibility.
130#[derive(Debug, Clone)]
131pub struct UnaryCorefFactor {
132    /// Mention index
133    pub mention_idx: usize,
134    /// Variable ID
135    scope: Vec<VariableId>,
136    /// Precomputed scores for each antecedent candidate
137    pub scores: Vec<(AntecedentValue, f64)>,
138}
139
140impl UnaryCorefFactor {
141    /// Create a new unary coreference factor.
142    pub fn new(mention_idx: usize, scores: Vec<(AntecedentValue, f64)>) -> Self {
143        let scope = vec![VariableId {
144            mention_idx,
145            var_type: VariableType::Antecedent,
146        }];
147        Self {
148            mention_idx,
149            scope,
150            scores,
151        }
152    }
153}
154
155impl Factor for UnaryCorefFactor {
156    fn scope(&self) -> &[VariableId] {
157        &self.scope
158    }
159
160    fn log_potential(&self, assignment: &Assignment) -> f64 {
161        let antecedent = assignment.get_antecedent(self.mention_idx);
162        antecedent
163            .and_then(|a| self.scores.iter().find(|(v, _)| *v == a).map(|(_, s)| *s))
164            .unwrap_or(f64::NEG_INFINITY)
165    }
166
167    fn name(&self) -> &'static str {
168        "unary_coref"
169    }
170}
171
172/// Unary factor for NER (semantic typing).
173///
174/// Features: token features, gazetteer matches, context.
175#[derive(Debug, Clone)]
176pub struct UnaryNerFactor {
177    /// Mention index
178    pub mention_idx: usize,
179    /// Variable ID
180    scope: Vec<VariableId>,
181    /// Scores for each entity type
182    pub scores: Vec<(EntityType, f64)>,
183}
184
185impl UnaryNerFactor {
186    /// Create a new unary NER factor.
187    pub fn new(mention_idx: usize, scores: Vec<(EntityType, f64)>) -> Self {
188        let scope = vec![VariableId {
189            mention_idx,
190            var_type: VariableType::SemanticType,
191        }];
192        Self {
193            mention_idx,
194            scope,
195            scores,
196        }
197    }
198}
199
200impl Factor for UnaryNerFactor {
201    fn scope(&self) -> &[VariableId] {
202        &self.scope
203    }
204
205    fn log_potential(&self, assignment: &Assignment) -> f64 {
206        let entity_type = assignment.get_type(self.mention_idx);
207        entity_type
208            .and_then(|t| self.scores.iter().find(|(v, _)| *v == t).map(|(_, s)| *s))
209            .unwrap_or(f64::NEG_INFINITY)
210    }
211
212    fn name(&self) -> &'static str {
213        "unary_ner"
214    }
215}
216
217/// Unary factor for entity linking.
218///
219/// Features: string match, prior probability, context similarity.
220#[derive(Debug, Clone)]
221pub struct UnaryLinkFactor {
222    /// Mention index
223    pub mention_idx: usize,
224    /// Variable ID
225    scope: Vec<VariableId>,
226    /// Scores for each link candidate
227    pub scores: Vec<(LinkValue, f64)>,
228}
229
230impl UnaryLinkFactor {
231    /// Create a new unary link factor.
232    pub fn new(mention_idx: usize, scores: Vec<(LinkValue, f64)>) -> Self {
233        let scope = vec![VariableId {
234            mention_idx,
235            var_type: VariableType::EntityLink,
236        }];
237        Self {
238            mention_idx,
239            scope,
240            scores,
241        }
242    }
243}
244
245impl Factor for UnaryLinkFactor {
246    fn scope(&self) -> &[VariableId] {
247        &self.scope
248    }
249
250    fn log_potential(&self, assignment: &Assignment) -> f64 {
251        let link = assignment.get_link(self.mention_idx);
252        link.and_then(|l| self.scores.iter().find(|(v, _)| v == l).map(|(_, s)| *s))
253            .unwrap_or(f64::NEG_INFINITY)
254    }
255
256    fn name(&self) -> &'static str {
257        "unary_link"
258    }
259}
260
261// =============================================================================
262// Wikipedia Knowledge Store (for cross-task factors)
263// =============================================================================
264
265/// Store of Wikipedia/Wikidata knowledge for joint inference.
266///
267/// Provides semantic information about entities:
268/// - Type mappings (Q5 → Person, Q43229 → Organization)
269/// - Outgoing links (for relatedness computation)
270/// - Categories
271#[derive(Debug, Clone, Default)]
272pub struct WikipediaKnowledgeStore {
273    /// Entity types by KB ID
274    pub entity_types: HashMap<String, WikidataNERType>,
275    /// Outgoing links by KB ID
276    pub outlinks: HashMap<String, HashSet<String>>,
277    /// Categories by KB ID
278    pub categories: HashMap<String, Vec<String>>,
279    /// Wikidata dictionary for type lookups
280    pub dictionary: Option<Arc<WikidataDictionary>>,
281}
282
283impl WikipediaKnowledgeStore {
284    /// Create a new empty store.
285    pub fn new() -> Self {
286        Self::default()
287    }
288
289    /// Create from a Wikidata dictionary.
290    pub fn from_dictionary(dict: WikidataDictionary) -> Self {
291        Self {
292            dictionary: Some(Arc::new(dict)),
293            ..Default::default()
294        }
295    }
296
297    /// Get the entity type for a KB ID.
298    pub fn get_type(&self, kb_id: &str) -> Option<WikidataNERType> {
299        // First check explicit mappings
300        if let Some(t) = self.entity_types.get(kb_id) {
301            return Some(*t);
302        }
303
304        // Then check dictionary
305        if let Some(ref dict) = self.dictionary {
306            if let Some(entity) = dict.get(kb_id) {
307                return entity.entity_type;
308            }
309        }
310
311        None
312    }
313
314    /// Add a type mapping.
315    pub fn add_type(&mut self, kb_id: &str, ner_type: WikidataNERType) {
316        self.entity_types.insert(kb_id.to_string(), ner_type);
317    }
318
319    /// Add outgoing links for an entity.
320    pub fn add_outlinks(&mut self, kb_id: &str, links: impl IntoIterator<Item = String>) {
321        self.outlinks
322            .entry(kb_id.to_string())
323            .or_default()
324            .extend(links);
325    }
326
327    /// Check if two entities share outlinks.
328    pub fn shared_outlinks(&self, kb_id_a: &str, kb_id_b: &str) -> usize {
329        let links_a = self.outlinks.get(kb_id_a);
330        let links_b = self.outlinks.get(kb_id_b);
331
332        match (links_a, links_b) {
333            (Some(a), Some(b)) => a.intersection(b).count(),
334            _ => 0,
335        }
336    }
337
338    /// Check if one entity links to another.
339    pub fn has_link(&self, from: &str, to: &str) -> bool {
340        self.outlinks
341            .get(from)
342            .is_some_and(|links| links.contains(to))
343    }
344
345    /// Check if entities mutually link to each other.
346    pub fn mutual_link(&self, kb_id_a: &str, kb_id_b: &str) -> bool {
347        self.has_link(kb_id_a, kb_id_b) || self.has_link(kb_id_b, kb_id_a)
348    }
349
350    /// Compute relatedness score between two entities.
351    ///
352    /// Based on Wikipedia link structure (Milne & Witten 2008 style).
353    pub fn relatedness(&self, kb_id_a: &str, kb_id_b: &str) -> f64 {
354        if kb_id_a == kb_id_b {
355            return 1.0;
356        }
357
358        let shared = self.shared_outlinks(kb_id_a, kb_id_b) as f64;
359        let mutual = if self.mutual_link(kb_id_a, kb_id_b) {
360            1.0
361        } else {
362            0.0
363        };
364
365        // Simple relatedness: normalized shared links + bonus for mutual
366        let links_a = self.outlinks.get(kb_id_a).map_or(0, |l| l.len()) as f64;
367        let links_b = self.outlinks.get(kb_id_b).map_or(0, |l| l.len()) as f64;
368
369        if links_a + links_b == 0.0 {
370            return mutual * 0.5;
371        }
372
373        let jaccard = shared / (links_a + links_b - shared).max(1.0);
374        (jaccard + mutual * 0.3).min(1.0)
375    }
376}
377
378// =============================================================================
379// Cross-Task Factors
380// =============================================================================
381
382/// Factor coupling entity linking and NER.
383///
384/// Uses Wikipedia article semantics to inform NER type:
385/// - Infobox type (e.g., "company" → ORGANIZATION)
386/// - Categories (e.g., "American politicians" → PERSON)
387/// - First sentence copula (e.g., "is a British city" → LOCATION)
388///
389/// # Example
390///
391/// If mention links to `Dell` (Wikipedia):
392/// - Infobox type: company
393/// - Categories: Computer companies, Technology companies
394/// - → Strong signal for ORGANIZATION type
395#[derive(Debug, Clone)]
396pub struct LinkNerFactor {
397    /// Mention index
398    pub mention_idx: usize,
399    /// Variable IDs (type and link for same mention)
400    scope: Vec<VariableId>,
401    /// Weights for type-link compatibility
402    pub weights: LinkNerWeights,
403    /// Knowledge store for lookups
404    knowledge: Option<Arc<WikipediaKnowledgeStore>>,
405}
406
407/// Weights for Link+NER factor.
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct LinkNerWeights {
410    /// Weight for infobox/Wikidata type match
411    pub type_match: f64,
412    /// Weight for type mismatch (penalty, should be negative)
413    pub type_mismatch: f64,
414    /// Weight for category match
415    pub category_match: f64,
416    /// Weight for NIL link (no entity in KB)
417    pub nil_bonus: f64,
418}
419
420impl Default for LinkNerWeights {
421    fn default() -> Self {
422        Self {
423            type_match: 1.5,
424            type_mismatch: -1.0,
425            category_match: 0.5,
426            nil_bonus: 0.0,
427        }
428    }
429}
430
431impl LinkNerFactor {
432    /// Create a new Link+NER factor.
433    pub fn new(mention_idx: usize, weights: LinkNerWeights) -> Self {
434        let scope = vec![
435            VariableId {
436                mention_idx,
437                var_type: VariableType::SemanticType,
438            },
439            VariableId {
440                mention_idx,
441                var_type: VariableType::EntityLink,
442            },
443        ];
444        Self {
445            mention_idx,
446            scope,
447            weights,
448            knowledge: None,
449        }
450    }
451
452    /// Set knowledge store.
453    pub fn with_knowledge(mut self, knowledge: Arc<WikipediaKnowledgeStore>) -> Self {
454        self.knowledge = Some(knowledge);
455        self
456    }
457
458    /// Check if NER type matches Wikidata type.
459    fn types_compatible(ner_type: &EntityType, wiki_type: WikidataNERType) -> bool {
460        match wiki_type {
461            WikidataNERType::Person => matches!(ner_type, EntityType::Person),
462            WikidataNERType::Organization => matches!(ner_type, EntityType::Organization),
463            WikidataNERType::Location | WikidataNERType::GeopoliticalEntity => {
464                matches!(ner_type, EntityType::Location)
465                    || matches!(ner_type, EntityType::Other(ref s) if s == "GPE")
466            }
467            WikidataNERType::Event => {
468                matches!(ner_type, EntityType::Other(ref s) if s == "EVENT")
469            }
470            WikidataNERType::WorkOfArt => {
471                matches!(ner_type, EntityType::Other(ref s) if s == "WORK_OF_ART")
472            }
473            WikidataNERType::Product => {
474                matches!(ner_type, EntityType::Other(ref s) if s == "PRODUCT")
475            }
476            WikidataNERType::DateTime => {
477                matches!(ner_type, EntityType::Other(ref s) if s == "DATE")
478            }
479            WikidataNERType::Miscellaneous => true, // MISC compatible with anything
480        }
481    }
482}
483
484impl Factor for LinkNerFactor {
485    fn scope(&self) -> &[VariableId] {
486        &self.scope
487    }
488
489    fn log_potential(&self, assignment: &Assignment) -> f64 {
490        let entity_type = match assignment.get_type(self.mention_idx) {
491            Some(t) => t,
492            None => return 0.0,
493        };
494
495        let link = match assignment.get_link(self.mention_idx) {
496            Some(l) => l,
497            None => return 0.0,
498        };
499
500        // NIL links get a small bonus (or penalty depending on config)
501        let kb_id = match link {
502            LinkValue::KbId(id) => id,
503            LinkValue::Nil => return self.weights.nil_bonus,
504        };
505
506        // Look up entity type from knowledge store
507        let wiki_type = self.knowledge.as_ref().and_then(|k| k.get_type(kb_id));
508
509        match wiki_type {
510            Some(wt) => {
511                if Self::types_compatible(&entity_type, wt) {
512                    self.weights.type_match
513                } else {
514                    self.weights.type_mismatch
515                }
516            }
517            None => 0.0, // No type info available
518        }
519    }
520
521    fn name(&self) -> &'static str {
522        "link_ner"
523    }
524}
525
526/// Factor coupling coreference and NER.
527///
528/// Encourages consistent semantic types across coreference chains.
529/// Only fires when mention i is linked to mention j (a_i = j).
530///
531/// # Features
532///
533/// - Type pair: (t_i, t_j) indicator features
534/// - Monolexical: (t_i, head_j) and (t_j, head_i) features
535///
536/// # Example
537///
538/// If "he" → "John Smith" is a coref link:
539/// - t("John Smith") = PERSON
540/// - t("he") = PERSON
541/// - Factor score high for matching types
542#[derive(Debug, Clone)]
543pub struct CorefNerFactor {
544    /// Current mention index (i)
545    pub mention_i: usize,
546    /// Antecedent mention index (j)
547    pub mention_j: usize,
548    /// Variable IDs
549    scope: Vec<VariableId>,
550    /// Weights
551    pub weights: CorefNerWeights,
552    /// Head word of mention i for monolexical features.
553    pub head_i: Option<String>,
554    /// Head word of mention j for monolexical features.
555    pub head_j: Option<String>,
556    /// Monolexical feature lookup: (type, head) → weight
557    pub monolexical_weights: HashMap<(String, String), f64>,
558}
559
560/// Weights for Coref+NER factor.
561#[derive(Debug, Clone, Serialize, Deserialize)]
562pub struct CorefNerWeights {
563    /// Weight for same-type bonus
564    pub type_match: f64,
565    /// Weight for type mismatch penalty
566    pub type_mismatch: f64,
567    /// Weight for monolexical features (type + head word)
568    pub monolexical: f64,
569    /// Special weight for pronoun → proper name with matching type
570    pub pronoun_proper_match: f64,
571}
572
573impl Default for CorefNerWeights {
574    fn default() -> Self {
575        Self {
576            type_match: 1.0,
577            type_mismatch: -0.5,
578            monolexical: 0.3,
579            pronoun_proper_match: 0.5,
580        }
581    }
582}
583
584impl CorefNerFactor {
585    /// Create a new Coref+NER factor.
586    pub fn new(mention_i: usize, mention_j: usize, weights: CorefNerWeights) -> Self {
587        let scope = vec![
588            VariableId {
589                mention_idx: mention_i,
590                var_type: VariableType::Antecedent,
591            },
592            VariableId {
593                mention_idx: mention_i,
594                var_type: VariableType::SemanticType,
595            },
596            VariableId {
597                mention_idx: mention_j,
598                var_type: VariableType::SemanticType,
599            },
600        ];
601        Self {
602            mention_i,
603            mention_j,
604            scope,
605            weights,
606            head_i: None,
607            head_j: None,
608            monolexical_weights: HashMap::new(),
609        }
610    }
611
612    /// Set head words for monolexical features.
613    pub fn with_heads(mut self, head_i: &str, head_j: &str) -> Self {
614        self.head_i = Some(head_i.to_lowercase());
615        self.head_j = Some(head_j.to_lowercase());
616        self
617    }
618
619    /// Add monolexical weight.
620    pub fn add_monolexical_weight(&mut self, type_name: &str, head: &str, weight: f64) {
621        self.monolexical_weights
622            .insert((type_name.to_string(), head.to_lowercase()), weight);
623    }
624
625    /// Load default monolexical weights (common patterns).
626    pub fn with_default_monolexical(mut self) -> Self {
627        // Person type + common person head words
628        for head in ["mr", "mrs", "dr", "prof", "president", "ceo", "chairman"] {
629            self.add_monolexical_weight("Person", head, 0.5);
630        }
631
632        // Organization type + common org head words
633        for head in [
634            "company",
635            "corporation",
636            "inc",
637            "corp",
638            "llc",
639            "firm",
640            "bank",
641        ] {
642            self.add_monolexical_weight("Organization", head, 0.5);
643        }
644
645        // Location type + common location head words
646        for head in [
647            "city", "country", "state", "province", "region", "river", "mountain",
648        ] {
649            self.add_monolexical_weight("Location", head, 0.5);
650        }
651
652        self
653    }
654}
655
656impl Factor for CorefNerFactor {
657    fn scope(&self) -> &[VariableId] {
658        &self.scope
659    }
660
661    fn log_potential(&self, assignment: &Assignment) -> f64 {
662        // Only fires if i→j is a coreference link
663        let antecedent = assignment.get_antecedent(self.mention_i);
664        if antecedent != Some(AntecedentValue::Mention(self.mention_j)) {
665            return 0.0; // Factor doesn't contribute
666        }
667
668        let type_i = match assignment.get_type(self.mention_i) {
669            Some(t) => t,
670            None => return 0.0,
671        };
672
673        let type_j = match assignment.get_type(self.mention_j) {
674            Some(t) => t,
675            None => return 0.0,
676        };
677
678        let mut score = 0.0;
679
680        // Type pair feature
681        if type_i == type_j {
682            score += self.weights.type_match;
683        } else {
684            score += self.weights.type_mismatch;
685        }
686
687        // Monolexical features: (type_i, head_j)
688        if let Some(ref head_j) = self.head_j {
689            let type_name = format!("{:?}", type_i);
690            if let Some(&w) = self
691                .monolexical_weights
692                .get(&(type_name.clone(), head_j.clone()))
693            {
694                score += self.weights.monolexical * w;
695            }
696            // Also check simplified type name
697            let simple_type = match type_i {
698                EntityType::Person => "Person",
699                EntityType::Organization => "Organization",
700                EntityType::Location => "Location",
701                EntityType::Other(ref s) => s.as_str(),
702                _ => "Unknown",
703            };
704            if let Some(&w) = self
705                .monolexical_weights
706                .get(&(simple_type.to_string(), head_j.clone()))
707            {
708                score += self.weights.monolexical * w;
709            }
710        }
711
712        // Monolexical features: (type_j, head_i)
713        if let Some(ref head_i) = self.head_i {
714            let type_name = format!("{:?}", type_j);
715            if let Some(&w) = self
716                .monolexical_weights
717                .get(&(type_name.clone(), head_i.clone()))
718            {
719                score += self.weights.monolexical * w;
720            }
721            let simple_type = match type_j {
722                EntityType::Person => "Person",
723                EntityType::Organization => "Organization",
724                EntityType::Location => "Location",
725                EntityType::Other(ref s) => s.as_str(),
726                _ => "Unknown",
727            };
728            if let Some(&w) = self
729                .monolexical_weights
730                .get(&(simple_type.to_string(), head_i.clone()))
731            {
732                score += self.weights.monolexical * w;
733            }
734        }
735
736        score
737    }
738
739    fn name(&self) -> &'static str {
740        "coref_ner"
741    }
742}
743
744/// Factor coupling coreference and entity linking.
745///
746/// Encourages coreferent mentions to link to related Wikipedia articles.
747/// Only fires when mention i is linked to mention j (a_i = j).
748///
749/// # Features
750///
751/// - Same title: e_i = e_j (same article)
752/// - Shared outlinks: articles share outgoing links
753/// - Mutual links: one article links to the other
754///
755/// # Example
756///
757/// If "the company" → "Dell" is a coref link:
758/// - e("Dell") = Dell (company article)
759/// - e("the company") = Dell (should link to same)
760/// - Factor: high score for same entity
761#[derive(Debug, Clone)]
762pub struct CorefLinkFactor {
763    /// Current mention index (i)
764    pub mention_i: usize,
765    /// Antecedent mention index (j)
766    pub mention_j: usize,
767    /// Variable IDs
768    scope: Vec<VariableId>,
769    /// Weights
770    pub weights: CorefLinkWeights,
771    /// Knowledge store for Wikipedia graph
772    knowledge: Option<Arc<WikipediaKnowledgeStore>>,
773}
774
775/// Weights for Coref+Link factor.
776#[derive(Debug, Clone, Serialize, Deserialize)]
777pub struct CorefLinkWeights {
778    /// Weight for same entity (alias for same_title for learning.rs compatibility)
779    pub same_entity: f64,
780    /// Weight for different entity (negative, penalty for coreferent mentions with different links)
781    pub different_entity: f64,
782    /// Weight for same title (kept for backward compatibility)
783    pub same_title: f64,
784    /// Weight for shared outlinks (scaled by count)
785    pub shared_outlinks: f64,
786    /// Weight for mutual links
787    pub mutual_link: f64,
788    /// Weight for both being NIL
789    pub both_nil: f64,
790    /// Penalty for one NIL one not
791    pub nil_mismatch: f64,
792}
793
794impl Default for CorefLinkWeights {
795    fn default() -> Self {
796        Self {
797            same_entity: 2.0,       // Same as same_title
798            different_entity: -0.3, // Penalty for different links
799            same_title: 2.0,        // Backward compat
800            shared_outlinks: 0.1,
801            mutual_link: 1.0,
802            both_nil: 0.5,
803            nil_mismatch: -0.3,
804        }
805    }
806}
807
808impl CorefLinkFactor {
809    /// Create a new Coref+Link factor.
810    pub fn new(mention_i: usize, mention_j: usize, weights: CorefLinkWeights) -> Self {
811        let scope = vec![
812            VariableId {
813                mention_idx: mention_i,
814                var_type: VariableType::Antecedent,
815            },
816            VariableId {
817                mention_idx: mention_i,
818                var_type: VariableType::EntityLink,
819            },
820            VariableId {
821                mention_idx: mention_j,
822                var_type: VariableType::EntityLink,
823            },
824        ];
825        Self {
826            mention_i,
827            mention_j,
828            scope,
829            weights,
830            knowledge: None,
831        }
832    }
833
834    /// Set knowledge store.
835    pub fn with_knowledge(mut self, knowledge: Arc<WikipediaKnowledgeStore>) -> Self {
836        self.knowledge = Some(knowledge);
837        self
838    }
839}
840
841impl Factor for CorefLinkFactor {
842    fn scope(&self) -> &[VariableId] {
843        &self.scope
844    }
845
846    fn log_potential(&self, assignment: &Assignment) -> f64 {
847        // Only fires if i→j is a coreference link
848        let antecedent = assignment.get_antecedent(self.mention_i);
849        if antecedent != Some(AntecedentValue::Mention(self.mention_j)) {
850            return 0.0;
851        }
852
853        let link_i = match assignment.get_link(self.mention_i) {
854            Some(l) => l,
855            None => return 0.0,
856        };
857
858        let link_j = match assignment.get_link(self.mention_j) {
859            Some(l) => l,
860            None => return 0.0,
861        };
862
863        let mut score = 0.0;
864
865        // Handle NIL cases
866        match (link_i, link_j) {
867            (LinkValue::Nil, LinkValue::Nil) => {
868                return self.weights.both_nil;
869            }
870            (LinkValue::Nil, _) | (_, LinkValue::Nil) => {
871                return self.weights.nil_mismatch;
872            }
873            (LinkValue::KbId(id_i), LinkValue::KbId(id_j)) => {
874                // Same entity feature
875                if id_i == id_j {
876                    score += self.weights.same_entity;
877                    return score; // Same entity, no need for relatedness
878                }
879
880                // Wikipedia graph features
881                if let Some(ref knowledge) = self.knowledge {
882                    // Shared outlinks
883                    let shared = knowledge.shared_outlinks(id_i, id_j);
884                    score += self.weights.shared_outlinks * (shared as f64).ln().max(0.0);
885
886                    // Mutual links
887                    if knowledge.mutual_link(id_i, id_j) {
888                        score += self.weights.mutual_link;
889                    }
890                }
891            }
892        }
893
894        score
895    }
896
897    fn name(&self) -> &'static str {
898        "coref_link"
899    }
900}
901
902// =============================================================================
903// Tests
904// =============================================================================
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909
910    #[test]
911    fn test_unary_coref_factor() {
912        let scores = vec![
913            (AntecedentValue::NewCluster, -1.0),
914            (AntecedentValue::Mention(0), 0.5),
915        ];
916        let factor = UnaryCorefFactor::new(1, scores);
917
918        let mut assignment = Assignment::default();
919        assignment.set_antecedent(1, AntecedentValue::Mention(0));
920
921        let score = factor.log_potential(&assignment);
922        assert!((score - 0.5).abs() < 1e-6);
923    }
924
925    #[test]
926    fn test_coref_ner_factor_same_type() {
927        let factor = CorefNerFactor::new(1, 0, CorefNerWeights::default());
928
929        let mut assignment = Assignment::default();
930        assignment.set_antecedent(1, AntecedentValue::Mention(0));
931        assignment.set_type(0, EntityType::Person);
932        assignment.set_type(1, EntityType::Person);
933
934        let score = factor.log_potential(&assignment);
935        assert!(score > 0.0, "Same types should have positive score");
936    }
937
938    #[test]
939    fn test_coref_ner_factor_different_type() {
940        let factor = CorefNerFactor::new(1, 0, CorefNerWeights::default());
941
942        let mut assignment = Assignment::default();
943        assignment.set_antecedent(1, AntecedentValue::Mention(0));
944        assignment.set_type(0, EntityType::Person);
945        assignment.set_type(1, EntityType::Organization);
946
947        let score = factor.log_potential(&assignment);
948        assert!(score < 0.0, "Different types should have penalty");
949    }
950
951    #[test]
952    fn test_coref_ner_factor_no_link() {
953        let factor = CorefNerFactor::new(1, 0, CorefNerWeights::default());
954
955        let mut assignment = Assignment::default();
956        assignment.set_antecedent(1, AntecedentValue::NewCluster); // NOT linked
957        assignment.set_type(0, EntityType::Person);
958        assignment.set_type(1, EntityType::Organization);
959
960        let score = factor.log_potential(&assignment);
961        assert!(
962            (score - 0.0).abs() < 1e-6,
963            "Factor should not fire when not linked"
964        );
965    }
966
967    #[test]
968    fn test_coref_ner_factor_with_heads() {
969        let factor = CorefNerFactor::new(1, 0, CorefNerWeights::default())
970            .with_heads("he", "president")
971            .with_default_monolexical();
972
973        let mut assignment = Assignment::default();
974        assignment.set_antecedent(1, AntecedentValue::Mention(0));
975        assignment.set_type(0, EntityType::Person);
976        assignment.set_type(1, EntityType::Person);
977
978        let score = factor.log_potential(&assignment);
979        // Should include monolexical bonus for "president" + Person
980        assert!(score > 1.0, "Should have monolexical bonus");
981    }
982
983    #[test]
984    fn test_coref_link_factor_same_title() {
985        let factor = CorefLinkFactor::new(1, 0, CorefLinkWeights::default());
986
987        let mut assignment = Assignment::default();
988        assignment.set_antecedent(1, AntecedentValue::Mention(0));
989        assignment.set_link(0, LinkValue::KbId("Q42".to_string()));
990        assignment.set_link(1, LinkValue::KbId("Q42".to_string()));
991
992        let score = factor.log_potential(&assignment);
993        assert!(score > 0.0, "Same title should have high score");
994    }
995
996    #[test]
997    fn test_coref_link_factor_with_knowledge() {
998        let mut knowledge = WikipediaKnowledgeStore::new();
999        knowledge.add_outlinks("Q1", vec!["Q2".to_string(), "Q3".to_string()]);
1000        knowledge.add_outlinks("Q2", vec!["Q1".to_string(), "Q3".to_string()]);
1001
1002        let factor = CorefLinkFactor::new(1, 0, CorefLinkWeights::default())
1003            .with_knowledge(Arc::new(knowledge));
1004
1005        let mut assignment = Assignment::default();
1006        assignment.set_antecedent(1, AntecedentValue::Mention(0));
1007        assignment.set_link(0, LinkValue::KbId("Q1".to_string()));
1008        assignment.set_link(1, LinkValue::KbId("Q2".to_string()));
1009
1010        let score = factor.log_potential(&assignment);
1011        // Should have positive score from mutual links and shared outlinks
1012        assert!(score > 0.0, "Related entities should have positive score");
1013    }
1014
1015    #[test]
1016    fn test_link_ner_factor_type_match() {
1017        let mut knowledge = WikipediaKnowledgeStore::new();
1018        knowledge.add_type("Q937", WikidataNERType::Person);
1019
1020        let factor =
1021            LinkNerFactor::new(0, LinkNerWeights::default()).with_knowledge(Arc::new(knowledge));
1022
1023        let mut assignment = Assignment::default();
1024        assignment.set_type(0, EntityType::Person);
1025        assignment.set_link(0, LinkValue::KbId("Q937".to_string()));
1026
1027        let score = factor.log_potential(&assignment);
1028        assert!(score > 0.0, "Type match should have positive score");
1029    }
1030
1031    #[test]
1032    fn test_link_ner_factor_type_mismatch() {
1033        let mut knowledge = WikipediaKnowledgeStore::new();
1034        knowledge.add_type("Q937", WikidataNERType::Person);
1035
1036        let factor =
1037            LinkNerFactor::new(0, LinkNerWeights::default()).with_knowledge(Arc::new(knowledge));
1038
1039        let mut assignment = Assignment::default();
1040        assignment.set_type(0, EntityType::Organization); // Mismatch!
1041        assignment.set_link(0, LinkValue::KbId("Q937".to_string()));
1042
1043        let score = factor.log_potential(&assignment);
1044        assert!(score < 0.0, "Type mismatch should have negative score");
1045    }
1046
1047    #[test]
1048    fn test_link_ner_factor_nil() {
1049        let factor = LinkNerFactor::new(0, LinkNerWeights::default());
1050
1051        let mut assignment = Assignment::default();
1052        assignment.set_type(0, EntityType::Person);
1053        assignment.set_link(0, LinkValue::Nil);
1054
1055        let score = factor.log_potential(&assignment);
1056        // NIL should return nil_bonus (default 0)
1057        assert!((score - 0.0).abs() < 1e-6);
1058    }
1059
1060    #[test]
1061    fn test_wikipedia_knowledge_store_relatedness() {
1062        let mut knowledge = WikipediaKnowledgeStore::new();
1063        knowledge.add_outlinks("A", vec!["C".to_string(), "D".to_string()]);
1064        knowledge.add_outlinks("B", vec!["C".to_string(), "E".to_string()]);
1065
1066        // A and B share outlink to C
1067        let shared = knowledge.shared_outlinks("A", "B");
1068        assert_eq!(shared, 1);
1069
1070        // Self-relatedness is 1.0
1071        let self_rel = knowledge.relatedness("A", "A");
1072        assert!((self_rel - 1.0).abs() < 1e-10);
1073
1074        // A-B have some relatedness
1075        let rel = knowledge.relatedness("A", "B");
1076        assert!(rel > 0.0);
1077    }
1078
1079    #[test]
1080    fn test_types_compatible() {
1081        assert!(LinkNerFactor::types_compatible(
1082            &EntityType::Person,
1083            WikidataNERType::Person
1084        ));
1085        assert!(LinkNerFactor::types_compatible(
1086            &EntityType::Organization,
1087            WikidataNERType::Organization
1088        ));
1089        assert!(LinkNerFactor::types_compatible(
1090            &EntityType::Location,
1091            WikidataNERType::Location
1092        ));
1093        assert!(!LinkNerFactor::types_compatible(
1094            &EntityType::Person,
1095            WikidataNERType::Organization
1096        ));
1097    }
1098}