Skip to main content

anno/backends/box_embeddings_training/
algorithm.rs

1//! Box embedding trainer: gradient steps, AMSGrad, coreference training loop.
2
3#[allow(unused_imports)]
4use super::types::*;
5#[allow(unused_imports)]
6use super::*;
7use std::collections::HashMap;
8
9/// Trainer for box embedding models.
10pub struct BoxEmbeddingTrainer {
11    /// Training configuration
12    config: TrainingConfig,
13    /// Entity ID → TrainableBox mapping
14    boxes: HashMap<usize, TrainableBox>,
15    /// Entity ID → AMSGradState mapping
16    optimizer_states: HashMap<usize, AMSGradState>,
17    /// Embedding dimension
18    dim: usize,
19}
20
21impl BoxEmbeddingTrainer {
22    /// Create a new trainer.
23    ///
24    /// # Arguments
25    ///
26    /// * `config` - Training configuration
27    /// * `dim` - Embedding dimension
28    /// * `initial_embeddings` - Optional initial vector embeddings (entity_id → vector)
29    pub fn new(
30        config: TrainingConfig,
31        dim: usize,
32        initial_embeddings: Option<HashMap<usize, Vec<f32>>>,
33    ) -> Self {
34        let mut boxes = HashMap::new();
35        let mut optimizer_states = HashMap::new();
36
37        if let Some(embeddings) = initial_embeddings {
38            // Initialize from vector embeddings
39            for (entity_id, vector) in embeddings {
40                assert_eq!(vector.len(), dim);
41                let box_embedding = TrainableBox::from_vector(&vector, 0.1);
42                boxes.insert(entity_id, box_embedding.clone());
43                optimizer_states.insert(entity_id, AMSGradState::new(dim, config.learning_rate));
44            }
45        }
46
47        Self {
48            config,
49            boxes,
50            optimizer_states,
51            dim,
52        }
53    }
54
55    /// Initialize boxes from entities.
56    ///
57    /// Creates trainable boxes for all entities, either from provided
58    /// vector embeddings or random initialization.
59    ///
60    /// **Key insight**: For positive pairs (entities that corefer), initialize
61    /// boxes to overlap so gradients can flow from the start.
62    ///
63    /// # Arguments
64    ///
65    /// * `examples` - Training examples with entities and coreference chains
66    /// * `initial_embeddings` - Optional pre-computed vector embeddings (entity_id → vector)
67    ///   If provided, boxes are initialized around these vectors. If None, uses smart
68    ///   random initialization with shared centers for coreferent entities.
69    pub fn initialize_boxes(
70        &mut self,
71        examples: &[TrainingExample],
72        initial_embeddings: Option<&HashMap<usize, Vec<f32>>>,
73    ) {
74        // Collect all unique entity IDs and build coreference groups
75        let mut entity_ids = std::collections::HashSet::new();
76        let mut coref_groups: Vec<Vec<usize>> = Vec::new();
77
78        for example in examples {
79            for entity in &example.entities {
80                let entity_id = entity.start;
81                entity_ids.insert(entity_id);
82            }
83
84            // Build groups of entities that corefer
85            for chain in &example.chains {
86                let group: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
87                if group.len() > 1 {
88                    coref_groups.push(group);
89                }
90            }
91        }
92
93        // Initialize boxes
94        for &entity_id in &entity_ids {
95            // If vector embeddings provided, use them (better initialization)
96            if let Some(embeddings) = initial_embeddings {
97                if let Some(vector) = embeddings.get(&entity_id) {
98                    // Normalize vector to unit length for better initialization
99                    let norm: f32 = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
100                    let normalized: Vec<f32> = if norm > 0.0 {
101                        vector.iter().map(|&x| x / norm).collect()
102                    } else {
103                        vector.clone()
104                    };
105
106                    // Use larger initial width (0.2) when starting from vectors
107                    // This ensures boxes can overlap even if vectors are slightly different
108                    let box_embedding = TrainableBox::from_vector(&normalized, 0.2);
109                    self.boxes.insert(entity_id, box_embedding.clone());
110                    self.optimizer_states.insert(
111                        entity_id,
112                        AMSGradState::new(self.dim, self.config.learning_rate),
113                    );
114                    continue;
115                }
116            }
117
118            // Check if this entity is in a coreference group
119            let mut group_center: Option<Vec<f32>> = None;
120            let mut in_coref_group = false;
121
122            for group in &coref_groups {
123                if group.contains(&entity_id) {
124                    // Use a shared center for all entities in the group
125                    if group_center.is_none() {
126                        group_center = Some(
127                            (0..self.dim)
128                                .map(|_| (simple_random() - 0.5) * 0.3) // Smaller region for better overlap
129                                .collect(),
130                        );
131                    }
132                    in_coref_group = true;
133                    break;
134                }
135            }
136
137            // Initialize: if in coref group, use shared center; otherwise random
138            let mu = if let Some(ref center) = group_center {
139                // Add very small random offset to shared center (ensures overlap)
140                center
141                    .iter()
142                    .map(|&c| c + (simple_random() - 0.5) * 0.05) // Very small offset
143                    .collect()
144            } else {
145                // Random center, but spread out more to avoid accidental overlap
146                (0..self.dim)
147                    .map(|_| (simple_random() - 0.5) * 1.0)
148                    .collect()
149            };
150
151            // Balanced initialization: coreferent entities should overlap significantly
152            // Non-coreferent entities should be distinct and compact
153            // Initialize with larger width for coreferent entities (to ensure overlap)
154            let initial_width = if in_coref_group {
155                1.1_f32 // Good width for coreferent entities (ensures overlap but allows learning)
156            } else {
157                0.18_f32 // Small width for non-coreferent (distinct but not too small)
158            };
159            let delta: Vec<f32> = vec![initial_width.ln(); self.dim];
160            let box_embedding = TrainableBox::new(mu, delta);
161            self.boxes.insert(entity_id, box_embedding.clone());
162            self.optimizer_states.insert(
163                entity_id,
164                AMSGradState::new(self.dim, self.config.learning_rate),
165            );
166        }
167    }
168
169    /// Train on a single example.
170    fn train_example(&mut self, example: &TrainingExample, epoch: usize) -> f32 {
171        let mut total_loss = 0.0;
172        let mut num_pairs = 0;
173
174        // Update learning rate with warmup and decay
175        let current_lr = get_learning_rate(
176            epoch,
177            self.config.epochs,
178            self.config.learning_rate,
179            self.config.warmup_epochs,
180        );
181        for state in self.optimizer_states.values_mut() {
182            state.set_lr(current_lr);
183        }
184
185        // Build positive pairs (entities in same chain)
186        let mut positive_pairs = Vec::new();
187        for chain in &example.chains {
188            let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
189            for i in 0..mentions.len() {
190                for j in (i + 1)..mentions.len() {
191                    positive_pairs.push((mentions[i], mentions[j]));
192                }
193            }
194        }
195
196        // Build negative pairs (entities in different chains)
197        let mut negative_pairs = Vec::new();
198        for i in 0..example.chains.len() {
199            for j in (i + 1)..example.chains.len() {
200                let chain_i: Vec<usize> =
201                    example.chains[i].mentions.iter().map(|m| m.start).collect();
202                let chain_j: Vec<usize> =
203                    example.chains[j].mentions.iter().map(|m| m.start).collect();
204                for &id_i in &chain_i {
205                    for &id_j in &chain_j {
206                        negative_pairs.push((id_i, id_j));
207                    }
208                }
209            }
210        }
211
212        // Accumulate gradients for all pairs
213        let mut gradients: HashMap<usize, (Vec<f32>, Vec<f32>)> = HashMap::new();
214
215        // Process positive pairs
216        for &(id_a, id_b) in &positive_pairs {
217            // Clone boxes for gradient computation
218            let box_a = self.boxes.get(&id_a).cloned();
219            let box_b = self.boxes.get(&id_b).cloned();
220
221            if let (Some(box_a_ref), Some(box_b_ref)) = (box_a.as_ref(), box_b.as_ref()) {
222                let loss = compute_pair_loss(box_a_ref, box_b_ref, true, &self.config);
223                total_loss += loss;
224                num_pairs += 1;
225
226                // Compute analytical gradients
227                let (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b) =
228                    compute_analytical_gradients(box_a_ref, box_b_ref, true, &self.config);
229
230                // Skip if gradients are invalid
231                if grad_mu_a.iter().any(|&x| !x.is_finite())
232                    || grad_delta_a.iter().any(|&x| !x.is_finite())
233                    || grad_mu_b.iter().any(|&x| !x.is_finite())
234                    || grad_delta_b.iter().any(|&x| !x.is_finite())
235                {
236                    continue;
237                }
238
239                // Accumulate gradients
240                let entry_a = gradients
241                    .entry(id_a)
242                    .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
243                for i in 0..self.dim {
244                    entry_a.0[i] += grad_mu_a[i];
245                    entry_a.1[i] += grad_delta_a[i];
246                }
247
248                let entry_b = gradients
249                    .entry(id_b)
250                    .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
251                for i in 0..self.dim {
252                    entry_b.0[i] += grad_mu_b[i];
253                    entry_b.1[i] += grad_delta_b[i];
254                }
255            }
256        }
257
258        // Process negative pairs (with self-adversarial sampling if enabled)
259        let negative_samples: Vec<(usize, usize)> =
260            if self.config.use_self_adversarial && !negative_pairs.is_empty() {
261                // Sample based on current predictions
262                let num_samples = positive_pairs.len().min(negative_pairs.len());
263                let sampled_indices = sample_self_adversarial_negatives(
264                    &negative_pairs,
265                    &self.boxes,
266                    num_samples,
267                    self.config.adversarial_temperature,
268                );
269                sampled_indices
270                    .iter()
271                    .map(|&idx| negative_pairs[idx])
272                    .collect()
273            } else {
274                // Uniform sampling
275                let num_samples = positive_pairs.len().min(negative_pairs.len());
276                negative_pairs.into_iter().take(num_samples).collect()
277            };
278
279        for &(id_a, id_b) in &negative_samples {
280            // Clone boxes for gradient computation
281            let box_a = self.boxes.get(&id_a).cloned();
282            let box_b = self.boxes.get(&id_b).cloned();
283
284            if let (Some(box_a_ref), Some(box_b_ref)) = (box_a.as_ref(), box_b.as_ref()) {
285                let loss = compute_pair_loss(box_a_ref, box_b_ref, false, &self.config);
286                total_loss += loss;
287                num_pairs += 1;
288
289                // Compute analytical gradients
290                let (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b) =
291                    compute_analytical_gradients(box_a_ref, box_b_ref, false, &self.config);
292
293                // Skip if gradients are invalid
294                if grad_mu_a.iter().any(|&x| !x.is_finite())
295                    || grad_delta_a.iter().any(|&x| !x.is_finite())
296                    || grad_mu_b.iter().any(|&x| !x.is_finite())
297                    || grad_delta_b.iter().any(|&x| !x.is_finite())
298                {
299                    continue;
300                }
301
302                // Accumulate gradients
303                let entry_a = gradients
304                    .entry(id_a)
305                    .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
306                for i in 0..self.dim {
307                    entry_a.0[i] += grad_mu_a[i];
308                    entry_a.1[i] += grad_delta_a[i];
309                }
310
311                let entry_b = gradients
312                    .entry(id_b)
313                    .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
314                for i in 0..self.dim {
315                    entry_b.0[i] += grad_mu_b[i];
316                    entry_b.1[i] += grad_delta_b[i];
317                }
318            }
319        }
320
321        // Apply accumulated gradients using AMSGrad
322        for (entity_id, (grad_mu, grad_delta)) in gradients {
323            if let (Some(box_mut), Some(state)) = (
324                self.boxes.get_mut(&entity_id),
325                self.optimizer_states.get_mut(&entity_id),
326            ) {
327                box_mut.update_amsgrad(&grad_mu, &grad_delta, state);
328            }
329        }
330
331        if num_pairs > 0 {
332            total_loss / num_pairs as f32
333        } else {
334            0.0
335        }
336    }
337
338    /// Train on a dataset with mini-batching and early stopping.
339    /// Uses adaptive negative weighting: starts with low weight to learn positives,
340    /// then gradually increases to separate negatives.
341    pub fn train(&mut self, examples: &[TrainingExample]) -> Vec<f32> {
342        let mut losses = Vec::new();
343        let mut best_loss = f32::INFINITY;
344        let mut patience_counter = 0;
345
346        // Track score gap for adaptive weighting
347        let mut score_gap_history = Vec::new();
348
349        for epoch in 0..self.config.epochs {
350            // Multi-stage training: focus on positives first, then negatives
351            let (avg_pos, avg_neg, _) = self.get_overlap_stats(examples);
352            let current_gap = avg_pos - avg_neg;
353            score_gap_history.push(current_gap);
354
355            // Determine training stage
356            let positive_focus_epochs = self
357                .config
358                .positive_focus_epochs
359                .unwrap_or(self.config.epochs / 3);
360            let is_positive_stage = epoch < positive_focus_epochs;
361
362            // Calculate adaptive negative weight based on stage and performance
363            let adaptive_negative_weight = if is_positive_stage {
364                // Stage 1: Focus on positive learning - low negative weight but not zero
365                // Still apply some negative gradients to prevent negative scores from growing too much
366                // Gradually increase from 0.2 to 0.3 during positive stage
367                let stage_progress = epoch as f32 / positive_focus_epochs as f32;
368                self.config.negative_weight * (0.2 + stage_progress * 0.1)
369            } else if avg_pos > 0.05 && avg_neg > 0.3 {
370                // Stage 2: Positive learning is good but negatives are too high - aggressive separation
371                // Increase negative weight more aggressively
372                let progress = ((epoch - positive_focus_epochs) as f32
373                    / (self.config.epochs - positive_focus_epochs) as f32)
374                    .min(1.0);
375                // Scale based on how bad negatives are - more aggressive
376                let neg_penalty = (avg_neg / 0.4).min(1.0); // Penalty factor for high negatives (lower threshold)
377                self.config.negative_weight * (0.7 + progress * 0.8 + neg_penalty * 0.4).min(2.0)
378            // Up to 2.0x
379            } else if avg_pos > 0.02 && current_gap > 0.0 {
380                // Stage 2: Positive learning is good, can focus on separation
381                // Gradually increase negative weight as gap improves
382                let progress = ((epoch - positive_focus_epochs) as f32
383                    / (self.config.epochs - positive_focus_epochs) as f32)
384                    .min(1.0);
385                self.config.negative_weight * (0.5 + progress * 0.5).min(1.0 + (current_gap / 0.1))
386            } else if avg_pos < 0.01 {
387                // Positive scores too low, reduce negative weight
388                self.config.negative_weight * 0.3
389            } else {
390                // Default behavior - moderate weight, gradually increase
391                let progress = ((epoch - positive_focus_epochs) as f32
392                    / (self.config.epochs - positive_focus_epochs) as f32)
393                    .min(1.0);
394                self.config.negative_weight * (0.4 + progress * 0.4)
395            };
396
397            // Temporarily override negative weight for this epoch
398            let original_negative_weight = self.config.negative_weight;
399            self.config.negative_weight = adaptive_negative_weight;
400            // Shuffle examples for better training (simple Fisher-Yates)
401            let mut shuffled_indices: Vec<usize> = (0..examples.len()).collect();
402            for i in (1..shuffled_indices.len()).rev() {
403                let j = (simple_random() * (i + 1) as f32) as usize;
404                shuffled_indices.swap(i, j);
405            }
406
407            let mut epoch_loss = 0.0;
408            let mut num_batches = 0;
409
410            // Mini-batch training
411            for batch_start in (0..examples.len()).step_by(self.config.batch_size) {
412                let batch_end = (batch_start + self.config.batch_size).min(examples.len());
413                let batch_indices = &shuffled_indices[batch_start..batch_end];
414
415                let mut batch_loss = 0.0;
416                let mut batch_pairs = 0;
417
418                // Process batch
419                for &idx in batch_indices {
420                    let example = &examples[idx];
421                    let loss = self.train_example(example, epoch);
422                    batch_loss += loss;
423                    batch_pairs += 1;
424                }
425
426                if batch_pairs > 0 {
427                    epoch_loss += batch_loss / batch_pairs as f32;
428                    num_batches += 1;
429                }
430            }
431
432            let avg_loss = if num_batches > 0 {
433                epoch_loss / num_batches as f32
434            } else {
435                0.0
436            };
437            losses.push(avg_loss);
438
439            let current_lr = get_learning_rate(
440                epoch,
441                self.config.epochs,
442                self.config.learning_rate,
443                self.config.warmup_epochs,
444            );
445
446            // Early stopping check
447            let improved = avg_loss < best_loss - self.config.early_stopping_min_delta;
448            if improved {
449                best_loss = avg_loss;
450                patience_counter = 0;
451            } else {
452                patience_counter += 1;
453            }
454
455            // Show overlap stats periodically
456            if epoch % 10 == 0 || epoch == self.config.epochs - 1 || improved {
457                let (avg_pos, avg_neg, overlap_rate) = self.get_overlap_stats(examples);
458                let status = if improved { "✓" } else { " " };
459                let patience_info = if let Some(patience) = self.config.early_stopping_patience {
460                    format!(", patience={}/{}", patience_counter, patience)
461                } else {
462                    String::new()
463                };
464                let loss_reduction = if losses.len() > 1 {
465                    format!(" ({:.1}%↓)", (1.0 - avg_loss / losses[0]) * 100.0)
466                } else {
467                    String::new()
468                };
469                let score_gap = avg_pos - avg_neg; // Positive should be higher than negative
470                let positive_focus_epochs = self
471                    .config
472                    .positive_focus_epochs
473                    .unwrap_or(self.config.epochs / 3);
474                let stage = if epoch < positive_focus_epochs {
475                    "P+"
476                } else {
477                    "S-"
478                };
479                println!("Epoch {}: loss = {:.4}{}, lr = {:.6}, best = {:.4} {} ({} batches{}, neg_w={:.2}, stage={})",
480                    epoch, avg_loss, loss_reduction, current_lr, best_loss, status, num_batches, patience_info, adaptive_negative_weight, stage);
481                println!(
482                    "  Overlap: {:.1}%, Pos: {:.4}, Neg: {:.4}, Gap: {:.4} {}",
483                    overlap_rate * 100.0,
484                    avg_pos,
485                    avg_neg,
486                    score_gap,
487                    if score_gap > 0.0 { "✓" } else { "⚠" }
488                );
489            }
490
491            // Restore original negative weight
492            self.config.negative_weight = original_negative_weight;
493
494            // Early stopping
495            if let Some(patience) = self.config.early_stopping_patience {
496                if patience_counter >= patience {
497                    println!(
498                        "Early stopping at epoch {} (no improvement for {} epochs)",
499                        epoch, patience
500                    );
501                    break;
502                }
503            }
504        }
505
506        losses
507    }
508
509    /// Get trained boxes for inference.
510    pub fn get_boxes(&self) -> HashMap<usize, BoxEmbedding> {
511        self.boxes
512            .iter()
513            .map(|(id, trainable)| (*id, trainable.to_box()))
514            .collect()
515    }
516
517    /// Get diagnostic statistics about box overlaps.
518    ///
519    /// Returns (avg_positive_score, avg_negative_score, overlap_rate)
520    pub fn get_overlap_stats(&self, examples: &[TrainingExample]) -> (f32, f32, f32) {
521        let mut positive_scores = Vec::new();
522        let mut negative_scores = Vec::new();
523        let mut overlapping_pairs = 0;
524        let mut total_pairs = 0;
525
526        for example in examples {
527            // Positive pairs
528            for chain in &example.chains {
529                let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
530                for i in 0..mentions.len() {
531                    for j in (i + 1)..mentions.len() {
532                        if let (Some(box_a), Some(box_b)) =
533                            (self.boxes.get(&mentions[i]), self.boxes.get(&mentions[j]))
534                        {
535                            let box_a_embed = box_a.to_box();
536                            let box_b_embed = box_b.to_box();
537                            let score = box_a_embed.coreference_score(&box_b_embed);
538                            positive_scores.push(score);
539                            if score > 0.01 {
540                                overlapping_pairs += 1;
541                            }
542                            total_pairs += 1;
543                        }
544                    }
545                }
546            }
547
548            // Negative pairs
549            for i in 0..example.chains.len() {
550                for j in (i + 1)..example.chains.len() {
551                    let chain_i: Vec<usize> =
552                        example.chains[i].mentions.iter().map(|m| m.start).collect();
553                    let chain_j: Vec<usize> =
554                        example.chains[j].mentions.iter().map(|m| m.start).collect();
555                    for &id_i in &chain_i {
556                        for &id_j in &chain_j {
557                            if let (Some(box_a), Some(box_b)) =
558                                (self.boxes.get(&id_i), self.boxes.get(&id_j))
559                            {
560                                let box_a_embed = box_a.to_box();
561                                let box_b_embed = box_b.to_box();
562                                let score = box_a_embed.coreference_score(&box_b_embed);
563                                negative_scores.push(score);
564                            }
565                        }
566                    }
567                }
568            }
569        }
570
571        let avg_positive = if !positive_scores.is_empty() {
572            positive_scores.iter().sum::<f32>() / positive_scores.len() as f32
573        } else {
574            0.0
575        };
576
577        let avg_negative = if !negative_scores.is_empty() {
578            negative_scores.iter().sum::<f32>() / negative_scores.len() as f32
579        } else {
580            0.0
581        };
582
583        let overlap_rate = if total_pairs > 0 {
584            overlapping_pairs as f32 / total_pairs as f32
585        } else {
586            0.0
587        };
588
589        (avg_positive, avg_negative, overlap_rate)
590    }
591
592    /// Evaluate coreference accuracy on a test set.
593    ///
594    /// Returns (accuracy, precision, recall, f1) where:
595    /// - Accuracy: fraction of pairs correctly classified
596    /// - Precision: fraction of predicted positives that are correct
597    /// - Recall: fraction of true positives that are predicted
598    /// - F1: harmonic mean of precision and recall
599    ///
600    /// **Note**: This is a simple pair-wise evaluation. For standard coreference metrics
601    /// (MUC, B³, CEAF, LEA, BLANC, CoNLL F1), use `evaluate_standard_metrics()` instead.
602    pub fn evaluate(&self, examples: &[TrainingExample], threshold: f32) -> (f32, f32, f32, f32) {
603        let mut true_positives = 0;
604        let mut false_positives = 0;
605        let mut false_negatives = 0;
606        let mut total_pairs = 0;
607
608        for example in examples {
609            // Build positive pairs (should corefer)
610            let mut positive_pairs = Vec::new();
611            for chain in &example.chains {
612                let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
613                for i in 0..mentions.len() {
614                    for j in (i + 1)..mentions.len() {
615                        positive_pairs.push((mentions[i], mentions[j]));
616                    }
617                }
618            }
619
620            // Build negative pairs (shouldn't corefer)
621            let mut negative_pairs = Vec::new();
622            for i in 0..example.chains.len() {
623                for j in (i + 1)..example.chains.len() {
624                    let chain_i: Vec<usize> =
625                        example.chains[i].mentions.iter().map(|m| m.start).collect();
626                    let chain_j: Vec<usize> =
627                        example.chains[j].mentions.iter().map(|m| m.start).collect();
628                    for &id_i in &chain_i {
629                        for &id_j in &chain_j {
630                            negative_pairs.push((id_i, id_j));
631                        }
632                    }
633                }
634            }
635
636            // Evaluate positive pairs
637            for &(id_a, id_b) in &positive_pairs {
638                total_pairs += 1;
639                if let (Some(box_a), Some(box_b)) = (self.boxes.get(&id_a), self.boxes.get(&id_b)) {
640                    let box_a_embed = box_a.to_box();
641                    let box_b_embed = box_b.to_box();
642                    let score = box_a_embed.coreference_score(&box_b_embed);
643                    if score >= threshold {
644                        true_positives += 1;
645                    } else {
646                        false_negatives += 1;
647                    }
648                } else {
649                    // Missing boxes - count as false negative (model can't predict)
650                    false_negatives += 1;
651                }
652            }
653
654            // Evaluate negative pairs
655            for &(id_a, id_b) in &negative_pairs {
656                total_pairs += 1;
657                if let (Some(box_a), Some(box_b)) = (self.boxes.get(&id_a), self.boxes.get(&id_b)) {
658                    let box_a_embed = box_a.to_box();
659                    let box_b_embed = box_b.to_box();
660                    let score = box_a_embed.coreference_score(&box_b_embed);
661                    if score >= threshold {
662                        false_positives += 1;
663                    }
664                    // If score < threshold, it's a true negative (correctly predicted as non-coreferent)
665                }
666                // If boxes are missing, we can't evaluate - don't count as error
667            }
668        }
669
670        // Compute metrics
671        let precision = if true_positives + false_positives > 0 {
672            true_positives as f32 / (true_positives + false_positives) as f32
673        } else {
674            0.0
675        };
676
677        let recall = if true_positives + false_negatives > 0 {
678            true_positives as f32 / (true_positives + false_negatives) as f32
679        } else {
680            0.0
681        };
682
683        let f1 = if precision + recall > 0.0 {
684            2.0 * precision * recall / (precision + recall)
685        } else {
686            0.0
687        };
688
689        let accuracy = if total_pairs > 0 {
690            (true_positives + (total_pairs - true_positives - false_positives - false_negatives))
691                as f32
692                / total_pairs as f32
693        } else {
694            0.0
695        };
696
697        (accuracy, precision, recall, f1)
698    }
699
700    /// Save trained boxes to a file (JSON format).
701    ///
702    /// # Arguments
703    ///
704    /// * `path` - File path to save to
705    ///
706    /// # Example
707    ///
708    /// ```rust,ignore
709    /// trainer.save_boxes("trained_boxes.json")?;
710    /// ```
711    pub fn save_boxes(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
712        use std::fs::File;
713        use std::io::Write;
714
715        let serialized = serde_json::to_string_pretty(&self.boxes)?;
716        let mut file = File::create(path)?;
717        file.write_all(serialized.as_bytes())?;
718        Ok(())
719    }
720
721    /// Load trained boxes from a file (JSON format).
722    ///
723    /// # Arguments
724    ///
725    /// * `path` - File path to load from
726    /// * `dim` - Expected embedding dimension
727    ///
728    /// # Example
729    ///
730    /// ```rust,ignore
731    /// let boxes = BoxEmbeddingTrainer::load_boxes("trained_boxes.json", 32)?;
732    /// ```
733    pub fn load_boxes(
734        path: &str,
735        dim: usize,
736    ) -> Result<HashMap<usize, TrainableBox>, Box<dyn std::error::Error>> {
737        use std::fs::File;
738        use std::io::Read;
739
740        let mut file = File::open(path)?;
741        let mut contents = String::new();
742        file.read_to_string(&mut contents)?;
743        let boxes: HashMap<usize, TrainableBox> = serde_json::from_str(&contents)?;
744
745        // Verify dimensions
746        for (id, box_embedding) in &boxes {
747            if box_embedding.dim != dim {
748                return Err(format!(
749                    "Box for entity {} has dimension {}, expected {}",
750                    id, box_embedding.dim, dim
751                )
752                .into());
753            }
754        }
755
756        Ok(boxes)
757    }
758
759    /// Evaluate using standard coreference metrics (MUC, B³, CEAF, LEA, BLANC, CoNLL F1).
760    ///
761    /// This converts the trained boxes into coreference chains and evaluates using
762    /// the standard metrics used in coreference research.
763    ///
764    /// # Arguments
765    ///
766    /// * `examples` - Test examples with gold coreference chains
767    /// * `threshold` - Coreference threshold for box overlap
768    ///
769    /// # Returns
770    ///
771    /// `CorefEvaluation` with all standard metrics
772    ///
773    /// Requires `analysis` (or legacy `eval`) feature for access to standard coref metrics.
774    #[cfg(any(feature = "analysis", feature = "eval"))]
775    pub fn evaluate_standard_metrics(
776        &self,
777        examples: &[TrainingExample],
778        threshold: f32,
779    ) -> crate::eval::coref_metrics::CorefEvaluation {
780        use crate::backends::box_embeddings::BoxCorefConfig;
781        use crate::eval::coref_metrics::CorefEvaluation;
782        use crate::eval::coref_resolver::BoxCorefResolver;
783
784        let mut all_predicted_chains = Vec::new();
785        let mut all_gold_chains = Vec::new();
786
787        for example in examples {
788            // Collect gold chains
789            all_gold_chains.extend(example.chains.clone());
790
791            // Get entities from example
792            let entities = &example.entities;
793
794            // Get boxes for entities (or create default boxes if missing)
795            let mut boxes = Vec::new();
796            for entity in entities {
797                if let Some(trainable_box) = self.boxes.get(&entity.start) {
798                    boxes.push(trainable_box.to_box());
799                } else {
800                    // Missing box - create a small default box
801                    let center = vec![0.0; self.dim];
802                    boxes.push(crate::backends::box_embeddings::BoxEmbedding::from_vector(
803                        &center, 0.1,
804                    ));
805                }
806            }
807
808            // Resolve coreference using boxes
809            let box_config = BoxCorefConfig {
810                coreference_threshold: threshold,
811                ..Default::default()
812            };
813            let resolver = BoxCorefResolver::new(box_config);
814            let resolved_entities = resolver.resolve_with_boxes(entities, &boxes);
815
816            // Convert resolved entities to chains
817            let predicted_chains = anno_core::core::coref::entities_to_chains(&resolved_entities);
818            all_predicted_chains.extend(predicted_chains);
819        }
820
821        // Compute standard metrics
822        CorefEvaluation::compute(&all_predicted_chains, &all_gold_chains)
823    }
824}
825
826/// Split training examples into train/validation sets.
827///
828/// # Arguments
829///
830/// * `examples` - All training examples
831/// * `val_ratio` - Fraction of examples to use for validation (0.0-1.0)
832///
833/// # Returns
834///
835/// (train_examples, val_examples)
836pub fn split_train_val(
837    examples: &[TrainingExample],
838    val_ratio: f32,
839) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
840    let val_size = (examples.len() as f32 * val_ratio) as usize;
841    let mut shuffled: Vec<TrainingExample> = examples.to_vec();
842
843    // Simple shuffle
844    for i in (1..shuffled.len()).rev() {
845        let j = (simple_random() * (i + 1) as f32) as usize;
846        shuffled.swap(i, j);
847    }
848
849    let val_examples = shuffled.split_off(val_size);
850    (shuffled, val_examples)
851}
852
853// =============================================================================
854// Loss and Gradient Computation
855// =============================================================================
856
857/// Compute loss for a pair of boxes.
858fn compute_pair_loss(
859    box_a: &TrainableBox,
860    box_b: &TrainableBox,
861    is_positive: bool,
862    config: &TrainingConfig,
863) -> f32 {
864    let box_a_embed = box_a.to_box();
865    let box_b_embed = box_b.to_box();
866
867    if is_positive {
868        // Positive pair: maximize conditional probability
869        let p_a_b = box_a_embed.conditional_probability(&box_b_embed);
870        let p_b_a = box_b_embed.conditional_probability(&box_a_embed);
871
872        // Clamp probabilities to avoid log(0)
873        let p_a_b = p_a_b.max(1e-8);
874        let p_b_a = p_b_a.max(1e-8);
875
876        // Use symmetric score: min of both conditional probabilities
877        // This ensures both boxes must overlap significantly
878        let min_prob = p_a_b.min(p_b_a);
879        let neg_log_prob = -min_prob.ln();
880
881        // Also add penalty if boxes are too far apart (encourages movement)
882        let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
883        let distance_penalty = if vol_intersection < 1e-10 {
884            // Boxes don't overlap - add distance penalty
885            let center_a = box_a_embed.center();
886            let center_b = box_b_embed.center();
887            let dist: f32 = center_a
888                .iter()
889                .zip(center_b.iter())
890                .map(|(a, b)| (a - b).powi(2))
891                .sum::<f32>()
892                .sqrt();
893            0.3 * dist // Stronger penalty to encourage boxes to move closer
894        } else {
895            // Even if overlapping, encourage more overlap
896            // Penalize if overlap is too small
897            let vol_a = box_a_embed.volume();
898            let vol_b = box_b_embed.volume();
899            let overlap_ratio = vol_intersection / vol_a.min(vol_b).max(1e-10);
900            if overlap_ratio < 0.5 {
901                // Encourage more overlap
902                0.1 * (0.5 - overlap_ratio)
903            } else {
904                0.0
905            }
906        };
907
908        // Regularization: penalize large volumes
909        let vol_a = box_a_embed.volume();
910        let vol_b = box_b_embed.volume();
911        // Light regularization to prevent boxes from growing too large
912        let reg = config.regularization * 1.0 * (vol_a + vol_b);
913
914        (neg_log_prob + reg + distance_penalty).max(0.0)
915    } else {
916        // Negative pair: enforce separation
917        // We want conditional probability to be LOW (boxes should be disjoint)
918        let p_a_b = box_a_embed.conditional_probability(&box_b_embed);
919        let p_b_a = box_b_embed.conditional_probability(&box_a_embed);
920
921        // Use max of both conditional probabilities
922        let max_prob = p_a_b.max(p_b_a);
923
924        // Loss: penalize high conditional probability
925        // Use hinge loss instead of quadratic for smoother gradients
926        let margin_loss = if max_prob > config.margin {
927            // Stronger quadratic penalty for exceeding margin
928            let excess = max_prob - config.margin;
929            excess.powi(2) * (1.0 + excess * 2.0) // Exponential scaling
930        } else {
931            0.0 // No loss if already below margin (good!)
932        };
933
934        // Add extra penalty for very high probabilities (exponential decay)
935        // Note: This is currently not used in the loss calculation but kept for future use
936        let _high_prob_penalty = if max_prob > 0.1 {
937            (max_prob - 0.1).powi(2) * 0.5 // Extra penalty for very high probabilities
938        } else {
939            0.0
940        };
941
942        // Add extra penalty if boxes overlap significantly
943        let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
944        let vol_a = box_a_embed.volume();
945        let vol_b = box_b_embed.volume();
946        let overlap_penalty = if vol_intersection > 1e-10 {
947            // Boxes overlap - add strong penalty (more aggressive for higher overlap)
948            let overlap_ratio = vol_intersection / vol_a.min(vol_b).max(1e-10);
949            // Exponential penalty for high overlap
950            if overlap_ratio > 0.5 {
951                4.0 * overlap_ratio * overlap_ratio // Stronger quadratic penalty
952            } else if overlap_ratio > 0.3 {
953                3.0 * overlap_ratio // Stronger linear penalty for moderate overlap
954            } else {
955                2.5 * overlap_ratio // Linear penalty for low overlap
956            }
957        } else {
958            0.0
959        };
960
961        // Base loss: only penalize if probability is significant (above a threshold)
962        // Don't penalize tiny probabilities - they're fine
963        let base_loss = if max_prob > 0.01 {
964            max_prob * 0.2 // Stronger penalty for significant probabilities
965        } else {
966            0.0 // No penalty for tiny probabilities
967        };
968
969        // Adaptive penalty: stronger for very high probabilities
970        // Use exponential scaling for probabilities > 0.1, moderate for > 0.05
971        let adaptive_penalty = if max_prob > 0.1 {
972            // Exponential penalty for very high probabilities (stronger scaling)
973            let prob_excess = max_prob - 0.1;
974            prob_excess.powi(2) * (3.0 + prob_excess * 7.0) // Stronger exponential scaling
975        } else if max_prob > 0.05 {
976            // Moderate penalty for medium-high probabilities
977            (max_prob - 0.05).powi(2) * 1.5 // Stronger
978        } else if max_prob > 0.02 {
979            // Light penalty for low-medium probabilities
980            (max_prob - 0.02).powi(2) * 0.5
981        } else {
982            0.0
983        };
984
985        config.negative_weight * (margin_loss + overlap_penalty + base_loss + adaptive_penalty)
986    }
987}
988
989/// Compute analytical gradients for a pair of boxes.
990fn compute_analytical_gradients(
991    box_a: &TrainableBox,
992    box_b: &TrainableBox,
993    is_positive: bool,
994    config: &TrainingConfig,
995) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
996    let box_a_embed = box_a.to_box();
997    let box_b_embed = box_b.to_box();
998    let dim = box_a.dim;
999
1000    // Initialize gradients
1001    let mut grad_mu_a = vec![0.0; dim];
1002    let mut grad_delta_a = vec![0.0; dim];
1003    let mut grad_mu_b = vec![0.0; dim];
1004    let mut grad_delta_b = vec![0.0; dim];
1005
1006    // Compute intersection volume and individual volumes
1007    let vol_a = box_a_embed.volume();
1008    let vol_b = box_b_embed.volume();
1009    let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1010
1011    if is_positive {
1012        // Positive pair: L = -log(P(A|B)) - log(P(B|A)) + reg
1013        // P(A|B) = Vol(A ∩ B) / Vol(B)
1014        // P(B|A) = Vol(A ∩ B) / Vol(A)
1015
1016        let p_a_b = if vol_b > 0.0 {
1017            vol_intersection / vol_b
1018        } else {
1019            0.0
1020        };
1021        let p_b_a = if vol_a > 0.0 {
1022            vol_intersection / vol_a
1023        } else {
1024            0.0
1025        };
1026
1027        // Clamp to avoid division by zero
1028        let p_a_b = p_a_b.max(1e-8);
1029        let p_b_a = p_b_a.max(1e-8);
1030
1031        // Gradients through -log(P(A|B)) = -log(Vol_intersection) + log(Vol_B)
1032        // For positive pairs, we want to maximize overlap
1033        // If boxes don't overlap, we need gradients to move them together
1034
1035        // Check if boxes overlap
1036        let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1037        let has_overlap = vol_intersection > 1e-10;
1038
1039        if !has_overlap {
1040            // Boxes don't overlap - add very strong gradient to move centers closer
1041            let center_a = box_a_embed.center();
1042            let center_b = box_b_embed.center();
1043            let center_dist = center_a
1044                .iter()
1045                .zip(center_b.iter())
1046                .map(|(a, b)| (a - b).powi(2))
1047                .sum::<f32>()
1048                .sqrt();
1049
1050            for i in 0..dim {
1051                let diff = center_b[i] - center_a[i];
1052                // Adaptive strength based on distance - stronger when far apart
1053                let distance_factor = (center_dist / dim as f32).clamp(0.5, 2.0);
1054                let attraction_strength = 4.0 * distance_factor; // Stronger when far apart
1055
1056                grad_mu_a[i] += attraction_strength * diff;
1057                grad_mu_b[i] += -attraction_strength * diff;
1058
1059                // Strong increase in box sizes to help them overlap
1060                grad_delta_a[i] += 0.5 * distance_factor; // Stronger growth when far apart
1061                grad_delta_b[i] += 0.5 * distance_factor;
1062            }
1063        }
1064
1065        for i in 0..dim {
1066            // Gradient w.r.t. box A
1067            // ∂(-log(P(A|B)))/∂δ_A = -1/P(A|B) * ∂P(A|B)/∂δ_A
1068            // ∂P(A|B)/∂δ_A = (1/Vol_B) * ∂Vol_intersection/∂δ_A
1069            // ∂Vol_intersection/∂δ_A = Vol_intersection (if boxes overlap in dim i)
1070
1071            let overlap_i = if box_a_embed.min[i] < box_b_embed.max[i]
1072                && box_b_embed.min[i] < box_a_embed.max[i]
1073            {
1074                // Boxes overlap in dimension i
1075                let min_overlap = box_a_embed.min[i].max(box_b_embed.min[i]);
1076                let max_overlap = box_a_embed.max[i].min(box_b_embed.max[i]);
1077                (max_overlap - min_overlap).max(0.0)
1078            } else {
1079                0.0
1080            };
1081
1082            if overlap_i > 0.0 && vol_intersection > 0.0 {
1083                // Gradient through intersection volume
1084                // When boxes overlap, focus on improving overlap ratio (not just growing boxes)
1085                let overlap_ratio_a = vol_intersection / vol_a.max(1e-10);
1086                let overlap_ratio_b = vol_intersection / vol_b.max(1e-10);
1087
1088                // If overlap ratio is low, encourage growth; if high, maintain
1089                // Adaptive growth based on current overlap - more aggressive
1090                if overlap_ratio_a < 0.15 {
1091                    // Extremely low overlap - encourage very strong growth
1092                    grad_delta_a[i] += 0.35;
1093                } else if overlap_ratio_a < 0.3 {
1094                    // Very low overlap - encourage extremely strong growth
1095                    grad_delta_a[i] += 0.3;
1096                } else if overlap_ratio_a < 0.5 {
1097                    // Low overlap - encourage very strong growth
1098                    grad_delta_a[i] += 0.2;
1099                } else if overlap_ratio_a < 0.7 {
1100                    // Moderate overlap - strong growth
1101                    grad_delta_a[i] += 0.1;
1102                } else if overlap_ratio_a < 0.85 {
1103                    // Good overlap - small growth
1104                    grad_delta_a[i] += 0.05;
1105                }
1106                // If overlap_ratio_a >= 0.85, don't grow (excellent)
1107
1108                if overlap_ratio_b < 0.15 {
1109                    // Extremely low overlap - encourage very strong growth
1110                    grad_delta_b[i] += 0.35;
1111                } else if overlap_ratio_b < 0.3 {
1112                    // Very low overlap - encourage extremely strong growth
1113                    grad_delta_b[i] += 0.3;
1114                } else if overlap_ratio_b < 0.5 {
1115                    // Low overlap - encourage very strong growth
1116                    grad_delta_b[i] += 0.2;
1117                } else if overlap_ratio_b < 0.7 {
1118                    // Moderate overlap - strong growth
1119                    grad_delta_b[i] += 0.1;
1120                } else if overlap_ratio_b < 0.85 {
1121                    // Good overlap - small growth
1122                    grad_delta_b[i] += 0.05;
1123                }
1124
1125                // Gradient through conditional probability (main signal) - adaptive strength
1126                // Stronger when overlap is low, gentler when overlap is good
1127                let gradient_strength = if overlap_ratio_a < 0.1 {
1128                    1.7 // Extremely strong when overlap is extremely low
1129                } else if overlap_ratio_a < 0.2 {
1130                    1.6 // Extremely strong when overlap is very low
1131                } else if overlap_ratio_a < 0.4 {
1132                    1.4 // Very strong when overlap is low
1133                } else if overlap_ratio_a < 0.6 {
1134                    1.1 // Strong when overlap is moderate
1135                } else {
1136                    0.6 // Gentle when overlap is good
1137                };
1138
1139                let grad_vol_intersection_delta_a = vol_intersection * 0.5 * gradient_strength;
1140                let grad_p_a_b_delta_a = grad_vol_intersection_delta_a / vol_b.max(1e-8);
1141                grad_delta_a[i] += -grad_p_a_b_delta_a / p_a_b.max(1e-8) * gradient_strength;
1142
1143                let grad_vol_intersection_delta_b = vol_intersection * 0.5 * gradient_strength;
1144                let grad_p_b_a_delta_b = grad_vol_intersection_delta_b / vol_a.max(1e-8);
1145                grad_delta_b[i] += -grad_p_b_a_delta_b / p_b_a.max(1e-8) * gradient_strength;
1146            } else {
1147                // Boxes don't overlap in this dimension - extremely strong growth to achieve overlap
1148                grad_delta_a[i] += 0.3; // Extremely strong growth for box A
1149                grad_delta_b[i] += 0.3; // Extremely strong growth for box B
1150            }
1151
1152            // Regularization gradient: ∂(λ * Vol)/∂δ = λ * Vol
1153            // Light regularization to prevent boxes from growing too large
1154            // Apply to both positive and negative pairs (but lighter for positives)
1155            grad_delta_a[i] += config.regularization * 1.0 * vol_a; // Lighter regularization
1156            grad_delta_b[i] += config.regularization * 1.0 * vol_b;
1157        }
1158    } else {
1159        // Negative pair: L = max(0, margin - max(P(A|B), P(B|A))) * λ_neg + overlap_penalty
1160        let p_a_b = if vol_b > 0.0 {
1161            vol_intersection / vol_b
1162        } else {
1163            0.0
1164        };
1165        let p_b_a = if vol_a > 0.0 {
1166            vol_intersection / vol_a
1167        } else {
1168            0.0
1169        };
1170        let max_prob = p_a_b.max(p_b_a);
1171
1172        // Always apply gradients for negative pairs (they should always be separated)
1173        // Don't check margin - always try to minimize conditional probability
1174        for i in 0..dim {
1175            // Check if boxes overlap in this dimension
1176            let overlap_i = if box_a_embed.min[i] < box_b_embed.max[i]
1177                && box_b_embed.min[i] < box_a_embed.max[i]
1178            {
1179                let min_overlap = box_a_embed.min[i].max(box_b_embed.min[i]);
1180                let max_overlap = box_a_embed.max[i].min(box_b_embed.max[i]);
1181                (max_overlap - min_overlap).max(0.0)
1182            } else {
1183                0.0
1184            };
1185
1186            if overlap_i > 0.0 {
1187                // Boxes overlap - strong gradient to separate
1188                // Move centers apart
1189                let center_a = box_a_embed.center();
1190                let center_b = box_b_embed.center();
1191                let diff = center_b[i] - center_a[i];
1192
1193                // Gradient to push boxes apart (adaptive strength based on overlap)
1194                // Stronger separation when overlap is high
1195                let overlap_factor =
1196                    (overlap_i / (box_a_embed.max[i] - box_a_embed.min[i]).max(1e-6)).min(1.0);
1197                let separation_strength = 1.5 + overlap_factor * 2.0; // 1.5 to 3.5 based on overlap (even stronger)
1198                if diff.abs() > 1e-6 {
1199                    grad_mu_a[i] += -config.negative_weight * separation_strength * diff;
1200                    grad_mu_b[i] += config.negative_weight * separation_strength * diff;
1201                } else {
1202                    // Boxes are on top of each other - push in random direction (very strong)
1203                    grad_mu_a[i] += -config.negative_weight * separation_strength * 2.5;
1204                    grad_mu_b[i] += config.negative_weight * separation_strength * 2.5;
1205                }
1206
1207                // Shrink boxes to reduce overlap (adaptive strength based on overlap amount)
1208                // More aggressive shrinking for higher overlap
1209                let overlap_ratio_dim =
1210                    overlap_i / (box_a_embed.max[i] - box_a_embed.min[i]).max(1e-6);
1211                let shrink_strength = if overlap_ratio_dim > 0.7 {
1212                    0.7 // Extremely strong shrink for very high overlap
1213                } else if overlap_ratio_dim > 0.5 {
1214                    0.6 // Very strong shrink for high overlap
1215                } else if overlap_ratio_dim > 0.3 {
1216                    0.5 // Strong shrink for moderate overlap
1217                } else {
1218                    0.35 // Moderate shrink for low overlap
1219                };
1220                grad_delta_a[i] += -config.negative_weight * shrink_strength;
1221                grad_delta_b[i] += -config.negative_weight * shrink_strength;
1222            } else {
1223                // Boxes don't overlap - NO gradient (let positive pairs grow)
1224                // Don't shrink when already separated
1225            }
1226
1227            // Gradient through overlap_penalty: ∂(2.0 * overlap_ratio)/∂δ
1228            // Overlap ratio = vol_intersection / min(vol_a, vol_b)
1229            // ONLY penalize if actually overlapping - stronger penalty for higher overlap
1230            if overlap_i > 0.0 && vol_intersection > 1e-10 {
1231                let min_vol = vol_a.min(vol_b);
1232                let overlap_ratio = vol_intersection / min_vol.max(1e-10);
1233                // Stronger penalty for higher overlap ratios (more aggressive)
1234                // Match the loss function: 2.5x/3.0x/4.0x based on overlap
1235                let penalty_strength = if overlap_ratio > 0.5 {
1236                    0.4 + overlap_ratio * 0.6 // 0.7 to 1.0 for very high overlap
1237                } else if overlap_ratio > 0.3 {
1238                    0.3 + overlap_ratio * 0.5 // 0.39 to 0.45 for moderate-high overlap
1239                } else {
1240                    0.2 + overlap_ratio * 0.4 // 0.2 to 0.32 for low overlap
1241                };
1242                let penalty_multiplier = if overlap_ratio > 0.5 {
1243                    4.0
1244                } else if overlap_ratio > 0.3 {
1245                    3.0
1246                } else {
1247                    2.5
1248                };
1249                grad_delta_a[i] +=
1250                    config.negative_weight * penalty_multiplier * overlap_ratio * penalty_strength;
1251                grad_delta_b[i] +=
1252                    config.negative_weight * penalty_multiplier * overlap_ratio * penalty_strength;
1253            }
1254
1255            // Gradient through base_loss and margin_loss: minimize max_prob
1256            // ∂(0.2 * max_prob)/∂δ = 0.2 * ∂max_prob/∂δ
1257            // ∂((max_prob - margin)^2)/∂δ = 2 * (max_prob - margin) * ∂max_prob/∂δ
1258            // Always apply these gradients (not just when overlapping) to keep scores low
1259            if p_a_b >= p_b_a {
1260                // p_a_b is the max, minimize it
1261                if overlap_i > 0.0 && vol_intersection > 1e-10 {
1262                    let grad_vol_intersection_delta_a = vol_intersection * 0.4;
1263                    let grad_p_a_b_delta_a = grad_vol_intersection_delta_a / vol_b.max(1e-8);
1264                    // Positive gradient on delta = shrink box (reduce intersection)
1265                    grad_delta_a[i] += config.negative_weight * 0.2 * grad_p_a_b_delta_a;
1266
1267                    // Also add margin loss gradient if active
1268                    if max_prob > config.margin {
1269                        let excess = max_prob - config.margin;
1270                        let margin_grad = 2.0 * excess * (1.0 + excess * 2.0) * grad_p_a_b_delta_a
1271                            + 2.0 * excess.powi(2) * 2.0 * grad_p_a_b_delta_a; // Exponential scaling
1272                        grad_delta_a[i] += config.negative_weight * margin_grad;
1273                    }
1274
1275                    // Adaptive penalty gradient for very high probabilities (stronger for higher probs)
1276                    if max_prob > 0.1 {
1277                        // Exponential scaling: stronger penalty for very high probabilities
1278                        let prob_excess = max_prob - 0.1;
1279                        let adaptive_grad =
1280                            2.0 * prob_excess * grad_p_a_b_delta_a * (3.0 + prob_excess * 7.0); // Stronger
1281                        grad_delta_a[i] += config.negative_weight * adaptive_grad;
1282                    } else if max_prob > 0.05 {
1283                        // Moderate penalty gradient for medium-high probabilities
1284                        let prob_excess = max_prob - 0.05;
1285                        let adaptive_grad = 2.0 * prob_excess * grad_p_a_b_delta_a * 1.5; // Stronger
1286                        grad_delta_a[i] += config.negative_weight * adaptive_grad;
1287                    } else if max_prob > 0.02 {
1288                        // Light penalty gradient for low-medium probabilities
1289                        let prob_excess = max_prob - 0.02;
1290                        let adaptive_grad = 2.0 * prob_excess * grad_p_a_b_delta_a * 0.5;
1291                        grad_delta_a[i] += config.negative_weight * adaptive_grad;
1292                    }
1293                }
1294                // Don't add extra shrink when not overlapping - let positive pairs grow
1295            } else {
1296                // p_b_a is the max, minimize it
1297                if overlap_i > 0.0 && vol_intersection > 1e-10 {
1298                    let grad_vol_intersection_delta_b = vol_intersection * 0.4;
1299                    let grad_p_b_a_delta_b = grad_vol_intersection_delta_b / vol_a.max(1e-8);
1300                    // Positive gradient on delta = shrink box
1301                    grad_delta_b[i] += config.negative_weight * 0.25 * grad_p_b_a_delta_b; // Slightly stronger
1302
1303                    // Also add margin loss gradient if active (stronger)
1304                    if max_prob > config.margin {
1305                        let excess = max_prob - config.margin;
1306                        let margin_grad = 2.0 * excess * (1.0 + excess * 2.0) * grad_p_b_a_delta_b
1307                            + 2.0 * excess.powi(2) * 2.0 * grad_p_b_a_delta_b; // Exponential scaling
1308                        grad_delta_b[i] += config.negative_weight * margin_grad;
1309                    }
1310
1311                    // Adaptive penalty gradient for very high probabilities (stronger for higher probs)
1312                    if max_prob > 0.1 {
1313                        // Exponential scaling: stronger penalty for very high probabilities
1314                        let prob_excess = max_prob - 0.1;
1315                        let adaptive_grad =
1316                            2.0 * prob_excess * grad_p_b_a_delta_b * (2.0 + prob_excess * 5.0);
1317                        grad_delta_b[i] += config.negative_weight * adaptive_grad;
1318                    } else if max_prob > 0.05 {
1319                        // Moderate penalty gradient for medium-high probabilities
1320                        let prob_excess = max_prob - 0.05;
1321                        let adaptive_grad = 2.0 * prob_excess * grad_p_b_a_delta_b * 1.0;
1322                        grad_delta_b[i] += config.negative_weight * adaptive_grad;
1323                    }
1324                }
1325                // Don't add extra shrink when not overlapping - let positive pairs grow
1326            }
1327        }
1328    }
1329
1330    // Clip gradients to prevent explosion
1331    for grad in &mut grad_mu_a {
1332        *grad = grad.clamp(-10.0_f32, 10.0_f32);
1333    }
1334    for grad in &mut grad_delta_a {
1335        *grad = grad.clamp(-10.0_f32, 10.0_f32);
1336    }
1337    for grad in &mut grad_mu_b {
1338        *grad = grad.clamp(-10.0_f32, 10.0_f32);
1339    }
1340    for grad in &mut grad_delta_b {
1341        *grad = grad.clamp(-10.0_f32, 10.0_f32);
1342    }
1343
1344    (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b)
1345}
1346
1347/// Sample negative pairs using self-adversarial sampling.
1348fn sample_self_adversarial_negatives(
1349    negative_pairs: &[(usize, usize)],
1350    boxes: &HashMap<usize, TrainableBox>,
1351    num_samples: usize,
1352    temperature: f32,
1353) -> Vec<usize> {
1354    // Compute scores for all negative pairs
1355    let mut scores: Vec<(usize, f32)> = negative_pairs
1356        .iter()
1357        .enumerate()
1358        .filter_map(|(idx, &(id_a, id_b))| {
1359            if let (Some(box_a), Some(box_b)) = (boxes.get(&id_a), boxes.get(&id_b)) {
1360                let box_a_embed = box_a.to_box();
1361                let box_b_embed = box_b.to_box();
1362                let score = box_a_embed.coreference_score(&box_b_embed);
1363                Some((idx, score / temperature))
1364            } else {
1365                None
1366            }
1367        })
1368        .collect();
1369
1370    // Sort by score (descending) - higher scores are "harder" negatives
1371    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1372
1373    // Sample top-k (hardest negatives)
1374    scores
1375        .into_iter()
1376        .take(num_samples)
1377        .map(|(idx, _)| idx)
1378        .collect()
1379}
1380
1381/// Get learning rate with warmup and cosine decay.
1382fn get_learning_rate(epoch: usize, total_epochs: usize, base_lr: f32, warmup_epochs: usize) -> f32 {
1383    if epoch < warmup_epochs {
1384        // Linear warmup: 0.1 * lr → lr
1385        let warmup_lr = base_lr * 0.1;
1386        warmup_lr + (base_lr - warmup_lr) * (epoch as f32 / warmup_epochs as f32)
1387    } else {
1388        // Cosine decay: lr → 0.1 * lr
1389        let progress =
1390            (epoch - warmup_epochs) as f32 / (total_epochs - warmup_epochs).max(1) as f32;
1391        let min_lr = base_lr * 0.1;
1392        min_lr + (base_lr - min_lr) * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0
1393    }
1394}
1395
1396// =============================================================================
1397// TrainableBox AMSGrad Update
1398// =============================================================================
1399
1400impl TrainableBox {
1401    /// Update box parameters using AMSGrad optimizer.
1402    pub fn update_amsgrad(
1403        &mut self,
1404        grad_mu: &[f32],
1405        grad_delta: &[f32],
1406        state: &mut AMSGradState,
1407    ) {
1408        state.t += 1;
1409        let t = state.t as f32;
1410
1411        // Update first moment (m)
1412        for (i, &grad) in grad_mu.iter().enumerate().take(self.dim) {
1413            state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * grad;
1414        }
1415
1416        // Update second moment (v) and max (v_hat)
1417        for (i, &grad) in grad_mu.iter().enumerate().take(self.dim) {
1418            let v_new = state.beta2 * state.v[i] + (1.0 - state.beta2) * grad * grad;
1419            state.v[i] = v_new;
1420            state.v_hat[i] = state.v_hat[i].max(v_new);
1421        }
1422
1423        // Bias correction for first moment
1424        let m_hat: Vec<f32> = state
1425            .m
1426            .iter()
1427            .map(|&m| m / (1.0 - state.beta1.powf(t)))
1428            .collect();
1429
1430        // Update mu
1431        for (i, &m_hat_val) in m_hat.iter().enumerate().take(self.dim) {
1432            let update = state.lr * m_hat_val / (state.v_hat[i].sqrt() + state.epsilon);
1433            self.mu[i] -= update;
1434
1435            // Ensure finite
1436            if !self.mu[i].is_finite() {
1437                self.mu[i] = 0.0;
1438            }
1439        }
1440
1441        // Similar for delta
1442        let mut m_delta = vec![0.0_f32; self.dim];
1443        let mut v_delta = vec![0.0_f32; self.dim];
1444        let mut v_hat_delta = vec![0.0_f32; self.dim];
1445
1446        for i in 0..self.dim {
1447            m_delta[i] = state.beta1 * m_delta[i] + (1.0 - state.beta1) * grad_delta[i];
1448            let v_new: f32 =
1449                state.beta2 * v_delta[i] + (1.0 - state.beta2) * grad_delta[i] * grad_delta[i];
1450            v_delta[i] = v_new;
1451            v_hat_delta[i] = v_hat_delta[i].max(v_new);
1452        }
1453
1454        let m_hat_delta: Vec<f32> = m_delta
1455            .iter()
1456            .map(|&m| m / (1.0 - state.beta1.powf(t)))
1457            .collect();
1458
1459        for i in 0..self.dim {
1460            let update = state.lr * m_hat_delta[i] / (v_hat_delta[i].sqrt() + state.epsilon);
1461            self.delta[i] -= update;
1462
1463            // Clamp delta to reasonable range (width between 0.01 and 10.0)
1464            self.delta[i] = self.delta[i].clamp(0.01_f32.ln(), 10.0_f32.ln());
1465
1466            // Ensure finite
1467            if !self.delta[i].is_finite() {
1468                self.delta[i] = 0.5_f32.ln();
1469            }
1470        }
1471    }
1472}
1473
1474// =============================================================================
1475// Simple Random Number Generator
1476// =============================================================================
1477
1478/// Simple random number generator (for when rand feature is not available).
1479///
1480/// Thread-safe implementation using atomic counter to avoid unsafe static mut.
1481fn simple_random() -> f32 {
1482    use std::collections::hash_map::DefaultHasher;
1483    use std::hash::{Hash, Hasher};
1484    use std::sync::atomic::{AtomicUsize, Ordering};
1485    use std::time::{SystemTime, UNIX_EPOCH};
1486
1487    static COUNTER: AtomicUsize = AtomicUsize::new(0);
1488
1489    // Thread-safe counter increment
1490    let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1491
1492    let mut hasher = DefaultHasher::new();
1493    // Use duration since epoch, or fallback to count if time is unavailable
1494    let time_nanos = SystemTime::now()
1495        .duration_since(UNIX_EPOCH)
1496        .map(|d| d.as_nanos())
1497        .unwrap_or(count as u128);
1498    time_nanos.hash(&mut hasher);
1499    count.hash(&mut hasher);
1500    let hash = hasher.finish();
1501    (hash as f32) / (u64::MAX as f32)
1502}