1use crate::{EmbeddingModel, TrainingStats};
4use anyhow::Result;
5use scirs2_core::ndarray_ext::Array2;
6use std::collections::VecDeque;
7use std::sync::{Arc, Mutex};
8use std::time::Instant;
9use tokio::sync::{broadcast, RwLock};
10use tokio::task::JoinHandle;
11use tracing::{debug, info, warn};
12
13pub struct TrainingScheduler {
15 pub config: TrainingConfig,
16 pub optimizer: OptimizerType,
17 pub scheduler: LearningRateScheduler,
18 pub early_stopping: Option<EarlyStopping>,
19}
20
21#[derive(Debug, Clone)]
23pub struct TrainingConfig {
24 pub max_epochs: usize,
25 pub batch_size: usize,
26 pub learning_rate: f64,
27 pub validation_freq: usize,
28 pub checkpoint_freq: usize,
29 pub log_freq: usize,
30 pub use_early_stopping: bool,
31 pub patience: usize,
32 pub min_delta: f64,
33}
34
35impl Default for TrainingConfig {
36 fn default() -> Self {
37 Self {
38 max_epochs: 1000,
39 batch_size: 1024,
40 learning_rate: 0.01,
41 validation_freq: 10,
42 checkpoint_freq: 100,
43 log_freq: 10,
44 use_early_stopping: true,
45 patience: 50,
46 min_delta: 1e-6,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum OptimizerType {
54 SGD,
55 Adam {
56 beta1: f64,
57 beta2: f64,
58 epsilon: f64,
59 },
60 AdaGrad {
61 epsilon: f64,
62 },
63 RMSprop {
64 alpha: f64,
65 epsilon: f64,
66 },
67}
68
69impl Default for OptimizerType {
70 fn default() -> Self {
71 OptimizerType::Adam {
72 beta1: 0.9,
73 beta2: 0.999,
74 epsilon: 1e-8,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub enum LearningRateScheduler {
82 Constant,
83 ExponentialDecay {
84 decay_rate: f64,
85 decay_steps: usize,
86 },
87 StepDecay {
88 step_size: usize,
89 gamma: f64,
90 },
91 CosineAnnealing {
92 t_max: usize,
93 eta_min: f64,
94 },
95 ReduceOnPlateau {
96 factor: f64,
97 patience: usize,
98 threshold: f64,
99 },
100}
101
102impl Default for LearningRateScheduler {
103 fn default() -> Self {
104 LearningRateScheduler::ExponentialDecay {
105 decay_rate: 0.96,
106 decay_steps: 100,
107 }
108 }
109}
110
111impl LearningRateScheduler {
112 pub fn get_lr(&self, epoch: usize, base_lr: f64, _current_loss: Option<f64>) -> f64 {
113 match self {
114 LearningRateScheduler::Constant => base_lr,
115 LearningRateScheduler::ExponentialDecay {
116 decay_rate,
117 decay_steps,
118 } => base_lr * decay_rate.powf(epoch as f64 / *decay_steps as f64),
119 LearningRateScheduler::StepDecay { step_size, gamma } => {
120 base_lr * gamma.powf((epoch / step_size) as f64)
121 }
122 LearningRateScheduler::CosineAnnealing { t_max, eta_min } => {
123 eta_min
124 + (base_lr - eta_min)
125 * (1.0 + (std::f64::consts::PI * epoch as f64 / *t_max as f64).cos())
126 / 2.0
127 }
128 LearningRateScheduler::ReduceOnPlateau { .. } => {
129 base_lr
131 }
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct EarlyStopping {
139 patience: usize,
140 min_delta: f64,
141 best_loss: f64,
142 wait_count: usize,
143 stopped: bool,
144}
145
146impl EarlyStopping {
147 pub fn new(patience: usize, min_delta: f64) -> Self {
148 Self {
149 patience,
150 min_delta,
151 best_loss: f64::INFINITY,
152 wait_count: 0,
153 stopped: false,
154 }
155 }
156
157 pub fn update(&mut self, current_loss: f64) -> bool {
158 if current_loss < self.best_loss - self.min_delta {
159 self.best_loss = current_loss;
160 self.wait_count = 0;
161 } else {
162 self.wait_count += 1;
163 if self.wait_count > self.patience {
164 self.stopped = true;
165 }
166 }
167
168 self.stopped
169 }
170
171 pub fn should_stop(&self) -> bool {
172 self.stopped
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct AdamOptimizer {
179 beta1: f64,
180 beta2: f64,
181 epsilon: f64,
182 t: usize, m: Option<Array2<f64>>, v: Option<Array2<f64>>, }
186
187impl AdamOptimizer {
188 pub fn new(beta1: f64, beta2: f64, epsilon: f64) -> Self {
189 Self {
190 beta1,
191 beta2,
192 epsilon,
193 t: 0,
194 m: None,
195 v: None,
196 }
197 }
198
199 pub fn update(&mut self, params: &mut Array2<f64>, grads: &Array2<f64>, lr: f64) {
200 self.t += 1;
201
202 if self.m.is_none() {
204 self.m = Some(Array2::zeros(params.raw_dim()));
205 self.v = Some(Array2::zeros(params.raw_dim()));
206 }
207
208 let m = self
209 .m
210 .as_mut()
211 .expect("moment estimate m should be initialized");
212 let v = self
213 .v
214 .as_mut()
215 .expect("moment estimate v should be initialized");
216
217 *m = &*m * self.beta1 + grads * (1.0 - self.beta1);
219
220 *v = &*v * self.beta2 + &(grads * grads) * (1.0 - self.beta2);
222
223 let m_hat = &*m / (1.0 - self.beta1.powi(self.t as i32));
225
226 let v_hat = &*v / (1.0 - self.beta2.powi(self.t as i32));
228
229 *params = &*params - &(&m_hat / (&v_hat.mapv(|x| x.sqrt()) + self.epsilon)) * lr;
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct MetricsTracker {
237 pub losses: Vec<f64>,
238 pub learning_rates: Vec<f64>,
239 pub epochs: Vec<usize>,
240 pub validation_losses: Vec<f64>,
241 pub training_times: Vec<f64>,
242}
243
244impl MetricsTracker {
245 pub fn new() -> Self {
246 Self {
247 losses: Vec::new(),
248 learning_rates: Vec::new(),
249 epochs: Vec::new(),
250 validation_losses: Vec::new(),
251 training_times: Vec::new(),
252 }
253 }
254
255 pub fn record_epoch(&mut self, epoch: usize, loss: f64, lr: f64, training_time: f64) {
256 self.epochs.push(epoch);
257 self.losses.push(loss);
258 self.learning_rates.push(lr);
259 self.training_times.push(training_time);
260 }
261
262 pub fn record_validation(&mut self, val_loss: f64) {
263 self.validation_losses.push(val_loss);
264 }
265
266 pub fn get_smoothed_loss(&self, window_size: usize) -> Vec<f64> {
267 if self.losses.len() < window_size {
268 return self.losses.clone();
269 }
270
271 let mut smoothed = Vec::new();
272 let mut window: VecDeque<f64> = VecDeque::new();
273
274 for &loss in &self.losses {
275 window.push_back(loss);
276 if window.len() > window_size {
277 window.pop_front();
278 }
279
280 let avg = window.iter().sum::<f64>() / window.len() as f64;
281 smoothed.push(avg);
282 }
283
284 smoothed
285 }
286}
287
288impl Default for MetricsTracker {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294pub struct AdvancedTrainer {
296 config: TrainingConfig,
297 optimizer: OptimizerType,
298 scheduler: LearningRateScheduler,
299 early_stopping: Option<EarlyStopping>,
300 metrics: MetricsTracker,
301}
302
303impl AdvancedTrainer {
304 pub fn new(config: TrainingConfig) -> Self {
305 let early_stopping = if config.use_early_stopping {
306 Some(EarlyStopping::new(config.patience, config.min_delta))
307 } else {
308 None
309 };
310
311 Self {
312 config,
313 optimizer: OptimizerType::default(),
314 scheduler: LearningRateScheduler::default(),
315 early_stopping,
316 metrics: MetricsTracker::new(),
317 }
318 }
319
320 pub fn with_optimizer(mut self, optimizer: OptimizerType) -> Self {
321 self.optimizer = optimizer;
322 self
323 }
324
325 pub fn with_scheduler(mut self, scheduler: LearningRateScheduler) -> Self {
326 self.scheduler = scheduler;
327 self
328 }
329
330 pub async fn train(&mut self, model: &mut dyn EmbeddingModel) -> Result<TrainingStats> {
331 let start_time = Instant::now();
332 info!(
333 "Starting advanced training with {} epochs",
334 self.config.max_epochs
335 );
336
337 for epoch in 0..self.config.max_epochs {
338 let epoch_start = Instant::now();
339
340 let current_lr = self
342 .scheduler
343 .get_lr(epoch, self.config.learning_rate, None);
344
345 let epoch_stats = model.train(Some(1)).await?;
347 let epoch_loss = epoch_stats.final_loss;
348 let epoch_time = epoch_start.elapsed().as_secs_f64();
349
350 self.metrics
352 .record_epoch(epoch, epoch_loss, current_lr, epoch_time);
353
354 if epoch % self.config.log_freq == 0 {
356 debug!(
357 "Epoch {}: loss = {:.6}, lr = {:.6}, time = {:.3}s",
358 epoch, epoch_loss, current_lr, epoch_time
359 );
360 }
361
362 if let Some(ref mut early_stop) = self.early_stopping {
364 if early_stop.update(epoch_loss) {
365 info!("Early stopping triggered at epoch {}", epoch);
366 break;
367 }
368 }
369
370 if epoch > 10 && epoch_loss < 1e-8 {
372 info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
373 break;
374 }
375 }
376
377 let training_time = start_time.elapsed().as_secs_f64();
378 let final_loss = self.metrics.losses.last().copied().unwrap_or(0.0);
379
380 Ok(TrainingStats {
381 epochs_completed: self.metrics.epochs.len(),
382 final_loss,
383 training_time_seconds: training_time,
384 convergence_achieved: final_loss < 1e-6,
385 loss_history: self.metrics.losses.clone(),
386 })
387 }
388
389 pub fn get_metrics(&self) -> &MetricsTracker {
390 &self.metrics
391 }
392}
393
394pub struct ValidationSuite {
396 pub test_triples: Vec<(String, String, String)>,
397 pub validation_freq: usize,
398}
399
400impl ValidationSuite {
401 pub fn new(test_triples: Vec<(String, String, String)>, validation_freq: usize) -> Self {
402 Self {
403 test_triples,
404 validation_freq,
405 }
406 }
407
408 pub fn evaluate_model(&self, model: &dyn EmbeddingModel) -> Result<ValidationMetrics> {
409 let mut total_score = 0.0;
410 let mut valid_predictions = 0;
411
412 for (subject, predicate, object) in &self.test_triples {
413 if let Ok(score) = model.score_triple(subject, predicate, object) {
414 total_score += score;
415 valid_predictions += 1;
416 }
417 }
418
419 let avg_score = if valid_predictions > 0 {
420 total_score / valid_predictions as f64
421 } else {
422 0.0
423 };
424
425 Ok(ValidationMetrics {
426 average_score: avg_score,
427 num_evaluated: valid_predictions,
428 num_total: self.test_triples.len(),
429 })
430 }
431}
432
433#[derive(Debug, Clone)]
435pub struct ValidationMetrics {
436 pub average_score: f64,
437 pub num_evaluated: usize,
438 pub num_total: usize,
439}
440
441#[derive(Debug, Clone)]
443pub struct DistributedConfig {
444 pub world_size: usize,
445 pub rank: usize,
446 pub device_ids: Vec<usize>,
447 pub backend: DistributedBackend,
448 pub sync_frequency: usize,
449 pub gradient_clipping: Option<f64>,
450 pub all_reduce_method: AllReduceMethod,
451}
452
453impl Default for DistributedConfig {
454 fn default() -> Self {
455 Self {
456 world_size: 1,
457 rank: 0,
458 device_ids: vec![0],
459 backend: DistributedBackend::NCCL,
460 sync_frequency: 1,
461 gradient_clipping: Some(1.0),
462 all_reduce_method: AllReduceMethod::Average,
463 }
464 }
465}
466
467#[derive(Debug, Clone)]
469pub enum DistributedBackend {
470 NCCL,
471 MPI,
472 Gloo,
473}
474
475#[derive(Debug, Clone)]
477pub enum AllReduceMethod {
478 Sum,
479 Average,
480 WeightedAverage,
481}
482
483#[allow(dead_code)]
485pub struct DistributedTrainer {
486 config: TrainingConfig,
487 distributed_config: DistributedConfig,
488 optimizer: OptimizerType,
489 scheduler: LearningRateScheduler,
490 early_stopping: Option<EarlyStopping>,
491 metrics: Arc<RwLock<MetricsTracker>>,
492 gradient_accumulator: Arc<Mutex<GradientAccumulator>>,
493 sync_channel: (
494 broadcast::Sender<SyncMessage>,
495 broadcast::Receiver<SyncMessage>,
496 ),
497}
498
499#[derive(Debug, Clone)]
501pub enum SyncMessage {
502 GradientUpdate {
503 epoch: usize,
504 rank: usize,
505 gradients: Vec<f64>,
506 },
507 ParameterSync {
508 epoch: usize,
509 parameters: Vec<f64>,
510 },
511 EarlyStop {
512 epoch: usize,
513 loss: f64,
514 },
515 Checkpoint {
516 epoch: usize,
517 model_state: Vec<u8>,
518 },
519}
520
521#[derive(Debug)]
523pub struct GradientAccumulator {
524 accumulated_gradients: Vec<Array2<f64>>,
525 accumulation_count: usize,
526 target_count: usize,
527}
528
529impl GradientAccumulator {
530 pub fn new(target_count: usize) -> Self {
531 Self {
532 accumulated_gradients: Vec::new(),
533 accumulation_count: 0,
534 target_count,
535 }
536 }
537
538 pub fn accumulate(&mut self, gradients: Vec<Array2<f64>>) {
539 if self.accumulated_gradients.is_empty() {
540 self.accumulated_gradients = gradients;
541 } else {
542 for (i, grad) in gradients.into_iter().enumerate() {
543 if i < self.accumulated_gradients.len() {
544 self.accumulated_gradients[i] = &self.accumulated_gradients[i] + &grad;
545 } else {
546 self.accumulated_gradients.push(grad);
547 }
548 }
549 }
550 self.accumulation_count += 1;
551 }
552
553 pub fn is_ready(&self) -> bool {
554 self.accumulation_count >= self.target_count
555 }
556
557 pub fn get_averaged_gradients(&mut self) -> Vec<Array2<f64>> {
558 let count = self.accumulation_count as f64;
559 let result = self
560 .accumulated_gradients
561 .iter()
562 .map(|grad| grad / count)
563 .collect();
564 self.reset();
565 result
566 }
567
568 pub fn reset(&mut self) {
569 self.accumulated_gradients.clear();
570 self.accumulation_count = 0;
571 }
572}
573
574impl DistributedTrainer {
575 pub fn new(config: TrainingConfig, distributed_config: DistributedConfig) -> Self {
576 let early_stopping = if config.use_early_stopping {
577 Some(EarlyStopping::new(config.patience, config.min_delta))
578 } else {
579 None
580 };
581
582 let (sync_tx, sync_rx) = broadcast::channel(1000);
583 let gradient_accumulator = Arc::new(Mutex::new(GradientAccumulator::new(
584 distributed_config.world_size,
585 )));
586
587 Self {
588 config,
589 distributed_config,
590 optimizer: OptimizerType::default(),
591 scheduler: LearningRateScheduler::default(),
592 early_stopping,
593 metrics: Arc::new(RwLock::new(MetricsTracker::new())),
594 gradient_accumulator,
595 sync_channel: (sync_tx, sync_rx),
596 }
597 }
598
599 pub fn with_optimizer(mut self, optimizer: OptimizerType) -> Self {
600 self.optimizer = optimizer;
601 self
602 }
603
604 pub fn with_scheduler(mut self, scheduler: LearningRateScheduler) -> Self {
605 self.scheduler = scheduler;
606 self
607 }
608
609 pub async fn train_distributed(
611 &mut self,
612 model: Arc<RwLock<dyn EmbeddingModel + Send + Sync>>,
613 ) -> Result<TrainingStats> {
614 let start_time = Instant::now();
615 info!(
616 "Starting distributed training with {} workers on rank {}",
617 self.distributed_config.world_size, self.distributed_config.rank
618 );
619
620 let mut worker_handles = Vec::new();
622
623 for device_id in &self.distributed_config.device_ids {
624 let worker_handle = self
625 .spawn_worker_task(*device_id, Arc::clone(&model))
626 .await?;
627 worker_handles.push(worker_handle);
628 }
629
630 let coordinator_handle = self.spawn_coordinator_task().await?;
632
633 let mut final_stats = None;
635 for handle in worker_handles {
636 if let Ok(stats) = handle.await {
637 match stats {
638 Ok(s) => final_stats = Some(s),
639 Err(e) => warn!("Worker failed: {}", e),
640 }
641 }
642 }
643
644 coordinator_handle.abort();
646
647 let training_time = start_time.elapsed().as_secs_f64();
648 let metrics = self.metrics.read().await;
649
650 Ok(final_stats.unwrap_or_else(|| TrainingStats {
651 epochs_completed: metrics.epochs.len(),
652 final_loss: metrics.losses.last().copied().unwrap_or(0.0),
653 training_time_seconds: training_time,
654 convergence_achieved: false,
655 loss_history: metrics.losses.clone(),
656 }))
657 }
658
659 async fn spawn_worker_task(
661 &self,
662 device_id: usize,
663 model: Arc<RwLock<dyn EmbeddingModel + Send + Sync>>,
664 ) -> Result<JoinHandle<Result<TrainingStats>>> {
665 let config = self.config.clone();
666 let distributed_config = self.distributed_config.clone();
667 let _optimizer = self.optimizer.clone();
668 let scheduler = self.scheduler.clone();
669 let metrics = Arc::clone(&self.metrics);
670 let mut sync_rx = self.sync_channel.0.subscribe();
671 let sync_tx = self.sync_channel.0.clone();
672
673 let handle = tokio::spawn(async move {
674 info!(
675 "Worker {} starting on device {}",
676 distributed_config.rank, device_id
677 );
678
679 let mut local_early_stopping = if config.use_early_stopping {
680 Some(EarlyStopping::new(config.patience, config.min_delta))
681 } else {
682 None
683 };
684
685 let mut total_training_time = 0.0;
686
687 for epoch in 0..config.max_epochs {
688 let epoch_start = Instant::now();
689
690 let current_lr = scheduler.get_lr(epoch, config.learning_rate, None);
692
693 let mut model_guard = model.write().await;
695 let epoch_stats = model_guard.train(Some(1)).await?;
696 drop(model_guard);
697
698 let epoch_loss = epoch_stats.final_loss;
699 let epoch_time = epoch_start.elapsed().as_secs_f64();
700 total_training_time += epoch_time;
701
702 {
704 let mut metrics_guard = metrics.write().await;
705 metrics_guard.record_epoch(epoch, epoch_loss, current_lr, epoch_time);
706 }
707
708 if epoch % distributed_config.sync_frequency == 0 {
710 let _ = sync_tx.send(SyncMessage::GradientUpdate {
712 epoch,
713 rank: distributed_config.rank,
714 gradients: vec![epoch_loss], });
716
717 tokio::select! {
719 msg = sync_rx.recv() => {
720 if let Ok(SyncMessage::ParameterSync { .. }) = msg {
721 debug!("Received parameter sync for epoch {}", epoch);
722 }
723 }
724 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
725 debug!("Sync timeout for epoch {}", epoch);
726 }
727 }
728 }
729
730 if epoch % config.log_freq == 0 {
732 debug!(
733 "Worker {} Epoch {}: loss = {:.6}, lr = {:.6}, time = {:.3}s",
734 distributed_config.rank, epoch, epoch_loss, current_lr, epoch_time
735 );
736 }
737
738 if let Some(ref mut early_stop) = local_early_stopping {
740 if early_stop.update(epoch_loss) {
741 info!(
742 "Worker {} early stopping triggered at epoch {}",
743 distributed_config.rank, epoch
744 );
745 let _ = sync_tx.send(SyncMessage::EarlyStop {
746 epoch,
747 loss: epoch_loss,
748 });
749 break;
750 }
751 }
752
753 if epoch > 10 && epoch_loss < 1e-8 {
755 info!(
756 "Worker {} converged at epoch {} with loss {:.6}",
757 distributed_config.rank, epoch, epoch_loss
758 );
759 break;
760 }
761 }
762
763 let final_metrics = metrics.read().await;
764 Ok(TrainingStats {
765 epochs_completed: final_metrics.epochs.len(),
766 final_loss: final_metrics.losses.last().copied().unwrap_or(0.0),
767 training_time_seconds: total_training_time,
768 convergence_achieved: final_metrics
769 .losses
770 .last()
771 .copied()
772 .unwrap_or(f64::INFINITY)
773 < 1e-6,
774 loss_history: final_metrics.losses.clone(),
775 })
776 });
777
778 Ok(handle)
779 }
780
781 async fn spawn_coordinator_task(&self) -> Result<JoinHandle<()>> {
783 let mut sync_rx = self.sync_channel.0.subscribe();
784 let sync_tx = self.sync_channel.0.clone();
785 let gradient_accumulator = Arc::clone(&self.gradient_accumulator);
786 let world_size = self.distributed_config.world_size;
787
788 let handle = tokio::spawn(async move {
789 info!("Coordinator starting for {} workers", world_size);
790
791 while let Ok(msg) = sync_rx.recv().await {
792 match msg {
793 SyncMessage::GradientUpdate {
794 epoch,
795 rank,
796 gradients,
797 } => {
798 debug!(
799 "Received gradients from worker {} for epoch {}",
800 rank, epoch
801 );
802
803 {
805 let _accumulator = gradient_accumulator
806 .lock()
807 .expect("lock should not be poisoned");
808 }
811
812 let _ = sync_tx.send(SyncMessage::ParameterSync {
814 epoch,
815 parameters: gradients, });
817 }
818 SyncMessage::EarlyStop { epoch, loss } => {
819 info!(
820 "Early stop signal received at epoch {} with loss {:.6}",
821 epoch, loss
822 );
823 }
825 _ => {}
826 }
827 }
828 });
829
830 Ok(handle)
831 }
832
833 #[allow(dead_code)]
835 async fn all_reduce_gradients(&self, gradients: Vec<Array2<f64>>) -> Result<Vec<Array2<f64>>> {
836 match self.distributed_config.all_reduce_method {
838 AllReduceMethod::Average => {
839 let world_size = self.distributed_config.world_size as f64;
840 Ok(gradients.into_iter().map(|g| g / world_size).collect())
841 }
842 AllReduceMethod::Sum => Ok(gradients),
843 AllReduceMethod::WeightedAverage => {
844 let world_size = self.distributed_config.world_size as f64;
846 Ok(gradients.into_iter().map(|g| g / world_size).collect())
847 }
848 }
849 }
850
851 #[allow(dead_code)]
853 fn clip_gradients(&self, gradients: &mut [Array2<f64>]) {
854 if let Some(max_norm) = self.distributed_config.gradient_clipping {
855 for grad in gradients.iter_mut() {
856 let norm = grad.mapv(|x| x * x).sum().sqrt();
857 if norm > max_norm {
858 *grad *= max_norm / norm;
859 }
860 }
861 }
862 }
863}
864
865pub struct DistributedUtils;
867
868impl DistributedUtils {
869 pub async fn init_distributed(rank: usize, world_size: usize) -> Result<()> {
871 info!(
872 "Initializing distributed training: rank {} of {}",
873 rank, world_size
874 );
875 Ok(())
877 }
878
879 pub async fn cleanup_distributed() -> Result<()> {
881 info!("Cleaning up distributed training environment");
882 Ok(())
884 }
885
886 pub fn is_distributed_available() -> bool {
888 true
890 }
891
892 pub fn get_optimal_world_size() -> usize {
894 std::thread::available_parallelism()
896 .map(|p| p.get())
897 .unwrap_or(1)
898 }
899}
900
901#[cfg(test)]
902mod tests {
903 use super::*;
904
905 #[test]
906 fn test_learning_rate_scheduler() {
907 let scheduler = LearningRateScheduler::ExponentialDecay {
908 decay_rate: 0.9,
909 decay_steps: 10,
910 };
911
912 let lr0 = scheduler.get_lr(0, 0.1, None);
913 let lr10 = scheduler.get_lr(10, 0.1, None);
914 let lr20 = scheduler.get_lr(20, 0.1, None);
915
916 assert!((lr0 - 0.1).abs() < 1e-10);
917 assert!(lr10 < lr0);
918 assert!(lr20 < lr10);
919 }
920
921 #[test]
922 fn test_early_stopping() {
923 let mut early_stop = EarlyStopping::new(3, 0.01);
924
925 assert!(!early_stop.update(1.0));
926 assert!(!early_stop.update(0.5));
927 assert!(!early_stop.update(0.51));
928 assert!(!early_stop.update(0.52));
929 assert!(!early_stop.update(0.53));
930 assert!(early_stop.update(0.54)); }
932
933 #[test]
934 fn test_metrics_tracker() {
935 let mut tracker = MetricsTracker::new();
936
937 tracker.record_epoch(0, 1.0, 0.01, 1.5);
938 tracker.record_epoch(1, 0.5, 0.009, 1.4);
939 tracker.record_epoch(2, 0.3, 0.008, 1.3);
940
941 assert_eq!(tracker.losses.len(), 3);
942 assert_eq!(tracker.epochs.len(), 3);
943
944 let smoothed = tracker.get_smoothed_loss(2);
945 assert_eq!(smoothed.len(), 3);
946 }
947
948 #[test]
949 fn test_distributed_config() {
950 let config = DistributedConfig::default();
951 assert_eq!(config.world_size, 1);
952 assert_eq!(config.rank, 0);
953 assert_eq!(config.device_ids.len(), 1);
954 }
955
956 #[test]
957 fn test_gradient_accumulator() {
958 let mut accumulator = GradientAccumulator::new(2);
959 assert!(!accumulator.is_ready());
960
961 let grad1 = vec![Array2::from_elem((2, 2), 1.0)];
962 let grad2 = vec![Array2::from_elem((2, 2), 2.0)];
963
964 accumulator.accumulate(grad1);
965 assert!(!accumulator.is_ready());
966
967 accumulator.accumulate(grad2);
968 assert!(accumulator.is_ready());
969
970 let averaged = accumulator.get_averaged_gradients();
971 assert_eq!(averaged.len(), 1);
972 assert!((averaged[0][[0, 0]] - 1.5).abs() < 1e-10);
973 }
974
975 #[test]
976 fn test_distributed_utils() {
977 assert!(DistributedUtils::is_distributed_available());
978 let world_size = DistributedUtils::get_optimal_world_size();
979 assert!(world_size >= 1);
980 }
981}