1use crate::{EmbeddingModel, TrainingStats};
4use anyhow::Result;
5use scirs2_core::ndarray_ext::Array2;
6use std::collections::VecDeque;
7use std::sync::{Arc, Mutex};
8use std::time::Instant;
9use tokio::sync::{broadcast, RwLock};
10use tokio::task::JoinHandle;
11use tracing::{debug, info, warn};
12
13pub struct TrainingScheduler {
15 pub config: TrainingConfig,
16 pub optimizer: OptimizerType,
17 pub scheduler: LearningRateScheduler,
18 pub early_stopping: Option<EarlyStopping>,
19}
20
21#[derive(Debug, Clone)]
23pub struct TrainingConfig {
24 pub max_epochs: usize,
25 pub batch_size: usize,
26 pub learning_rate: f64,
27 pub validation_freq: usize,
28 pub checkpoint_freq: usize,
29 pub log_freq: usize,
30 pub use_early_stopping: bool,
31 pub patience: usize,
32 pub min_delta: f64,
33}
34
35impl Default for TrainingConfig {
36 fn default() -> Self {
37 Self {
38 max_epochs: 1000,
39 batch_size: 1024,
40 learning_rate: 0.01,
41 validation_freq: 10,
42 checkpoint_freq: 100,
43 log_freq: 10,
44 use_early_stopping: true,
45 patience: 50,
46 min_delta: 1e-6,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum OptimizerType {
54 SGD,
55 Adam {
56 beta1: f64,
57 beta2: f64,
58 epsilon: f64,
59 },
60 AdaGrad {
61 epsilon: f64,
62 },
63 RMSprop {
64 alpha: f64,
65 epsilon: f64,
66 },
67}
68
69impl Default for OptimizerType {
70 fn default() -> Self {
71 OptimizerType::Adam {
72 beta1: 0.9,
73 beta2: 0.999,
74 epsilon: 1e-8,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub enum LearningRateScheduler {
82 Constant,
83 ExponentialDecay {
84 decay_rate: f64,
85 decay_steps: usize,
86 },
87 StepDecay {
88 step_size: usize,
89 gamma: f64,
90 },
91 CosineAnnealing {
92 t_max: usize,
93 eta_min: f64,
94 },
95 ReduceOnPlateau {
96 factor: f64,
97 patience: usize,
98 threshold: f64,
99 },
100}
101
102impl Default for LearningRateScheduler {
103 fn default() -> Self {
104 LearningRateScheduler::ExponentialDecay {
105 decay_rate: 0.96,
106 decay_steps: 100,
107 }
108 }
109}
110
111impl LearningRateScheduler {
112 pub fn get_lr(&self, epoch: usize, base_lr: f64, _current_loss: Option<f64>) -> f64 {
113 match self {
114 LearningRateScheduler::Constant => base_lr,
115 LearningRateScheduler::ExponentialDecay {
116 decay_rate,
117 decay_steps,
118 } => base_lr * decay_rate.powf(epoch as f64 / *decay_steps as f64),
119 LearningRateScheduler::StepDecay { step_size, gamma } => {
120 base_lr * gamma.powf((epoch / step_size) as f64)
121 }
122 LearningRateScheduler::CosineAnnealing { t_max, eta_min } => {
123 eta_min
124 + (base_lr - eta_min)
125 * (1.0 + (std::f64::consts::PI * epoch as f64 / *t_max as f64).cos())
126 / 2.0
127 }
128 LearningRateScheduler::ReduceOnPlateau { .. } => {
129 base_lr
131 }
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct EarlyStopping {
139 patience: usize,
140 min_delta: f64,
141 best_loss: f64,
142 wait_count: usize,
143 stopped: bool,
144}
145
146impl EarlyStopping {
147 pub fn new(patience: usize, min_delta: f64) -> Self {
148 Self {
149 patience,
150 min_delta,
151 best_loss: f64::INFINITY,
152 wait_count: 0,
153 stopped: false,
154 }
155 }
156
157 pub fn update(&mut self, current_loss: f64) -> bool {
158 if current_loss < self.best_loss - self.min_delta {
159 self.best_loss = current_loss;
160 self.wait_count = 0;
161 } else {
162 self.wait_count += 1;
163 if self.wait_count > self.patience {
164 self.stopped = true;
165 }
166 }
167
168 self.stopped
169 }
170
171 pub fn should_stop(&self) -> bool {
172 self.stopped
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct AdamOptimizer {
179 beta1: f64,
180 beta2: f64,
181 epsilon: f64,
182 t: usize, m: Option<Array2<f64>>, v: Option<Array2<f64>>, }
186
187impl AdamOptimizer {
188 pub fn new(beta1: f64, beta2: f64, epsilon: f64) -> Self {
189 Self {
190 beta1,
191 beta2,
192 epsilon,
193 t: 0,
194 m: None,
195 v: None,
196 }
197 }
198
199 pub fn update(&mut self, params: &mut Array2<f64>, grads: &Array2<f64>, lr: f64) {
200 self.t += 1;
201
202 if self.m.is_none() {
204 self.m = Some(Array2::zeros(params.raw_dim()));
205 self.v = Some(Array2::zeros(params.raw_dim()));
206 }
207
208 let m = self.m.as_mut().unwrap();
209 let v = self.v.as_mut().unwrap();
210
211 *m = &*m * self.beta1 + grads * (1.0 - self.beta1);
213
214 *v = &*v * self.beta2 + &(grads * grads) * (1.0 - self.beta2);
216
217 let m_hat = &*m / (1.0 - self.beta1.powi(self.t as i32));
219
220 let v_hat = &*v / (1.0 - self.beta2.powi(self.t as i32));
222
223 *params = &*params - &(&m_hat / (&v_hat.mapv(|x| x.sqrt()) + self.epsilon)) * lr;
225 }
226}
227
228#[derive(Debug, Clone)]
230pub struct MetricsTracker {
231 pub losses: Vec<f64>,
232 pub learning_rates: Vec<f64>,
233 pub epochs: Vec<usize>,
234 pub validation_losses: Vec<f64>,
235 pub training_times: Vec<f64>,
236}
237
238impl MetricsTracker {
239 pub fn new() -> Self {
240 Self {
241 losses: Vec::new(),
242 learning_rates: Vec::new(),
243 epochs: Vec::new(),
244 validation_losses: Vec::new(),
245 training_times: Vec::new(),
246 }
247 }
248
249 pub fn record_epoch(&mut self, epoch: usize, loss: f64, lr: f64, training_time: f64) {
250 self.epochs.push(epoch);
251 self.losses.push(loss);
252 self.learning_rates.push(lr);
253 self.training_times.push(training_time);
254 }
255
256 pub fn record_validation(&mut self, val_loss: f64) {
257 self.validation_losses.push(val_loss);
258 }
259
260 pub fn get_smoothed_loss(&self, window_size: usize) -> Vec<f64> {
261 if self.losses.len() < window_size {
262 return self.losses.clone();
263 }
264
265 let mut smoothed = Vec::new();
266 let mut window: VecDeque<f64> = VecDeque::new();
267
268 for &loss in &self.losses {
269 window.push_back(loss);
270 if window.len() > window_size {
271 window.pop_front();
272 }
273
274 let avg = window.iter().sum::<f64>() / window.len() as f64;
275 smoothed.push(avg);
276 }
277
278 smoothed
279 }
280}
281
282impl Default for MetricsTracker {
283 fn default() -> Self {
284 Self::new()
285 }
286}
287
288pub struct AdvancedTrainer {
290 config: TrainingConfig,
291 optimizer: OptimizerType,
292 scheduler: LearningRateScheduler,
293 early_stopping: Option<EarlyStopping>,
294 metrics: MetricsTracker,
295}
296
297impl AdvancedTrainer {
298 pub fn new(config: TrainingConfig) -> Self {
299 let early_stopping = if config.use_early_stopping {
300 Some(EarlyStopping::new(config.patience, config.min_delta))
301 } else {
302 None
303 };
304
305 Self {
306 config,
307 optimizer: OptimizerType::default(),
308 scheduler: LearningRateScheduler::default(),
309 early_stopping,
310 metrics: MetricsTracker::new(),
311 }
312 }
313
314 pub fn with_optimizer(mut self, optimizer: OptimizerType) -> Self {
315 self.optimizer = optimizer;
316 self
317 }
318
319 pub fn with_scheduler(mut self, scheduler: LearningRateScheduler) -> Self {
320 self.scheduler = scheduler;
321 self
322 }
323
324 pub async fn train(&mut self, model: &mut dyn EmbeddingModel) -> Result<TrainingStats> {
325 let start_time = Instant::now();
326 info!(
327 "Starting advanced training with {} epochs",
328 self.config.max_epochs
329 );
330
331 for epoch in 0..self.config.max_epochs {
332 let epoch_start = Instant::now();
333
334 let current_lr = self
336 .scheduler
337 .get_lr(epoch, self.config.learning_rate, None);
338
339 let epoch_stats = model.train(Some(1)).await?;
341 let epoch_loss = epoch_stats.final_loss;
342 let epoch_time = epoch_start.elapsed().as_secs_f64();
343
344 self.metrics
346 .record_epoch(epoch, epoch_loss, current_lr, epoch_time);
347
348 if epoch % self.config.log_freq == 0 {
350 debug!(
351 "Epoch {}: loss = {:.6}, lr = {:.6}, time = {:.3}s",
352 epoch, epoch_loss, current_lr, epoch_time
353 );
354 }
355
356 if let Some(ref mut early_stop) = self.early_stopping {
358 if early_stop.update(epoch_loss) {
359 info!("Early stopping triggered at epoch {}", epoch);
360 break;
361 }
362 }
363
364 if epoch > 10 && epoch_loss < 1e-8 {
366 info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
367 break;
368 }
369 }
370
371 let training_time = start_time.elapsed().as_secs_f64();
372 let final_loss = self.metrics.losses.last().copied().unwrap_or(0.0);
373
374 Ok(TrainingStats {
375 epochs_completed: self.metrics.epochs.len(),
376 final_loss,
377 training_time_seconds: training_time,
378 convergence_achieved: final_loss < 1e-6,
379 loss_history: self.metrics.losses.clone(),
380 })
381 }
382
383 pub fn get_metrics(&self) -> &MetricsTracker {
384 &self.metrics
385 }
386}
387
388pub struct ValidationSuite {
390 pub test_triples: Vec<(String, String, String)>,
391 pub validation_freq: usize,
392}
393
394impl ValidationSuite {
395 pub fn new(test_triples: Vec<(String, String, String)>, validation_freq: usize) -> Self {
396 Self {
397 test_triples,
398 validation_freq,
399 }
400 }
401
402 pub fn evaluate_model(&self, model: &dyn EmbeddingModel) -> Result<ValidationMetrics> {
403 let mut total_score = 0.0;
404 let mut valid_predictions = 0;
405
406 for (subject, predicate, object) in &self.test_triples {
407 if let Ok(score) = model.score_triple(subject, predicate, object) {
408 total_score += score;
409 valid_predictions += 1;
410 }
411 }
412
413 let avg_score = if valid_predictions > 0 {
414 total_score / valid_predictions as f64
415 } else {
416 0.0
417 };
418
419 Ok(ValidationMetrics {
420 average_score: avg_score,
421 num_evaluated: valid_predictions,
422 num_total: self.test_triples.len(),
423 })
424 }
425}
426
427#[derive(Debug, Clone)]
429pub struct ValidationMetrics {
430 pub average_score: f64,
431 pub num_evaluated: usize,
432 pub num_total: usize,
433}
434
435#[derive(Debug, Clone)]
437pub struct DistributedConfig {
438 pub world_size: usize,
439 pub rank: usize,
440 pub device_ids: Vec<usize>,
441 pub backend: DistributedBackend,
442 pub sync_frequency: usize,
443 pub gradient_clipping: Option<f64>,
444 pub all_reduce_method: AllReduceMethod,
445}
446
447impl Default for DistributedConfig {
448 fn default() -> Self {
449 Self {
450 world_size: 1,
451 rank: 0,
452 device_ids: vec![0],
453 backend: DistributedBackend::NCCL,
454 sync_frequency: 1,
455 gradient_clipping: Some(1.0),
456 all_reduce_method: AllReduceMethod::Average,
457 }
458 }
459}
460
461#[derive(Debug, Clone)]
463pub enum DistributedBackend {
464 NCCL,
465 MPI,
466 Gloo,
467}
468
469#[derive(Debug, Clone)]
471pub enum AllReduceMethod {
472 Sum,
473 Average,
474 WeightedAverage,
475}
476
477#[allow(dead_code)]
479pub struct DistributedTrainer {
480 config: TrainingConfig,
481 distributed_config: DistributedConfig,
482 optimizer: OptimizerType,
483 scheduler: LearningRateScheduler,
484 early_stopping: Option<EarlyStopping>,
485 metrics: Arc<RwLock<MetricsTracker>>,
486 gradient_accumulator: Arc<Mutex<GradientAccumulator>>,
487 sync_channel: (
488 broadcast::Sender<SyncMessage>,
489 broadcast::Receiver<SyncMessage>,
490 ),
491}
492
493#[derive(Debug, Clone)]
495pub enum SyncMessage {
496 GradientUpdate {
497 epoch: usize,
498 rank: usize,
499 gradients: Vec<f64>,
500 },
501 ParameterSync {
502 epoch: usize,
503 parameters: Vec<f64>,
504 },
505 EarlyStop {
506 epoch: usize,
507 loss: f64,
508 },
509 Checkpoint {
510 epoch: usize,
511 model_state: Vec<u8>,
512 },
513}
514
515#[derive(Debug)]
517pub struct GradientAccumulator {
518 accumulated_gradients: Vec<Array2<f64>>,
519 accumulation_count: usize,
520 target_count: usize,
521}
522
523impl GradientAccumulator {
524 pub fn new(target_count: usize) -> Self {
525 Self {
526 accumulated_gradients: Vec::new(),
527 accumulation_count: 0,
528 target_count,
529 }
530 }
531
532 pub fn accumulate(&mut self, gradients: Vec<Array2<f64>>) {
533 if self.accumulated_gradients.is_empty() {
534 self.accumulated_gradients = gradients;
535 } else {
536 for (i, grad) in gradients.into_iter().enumerate() {
537 if i < self.accumulated_gradients.len() {
538 self.accumulated_gradients[i] = &self.accumulated_gradients[i] + &grad;
539 } else {
540 self.accumulated_gradients.push(grad);
541 }
542 }
543 }
544 self.accumulation_count += 1;
545 }
546
547 pub fn is_ready(&self) -> bool {
548 self.accumulation_count >= self.target_count
549 }
550
551 pub fn get_averaged_gradients(&mut self) -> Vec<Array2<f64>> {
552 let count = self.accumulation_count as f64;
553 let result = self
554 .accumulated_gradients
555 .iter()
556 .map(|grad| grad / count)
557 .collect();
558 self.reset();
559 result
560 }
561
562 pub fn reset(&mut self) {
563 self.accumulated_gradients.clear();
564 self.accumulation_count = 0;
565 }
566}
567
568impl DistributedTrainer {
569 pub fn new(config: TrainingConfig, distributed_config: DistributedConfig) -> Self {
570 let early_stopping = if config.use_early_stopping {
571 Some(EarlyStopping::new(config.patience, config.min_delta))
572 } else {
573 None
574 };
575
576 let (sync_tx, sync_rx) = broadcast::channel(1000);
577 let gradient_accumulator = Arc::new(Mutex::new(GradientAccumulator::new(
578 distributed_config.world_size,
579 )));
580
581 Self {
582 config,
583 distributed_config,
584 optimizer: OptimizerType::default(),
585 scheduler: LearningRateScheduler::default(),
586 early_stopping,
587 metrics: Arc::new(RwLock::new(MetricsTracker::new())),
588 gradient_accumulator,
589 sync_channel: (sync_tx, sync_rx),
590 }
591 }
592
593 pub fn with_optimizer(mut self, optimizer: OptimizerType) -> Self {
594 self.optimizer = optimizer;
595 self
596 }
597
598 pub fn with_scheduler(mut self, scheduler: LearningRateScheduler) -> Self {
599 self.scheduler = scheduler;
600 self
601 }
602
603 pub async fn train_distributed(
605 &mut self,
606 model: Arc<RwLock<dyn EmbeddingModel + Send + Sync>>,
607 ) -> Result<TrainingStats> {
608 let start_time = Instant::now();
609 info!(
610 "Starting distributed training with {} workers on rank {}",
611 self.distributed_config.world_size, self.distributed_config.rank
612 );
613
614 let mut worker_handles = Vec::new();
616
617 for device_id in &self.distributed_config.device_ids {
618 let worker_handle = self
619 .spawn_worker_task(*device_id, Arc::clone(&model))
620 .await?;
621 worker_handles.push(worker_handle);
622 }
623
624 let coordinator_handle = self.spawn_coordinator_task().await?;
626
627 let mut final_stats = None;
629 for handle in worker_handles {
630 if let Ok(stats) = handle.await {
631 match stats {
632 Ok(s) => final_stats = Some(s),
633 Err(e) => warn!("Worker failed: {}", e),
634 }
635 }
636 }
637
638 coordinator_handle.abort();
640
641 let training_time = start_time.elapsed().as_secs_f64();
642 let metrics = self.metrics.read().await;
643
644 Ok(final_stats.unwrap_or_else(|| TrainingStats {
645 epochs_completed: metrics.epochs.len(),
646 final_loss: metrics.losses.last().copied().unwrap_or(0.0),
647 training_time_seconds: training_time,
648 convergence_achieved: false,
649 loss_history: metrics.losses.clone(),
650 }))
651 }
652
653 async fn spawn_worker_task(
655 &self,
656 device_id: usize,
657 model: Arc<RwLock<dyn EmbeddingModel + Send + Sync>>,
658 ) -> Result<JoinHandle<Result<TrainingStats>>> {
659 let config = self.config.clone();
660 let distributed_config = self.distributed_config.clone();
661 let _optimizer = self.optimizer.clone();
662 let scheduler = self.scheduler.clone();
663 let metrics = Arc::clone(&self.metrics);
664 let mut sync_rx = self.sync_channel.0.subscribe();
665 let sync_tx = self.sync_channel.0.clone();
666
667 let handle = tokio::spawn(async move {
668 info!(
669 "Worker {} starting on device {}",
670 distributed_config.rank, device_id
671 );
672
673 let mut local_early_stopping = if config.use_early_stopping {
674 Some(EarlyStopping::new(config.patience, config.min_delta))
675 } else {
676 None
677 };
678
679 let mut total_training_time = 0.0;
680
681 for epoch in 0..config.max_epochs {
682 let epoch_start = Instant::now();
683
684 let current_lr = scheduler.get_lr(epoch, config.learning_rate, None);
686
687 let mut model_guard = model.write().await;
689 let epoch_stats = model_guard.train(Some(1)).await?;
690 drop(model_guard);
691
692 let epoch_loss = epoch_stats.final_loss;
693 let epoch_time = epoch_start.elapsed().as_secs_f64();
694 total_training_time += epoch_time;
695
696 {
698 let mut metrics_guard = metrics.write().await;
699 metrics_guard.record_epoch(epoch, epoch_loss, current_lr, epoch_time);
700 }
701
702 if epoch % distributed_config.sync_frequency == 0 {
704 let _ = sync_tx.send(SyncMessage::GradientUpdate {
706 epoch,
707 rank: distributed_config.rank,
708 gradients: vec![epoch_loss], });
710
711 tokio::select! {
713 msg = sync_rx.recv() => {
714 if let Ok(SyncMessage::ParameterSync { .. }) = msg {
715 debug!("Received parameter sync for epoch {}", epoch);
716 }
717 }
718 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
719 debug!("Sync timeout for epoch {}", epoch);
720 }
721 }
722 }
723
724 if epoch % config.log_freq == 0 {
726 debug!(
727 "Worker {} Epoch {}: loss = {:.6}, lr = {:.6}, time = {:.3}s",
728 distributed_config.rank, epoch, epoch_loss, current_lr, epoch_time
729 );
730 }
731
732 if let Some(ref mut early_stop) = local_early_stopping {
734 if early_stop.update(epoch_loss) {
735 info!(
736 "Worker {} early stopping triggered at epoch {}",
737 distributed_config.rank, epoch
738 );
739 let _ = sync_tx.send(SyncMessage::EarlyStop {
740 epoch,
741 loss: epoch_loss,
742 });
743 break;
744 }
745 }
746
747 if epoch > 10 && epoch_loss < 1e-8 {
749 info!(
750 "Worker {} converged at epoch {} with loss {:.6}",
751 distributed_config.rank, epoch, epoch_loss
752 );
753 break;
754 }
755 }
756
757 let final_metrics = metrics.read().await;
758 Ok(TrainingStats {
759 epochs_completed: final_metrics.epochs.len(),
760 final_loss: final_metrics.losses.last().copied().unwrap_or(0.0),
761 training_time_seconds: total_training_time,
762 convergence_achieved: final_metrics
763 .losses
764 .last()
765 .copied()
766 .unwrap_or(f64::INFINITY)
767 < 1e-6,
768 loss_history: final_metrics.losses.clone(),
769 })
770 });
771
772 Ok(handle)
773 }
774
775 async fn spawn_coordinator_task(&self) -> Result<JoinHandle<()>> {
777 let mut sync_rx = self.sync_channel.0.subscribe();
778 let sync_tx = self.sync_channel.0.clone();
779 let gradient_accumulator = Arc::clone(&self.gradient_accumulator);
780 let world_size = self.distributed_config.world_size;
781
782 let handle = tokio::spawn(async move {
783 info!("Coordinator starting for {} workers", world_size);
784
785 while let Ok(msg) = sync_rx.recv().await {
786 match msg {
787 SyncMessage::GradientUpdate {
788 epoch,
789 rank,
790 gradients,
791 } => {
792 debug!(
793 "Received gradients from worker {} for epoch {}",
794 rank, epoch
795 );
796
797 {
799 let _accumulator = gradient_accumulator.lock().unwrap();
800 }
803
804 let _ = sync_tx.send(SyncMessage::ParameterSync {
806 epoch,
807 parameters: gradients, });
809 }
810 SyncMessage::EarlyStop { epoch, loss } => {
811 info!(
812 "Early stop signal received at epoch {} with loss {:.6}",
813 epoch, loss
814 );
815 }
817 _ => {}
818 }
819 }
820 });
821
822 Ok(handle)
823 }
824
825 #[allow(dead_code)]
827 async fn all_reduce_gradients(&self, gradients: Vec<Array2<f64>>) -> Result<Vec<Array2<f64>>> {
828 match self.distributed_config.all_reduce_method {
830 AllReduceMethod::Average => {
831 let world_size = self.distributed_config.world_size as f64;
832 Ok(gradients.into_iter().map(|g| g / world_size).collect())
833 }
834 AllReduceMethod::Sum => Ok(gradients),
835 AllReduceMethod::WeightedAverage => {
836 let world_size = self.distributed_config.world_size as f64;
838 Ok(gradients.into_iter().map(|g| g / world_size).collect())
839 }
840 }
841 }
842
843 #[allow(dead_code)]
845 fn clip_gradients(&self, gradients: &mut [Array2<f64>]) {
846 if let Some(max_norm) = self.distributed_config.gradient_clipping {
847 for grad in gradients.iter_mut() {
848 let norm = grad.mapv(|x| x * x).sum().sqrt();
849 if norm > max_norm {
850 *grad *= max_norm / norm;
851 }
852 }
853 }
854 }
855}
856
857pub struct DistributedUtils;
859
860impl DistributedUtils {
861 pub async fn init_distributed(rank: usize, world_size: usize) -> Result<()> {
863 info!(
864 "Initializing distributed training: rank {} of {}",
865 rank, world_size
866 );
867 Ok(())
869 }
870
871 pub async fn cleanup_distributed() -> Result<()> {
873 info!("Cleaning up distributed training environment");
874 Ok(())
876 }
877
878 pub fn is_distributed_available() -> bool {
880 true
882 }
883
884 pub fn get_optimal_world_size() -> usize {
886 std::thread::available_parallelism()
888 .map(|p| p.get())
889 .unwrap_or(1)
890 }
891}
892
893#[cfg(test)]
894mod tests {
895 use super::*;
896
897 #[test]
898 fn test_learning_rate_scheduler() {
899 let scheduler = LearningRateScheduler::ExponentialDecay {
900 decay_rate: 0.9,
901 decay_steps: 10,
902 };
903
904 let lr0 = scheduler.get_lr(0, 0.1, None);
905 let lr10 = scheduler.get_lr(10, 0.1, None);
906 let lr20 = scheduler.get_lr(20, 0.1, None);
907
908 assert!((lr0 - 0.1).abs() < 1e-10);
909 assert!(lr10 < lr0);
910 assert!(lr20 < lr10);
911 }
912
913 #[test]
914 fn test_early_stopping() {
915 let mut early_stop = EarlyStopping::new(3, 0.01);
916
917 assert!(!early_stop.update(1.0));
918 assert!(!early_stop.update(0.5));
919 assert!(!early_stop.update(0.51));
920 assert!(!early_stop.update(0.52));
921 assert!(!early_stop.update(0.53));
922 assert!(early_stop.update(0.54)); }
924
925 #[test]
926 fn test_metrics_tracker() {
927 let mut tracker = MetricsTracker::new();
928
929 tracker.record_epoch(0, 1.0, 0.01, 1.5);
930 tracker.record_epoch(1, 0.5, 0.009, 1.4);
931 tracker.record_epoch(2, 0.3, 0.008, 1.3);
932
933 assert_eq!(tracker.losses.len(), 3);
934 assert_eq!(tracker.epochs.len(), 3);
935
936 let smoothed = tracker.get_smoothed_loss(2);
937 assert_eq!(smoothed.len(), 3);
938 }
939
940 #[test]
941 fn test_distributed_config() {
942 let config = DistributedConfig::default();
943 assert_eq!(config.world_size, 1);
944 assert_eq!(config.rank, 0);
945 assert_eq!(config.device_ids.len(), 1);
946 }
947
948 #[test]
949 fn test_gradient_accumulator() {
950 let mut accumulator = GradientAccumulator::new(2);
951 assert!(!accumulator.is_ready());
952
953 let grad1 = vec![Array2::from_elem((2, 2), 1.0)];
954 let grad2 = vec![Array2::from_elem((2, 2), 2.0)];
955
956 accumulator.accumulate(grad1);
957 assert!(!accumulator.is_ready());
958
959 accumulator.accumulate(grad2);
960 assert!(accumulator.is_ready());
961
962 let averaged = accumulator.get_averaged_gradients();
963 assert_eq!(averaged.len(), 1);
964 assert!((averaged[0][[0, 0]] - 1.5).abs() < 1e-10);
965 }
966
967 #[test]
968 fn test_distributed_utils() {
969 assert!(DistributedUtils::is_distributed_available());
970 let world_size = DistributedUtils::get_optimal_world_size();
971 assert!(world_size >= 1);
972 }
973}