1use crate::checkpoint::EarlyStopping;
38use crate::distributed::{GradientSync, LocalGradientSync};
39use crate::error::ModelResult;
40use scirs2_core::ndarray::{Array1, Array2};
41use serde::{Deserialize, Serialize};
42
43pub trait DataProvider: Send {
53 fn num_samples(&self) -> usize;
55
56 fn num_features(&self) -> usize;
58
59 fn get_batch(&self, indices: &[usize]) -> (Array2<f32>, Array1<f32>);
64
65 fn shuffle_indices(&self, rng_seed: u64) -> Vec<usize> {
69 let n = self.num_samples();
70 let mut indices: Vec<usize> = (0..n).collect();
71 let mut state = rng_seed.wrapping_add(1);
74 for i in (1..n).rev() {
75 state = state
76 .wrapping_mul(6_364_136_223_846_793_005)
77 .wrapping_add(1_442_695_040_888_963_407);
78 let j = (state >> 33) as usize % (i + 1);
79 indices.swap(i, j);
80 }
81 indices
82 }
83}
84
85pub struct ArrayDataProvider {
92 features: Array2<f32>,
93 targets: Array1<f32>,
94}
95
96impl ArrayDataProvider {
97 pub fn new(features: Array2<f32>, targets: Array1<f32>) -> Self {
103 debug_assert_eq!(
104 features.nrows(),
105 targets.len(),
106 "features and targets must have the same number of samples"
107 );
108 Self { features, targets }
109 }
110}
111
112impl DataProvider for ArrayDataProvider {
113 fn num_samples(&self) -> usize {
114 self.features.nrows()
115 }
116
117 fn num_features(&self) -> usize {
118 self.features.ncols()
119 }
120
121 fn get_batch(&self, indices: &[usize]) -> (Array2<f32>, Array1<f32>) {
122 let nf = self.num_features();
123 let nb = indices.len();
124
125 let mut feat = Array2::<f32>::zeros((nb, nf));
126 let mut tgt = Array1::<f32>::zeros(nb);
127
128 for (batch_idx, &sample_idx) in indices.iter().enumerate() {
129 let sample_idx = sample_idx.min(self.features.nrows().saturating_sub(1));
130 feat.row_mut(batch_idx)
131 .assign(&self.features.row(sample_idx));
132 tgt[batch_idx] = self.targets[sample_idx];
133 }
134
135 (feat, tgt)
136 }
137}
138
139pub trait Optimizer: Send {
149 fn step(
154 &mut self,
155 weights: &mut Array1<f32>,
156 bias: &mut f32,
157 weight_grad: &Array1<f32>,
158 bias_grad: f32,
159 );
160
161 fn learning_rate(&self) -> f32;
163
164 fn set_learning_rate(&mut self, lr: f32);
166}
167
168pub struct SgdOptimizer {
174 lr: f32,
175}
176
177impl SgdOptimizer {
178 pub fn new(lr: f32) -> Self {
180 Self { lr }
181 }
182}
183
184impl Optimizer for SgdOptimizer {
185 fn step(
186 &mut self,
187 weights: &mut Array1<f32>,
188 bias: &mut f32,
189 weight_grad: &Array1<f32>,
190 bias_grad: f32,
191 ) {
192 *weights = weights.clone() - self.lr * weight_grad;
193 *bias -= self.lr * bias_grad;
194 }
195
196 fn learning_rate(&self) -> f32 {
197 self.lr
198 }
199
200 fn set_learning_rate(&mut self, lr: f32) {
201 self.lr = lr;
202 }
203}
204
205pub struct AdamOptimizer {
211 lr: f32,
212 beta1: f32,
213 beta2: f32,
214 epsilon: f32,
215 m_w: Option<Array1<f32>>,
217 v_w: Option<Array1<f32>>,
219 m_b: f32,
221 v_b: f32,
223 t: u64,
225}
226
227impl AdamOptimizer {
228 pub fn new(lr: f32) -> Self {
230 Self {
231 lr,
232 beta1: 0.9,
233 beta2: 0.999,
234 epsilon: 1e-8,
235 m_w: None,
236 v_w: None,
237 m_b: 0.0,
238 v_b: 0.0,
239 t: 0,
240 }
241 }
242}
243
244impl Optimizer for AdamOptimizer {
245 fn step(
246 &mut self,
247 weights: &mut Array1<f32>,
248 bias: &mut f32,
249 weight_grad: &Array1<f32>,
250 bias_grad: f32,
251 ) {
252 self.t += 1;
253 let t = self.t as f32;
254
255 let n = weights.len();
257 let m_w = self.m_w.get_or_insert_with(|| Array1::<f32>::zeros(n));
258 let v_w = self.v_w.get_or_insert_with(|| Array1::<f32>::zeros(n));
259
260 *m_w = self.beta1 * m_w.clone() + (1.0 - self.beta1) * weight_grad;
262 let grad_sq = weight_grad.mapv(|x| x * x);
263 *v_w = self.beta2 * v_w.clone() + (1.0 - self.beta2) * &grad_sq;
264
265 let bc1 = 1.0 - self.beta1.powf(t);
267 let bc2 = 1.0 - self.beta2.powf(t);
268 let m_hat = m_w.clone() / bc1;
269 let v_hat = v_w.clone() / bc2;
270
271 *weights = weights.clone() - self.lr * &m_hat / (v_hat.mapv(|x| x.sqrt()) + self.epsilon);
272
273 self.m_b = self.beta1 * self.m_b + (1.0 - self.beta1) * bias_grad;
275 self.v_b = self.beta2 * self.v_b + (1.0 - self.beta2) * bias_grad * bias_grad;
276 let mb_hat = self.m_b / bc1;
277 let vb_hat = self.v_b / bc2;
278 *bias -= self.lr * mb_hat / (vb_hat.sqrt() + self.epsilon);
279 }
280
281 fn learning_rate(&self) -> f32 {
282 self.lr
283 }
284
285 fn set_learning_rate(&mut self, lr: f32) {
286 self.lr = lr;
287 }
288}
289
290pub trait LrScheduler: Send {
298 fn step(&mut self, epoch: usize, val_loss: Option<f32>) -> f32;
305
306 fn current_lr(&self) -> f32;
308}
309
310pub struct ConstantScheduler {
316 lr: f32,
317}
318
319impl ConstantScheduler {
320 pub fn new(lr: f32) -> Self {
322 Self { lr }
323 }
324}
325
326impl LrScheduler for ConstantScheduler {
327 fn step(&mut self, _epoch: usize, _val_loss: Option<f32>) -> f32 {
328 self.lr
329 }
330
331 fn current_lr(&self) -> f32 {
332 self.lr
333 }
334}
335
336pub struct ExponentialScheduler {
341 decay_rate: f32,
342 min_lr: f32,
343 current: f32,
344}
345
346impl ExponentialScheduler {
347 pub fn new(initial_lr: f32, decay_rate: f32, min_lr: f32) -> Self {
349 Self {
350 decay_rate,
351 min_lr,
352 current: initial_lr,
353 }
354 }
355}
356
357impl LrScheduler for ExponentialScheduler {
358 fn step(&mut self, _epoch: usize, _val_loss: Option<f32>) -> f32 {
359 self.current = (self.current * self.decay_rate).max(self.min_lr);
360 self.current
361 }
362
363 fn current_lr(&self) -> f32 {
364 self.current
365 }
366}
367
368pub struct StepDecayScheduler {
370 step_size: usize,
371 gamma: f32,
372 current: f32,
373}
374
375impl StepDecayScheduler {
376 pub fn new(initial_lr: f32, step_size: usize, gamma: f32) -> Self {
382 Self {
383 step_size,
384 gamma,
385 current: initial_lr,
386 }
387 }
388}
389
390impl LrScheduler for StepDecayScheduler {
391 fn step(&mut self, epoch: usize, _val_loss: Option<f32>) -> f32 {
392 if epoch > 0 && epoch.is_multiple_of(self.step_size) {
393 self.current *= self.gamma;
394 }
395 self.current
396 }
397
398 fn current_lr(&self) -> f32 {
399 self.current
400 }
401}
402
403pub trait TrainingCallback: Send {
412 fn on_epoch_start(&mut self, _epoch: usize) {}
414
415 fn on_epoch_end(&mut self, _epoch: usize, _train_loss: f32, _val_loss: Option<f32>) {}
417
418 fn on_batch_end(&mut self, _epoch: usize, _batch: usize, _loss: f32) {}
420
421 fn on_training_end(&mut self, _result: &TrainingResult) {}
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
431pub struct TrainingResult {
432 pub train_losses: Vec<f32>,
434 pub val_losses: Vec<Option<f32>>,
436 pub best_epoch: usize,
438 pub best_val_loss: Option<f32>,
440 pub epochs_trained: usize,
443 pub final_train_loss: f32,
445}
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
453pub struct TrainingConfig {
454 pub max_epochs: usize,
456 pub batch_size: usize,
458 pub learning_rate: f32,
460 pub val_fraction: f32,
462 pub rng_seed: u64,
464 pub log_every_n_epochs: usize,
466}
467
468impl Default for TrainingConfig {
469 fn default() -> Self {
470 Self {
471 max_epochs: 100,
472 batch_size: 32,
473 learning_rate: 0.01,
474 val_fraction: 0.1,
475 rng_seed: 42,
476 log_every_n_epochs: 10,
477 }
478 }
479}
480
481fn mse_linear_backward(
492 features: &Array2<f32>,
493 targets: &Array1<f32>,
494 weights: &Array1<f32>,
495 bias: f32,
496) -> (f32, Array1<f32>, f32) {
497 let n = features.nrows() as f32;
498 let nf = features.ncols();
499
500 let mut predictions = Array1::<f32>::zeros(features.nrows());
502 for (i, row) in features.rows().into_iter().enumerate() {
503 let dot: f32 = row.iter().zip(weights.iter()).map(|(&x, &w)| x * w).sum();
504 predictions[i] = dot + bias;
505 }
506
507 let residuals = &predictions - targets;
509
510 let loss = residuals.iter().map(|&r| r * r).sum::<f32>() / n;
512
513 let mut weight_grad = Array1::<f32>::zeros(nf);
515 for (i, row) in features.rows().into_iter().enumerate() {
516 let r = residuals[i];
517 for (j, &x) in row.iter().enumerate() {
518 weight_grad[j] += 2.0 * x * r / n;
519 }
520 }
521
522 let bias_grad = 2.0 * residuals.sum() / n;
524
525 (loss, weight_grad, bias_grad)
526}
527
528fn train_val_split(n: usize, val_fraction: f32, rng_seed: u64) -> (Vec<usize>, Vec<usize>) {
533 let val_fraction = val_fraction.clamp(0.0, 0.99);
534 let all_indices: Vec<usize> = lcg_shuffle((0..n).collect(), rng_seed);
535 let val_count = ((n as f32 * val_fraction).round() as usize).min(n.saturating_sub(1));
536 let train_count = n - val_count;
537 let train = all_indices[..train_count].to_vec();
538 let val = all_indices[train_count..].to_vec();
539 (train, val)
540}
541
542fn lcg_shuffle(mut v: Vec<usize>, seed: u64) -> Vec<usize> {
544 let n = v.len();
545 let mut state = seed.wrapping_add(1);
546 for i in (1..n).rev() {
547 state = state
548 .wrapping_mul(6_364_136_223_846_793_005)
549 .wrapping_add(1_442_695_040_888_963_407);
550 let j = (state >> 33) as usize % (i + 1);
551 v.swap(i, j);
552 }
553 v
554}
555
556pub struct TrainingLoop {
566 config: TrainingConfig,
567 callbacks: Vec<Box<dyn TrainingCallback>>,
568 gradient_sync: Box<dyn GradientSync>,
569}
570
571impl TrainingLoop {
572 pub fn new(config: TrainingConfig) -> Self {
576 Self {
577 config,
578 callbacks: Vec::new(),
579 gradient_sync: Box::new(LocalGradientSync::new()),
580 }
581 }
582
583 pub fn add_callback(&mut self, cb: Box<dyn TrainingCallback>) {
585 self.callbacks.push(cb);
586 }
587
588 pub fn with_gradient_sync(mut self, sync: Box<dyn GradientSync>) -> Self {
590 self.gradient_sync = sync;
591 self
592 }
593
594 pub fn run(
609 &mut self,
610 data: &dyn DataProvider,
611 optimizer: &mut dyn Optimizer,
612 lr_scheduler: &mut dyn LrScheduler,
613 mut early_stopping: Option<&mut EarlyStopping>,
614 model_weights: &mut Array1<f32>,
615 model_bias: &mut f32,
616 ) -> ModelResult<TrainingResult> {
617 let n = data.num_samples();
618 let (train_indices, val_indices) =
619 train_val_split(n, self.config.val_fraction, self.config.rng_seed);
620
621 optimizer.set_learning_rate(self.config.learning_rate);
623
624 let mut train_losses: Vec<f32> = Vec::with_capacity(self.config.max_epochs);
625 let mut val_losses: Vec<Option<f32>> = Vec::with_capacity(self.config.max_epochs);
626 let mut best_val_loss: Option<f32> = None;
627 let mut best_epoch = 0_usize;
628
629 'epoch_loop: for epoch in 0..self.config.max_epochs {
630 for cb in self.callbacks.iter_mut() {
632 cb.on_epoch_start(epoch);
633 }
634
635 let shuffled = lcg_shuffle(
637 train_indices.clone(),
638 self.config.rng_seed.wrapping_add(epoch as u64),
639 );
640
641 let batch_size = self.config.batch_size.max(1);
643 let mut epoch_loss_sum = 0.0_f32;
644 let mut epoch_batches = 0_usize;
645
646 let mut batch_idx = 0_usize;
647 let mut offset = 0_usize;
648 while offset < shuffled.len() {
649 let end = (offset + batch_size).min(shuffled.len());
650 let batch_sample_ids = &shuffled[offset..end];
651
652 let (batch_feat, batch_tgt) = data.get_batch(batch_sample_ids);
653
654 let (loss, mut weight_grad, bias_grad) =
655 mse_linear_backward(&batch_feat, &batch_tgt, model_weights, *model_bias);
656
657 self.gradient_sync.sync_gradients(&mut weight_grad)?;
659
660 optimizer.step(model_weights, model_bias, &weight_grad, bias_grad);
661
662 epoch_loss_sum += loss;
663 epoch_batches += 1;
664
665 for cb in self.callbacks.iter_mut() {
667 cb.on_batch_end(epoch, batch_idx, loss);
668 }
669
670 offset += batch_size;
671 batch_idx += 1;
672 }
673
674 let epoch_train_loss = if epoch_batches > 0 {
675 epoch_loss_sum / epoch_batches as f32
676 } else {
677 0.0
678 };
679
680 let epoch_val_loss = if !val_indices.is_empty() {
682 let (val_feat, val_tgt) = data.get_batch(&val_indices);
683 let (vloss, _, _) =
684 mse_linear_backward(&val_feat, &val_tgt, model_weights, *model_bias);
685 Some(vloss)
686 } else {
687 None
688 };
689
690 train_losses.push(epoch_train_loss);
691 val_losses.push(epoch_val_loss);
692
693 if let Some(vl) = epoch_val_loss {
695 if best_val_loss.is_none_or(|best| vl < best) {
696 best_val_loss = Some(vl);
697 best_epoch = epoch;
698 }
699 }
700
701 let new_lr = lr_scheduler.step(epoch, epoch_val_loss);
703 optimizer.set_learning_rate(new_lr);
704
705 if let Some(ref mut es) = early_stopping {
707 let check_loss = epoch_val_loss.unwrap_or(epoch_train_loss);
708 if es.should_stop(check_loss) {
709 for cb in self.callbacks.iter_mut() {
711 cb.on_epoch_end(epoch, epoch_train_loss, epoch_val_loss);
712 }
713 break 'epoch_loop;
714 }
715 }
716
717 if self.config.log_every_n_epochs > 0 && epoch % self.config.log_every_n_epochs == 0 {
719 if let Some(vl) = epoch_val_loss {
720 tracing::info!(
721 "Epoch {:>4} | train_loss={:.6} | val_loss={:.6} | lr={:.6}",
722 epoch,
723 epoch_train_loss,
724 vl,
725 lr_scheduler.current_lr()
726 );
727 } else {
728 tracing::info!(
729 "Epoch {:>4} | train_loss={:.6} | lr={:.6}",
730 epoch,
731 epoch_train_loss,
732 lr_scheduler.current_lr()
733 );
734 }
735 }
736
737 for cb in self.callbacks.iter_mut() {
739 cb.on_epoch_end(epoch, epoch_train_loss, epoch_val_loss);
740 }
741 }
742
743 let epochs_trained = train_losses.len();
744 let final_train_loss = train_losses.last().copied().unwrap_or(f32::NAN);
745
746 let result = TrainingResult {
747 train_losses,
748 val_losses,
749 best_epoch,
750 best_val_loss,
751 epochs_trained,
752 final_train_loss,
753 };
754
755 for cb in self.callbacks.iter_mut() {
757 cb.on_training_end(&result);
758 }
759
760 Ok(result)
761 }
762}
763
764#[cfg(test)]
769mod tests {
770 use super::*;
771 use scirs2_core::ndarray::{Array1, Array2};
772
773 fn make_linear_dataset(n: usize, noise: f32) -> ArrayDataProvider {
775 let mut feat_data = vec![0.0_f32; n];
776 let mut tgt_data = vec![0.0_f32; n];
777 let mut state: u64 = 12345;
779 for i in 0..n {
780 state = state
781 .wrapping_mul(6_364_136_223_846_793_005)
782 .wrapping_add(1_442_695_040_888_963_407);
783 let x = i as f32 / n as f32;
784 let eps = ((state >> 33) as f32 / u32::MAX as f32 - 0.5) * 2.0 * noise;
785 feat_data[i] = x;
786 tgt_data[i] = 2.0 * x + 1.0 + eps;
787 }
788 let features = Array2::from_shape_vec((n, 1), feat_data).expect("shape ok");
789 let targets = Array1::from_vec(tgt_data);
790 ArrayDataProvider::new(features, targets)
791 }
792
793 #[test]
798 fn test_array_data_provider_batch() {
799 let provider = make_linear_dataset(50, 0.0);
800 assert_eq!(provider.num_samples(), 50);
801 assert_eq!(provider.num_features(), 1);
802
803 let indices: Vec<usize> = (0..10).collect();
804 let (feat, tgt) = provider.get_batch(&indices);
805
806 assert_eq!(feat.shape(), &[10, 1]);
807 assert_eq!(tgt.len(), 10);
808 }
809
810 #[test]
815 fn test_training_loop_linear_regression_convergence() {
816 let data = make_linear_dataset(100, 0.05);
817
818 let config = TrainingConfig {
819 max_epochs: 200,
820 batch_size: 32,
821 learning_rate: 0.1,
822 val_fraction: 0.2,
823 rng_seed: 7,
824 log_every_n_epochs: 0,
825 };
826
827 let mut optimizer = SgdOptimizer::new(config.learning_rate);
828 let mut scheduler = ConstantScheduler::new(config.learning_rate);
829 let mut weights = Array1::<f32>::zeros(1);
830 let mut bias = 0.0_f32;
831
832 let mut training_loop = TrainingLoop::new(config);
833 let result = training_loop
834 .run(
835 &data,
836 &mut optimizer,
837 &mut scheduler,
838 None,
839 &mut weights,
840 &mut bias,
841 )
842 .expect("training should succeed");
843
844 assert!(
845 result.final_train_loss < 0.1,
846 "expected final loss < 0.1, got {}",
847 result.final_train_loss
848 );
849 }
850
851 #[test]
856 fn test_training_loop_early_stopping() {
857 let data = make_linear_dataset(60, 0.0);
860
861 let config = TrainingConfig {
862 max_epochs: 500,
863 batch_size: 60,
864 learning_rate: 0.05,
865 val_fraction: 0.3,
866 rng_seed: 99,
867 log_every_n_epochs: 0,
868 };
869
870 let mut optimizer = SgdOptimizer::new(config.learning_rate);
871 let mut scheduler = ConstantScheduler::new(config.learning_rate);
872 let mut es = EarlyStopping::new(3, 0.001);
875 let mut weights = Array1::<f32>::zeros(1);
876 let mut bias = 0.0_f32;
877
878 let mut training_loop = TrainingLoop::new(config.clone());
879 let result = training_loop
880 .run(
881 &data,
882 &mut optimizer,
883 &mut scheduler,
884 Some(&mut es),
885 &mut weights,
886 &mut bias,
887 )
888 .expect("training should succeed");
889
890 assert!(
891 result.epochs_trained < config.max_epochs,
892 "expected early stop before {} epochs, trained {} epochs",
893 config.max_epochs,
894 result.epochs_trained
895 );
896 }
897
898 #[test]
903 fn test_training_loop_lr_scheduling() {
904 let data = make_linear_dataset(40, 0.0);
905
906 let initial_lr = 0.1_f32;
907 let config = TrainingConfig {
908 max_epochs: 20,
909 batch_size: 40,
910 learning_rate: initial_lr,
911 val_fraction: 0.0,
912 rng_seed: 1,
913 log_every_n_epochs: 0,
914 };
915
916 let mut optimizer = SgdOptimizer::new(initial_lr);
917 let mut scheduler = StepDecayScheduler::new(initial_lr, 2, 0.9);
919 let mut weights = Array1::<f32>::zeros(1);
920 let mut bias = 0.0_f32;
921
922 let mut training_loop = TrainingLoop::new(config.clone());
923 training_loop
924 .run(
925 &data,
926 &mut optimizer,
927 &mut scheduler,
928 None,
929 &mut weights,
930 &mut bias,
931 )
932 .expect("training should succeed");
933
934 assert!(
937 scheduler.current_lr() < initial_lr,
938 "scheduler should have reduced LR from {initial_lr} but got {}",
939 scheduler.current_lr()
940 );
941 }
942
943 #[test]
948 fn test_training_result_history() {
949 let data = make_linear_dataset(30, 0.0);
950
951 let config = TrainingConfig {
952 max_epochs: 10,
953 batch_size: 10,
954 learning_rate: 0.01,
955 val_fraction: 0.0,
956 rng_seed: 5,
957 log_every_n_epochs: 0,
958 };
959
960 let mut optimizer = SgdOptimizer::new(0.01);
961 let mut scheduler = ConstantScheduler::new(0.01);
962 let mut weights = Array1::<f32>::zeros(1);
963 let mut bias = 0.0_f32;
964
965 let mut training_loop = TrainingLoop::new(config.clone());
966 let result = training_loop
967 .run(
968 &data,
969 &mut optimizer,
970 &mut scheduler,
971 None,
972 &mut weights,
973 &mut bias,
974 )
975 .expect("training should succeed");
976
977 assert_eq!(
978 result.train_losses.len(),
979 result.epochs_trained,
980 "train_losses length must match epochs_trained"
981 );
982 assert_eq!(
983 result.val_losses.len(),
984 result.epochs_trained,
985 "val_losses length must match epochs_trained"
986 );
987 assert_eq!(result.epochs_trained, config.max_epochs);
988 }
989
990 struct EpochCounter {
995 count: usize,
996 }
997
998 impl TrainingCallback for EpochCounter {
999 fn on_epoch_end(&mut self, _epoch: usize, _train_loss: f32, _val_loss: Option<f32>) {
1000 self.count += 1;
1001 }
1002 }
1003
1004 #[test]
1005 fn test_training_callback_fired() {
1006 let data = make_linear_dataset(20, 0.0);
1007 let max_epochs = 7;
1008 let config = TrainingConfig {
1009 max_epochs,
1010 batch_size: 20,
1011 learning_rate: 0.01,
1012 val_fraction: 0.0,
1013 rng_seed: 3,
1014 log_every_n_epochs: 0,
1015 };
1016
1017 let mut optimizer = SgdOptimizer::new(0.01);
1018 let mut scheduler = ConstantScheduler::new(0.01);
1019 let mut weights = Array1::<f32>::zeros(1);
1020 let mut bias = 0.0_f32;
1021
1022 let counter = EpochCounter { count: 0 };
1023
1024 let mut training_loop = TrainingLoop::new(config.clone());
1025 training_loop.add_callback(Box::new(counter));
1026
1027 training_loop
1028 .run(
1029 &data,
1030 &mut optimizer,
1031 &mut scheduler,
1032 None,
1033 &mut weights,
1034 &mut bias,
1035 )
1036 .expect("training should succeed");
1037
1038 }
1044}