Skip to main content

anno/backends/
box_embeddings_training.rs

1//! Training system for box embeddings in coreference resolution.
2//!
3//! This module implements the mathematical framework for learning box embeddings
4//! from coreference annotations, based on research from BERE, BoxTE, and UKGE.
5//!
6//! **Note**: This is the canonical training implementation. The [matryoshka-box](https://github.com/arclabs561/matryoshka-box)
7//! research project extends this with matryoshka-specific features (variable dimensions, etc.).
8//! For standard box embedding training, use this module. For research extensions, see matryoshka-box.
9//!
10//! # Mathematical Foundation
11//!
12//! ## Objective Function
13//!
14//! We maximize the conditional probability for positive pairs (entities that corefer)
15//! and minimize it for negative pairs (entities that don't corefer):
16//!
17//! ```text
18//! L = -Σ log P(box_i | box_j) for positive pairs (i,j)
19//!     + λ * Σ max(0, margin - P(box_i | box_j)) for negative pairs (i,j)
20//! ```
21//!
22//! Where:
23//! - P(box_i | box_j) = Vol(box_i ∩ box_j) / Vol(box_j) (conditional probability)
24//! - Positive pairs: entities in the same coreference chain
25//! - Negative pairs: entities in different chains
26//! - λ: negative sampling weight
27//! - margin: minimum separation for negative pairs
28//!
29//! ## Gradient Computation
30//!
31//! We use **analytical gradients** (not finite differences) for efficiency:
32//!
33//! For a box with min = μ - exp(δ)/2, max = μ + exp(δ)/2:
34//!
35//! ∂Vol/∂μᵢ = 0 (volume doesn't depend on center position)
36//! ∂Vol/∂δᵢ = Vol * 1 (since Vol = ∏exp(δᵢ))
37//!
38//! For intersection volume, gradients flow through each dimension's overlap.
39//!
40//! ## Box Constraints
41//!
42//! We enforce mᵢ ≤ Mᵢ for all dimensions using:
43//! - Reparameterization: Mᵢ = mᵢ + exp(δᵢ) where δᵢ is the learned parameter
44
45use crate::backends::box_embeddings::BoxEmbedding;
46use anno_core::Entity;
47use anno_core::{CorefChain, CorefDocument};
48use serde::{Deserialize, Serialize};
49use std::collections::HashMap;
50
51// =============================================================================
52// Trainable Box Embedding
53// =============================================================================
54
55/// A trainable box embedding with learnable parameters.
56///
57/// Uses reparameterization to ensure min <= max:
58/// - min = mu - exp(delta)/2
59/// - max = mu + exp(delta)/2
60///
61/// This ensures boxes are always valid (min <= max).
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TrainableBox {
64    /// Mean position in each dimension (d-dimensional vector).
65    pub mu: Vec<f32>,
66    /// Log-width in each dimension (width = exp(delta)).
67    pub delta: Vec<f32>,
68    /// Dimension
69    pub dim: usize,
70}
71
72impl TrainableBox {
73    /// Create a new trainable box.
74    ///
75    /// # Arguments
76    ///
77    /// * `mu` - Mean position (center of box)
78    /// * `delta` - Log-width (width = exp(delta))
79    ///
80    /// The box will have:
81    /// - min = mu - exp(delta) / 2
82    /// - max = mu + exp(delta) / 2
83    #[must_use]
84    pub fn new(mu: Vec<f32>, delta: Vec<f32>) -> Self {
85        assert_eq!(
86            mu.len(),
87            delta.len(),
88            "mu and delta must have same dimension"
89        );
90        let dim = mu.len();
91        Self { mu, delta, dim }
92    }
93
94    /// Initialize from a vector embedding.
95    ///
96    /// Creates a small box around the vector with initial width `init_width`.
97    #[must_use]
98    pub fn from_vector(vector: &[f32], init_width: f32) -> Self {
99        let mu = vector.to_vec();
100        let delta: Vec<f32> = vec![init_width.ln(); mu.len()];
101        Self::new(mu, delta)
102    }
103
104    /// Convert to a BoxEmbedding (for inference).
105    #[must_use]
106    pub fn to_box(&self) -> BoxEmbedding {
107        let min: Vec<f32> = self
108            .mu
109            .iter()
110            .zip(self.delta.iter())
111            .map(|(&m, &d)| m - (d.exp() / 2.0))
112            .collect();
113        let max: Vec<f32> = self
114            .mu
115            .iter()
116            .zip(self.delta.iter())
117            .map(|(&m, &d)| m + (d.exp() / 2.0))
118            .collect();
119        BoxEmbedding::new(min, max)
120    }
121}
122
123// =============================================================================
124// Training Data Structures
125// =============================================================================
126
127/// A single training example (one document with coreference chains).
128///
129/// Each example contains:
130/// - Entities: All mentions in the document
131/// - Chains: Groups of entities that corefer
132#[derive(Debug, Clone)]
133pub struct TrainingExample {
134    /// All entity mentions in the document
135    pub entities: Vec<Entity>,
136    /// Coreference chains (groups of entity IDs that refer to the same entity)
137    pub chains: Vec<CorefChain>,
138}
139
140/// Convert a CorefDocument to a TrainingExample.
141///
142/// This extracts all mentions as entities and preserves the coreference chains.
143impl From<&CorefDocument> for TrainingExample {
144    fn from(doc: &CorefDocument) -> Self {
145        // Collect all mentions as entities
146        let mut entities = Vec::new();
147        let mut mention_to_entity_id = HashMap::new();
148
149        for chain in &doc.chains {
150            for mention in &chain.mentions {
151                // Use character offset as entity ID (unique per document)
152                let entity_id = mention.start;
153
154                // Convert mention entity_type (Option<String>) to EntityType
155                // Default to Person if type is unknown
156                let entity_type = mention
157                    .entity_type
158                    .as_ref()
159                    .and_then(|s| match s.as_str() {
160                        "PER" | "Person" | "person" => Some(anno_core::EntityType::Person),
161                        "ORG" | "Organization" | "organization" => {
162                            Some(anno_core::EntityType::Organization)
163                        }
164                        "LOC" | "Location" | "location" => Some(anno_core::EntityType::Location),
165                        _ => None,
166                    })
167                    .unwrap_or(anno_core::EntityType::Person);
168
169                // Create Entity from Mention
170                let entity = Entity::new(
171                    mention.text.clone(),
172                    entity_type,
173                    entity_id,
174                    mention.end,
175                    1.0,
176                );
177
178                entities.push(entity);
179                mention_to_entity_id.insert((mention.start, mention.end), entity_id);
180            }
181        }
182
183        // Rebuild chains with same mentions (they already have correct offsets)
184        let chains = doc.chains.clone();
185
186        Self { entities, chains }
187    }
188}
189
190/// Convert multiple CorefDocuments to TrainingExamples.
191pub fn coref_documents_to_training_examples(docs: &[CorefDocument]) -> Vec<TrainingExample> {
192    docs.iter().map(TrainingExample::from).collect()
193}
194
195// =============================================================================
196// Training Configuration
197// =============================================================================
198
199/// Training configuration for box embeddings.
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct TrainingConfig {
202    /// Learning rate
203    pub learning_rate: f32,
204    /// Weight for negative pairs
205    pub negative_weight: f32,
206    /// Margin for negative pairs
207    pub margin: f32,
208    /// L2 regularization weight
209    pub regularization: f32,
210    /// Number of training epochs
211    pub epochs: usize,
212    /// Batch size (for mini-batch training)
213    pub batch_size: usize,
214    /// Warmup epochs (linear increase from 0.1*lr to lr)
215    pub warmup_epochs: usize,
216    /// Use self-adversarial negative sampling
217    pub use_self_adversarial: bool,
218    /// Temperature for self-adversarial sampling
219    pub adversarial_temperature: f32,
220    /// Early stopping patience (stop if no improvement for N epochs)
221    pub early_stopping_patience: Option<usize>,
222    /// Minimum improvement for early stopping (relative)
223    pub early_stopping_min_delta: f32,
224    /// Multi-stage training: focus on positives first (epochs), then negatives
225    pub positive_focus_epochs: Option<usize>,
226}
227
228impl Default for TrainingConfig {
229    fn default() -> Self {
230        Self {
231            learning_rate: 0.001,
232            negative_weight: 0.5,
233            margin: 0.3,
234            regularization: 0.0001,
235            epochs: 100,
236            batch_size: 32,
237            warmup_epochs: 10,
238            use_self_adversarial: true,
239            adversarial_temperature: 1.0,
240            early_stopping_patience: Some(10),
241            early_stopping_min_delta: 0.001,
242            positive_focus_epochs: None,
243        }
244    }
245}
246
247// =============================================================================
248// AMSGrad Optimizer State
249// =============================================================================
250
251/// AMSGrad optimizer state for a single box.
252#[derive(Debug, Clone)]
253pub struct AMSGradState {
254    /// First moment estimate (m)
255    pub m: Vec<f32>,
256    /// Second moment estimate (v)
257    pub v: Vec<f32>,
258    /// Max second moment estimate (v_hat)
259    pub v_hat: Vec<f32>,
260    /// Iteration counter
261    pub t: usize,
262    /// Learning rate
263    pub lr: f32,
264    /// Beta1 (momentum)
265    pub beta1: f32,
266    /// Beta2 (RMSprop)
267    pub beta2: f32,
268    /// Epsilon (numerical stability)
269    pub epsilon: f32,
270}
271
272impl AMSGradState {
273    /// Create new AMSGrad state.
274    pub fn new(dim: usize, learning_rate: f32) -> Self {
275        Self {
276            m: vec![0.0; dim],
277            v: vec![0.0; dim],
278            v_hat: vec![0.0; dim],
279            t: 0,
280            lr: learning_rate,
281            beta1: 0.9,
282            beta2: 0.999,
283            epsilon: 1e-8,
284        }
285    }
286
287    /// Set learning rate.
288    pub fn set_lr(&mut self, lr: f32) {
289        self.lr = lr;
290    }
291}
292
293// =============================================================================
294// Trainer
295// =============================================================================
296
297/// Trainer for box embeddings.
298pub struct BoxEmbeddingTrainer {
299    /// Training configuration
300    config: TrainingConfig,
301    /// Entity ID → TrainableBox mapping
302    boxes: HashMap<usize, TrainableBox>,
303    /// Entity ID → AMSGradState mapping
304    optimizer_states: HashMap<usize, AMSGradState>,
305    /// Embedding dimension
306    dim: usize,
307}
308
309impl BoxEmbeddingTrainer {
310    /// Create a new trainer.
311    ///
312    /// # Arguments
313    ///
314    /// * `config` - Training configuration
315    /// * `dim` - Embedding dimension
316    /// * `initial_embeddings` - Optional initial vector embeddings (entity_id → vector)
317    pub fn new(
318        config: TrainingConfig,
319        dim: usize,
320        initial_embeddings: Option<HashMap<usize, Vec<f32>>>,
321    ) -> Self {
322        let mut boxes = HashMap::new();
323        let mut optimizer_states = HashMap::new();
324
325        if let Some(embeddings) = initial_embeddings {
326            // Initialize from vector embeddings
327            for (entity_id, vector) in embeddings {
328                assert_eq!(vector.len(), dim);
329                let box_embedding = TrainableBox::from_vector(&vector, 0.1);
330                boxes.insert(entity_id, box_embedding.clone());
331                optimizer_states.insert(entity_id, AMSGradState::new(dim, config.learning_rate));
332            }
333        }
334
335        Self {
336            config,
337            boxes,
338            optimizer_states,
339            dim,
340        }
341    }
342
343    /// Initialize boxes from entities.
344    ///
345    /// Creates trainable boxes for all entities, either from provided
346    /// vector embeddings or random initialization.
347    ///
348    /// **Key insight**: For positive pairs (entities that corefer), initialize
349    /// boxes to overlap so gradients can flow from the start.
350    ///
351    /// # Arguments
352    ///
353    /// * `examples` - Training examples with entities and coreference chains
354    /// * `initial_embeddings` - Optional pre-computed vector embeddings (entity_id → vector)
355    ///   If provided, boxes are initialized around these vectors. If None, uses smart
356    ///   random initialization with shared centers for coreferent entities.
357    pub fn initialize_boxes(
358        &mut self,
359        examples: &[TrainingExample],
360        initial_embeddings: Option<&HashMap<usize, Vec<f32>>>,
361    ) {
362        // Collect all unique entity IDs and build coreference groups
363        let mut entity_ids = std::collections::HashSet::new();
364        let mut coref_groups: Vec<Vec<usize>> = Vec::new();
365
366        for example in examples {
367            for entity in &example.entities {
368                let entity_id = entity.start;
369                entity_ids.insert(entity_id);
370            }
371
372            // Build groups of entities that corefer
373            for chain in &example.chains {
374                let group: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
375                if group.len() > 1 {
376                    coref_groups.push(group);
377                }
378            }
379        }
380
381        // Initialize boxes
382        for &entity_id in &entity_ids {
383            // If vector embeddings provided, use them (better initialization)
384            if let Some(embeddings) = initial_embeddings {
385                if let Some(vector) = embeddings.get(&entity_id) {
386                    // Normalize vector to unit length for better initialization
387                    let norm: f32 = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
388                    let normalized: Vec<f32> = if norm > 0.0 {
389                        vector.iter().map(|&x| x / norm).collect()
390                    } else {
391                        vector.clone()
392                    };
393
394                    // Use larger initial width (0.2) when starting from vectors
395                    // This ensures boxes can overlap even if vectors are slightly different
396                    let box_embedding = TrainableBox::from_vector(&normalized, 0.2);
397                    self.boxes.insert(entity_id, box_embedding.clone());
398                    self.optimizer_states.insert(
399                        entity_id,
400                        AMSGradState::new(self.dim, self.config.learning_rate),
401                    );
402                    continue;
403                }
404            }
405
406            // Check if this entity is in a coreference group
407            let mut group_center: Option<Vec<f32>> = None;
408            let mut in_coref_group = false;
409
410            for group in &coref_groups {
411                if group.contains(&entity_id) {
412                    // Use a shared center for all entities in the group
413                    if group_center.is_none() {
414                        group_center = Some(
415                            (0..self.dim)
416                                .map(|_| (simple_random() - 0.5) * 0.3) // Smaller region for better overlap
417                                .collect(),
418                        );
419                    }
420                    in_coref_group = true;
421                    break;
422                }
423            }
424
425            // Initialize: if in coref group, use shared center; otherwise random
426            let mu = if let Some(ref center) = group_center {
427                // Add very small random offset to shared center (ensures overlap)
428                center
429                    .iter()
430                    .map(|&c| c + (simple_random() - 0.5) * 0.05) // Very small offset
431                    .collect()
432            } else {
433                // Random center, but spread out more to avoid accidental overlap
434                (0..self.dim)
435                    .map(|_| (simple_random() - 0.5) * 1.0)
436                    .collect()
437            };
438
439            // Balanced initialization: coreferent entities should overlap significantly
440            // Non-coreferent entities should be distinct and compact
441            // Initialize with larger width for coreferent entities (to ensure overlap)
442            let initial_width = if in_coref_group {
443                1.1_f32 // Good width for coreferent entities (ensures overlap but allows learning)
444            } else {
445                0.18_f32 // Small width for non-coreferent (distinct but not too small)
446            };
447            let delta: Vec<f32> = vec![initial_width.ln(); self.dim];
448            let box_embedding = TrainableBox::new(mu, delta);
449            self.boxes.insert(entity_id, box_embedding.clone());
450            self.optimizer_states.insert(
451                entity_id,
452                AMSGradState::new(self.dim, self.config.learning_rate),
453            );
454        }
455    }
456
457    /// Train on a single example.
458    fn train_example(&mut self, example: &TrainingExample, epoch: usize) -> f32 {
459        let mut total_loss = 0.0;
460        let mut num_pairs = 0;
461
462        // Update learning rate with warmup and decay
463        let current_lr = get_learning_rate(
464            epoch,
465            self.config.epochs,
466            self.config.learning_rate,
467            self.config.warmup_epochs,
468        );
469        for state in self.optimizer_states.values_mut() {
470            state.set_lr(current_lr);
471        }
472
473        // Build positive pairs (entities in same chain)
474        let mut positive_pairs = Vec::new();
475        for chain in &example.chains {
476            let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
477            for i in 0..mentions.len() {
478                for j in (i + 1)..mentions.len() {
479                    positive_pairs.push((mentions[i], mentions[j]));
480                }
481            }
482        }
483
484        // Build negative pairs (entities in different chains)
485        let mut negative_pairs = Vec::new();
486        for i in 0..example.chains.len() {
487            for j in (i + 1)..example.chains.len() {
488                let chain_i: Vec<usize> =
489                    example.chains[i].mentions.iter().map(|m| m.start).collect();
490                let chain_j: Vec<usize> =
491                    example.chains[j].mentions.iter().map(|m| m.start).collect();
492                for &id_i in &chain_i {
493                    for &id_j in &chain_j {
494                        negative_pairs.push((id_i, id_j));
495                    }
496                }
497            }
498        }
499
500        // Accumulate gradients for all pairs
501        let mut gradients: HashMap<usize, (Vec<f32>, Vec<f32>)> = HashMap::new();
502
503        // Process positive pairs
504        for &(id_a, id_b) in &positive_pairs {
505            // Clone boxes for gradient computation
506            let box_a = self.boxes.get(&id_a).cloned();
507            let box_b = self.boxes.get(&id_b).cloned();
508
509            if let (Some(box_a_ref), Some(box_b_ref)) = (box_a.as_ref(), box_b.as_ref()) {
510                let loss = compute_pair_loss(box_a_ref, box_b_ref, true, &self.config);
511                total_loss += loss;
512                num_pairs += 1;
513
514                // Compute analytical gradients
515                let (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b) =
516                    compute_analytical_gradients(box_a_ref, box_b_ref, true, &self.config);
517
518                // Skip if gradients are invalid
519                if grad_mu_a.iter().any(|&x| !x.is_finite())
520                    || grad_delta_a.iter().any(|&x| !x.is_finite())
521                    || grad_mu_b.iter().any(|&x| !x.is_finite())
522                    || grad_delta_b.iter().any(|&x| !x.is_finite())
523                {
524                    continue;
525                }
526
527                // Accumulate gradients
528                let entry_a = gradients
529                    .entry(id_a)
530                    .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
531                for i in 0..self.dim {
532                    entry_a.0[i] += grad_mu_a[i];
533                    entry_a.1[i] += grad_delta_a[i];
534                }
535
536                let entry_b = gradients
537                    .entry(id_b)
538                    .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
539                for i in 0..self.dim {
540                    entry_b.0[i] += grad_mu_b[i];
541                    entry_b.1[i] += grad_delta_b[i];
542                }
543            }
544        }
545
546        // Process negative pairs (with self-adversarial sampling if enabled)
547        let negative_samples: Vec<(usize, usize)> =
548            if self.config.use_self_adversarial && !negative_pairs.is_empty() {
549                // Sample based on current predictions
550                let num_samples = positive_pairs.len().min(negative_pairs.len());
551                let sampled_indices = sample_self_adversarial_negatives(
552                    &negative_pairs,
553                    &self.boxes,
554                    num_samples,
555                    self.config.adversarial_temperature,
556                );
557                sampled_indices
558                    .iter()
559                    .map(|&idx| negative_pairs[idx])
560                    .collect()
561            } else {
562                // Uniform sampling
563                let num_samples = positive_pairs.len().min(negative_pairs.len());
564                negative_pairs.into_iter().take(num_samples).collect()
565            };
566
567        for &(id_a, id_b) in &negative_samples {
568            // Clone boxes for gradient computation
569            let box_a = self.boxes.get(&id_a).cloned();
570            let box_b = self.boxes.get(&id_b).cloned();
571
572            if let (Some(box_a_ref), Some(box_b_ref)) = (box_a.as_ref(), box_b.as_ref()) {
573                let loss = compute_pair_loss(box_a_ref, box_b_ref, false, &self.config);
574                total_loss += loss;
575                num_pairs += 1;
576
577                // Compute analytical gradients
578                let (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b) =
579                    compute_analytical_gradients(box_a_ref, box_b_ref, false, &self.config);
580
581                // Skip if gradients are invalid
582                if grad_mu_a.iter().any(|&x| !x.is_finite())
583                    || grad_delta_a.iter().any(|&x| !x.is_finite())
584                    || grad_mu_b.iter().any(|&x| !x.is_finite())
585                    || grad_delta_b.iter().any(|&x| !x.is_finite())
586                {
587                    continue;
588                }
589
590                // Accumulate gradients
591                let entry_a = gradients
592                    .entry(id_a)
593                    .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
594                for i in 0..self.dim {
595                    entry_a.0[i] += grad_mu_a[i];
596                    entry_a.1[i] += grad_delta_a[i];
597                }
598
599                let entry_b = gradients
600                    .entry(id_b)
601                    .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
602                for i in 0..self.dim {
603                    entry_b.0[i] += grad_mu_b[i];
604                    entry_b.1[i] += grad_delta_b[i];
605                }
606            }
607        }
608
609        // Apply accumulated gradients using AMSGrad
610        for (entity_id, (grad_mu, grad_delta)) in gradients {
611            if let (Some(box_mut), Some(state)) = (
612                self.boxes.get_mut(&entity_id),
613                self.optimizer_states.get_mut(&entity_id),
614            ) {
615                box_mut.update_amsgrad(&grad_mu, &grad_delta, state);
616            }
617        }
618
619        if num_pairs > 0 {
620            total_loss / num_pairs as f32
621        } else {
622            0.0
623        }
624    }
625
626    /// Train on a dataset with mini-batching and early stopping.
627    /// Uses adaptive negative weighting: starts with low weight to learn positives,
628    /// then gradually increases to separate negatives.
629    pub fn train(&mut self, examples: &[TrainingExample]) -> Vec<f32> {
630        let mut losses = Vec::new();
631        let mut best_loss = f32::INFINITY;
632        let mut patience_counter = 0;
633
634        // Track score gap for adaptive weighting
635        let mut score_gap_history = Vec::new();
636
637        for epoch in 0..self.config.epochs {
638            // Multi-stage training: focus on positives first, then negatives
639            let (avg_pos, avg_neg, _) = self.get_overlap_stats(examples);
640            let current_gap = avg_pos - avg_neg;
641            score_gap_history.push(current_gap);
642
643            // Determine training stage
644            let positive_focus_epochs = self
645                .config
646                .positive_focus_epochs
647                .unwrap_or(self.config.epochs / 3);
648            let is_positive_stage = epoch < positive_focus_epochs;
649
650            // Calculate adaptive negative weight based on stage and performance
651            let adaptive_negative_weight = if is_positive_stage {
652                // Stage 1: Focus on positive learning - low negative weight but not zero
653                // Still apply some negative gradients to prevent negative scores from growing too much
654                // Gradually increase from 0.2 to 0.3 during positive stage
655                let stage_progress = epoch as f32 / positive_focus_epochs as f32;
656                self.config.negative_weight * (0.2 + stage_progress * 0.1)
657            } else if avg_pos > 0.05 && avg_neg > 0.3 {
658                // Stage 2: Positive learning is good but negatives are too high - aggressive separation
659                // Increase negative weight more aggressively
660                let progress = ((epoch - positive_focus_epochs) as f32
661                    / (self.config.epochs - positive_focus_epochs) as f32)
662                    .min(1.0);
663                // Scale based on how bad negatives are - more aggressive
664                let neg_penalty = (avg_neg / 0.4).min(1.0); // Penalty factor for high negatives (lower threshold)
665                self.config.negative_weight * (0.7 + progress * 0.8 + neg_penalty * 0.4).min(2.0)
666            // Up to 2.0x
667            } else if avg_pos > 0.02 && current_gap > 0.0 {
668                // Stage 2: Positive learning is good, can focus on separation
669                // Gradually increase negative weight as gap improves
670                let progress = ((epoch - positive_focus_epochs) as f32
671                    / (self.config.epochs - positive_focus_epochs) as f32)
672                    .min(1.0);
673                self.config.negative_weight * (0.5 + progress * 0.5).min(1.0 + (current_gap / 0.1))
674            } else if avg_pos < 0.01 {
675                // Positive scores too low, reduce negative weight
676                self.config.negative_weight * 0.3
677            } else {
678                // Default behavior - moderate weight, gradually increase
679                let progress = ((epoch - positive_focus_epochs) as f32
680                    / (self.config.epochs - positive_focus_epochs) as f32)
681                    .min(1.0);
682                self.config.negative_weight * (0.4 + progress * 0.4)
683            };
684
685            // Temporarily override negative weight for this epoch
686            let original_negative_weight = self.config.negative_weight;
687            self.config.negative_weight = adaptive_negative_weight;
688            // Shuffle examples for better training (simple Fisher-Yates)
689            let mut shuffled_indices: Vec<usize> = (0..examples.len()).collect();
690            for i in (1..shuffled_indices.len()).rev() {
691                let j = (simple_random() * (i + 1) as f32) as usize;
692                shuffled_indices.swap(i, j);
693            }
694
695            let mut epoch_loss = 0.0;
696            let mut num_batches = 0;
697
698            // Mini-batch training
699            for batch_start in (0..examples.len()).step_by(self.config.batch_size) {
700                let batch_end = (batch_start + self.config.batch_size).min(examples.len());
701                let batch_indices = &shuffled_indices[batch_start..batch_end];
702
703                let mut batch_loss = 0.0;
704                let mut batch_pairs = 0;
705
706                // Process batch
707                for &idx in batch_indices {
708                    let example = &examples[idx];
709                    let loss = self.train_example(example, epoch);
710                    batch_loss += loss;
711                    batch_pairs += 1;
712                }
713
714                if batch_pairs > 0 {
715                    epoch_loss += batch_loss / batch_pairs as f32;
716                    num_batches += 1;
717                }
718            }
719
720            let avg_loss = if num_batches > 0 {
721                epoch_loss / num_batches as f32
722            } else {
723                0.0
724            };
725            losses.push(avg_loss);
726
727            let current_lr = get_learning_rate(
728                epoch,
729                self.config.epochs,
730                self.config.learning_rate,
731                self.config.warmup_epochs,
732            );
733
734            // Early stopping check
735            let improved = avg_loss < best_loss - self.config.early_stopping_min_delta;
736            if improved {
737                best_loss = avg_loss;
738                patience_counter = 0;
739            } else {
740                patience_counter += 1;
741            }
742
743            // Show overlap stats periodically
744            if epoch % 10 == 0 || epoch == self.config.epochs - 1 || improved {
745                let (avg_pos, avg_neg, overlap_rate) = self.get_overlap_stats(examples);
746                let status = if improved { "✓" } else { " " };
747                let patience_info = if let Some(patience) = self.config.early_stopping_patience {
748                    format!(", patience={}/{}", patience_counter, patience)
749                } else {
750                    String::new()
751                };
752                let loss_reduction = if losses.len() > 1 {
753                    format!(" ({:.1}%↓)", (1.0 - avg_loss / losses[0]) * 100.0)
754                } else {
755                    String::new()
756                };
757                let score_gap = avg_pos - avg_neg; // Positive should be higher than negative
758                let positive_focus_epochs = self
759                    .config
760                    .positive_focus_epochs
761                    .unwrap_or(self.config.epochs / 3);
762                let stage = if epoch < positive_focus_epochs {
763                    "P+"
764                } else {
765                    "S-"
766                };
767                println!("Epoch {}: loss = {:.4}{}, lr = {:.6}, best = {:.4} {} ({} batches{}, neg_w={:.2}, stage={})",
768                    epoch, avg_loss, loss_reduction, current_lr, best_loss, status, num_batches, patience_info, adaptive_negative_weight, stage);
769                println!(
770                    "  Overlap: {:.1}%, Pos: {:.4}, Neg: {:.4}, Gap: {:.4} {}",
771                    overlap_rate * 100.0,
772                    avg_pos,
773                    avg_neg,
774                    score_gap,
775                    if score_gap > 0.0 { "✓" } else { "⚠" }
776                );
777            }
778
779            // Restore original negative weight
780            self.config.negative_weight = original_negative_weight;
781
782            // Early stopping
783            if let Some(patience) = self.config.early_stopping_patience {
784                if patience_counter >= patience {
785                    println!(
786                        "Early stopping at epoch {} (no improvement for {} epochs)",
787                        epoch, patience
788                    );
789                    break;
790                }
791            }
792        }
793
794        losses
795    }
796
797    /// Get trained boxes for inference.
798    pub fn get_boxes(&self) -> HashMap<usize, BoxEmbedding> {
799        self.boxes
800            .iter()
801            .map(|(id, trainable)| (*id, trainable.to_box()))
802            .collect()
803    }
804
805    /// Get diagnostic statistics about box overlaps.
806    ///
807    /// Returns (avg_positive_score, avg_negative_score, overlap_rate)
808    pub fn get_overlap_stats(&self, examples: &[TrainingExample]) -> (f32, f32, f32) {
809        let mut positive_scores = Vec::new();
810        let mut negative_scores = Vec::new();
811        let mut overlapping_pairs = 0;
812        let mut total_pairs = 0;
813
814        for example in examples {
815            // Positive pairs
816            for chain in &example.chains {
817                let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
818                for i in 0..mentions.len() {
819                    for j in (i + 1)..mentions.len() {
820                        if let (Some(box_a), Some(box_b)) =
821                            (self.boxes.get(&mentions[i]), self.boxes.get(&mentions[j]))
822                        {
823                            let box_a_embed = box_a.to_box();
824                            let box_b_embed = box_b.to_box();
825                            let score = box_a_embed.coreference_score(&box_b_embed);
826                            positive_scores.push(score);
827                            if score > 0.01 {
828                                overlapping_pairs += 1;
829                            }
830                            total_pairs += 1;
831                        }
832                    }
833                }
834            }
835
836            // Negative pairs
837            for i in 0..example.chains.len() {
838                for j in (i + 1)..example.chains.len() {
839                    let chain_i: Vec<usize> =
840                        example.chains[i].mentions.iter().map(|m| m.start).collect();
841                    let chain_j: Vec<usize> =
842                        example.chains[j].mentions.iter().map(|m| m.start).collect();
843                    for &id_i in &chain_i {
844                        for &id_j in &chain_j {
845                            if let (Some(box_a), Some(box_b)) =
846                                (self.boxes.get(&id_i), self.boxes.get(&id_j))
847                            {
848                                let box_a_embed = box_a.to_box();
849                                let box_b_embed = box_b.to_box();
850                                let score = box_a_embed.coreference_score(&box_b_embed);
851                                negative_scores.push(score);
852                            }
853                        }
854                    }
855                }
856            }
857        }
858
859        let avg_positive = if !positive_scores.is_empty() {
860            positive_scores.iter().sum::<f32>() / positive_scores.len() as f32
861        } else {
862            0.0
863        };
864
865        let avg_negative = if !negative_scores.is_empty() {
866            negative_scores.iter().sum::<f32>() / negative_scores.len() as f32
867        } else {
868            0.0
869        };
870
871        let overlap_rate = if total_pairs > 0 {
872            overlapping_pairs as f32 / total_pairs as f32
873        } else {
874            0.0
875        };
876
877        (avg_positive, avg_negative, overlap_rate)
878    }
879
880    /// Evaluate coreference accuracy on a test set.
881    ///
882    /// Returns (accuracy, precision, recall, f1) where:
883    /// - Accuracy: fraction of pairs correctly classified
884    /// - Precision: fraction of predicted positives that are correct
885    /// - Recall: fraction of true positives that are predicted
886    /// - F1: harmonic mean of precision and recall
887    ///
888    /// **Note**: This is a simple pair-wise evaluation. For standard coreference metrics
889    /// (MUC, B³, CEAF, LEA, BLANC, CoNLL F1), use `evaluate_standard_metrics()` instead.
890    pub fn evaluate(&self, examples: &[TrainingExample], threshold: f32) -> (f32, f32, f32, f32) {
891        let mut true_positives = 0;
892        let mut false_positives = 0;
893        let mut false_negatives = 0;
894        let mut total_pairs = 0;
895
896        for example in examples {
897            // Build positive pairs (should corefer)
898            let mut positive_pairs = Vec::new();
899            for chain in &example.chains {
900                let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
901                for i in 0..mentions.len() {
902                    for j in (i + 1)..mentions.len() {
903                        positive_pairs.push((mentions[i], mentions[j]));
904                    }
905                }
906            }
907
908            // Build negative pairs (shouldn't corefer)
909            let mut negative_pairs = Vec::new();
910            for i in 0..example.chains.len() {
911                for j in (i + 1)..example.chains.len() {
912                    let chain_i: Vec<usize> =
913                        example.chains[i].mentions.iter().map(|m| m.start).collect();
914                    let chain_j: Vec<usize> =
915                        example.chains[j].mentions.iter().map(|m| m.start).collect();
916                    for &id_i in &chain_i {
917                        for &id_j in &chain_j {
918                            negative_pairs.push((id_i, id_j));
919                        }
920                    }
921                }
922            }
923
924            // Evaluate positive pairs
925            for &(id_a, id_b) in &positive_pairs {
926                total_pairs += 1;
927                if let (Some(box_a), Some(box_b)) = (self.boxes.get(&id_a), self.boxes.get(&id_b)) {
928                    let box_a_embed = box_a.to_box();
929                    let box_b_embed = box_b.to_box();
930                    let score = box_a_embed.coreference_score(&box_b_embed);
931                    if score >= threshold {
932                        true_positives += 1;
933                    } else {
934                        false_negatives += 1;
935                    }
936                } else {
937                    // Missing boxes - count as false negative (model can't predict)
938                    false_negatives += 1;
939                }
940            }
941
942            // Evaluate negative pairs
943            for &(id_a, id_b) in &negative_pairs {
944                total_pairs += 1;
945                if let (Some(box_a), Some(box_b)) = (self.boxes.get(&id_a), self.boxes.get(&id_b)) {
946                    let box_a_embed = box_a.to_box();
947                    let box_b_embed = box_b.to_box();
948                    let score = box_a_embed.coreference_score(&box_b_embed);
949                    if score >= threshold {
950                        false_positives += 1;
951                    }
952                    // If score < threshold, it's a true negative (correctly predicted as non-coreferent)
953                }
954                // If boxes are missing, we can't evaluate - don't count as error
955            }
956        }
957
958        // Compute metrics
959        let precision = if true_positives + false_positives > 0 {
960            true_positives as f32 / (true_positives + false_positives) as f32
961        } else {
962            0.0
963        };
964
965        let recall = if true_positives + false_negatives > 0 {
966            true_positives as f32 / (true_positives + false_negatives) as f32
967        } else {
968            0.0
969        };
970
971        let f1 = if precision + recall > 0.0 {
972            2.0 * precision * recall / (precision + recall)
973        } else {
974            0.0
975        };
976
977        let accuracy = if total_pairs > 0 {
978            (true_positives + (total_pairs - true_positives - false_positives - false_negatives))
979                as f32
980                / total_pairs as f32
981        } else {
982            0.0
983        };
984
985        (accuracy, precision, recall, f1)
986    }
987
988    /// Save trained boxes to a file (JSON format).
989    ///
990    /// # Arguments
991    ///
992    /// * `path` - File path to save to
993    ///
994    /// # Example
995    ///
996    /// ```rust,ignore
997    /// trainer.save_boxes("trained_boxes.json")?;
998    /// ```
999    pub fn save_boxes(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1000        use std::fs::File;
1001        use std::io::Write;
1002
1003        let serialized = serde_json::to_string_pretty(&self.boxes)?;
1004        let mut file = File::create(path)?;
1005        file.write_all(serialized.as_bytes())?;
1006        Ok(())
1007    }
1008
1009    /// Load trained boxes from a file (JSON format).
1010    ///
1011    /// # Arguments
1012    ///
1013    /// * `path` - File path to load from
1014    /// * `dim` - Expected embedding dimension
1015    ///
1016    /// # Example
1017    ///
1018    /// ```rust,ignore
1019    /// let boxes = BoxEmbeddingTrainer::load_boxes("trained_boxes.json", 32)?;
1020    /// ```
1021    pub fn load_boxes(
1022        path: &str,
1023        dim: usize,
1024    ) -> Result<HashMap<usize, TrainableBox>, Box<dyn std::error::Error>> {
1025        use std::fs::File;
1026        use std::io::Read;
1027
1028        let mut file = File::open(path)?;
1029        let mut contents = String::new();
1030        file.read_to_string(&mut contents)?;
1031        let boxes: HashMap<usize, TrainableBox> = serde_json::from_str(&contents)?;
1032
1033        // Verify dimensions
1034        for (id, box_embedding) in &boxes {
1035            if box_embedding.dim != dim {
1036                return Err(format!(
1037                    "Box for entity {} has dimension {}, expected {}",
1038                    id, box_embedding.dim, dim
1039                )
1040                .into());
1041            }
1042        }
1043
1044        Ok(boxes)
1045    }
1046
1047    /// Evaluate using standard coreference metrics (MUC, B³, CEAF, LEA, BLANC, CoNLL F1).
1048    ///
1049    /// This converts the trained boxes into coreference chains and evaluates using
1050    /// the standard metrics used in coreference research.
1051    ///
1052    /// # Arguments
1053    ///
1054    /// * `examples` - Test examples with gold coreference chains
1055    /// * `threshold` - Coreference threshold for box overlap
1056    ///
1057    /// # Returns
1058    ///
1059    /// `CorefEvaluation` with all standard metrics
1060    ///
1061    /// Requires `analysis` (or legacy `eval`) feature for access to standard coref metrics.
1062    #[cfg(any(feature = "analysis", feature = "eval"))]
1063    pub fn evaluate_standard_metrics(
1064        &self,
1065        examples: &[TrainingExample],
1066        threshold: f32,
1067    ) -> crate::eval::coref_metrics::CorefEvaluation {
1068        use crate::backends::box_embeddings::BoxCorefConfig;
1069        use crate::eval::coref_metrics::CorefEvaluation;
1070        use crate::eval::coref_resolver::BoxCorefResolver;
1071
1072        let mut all_predicted_chains = Vec::new();
1073        let mut all_gold_chains = Vec::new();
1074
1075        for example in examples {
1076            // Collect gold chains
1077            all_gold_chains.extend(example.chains.clone());
1078
1079            // Get entities from example
1080            let entities = &example.entities;
1081
1082            // Get boxes for entities (or create default boxes if missing)
1083            let mut boxes = Vec::new();
1084            for entity in entities {
1085                if let Some(trainable_box) = self.boxes.get(&entity.start) {
1086                    boxes.push(trainable_box.to_box());
1087                } else {
1088                    // Missing box - create a small default box
1089                    let center = vec![0.0; self.dim];
1090                    boxes.push(crate::backends::box_embeddings::BoxEmbedding::from_vector(
1091                        &center, 0.1,
1092                    ));
1093                }
1094            }
1095
1096            // Resolve coreference using boxes
1097            let box_config = BoxCorefConfig {
1098                coreference_threshold: threshold,
1099                ..Default::default()
1100            };
1101            let resolver = BoxCorefResolver::new(box_config);
1102            let resolved_entities = resolver.resolve_with_boxes(entities, &boxes);
1103
1104            // Convert resolved entities to chains
1105            let predicted_chains = anno_core::core::coref::entities_to_chains(&resolved_entities);
1106            all_predicted_chains.extend(predicted_chains);
1107        }
1108
1109        // Compute standard metrics
1110        CorefEvaluation::compute(&all_predicted_chains, &all_gold_chains)
1111    }
1112}
1113
1114/// Split training examples into train/validation sets.
1115///
1116/// # Arguments
1117///
1118/// * `examples` - All training examples
1119/// * `val_ratio` - Fraction of examples to use for validation (0.0-1.0)
1120///
1121/// # Returns
1122///
1123/// (train_examples, val_examples)
1124pub fn split_train_val(
1125    examples: &[TrainingExample],
1126    val_ratio: f32,
1127) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
1128    let val_size = (examples.len() as f32 * val_ratio) as usize;
1129    let mut shuffled: Vec<TrainingExample> = examples.to_vec();
1130
1131    // Simple shuffle
1132    for i in (1..shuffled.len()).rev() {
1133        let j = (simple_random() * (i + 1) as f32) as usize;
1134        shuffled.swap(i, j);
1135    }
1136
1137    let val_examples = shuffled.split_off(val_size);
1138    (shuffled, val_examples)
1139}
1140
1141// =============================================================================
1142// Loss and Gradient Computation
1143// =============================================================================
1144
1145/// Compute loss for a pair of boxes.
1146fn compute_pair_loss(
1147    box_a: &TrainableBox,
1148    box_b: &TrainableBox,
1149    is_positive: bool,
1150    config: &TrainingConfig,
1151) -> f32 {
1152    let box_a_embed = box_a.to_box();
1153    let box_b_embed = box_b.to_box();
1154
1155    if is_positive {
1156        // Positive pair: maximize conditional probability
1157        let p_a_b = box_a_embed.conditional_probability(&box_b_embed);
1158        let p_b_a = box_b_embed.conditional_probability(&box_a_embed);
1159
1160        // Clamp probabilities to avoid log(0)
1161        let p_a_b = p_a_b.max(1e-8);
1162        let p_b_a = p_b_a.max(1e-8);
1163
1164        // Use symmetric score: min of both conditional probabilities
1165        // This ensures both boxes must overlap significantly
1166        let min_prob = p_a_b.min(p_b_a);
1167        let neg_log_prob = -min_prob.ln();
1168
1169        // Also add penalty if boxes are too far apart (encourages movement)
1170        let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1171        let distance_penalty = if vol_intersection < 1e-10 {
1172            // Boxes don't overlap - add distance penalty
1173            let center_a = box_a_embed.center();
1174            let center_b = box_b_embed.center();
1175            let dist: f32 = center_a
1176                .iter()
1177                .zip(center_b.iter())
1178                .map(|(a, b)| (a - b).powi(2))
1179                .sum::<f32>()
1180                .sqrt();
1181            0.3 * dist // Stronger penalty to encourage boxes to move closer
1182        } else {
1183            // Even if overlapping, encourage more overlap
1184            // Penalize if overlap is too small
1185            let vol_a = box_a_embed.volume();
1186            let vol_b = box_b_embed.volume();
1187            let overlap_ratio = vol_intersection / vol_a.min(vol_b).max(1e-10);
1188            if overlap_ratio < 0.5 {
1189                // Encourage more overlap
1190                0.1 * (0.5 - overlap_ratio)
1191            } else {
1192                0.0
1193            }
1194        };
1195
1196        // Regularization: penalize large volumes
1197        let vol_a = box_a_embed.volume();
1198        let vol_b = box_b_embed.volume();
1199        // Light regularization to prevent boxes from growing too large
1200        let reg = config.regularization * 1.0 * (vol_a + vol_b);
1201
1202        (neg_log_prob + reg + distance_penalty).max(0.0)
1203    } else {
1204        // Negative pair: enforce separation
1205        // We want conditional probability to be LOW (boxes should be disjoint)
1206        let p_a_b = box_a_embed.conditional_probability(&box_b_embed);
1207        let p_b_a = box_b_embed.conditional_probability(&box_a_embed);
1208
1209        // Use max of both conditional probabilities
1210        let max_prob = p_a_b.max(p_b_a);
1211
1212        // Loss: penalize high conditional probability
1213        // Use hinge loss instead of quadratic for smoother gradients
1214        let margin_loss = if max_prob > config.margin {
1215            // Stronger quadratic penalty for exceeding margin
1216            let excess = max_prob - config.margin;
1217            excess.powi(2) * (1.0 + excess * 2.0) // Exponential scaling
1218        } else {
1219            0.0 // No loss if already below margin (good!)
1220        };
1221
1222        // Add extra penalty for very high probabilities (exponential decay)
1223        // Note: This is currently not used in the loss calculation but kept for future use
1224        let _high_prob_penalty = if max_prob > 0.1 {
1225            (max_prob - 0.1).powi(2) * 0.5 // Extra penalty for very high probabilities
1226        } else {
1227            0.0
1228        };
1229
1230        // Add extra penalty if boxes overlap significantly
1231        let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1232        let vol_a = box_a_embed.volume();
1233        let vol_b = box_b_embed.volume();
1234        let overlap_penalty = if vol_intersection > 1e-10 {
1235            // Boxes overlap - add strong penalty (more aggressive for higher overlap)
1236            let overlap_ratio = vol_intersection / vol_a.min(vol_b).max(1e-10);
1237            // Exponential penalty for high overlap
1238            if overlap_ratio > 0.5 {
1239                4.0 * overlap_ratio * overlap_ratio // Stronger quadratic penalty
1240            } else if overlap_ratio > 0.3 {
1241                3.0 * overlap_ratio // Stronger linear penalty for moderate overlap
1242            } else {
1243                2.5 * overlap_ratio // Linear penalty for low overlap
1244            }
1245        } else {
1246            0.0
1247        };
1248
1249        // Base loss: only penalize if probability is significant (above a threshold)
1250        // Don't penalize tiny probabilities - they're fine
1251        let base_loss = if max_prob > 0.01 {
1252            max_prob * 0.2 // Stronger penalty for significant probabilities
1253        } else {
1254            0.0 // No penalty for tiny probabilities
1255        };
1256
1257        // Adaptive penalty: stronger for very high probabilities
1258        // Use exponential scaling for probabilities > 0.1, moderate for > 0.05
1259        let adaptive_penalty = if max_prob > 0.1 {
1260            // Exponential penalty for very high probabilities (stronger scaling)
1261            let prob_excess = max_prob - 0.1;
1262            prob_excess.powi(2) * (3.0 + prob_excess * 7.0) // Stronger exponential scaling
1263        } else if max_prob > 0.05 {
1264            // Moderate penalty for medium-high probabilities
1265            (max_prob - 0.05).powi(2) * 1.5 // Stronger
1266        } else if max_prob > 0.02 {
1267            // Light penalty for low-medium probabilities
1268            (max_prob - 0.02).powi(2) * 0.5
1269        } else {
1270            0.0
1271        };
1272
1273        config.negative_weight * (margin_loss + overlap_penalty + base_loss + adaptive_penalty)
1274    }
1275}
1276
1277/// Compute analytical gradients for a pair of boxes.
1278fn compute_analytical_gradients(
1279    box_a: &TrainableBox,
1280    box_b: &TrainableBox,
1281    is_positive: bool,
1282    config: &TrainingConfig,
1283) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
1284    let box_a_embed = box_a.to_box();
1285    let box_b_embed = box_b.to_box();
1286    let dim = box_a.dim;
1287
1288    // Initialize gradients
1289    let mut grad_mu_a = vec![0.0; dim];
1290    let mut grad_delta_a = vec![0.0; dim];
1291    let mut grad_mu_b = vec![0.0; dim];
1292    let mut grad_delta_b = vec![0.0; dim];
1293
1294    // Compute intersection volume and individual volumes
1295    let vol_a = box_a_embed.volume();
1296    let vol_b = box_b_embed.volume();
1297    let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1298
1299    if is_positive {
1300        // Positive pair: L = -log(P(A|B)) - log(P(B|A)) + reg
1301        // P(A|B) = Vol(A ∩ B) / Vol(B)
1302        // P(B|A) = Vol(A ∩ B) / Vol(A)
1303
1304        let p_a_b = if vol_b > 0.0 {
1305            vol_intersection / vol_b
1306        } else {
1307            0.0
1308        };
1309        let p_b_a = if vol_a > 0.0 {
1310            vol_intersection / vol_a
1311        } else {
1312            0.0
1313        };
1314
1315        // Clamp to avoid division by zero
1316        let p_a_b = p_a_b.max(1e-8);
1317        let p_b_a = p_b_a.max(1e-8);
1318
1319        // Gradients through -log(P(A|B)) = -log(Vol_intersection) + log(Vol_B)
1320        // For positive pairs, we want to maximize overlap
1321        // If boxes don't overlap, we need gradients to move them together
1322
1323        // Check if boxes overlap
1324        let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1325        let has_overlap = vol_intersection > 1e-10;
1326
1327        if !has_overlap {
1328            // Boxes don't overlap - add very strong gradient to move centers closer
1329            let center_a = box_a_embed.center();
1330            let center_b = box_b_embed.center();
1331            let center_dist = center_a
1332                .iter()
1333                .zip(center_b.iter())
1334                .map(|(a, b)| (a - b).powi(2))
1335                .sum::<f32>()
1336                .sqrt();
1337
1338            for i in 0..dim {
1339                let diff = center_b[i] - center_a[i];
1340                // Adaptive strength based on distance - stronger when far apart
1341                let distance_factor = (center_dist / dim as f32).clamp(0.5, 2.0);
1342                let attraction_strength = 4.0 * distance_factor; // Stronger when far apart
1343
1344                grad_mu_a[i] += attraction_strength * diff;
1345                grad_mu_b[i] += -attraction_strength * diff;
1346
1347                // Strong increase in box sizes to help them overlap
1348                grad_delta_a[i] += 0.5 * distance_factor; // Stronger growth when far apart
1349                grad_delta_b[i] += 0.5 * distance_factor;
1350            }
1351        }
1352
1353        for i in 0..dim {
1354            // Gradient w.r.t. box A
1355            // ∂(-log(P(A|B)))/∂δ_A = -1/P(A|B) * ∂P(A|B)/∂δ_A
1356            // ∂P(A|B)/∂δ_A = (1/Vol_B) * ∂Vol_intersection/∂δ_A
1357            // ∂Vol_intersection/∂δ_A = Vol_intersection (if boxes overlap in dim i)
1358
1359            let overlap_i = if box_a_embed.min[i] < box_b_embed.max[i]
1360                && box_b_embed.min[i] < box_a_embed.max[i]
1361            {
1362                // Boxes overlap in dimension i
1363                let min_overlap = box_a_embed.min[i].max(box_b_embed.min[i]);
1364                let max_overlap = box_a_embed.max[i].min(box_b_embed.max[i]);
1365                (max_overlap - min_overlap).max(0.0)
1366            } else {
1367                0.0
1368            };
1369
1370            if overlap_i > 0.0 && vol_intersection > 0.0 {
1371                // Gradient through intersection volume
1372                // When boxes overlap, focus on improving overlap ratio (not just growing boxes)
1373                let overlap_ratio_a = vol_intersection / vol_a.max(1e-10);
1374                let overlap_ratio_b = vol_intersection / vol_b.max(1e-10);
1375
1376                // If overlap ratio is low, encourage growth; if high, maintain
1377                // Adaptive growth based on current overlap - more aggressive
1378                if overlap_ratio_a < 0.15 {
1379                    // Extremely low overlap - encourage very strong growth
1380                    grad_delta_a[i] += 0.35;
1381                } else if overlap_ratio_a < 0.3 {
1382                    // Very low overlap - encourage extremely strong growth
1383                    grad_delta_a[i] += 0.3;
1384                } else if overlap_ratio_a < 0.5 {
1385                    // Low overlap - encourage very strong growth
1386                    grad_delta_a[i] += 0.2;
1387                } else if overlap_ratio_a < 0.7 {
1388                    // Moderate overlap - strong growth
1389                    grad_delta_a[i] += 0.1;
1390                } else if overlap_ratio_a < 0.85 {
1391                    // Good overlap - small growth
1392                    grad_delta_a[i] += 0.05;
1393                }
1394                // If overlap_ratio_a >= 0.85, don't grow (excellent)
1395
1396                if overlap_ratio_b < 0.15 {
1397                    // Extremely low overlap - encourage very strong growth
1398                    grad_delta_b[i] += 0.35;
1399                } else if overlap_ratio_b < 0.3 {
1400                    // Very low overlap - encourage extremely strong growth
1401                    grad_delta_b[i] += 0.3;
1402                } else if overlap_ratio_b < 0.5 {
1403                    // Low overlap - encourage very strong growth
1404                    grad_delta_b[i] += 0.2;
1405                } else if overlap_ratio_b < 0.7 {
1406                    // Moderate overlap - strong growth
1407                    grad_delta_b[i] += 0.1;
1408                } else if overlap_ratio_b < 0.85 {
1409                    // Good overlap - small growth
1410                    grad_delta_b[i] += 0.05;
1411                }
1412
1413                // Gradient through conditional probability (main signal) - adaptive strength
1414                // Stronger when overlap is low, gentler when overlap is good
1415                let gradient_strength = if overlap_ratio_a < 0.1 {
1416                    1.7 // Extremely strong when overlap is extremely low
1417                } else if overlap_ratio_a < 0.2 {
1418                    1.6 // Extremely strong when overlap is very low
1419                } else if overlap_ratio_a < 0.4 {
1420                    1.4 // Very strong when overlap is low
1421                } else if overlap_ratio_a < 0.6 {
1422                    1.1 // Strong when overlap is moderate
1423                } else {
1424                    0.6 // Gentle when overlap is good
1425                };
1426
1427                let grad_vol_intersection_delta_a = vol_intersection * 0.5 * gradient_strength;
1428                let grad_p_a_b_delta_a = grad_vol_intersection_delta_a / vol_b.max(1e-8);
1429                grad_delta_a[i] += -grad_p_a_b_delta_a / p_a_b.max(1e-8) * gradient_strength;
1430
1431                let grad_vol_intersection_delta_b = vol_intersection * 0.5 * gradient_strength;
1432                let grad_p_b_a_delta_b = grad_vol_intersection_delta_b / vol_a.max(1e-8);
1433                grad_delta_b[i] += -grad_p_b_a_delta_b / p_b_a.max(1e-8) * gradient_strength;
1434            } else {
1435                // Boxes don't overlap in this dimension - extremely strong growth to achieve overlap
1436                grad_delta_a[i] += 0.3; // Extremely strong growth for box A
1437                grad_delta_b[i] += 0.3; // Extremely strong growth for box B
1438            }
1439
1440            // Regularization gradient: ∂(λ * Vol)/∂δ = λ * Vol
1441            // Light regularization to prevent boxes from growing too large
1442            // Apply to both positive and negative pairs (but lighter for positives)
1443            grad_delta_a[i] += config.regularization * 1.0 * vol_a; // Lighter regularization
1444            grad_delta_b[i] += config.regularization * 1.0 * vol_b;
1445        }
1446    } else {
1447        // Negative pair: L = max(0, margin - max(P(A|B), P(B|A))) * λ_neg + overlap_penalty
1448        let p_a_b = if vol_b > 0.0 {
1449            vol_intersection / vol_b
1450        } else {
1451            0.0
1452        };
1453        let p_b_a = if vol_a > 0.0 {
1454            vol_intersection / vol_a
1455        } else {
1456            0.0
1457        };
1458        let max_prob = p_a_b.max(p_b_a);
1459
1460        // Always apply gradients for negative pairs (they should always be separated)
1461        // Don't check margin - always try to minimize conditional probability
1462        for i in 0..dim {
1463            // Check if boxes overlap in this dimension
1464            let overlap_i = if box_a_embed.min[i] < box_b_embed.max[i]
1465                && box_b_embed.min[i] < box_a_embed.max[i]
1466            {
1467                let min_overlap = box_a_embed.min[i].max(box_b_embed.min[i]);
1468                let max_overlap = box_a_embed.max[i].min(box_b_embed.max[i]);
1469                (max_overlap - min_overlap).max(0.0)
1470            } else {
1471                0.0
1472            };
1473
1474            if overlap_i > 0.0 {
1475                // Boxes overlap - strong gradient to separate
1476                // Move centers apart
1477                let center_a = box_a_embed.center();
1478                let center_b = box_b_embed.center();
1479                let diff = center_b[i] - center_a[i];
1480
1481                // Gradient to push boxes apart (adaptive strength based on overlap)
1482                // Stronger separation when overlap is high
1483                let overlap_factor =
1484                    (overlap_i / (box_a_embed.max[i] - box_a_embed.min[i]).max(1e-6)).min(1.0);
1485                let separation_strength = 1.5 + overlap_factor * 2.0; // 1.5 to 3.5 based on overlap (even stronger)
1486                if diff.abs() > 1e-6 {
1487                    grad_mu_a[i] += -config.negative_weight * separation_strength * diff;
1488                    grad_mu_b[i] += config.negative_weight * separation_strength * diff;
1489                } else {
1490                    // Boxes are on top of each other - push in random direction (very strong)
1491                    grad_mu_a[i] += -config.negative_weight * separation_strength * 2.5;
1492                    grad_mu_b[i] += config.negative_weight * separation_strength * 2.5;
1493                }
1494
1495                // Shrink boxes to reduce overlap (adaptive strength based on overlap amount)
1496                // More aggressive shrinking for higher overlap
1497                let overlap_ratio_dim =
1498                    overlap_i / (box_a_embed.max[i] - box_a_embed.min[i]).max(1e-6);
1499                let shrink_strength = if overlap_ratio_dim > 0.7 {
1500                    0.7 // Extremely strong shrink for very high overlap
1501                } else if overlap_ratio_dim > 0.5 {
1502                    0.6 // Very strong shrink for high overlap
1503                } else if overlap_ratio_dim > 0.3 {
1504                    0.5 // Strong shrink for moderate overlap
1505                } else {
1506                    0.35 // Moderate shrink for low overlap
1507                };
1508                grad_delta_a[i] += -config.negative_weight * shrink_strength;
1509                grad_delta_b[i] += -config.negative_weight * shrink_strength;
1510            } else {
1511                // Boxes don't overlap - NO gradient (let positive pairs grow)
1512                // Don't shrink when already separated
1513            }
1514
1515            // Gradient through overlap_penalty: ∂(2.0 * overlap_ratio)/∂δ
1516            // Overlap ratio = vol_intersection / min(vol_a, vol_b)
1517            // ONLY penalize if actually overlapping - stronger penalty for higher overlap
1518            if overlap_i > 0.0 && vol_intersection > 1e-10 {
1519                let min_vol = vol_a.min(vol_b);
1520                let overlap_ratio = vol_intersection / min_vol.max(1e-10);
1521                // Stronger penalty for higher overlap ratios (more aggressive)
1522                // Match the loss function: 2.5x/3.0x/4.0x based on overlap
1523                let penalty_strength = if overlap_ratio > 0.5 {
1524                    0.4 + overlap_ratio * 0.6 // 0.7 to 1.0 for very high overlap
1525                } else if overlap_ratio > 0.3 {
1526                    0.3 + overlap_ratio * 0.5 // 0.39 to 0.45 for moderate-high overlap
1527                } else {
1528                    0.2 + overlap_ratio * 0.4 // 0.2 to 0.32 for low overlap
1529                };
1530                let penalty_multiplier = if overlap_ratio > 0.5 {
1531                    4.0
1532                } else if overlap_ratio > 0.3 {
1533                    3.0
1534                } else {
1535                    2.5
1536                };
1537                grad_delta_a[i] +=
1538                    config.negative_weight * penalty_multiplier * overlap_ratio * penalty_strength;
1539                grad_delta_b[i] +=
1540                    config.negative_weight * penalty_multiplier * overlap_ratio * penalty_strength;
1541            }
1542
1543            // Gradient through base_loss and margin_loss: minimize max_prob
1544            // ∂(0.2 * max_prob)/∂δ = 0.2 * ∂max_prob/∂δ
1545            // ∂((max_prob - margin)^2)/∂δ = 2 * (max_prob - margin) * ∂max_prob/∂δ
1546            // Always apply these gradients (not just when overlapping) to keep scores low
1547            if p_a_b >= p_b_a {
1548                // p_a_b is the max, minimize it
1549                if overlap_i > 0.0 && vol_intersection > 1e-10 {
1550                    let grad_vol_intersection_delta_a = vol_intersection * 0.4;
1551                    let grad_p_a_b_delta_a = grad_vol_intersection_delta_a / vol_b.max(1e-8);
1552                    // Positive gradient on delta = shrink box (reduce intersection)
1553                    grad_delta_a[i] += config.negative_weight * 0.2 * grad_p_a_b_delta_a;
1554
1555                    // Also add margin loss gradient if active
1556                    if max_prob > config.margin {
1557                        let excess = max_prob - config.margin;
1558                        let margin_grad = 2.0 * excess * (1.0 + excess * 2.0) * grad_p_a_b_delta_a
1559                            + 2.0 * excess.powi(2) * 2.0 * grad_p_a_b_delta_a; // Exponential scaling
1560                        grad_delta_a[i] += config.negative_weight * margin_grad;
1561                    }
1562
1563                    // Adaptive penalty gradient for very high probabilities (stronger for higher probs)
1564                    if max_prob > 0.1 {
1565                        // Exponential scaling: stronger penalty for very high probabilities
1566                        let prob_excess = max_prob - 0.1;
1567                        let adaptive_grad =
1568                            2.0 * prob_excess * grad_p_a_b_delta_a * (3.0 + prob_excess * 7.0); // Stronger
1569                        grad_delta_a[i] += config.negative_weight * adaptive_grad;
1570                    } else if max_prob > 0.05 {
1571                        // Moderate penalty gradient for medium-high probabilities
1572                        let prob_excess = max_prob - 0.05;
1573                        let adaptive_grad = 2.0 * prob_excess * grad_p_a_b_delta_a * 1.5; // Stronger
1574                        grad_delta_a[i] += config.negative_weight * adaptive_grad;
1575                    } else if max_prob > 0.02 {
1576                        // Light penalty gradient for low-medium probabilities
1577                        let prob_excess = max_prob - 0.02;
1578                        let adaptive_grad = 2.0 * prob_excess * grad_p_a_b_delta_a * 0.5;
1579                        grad_delta_a[i] += config.negative_weight * adaptive_grad;
1580                    }
1581                }
1582                // Don't add extra shrink when not overlapping - let positive pairs grow
1583            } else {
1584                // p_b_a is the max, minimize it
1585                if overlap_i > 0.0 && vol_intersection > 1e-10 {
1586                    let grad_vol_intersection_delta_b = vol_intersection * 0.4;
1587                    let grad_p_b_a_delta_b = grad_vol_intersection_delta_b / vol_a.max(1e-8);
1588                    // Positive gradient on delta = shrink box
1589                    grad_delta_b[i] += config.negative_weight * 0.25 * grad_p_b_a_delta_b; // Slightly stronger
1590
1591                    // Also add margin loss gradient if active (stronger)
1592                    if max_prob > config.margin {
1593                        let excess = max_prob - config.margin;
1594                        let margin_grad = 2.0 * excess * (1.0 + excess * 2.0) * grad_p_b_a_delta_b
1595                            + 2.0 * excess.powi(2) * 2.0 * grad_p_b_a_delta_b; // Exponential scaling
1596                        grad_delta_b[i] += config.negative_weight * margin_grad;
1597                    }
1598
1599                    // Adaptive penalty gradient for very high probabilities (stronger for higher probs)
1600                    if max_prob > 0.1 {
1601                        // Exponential scaling: stronger penalty for very high probabilities
1602                        let prob_excess = max_prob - 0.1;
1603                        let adaptive_grad =
1604                            2.0 * prob_excess * grad_p_b_a_delta_b * (2.0 + prob_excess * 5.0);
1605                        grad_delta_b[i] += config.negative_weight * adaptive_grad;
1606                    } else if max_prob > 0.05 {
1607                        // Moderate penalty gradient for medium-high probabilities
1608                        let prob_excess = max_prob - 0.05;
1609                        let adaptive_grad = 2.0 * prob_excess * grad_p_b_a_delta_b * 1.0;
1610                        grad_delta_b[i] += config.negative_weight * adaptive_grad;
1611                    }
1612                }
1613                // Don't add extra shrink when not overlapping - let positive pairs grow
1614            }
1615        }
1616    }
1617
1618    // Clip gradients to prevent explosion
1619    for grad in &mut grad_mu_a {
1620        *grad = grad.clamp(-10.0_f32, 10.0_f32);
1621    }
1622    for grad in &mut grad_delta_a {
1623        *grad = grad.clamp(-10.0_f32, 10.0_f32);
1624    }
1625    for grad in &mut grad_mu_b {
1626        *grad = grad.clamp(-10.0_f32, 10.0_f32);
1627    }
1628    for grad in &mut grad_delta_b {
1629        *grad = grad.clamp(-10.0_f32, 10.0_f32);
1630    }
1631
1632    (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b)
1633}
1634
1635/// Sample negative pairs using self-adversarial sampling.
1636fn sample_self_adversarial_negatives(
1637    negative_pairs: &[(usize, usize)],
1638    boxes: &HashMap<usize, TrainableBox>,
1639    num_samples: usize,
1640    temperature: f32,
1641) -> Vec<usize> {
1642    // Compute scores for all negative pairs
1643    let mut scores: Vec<(usize, f32)> = negative_pairs
1644        .iter()
1645        .enumerate()
1646        .filter_map(|(idx, &(id_a, id_b))| {
1647            if let (Some(box_a), Some(box_b)) = (boxes.get(&id_a), boxes.get(&id_b)) {
1648                let box_a_embed = box_a.to_box();
1649                let box_b_embed = box_b.to_box();
1650                let score = box_a_embed.coreference_score(&box_b_embed);
1651                Some((idx, score / temperature))
1652            } else {
1653                None
1654            }
1655        })
1656        .collect();
1657
1658    // Sort by score (descending) - higher scores are "harder" negatives
1659    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1660
1661    // Sample top-k (hardest negatives)
1662    scores
1663        .into_iter()
1664        .take(num_samples)
1665        .map(|(idx, _)| idx)
1666        .collect()
1667}
1668
1669/// Get learning rate with warmup and cosine decay.
1670fn get_learning_rate(epoch: usize, total_epochs: usize, base_lr: f32, warmup_epochs: usize) -> f32 {
1671    if epoch < warmup_epochs {
1672        // Linear warmup: 0.1 * lr → lr
1673        let warmup_lr = base_lr * 0.1;
1674        warmup_lr + (base_lr - warmup_lr) * (epoch as f32 / warmup_epochs as f32)
1675    } else {
1676        // Cosine decay: lr → 0.1 * lr
1677        let progress =
1678            (epoch - warmup_epochs) as f32 / (total_epochs - warmup_epochs).max(1) as f32;
1679        let min_lr = base_lr * 0.1;
1680        min_lr + (base_lr - min_lr) * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0
1681    }
1682}
1683
1684// =============================================================================
1685// TrainableBox AMSGrad Update
1686// =============================================================================
1687
1688impl TrainableBox {
1689    /// Update box parameters using AMSGrad optimizer.
1690    pub fn update_amsgrad(
1691        &mut self,
1692        grad_mu: &[f32],
1693        grad_delta: &[f32],
1694        state: &mut AMSGradState,
1695    ) {
1696        state.t += 1;
1697        let t = state.t as f32;
1698
1699        // Update first moment (m)
1700        for (i, &grad) in grad_mu.iter().enumerate().take(self.dim) {
1701            state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * grad;
1702        }
1703
1704        // Update second moment (v) and max (v_hat)
1705        for (i, &grad) in grad_mu.iter().enumerate().take(self.dim) {
1706            let v_new = state.beta2 * state.v[i] + (1.0 - state.beta2) * grad * grad;
1707            state.v[i] = v_new;
1708            state.v_hat[i] = state.v_hat[i].max(v_new);
1709        }
1710
1711        // Bias correction for first moment
1712        let m_hat: Vec<f32> = state
1713            .m
1714            .iter()
1715            .map(|&m| m / (1.0 - state.beta1.powf(t)))
1716            .collect();
1717
1718        // Update mu
1719        for (i, &m_hat_val) in m_hat.iter().enumerate().take(self.dim) {
1720            let update = state.lr * m_hat_val / (state.v_hat[i].sqrt() + state.epsilon);
1721            self.mu[i] -= update;
1722
1723            // Ensure finite
1724            if !self.mu[i].is_finite() {
1725                self.mu[i] = 0.0;
1726            }
1727        }
1728
1729        // Similar for delta
1730        let mut m_delta = vec![0.0_f32; self.dim];
1731        let mut v_delta = vec![0.0_f32; self.dim];
1732        let mut v_hat_delta = vec![0.0_f32; self.dim];
1733
1734        for i in 0..self.dim {
1735            m_delta[i] = state.beta1 * m_delta[i] + (1.0 - state.beta1) * grad_delta[i];
1736            let v_new: f32 =
1737                state.beta2 * v_delta[i] + (1.0 - state.beta2) * grad_delta[i] * grad_delta[i];
1738            v_delta[i] = v_new;
1739            v_hat_delta[i] = v_hat_delta[i].max(v_new);
1740        }
1741
1742        let m_hat_delta: Vec<f32> = m_delta
1743            .iter()
1744            .map(|&m| m / (1.0 - state.beta1.powf(t)))
1745            .collect();
1746
1747        for i in 0..self.dim {
1748            let update = state.lr * m_hat_delta[i] / (v_hat_delta[i].sqrt() + state.epsilon);
1749            self.delta[i] -= update;
1750
1751            // Clamp delta to reasonable range (width between 0.01 and 10.0)
1752            self.delta[i] = self.delta[i].clamp(0.01_f32.ln(), 10.0_f32.ln());
1753
1754            // Ensure finite
1755            if !self.delta[i].is_finite() {
1756                self.delta[i] = 0.5_f32.ln();
1757            }
1758        }
1759    }
1760}
1761
1762// =============================================================================
1763// Simple Random Number Generator
1764// =============================================================================
1765
1766/// Simple random number generator (for when rand feature is not available).
1767///
1768/// Thread-safe implementation using atomic counter to avoid unsafe static mut.
1769fn simple_random() -> f32 {
1770    use std::collections::hash_map::DefaultHasher;
1771    use std::hash::{Hash, Hasher};
1772    use std::sync::atomic::{AtomicUsize, Ordering};
1773    use std::time::{SystemTime, UNIX_EPOCH};
1774
1775    static COUNTER: AtomicUsize = AtomicUsize::new(0);
1776
1777    // Thread-safe counter increment
1778    let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1779
1780    let mut hasher = DefaultHasher::new();
1781    // Use duration since epoch, or fallback to count if time is unavailable
1782    let time_nanos = SystemTime::now()
1783        .duration_since(UNIX_EPOCH)
1784        .map(|d| d.as_nanos())
1785        .unwrap_or(count as u128);
1786    time_nanos.hash(&mut hasher);
1787    count.hash(&mut hasher);
1788    let hash = hasher.finish();
1789    (hash as f32) / (u64::MAX as f32)
1790}