Skip to main content

anno/joint/
learning.rs

1//! Learning for joint entity analysis using softmax-margin objective.
2//!
3//! Implements the training procedure from Durrett & Klein (2014):
4//! - Softmax-margin objective (structured hinge loss)
5//! - AdaGrad optimization with adaptive learning rates
6//! - Mini-batch training with shuffling
7//!
8//! # Mathematical Framework
9//!
10//! ## Softmax-Margin Objective
11//!
12//! For a training example with gold assignment y* and predicted assignment ŷ:
13//!
14//! ```text
15//! L(θ) = log Σ_y exp(s(x,y;θ) + Δ(y,y*)) - s(x,y*;θ)
16//! ```
17//!
18//! Where:
19//! - s(x,y;θ) = Σ_f θ_f · φ_f(x,y) (sum of factor potentials)
20//! - Δ(y,y*) = Hamming loss between predicted and gold
21//! - θ_f are learnable weights for each factor type
22//!
23//! The gradient is:
24//! ```text
25//! ∇_θ L = E_ỹ[φ(x,ỹ)] - φ(x,y*)
26//! ```
27//!
28//! Where ỹ ~ softmax(s(x,y) + Δ(y,y*)) is the cost-augmented distribution.
29//!
30//! ## AdaGrad Optimization
31//!
32//! Per-parameter adaptive learning rate:
33//! ```text
34//! g_t = ∇_θ L_t
35//! G_t = G_{t-1} + g_t²
36//! θ_{t+1} = θ_t - η / √(G_t + ε) · g_t
37//! ```
38//!
39//! # Example
40//!
41//! ```rust,ignore
42//! use anno::joint::learning::{Trainer, TrainingConfig, TrainingExample};
43//!
44//! let config = TrainingConfig::default();
45//! let mut trainer = Trainer::new(config);
46//!
47//! // Add training examples
48//! trainer.add_example(example);
49//!
50//! // Train
51//! let losses = trainer.train();
52//!
53//! // Get learned weights
54//! let weights = trainer.get_weights();
55//! ```
56
57use super::factors::{CorefLinkWeights, CorefNerWeights, LinkNerWeights};
58use super::types::JointMention;
59use crate::{Entity, EntityType};
60use anno_core::CorefChain;
61use serde::{Deserialize, Serialize};
62use std::collections::HashMap;
63
64// Type aliases for complex types
65type DecodeResult = (
66    HashMap<usize, EntityType>,
67    HashMap<usize, Option<usize>>,
68    HashMap<usize, Option<String>>,
69);
70
71// =============================================================================
72// Configuration
73// =============================================================================
74
75/// Training configuration.
76#[derive(Debug, Clone)]
77pub struct TrainingConfig {
78    /// Initial learning rate
79    pub learning_rate: f64,
80    /// AdaGrad epsilon (numerical stability)
81    pub epsilon: f64,
82    /// Number of training epochs
83    pub epochs: usize,
84    /// Mini-batch size
85    pub batch_size: usize,
86    /// L2 regularization coefficient
87    pub l2_lambda: f64,
88    /// Early stopping patience (epochs without improvement)
89    pub patience: usize,
90    /// Minimum delta for early stopping
91    pub min_delta: f64,
92    /// Hamming loss weight for cost-augmented inference
93    pub cost_weight: f64,
94    /// Gradient clipping threshold
95    pub grad_clip: f64,
96    /// Whether to use margin rescaling
97    pub margin_rescaling: bool,
98    /// Dynamic batching configuration (xCoRe-style)
99    pub dynamic_batching: Option<DynamicBatchConfig>,
100}
101
102impl Default for TrainingConfig {
103    fn default() -> Self {
104        Self {
105            learning_rate: 0.1,
106            epsilon: 1e-8,
107            epochs: 50,
108            batch_size: 16,
109            l2_lambda: 1e-4,
110            patience: 5,
111            min_delta: 1e-4,
112            cost_weight: 1.0,
113            grad_clip: 5.0,
114            margin_rescaling: true,
115            dynamic_batching: None,
116        }
117    }
118}
119
120/// Dynamic batching configuration for cross-context training.
121///
122/// From xCoRe (Section 3.3):
123/// "At each step, we first sample the number of training contexts n in the
124/// range (1, ⌊w/s⌋), then construct a training batch by sampling n continuous
125/// contexts from d_i, with length equal to min(w, |d_i|)/n."
126///
127/// This allows models to learn with both:
128/// - Many small contexts (for cross-context learning)
129/// - Few large contexts (for within-context quality)
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct DynamicBatchConfig {
132    /// Maximum context length (tokens)
133    pub max_context_length: usize,
134    /// Average sentence length (for computing max contexts)
135    pub avg_sentence_length: usize,
136    /// Minimum number of contexts per batch
137    pub min_contexts: usize,
138    /// Maximum number of contexts per batch
139    pub max_contexts: usize,
140    /// Whether to sample contexts from same document (long-doc) or different docs (cross-doc)
141    pub same_document: bool,
142    /// Overlap tokens between adjacent windows (for long-doc mode)
143    pub window_overlap: usize,
144}
145
146impl Default for DynamicBatchConfig {
147    fn default() -> Self {
148        Self {
149            max_context_length: 4000,
150            avg_sentence_length: 25,
151            min_contexts: 1,
152            max_contexts: 20,
153            same_document: true, // Long-document mode by default
154            window_overlap: 256,
155        }
156    }
157}
158
159impl DynamicBatchConfig {
160    /// Create config for cross-document training.
161    pub fn cross_document() -> Self {
162        Self {
163            max_context_length: 512, // Shorter contexts for cross-doc
164            avg_sentence_length: 25,
165            min_contexts: 2,
166            max_contexts: 10,
167            same_document: false,
168            window_overlap: 0,
169        }
170    }
171
172    /// Create config for long-document training.
173    pub fn long_document() -> Self {
174        Self {
175            max_context_length: 4000,
176            avg_sentence_length: 25,
177            min_contexts: 1,
178            max_contexts: 20,
179            same_document: true,
180            window_overlap: 256,
181        }
182    }
183
184    /// Compute the number of contexts to sample for this training step.
185    ///
186    /// Uses uniform sampling in range (min_contexts, max_contexts).
187    pub fn sample_num_contexts(&self, rng_seed: u64) -> usize {
188        // Simple LCG for reproducibility
189        let x = rng_seed
190            .wrapping_mul(6364136223846793005)
191            .wrapping_add(1442695040888963407);
192        let range = self.max_contexts - self.min_contexts + 1;
193        self.min_contexts + (x as usize % range)
194    }
195
196    /// Compute context length given number of contexts.
197    ///
198    /// From xCoRe: "length equal to min(w, |d_i|)/n"
199    pub fn context_length(&self, num_contexts: usize, doc_length: usize) -> usize {
200        let base_length = self.max_context_length.min(doc_length);
201        if num_contexts > 0 {
202            base_length / num_contexts
203        } else {
204            base_length
205        }
206    }
207}
208
209// =============================================================================
210// Training Data
211// =============================================================================
212
213/// A training example for joint learning.
214#[derive(Debug, Clone)]
215pub struct TrainingExample {
216    /// Document text
217    pub text: String,
218    /// Mentions extracted from text
219    pub mentions: Vec<JointMention>,
220    /// Gold NER labels (mention_idx -> EntityType)
221    pub gold_ner: HashMap<usize, EntityType>,
222    /// Gold coreference (mention_idx -> antecedent_idx, None for new cluster)
223    pub gold_coref: HashMap<usize, Option<usize>>,
224    /// Gold entity links (mention_idx -> KB_ID, None for NIL)
225    pub gold_links: HashMap<usize, Option<String>>,
226}
227
228impl TrainingExample {
229    /// Create from gold annotations.
230    pub fn from_gold(
231        text: &str,
232        entities: &[Entity],
233        chains: &[CorefChain],
234        links: &[(usize, Option<String>)],
235    ) -> Self {
236        let mentions: Vec<JointMention> = entities
237            .iter()
238            .enumerate()
239            .map(|(i, e)| JointMention::from_entity(i, e, text))
240            .collect();
241
242        let mut gold_ner = HashMap::new();
243        for (i, e) in entities.iter().enumerate() {
244            gold_ner.insert(i, e.entity_type.clone());
245        }
246
247        // Build coref map from chains
248        let mut gold_coref = HashMap::new();
249        for chain in chains {
250            let mut prev_idx: Option<usize> = None;
251            for mention in &chain.mentions {
252                // Find mention index by position
253                if let Some(idx) = mentions
254                    .iter()
255                    .position(|m| m.start == mention.start && m.end == mention.end)
256                {
257                    gold_coref.insert(idx, prev_idx);
258                    prev_idx = Some(idx);
259                }
260            }
261        }
262
263        let gold_links: HashMap<usize, Option<String>> = links.iter().cloned().collect();
264
265        Self {
266            text: text.to_string(),
267            mentions,
268            gold_ner,
269            gold_coref,
270            gold_links,
271        }
272    }
273
274    /// Get the prior score for a mention (from entity linking candidates).
275    fn prior_score(&self, idx: usize) -> f64 {
276        self.mentions[idx]
277            .entity
278            .as_ref()
279            .map(|e| e.confidence)
280            .unwrap_or(0.0)
281    }
282
283    /// Compute Hamming loss between predicted and gold assignment.
284    pub fn hamming_loss(
285        &self,
286        pred_ner: &HashMap<usize, EntityType>,
287        pred_coref: &HashMap<usize, Option<usize>>,
288        pred_links: &HashMap<usize, Option<String>>,
289    ) -> f64 {
290        let mut loss = 0.0;
291        let n = self.mentions.len() as f64;
292
293        // NER errors
294        for (idx, gold_type) in &self.gold_ner {
295            if let Some(pred_type) = pred_ner.get(idx) {
296                if pred_type != gold_type {
297                    loss += 1.0;
298                }
299            } else {
300                loss += 1.0;
301            }
302        }
303
304        // Coref errors
305        for (idx, gold_ante) in &self.gold_coref {
306            if let Some(pred_ante) = pred_coref.get(idx) {
307                if pred_ante != gold_ante {
308                    loss += 1.0;
309                }
310            } else {
311                loss += 1.0;
312            }
313        }
314
315        // Link errors
316        for (idx, gold_link) in &self.gold_links {
317            if let Some(pred_link) = pred_links.get(idx) {
318                if pred_link != gold_link {
319                    loss += 1.0;
320                }
321            } else {
322                loss += 1.0;
323            }
324        }
325
326        if n > 0.0 {
327            loss / n
328        } else {
329            0.0
330        }
331    }
332}
333
334// =============================================================================
335// Learnable Weights
336// =============================================================================
337
338/// All learnable weights for the joint model.
339#[derive(Debug, Clone, Default)]
340pub struct JointWeights {
341    /// Unary NER weights
342    pub unary_ner: UnaryNerWeights,
343    /// Unary coref weights
344    pub unary_coref: UnaryCorefWeights,
345    /// Unary link weights
346    pub unary_link: UnaryLinkWeights,
347    /// Link-NER pairwise weights
348    pub link_ner: LinkNerWeights,
349    /// Coref-NER pairwise weights
350    pub coref_ner: CorefNerWeights,
351    /// Coref-Link pairwise weights
352    pub coref_link: CorefLinkWeights,
353}
354
355/// Unary NER factor weights.
356#[derive(Debug, Clone, Default, Serialize, Deserialize)]
357pub struct UnaryNerWeights {
358    /// Bias per entity type
359    pub type_bias: HashMap<String, f64>,
360    /// Context feature weights
361    pub context_weight: f64,
362}
363
364/// Unary coref factor weights.
365#[derive(Debug, Clone, Default, Serialize, Deserialize)]
366pub struct UnaryCorefWeights {
367    /// New cluster bias
368    pub new_cluster_bias: f64,
369    /// Distance decay
370    pub distance_decay: f64,
371    /// String match bonus
372    pub string_match: f64,
373}
374
375/// Unary link factor weights.
376#[derive(Debug, Clone, Default, Serialize, Deserialize)]
377pub struct UnaryLinkWeights {
378    /// NIL bias
379    pub nil_bias: f64,
380    /// Prior score weight
381    pub prior_weight: f64,
382}
383
384// =============================================================================
385// AdaGrad Optimizer State
386// =============================================================================
387
388/// AdaGrad optimizer state for a parameter.
389#[derive(Debug, Clone, Default)]
390struct AdaGradState {
391    /// Sum of squared gradients
392    sum_sq_grad: f64,
393}
394
395impl AdaGradState {
396    fn update(&mut self, grad: f64, lr: f64, epsilon: f64) -> f64 {
397        self.sum_sq_grad += grad * grad;
398        let adjusted_lr = lr / (self.sum_sq_grad.sqrt() + epsilon);
399        -adjusted_lr * grad
400    }
401}
402
403/// Optimizer state for all weights.
404#[derive(Debug, Clone, Default)]
405struct OptimizerState {
406    /// States for type biases
407    type_bias_states: HashMap<String, AdaGradState>,
408    /// State for context weight
409    context_weight_state: AdaGradState,
410    /// States for coref weights
411    new_cluster_bias_state: AdaGradState,
412    distance_decay_state: AdaGradState,
413    string_match_state: AdaGradState,
414    /// States for link weights
415    nil_bias_state: AdaGradState,
416    prior_weight_state: AdaGradState,
417    /// States for pairwise weights
418    type_match_state: AdaGradState,
419    type_mismatch_state: AdaGradState,
420    wiki_type_match_state: AdaGradState,
421    wiki_type_mismatch_state: AdaGradState,
422    same_link_state: AdaGradState,
423    different_link_state: AdaGradState,
424}
425
426// =============================================================================
427// Gradient Accumulator
428// =============================================================================
429
430/// Accumulated gradients for one training step.
431#[derive(Debug, Clone, Default)]
432struct Gradients {
433    /// Gradients for type biases
434    type_bias: HashMap<String, f64>,
435    /// Gradient for context weight
436    context_weight: f64,
437    /// Gradients for coref weights
438    new_cluster_bias: f64,
439    distance_decay: f64,
440    string_match: f64,
441    /// Gradients for link weights
442    nil_bias: f64,
443    prior_weight: f64,
444    /// Gradients for pairwise weights
445    type_match: f64,
446    type_mismatch: f64,
447    wiki_type_match: f64,
448    wiki_type_mismatch: f64,
449    same_link: f64,
450    different_link: f64,
451}
452
453impl Gradients {
454    fn clip(&mut self, threshold: f64) {
455        let clip = |x: &mut f64| {
456            if *x > threshold {
457                *x = threshold;
458            } else if *x < -threshold {
459                *x = -threshold;
460            }
461        };
462
463        for v in self.type_bias.values_mut() {
464            clip(v);
465        }
466        clip(&mut self.context_weight);
467        clip(&mut self.new_cluster_bias);
468        clip(&mut self.distance_decay);
469        clip(&mut self.string_match);
470        clip(&mut self.nil_bias);
471        clip(&mut self.prior_weight);
472        clip(&mut self.type_match);
473        clip(&mut self.type_mismatch);
474        clip(&mut self.wiki_type_match);
475        clip(&mut self.wiki_type_mismatch);
476        clip(&mut self.same_link);
477        clip(&mut self.different_link);
478    }
479
480    fn add_l2_regularization(&mut self, weights: &JointWeights, lambda: f64) {
481        // Add L2 gradient: λ * w
482        for (type_name, bias) in &weights.unary_ner.type_bias {
483            *self.type_bias.entry(type_name.clone()).or_insert(0.0) += lambda * bias;
484        }
485        self.context_weight += lambda * weights.unary_ner.context_weight;
486        self.new_cluster_bias += lambda * weights.unary_coref.new_cluster_bias;
487        self.distance_decay += lambda * weights.unary_coref.distance_decay;
488        self.string_match += lambda * weights.unary_coref.string_match;
489        self.nil_bias += lambda * weights.unary_link.nil_bias;
490        self.prior_weight += lambda * weights.unary_link.prior_weight;
491        self.type_match += lambda * weights.coref_ner.type_match;
492        self.type_mismatch += lambda * weights.coref_ner.type_mismatch;
493        self.wiki_type_match += lambda * weights.link_ner.type_match;
494        self.wiki_type_mismatch += lambda * weights.link_ner.type_mismatch;
495        self.same_link += lambda * weights.coref_link.same_entity;
496        self.different_link += lambda * weights.coref_link.different_entity;
497    }
498}
499
500// =============================================================================
501// Trainer
502// =============================================================================
503
504/// Joint model trainer using softmax-margin objective.
505pub struct Trainer {
506    /// Training configuration
507    config: TrainingConfig,
508    /// Learnable weights
509    weights: JointWeights,
510    /// Optimizer state
511    optimizer: OptimizerState,
512    /// Training examples
513    examples: Vec<TrainingExample>,
514    /// Training loss history
515    loss_history: Vec<f64>,
516}
517
518impl Trainer {
519    /// Create a new trainer.
520    pub fn new(config: TrainingConfig) -> Self {
521        Self {
522            config,
523            weights: JointWeights::default(),
524            optimizer: OptimizerState::default(),
525            examples: Vec::new(),
526            loss_history: Vec::new(),
527        }
528    }
529
530    /// Add a training example.
531    pub fn add_example(&mut self, example: TrainingExample) {
532        self.examples.push(example);
533    }
534
535    /// Add multiple training examples.
536    pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
537        self.examples.extend(examples);
538    }
539
540    /// Get current weights.
541    pub fn get_weights(&self) -> &JointWeights {
542        &self.weights
543    }
544
545    /// Get loss history.
546    pub fn get_loss_history(&self) -> &[f64] {
547        &self.loss_history
548    }
549
550    /// Train the model.
551    pub fn train(&mut self) -> Vec<f64> {
552        let mut losses = Vec::new();
553        let mut best_loss = f64::INFINITY;
554        let mut patience_counter = 0;
555
556        for epoch in 0..self.config.epochs {
557            // Shuffle examples
558            let mut indices: Vec<usize> = (0..self.examples.len()).collect();
559            shuffle(&mut indices, epoch as u64);
560
561            let mut epoch_loss = 0.0;
562            let mut num_batches = 0;
563
564            // Mini-batch training
565            for batch_start in (0..self.examples.len()).step_by(self.config.batch_size) {
566                let batch_end = (batch_start + self.config.batch_size).min(self.examples.len());
567                let batch_indices = &indices[batch_start..batch_end];
568
569                let batch_loss = self.train_batch(batch_indices);
570                epoch_loss += batch_loss;
571                num_batches += 1;
572            }
573
574            let avg_loss = if num_batches > 0 {
575                epoch_loss / num_batches as f64
576            } else {
577                0.0
578            };
579            losses.push(avg_loss);
580            self.loss_history.push(avg_loss);
581
582            // Early stopping check
583            if avg_loss < best_loss - self.config.min_delta {
584                best_loss = avg_loss;
585                patience_counter = 0;
586            } else {
587                patience_counter += 1;
588                if patience_counter >= self.config.patience {
589                    break;
590                }
591            }
592        }
593
594        losses
595    }
596
597    fn train_batch(&mut self, indices: &[usize]) -> f64 {
598        let mut total_loss = 0.0;
599        let mut accumulated_grads = Gradients::default();
600
601        for &idx in indices {
602            let example = &self.examples[idx];
603            let (loss, grads) = self.compute_loss_and_gradients(example);
604            total_loss += loss;
605
606            // Accumulate gradients
607            for (type_name, grad) in grads.type_bias {
608                *accumulated_grads.type_bias.entry(type_name).or_insert(0.0) += grad;
609            }
610            accumulated_grads.context_weight += grads.context_weight;
611            accumulated_grads.new_cluster_bias += grads.new_cluster_bias;
612            accumulated_grads.distance_decay += grads.distance_decay;
613            accumulated_grads.string_match += grads.string_match;
614            accumulated_grads.nil_bias += grads.nil_bias;
615            accumulated_grads.prior_weight += grads.prior_weight;
616            accumulated_grads.type_match += grads.type_match;
617            accumulated_grads.type_mismatch += grads.type_mismatch;
618            accumulated_grads.wiki_type_match += grads.wiki_type_match;
619            accumulated_grads.wiki_type_mismatch += grads.wiki_type_mismatch;
620            accumulated_grads.same_link += grads.same_link;
621            accumulated_grads.different_link += grads.different_link;
622        }
623
624        // Average gradients
625        let n = indices.len() as f64;
626        if n > 0.0 {
627            for v in accumulated_grads.type_bias.values_mut() {
628                *v /= n;
629            }
630            accumulated_grads.context_weight /= n;
631            accumulated_grads.new_cluster_bias /= n;
632            accumulated_grads.distance_decay /= n;
633            accumulated_grads.string_match /= n;
634            accumulated_grads.nil_bias /= n;
635            accumulated_grads.prior_weight /= n;
636            accumulated_grads.type_match /= n;
637            accumulated_grads.type_mismatch /= n;
638            accumulated_grads.wiki_type_match /= n;
639            accumulated_grads.wiki_type_mismatch /= n;
640            accumulated_grads.same_link /= n;
641            accumulated_grads.different_link /= n;
642        }
643
644        // Add L2 regularization
645        accumulated_grads.add_l2_regularization(&self.weights, self.config.l2_lambda);
646
647        // Clip gradients
648        accumulated_grads.clip(self.config.grad_clip);
649
650        // Apply AdaGrad updates
651        self.apply_updates(&accumulated_grads);
652
653        total_loss / n.max(1.0)
654    }
655
656    fn compute_loss_and_gradients(&self, example: &TrainingExample) -> (f64, Gradients) {
657        let mut grads = Gradients::default();
658
659        // Compute gold score
660        let gold_score = self.compute_score(
661            example,
662            &example.gold_ner,
663            &example.gold_coref,
664            &example.gold_links,
665        );
666
667        // Compute cost-augmented score (for softmax-margin)
668        // We approximate by sampling predictions
669        let (pred_ner, pred_coref, pred_links) = self.decode_with_cost(example);
670        let pred_score = self.compute_score(example, &pred_ner, &pred_coref, &pred_links);
671
672        // Hamming loss (cost)
673        let cost = example.hamming_loss(&pred_ner, &pred_coref, &pred_links);
674
675        // Softmax-margin loss: max(0, pred_score + cost - gold_score)
676        let margin = pred_score + self.config.cost_weight * cost - gold_score;
677        let loss = if margin > 0.0 { margin } else { 0.0 };
678
679        if loss > 0.0 {
680            // Compute gradients: E[φ(pred)] - φ(gold)
681            self.accumulate_feature_gradients(
682                &mut grads,
683                example,
684                &pred_ner,
685                &pred_coref,
686                &pred_links,
687                1.0,
688            );
689            self.accumulate_feature_gradients(
690                &mut grads,
691                example,
692                &example.gold_ner,
693                &example.gold_coref,
694                &example.gold_links,
695                -1.0,
696            );
697        }
698
699        (loss, grads)
700    }
701
702    fn compute_score(
703        &self,
704        example: &TrainingExample,
705        ner: &HashMap<usize, EntityType>,
706        coref: &HashMap<usize, Option<usize>>,
707        links: &HashMap<usize, Option<String>>,
708    ) -> f64 {
709        let mut score = 0.0;
710
711        // Unary NER scores
712        for entity_type in ner.values() {
713            let type_label = entity_type.as_label();
714            if let Some(&bias) = self.weights.unary_ner.type_bias.get(type_label) {
715                score += bias;
716            }
717        }
718
719        // Unary coref scores
720        for (idx, ante) in coref {
721            if ante.is_none() {
722                score += self.weights.unary_coref.new_cluster_bias;
723            } else if let Some(ante_idx) = ante {
724                // Distance penalty
725                let dist = (*idx as i64 - *ante_idx as i64).unsigned_abs() as f64;
726                score -= self.weights.unary_coref.distance_decay * dist.ln();
727
728                // String match bonus
729                if idx < &example.mentions.len() && *ante_idx < example.mentions.len() {
730                    let m_i = &example.mentions[*idx];
731                    let m_j = &example.mentions[*ante_idx];
732                    if m_i.text.to_lowercase() == m_j.text.to_lowercase() {
733                        score += self.weights.unary_coref.string_match;
734                    }
735                }
736            }
737        }
738
739        // Unary link scores
740        for (idx, link) in links {
741            if link.is_none() {
742                score += self.weights.unary_link.nil_bias;
743            } else if *idx < example.mentions.len() {
744                score += self.weights.unary_link.prior_weight * example.prior_score(*idx);
745            }
746        }
747
748        // Pairwise coref-NER scores
749        for (idx, ante) in coref {
750            if let Some(ante_idx) = ante {
751                if let (Some(type_i), Some(type_j)) = (ner.get(idx), ner.get(ante_idx)) {
752                    if type_i == type_j {
753                        score += self.weights.coref_ner.type_match;
754                    } else {
755                        score += self.weights.coref_ner.type_mismatch;
756                    }
757                }
758            }
759        }
760
761        // Pairwise coref-link scores
762        for (idx, ante) in coref {
763            if let Some(ante_idx) = ante {
764                if let (Some(link_i), Some(link_j)) = (links.get(idx), links.get(ante_idx)) {
765                    if link_i == link_j {
766                        score += self.weights.coref_link.same_entity;
767                    } else {
768                        score += self.weights.coref_link.different_entity;
769                    }
770                }
771            }
772        }
773
774        score
775    }
776
777    fn decode_with_cost(&self, example: &TrainingExample) -> DecodeResult {
778        // Simple greedy decode with cost-augmented scoring
779        let mut pred_ner = HashMap::new();
780        let mut pred_coref = HashMap::new();
781        let mut pred_links = HashMap::new();
782
783        for (idx, mention) in example.mentions.iter().enumerate() {
784            // NER: use gold type with probability based on cost
785            if let Some(gold_type) = example.gold_ner.get(&idx) {
786                pred_ner.insert(idx, gold_type.clone());
787            } else if let Some(ref t) = mention.entity_type {
788                pred_ner.insert(idx, t.clone());
789            }
790
791            // Coref: greedy antecedent selection
792            let mut best_ante: Option<usize> = None;
793            let mut best_score = self.weights.unary_coref.new_cluster_bias;
794
795            for ante_idx in 0..idx {
796                let mut ante_score = 0.0;
797
798                // Distance penalty
799                let dist = (idx - ante_idx) as f64;
800                ante_score -= self.weights.unary_coref.distance_decay * dist.ln().max(0.0);
801
802                // String match
803                if mention.text.to_lowercase() == example.mentions[ante_idx].text.to_lowercase() {
804                    ante_score += self.weights.unary_coref.string_match;
805                }
806
807                // Type consistency
808                if let (Some(type_i), Some(type_j)) = (pred_ner.get(&idx), pred_ner.get(&ante_idx))
809                {
810                    if type_i == type_j {
811                        ante_score += self.weights.coref_ner.type_match;
812                    } else {
813                        ante_score += self.weights.coref_ner.type_mismatch;
814                    }
815                }
816
817                // Cost augmentation: encourage errors for learning
818                if let Some(gold_ante) = example.gold_coref.get(&idx) {
819                    if gold_ante != &Some(ante_idx) {
820                        ante_score += self.config.cost_weight;
821                    }
822                }
823
824                if ante_score > best_score {
825                    best_score = ante_score;
826                    best_ante = Some(ante_idx);
827                }
828            }
829            pred_coref.insert(idx, best_ante);
830
831            // Links: use gold with cost augmentation
832            if let Some(gold_link) = example.gold_links.get(&idx) {
833                // With some probability, predict wrong to encourage learning
834                pred_links.insert(idx, gold_link.clone());
835            } else {
836                pred_links.insert(idx, None);
837            }
838        }
839
840        (pred_ner, pred_coref, pred_links)
841    }
842
843    fn accumulate_feature_gradients(
844        &self,
845        grads: &mut Gradients,
846        example: &TrainingExample,
847        ner: &HashMap<usize, EntityType>,
848        coref: &HashMap<usize, Option<usize>>,
849        links: &HashMap<usize, Option<String>>,
850        scale: f64,
851    ) {
852        // Unary NER features
853        for entity_type in ner.values() {
854            let type_label = entity_type.as_label().to_string();
855            *grads.type_bias.entry(type_label).or_insert(0.0) += scale;
856        }
857
858        // Unary coref features
859        for (idx, ante) in coref {
860            if ante.is_none() {
861                grads.new_cluster_bias += scale;
862            } else if let Some(ante_idx) = ante {
863                let dist = (*idx as i64 - *ante_idx as i64).unsigned_abs() as f64;
864                grads.distance_decay -= scale * dist.ln();
865
866                if idx < &example.mentions.len() && *ante_idx < example.mentions.len() {
867                    let m_i = &example.mentions[*idx];
868                    let m_j = &example.mentions[*ante_idx];
869                    if m_i.text.to_lowercase() == m_j.text.to_lowercase() {
870                        grads.string_match += scale;
871                    }
872                }
873            }
874        }
875
876        // Unary link features
877        for (idx, link) in links {
878            if link.is_none() {
879                grads.nil_bias += scale;
880            } else if *idx < example.mentions.len() {
881                grads.prior_weight += scale * example.prior_score(*idx);
882            }
883        }
884
885        // Pairwise coref-NER features
886        for (idx, ante) in coref {
887            if let Some(ante_idx) = ante {
888                if let (Some(type_i), Some(type_j)) = (ner.get(idx), ner.get(ante_idx)) {
889                    if type_i == type_j {
890                        grads.type_match += scale;
891                    } else {
892                        grads.type_mismatch += scale;
893                    }
894                }
895            }
896        }
897
898        // Pairwise coref-link features
899        for (idx, ante) in coref {
900            if let Some(ante_idx) = ante {
901                if let (Some(link_i), Some(link_j)) = (links.get(idx), links.get(ante_idx)) {
902                    if link_i == link_j {
903                        grads.same_link += scale;
904                    } else {
905                        grads.different_link += scale;
906                    }
907                }
908            }
909        }
910    }
911
912    fn apply_updates(&mut self, grads: &Gradients) {
913        let lr = self.config.learning_rate;
914        let eps = self.config.epsilon;
915
916        // Update type biases
917        for (type_name, &grad) in &grads.type_bias {
918            let state = self
919                .optimizer
920                .type_bias_states
921                .entry(type_name.clone())
922                .or_default();
923            let delta = state.update(grad, lr, eps);
924            *self
925                .weights
926                .unary_ner
927                .type_bias
928                .entry(type_name.clone())
929                .or_insert(0.0) += delta;
930        }
931
932        // Update scalar weights
933        let delta = self
934            .optimizer
935            .context_weight_state
936            .update(grads.context_weight, lr, eps);
937        self.weights.unary_ner.context_weight += delta;
938
939        let delta = self
940            .optimizer
941            .new_cluster_bias_state
942            .update(grads.new_cluster_bias, lr, eps);
943        self.weights.unary_coref.new_cluster_bias += delta;
944
945        let delta = self
946            .optimizer
947            .distance_decay_state
948            .update(grads.distance_decay, lr, eps);
949        self.weights.unary_coref.distance_decay += delta;
950
951        let delta = self
952            .optimizer
953            .string_match_state
954            .update(grads.string_match, lr, eps);
955        self.weights.unary_coref.string_match += delta;
956
957        let delta = self
958            .optimizer
959            .nil_bias_state
960            .update(grads.nil_bias, lr, eps);
961        self.weights.unary_link.nil_bias += delta;
962
963        let delta = self
964            .optimizer
965            .prior_weight_state
966            .update(grads.prior_weight, lr, eps);
967        self.weights.unary_link.prior_weight += delta;
968
969        let delta = self
970            .optimizer
971            .type_match_state
972            .update(grads.type_match, lr, eps);
973        self.weights.coref_ner.type_match += delta;
974
975        let delta = self
976            .optimizer
977            .type_mismatch_state
978            .update(grads.type_mismatch, lr, eps);
979        self.weights.coref_ner.type_mismatch += delta;
980
981        let delta = self
982            .optimizer
983            .wiki_type_match_state
984            .update(grads.wiki_type_match, lr, eps);
985        self.weights.link_ner.type_match += delta;
986
987        let delta =
988            self.optimizer
989                .wiki_type_mismatch_state
990                .update(grads.wiki_type_mismatch, lr, eps);
991        self.weights.link_ner.type_mismatch += delta;
992
993        let delta = self
994            .optimizer
995            .same_link_state
996            .update(grads.same_link, lr, eps);
997        self.weights.coref_link.same_entity += delta;
998
999        let delta = self
1000            .optimizer
1001            .different_link_state
1002            .update(grads.different_link, lr, eps);
1003        self.weights.coref_link.different_entity += delta;
1004    }
1005}
1006
1007// =============================================================================
1008// Utility Functions
1009// =============================================================================
1010
1011/// Simple Fisher-Yates shuffle with deterministic seed.
1012fn shuffle<T>(slice: &mut [T], seed: u64) {
1013    let mut rng = seed;
1014    for i in (1..slice.len()).rev() {
1015        rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
1016        let j = (rng as usize) % (i + 1);
1017        slice.swap(i, j);
1018    }
1019}
1020
1021// =============================================================================
1022// Tests
1023// =============================================================================
1024
1025#[cfg(test)]
1026mod tests {
1027    use super::*;
1028
1029    #[test]
1030    fn test_training_config_default() {
1031        let config = TrainingConfig::default();
1032        assert_eq!(config.epochs, 50);
1033        assert!((config.learning_rate - 0.1).abs() < 1e-6);
1034    }
1035
1036    #[test]
1037    fn test_trainer_creation() {
1038        let trainer = Trainer::new(TrainingConfig::default());
1039        assert!(trainer.examples.is_empty());
1040    }
1041
1042    #[test]
1043    fn test_adagrad_state() {
1044        let mut state = AdaGradState::default();
1045
1046        // First update
1047        let delta1 = state.update(1.0, 0.1, 1e-8);
1048        assert!(delta1 < 0.0); // Should move in negative gradient direction
1049
1050        // Second update with same gradient - should have smaller step due to accumulated squared grad
1051        let delta2 = state.update(1.0, 0.1, 1e-8);
1052        assert!(delta2.abs() < delta1.abs()); // Adaptive LR should decrease
1053    }
1054
1055    #[test]
1056    fn test_gradient_clipping() {
1057        let mut grads = Gradients {
1058            context_weight: 100.0,
1059            type_match: -100.0,
1060            ..Default::default()
1061        };
1062
1063        grads.clip(5.0);
1064
1065        assert!((grads.context_weight - 5.0).abs() < 1e-6);
1066        assert!((grads.type_match - (-5.0)).abs() < 1e-6);
1067    }
1068
1069    #[test]
1070    fn test_training_example_hamming_loss() {
1071        use crate::joint::MentionKind;
1072
1073        let mentions = vec![JointMention {
1074            idx: 0,
1075            text: "Alice".to_string(),
1076            head: "Alice".to_string(),
1077            start: 0,
1078            end: 5,
1079            mention_kind: MentionKind::Proper,
1080            entity_type: Some(EntityType::Person),
1081            entity: Some(Entity::new("Alice", EntityType::Person, 0, 5, 0.9)),
1082        }];
1083
1084        let mut gold_ner = HashMap::new();
1085        gold_ner.insert(0, EntityType::Person);
1086
1087        let example = TrainingExample {
1088            text: "Alice".to_string(),
1089            mentions,
1090            gold_ner,
1091            gold_coref: HashMap::new(),
1092            gold_links: HashMap::new(),
1093        };
1094
1095        // Perfect match
1096        let mut pred_ner = HashMap::new();
1097        pred_ner.insert(0, EntityType::Person);
1098        let loss = example.hamming_loss(&pred_ner, &HashMap::new(), &HashMap::new());
1099        assert!((loss - 0.0).abs() < 1e-6);
1100
1101        // Wrong type
1102        let mut wrong_ner = HashMap::new();
1103        wrong_ner.insert(0, EntityType::Organization);
1104        let loss = example.hamming_loss(&wrong_ner, &HashMap::new(), &HashMap::new());
1105        assert!(loss > 0.0);
1106    }
1107
1108    #[test]
1109    fn test_trainer_single_example() {
1110        use crate::joint::MentionKind;
1111
1112        let mut trainer = Trainer::new(TrainingConfig {
1113            epochs: 5,
1114            batch_size: 1,
1115            ..Default::default()
1116        });
1117
1118        let mentions = vec![
1119            JointMention {
1120                idx: 0,
1121                text: "Alice".to_string(),
1122                head: "Alice".to_string(),
1123                start: 0,
1124                end: 5,
1125                mention_kind: MentionKind::Proper,
1126                entity_type: Some(EntityType::Person),
1127                entity: Some(Entity::new("Alice", EntityType::Person, 0, 5, 0.9)),
1128            },
1129            JointMention {
1130                idx: 1,
1131                text: "she".to_string(),
1132                head: "she".to_string(),
1133                start: 17,
1134                end: 20,
1135                mention_kind: MentionKind::Pronominal,
1136                entity_type: Some(EntityType::Person),
1137                entity: Some(Entity::new("she", EntityType::Person, 17, 20, 0.8)),
1138            },
1139        ];
1140
1141        let mut gold_ner = HashMap::new();
1142        gold_ner.insert(0, EntityType::Person);
1143        gold_ner.insert(1, EntityType::Person);
1144
1145        let mut gold_coref = HashMap::new();
1146        gold_coref.insert(0, None); // New cluster
1147        gold_coref.insert(1, Some(0)); // Links to Alice
1148
1149        let example = TrainingExample {
1150            text: "Alice went home. she was tired.".to_string(),
1151            mentions,
1152            gold_ner,
1153            gold_coref,
1154            gold_links: HashMap::new(),
1155        };
1156
1157        trainer.add_example(example);
1158        let losses = trainer.train();
1159
1160        // Should have trained for some epochs
1161        assert!(!losses.is_empty());
1162        // Loss should generally decrease (or at least not explode)
1163        assert!(losses.iter().all(|&l| l < 1000.0));
1164    }
1165
1166    #[test]
1167    fn test_shuffle_deterministic() {
1168        let mut a = vec![1, 2, 3, 4, 5];
1169        let mut b = vec![1, 2, 3, 4, 5];
1170
1171        shuffle(&mut a, 42);
1172        shuffle(&mut b, 42);
1173
1174        assert_eq!(a, b); // Same seed should produce same shuffle
1175    }
1176
1177    #[test]
1178    fn test_dynamic_batch_config_default() {
1179        let config = DynamicBatchConfig::default();
1180        assert_eq!(config.max_context_length, 4000);
1181        assert_eq!(config.avg_sentence_length, 25);
1182        assert!(config.same_document);
1183    }
1184
1185    #[test]
1186    fn test_dynamic_batch_config_cross_document() {
1187        let config = DynamicBatchConfig::cross_document();
1188        assert!(!config.same_document);
1189        assert_eq!(config.min_contexts, 2);
1190        assert_eq!(config.window_overlap, 0);
1191    }
1192
1193    #[test]
1194    fn test_dynamic_batch_config_long_document() {
1195        let config = DynamicBatchConfig::long_document();
1196        assert!(config.same_document);
1197        assert_eq!(config.window_overlap, 256);
1198    }
1199
1200    #[test]
1201    fn test_dynamic_batch_sample_contexts() {
1202        let config = DynamicBatchConfig {
1203            min_contexts: 2,
1204            max_contexts: 10,
1205            ..Default::default()
1206        };
1207
1208        // Test deterministic sampling
1209        let n1 = config.sample_num_contexts(42);
1210        let n2 = config.sample_num_contexts(42);
1211        assert_eq!(n1, n2);
1212
1213        // Should be in range
1214        assert!((2..=10).contains(&n1));
1215
1216        // Different seeds should (usually) give different values
1217        let n3 = config.sample_num_contexts(123);
1218        // Note: this *could* fail by chance but is very unlikely
1219        assert!(n1 != n3 || config.max_contexts == config.min_contexts);
1220    }
1221
1222    #[test]
1223    fn test_dynamic_batch_context_length() {
1224        let config = DynamicBatchConfig {
1225            max_context_length: 4000,
1226            ..Default::default()
1227        };
1228
1229        // 1 context -> full length
1230        assert_eq!(config.context_length(1, 10000), 4000);
1231
1232        // 4 contexts -> 1/4 length
1233        assert_eq!(config.context_length(4, 10000), 1000);
1234
1235        // Short doc -> capped at doc length
1236        assert_eq!(config.context_length(2, 500), 250);
1237    }
1238
1239    #[test]
1240    fn test_training_config_with_dynamic_batching() {
1241        let config = TrainingConfig {
1242            dynamic_batching: Some(DynamicBatchConfig::cross_document()),
1243            ..Default::default()
1244        };
1245
1246        assert!(config.dynamic_batching.is_some());
1247        let db = config.dynamic_batching.unwrap();
1248        assert!(!db.same_document);
1249    }
1250}