1#[allow(unused_imports)]
4use super::types::*;
5#[allow(unused_imports)]
6use super::*;
7use std::collections::HashMap;
8
9pub struct BoxEmbeddingTrainer {
11 config: TrainingConfig,
13 boxes: HashMap<usize, TrainableBox>,
15 optimizer_states: HashMap<usize, AMSGradState>,
17 dim: usize,
19}
20
21impl BoxEmbeddingTrainer {
22 pub fn new(
30 config: TrainingConfig,
31 dim: usize,
32 initial_embeddings: Option<HashMap<usize, Vec<f32>>>,
33 ) -> Self {
34 let mut boxes = HashMap::new();
35 let mut optimizer_states = HashMap::new();
36
37 if let Some(embeddings) = initial_embeddings {
38 for (entity_id, vector) in embeddings {
40 assert_eq!(vector.len(), dim);
41 let box_embedding = TrainableBox::from_vector(&vector, 0.1);
42 boxes.insert(entity_id, box_embedding.clone());
43 optimizer_states.insert(entity_id, AMSGradState::new(dim, config.learning_rate));
44 }
45 }
46
47 Self {
48 config,
49 boxes,
50 optimizer_states,
51 dim,
52 }
53 }
54
55 pub fn initialize_boxes(
70 &mut self,
71 examples: &[TrainingExample],
72 initial_embeddings: Option<&HashMap<usize, Vec<f32>>>,
73 ) {
74 let mut entity_ids = std::collections::HashSet::new();
76 let mut coref_groups: Vec<Vec<usize>> = Vec::new();
77
78 for example in examples {
79 for entity in &example.entities {
80 let entity_id = entity.start;
81 entity_ids.insert(entity_id);
82 }
83
84 for chain in &example.chains {
86 let group: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
87 if group.len() > 1 {
88 coref_groups.push(group);
89 }
90 }
91 }
92
93 for &entity_id in &entity_ids {
95 if let Some(embeddings) = initial_embeddings {
97 if let Some(vector) = embeddings.get(&entity_id) {
98 let norm: f32 = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
100 let normalized: Vec<f32> = if norm > 0.0 {
101 vector.iter().map(|&x| x / norm).collect()
102 } else {
103 vector.clone()
104 };
105
106 let box_embedding = TrainableBox::from_vector(&normalized, 0.2);
109 self.boxes.insert(entity_id, box_embedding.clone());
110 self.optimizer_states.insert(
111 entity_id,
112 AMSGradState::new(self.dim, self.config.learning_rate),
113 );
114 continue;
115 }
116 }
117
118 let mut group_center: Option<Vec<f32>> = None;
120 let mut in_coref_group = false;
121
122 for group in &coref_groups {
123 if group.contains(&entity_id) {
124 if group_center.is_none() {
126 group_center = Some(
127 (0..self.dim)
128 .map(|_| (simple_random() - 0.5) * 0.3) .collect(),
130 );
131 }
132 in_coref_group = true;
133 break;
134 }
135 }
136
137 let mu = if let Some(ref center) = group_center {
139 center
141 .iter()
142 .map(|&c| c + (simple_random() - 0.5) * 0.05) .collect()
144 } else {
145 (0..self.dim)
147 .map(|_| (simple_random() - 0.5) * 1.0)
148 .collect()
149 };
150
151 let initial_width = if in_coref_group {
155 1.1_f32 } else {
157 0.18_f32 };
159 let delta: Vec<f32> = vec![initial_width.ln(); self.dim];
160 let box_embedding = TrainableBox::new(mu, delta);
161 self.boxes.insert(entity_id, box_embedding.clone());
162 self.optimizer_states.insert(
163 entity_id,
164 AMSGradState::new(self.dim, self.config.learning_rate),
165 );
166 }
167 }
168
169 fn train_example(&mut self, example: &TrainingExample, epoch: usize) -> f32 {
171 let mut total_loss = 0.0;
172 let mut num_pairs = 0;
173
174 let current_lr = get_learning_rate(
176 epoch,
177 self.config.epochs,
178 self.config.learning_rate,
179 self.config.warmup_epochs,
180 );
181 for state in self.optimizer_states.values_mut() {
182 state.set_lr(current_lr);
183 }
184
185 let mut positive_pairs = Vec::new();
187 for chain in &example.chains {
188 let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
189 for i in 0..mentions.len() {
190 for j in (i + 1)..mentions.len() {
191 positive_pairs.push((mentions[i], mentions[j]));
192 }
193 }
194 }
195
196 let mut negative_pairs = Vec::new();
198 for i in 0..example.chains.len() {
199 for j in (i + 1)..example.chains.len() {
200 let chain_i: Vec<usize> =
201 example.chains[i].mentions.iter().map(|m| m.start).collect();
202 let chain_j: Vec<usize> =
203 example.chains[j].mentions.iter().map(|m| m.start).collect();
204 for &id_i in &chain_i {
205 for &id_j in &chain_j {
206 negative_pairs.push((id_i, id_j));
207 }
208 }
209 }
210 }
211
212 let mut gradients: HashMap<usize, (Vec<f32>, Vec<f32>)> = HashMap::new();
214
215 for &(id_a, id_b) in &positive_pairs {
217 let box_a = self.boxes.get(&id_a).cloned();
219 let box_b = self.boxes.get(&id_b).cloned();
220
221 if let (Some(box_a_ref), Some(box_b_ref)) = (box_a.as_ref(), box_b.as_ref()) {
222 let loss = compute_pair_loss(box_a_ref, box_b_ref, true, &self.config);
223 total_loss += loss;
224 num_pairs += 1;
225
226 let (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b) =
228 compute_analytical_gradients(box_a_ref, box_b_ref, true, &self.config);
229
230 if grad_mu_a.iter().any(|&x| !x.is_finite())
232 || grad_delta_a.iter().any(|&x| !x.is_finite())
233 || grad_mu_b.iter().any(|&x| !x.is_finite())
234 || grad_delta_b.iter().any(|&x| !x.is_finite())
235 {
236 continue;
237 }
238
239 let entry_a = gradients
241 .entry(id_a)
242 .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
243 for i in 0..self.dim {
244 entry_a.0[i] += grad_mu_a[i];
245 entry_a.1[i] += grad_delta_a[i];
246 }
247
248 let entry_b = gradients
249 .entry(id_b)
250 .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
251 for i in 0..self.dim {
252 entry_b.0[i] += grad_mu_b[i];
253 entry_b.1[i] += grad_delta_b[i];
254 }
255 }
256 }
257
258 let negative_samples: Vec<(usize, usize)> =
260 if self.config.use_self_adversarial && !negative_pairs.is_empty() {
261 let num_samples = positive_pairs.len().min(negative_pairs.len());
263 let sampled_indices = sample_self_adversarial_negatives(
264 &negative_pairs,
265 &self.boxes,
266 num_samples,
267 self.config.adversarial_temperature,
268 );
269 sampled_indices
270 .iter()
271 .map(|&idx| negative_pairs[idx])
272 .collect()
273 } else {
274 let num_samples = positive_pairs.len().min(negative_pairs.len());
276 negative_pairs.into_iter().take(num_samples).collect()
277 };
278
279 for &(id_a, id_b) in &negative_samples {
280 let box_a = self.boxes.get(&id_a).cloned();
282 let box_b = self.boxes.get(&id_b).cloned();
283
284 if let (Some(box_a_ref), Some(box_b_ref)) = (box_a.as_ref(), box_b.as_ref()) {
285 let loss = compute_pair_loss(box_a_ref, box_b_ref, false, &self.config);
286 total_loss += loss;
287 num_pairs += 1;
288
289 let (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b) =
291 compute_analytical_gradients(box_a_ref, box_b_ref, false, &self.config);
292
293 if grad_mu_a.iter().any(|&x| !x.is_finite())
295 || grad_delta_a.iter().any(|&x| !x.is_finite())
296 || grad_mu_b.iter().any(|&x| !x.is_finite())
297 || grad_delta_b.iter().any(|&x| !x.is_finite())
298 {
299 continue;
300 }
301
302 let entry_a = gradients
304 .entry(id_a)
305 .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
306 for i in 0..self.dim {
307 entry_a.0[i] += grad_mu_a[i];
308 entry_a.1[i] += grad_delta_a[i];
309 }
310
311 let entry_b = gradients
312 .entry(id_b)
313 .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
314 for i in 0..self.dim {
315 entry_b.0[i] += grad_mu_b[i];
316 entry_b.1[i] += grad_delta_b[i];
317 }
318 }
319 }
320
321 for (entity_id, (grad_mu, grad_delta)) in gradients {
323 if let (Some(box_mut), Some(state)) = (
324 self.boxes.get_mut(&entity_id),
325 self.optimizer_states.get_mut(&entity_id),
326 ) {
327 box_mut.update_amsgrad(&grad_mu, &grad_delta, state);
328 }
329 }
330
331 if num_pairs > 0 {
332 total_loss / num_pairs as f32
333 } else {
334 0.0
335 }
336 }
337
338 pub fn train(&mut self, examples: &[TrainingExample]) -> Vec<f32> {
342 let mut losses = Vec::new();
343 let mut best_loss = f32::INFINITY;
344 let mut patience_counter = 0;
345
346 let mut score_gap_history = Vec::new();
348
349 for epoch in 0..self.config.epochs {
350 let (avg_pos, avg_neg, _) = self.get_overlap_stats(examples);
352 let current_gap = avg_pos - avg_neg;
353 score_gap_history.push(current_gap);
354
355 let positive_focus_epochs = self
357 .config
358 .positive_focus_epochs
359 .unwrap_or(self.config.epochs / 3);
360 let is_positive_stage = epoch < positive_focus_epochs;
361
362 let adaptive_negative_weight = if is_positive_stage {
364 let stage_progress = epoch as f32 / positive_focus_epochs as f32;
368 self.config.negative_weight * (0.2 + stage_progress * 0.1)
369 } else if avg_pos > 0.05 && avg_neg > 0.3 {
370 let progress = ((epoch - positive_focus_epochs) as f32
373 / (self.config.epochs - positive_focus_epochs) as f32)
374 .min(1.0);
375 let neg_penalty = (avg_neg / 0.4).min(1.0); self.config.negative_weight * (0.7 + progress * 0.8 + neg_penalty * 0.4).min(2.0)
378 } else if avg_pos > 0.02 && current_gap > 0.0 {
380 let progress = ((epoch - positive_focus_epochs) as f32
383 / (self.config.epochs - positive_focus_epochs) as f32)
384 .min(1.0);
385 self.config.negative_weight * (0.5 + progress * 0.5).min(1.0 + (current_gap / 0.1))
386 } else if avg_pos < 0.01 {
387 self.config.negative_weight * 0.3
389 } else {
390 let progress = ((epoch - positive_focus_epochs) as f32
392 / (self.config.epochs - positive_focus_epochs) as f32)
393 .min(1.0);
394 self.config.negative_weight * (0.4 + progress * 0.4)
395 };
396
397 let original_negative_weight = self.config.negative_weight;
399 self.config.negative_weight = adaptive_negative_weight;
400 let mut shuffled_indices: Vec<usize> = (0..examples.len()).collect();
402 for i in (1..shuffled_indices.len()).rev() {
403 let j = (simple_random() * (i + 1) as f32) as usize;
404 shuffled_indices.swap(i, j);
405 }
406
407 let mut epoch_loss = 0.0;
408 let mut num_batches = 0;
409
410 for batch_start in (0..examples.len()).step_by(self.config.batch_size) {
412 let batch_end = (batch_start + self.config.batch_size).min(examples.len());
413 let batch_indices = &shuffled_indices[batch_start..batch_end];
414
415 let mut batch_loss = 0.0;
416 let mut batch_pairs = 0;
417
418 for &idx in batch_indices {
420 let example = &examples[idx];
421 let loss = self.train_example(example, epoch);
422 batch_loss += loss;
423 batch_pairs += 1;
424 }
425
426 if batch_pairs > 0 {
427 epoch_loss += batch_loss / batch_pairs as f32;
428 num_batches += 1;
429 }
430 }
431
432 let avg_loss = if num_batches > 0 {
433 epoch_loss / num_batches as f32
434 } else {
435 0.0
436 };
437 losses.push(avg_loss);
438
439 let current_lr = get_learning_rate(
440 epoch,
441 self.config.epochs,
442 self.config.learning_rate,
443 self.config.warmup_epochs,
444 );
445
446 let improved = avg_loss < best_loss - self.config.early_stopping_min_delta;
448 if improved {
449 best_loss = avg_loss;
450 patience_counter = 0;
451 } else {
452 patience_counter += 1;
453 }
454
455 if epoch % 10 == 0 || epoch == self.config.epochs - 1 || improved {
457 let (avg_pos, avg_neg, overlap_rate) = self.get_overlap_stats(examples);
458 let status = if improved { "✓" } else { " " };
459 let patience_info = if let Some(patience) = self.config.early_stopping_patience {
460 format!(", patience={}/{}", patience_counter, patience)
461 } else {
462 String::new()
463 };
464 let loss_reduction = if losses.len() > 1 {
465 format!(" ({:.1}%↓)", (1.0 - avg_loss / losses[0]) * 100.0)
466 } else {
467 String::new()
468 };
469 let score_gap = avg_pos - avg_neg; let positive_focus_epochs = self
471 .config
472 .positive_focus_epochs
473 .unwrap_or(self.config.epochs / 3);
474 let stage = if epoch < positive_focus_epochs {
475 "P+"
476 } else {
477 "S-"
478 };
479 println!("Epoch {}: loss = {:.4}{}, lr = {:.6}, best = {:.4} {} ({} batches{}, neg_w={:.2}, stage={})",
480 epoch, avg_loss, loss_reduction, current_lr, best_loss, status, num_batches, patience_info, adaptive_negative_weight, stage);
481 println!(
482 " Overlap: {:.1}%, Pos: {:.4}, Neg: {:.4}, Gap: {:.4} {}",
483 overlap_rate * 100.0,
484 avg_pos,
485 avg_neg,
486 score_gap,
487 if score_gap > 0.0 { "✓" } else { "⚠" }
488 );
489 }
490
491 self.config.negative_weight = original_negative_weight;
493
494 if let Some(patience) = self.config.early_stopping_patience {
496 if patience_counter >= patience {
497 println!(
498 "Early stopping at epoch {} (no improvement for {} epochs)",
499 epoch, patience
500 );
501 break;
502 }
503 }
504 }
505
506 losses
507 }
508
509 pub fn get_boxes(&self) -> HashMap<usize, BoxEmbedding> {
511 self.boxes
512 .iter()
513 .map(|(id, trainable)| (*id, trainable.to_box()))
514 .collect()
515 }
516
517 pub fn get_overlap_stats(&self, examples: &[TrainingExample]) -> (f32, f32, f32) {
521 let mut positive_scores = Vec::new();
522 let mut negative_scores = Vec::new();
523 let mut overlapping_pairs = 0;
524 let mut total_pairs = 0;
525
526 for example in examples {
527 for chain in &example.chains {
529 let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
530 for i in 0..mentions.len() {
531 for j in (i + 1)..mentions.len() {
532 if let (Some(box_a), Some(box_b)) =
533 (self.boxes.get(&mentions[i]), self.boxes.get(&mentions[j]))
534 {
535 let box_a_embed = box_a.to_box();
536 let box_b_embed = box_b.to_box();
537 let score = box_a_embed.coreference_score(&box_b_embed);
538 positive_scores.push(score);
539 if score > 0.01 {
540 overlapping_pairs += 1;
541 }
542 total_pairs += 1;
543 }
544 }
545 }
546 }
547
548 for i in 0..example.chains.len() {
550 for j in (i + 1)..example.chains.len() {
551 let chain_i: Vec<usize> =
552 example.chains[i].mentions.iter().map(|m| m.start).collect();
553 let chain_j: Vec<usize> =
554 example.chains[j].mentions.iter().map(|m| m.start).collect();
555 for &id_i in &chain_i {
556 for &id_j in &chain_j {
557 if let (Some(box_a), Some(box_b)) =
558 (self.boxes.get(&id_i), self.boxes.get(&id_j))
559 {
560 let box_a_embed = box_a.to_box();
561 let box_b_embed = box_b.to_box();
562 let score = box_a_embed.coreference_score(&box_b_embed);
563 negative_scores.push(score);
564 }
565 }
566 }
567 }
568 }
569 }
570
571 let avg_positive = if !positive_scores.is_empty() {
572 positive_scores.iter().sum::<f32>() / positive_scores.len() as f32
573 } else {
574 0.0
575 };
576
577 let avg_negative = if !negative_scores.is_empty() {
578 negative_scores.iter().sum::<f32>() / negative_scores.len() as f32
579 } else {
580 0.0
581 };
582
583 let overlap_rate = if total_pairs > 0 {
584 overlapping_pairs as f32 / total_pairs as f32
585 } else {
586 0.0
587 };
588
589 (avg_positive, avg_negative, overlap_rate)
590 }
591
592 pub fn evaluate(&self, examples: &[TrainingExample], threshold: f32) -> (f32, f32, f32, f32) {
603 let mut true_positives = 0;
604 let mut false_positives = 0;
605 let mut false_negatives = 0;
606 let mut total_pairs = 0;
607
608 for example in examples {
609 let mut positive_pairs = Vec::new();
611 for chain in &example.chains {
612 let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
613 for i in 0..mentions.len() {
614 for j in (i + 1)..mentions.len() {
615 positive_pairs.push((mentions[i], mentions[j]));
616 }
617 }
618 }
619
620 let mut negative_pairs = Vec::new();
622 for i in 0..example.chains.len() {
623 for j in (i + 1)..example.chains.len() {
624 let chain_i: Vec<usize> =
625 example.chains[i].mentions.iter().map(|m| m.start).collect();
626 let chain_j: Vec<usize> =
627 example.chains[j].mentions.iter().map(|m| m.start).collect();
628 for &id_i in &chain_i {
629 for &id_j in &chain_j {
630 negative_pairs.push((id_i, id_j));
631 }
632 }
633 }
634 }
635
636 for &(id_a, id_b) in &positive_pairs {
638 total_pairs += 1;
639 if let (Some(box_a), Some(box_b)) = (self.boxes.get(&id_a), self.boxes.get(&id_b)) {
640 let box_a_embed = box_a.to_box();
641 let box_b_embed = box_b.to_box();
642 let score = box_a_embed.coreference_score(&box_b_embed);
643 if score >= threshold {
644 true_positives += 1;
645 } else {
646 false_negatives += 1;
647 }
648 } else {
649 false_negatives += 1;
651 }
652 }
653
654 for &(id_a, id_b) in &negative_pairs {
656 total_pairs += 1;
657 if let (Some(box_a), Some(box_b)) = (self.boxes.get(&id_a), self.boxes.get(&id_b)) {
658 let box_a_embed = box_a.to_box();
659 let box_b_embed = box_b.to_box();
660 let score = box_a_embed.coreference_score(&box_b_embed);
661 if score >= threshold {
662 false_positives += 1;
663 }
664 }
666 }
668 }
669
670 let precision = if true_positives + false_positives > 0 {
672 true_positives as f32 / (true_positives + false_positives) as f32
673 } else {
674 0.0
675 };
676
677 let recall = if true_positives + false_negatives > 0 {
678 true_positives as f32 / (true_positives + false_negatives) as f32
679 } else {
680 0.0
681 };
682
683 let f1 = if precision + recall > 0.0 {
684 2.0 * precision * recall / (precision + recall)
685 } else {
686 0.0
687 };
688
689 let accuracy = if total_pairs > 0 {
690 (true_positives + (total_pairs - true_positives - false_positives - false_negatives))
691 as f32
692 / total_pairs as f32
693 } else {
694 0.0
695 };
696
697 (accuracy, precision, recall, f1)
698 }
699
700 pub fn save_boxes(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
712 use std::fs::File;
713 use std::io::Write;
714
715 let serialized = serde_json::to_string_pretty(&self.boxes)?;
716 let mut file = File::create(path)?;
717 file.write_all(serialized.as_bytes())?;
718 Ok(())
719 }
720
721 pub fn load_boxes(
734 path: &str,
735 dim: usize,
736 ) -> Result<HashMap<usize, TrainableBox>, Box<dyn std::error::Error>> {
737 use std::fs::File;
738 use std::io::Read;
739
740 let mut file = File::open(path)?;
741 let mut contents = String::new();
742 file.read_to_string(&mut contents)?;
743 let boxes: HashMap<usize, TrainableBox> = serde_json::from_str(&contents)?;
744
745 for (id, box_embedding) in &boxes {
747 if box_embedding.dim != dim {
748 return Err(format!(
749 "Box for entity {} has dimension {}, expected {}",
750 id, box_embedding.dim, dim
751 )
752 .into());
753 }
754 }
755
756 Ok(boxes)
757 }
758
759 #[cfg(any(feature = "analysis", feature = "eval"))]
775 pub fn evaluate_standard_metrics(
776 &self,
777 examples: &[TrainingExample],
778 threshold: f32,
779 ) -> crate::eval::coref_metrics::CorefEvaluation {
780 use crate::backends::box_embeddings::BoxCorefConfig;
781 use crate::eval::coref_metrics::CorefEvaluation;
782 use crate::eval::coref_resolver::BoxCorefResolver;
783
784 let mut all_predicted_chains = Vec::new();
785 let mut all_gold_chains = Vec::new();
786
787 for example in examples {
788 all_gold_chains.extend(example.chains.clone());
790
791 let entities = &example.entities;
793
794 let mut boxes = Vec::new();
796 for entity in entities {
797 if let Some(trainable_box) = self.boxes.get(&entity.start) {
798 boxes.push(trainable_box.to_box());
799 } else {
800 let center = vec![0.0; self.dim];
802 boxes.push(crate::backends::box_embeddings::BoxEmbedding::from_vector(
803 ¢er, 0.1,
804 ));
805 }
806 }
807
808 let box_config = BoxCorefConfig {
810 coreference_threshold: threshold,
811 ..Default::default()
812 };
813 let resolver = BoxCorefResolver::new(box_config);
814 let resolved_entities = resolver.resolve_with_boxes(entities, &boxes);
815
816 let predicted_chains = anno_core::core::coref::entities_to_chains(&resolved_entities);
818 all_predicted_chains.extend(predicted_chains);
819 }
820
821 CorefEvaluation::compute(&all_predicted_chains, &all_gold_chains)
823 }
824}
825
826pub fn split_train_val(
837 examples: &[TrainingExample],
838 val_ratio: f32,
839) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
840 let val_size = (examples.len() as f32 * val_ratio) as usize;
841 let mut shuffled: Vec<TrainingExample> = examples.to_vec();
842
843 for i in (1..shuffled.len()).rev() {
845 let j = (simple_random() * (i + 1) as f32) as usize;
846 shuffled.swap(i, j);
847 }
848
849 let val_examples = shuffled.split_off(val_size);
850 (shuffled, val_examples)
851}
852
853fn compute_pair_loss(
859 box_a: &TrainableBox,
860 box_b: &TrainableBox,
861 is_positive: bool,
862 config: &TrainingConfig,
863) -> f32 {
864 let box_a_embed = box_a.to_box();
865 let box_b_embed = box_b.to_box();
866
867 if is_positive {
868 let p_a_b = box_a_embed.conditional_probability(&box_b_embed);
870 let p_b_a = box_b_embed.conditional_probability(&box_a_embed);
871
872 let p_a_b = p_a_b.max(1e-8);
874 let p_b_a = p_b_a.max(1e-8);
875
876 let min_prob = p_a_b.min(p_b_a);
879 let neg_log_prob = -min_prob.ln();
880
881 let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
883 let distance_penalty = if vol_intersection < 1e-10 {
884 let center_a = box_a_embed.center();
886 let center_b = box_b_embed.center();
887 let dist: f32 = center_a
888 .iter()
889 .zip(center_b.iter())
890 .map(|(a, b)| (a - b).powi(2))
891 .sum::<f32>()
892 .sqrt();
893 0.3 * dist } else {
895 let vol_a = box_a_embed.volume();
898 let vol_b = box_b_embed.volume();
899 let overlap_ratio = vol_intersection / vol_a.min(vol_b).max(1e-10);
900 if overlap_ratio < 0.5 {
901 0.1 * (0.5 - overlap_ratio)
903 } else {
904 0.0
905 }
906 };
907
908 let vol_a = box_a_embed.volume();
910 let vol_b = box_b_embed.volume();
911 let reg = config.regularization * 1.0 * (vol_a + vol_b);
913
914 (neg_log_prob + reg + distance_penalty).max(0.0)
915 } else {
916 let p_a_b = box_a_embed.conditional_probability(&box_b_embed);
919 let p_b_a = box_b_embed.conditional_probability(&box_a_embed);
920
921 let max_prob = p_a_b.max(p_b_a);
923
924 let margin_loss = if max_prob > config.margin {
927 let excess = max_prob - config.margin;
929 excess.powi(2) * (1.0 + excess * 2.0) } else {
931 0.0 };
933
934 let _high_prob_penalty = if max_prob > 0.1 {
937 (max_prob - 0.1).powi(2) * 0.5 } else {
939 0.0
940 };
941
942 let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
944 let vol_a = box_a_embed.volume();
945 let vol_b = box_b_embed.volume();
946 let overlap_penalty = if vol_intersection > 1e-10 {
947 let overlap_ratio = vol_intersection / vol_a.min(vol_b).max(1e-10);
949 if overlap_ratio > 0.5 {
951 4.0 * overlap_ratio * overlap_ratio } else if overlap_ratio > 0.3 {
953 3.0 * overlap_ratio } else {
955 2.5 * overlap_ratio }
957 } else {
958 0.0
959 };
960
961 let base_loss = if max_prob > 0.01 {
964 max_prob * 0.2 } else {
966 0.0 };
968
969 let adaptive_penalty = if max_prob > 0.1 {
972 let prob_excess = max_prob - 0.1;
974 prob_excess.powi(2) * (3.0 + prob_excess * 7.0) } else if max_prob > 0.05 {
976 (max_prob - 0.05).powi(2) * 1.5 } else if max_prob > 0.02 {
979 (max_prob - 0.02).powi(2) * 0.5
981 } else {
982 0.0
983 };
984
985 config.negative_weight * (margin_loss + overlap_penalty + base_loss + adaptive_penalty)
986 }
987}
988
989fn compute_analytical_gradients(
991 box_a: &TrainableBox,
992 box_b: &TrainableBox,
993 is_positive: bool,
994 config: &TrainingConfig,
995) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
996 let box_a_embed = box_a.to_box();
997 let box_b_embed = box_b.to_box();
998 let dim = box_a.dim;
999
1000 let mut grad_mu_a = vec![0.0; dim];
1002 let mut grad_delta_a = vec![0.0; dim];
1003 let mut grad_mu_b = vec![0.0; dim];
1004 let mut grad_delta_b = vec![0.0; dim];
1005
1006 let vol_a = box_a_embed.volume();
1008 let vol_b = box_b_embed.volume();
1009 let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1010
1011 if is_positive {
1012 let p_a_b = if vol_b > 0.0 {
1017 vol_intersection / vol_b
1018 } else {
1019 0.0
1020 };
1021 let p_b_a = if vol_a > 0.0 {
1022 vol_intersection / vol_a
1023 } else {
1024 0.0
1025 };
1026
1027 let p_a_b = p_a_b.max(1e-8);
1029 let p_b_a = p_b_a.max(1e-8);
1030
1031 let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1037 let has_overlap = vol_intersection > 1e-10;
1038
1039 if !has_overlap {
1040 let center_a = box_a_embed.center();
1042 let center_b = box_b_embed.center();
1043 let center_dist = center_a
1044 .iter()
1045 .zip(center_b.iter())
1046 .map(|(a, b)| (a - b).powi(2))
1047 .sum::<f32>()
1048 .sqrt();
1049
1050 for i in 0..dim {
1051 let diff = center_b[i] - center_a[i];
1052 let distance_factor = (center_dist / dim as f32).clamp(0.5, 2.0);
1054 let attraction_strength = 4.0 * distance_factor; grad_mu_a[i] += attraction_strength * diff;
1057 grad_mu_b[i] += -attraction_strength * diff;
1058
1059 grad_delta_a[i] += 0.5 * distance_factor; grad_delta_b[i] += 0.5 * distance_factor;
1062 }
1063 }
1064
1065 for i in 0..dim {
1066 let overlap_i = if box_a_embed.min[i] < box_b_embed.max[i]
1072 && box_b_embed.min[i] < box_a_embed.max[i]
1073 {
1074 let min_overlap = box_a_embed.min[i].max(box_b_embed.min[i]);
1076 let max_overlap = box_a_embed.max[i].min(box_b_embed.max[i]);
1077 (max_overlap - min_overlap).max(0.0)
1078 } else {
1079 0.0
1080 };
1081
1082 if overlap_i > 0.0 && vol_intersection > 0.0 {
1083 let overlap_ratio_a = vol_intersection / vol_a.max(1e-10);
1086 let overlap_ratio_b = vol_intersection / vol_b.max(1e-10);
1087
1088 if overlap_ratio_a < 0.15 {
1091 grad_delta_a[i] += 0.35;
1093 } else if overlap_ratio_a < 0.3 {
1094 grad_delta_a[i] += 0.3;
1096 } else if overlap_ratio_a < 0.5 {
1097 grad_delta_a[i] += 0.2;
1099 } else if overlap_ratio_a < 0.7 {
1100 grad_delta_a[i] += 0.1;
1102 } else if overlap_ratio_a < 0.85 {
1103 grad_delta_a[i] += 0.05;
1105 }
1106 if overlap_ratio_b < 0.15 {
1109 grad_delta_b[i] += 0.35;
1111 } else if overlap_ratio_b < 0.3 {
1112 grad_delta_b[i] += 0.3;
1114 } else if overlap_ratio_b < 0.5 {
1115 grad_delta_b[i] += 0.2;
1117 } else if overlap_ratio_b < 0.7 {
1118 grad_delta_b[i] += 0.1;
1120 } else if overlap_ratio_b < 0.85 {
1121 grad_delta_b[i] += 0.05;
1123 }
1124
1125 let gradient_strength = if overlap_ratio_a < 0.1 {
1128 1.7 } else if overlap_ratio_a < 0.2 {
1130 1.6 } else if overlap_ratio_a < 0.4 {
1132 1.4 } else if overlap_ratio_a < 0.6 {
1134 1.1 } else {
1136 0.6 };
1138
1139 let grad_vol_intersection_delta_a = vol_intersection * 0.5 * gradient_strength;
1140 let grad_p_a_b_delta_a = grad_vol_intersection_delta_a / vol_b.max(1e-8);
1141 grad_delta_a[i] += -grad_p_a_b_delta_a / p_a_b.max(1e-8) * gradient_strength;
1142
1143 let grad_vol_intersection_delta_b = vol_intersection * 0.5 * gradient_strength;
1144 let grad_p_b_a_delta_b = grad_vol_intersection_delta_b / vol_a.max(1e-8);
1145 grad_delta_b[i] += -grad_p_b_a_delta_b / p_b_a.max(1e-8) * gradient_strength;
1146 } else {
1147 grad_delta_a[i] += 0.3; grad_delta_b[i] += 0.3; }
1151
1152 grad_delta_a[i] += config.regularization * 1.0 * vol_a; grad_delta_b[i] += config.regularization * 1.0 * vol_b;
1157 }
1158 } else {
1159 let p_a_b = if vol_b > 0.0 {
1161 vol_intersection / vol_b
1162 } else {
1163 0.0
1164 };
1165 let p_b_a = if vol_a > 0.0 {
1166 vol_intersection / vol_a
1167 } else {
1168 0.0
1169 };
1170 let max_prob = p_a_b.max(p_b_a);
1171
1172 for i in 0..dim {
1175 let overlap_i = if box_a_embed.min[i] < box_b_embed.max[i]
1177 && box_b_embed.min[i] < box_a_embed.max[i]
1178 {
1179 let min_overlap = box_a_embed.min[i].max(box_b_embed.min[i]);
1180 let max_overlap = box_a_embed.max[i].min(box_b_embed.max[i]);
1181 (max_overlap - min_overlap).max(0.0)
1182 } else {
1183 0.0
1184 };
1185
1186 if overlap_i > 0.0 {
1187 let center_a = box_a_embed.center();
1190 let center_b = box_b_embed.center();
1191 let diff = center_b[i] - center_a[i];
1192
1193 let overlap_factor =
1196 (overlap_i / (box_a_embed.max[i] - box_a_embed.min[i]).max(1e-6)).min(1.0);
1197 let separation_strength = 1.5 + overlap_factor * 2.0; if diff.abs() > 1e-6 {
1199 grad_mu_a[i] += -config.negative_weight * separation_strength * diff;
1200 grad_mu_b[i] += config.negative_weight * separation_strength * diff;
1201 } else {
1202 grad_mu_a[i] += -config.negative_weight * separation_strength * 2.5;
1204 grad_mu_b[i] += config.negative_weight * separation_strength * 2.5;
1205 }
1206
1207 let overlap_ratio_dim =
1210 overlap_i / (box_a_embed.max[i] - box_a_embed.min[i]).max(1e-6);
1211 let shrink_strength = if overlap_ratio_dim > 0.7 {
1212 0.7 } else if overlap_ratio_dim > 0.5 {
1214 0.6 } else if overlap_ratio_dim > 0.3 {
1216 0.5 } else {
1218 0.35 };
1220 grad_delta_a[i] += -config.negative_weight * shrink_strength;
1221 grad_delta_b[i] += -config.negative_weight * shrink_strength;
1222 } else {
1223 }
1226
1227 if overlap_i > 0.0 && vol_intersection > 1e-10 {
1231 let min_vol = vol_a.min(vol_b);
1232 let overlap_ratio = vol_intersection / min_vol.max(1e-10);
1233 let penalty_strength = if overlap_ratio > 0.5 {
1236 0.4 + overlap_ratio * 0.6 } else if overlap_ratio > 0.3 {
1238 0.3 + overlap_ratio * 0.5 } else {
1240 0.2 + overlap_ratio * 0.4 };
1242 let penalty_multiplier = if overlap_ratio > 0.5 {
1243 4.0
1244 } else if overlap_ratio > 0.3 {
1245 3.0
1246 } else {
1247 2.5
1248 };
1249 grad_delta_a[i] +=
1250 config.negative_weight * penalty_multiplier * overlap_ratio * penalty_strength;
1251 grad_delta_b[i] +=
1252 config.negative_weight * penalty_multiplier * overlap_ratio * penalty_strength;
1253 }
1254
1255 if p_a_b >= p_b_a {
1260 if overlap_i > 0.0 && vol_intersection > 1e-10 {
1262 let grad_vol_intersection_delta_a = vol_intersection * 0.4;
1263 let grad_p_a_b_delta_a = grad_vol_intersection_delta_a / vol_b.max(1e-8);
1264 grad_delta_a[i] += config.negative_weight * 0.2 * grad_p_a_b_delta_a;
1266
1267 if max_prob > config.margin {
1269 let excess = max_prob - config.margin;
1270 let margin_grad = 2.0 * excess * (1.0 + excess * 2.0) * grad_p_a_b_delta_a
1271 + 2.0 * excess.powi(2) * 2.0 * grad_p_a_b_delta_a; grad_delta_a[i] += config.negative_weight * margin_grad;
1273 }
1274
1275 if max_prob > 0.1 {
1277 let prob_excess = max_prob - 0.1;
1279 let adaptive_grad =
1280 2.0 * prob_excess * grad_p_a_b_delta_a * (3.0 + prob_excess * 7.0); grad_delta_a[i] += config.negative_weight * adaptive_grad;
1282 } else if max_prob > 0.05 {
1283 let prob_excess = max_prob - 0.05;
1285 let adaptive_grad = 2.0 * prob_excess * grad_p_a_b_delta_a * 1.5; grad_delta_a[i] += config.negative_weight * adaptive_grad;
1287 } else if max_prob > 0.02 {
1288 let prob_excess = max_prob - 0.02;
1290 let adaptive_grad = 2.0 * prob_excess * grad_p_a_b_delta_a * 0.5;
1291 grad_delta_a[i] += config.negative_weight * adaptive_grad;
1292 }
1293 }
1294 } else {
1296 if overlap_i > 0.0 && vol_intersection > 1e-10 {
1298 let grad_vol_intersection_delta_b = vol_intersection * 0.4;
1299 let grad_p_b_a_delta_b = grad_vol_intersection_delta_b / vol_a.max(1e-8);
1300 grad_delta_b[i] += config.negative_weight * 0.25 * grad_p_b_a_delta_b; if max_prob > config.margin {
1305 let excess = max_prob - config.margin;
1306 let margin_grad = 2.0 * excess * (1.0 + excess * 2.0) * grad_p_b_a_delta_b
1307 + 2.0 * excess.powi(2) * 2.0 * grad_p_b_a_delta_b; grad_delta_b[i] += config.negative_weight * margin_grad;
1309 }
1310
1311 if max_prob > 0.1 {
1313 let prob_excess = max_prob - 0.1;
1315 let adaptive_grad =
1316 2.0 * prob_excess * grad_p_b_a_delta_b * (2.0 + prob_excess * 5.0);
1317 grad_delta_b[i] += config.negative_weight * adaptive_grad;
1318 } else if max_prob > 0.05 {
1319 let prob_excess = max_prob - 0.05;
1321 let adaptive_grad = 2.0 * prob_excess * grad_p_b_a_delta_b * 1.0;
1322 grad_delta_b[i] += config.negative_weight * adaptive_grad;
1323 }
1324 }
1325 }
1327 }
1328 }
1329
1330 for grad in &mut grad_mu_a {
1332 *grad = grad.clamp(-10.0_f32, 10.0_f32);
1333 }
1334 for grad in &mut grad_delta_a {
1335 *grad = grad.clamp(-10.0_f32, 10.0_f32);
1336 }
1337 for grad in &mut grad_mu_b {
1338 *grad = grad.clamp(-10.0_f32, 10.0_f32);
1339 }
1340 for grad in &mut grad_delta_b {
1341 *grad = grad.clamp(-10.0_f32, 10.0_f32);
1342 }
1343
1344 (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b)
1345}
1346
1347fn sample_self_adversarial_negatives(
1349 negative_pairs: &[(usize, usize)],
1350 boxes: &HashMap<usize, TrainableBox>,
1351 num_samples: usize,
1352 temperature: f32,
1353) -> Vec<usize> {
1354 let mut scores: Vec<(usize, f32)> = negative_pairs
1356 .iter()
1357 .enumerate()
1358 .filter_map(|(idx, &(id_a, id_b))| {
1359 if let (Some(box_a), Some(box_b)) = (boxes.get(&id_a), boxes.get(&id_b)) {
1360 let box_a_embed = box_a.to_box();
1361 let box_b_embed = box_b.to_box();
1362 let score = box_a_embed.coreference_score(&box_b_embed);
1363 Some((idx, score / temperature))
1364 } else {
1365 None
1366 }
1367 })
1368 .collect();
1369
1370 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1372
1373 scores
1375 .into_iter()
1376 .take(num_samples)
1377 .map(|(idx, _)| idx)
1378 .collect()
1379}
1380
1381fn get_learning_rate(epoch: usize, total_epochs: usize, base_lr: f32, warmup_epochs: usize) -> f32 {
1383 if epoch < warmup_epochs {
1384 let warmup_lr = base_lr * 0.1;
1386 warmup_lr + (base_lr - warmup_lr) * (epoch as f32 / warmup_epochs as f32)
1387 } else {
1388 let progress =
1390 (epoch - warmup_epochs) as f32 / (total_epochs - warmup_epochs).max(1) as f32;
1391 let min_lr = base_lr * 0.1;
1392 min_lr + (base_lr - min_lr) * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0
1393 }
1394}
1395
1396impl TrainableBox {
1401 pub fn update_amsgrad(
1403 &mut self,
1404 grad_mu: &[f32],
1405 grad_delta: &[f32],
1406 state: &mut AMSGradState,
1407 ) {
1408 state.t += 1;
1409 let t = state.t as f32;
1410
1411 for (i, &grad) in grad_mu.iter().enumerate().take(self.dim) {
1413 state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * grad;
1414 }
1415
1416 for (i, &grad) in grad_mu.iter().enumerate().take(self.dim) {
1418 let v_new = state.beta2 * state.v[i] + (1.0 - state.beta2) * grad * grad;
1419 state.v[i] = v_new;
1420 state.v_hat[i] = state.v_hat[i].max(v_new);
1421 }
1422
1423 let m_hat: Vec<f32> = state
1425 .m
1426 .iter()
1427 .map(|&m| m / (1.0 - state.beta1.powf(t)))
1428 .collect();
1429
1430 for (i, &m_hat_val) in m_hat.iter().enumerate().take(self.dim) {
1432 let update = state.lr * m_hat_val / (state.v_hat[i].sqrt() + state.epsilon);
1433 self.mu[i] -= update;
1434
1435 if !self.mu[i].is_finite() {
1437 self.mu[i] = 0.0;
1438 }
1439 }
1440
1441 let mut m_delta = vec![0.0_f32; self.dim];
1443 let mut v_delta = vec![0.0_f32; self.dim];
1444 let mut v_hat_delta = vec![0.0_f32; self.dim];
1445
1446 for i in 0..self.dim {
1447 m_delta[i] = state.beta1 * m_delta[i] + (1.0 - state.beta1) * grad_delta[i];
1448 let v_new: f32 =
1449 state.beta2 * v_delta[i] + (1.0 - state.beta2) * grad_delta[i] * grad_delta[i];
1450 v_delta[i] = v_new;
1451 v_hat_delta[i] = v_hat_delta[i].max(v_new);
1452 }
1453
1454 let m_hat_delta: Vec<f32> = m_delta
1455 .iter()
1456 .map(|&m| m / (1.0 - state.beta1.powf(t)))
1457 .collect();
1458
1459 for i in 0..self.dim {
1460 let update = state.lr * m_hat_delta[i] / (v_hat_delta[i].sqrt() + state.epsilon);
1461 self.delta[i] -= update;
1462
1463 self.delta[i] = self.delta[i].clamp(0.01_f32.ln(), 10.0_f32.ln());
1465
1466 if !self.delta[i].is_finite() {
1468 self.delta[i] = 0.5_f32.ln();
1469 }
1470 }
1471 }
1472}
1473
1474fn simple_random() -> f32 {
1482 use std::collections::hash_map::DefaultHasher;
1483 use std::hash::{Hash, Hasher};
1484 use std::sync::atomic::{AtomicUsize, Ordering};
1485 use std::time::{SystemTime, UNIX_EPOCH};
1486
1487 static COUNTER: AtomicUsize = AtomicUsize::new(0);
1488
1489 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1491
1492 let mut hasher = DefaultHasher::new();
1493 let time_nanos = SystemTime::now()
1495 .duration_since(UNIX_EPOCH)
1496 .map(|d| d.as_nanos())
1497 .unwrap_or(count as u128);
1498 time_nanos.hash(&mut hasher);
1499 count.hash(&mut hasher);
1500 let hash = hasher.finish();
1501 (hash as f32) / (u64::MAX as f32)
1502}