1use super::factors::{CorefLinkWeights, CorefNerWeights, LinkNerWeights};
58use super::types::JointMention;
59use crate::{Entity, EntityType};
60use anno_core::CorefChain;
61use serde::{Deserialize, Serialize};
62use std::collections::HashMap;
63
64type DecodeResult = (
66 HashMap<usize, EntityType>,
67 HashMap<usize, Option<usize>>,
68 HashMap<usize, Option<String>>,
69);
70
71#[derive(Debug, Clone)]
77pub struct TrainingConfig {
78 pub learning_rate: f64,
80 pub epsilon: f64,
82 pub epochs: usize,
84 pub batch_size: usize,
86 pub l2_lambda: f64,
88 pub patience: usize,
90 pub min_delta: f64,
92 pub cost_weight: f64,
94 pub grad_clip: f64,
96 pub margin_rescaling: bool,
98 pub dynamic_batching: Option<DynamicBatchConfig>,
100}
101
102impl Default for TrainingConfig {
103 fn default() -> Self {
104 Self {
105 learning_rate: 0.1,
106 epsilon: 1e-8,
107 epochs: 50,
108 batch_size: 16,
109 l2_lambda: 1e-4,
110 patience: 5,
111 min_delta: 1e-4,
112 cost_weight: 1.0,
113 grad_clip: 5.0,
114 margin_rescaling: true,
115 dynamic_batching: None,
116 }
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct DynamicBatchConfig {
132 pub max_context_length: usize,
134 pub avg_sentence_length: usize,
136 pub min_contexts: usize,
138 pub max_contexts: usize,
140 pub same_document: bool,
142 pub window_overlap: usize,
144}
145
146impl Default for DynamicBatchConfig {
147 fn default() -> Self {
148 Self {
149 max_context_length: 4000,
150 avg_sentence_length: 25,
151 min_contexts: 1,
152 max_contexts: 20,
153 same_document: true, window_overlap: 256,
155 }
156 }
157}
158
159impl DynamicBatchConfig {
160 pub fn cross_document() -> Self {
162 Self {
163 max_context_length: 512, avg_sentence_length: 25,
165 min_contexts: 2,
166 max_contexts: 10,
167 same_document: false,
168 window_overlap: 0,
169 }
170 }
171
172 pub fn long_document() -> Self {
174 Self {
175 max_context_length: 4000,
176 avg_sentence_length: 25,
177 min_contexts: 1,
178 max_contexts: 20,
179 same_document: true,
180 window_overlap: 256,
181 }
182 }
183
184 pub fn sample_num_contexts(&self, rng_seed: u64) -> usize {
188 let x = rng_seed
190 .wrapping_mul(6364136223846793005)
191 .wrapping_add(1442695040888963407);
192 let range = self.max_contexts - self.min_contexts + 1;
193 self.min_contexts + (x as usize % range)
194 }
195
196 pub fn context_length(&self, num_contexts: usize, doc_length: usize) -> usize {
200 let base_length = self.max_context_length.min(doc_length);
201 if num_contexts > 0 {
202 base_length / num_contexts
203 } else {
204 base_length
205 }
206 }
207}
208
209#[derive(Debug, Clone)]
215pub struct TrainingExample {
216 pub text: String,
218 pub mentions: Vec<JointMention>,
220 pub gold_ner: HashMap<usize, EntityType>,
222 pub gold_coref: HashMap<usize, Option<usize>>,
224 pub gold_links: HashMap<usize, Option<String>>,
226}
227
228impl TrainingExample {
229 pub fn from_gold(
231 text: &str,
232 entities: &[Entity],
233 chains: &[CorefChain],
234 links: &[(usize, Option<String>)],
235 ) -> Self {
236 let mentions: Vec<JointMention> = entities
237 .iter()
238 .enumerate()
239 .map(|(i, e)| JointMention::from_entity(i, e, text))
240 .collect();
241
242 let mut gold_ner = HashMap::new();
243 for (i, e) in entities.iter().enumerate() {
244 gold_ner.insert(i, e.entity_type.clone());
245 }
246
247 let mut gold_coref = HashMap::new();
249 for chain in chains {
250 let mut prev_idx: Option<usize> = None;
251 for mention in &chain.mentions {
252 if let Some(idx) = mentions
254 .iter()
255 .position(|m| m.start == mention.start && m.end == mention.end)
256 {
257 gold_coref.insert(idx, prev_idx);
258 prev_idx = Some(idx);
259 }
260 }
261 }
262
263 let gold_links: HashMap<usize, Option<String>> = links.iter().cloned().collect();
264
265 Self {
266 text: text.to_string(),
267 mentions,
268 gold_ner,
269 gold_coref,
270 gold_links,
271 }
272 }
273
274 fn prior_score(&self, idx: usize) -> f64 {
276 self.mentions[idx]
277 .entity
278 .as_ref()
279 .map(|e| e.confidence)
280 .unwrap_or(0.0)
281 }
282
283 pub fn hamming_loss(
285 &self,
286 pred_ner: &HashMap<usize, EntityType>,
287 pred_coref: &HashMap<usize, Option<usize>>,
288 pred_links: &HashMap<usize, Option<String>>,
289 ) -> f64 {
290 let mut loss = 0.0;
291 let n = self.mentions.len() as f64;
292
293 for (idx, gold_type) in &self.gold_ner {
295 if let Some(pred_type) = pred_ner.get(idx) {
296 if pred_type != gold_type {
297 loss += 1.0;
298 }
299 } else {
300 loss += 1.0;
301 }
302 }
303
304 for (idx, gold_ante) in &self.gold_coref {
306 if let Some(pred_ante) = pred_coref.get(idx) {
307 if pred_ante != gold_ante {
308 loss += 1.0;
309 }
310 } else {
311 loss += 1.0;
312 }
313 }
314
315 for (idx, gold_link) in &self.gold_links {
317 if let Some(pred_link) = pred_links.get(idx) {
318 if pred_link != gold_link {
319 loss += 1.0;
320 }
321 } else {
322 loss += 1.0;
323 }
324 }
325
326 if n > 0.0 {
327 loss / n
328 } else {
329 0.0
330 }
331 }
332}
333
334#[derive(Debug, Clone, Default)]
340pub struct JointWeights {
341 pub unary_ner: UnaryNerWeights,
343 pub unary_coref: UnaryCorefWeights,
345 pub unary_link: UnaryLinkWeights,
347 pub link_ner: LinkNerWeights,
349 pub coref_ner: CorefNerWeights,
351 pub coref_link: CorefLinkWeights,
353}
354
355#[derive(Debug, Clone, Default, Serialize, Deserialize)]
357pub struct UnaryNerWeights {
358 pub type_bias: HashMap<String, f64>,
360 pub context_weight: f64,
362}
363
364#[derive(Debug, Clone, Default, Serialize, Deserialize)]
366pub struct UnaryCorefWeights {
367 pub new_cluster_bias: f64,
369 pub distance_decay: f64,
371 pub string_match: f64,
373}
374
375#[derive(Debug, Clone, Default, Serialize, Deserialize)]
377pub struct UnaryLinkWeights {
378 pub nil_bias: f64,
380 pub prior_weight: f64,
382}
383
384#[derive(Debug, Clone, Default)]
390struct AdaGradState {
391 sum_sq_grad: f64,
393}
394
395impl AdaGradState {
396 fn update(&mut self, grad: f64, lr: f64, epsilon: f64) -> f64 {
397 self.sum_sq_grad += grad * grad;
398 let adjusted_lr = lr / (self.sum_sq_grad.sqrt() + epsilon);
399 -adjusted_lr * grad
400 }
401}
402
403#[derive(Debug, Clone, Default)]
405struct OptimizerState {
406 type_bias_states: HashMap<String, AdaGradState>,
408 context_weight_state: AdaGradState,
410 new_cluster_bias_state: AdaGradState,
412 distance_decay_state: AdaGradState,
413 string_match_state: AdaGradState,
414 nil_bias_state: AdaGradState,
416 prior_weight_state: AdaGradState,
417 type_match_state: AdaGradState,
419 type_mismatch_state: AdaGradState,
420 wiki_type_match_state: AdaGradState,
421 wiki_type_mismatch_state: AdaGradState,
422 same_link_state: AdaGradState,
423 different_link_state: AdaGradState,
424}
425
426#[derive(Debug, Clone, Default)]
432struct Gradients {
433 type_bias: HashMap<String, f64>,
435 context_weight: f64,
437 new_cluster_bias: f64,
439 distance_decay: f64,
440 string_match: f64,
441 nil_bias: f64,
443 prior_weight: f64,
444 type_match: f64,
446 type_mismatch: f64,
447 wiki_type_match: f64,
448 wiki_type_mismatch: f64,
449 same_link: f64,
450 different_link: f64,
451}
452
453impl Gradients {
454 fn clip(&mut self, threshold: f64) {
455 let clip = |x: &mut f64| {
456 if *x > threshold {
457 *x = threshold;
458 } else if *x < -threshold {
459 *x = -threshold;
460 }
461 };
462
463 for v in self.type_bias.values_mut() {
464 clip(v);
465 }
466 clip(&mut self.context_weight);
467 clip(&mut self.new_cluster_bias);
468 clip(&mut self.distance_decay);
469 clip(&mut self.string_match);
470 clip(&mut self.nil_bias);
471 clip(&mut self.prior_weight);
472 clip(&mut self.type_match);
473 clip(&mut self.type_mismatch);
474 clip(&mut self.wiki_type_match);
475 clip(&mut self.wiki_type_mismatch);
476 clip(&mut self.same_link);
477 clip(&mut self.different_link);
478 }
479
480 fn add_l2_regularization(&mut self, weights: &JointWeights, lambda: f64) {
481 for (type_name, bias) in &weights.unary_ner.type_bias {
483 *self.type_bias.entry(type_name.clone()).or_insert(0.0) += lambda * bias;
484 }
485 self.context_weight += lambda * weights.unary_ner.context_weight;
486 self.new_cluster_bias += lambda * weights.unary_coref.new_cluster_bias;
487 self.distance_decay += lambda * weights.unary_coref.distance_decay;
488 self.string_match += lambda * weights.unary_coref.string_match;
489 self.nil_bias += lambda * weights.unary_link.nil_bias;
490 self.prior_weight += lambda * weights.unary_link.prior_weight;
491 self.type_match += lambda * weights.coref_ner.type_match;
492 self.type_mismatch += lambda * weights.coref_ner.type_mismatch;
493 self.wiki_type_match += lambda * weights.link_ner.type_match;
494 self.wiki_type_mismatch += lambda * weights.link_ner.type_mismatch;
495 self.same_link += lambda * weights.coref_link.same_entity;
496 self.different_link += lambda * weights.coref_link.different_entity;
497 }
498}
499
500pub struct Trainer {
506 config: TrainingConfig,
508 weights: JointWeights,
510 optimizer: OptimizerState,
512 examples: Vec<TrainingExample>,
514 loss_history: Vec<f64>,
516}
517
518impl Trainer {
519 pub fn new(config: TrainingConfig) -> Self {
521 Self {
522 config,
523 weights: JointWeights::default(),
524 optimizer: OptimizerState::default(),
525 examples: Vec::new(),
526 loss_history: Vec::new(),
527 }
528 }
529
530 pub fn add_example(&mut self, example: TrainingExample) {
532 self.examples.push(example);
533 }
534
535 pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
537 self.examples.extend(examples);
538 }
539
540 pub fn get_weights(&self) -> &JointWeights {
542 &self.weights
543 }
544
545 pub fn get_loss_history(&self) -> &[f64] {
547 &self.loss_history
548 }
549
550 pub fn train(&mut self) -> Vec<f64> {
552 let mut losses = Vec::new();
553 let mut best_loss = f64::INFINITY;
554 let mut patience_counter = 0;
555
556 for epoch in 0..self.config.epochs {
557 let mut indices: Vec<usize> = (0..self.examples.len()).collect();
559 shuffle(&mut indices, epoch as u64);
560
561 let mut epoch_loss = 0.0;
562 let mut num_batches = 0;
563
564 for batch_start in (0..self.examples.len()).step_by(self.config.batch_size) {
566 let batch_end = (batch_start + self.config.batch_size).min(self.examples.len());
567 let batch_indices = &indices[batch_start..batch_end];
568
569 let batch_loss = self.train_batch(batch_indices);
570 epoch_loss += batch_loss;
571 num_batches += 1;
572 }
573
574 let avg_loss = if num_batches > 0 {
575 epoch_loss / num_batches as f64
576 } else {
577 0.0
578 };
579 losses.push(avg_loss);
580 self.loss_history.push(avg_loss);
581
582 if avg_loss < best_loss - self.config.min_delta {
584 best_loss = avg_loss;
585 patience_counter = 0;
586 } else {
587 patience_counter += 1;
588 if patience_counter >= self.config.patience {
589 break;
590 }
591 }
592 }
593
594 losses
595 }
596
597 fn train_batch(&mut self, indices: &[usize]) -> f64 {
598 let mut total_loss = 0.0;
599 let mut accumulated_grads = Gradients::default();
600
601 for &idx in indices {
602 let example = &self.examples[idx];
603 let (loss, grads) = self.compute_loss_and_gradients(example);
604 total_loss += loss;
605
606 for (type_name, grad) in grads.type_bias {
608 *accumulated_grads.type_bias.entry(type_name).or_insert(0.0) += grad;
609 }
610 accumulated_grads.context_weight += grads.context_weight;
611 accumulated_grads.new_cluster_bias += grads.new_cluster_bias;
612 accumulated_grads.distance_decay += grads.distance_decay;
613 accumulated_grads.string_match += grads.string_match;
614 accumulated_grads.nil_bias += grads.nil_bias;
615 accumulated_grads.prior_weight += grads.prior_weight;
616 accumulated_grads.type_match += grads.type_match;
617 accumulated_grads.type_mismatch += grads.type_mismatch;
618 accumulated_grads.wiki_type_match += grads.wiki_type_match;
619 accumulated_grads.wiki_type_mismatch += grads.wiki_type_mismatch;
620 accumulated_grads.same_link += grads.same_link;
621 accumulated_grads.different_link += grads.different_link;
622 }
623
624 let n = indices.len() as f64;
626 if n > 0.0 {
627 for v in accumulated_grads.type_bias.values_mut() {
628 *v /= n;
629 }
630 accumulated_grads.context_weight /= n;
631 accumulated_grads.new_cluster_bias /= n;
632 accumulated_grads.distance_decay /= n;
633 accumulated_grads.string_match /= n;
634 accumulated_grads.nil_bias /= n;
635 accumulated_grads.prior_weight /= n;
636 accumulated_grads.type_match /= n;
637 accumulated_grads.type_mismatch /= n;
638 accumulated_grads.wiki_type_match /= n;
639 accumulated_grads.wiki_type_mismatch /= n;
640 accumulated_grads.same_link /= n;
641 accumulated_grads.different_link /= n;
642 }
643
644 accumulated_grads.add_l2_regularization(&self.weights, self.config.l2_lambda);
646
647 accumulated_grads.clip(self.config.grad_clip);
649
650 self.apply_updates(&accumulated_grads);
652
653 total_loss / n.max(1.0)
654 }
655
656 fn compute_loss_and_gradients(&self, example: &TrainingExample) -> (f64, Gradients) {
657 let mut grads = Gradients::default();
658
659 let gold_score = self.compute_score(
661 example,
662 &example.gold_ner,
663 &example.gold_coref,
664 &example.gold_links,
665 );
666
667 let (pred_ner, pred_coref, pred_links) = self.decode_with_cost(example);
670 let pred_score = self.compute_score(example, &pred_ner, &pred_coref, &pred_links);
671
672 let cost = example.hamming_loss(&pred_ner, &pred_coref, &pred_links);
674
675 let margin = pred_score + self.config.cost_weight * cost - gold_score;
677 let loss = if margin > 0.0 { margin } else { 0.0 };
678
679 if loss > 0.0 {
680 self.accumulate_feature_gradients(
682 &mut grads,
683 example,
684 &pred_ner,
685 &pred_coref,
686 &pred_links,
687 1.0,
688 );
689 self.accumulate_feature_gradients(
690 &mut grads,
691 example,
692 &example.gold_ner,
693 &example.gold_coref,
694 &example.gold_links,
695 -1.0,
696 );
697 }
698
699 (loss, grads)
700 }
701
702 fn compute_score(
703 &self,
704 example: &TrainingExample,
705 ner: &HashMap<usize, EntityType>,
706 coref: &HashMap<usize, Option<usize>>,
707 links: &HashMap<usize, Option<String>>,
708 ) -> f64 {
709 let mut score = 0.0;
710
711 for entity_type in ner.values() {
713 let type_label = entity_type.as_label();
714 if let Some(&bias) = self.weights.unary_ner.type_bias.get(type_label) {
715 score += bias;
716 }
717 }
718
719 for (idx, ante) in coref {
721 if ante.is_none() {
722 score += self.weights.unary_coref.new_cluster_bias;
723 } else if let Some(ante_idx) = ante {
724 let dist = (*idx as i64 - *ante_idx as i64).unsigned_abs() as f64;
726 score -= self.weights.unary_coref.distance_decay * dist.ln();
727
728 if idx < &example.mentions.len() && *ante_idx < example.mentions.len() {
730 let m_i = &example.mentions[*idx];
731 let m_j = &example.mentions[*ante_idx];
732 if m_i.text.to_lowercase() == m_j.text.to_lowercase() {
733 score += self.weights.unary_coref.string_match;
734 }
735 }
736 }
737 }
738
739 for (idx, link) in links {
741 if link.is_none() {
742 score += self.weights.unary_link.nil_bias;
743 } else if *idx < example.mentions.len() {
744 score += self.weights.unary_link.prior_weight * example.prior_score(*idx);
745 }
746 }
747
748 for (idx, ante) in coref {
750 if let Some(ante_idx) = ante {
751 if let (Some(type_i), Some(type_j)) = (ner.get(idx), ner.get(ante_idx)) {
752 if type_i == type_j {
753 score += self.weights.coref_ner.type_match;
754 } else {
755 score += self.weights.coref_ner.type_mismatch;
756 }
757 }
758 }
759 }
760
761 for (idx, ante) in coref {
763 if let Some(ante_idx) = ante {
764 if let (Some(link_i), Some(link_j)) = (links.get(idx), links.get(ante_idx)) {
765 if link_i == link_j {
766 score += self.weights.coref_link.same_entity;
767 } else {
768 score += self.weights.coref_link.different_entity;
769 }
770 }
771 }
772 }
773
774 score
775 }
776
777 fn decode_with_cost(&self, example: &TrainingExample) -> DecodeResult {
778 let mut pred_ner = HashMap::new();
780 let mut pred_coref = HashMap::new();
781 let mut pred_links = HashMap::new();
782
783 for (idx, mention) in example.mentions.iter().enumerate() {
784 if let Some(gold_type) = example.gold_ner.get(&idx) {
786 pred_ner.insert(idx, gold_type.clone());
787 } else if let Some(ref t) = mention.entity_type {
788 pred_ner.insert(idx, t.clone());
789 }
790
791 let mut best_ante: Option<usize> = None;
793 let mut best_score = self.weights.unary_coref.new_cluster_bias;
794
795 for ante_idx in 0..idx {
796 let mut ante_score = 0.0;
797
798 let dist = (idx - ante_idx) as f64;
800 ante_score -= self.weights.unary_coref.distance_decay * dist.ln().max(0.0);
801
802 if mention.text.to_lowercase() == example.mentions[ante_idx].text.to_lowercase() {
804 ante_score += self.weights.unary_coref.string_match;
805 }
806
807 if let (Some(type_i), Some(type_j)) = (pred_ner.get(&idx), pred_ner.get(&ante_idx))
809 {
810 if type_i == type_j {
811 ante_score += self.weights.coref_ner.type_match;
812 } else {
813 ante_score += self.weights.coref_ner.type_mismatch;
814 }
815 }
816
817 if let Some(gold_ante) = example.gold_coref.get(&idx) {
819 if gold_ante != &Some(ante_idx) {
820 ante_score += self.config.cost_weight;
821 }
822 }
823
824 if ante_score > best_score {
825 best_score = ante_score;
826 best_ante = Some(ante_idx);
827 }
828 }
829 pred_coref.insert(idx, best_ante);
830
831 if let Some(gold_link) = example.gold_links.get(&idx) {
833 pred_links.insert(idx, gold_link.clone());
835 } else {
836 pred_links.insert(idx, None);
837 }
838 }
839
840 (pred_ner, pred_coref, pred_links)
841 }
842
843 fn accumulate_feature_gradients(
844 &self,
845 grads: &mut Gradients,
846 example: &TrainingExample,
847 ner: &HashMap<usize, EntityType>,
848 coref: &HashMap<usize, Option<usize>>,
849 links: &HashMap<usize, Option<String>>,
850 scale: f64,
851 ) {
852 for entity_type in ner.values() {
854 let type_label = entity_type.as_label().to_string();
855 *grads.type_bias.entry(type_label).or_insert(0.0) += scale;
856 }
857
858 for (idx, ante) in coref {
860 if ante.is_none() {
861 grads.new_cluster_bias += scale;
862 } else if let Some(ante_idx) = ante {
863 let dist = (*idx as i64 - *ante_idx as i64).unsigned_abs() as f64;
864 grads.distance_decay -= scale * dist.ln();
865
866 if idx < &example.mentions.len() && *ante_idx < example.mentions.len() {
867 let m_i = &example.mentions[*idx];
868 let m_j = &example.mentions[*ante_idx];
869 if m_i.text.to_lowercase() == m_j.text.to_lowercase() {
870 grads.string_match += scale;
871 }
872 }
873 }
874 }
875
876 for (idx, link) in links {
878 if link.is_none() {
879 grads.nil_bias += scale;
880 } else if *idx < example.mentions.len() {
881 grads.prior_weight += scale * example.prior_score(*idx);
882 }
883 }
884
885 for (idx, ante) in coref {
887 if let Some(ante_idx) = ante {
888 if let (Some(type_i), Some(type_j)) = (ner.get(idx), ner.get(ante_idx)) {
889 if type_i == type_j {
890 grads.type_match += scale;
891 } else {
892 grads.type_mismatch += scale;
893 }
894 }
895 }
896 }
897
898 for (idx, ante) in coref {
900 if let Some(ante_idx) = ante {
901 if let (Some(link_i), Some(link_j)) = (links.get(idx), links.get(ante_idx)) {
902 if link_i == link_j {
903 grads.same_link += scale;
904 } else {
905 grads.different_link += scale;
906 }
907 }
908 }
909 }
910 }
911
912 fn apply_updates(&mut self, grads: &Gradients) {
913 let lr = self.config.learning_rate;
914 let eps = self.config.epsilon;
915
916 for (type_name, &grad) in &grads.type_bias {
918 let state = self
919 .optimizer
920 .type_bias_states
921 .entry(type_name.clone())
922 .or_default();
923 let delta = state.update(grad, lr, eps);
924 *self
925 .weights
926 .unary_ner
927 .type_bias
928 .entry(type_name.clone())
929 .or_insert(0.0) += delta;
930 }
931
932 let delta = self
934 .optimizer
935 .context_weight_state
936 .update(grads.context_weight, lr, eps);
937 self.weights.unary_ner.context_weight += delta;
938
939 let delta = self
940 .optimizer
941 .new_cluster_bias_state
942 .update(grads.new_cluster_bias, lr, eps);
943 self.weights.unary_coref.new_cluster_bias += delta;
944
945 let delta = self
946 .optimizer
947 .distance_decay_state
948 .update(grads.distance_decay, lr, eps);
949 self.weights.unary_coref.distance_decay += delta;
950
951 let delta = self
952 .optimizer
953 .string_match_state
954 .update(grads.string_match, lr, eps);
955 self.weights.unary_coref.string_match += delta;
956
957 let delta = self
958 .optimizer
959 .nil_bias_state
960 .update(grads.nil_bias, lr, eps);
961 self.weights.unary_link.nil_bias += delta;
962
963 let delta = self
964 .optimizer
965 .prior_weight_state
966 .update(grads.prior_weight, lr, eps);
967 self.weights.unary_link.prior_weight += delta;
968
969 let delta = self
970 .optimizer
971 .type_match_state
972 .update(grads.type_match, lr, eps);
973 self.weights.coref_ner.type_match += delta;
974
975 let delta = self
976 .optimizer
977 .type_mismatch_state
978 .update(grads.type_mismatch, lr, eps);
979 self.weights.coref_ner.type_mismatch += delta;
980
981 let delta = self
982 .optimizer
983 .wiki_type_match_state
984 .update(grads.wiki_type_match, lr, eps);
985 self.weights.link_ner.type_match += delta;
986
987 let delta =
988 self.optimizer
989 .wiki_type_mismatch_state
990 .update(grads.wiki_type_mismatch, lr, eps);
991 self.weights.link_ner.type_mismatch += delta;
992
993 let delta = self
994 .optimizer
995 .same_link_state
996 .update(grads.same_link, lr, eps);
997 self.weights.coref_link.same_entity += delta;
998
999 let delta = self
1000 .optimizer
1001 .different_link_state
1002 .update(grads.different_link, lr, eps);
1003 self.weights.coref_link.different_entity += delta;
1004 }
1005}
1006
1007fn shuffle<T>(slice: &mut [T], seed: u64) {
1013 let mut rng = seed;
1014 for i in (1..slice.len()).rev() {
1015 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
1016 let j = (rng as usize) % (i + 1);
1017 slice.swap(i, j);
1018 }
1019}
1020
1021#[cfg(test)]
1026mod tests {
1027 use super::*;
1028
1029 #[test]
1030 fn test_training_config_default() {
1031 let config = TrainingConfig::default();
1032 assert_eq!(config.epochs, 50);
1033 assert!((config.learning_rate - 0.1).abs() < 1e-6);
1034 }
1035
1036 #[test]
1037 fn test_trainer_creation() {
1038 let trainer = Trainer::new(TrainingConfig::default());
1039 assert!(trainer.examples.is_empty());
1040 }
1041
1042 #[test]
1043 fn test_adagrad_state() {
1044 let mut state = AdaGradState::default();
1045
1046 let delta1 = state.update(1.0, 0.1, 1e-8);
1048 assert!(delta1 < 0.0); let delta2 = state.update(1.0, 0.1, 1e-8);
1052 assert!(delta2.abs() < delta1.abs()); }
1054
1055 #[test]
1056 fn test_gradient_clipping() {
1057 let mut grads = Gradients {
1058 context_weight: 100.0,
1059 type_match: -100.0,
1060 ..Default::default()
1061 };
1062
1063 grads.clip(5.0);
1064
1065 assert!((grads.context_weight - 5.0).abs() < 1e-6);
1066 assert!((grads.type_match - (-5.0)).abs() < 1e-6);
1067 }
1068
1069 #[test]
1070 fn test_training_example_hamming_loss() {
1071 use crate::joint::MentionKind;
1072
1073 let mentions = vec![JointMention {
1074 idx: 0,
1075 text: "Alice".to_string(),
1076 head: "Alice".to_string(),
1077 start: 0,
1078 end: 5,
1079 mention_kind: MentionKind::Proper,
1080 entity_type: Some(EntityType::Person),
1081 entity: Some(Entity::new("Alice", EntityType::Person, 0, 5, 0.9)),
1082 }];
1083
1084 let mut gold_ner = HashMap::new();
1085 gold_ner.insert(0, EntityType::Person);
1086
1087 let example = TrainingExample {
1088 text: "Alice".to_string(),
1089 mentions,
1090 gold_ner,
1091 gold_coref: HashMap::new(),
1092 gold_links: HashMap::new(),
1093 };
1094
1095 let mut pred_ner = HashMap::new();
1097 pred_ner.insert(0, EntityType::Person);
1098 let loss = example.hamming_loss(&pred_ner, &HashMap::new(), &HashMap::new());
1099 assert!((loss - 0.0).abs() < 1e-6);
1100
1101 let mut wrong_ner = HashMap::new();
1103 wrong_ner.insert(0, EntityType::Organization);
1104 let loss = example.hamming_loss(&wrong_ner, &HashMap::new(), &HashMap::new());
1105 assert!(loss > 0.0);
1106 }
1107
1108 #[test]
1109 fn test_trainer_single_example() {
1110 use crate::joint::MentionKind;
1111
1112 let mut trainer = Trainer::new(TrainingConfig {
1113 epochs: 5,
1114 batch_size: 1,
1115 ..Default::default()
1116 });
1117
1118 let mentions = vec![
1119 JointMention {
1120 idx: 0,
1121 text: "Alice".to_string(),
1122 head: "Alice".to_string(),
1123 start: 0,
1124 end: 5,
1125 mention_kind: MentionKind::Proper,
1126 entity_type: Some(EntityType::Person),
1127 entity: Some(Entity::new("Alice", EntityType::Person, 0, 5, 0.9)),
1128 },
1129 JointMention {
1130 idx: 1,
1131 text: "she".to_string(),
1132 head: "she".to_string(),
1133 start: 17,
1134 end: 20,
1135 mention_kind: MentionKind::Pronominal,
1136 entity_type: Some(EntityType::Person),
1137 entity: Some(Entity::new("she", EntityType::Person, 17, 20, 0.8)),
1138 },
1139 ];
1140
1141 let mut gold_ner = HashMap::new();
1142 gold_ner.insert(0, EntityType::Person);
1143 gold_ner.insert(1, EntityType::Person);
1144
1145 let mut gold_coref = HashMap::new();
1146 gold_coref.insert(0, None); gold_coref.insert(1, Some(0)); let example = TrainingExample {
1150 text: "Alice went home. she was tired.".to_string(),
1151 mentions,
1152 gold_ner,
1153 gold_coref,
1154 gold_links: HashMap::new(),
1155 };
1156
1157 trainer.add_example(example);
1158 let losses = trainer.train();
1159
1160 assert!(!losses.is_empty());
1162 assert!(losses.iter().all(|&l| l < 1000.0));
1164 }
1165
1166 #[test]
1167 fn test_shuffle_deterministic() {
1168 let mut a = vec![1, 2, 3, 4, 5];
1169 let mut b = vec![1, 2, 3, 4, 5];
1170
1171 shuffle(&mut a, 42);
1172 shuffle(&mut b, 42);
1173
1174 assert_eq!(a, b); }
1176
1177 #[test]
1178 fn test_dynamic_batch_config_default() {
1179 let config = DynamicBatchConfig::default();
1180 assert_eq!(config.max_context_length, 4000);
1181 assert_eq!(config.avg_sentence_length, 25);
1182 assert!(config.same_document);
1183 }
1184
1185 #[test]
1186 fn test_dynamic_batch_config_cross_document() {
1187 let config = DynamicBatchConfig::cross_document();
1188 assert!(!config.same_document);
1189 assert_eq!(config.min_contexts, 2);
1190 assert_eq!(config.window_overlap, 0);
1191 }
1192
1193 #[test]
1194 fn test_dynamic_batch_config_long_document() {
1195 let config = DynamicBatchConfig::long_document();
1196 assert!(config.same_document);
1197 assert_eq!(config.window_overlap, 256);
1198 }
1199
1200 #[test]
1201 fn test_dynamic_batch_sample_contexts() {
1202 let config = DynamicBatchConfig {
1203 min_contexts: 2,
1204 max_contexts: 10,
1205 ..Default::default()
1206 };
1207
1208 let n1 = config.sample_num_contexts(42);
1210 let n2 = config.sample_num_contexts(42);
1211 assert_eq!(n1, n2);
1212
1213 assert!((2..=10).contains(&n1));
1215
1216 let n3 = config.sample_num_contexts(123);
1218 assert!(n1 != n3 || config.max_contexts == config.min_contexts);
1220 }
1221
1222 #[test]
1223 fn test_dynamic_batch_context_length() {
1224 let config = DynamicBatchConfig {
1225 max_context_length: 4000,
1226 ..Default::default()
1227 };
1228
1229 assert_eq!(config.context_length(1, 10000), 4000);
1231
1232 assert_eq!(config.context_length(4, 10000), 1000);
1234
1235 assert_eq!(config.context_length(2, 500), 250);
1237 }
1238
1239 #[test]
1240 fn test_training_config_with_dynamic_batching() {
1241 let config = TrainingConfig {
1242 dynamic_batching: Some(DynamicBatchConfig::cross_document()),
1243 ..Default::default()
1244 };
1245
1246 assert!(config.dynamic_batching.is_some());
1247 let db = config.dynamic_batching.unwrap();
1248 assert!(!db.same_document);
1249 }
1250}