1use crate::backends::box_embeddings::BoxEmbedding;
46use anno_core::Entity;
47use anno_core::{CorefChain, CorefDocument};
48use serde::{Deserialize, Serialize};
49use std::collections::HashMap;
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TrainableBox {
64 pub mu: Vec<f32>,
66 pub delta: Vec<f32>,
68 pub dim: usize,
70}
71
72impl TrainableBox {
73 #[must_use]
84 pub fn new(mu: Vec<f32>, delta: Vec<f32>) -> Self {
85 assert_eq!(
86 mu.len(),
87 delta.len(),
88 "mu and delta must have same dimension"
89 );
90 let dim = mu.len();
91 Self { mu, delta, dim }
92 }
93
94 #[must_use]
98 pub fn from_vector(vector: &[f32], init_width: f32) -> Self {
99 let mu = vector.to_vec();
100 let delta: Vec<f32> = vec![init_width.ln(); mu.len()];
101 Self::new(mu, delta)
102 }
103
104 #[must_use]
106 pub fn to_box(&self) -> BoxEmbedding {
107 let min: Vec<f32> = self
108 .mu
109 .iter()
110 .zip(self.delta.iter())
111 .map(|(&m, &d)| m - (d.exp() / 2.0))
112 .collect();
113 let max: Vec<f32> = self
114 .mu
115 .iter()
116 .zip(self.delta.iter())
117 .map(|(&m, &d)| m + (d.exp() / 2.0))
118 .collect();
119 BoxEmbedding::new(min, max)
120 }
121}
122
123#[derive(Debug, Clone)]
133pub struct TrainingExample {
134 pub entities: Vec<Entity>,
136 pub chains: Vec<CorefChain>,
138}
139
140impl From<&CorefDocument> for TrainingExample {
144 fn from(doc: &CorefDocument) -> Self {
145 let mut entities = Vec::new();
147 let mut mention_to_entity_id = HashMap::new();
148
149 for chain in &doc.chains {
150 for mention in &chain.mentions {
151 let entity_id = mention.start;
153
154 let entity_type = mention
157 .entity_type
158 .as_ref()
159 .and_then(|s| match s.as_str() {
160 "PER" | "Person" | "person" => Some(anno_core::EntityType::Person),
161 "ORG" | "Organization" | "organization" => {
162 Some(anno_core::EntityType::Organization)
163 }
164 "LOC" | "Location" | "location" => Some(anno_core::EntityType::Location),
165 _ => None,
166 })
167 .unwrap_or(anno_core::EntityType::Person);
168
169 let entity = Entity::new(
171 mention.text.clone(),
172 entity_type,
173 entity_id,
174 mention.end,
175 1.0,
176 );
177
178 entities.push(entity);
179 mention_to_entity_id.insert((mention.start, mention.end), entity_id);
180 }
181 }
182
183 let chains = doc.chains.clone();
185
186 Self { entities, chains }
187 }
188}
189
190pub fn coref_documents_to_training_examples(docs: &[CorefDocument]) -> Vec<TrainingExample> {
192 docs.iter().map(TrainingExample::from).collect()
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct TrainingConfig {
202 pub learning_rate: f32,
204 pub negative_weight: f32,
206 pub margin: f32,
208 pub regularization: f32,
210 pub epochs: usize,
212 pub batch_size: usize,
214 pub warmup_epochs: usize,
216 pub use_self_adversarial: bool,
218 pub adversarial_temperature: f32,
220 pub early_stopping_patience: Option<usize>,
222 pub early_stopping_min_delta: f32,
224 pub positive_focus_epochs: Option<usize>,
226}
227
228impl Default for TrainingConfig {
229 fn default() -> Self {
230 Self {
231 learning_rate: 0.001,
232 negative_weight: 0.5,
233 margin: 0.3,
234 regularization: 0.0001,
235 epochs: 100,
236 batch_size: 32,
237 warmup_epochs: 10,
238 use_self_adversarial: true,
239 adversarial_temperature: 1.0,
240 early_stopping_patience: Some(10),
241 early_stopping_min_delta: 0.001,
242 positive_focus_epochs: None,
243 }
244 }
245}
246
247#[derive(Debug, Clone)]
253pub struct AMSGradState {
254 pub m: Vec<f32>,
256 pub v: Vec<f32>,
258 pub v_hat: Vec<f32>,
260 pub t: usize,
262 pub lr: f32,
264 pub beta1: f32,
266 pub beta2: f32,
268 pub epsilon: f32,
270}
271
272impl AMSGradState {
273 pub fn new(dim: usize, learning_rate: f32) -> Self {
275 Self {
276 m: vec![0.0; dim],
277 v: vec![0.0; dim],
278 v_hat: vec![0.0; dim],
279 t: 0,
280 lr: learning_rate,
281 beta1: 0.9,
282 beta2: 0.999,
283 epsilon: 1e-8,
284 }
285 }
286
287 pub fn set_lr(&mut self, lr: f32) {
289 self.lr = lr;
290 }
291}
292
293pub struct BoxEmbeddingTrainer {
299 config: TrainingConfig,
301 boxes: HashMap<usize, TrainableBox>,
303 optimizer_states: HashMap<usize, AMSGradState>,
305 dim: usize,
307}
308
309impl BoxEmbeddingTrainer {
310 pub fn new(
318 config: TrainingConfig,
319 dim: usize,
320 initial_embeddings: Option<HashMap<usize, Vec<f32>>>,
321 ) -> Self {
322 let mut boxes = HashMap::new();
323 let mut optimizer_states = HashMap::new();
324
325 if let Some(embeddings) = initial_embeddings {
326 for (entity_id, vector) in embeddings {
328 assert_eq!(vector.len(), dim);
329 let box_embedding = TrainableBox::from_vector(&vector, 0.1);
330 boxes.insert(entity_id, box_embedding.clone());
331 optimizer_states.insert(entity_id, AMSGradState::new(dim, config.learning_rate));
332 }
333 }
334
335 Self {
336 config,
337 boxes,
338 optimizer_states,
339 dim,
340 }
341 }
342
343 pub fn initialize_boxes(
358 &mut self,
359 examples: &[TrainingExample],
360 initial_embeddings: Option<&HashMap<usize, Vec<f32>>>,
361 ) {
362 let mut entity_ids = std::collections::HashSet::new();
364 let mut coref_groups: Vec<Vec<usize>> = Vec::new();
365
366 for example in examples {
367 for entity in &example.entities {
368 let entity_id = entity.start;
369 entity_ids.insert(entity_id);
370 }
371
372 for chain in &example.chains {
374 let group: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
375 if group.len() > 1 {
376 coref_groups.push(group);
377 }
378 }
379 }
380
381 for &entity_id in &entity_ids {
383 if let Some(embeddings) = initial_embeddings {
385 if let Some(vector) = embeddings.get(&entity_id) {
386 let norm: f32 = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
388 let normalized: Vec<f32> = if norm > 0.0 {
389 vector.iter().map(|&x| x / norm).collect()
390 } else {
391 vector.clone()
392 };
393
394 let box_embedding = TrainableBox::from_vector(&normalized, 0.2);
397 self.boxes.insert(entity_id, box_embedding.clone());
398 self.optimizer_states.insert(
399 entity_id,
400 AMSGradState::new(self.dim, self.config.learning_rate),
401 );
402 continue;
403 }
404 }
405
406 let mut group_center: Option<Vec<f32>> = None;
408 let mut in_coref_group = false;
409
410 for group in &coref_groups {
411 if group.contains(&entity_id) {
412 if group_center.is_none() {
414 group_center = Some(
415 (0..self.dim)
416 .map(|_| (simple_random() - 0.5) * 0.3) .collect(),
418 );
419 }
420 in_coref_group = true;
421 break;
422 }
423 }
424
425 let mu = if let Some(ref center) = group_center {
427 center
429 .iter()
430 .map(|&c| c + (simple_random() - 0.5) * 0.05) .collect()
432 } else {
433 (0..self.dim)
435 .map(|_| (simple_random() - 0.5) * 1.0)
436 .collect()
437 };
438
439 let initial_width = if in_coref_group {
443 1.1_f32 } else {
445 0.18_f32 };
447 let delta: Vec<f32> = vec![initial_width.ln(); self.dim];
448 let box_embedding = TrainableBox::new(mu, delta);
449 self.boxes.insert(entity_id, box_embedding.clone());
450 self.optimizer_states.insert(
451 entity_id,
452 AMSGradState::new(self.dim, self.config.learning_rate),
453 );
454 }
455 }
456
457 fn train_example(&mut self, example: &TrainingExample, epoch: usize) -> f32 {
459 let mut total_loss = 0.0;
460 let mut num_pairs = 0;
461
462 let current_lr = get_learning_rate(
464 epoch,
465 self.config.epochs,
466 self.config.learning_rate,
467 self.config.warmup_epochs,
468 );
469 for state in self.optimizer_states.values_mut() {
470 state.set_lr(current_lr);
471 }
472
473 let mut positive_pairs = Vec::new();
475 for chain in &example.chains {
476 let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
477 for i in 0..mentions.len() {
478 for j in (i + 1)..mentions.len() {
479 positive_pairs.push((mentions[i], mentions[j]));
480 }
481 }
482 }
483
484 let mut negative_pairs = Vec::new();
486 for i in 0..example.chains.len() {
487 for j in (i + 1)..example.chains.len() {
488 let chain_i: Vec<usize> =
489 example.chains[i].mentions.iter().map(|m| m.start).collect();
490 let chain_j: Vec<usize> =
491 example.chains[j].mentions.iter().map(|m| m.start).collect();
492 for &id_i in &chain_i {
493 for &id_j in &chain_j {
494 negative_pairs.push((id_i, id_j));
495 }
496 }
497 }
498 }
499
500 let mut gradients: HashMap<usize, (Vec<f32>, Vec<f32>)> = HashMap::new();
502
503 for &(id_a, id_b) in &positive_pairs {
505 let box_a = self.boxes.get(&id_a).cloned();
507 let box_b = self.boxes.get(&id_b).cloned();
508
509 if let (Some(box_a_ref), Some(box_b_ref)) = (box_a.as_ref(), box_b.as_ref()) {
510 let loss = compute_pair_loss(box_a_ref, box_b_ref, true, &self.config);
511 total_loss += loss;
512 num_pairs += 1;
513
514 let (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b) =
516 compute_analytical_gradients(box_a_ref, box_b_ref, true, &self.config);
517
518 if grad_mu_a.iter().any(|&x| !x.is_finite())
520 || grad_delta_a.iter().any(|&x| !x.is_finite())
521 || grad_mu_b.iter().any(|&x| !x.is_finite())
522 || grad_delta_b.iter().any(|&x| !x.is_finite())
523 {
524 continue;
525 }
526
527 let entry_a = gradients
529 .entry(id_a)
530 .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
531 for i in 0..self.dim {
532 entry_a.0[i] += grad_mu_a[i];
533 entry_a.1[i] += grad_delta_a[i];
534 }
535
536 let entry_b = gradients
537 .entry(id_b)
538 .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
539 for i in 0..self.dim {
540 entry_b.0[i] += grad_mu_b[i];
541 entry_b.1[i] += grad_delta_b[i];
542 }
543 }
544 }
545
546 let negative_samples: Vec<(usize, usize)> =
548 if self.config.use_self_adversarial && !negative_pairs.is_empty() {
549 let num_samples = positive_pairs.len().min(negative_pairs.len());
551 let sampled_indices = sample_self_adversarial_negatives(
552 &negative_pairs,
553 &self.boxes,
554 num_samples,
555 self.config.adversarial_temperature,
556 );
557 sampled_indices
558 .iter()
559 .map(|&idx| negative_pairs[idx])
560 .collect()
561 } else {
562 let num_samples = positive_pairs.len().min(negative_pairs.len());
564 negative_pairs.into_iter().take(num_samples).collect()
565 };
566
567 for &(id_a, id_b) in &negative_samples {
568 let box_a = self.boxes.get(&id_a).cloned();
570 let box_b = self.boxes.get(&id_b).cloned();
571
572 if let (Some(box_a_ref), Some(box_b_ref)) = (box_a.as_ref(), box_b.as_ref()) {
573 let loss = compute_pair_loss(box_a_ref, box_b_ref, false, &self.config);
574 total_loss += loss;
575 num_pairs += 1;
576
577 let (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b) =
579 compute_analytical_gradients(box_a_ref, box_b_ref, false, &self.config);
580
581 if grad_mu_a.iter().any(|&x| !x.is_finite())
583 || grad_delta_a.iter().any(|&x| !x.is_finite())
584 || grad_mu_b.iter().any(|&x| !x.is_finite())
585 || grad_delta_b.iter().any(|&x| !x.is_finite())
586 {
587 continue;
588 }
589
590 let entry_a = gradients
592 .entry(id_a)
593 .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
594 for i in 0..self.dim {
595 entry_a.0[i] += grad_mu_a[i];
596 entry_a.1[i] += grad_delta_a[i];
597 }
598
599 let entry_b = gradients
600 .entry(id_b)
601 .or_insert_with(|| (vec![0.0; self.dim], vec![0.0; self.dim]));
602 for i in 0..self.dim {
603 entry_b.0[i] += grad_mu_b[i];
604 entry_b.1[i] += grad_delta_b[i];
605 }
606 }
607 }
608
609 for (entity_id, (grad_mu, grad_delta)) in gradients {
611 if let (Some(box_mut), Some(state)) = (
612 self.boxes.get_mut(&entity_id),
613 self.optimizer_states.get_mut(&entity_id),
614 ) {
615 box_mut.update_amsgrad(&grad_mu, &grad_delta, state);
616 }
617 }
618
619 if num_pairs > 0 {
620 total_loss / num_pairs as f32
621 } else {
622 0.0
623 }
624 }
625
626 pub fn train(&mut self, examples: &[TrainingExample]) -> Vec<f32> {
630 let mut losses = Vec::new();
631 let mut best_loss = f32::INFINITY;
632 let mut patience_counter = 0;
633
634 let mut score_gap_history = Vec::new();
636
637 for epoch in 0..self.config.epochs {
638 let (avg_pos, avg_neg, _) = self.get_overlap_stats(examples);
640 let current_gap = avg_pos - avg_neg;
641 score_gap_history.push(current_gap);
642
643 let positive_focus_epochs = self
645 .config
646 .positive_focus_epochs
647 .unwrap_or(self.config.epochs / 3);
648 let is_positive_stage = epoch < positive_focus_epochs;
649
650 let adaptive_negative_weight = if is_positive_stage {
652 let stage_progress = epoch as f32 / positive_focus_epochs as f32;
656 self.config.negative_weight * (0.2 + stage_progress * 0.1)
657 } else if avg_pos > 0.05 && avg_neg > 0.3 {
658 let progress = ((epoch - positive_focus_epochs) as f32
661 / (self.config.epochs - positive_focus_epochs) as f32)
662 .min(1.0);
663 let neg_penalty = (avg_neg / 0.4).min(1.0); self.config.negative_weight * (0.7 + progress * 0.8 + neg_penalty * 0.4).min(2.0)
666 } else if avg_pos > 0.02 && current_gap > 0.0 {
668 let progress = ((epoch - positive_focus_epochs) as f32
671 / (self.config.epochs - positive_focus_epochs) as f32)
672 .min(1.0);
673 self.config.negative_weight * (0.5 + progress * 0.5).min(1.0 + (current_gap / 0.1))
674 } else if avg_pos < 0.01 {
675 self.config.negative_weight * 0.3
677 } else {
678 let progress = ((epoch - positive_focus_epochs) as f32
680 / (self.config.epochs - positive_focus_epochs) as f32)
681 .min(1.0);
682 self.config.negative_weight * (0.4 + progress * 0.4)
683 };
684
685 let original_negative_weight = self.config.negative_weight;
687 self.config.negative_weight = adaptive_negative_weight;
688 let mut shuffled_indices: Vec<usize> = (0..examples.len()).collect();
690 for i in (1..shuffled_indices.len()).rev() {
691 let j = (simple_random() * (i + 1) as f32) as usize;
692 shuffled_indices.swap(i, j);
693 }
694
695 let mut epoch_loss = 0.0;
696 let mut num_batches = 0;
697
698 for batch_start in (0..examples.len()).step_by(self.config.batch_size) {
700 let batch_end = (batch_start + self.config.batch_size).min(examples.len());
701 let batch_indices = &shuffled_indices[batch_start..batch_end];
702
703 let mut batch_loss = 0.0;
704 let mut batch_pairs = 0;
705
706 for &idx in batch_indices {
708 let example = &examples[idx];
709 let loss = self.train_example(example, epoch);
710 batch_loss += loss;
711 batch_pairs += 1;
712 }
713
714 if batch_pairs > 0 {
715 epoch_loss += batch_loss / batch_pairs as f32;
716 num_batches += 1;
717 }
718 }
719
720 let avg_loss = if num_batches > 0 {
721 epoch_loss / num_batches as f32
722 } else {
723 0.0
724 };
725 losses.push(avg_loss);
726
727 let current_lr = get_learning_rate(
728 epoch,
729 self.config.epochs,
730 self.config.learning_rate,
731 self.config.warmup_epochs,
732 );
733
734 let improved = avg_loss < best_loss - self.config.early_stopping_min_delta;
736 if improved {
737 best_loss = avg_loss;
738 patience_counter = 0;
739 } else {
740 patience_counter += 1;
741 }
742
743 if epoch % 10 == 0 || epoch == self.config.epochs - 1 || improved {
745 let (avg_pos, avg_neg, overlap_rate) = self.get_overlap_stats(examples);
746 let status = if improved { "✓" } else { " " };
747 let patience_info = if let Some(patience) = self.config.early_stopping_patience {
748 format!(", patience={}/{}", patience_counter, patience)
749 } else {
750 String::new()
751 };
752 let loss_reduction = if losses.len() > 1 {
753 format!(" ({:.1}%↓)", (1.0 - avg_loss / losses[0]) * 100.0)
754 } else {
755 String::new()
756 };
757 let score_gap = avg_pos - avg_neg; let positive_focus_epochs = self
759 .config
760 .positive_focus_epochs
761 .unwrap_or(self.config.epochs / 3);
762 let stage = if epoch < positive_focus_epochs {
763 "P+"
764 } else {
765 "S-"
766 };
767 println!("Epoch {}: loss = {:.4}{}, lr = {:.6}, best = {:.4} {} ({} batches{}, neg_w={:.2}, stage={})",
768 epoch, avg_loss, loss_reduction, current_lr, best_loss, status, num_batches, patience_info, adaptive_negative_weight, stage);
769 println!(
770 " Overlap: {:.1}%, Pos: {:.4}, Neg: {:.4}, Gap: {:.4} {}",
771 overlap_rate * 100.0,
772 avg_pos,
773 avg_neg,
774 score_gap,
775 if score_gap > 0.0 { "✓" } else { "⚠" }
776 );
777 }
778
779 self.config.negative_weight = original_negative_weight;
781
782 if let Some(patience) = self.config.early_stopping_patience {
784 if patience_counter >= patience {
785 println!(
786 "Early stopping at epoch {} (no improvement for {} epochs)",
787 epoch, patience
788 );
789 break;
790 }
791 }
792 }
793
794 losses
795 }
796
797 pub fn get_boxes(&self) -> HashMap<usize, BoxEmbedding> {
799 self.boxes
800 .iter()
801 .map(|(id, trainable)| (*id, trainable.to_box()))
802 .collect()
803 }
804
805 pub fn get_overlap_stats(&self, examples: &[TrainingExample]) -> (f32, f32, f32) {
809 let mut positive_scores = Vec::new();
810 let mut negative_scores = Vec::new();
811 let mut overlapping_pairs = 0;
812 let mut total_pairs = 0;
813
814 for example in examples {
815 for chain in &example.chains {
817 let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
818 for i in 0..mentions.len() {
819 for j in (i + 1)..mentions.len() {
820 if let (Some(box_a), Some(box_b)) =
821 (self.boxes.get(&mentions[i]), self.boxes.get(&mentions[j]))
822 {
823 let box_a_embed = box_a.to_box();
824 let box_b_embed = box_b.to_box();
825 let score = box_a_embed.coreference_score(&box_b_embed);
826 positive_scores.push(score);
827 if score > 0.01 {
828 overlapping_pairs += 1;
829 }
830 total_pairs += 1;
831 }
832 }
833 }
834 }
835
836 for i in 0..example.chains.len() {
838 for j in (i + 1)..example.chains.len() {
839 let chain_i: Vec<usize> =
840 example.chains[i].mentions.iter().map(|m| m.start).collect();
841 let chain_j: Vec<usize> =
842 example.chains[j].mentions.iter().map(|m| m.start).collect();
843 for &id_i in &chain_i {
844 for &id_j in &chain_j {
845 if let (Some(box_a), Some(box_b)) =
846 (self.boxes.get(&id_i), self.boxes.get(&id_j))
847 {
848 let box_a_embed = box_a.to_box();
849 let box_b_embed = box_b.to_box();
850 let score = box_a_embed.coreference_score(&box_b_embed);
851 negative_scores.push(score);
852 }
853 }
854 }
855 }
856 }
857 }
858
859 let avg_positive = if !positive_scores.is_empty() {
860 positive_scores.iter().sum::<f32>() / positive_scores.len() as f32
861 } else {
862 0.0
863 };
864
865 let avg_negative = if !negative_scores.is_empty() {
866 negative_scores.iter().sum::<f32>() / negative_scores.len() as f32
867 } else {
868 0.0
869 };
870
871 let overlap_rate = if total_pairs > 0 {
872 overlapping_pairs as f32 / total_pairs as f32
873 } else {
874 0.0
875 };
876
877 (avg_positive, avg_negative, overlap_rate)
878 }
879
880 pub fn evaluate(&self, examples: &[TrainingExample], threshold: f32) -> (f32, f32, f32, f32) {
891 let mut true_positives = 0;
892 let mut false_positives = 0;
893 let mut false_negatives = 0;
894 let mut total_pairs = 0;
895
896 for example in examples {
897 let mut positive_pairs = Vec::new();
899 for chain in &example.chains {
900 let mentions: Vec<usize> = chain.mentions.iter().map(|m| m.start).collect();
901 for i in 0..mentions.len() {
902 for j in (i + 1)..mentions.len() {
903 positive_pairs.push((mentions[i], mentions[j]));
904 }
905 }
906 }
907
908 let mut negative_pairs = Vec::new();
910 for i in 0..example.chains.len() {
911 for j in (i + 1)..example.chains.len() {
912 let chain_i: Vec<usize> =
913 example.chains[i].mentions.iter().map(|m| m.start).collect();
914 let chain_j: Vec<usize> =
915 example.chains[j].mentions.iter().map(|m| m.start).collect();
916 for &id_i in &chain_i {
917 for &id_j in &chain_j {
918 negative_pairs.push((id_i, id_j));
919 }
920 }
921 }
922 }
923
924 for &(id_a, id_b) in &positive_pairs {
926 total_pairs += 1;
927 if let (Some(box_a), Some(box_b)) = (self.boxes.get(&id_a), self.boxes.get(&id_b)) {
928 let box_a_embed = box_a.to_box();
929 let box_b_embed = box_b.to_box();
930 let score = box_a_embed.coreference_score(&box_b_embed);
931 if score >= threshold {
932 true_positives += 1;
933 } else {
934 false_negatives += 1;
935 }
936 } else {
937 false_negatives += 1;
939 }
940 }
941
942 for &(id_a, id_b) in &negative_pairs {
944 total_pairs += 1;
945 if let (Some(box_a), Some(box_b)) = (self.boxes.get(&id_a), self.boxes.get(&id_b)) {
946 let box_a_embed = box_a.to_box();
947 let box_b_embed = box_b.to_box();
948 let score = box_a_embed.coreference_score(&box_b_embed);
949 if score >= threshold {
950 false_positives += 1;
951 }
952 }
954 }
956 }
957
958 let precision = if true_positives + false_positives > 0 {
960 true_positives as f32 / (true_positives + false_positives) as f32
961 } else {
962 0.0
963 };
964
965 let recall = if true_positives + false_negatives > 0 {
966 true_positives as f32 / (true_positives + false_negatives) as f32
967 } else {
968 0.0
969 };
970
971 let f1 = if precision + recall > 0.0 {
972 2.0 * precision * recall / (precision + recall)
973 } else {
974 0.0
975 };
976
977 let accuracy = if total_pairs > 0 {
978 (true_positives + (total_pairs - true_positives - false_positives - false_negatives))
979 as f32
980 / total_pairs as f32
981 } else {
982 0.0
983 };
984
985 (accuracy, precision, recall, f1)
986 }
987
988 pub fn save_boxes(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1000 use std::fs::File;
1001 use std::io::Write;
1002
1003 let serialized = serde_json::to_string_pretty(&self.boxes)?;
1004 let mut file = File::create(path)?;
1005 file.write_all(serialized.as_bytes())?;
1006 Ok(())
1007 }
1008
1009 pub fn load_boxes(
1022 path: &str,
1023 dim: usize,
1024 ) -> Result<HashMap<usize, TrainableBox>, Box<dyn std::error::Error>> {
1025 use std::fs::File;
1026 use std::io::Read;
1027
1028 let mut file = File::open(path)?;
1029 let mut contents = String::new();
1030 file.read_to_string(&mut contents)?;
1031 let boxes: HashMap<usize, TrainableBox> = serde_json::from_str(&contents)?;
1032
1033 for (id, box_embedding) in &boxes {
1035 if box_embedding.dim != dim {
1036 return Err(format!(
1037 "Box for entity {} has dimension {}, expected {}",
1038 id, box_embedding.dim, dim
1039 )
1040 .into());
1041 }
1042 }
1043
1044 Ok(boxes)
1045 }
1046
1047 #[cfg(any(feature = "analysis", feature = "eval"))]
1063 pub fn evaluate_standard_metrics(
1064 &self,
1065 examples: &[TrainingExample],
1066 threshold: f32,
1067 ) -> crate::eval::coref_metrics::CorefEvaluation {
1068 use crate::backends::box_embeddings::BoxCorefConfig;
1069 use crate::eval::coref_metrics::CorefEvaluation;
1070 use crate::eval::coref_resolver::BoxCorefResolver;
1071
1072 let mut all_predicted_chains = Vec::new();
1073 let mut all_gold_chains = Vec::new();
1074
1075 for example in examples {
1076 all_gold_chains.extend(example.chains.clone());
1078
1079 let entities = &example.entities;
1081
1082 let mut boxes = Vec::new();
1084 for entity in entities {
1085 if let Some(trainable_box) = self.boxes.get(&entity.start) {
1086 boxes.push(trainable_box.to_box());
1087 } else {
1088 let center = vec![0.0; self.dim];
1090 boxes.push(crate::backends::box_embeddings::BoxEmbedding::from_vector(
1091 ¢er, 0.1,
1092 ));
1093 }
1094 }
1095
1096 let box_config = BoxCorefConfig {
1098 coreference_threshold: threshold,
1099 ..Default::default()
1100 };
1101 let resolver = BoxCorefResolver::new(box_config);
1102 let resolved_entities = resolver.resolve_with_boxes(entities, &boxes);
1103
1104 let predicted_chains = anno_core::core::coref::entities_to_chains(&resolved_entities);
1106 all_predicted_chains.extend(predicted_chains);
1107 }
1108
1109 CorefEvaluation::compute(&all_predicted_chains, &all_gold_chains)
1111 }
1112}
1113
1114pub fn split_train_val(
1125 examples: &[TrainingExample],
1126 val_ratio: f32,
1127) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
1128 let val_size = (examples.len() as f32 * val_ratio) as usize;
1129 let mut shuffled: Vec<TrainingExample> = examples.to_vec();
1130
1131 for i in (1..shuffled.len()).rev() {
1133 let j = (simple_random() * (i + 1) as f32) as usize;
1134 shuffled.swap(i, j);
1135 }
1136
1137 let val_examples = shuffled.split_off(val_size);
1138 (shuffled, val_examples)
1139}
1140
1141fn compute_pair_loss(
1147 box_a: &TrainableBox,
1148 box_b: &TrainableBox,
1149 is_positive: bool,
1150 config: &TrainingConfig,
1151) -> f32 {
1152 let box_a_embed = box_a.to_box();
1153 let box_b_embed = box_b.to_box();
1154
1155 if is_positive {
1156 let p_a_b = box_a_embed.conditional_probability(&box_b_embed);
1158 let p_b_a = box_b_embed.conditional_probability(&box_a_embed);
1159
1160 let p_a_b = p_a_b.max(1e-8);
1162 let p_b_a = p_b_a.max(1e-8);
1163
1164 let min_prob = p_a_b.min(p_b_a);
1167 let neg_log_prob = -min_prob.ln();
1168
1169 let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1171 let distance_penalty = if vol_intersection < 1e-10 {
1172 let center_a = box_a_embed.center();
1174 let center_b = box_b_embed.center();
1175 let dist: f32 = center_a
1176 .iter()
1177 .zip(center_b.iter())
1178 .map(|(a, b)| (a - b).powi(2))
1179 .sum::<f32>()
1180 .sqrt();
1181 0.3 * dist } else {
1183 let vol_a = box_a_embed.volume();
1186 let vol_b = box_b_embed.volume();
1187 let overlap_ratio = vol_intersection / vol_a.min(vol_b).max(1e-10);
1188 if overlap_ratio < 0.5 {
1189 0.1 * (0.5 - overlap_ratio)
1191 } else {
1192 0.0
1193 }
1194 };
1195
1196 let vol_a = box_a_embed.volume();
1198 let vol_b = box_b_embed.volume();
1199 let reg = config.regularization * 1.0 * (vol_a + vol_b);
1201
1202 (neg_log_prob + reg + distance_penalty).max(0.0)
1203 } else {
1204 let p_a_b = box_a_embed.conditional_probability(&box_b_embed);
1207 let p_b_a = box_b_embed.conditional_probability(&box_a_embed);
1208
1209 let max_prob = p_a_b.max(p_b_a);
1211
1212 let margin_loss = if max_prob > config.margin {
1215 let excess = max_prob - config.margin;
1217 excess.powi(2) * (1.0 + excess * 2.0) } else {
1219 0.0 };
1221
1222 let _high_prob_penalty = if max_prob > 0.1 {
1225 (max_prob - 0.1).powi(2) * 0.5 } else {
1227 0.0
1228 };
1229
1230 let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1232 let vol_a = box_a_embed.volume();
1233 let vol_b = box_b_embed.volume();
1234 let overlap_penalty = if vol_intersection > 1e-10 {
1235 let overlap_ratio = vol_intersection / vol_a.min(vol_b).max(1e-10);
1237 if overlap_ratio > 0.5 {
1239 4.0 * overlap_ratio * overlap_ratio } else if overlap_ratio > 0.3 {
1241 3.0 * overlap_ratio } else {
1243 2.5 * overlap_ratio }
1245 } else {
1246 0.0
1247 };
1248
1249 let base_loss = if max_prob > 0.01 {
1252 max_prob * 0.2 } else {
1254 0.0 };
1256
1257 let adaptive_penalty = if max_prob > 0.1 {
1260 let prob_excess = max_prob - 0.1;
1262 prob_excess.powi(2) * (3.0 + prob_excess * 7.0) } else if max_prob > 0.05 {
1264 (max_prob - 0.05).powi(2) * 1.5 } else if max_prob > 0.02 {
1267 (max_prob - 0.02).powi(2) * 0.5
1269 } else {
1270 0.0
1271 };
1272
1273 config.negative_weight * (margin_loss + overlap_penalty + base_loss + adaptive_penalty)
1274 }
1275}
1276
1277fn compute_analytical_gradients(
1279 box_a: &TrainableBox,
1280 box_b: &TrainableBox,
1281 is_positive: bool,
1282 config: &TrainingConfig,
1283) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
1284 let box_a_embed = box_a.to_box();
1285 let box_b_embed = box_b.to_box();
1286 let dim = box_a.dim;
1287
1288 let mut grad_mu_a = vec![0.0; dim];
1290 let mut grad_delta_a = vec![0.0; dim];
1291 let mut grad_mu_b = vec![0.0; dim];
1292 let mut grad_delta_b = vec![0.0; dim];
1293
1294 let vol_a = box_a_embed.volume();
1296 let vol_b = box_b_embed.volume();
1297 let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1298
1299 if is_positive {
1300 let p_a_b = if vol_b > 0.0 {
1305 vol_intersection / vol_b
1306 } else {
1307 0.0
1308 };
1309 let p_b_a = if vol_a > 0.0 {
1310 vol_intersection / vol_a
1311 } else {
1312 0.0
1313 };
1314
1315 let p_a_b = p_a_b.max(1e-8);
1317 let p_b_a = p_b_a.max(1e-8);
1318
1319 let vol_intersection = box_a_embed.intersection_volume(&box_b_embed);
1325 let has_overlap = vol_intersection > 1e-10;
1326
1327 if !has_overlap {
1328 let center_a = box_a_embed.center();
1330 let center_b = box_b_embed.center();
1331 let center_dist = center_a
1332 .iter()
1333 .zip(center_b.iter())
1334 .map(|(a, b)| (a - b).powi(2))
1335 .sum::<f32>()
1336 .sqrt();
1337
1338 for i in 0..dim {
1339 let diff = center_b[i] - center_a[i];
1340 let distance_factor = (center_dist / dim as f32).clamp(0.5, 2.0);
1342 let attraction_strength = 4.0 * distance_factor; grad_mu_a[i] += attraction_strength * diff;
1345 grad_mu_b[i] += -attraction_strength * diff;
1346
1347 grad_delta_a[i] += 0.5 * distance_factor; grad_delta_b[i] += 0.5 * distance_factor;
1350 }
1351 }
1352
1353 for i in 0..dim {
1354 let overlap_i = if box_a_embed.min[i] < box_b_embed.max[i]
1360 && box_b_embed.min[i] < box_a_embed.max[i]
1361 {
1362 let min_overlap = box_a_embed.min[i].max(box_b_embed.min[i]);
1364 let max_overlap = box_a_embed.max[i].min(box_b_embed.max[i]);
1365 (max_overlap - min_overlap).max(0.0)
1366 } else {
1367 0.0
1368 };
1369
1370 if overlap_i > 0.0 && vol_intersection > 0.0 {
1371 let overlap_ratio_a = vol_intersection / vol_a.max(1e-10);
1374 let overlap_ratio_b = vol_intersection / vol_b.max(1e-10);
1375
1376 if overlap_ratio_a < 0.15 {
1379 grad_delta_a[i] += 0.35;
1381 } else if overlap_ratio_a < 0.3 {
1382 grad_delta_a[i] += 0.3;
1384 } else if overlap_ratio_a < 0.5 {
1385 grad_delta_a[i] += 0.2;
1387 } else if overlap_ratio_a < 0.7 {
1388 grad_delta_a[i] += 0.1;
1390 } else if overlap_ratio_a < 0.85 {
1391 grad_delta_a[i] += 0.05;
1393 }
1394 if overlap_ratio_b < 0.15 {
1397 grad_delta_b[i] += 0.35;
1399 } else if overlap_ratio_b < 0.3 {
1400 grad_delta_b[i] += 0.3;
1402 } else if overlap_ratio_b < 0.5 {
1403 grad_delta_b[i] += 0.2;
1405 } else if overlap_ratio_b < 0.7 {
1406 grad_delta_b[i] += 0.1;
1408 } else if overlap_ratio_b < 0.85 {
1409 grad_delta_b[i] += 0.05;
1411 }
1412
1413 let gradient_strength = if overlap_ratio_a < 0.1 {
1416 1.7 } else if overlap_ratio_a < 0.2 {
1418 1.6 } else if overlap_ratio_a < 0.4 {
1420 1.4 } else if overlap_ratio_a < 0.6 {
1422 1.1 } else {
1424 0.6 };
1426
1427 let grad_vol_intersection_delta_a = vol_intersection * 0.5 * gradient_strength;
1428 let grad_p_a_b_delta_a = grad_vol_intersection_delta_a / vol_b.max(1e-8);
1429 grad_delta_a[i] += -grad_p_a_b_delta_a / p_a_b.max(1e-8) * gradient_strength;
1430
1431 let grad_vol_intersection_delta_b = vol_intersection * 0.5 * gradient_strength;
1432 let grad_p_b_a_delta_b = grad_vol_intersection_delta_b / vol_a.max(1e-8);
1433 grad_delta_b[i] += -grad_p_b_a_delta_b / p_b_a.max(1e-8) * gradient_strength;
1434 } else {
1435 grad_delta_a[i] += 0.3; grad_delta_b[i] += 0.3; }
1439
1440 grad_delta_a[i] += config.regularization * 1.0 * vol_a; grad_delta_b[i] += config.regularization * 1.0 * vol_b;
1445 }
1446 } else {
1447 let p_a_b = if vol_b > 0.0 {
1449 vol_intersection / vol_b
1450 } else {
1451 0.0
1452 };
1453 let p_b_a = if vol_a > 0.0 {
1454 vol_intersection / vol_a
1455 } else {
1456 0.0
1457 };
1458 let max_prob = p_a_b.max(p_b_a);
1459
1460 for i in 0..dim {
1463 let overlap_i = if box_a_embed.min[i] < box_b_embed.max[i]
1465 && box_b_embed.min[i] < box_a_embed.max[i]
1466 {
1467 let min_overlap = box_a_embed.min[i].max(box_b_embed.min[i]);
1468 let max_overlap = box_a_embed.max[i].min(box_b_embed.max[i]);
1469 (max_overlap - min_overlap).max(0.0)
1470 } else {
1471 0.0
1472 };
1473
1474 if overlap_i > 0.0 {
1475 let center_a = box_a_embed.center();
1478 let center_b = box_b_embed.center();
1479 let diff = center_b[i] - center_a[i];
1480
1481 let overlap_factor =
1484 (overlap_i / (box_a_embed.max[i] - box_a_embed.min[i]).max(1e-6)).min(1.0);
1485 let separation_strength = 1.5 + overlap_factor * 2.0; if diff.abs() > 1e-6 {
1487 grad_mu_a[i] += -config.negative_weight * separation_strength * diff;
1488 grad_mu_b[i] += config.negative_weight * separation_strength * diff;
1489 } else {
1490 grad_mu_a[i] += -config.negative_weight * separation_strength * 2.5;
1492 grad_mu_b[i] += config.negative_weight * separation_strength * 2.5;
1493 }
1494
1495 let overlap_ratio_dim =
1498 overlap_i / (box_a_embed.max[i] - box_a_embed.min[i]).max(1e-6);
1499 let shrink_strength = if overlap_ratio_dim > 0.7 {
1500 0.7 } else if overlap_ratio_dim > 0.5 {
1502 0.6 } else if overlap_ratio_dim > 0.3 {
1504 0.5 } else {
1506 0.35 };
1508 grad_delta_a[i] += -config.negative_weight * shrink_strength;
1509 grad_delta_b[i] += -config.negative_weight * shrink_strength;
1510 } else {
1511 }
1514
1515 if overlap_i > 0.0 && vol_intersection > 1e-10 {
1519 let min_vol = vol_a.min(vol_b);
1520 let overlap_ratio = vol_intersection / min_vol.max(1e-10);
1521 let penalty_strength = if overlap_ratio > 0.5 {
1524 0.4 + overlap_ratio * 0.6 } else if overlap_ratio > 0.3 {
1526 0.3 + overlap_ratio * 0.5 } else {
1528 0.2 + overlap_ratio * 0.4 };
1530 let penalty_multiplier = if overlap_ratio > 0.5 {
1531 4.0
1532 } else if overlap_ratio > 0.3 {
1533 3.0
1534 } else {
1535 2.5
1536 };
1537 grad_delta_a[i] +=
1538 config.negative_weight * penalty_multiplier * overlap_ratio * penalty_strength;
1539 grad_delta_b[i] +=
1540 config.negative_weight * penalty_multiplier * overlap_ratio * penalty_strength;
1541 }
1542
1543 if p_a_b >= p_b_a {
1548 if overlap_i > 0.0 && vol_intersection > 1e-10 {
1550 let grad_vol_intersection_delta_a = vol_intersection * 0.4;
1551 let grad_p_a_b_delta_a = grad_vol_intersection_delta_a / vol_b.max(1e-8);
1552 grad_delta_a[i] += config.negative_weight * 0.2 * grad_p_a_b_delta_a;
1554
1555 if max_prob > config.margin {
1557 let excess = max_prob - config.margin;
1558 let margin_grad = 2.0 * excess * (1.0 + excess * 2.0) * grad_p_a_b_delta_a
1559 + 2.0 * excess.powi(2) * 2.0 * grad_p_a_b_delta_a; grad_delta_a[i] += config.negative_weight * margin_grad;
1561 }
1562
1563 if max_prob > 0.1 {
1565 let prob_excess = max_prob - 0.1;
1567 let adaptive_grad =
1568 2.0 * prob_excess * grad_p_a_b_delta_a * (3.0 + prob_excess * 7.0); grad_delta_a[i] += config.negative_weight * adaptive_grad;
1570 } else if max_prob > 0.05 {
1571 let prob_excess = max_prob - 0.05;
1573 let adaptive_grad = 2.0 * prob_excess * grad_p_a_b_delta_a * 1.5; grad_delta_a[i] += config.negative_weight * adaptive_grad;
1575 } else if max_prob > 0.02 {
1576 let prob_excess = max_prob - 0.02;
1578 let adaptive_grad = 2.0 * prob_excess * grad_p_a_b_delta_a * 0.5;
1579 grad_delta_a[i] += config.negative_weight * adaptive_grad;
1580 }
1581 }
1582 } else {
1584 if overlap_i > 0.0 && vol_intersection > 1e-10 {
1586 let grad_vol_intersection_delta_b = vol_intersection * 0.4;
1587 let grad_p_b_a_delta_b = grad_vol_intersection_delta_b / vol_a.max(1e-8);
1588 grad_delta_b[i] += config.negative_weight * 0.25 * grad_p_b_a_delta_b; if max_prob > config.margin {
1593 let excess = max_prob - config.margin;
1594 let margin_grad = 2.0 * excess * (1.0 + excess * 2.0) * grad_p_b_a_delta_b
1595 + 2.0 * excess.powi(2) * 2.0 * grad_p_b_a_delta_b; grad_delta_b[i] += config.negative_weight * margin_grad;
1597 }
1598
1599 if max_prob > 0.1 {
1601 let prob_excess = max_prob - 0.1;
1603 let adaptive_grad =
1604 2.0 * prob_excess * grad_p_b_a_delta_b * (2.0 + prob_excess * 5.0);
1605 grad_delta_b[i] += config.negative_weight * adaptive_grad;
1606 } else if max_prob > 0.05 {
1607 let prob_excess = max_prob - 0.05;
1609 let adaptive_grad = 2.0 * prob_excess * grad_p_b_a_delta_b * 1.0;
1610 grad_delta_b[i] += config.negative_weight * adaptive_grad;
1611 }
1612 }
1613 }
1615 }
1616 }
1617
1618 for grad in &mut grad_mu_a {
1620 *grad = grad.clamp(-10.0_f32, 10.0_f32);
1621 }
1622 for grad in &mut grad_delta_a {
1623 *grad = grad.clamp(-10.0_f32, 10.0_f32);
1624 }
1625 for grad in &mut grad_mu_b {
1626 *grad = grad.clamp(-10.0_f32, 10.0_f32);
1627 }
1628 for grad in &mut grad_delta_b {
1629 *grad = grad.clamp(-10.0_f32, 10.0_f32);
1630 }
1631
1632 (grad_mu_a, grad_delta_a, grad_mu_b, grad_delta_b)
1633}
1634
1635fn sample_self_adversarial_negatives(
1637 negative_pairs: &[(usize, usize)],
1638 boxes: &HashMap<usize, TrainableBox>,
1639 num_samples: usize,
1640 temperature: f32,
1641) -> Vec<usize> {
1642 let mut scores: Vec<(usize, f32)> = negative_pairs
1644 .iter()
1645 .enumerate()
1646 .filter_map(|(idx, &(id_a, id_b))| {
1647 if let (Some(box_a), Some(box_b)) = (boxes.get(&id_a), boxes.get(&id_b)) {
1648 let box_a_embed = box_a.to_box();
1649 let box_b_embed = box_b.to_box();
1650 let score = box_a_embed.coreference_score(&box_b_embed);
1651 Some((idx, score / temperature))
1652 } else {
1653 None
1654 }
1655 })
1656 .collect();
1657
1658 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1660
1661 scores
1663 .into_iter()
1664 .take(num_samples)
1665 .map(|(idx, _)| idx)
1666 .collect()
1667}
1668
1669fn get_learning_rate(epoch: usize, total_epochs: usize, base_lr: f32, warmup_epochs: usize) -> f32 {
1671 if epoch < warmup_epochs {
1672 let warmup_lr = base_lr * 0.1;
1674 warmup_lr + (base_lr - warmup_lr) * (epoch as f32 / warmup_epochs as f32)
1675 } else {
1676 let progress =
1678 (epoch - warmup_epochs) as f32 / (total_epochs - warmup_epochs).max(1) as f32;
1679 let min_lr = base_lr * 0.1;
1680 min_lr + (base_lr - min_lr) * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0
1681 }
1682}
1683
1684impl TrainableBox {
1689 pub fn update_amsgrad(
1691 &mut self,
1692 grad_mu: &[f32],
1693 grad_delta: &[f32],
1694 state: &mut AMSGradState,
1695 ) {
1696 state.t += 1;
1697 let t = state.t as f32;
1698
1699 for (i, &grad) in grad_mu.iter().enumerate().take(self.dim) {
1701 state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * grad;
1702 }
1703
1704 for (i, &grad) in grad_mu.iter().enumerate().take(self.dim) {
1706 let v_new = state.beta2 * state.v[i] + (1.0 - state.beta2) * grad * grad;
1707 state.v[i] = v_new;
1708 state.v_hat[i] = state.v_hat[i].max(v_new);
1709 }
1710
1711 let m_hat: Vec<f32> = state
1713 .m
1714 .iter()
1715 .map(|&m| m / (1.0 - state.beta1.powf(t)))
1716 .collect();
1717
1718 for (i, &m_hat_val) in m_hat.iter().enumerate().take(self.dim) {
1720 let update = state.lr * m_hat_val / (state.v_hat[i].sqrt() + state.epsilon);
1721 self.mu[i] -= update;
1722
1723 if !self.mu[i].is_finite() {
1725 self.mu[i] = 0.0;
1726 }
1727 }
1728
1729 let mut m_delta = vec![0.0_f32; self.dim];
1731 let mut v_delta = vec![0.0_f32; self.dim];
1732 let mut v_hat_delta = vec![0.0_f32; self.dim];
1733
1734 for i in 0..self.dim {
1735 m_delta[i] = state.beta1 * m_delta[i] + (1.0 - state.beta1) * grad_delta[i];
1736 let v_new: f32 =
1737 state.beta2 * v_delta[i] + (1.0 - state.beta2) * grad_delta[i] * grad_delta[i];
1738 v_delta[i] = v_new;
1739 v_hat_delta[i] = v_hat_delta[i].max(v_new);
1740 }
1741
1742 let m_hat_delta: Vec<f32> = m_delta
1743 .iter()
1744 .map(|&m| m / (1.0 - state.beta1.powf(t)))
1745 .collect();
1746
1747 for i in 0..self.dim {
1748 let update = state.lr * m_hat_delta[i] / (v_hat_delta[i].sqrt() + state.epsilon);
1749 self.delta[i] -= update;
1750
1751 self.delta[i] = self.delta[i].clamp(0.01_f32.ln(), 10.0_f32.ln());
1753
1754 if !self.delta[i].is_finite() {
1756 self.delta[i] = 0.5_f32.ln();
1757 }
1758 }
1759 }
1760}
1761
1762fn simple_random() -> f32 {
1770 use std::collections::hash_map::DefaultHasher;
1771 use std::hash::{Hash, Hasher};
1772 use std::sync::atomic::{AtomicUsize, Ordering};
1773 use std::time::{SystemTime, UNIX_EPOCH};
1774
1775 static COUNTER: AtomicUsize = AtomicUsize::new(0);
1776
1777 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1779
1780 let mut hasher = DefaultHasher::new();
1781 let time_nanos = SystemTime::now()
1783 .duration_since(UNIX_EPOCH)
1784 .map(|d| d.as_nanos())
1785 .unwrap_or(count as u128);
1786 time_nanos.hash(&mut hasher);
1787 count.hash(&mut hasher);
1788 let hash = hasher.finish();
1789 (hash as f32) / (u64::MAX as f32)
1790}