1use anyhow::Result;
29use chrono::{DateTime, Utc};
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33use tokio::sync::{Mutex, RwLock};
34use tracing::{debug, info, warn};
35
36use scirs2_core::distributed::{ClusterConfiguration, ClusterManager};
38use scirs2_core::ndarray_ext::Array1;
39
40use crate::EmbeddingModel;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum DistributedStrategy {
45 DataParallel {
47 num_workers: usize,
49 batch_size: usize,
51 },
52 ModelParallel {
54 num_shards: usize,
56 pipeline_stages: usize,
58 },
59 Hybrid {
61 data_parallel_size: usize,
63 model_parallel_size: usize,
65 },
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub enum AggregationMethod {
71 AllReduce,
73 RingAllReduce,
75 ParameterServer {
77 num_servers: usize,
79 },
80 Hierarchical {
82 branching_factor: usize,
84 },
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub enum CommunicationBackend {
90 Tcp,
92 Nccl,
94 Gloo,
96 Mpi,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct FaultToleranceConfig {
103 pub enable_checkpointing: bool,
105 pub checkpoint_frequency: usize,
107 pub max_retries: usize,
109 pub elastic_scaling: bool,
111 pub heartbeat_interval: u64,
113 pub worker_timeout: u64,
115}
116
117impl Default for FaultToleranceConfig {
118 fn default() -> Self {
119 Self {
120 enable_checkpointing: true,
121 checkpoint_frequency: 10,
122 max_retries: 3,
123 elastic_scaling: false,
124 heartbeat_interval: 30,
125 worker_timeout: 300,
126 }
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct DistributedTrainingConfig {
133 pub strategy: DistributedStrategy,
135 pub aggregation: AggregationMethod,
137 pub backend: CommunicationBackend,
139 pub fault_tolerance: FaultToleranceConfig,
141 pub gradient_compression: bool,
143 pub compression_ratio: f32,
145 pub mixed_precision: bool,
147 pub gradient_clip: Option<f32>,
149 pub warmup_epochs: usize,
151 pub pipeline_parallelism: bool,
153 pub num_microbatches: usize,
155}
156
157impl Default for DistributedTrainingConfig {
158 fn default() -> Self {
159 Self {
160 strategy: DistributedStrategy::DataParallel {
161 num_workers: 4,
162 batch_size: 256,
163 },
164 aggregation: AggregationMethod::AllReduce,
165 backend: CommunicationBackend::Tcp,
166 fault_tolerance: FaultToleranceConfig::default(),
167 gradient_compression: false,
168 compression_ratio: 0.5,
169 mixed_precision: false,
170 gradient_clip: Some(1.0),
171 warmup_epochs: 5,
172 pipeline_parallelism: false,
173 num_microbatches: 4,
174 }
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct WorkerInfo {
181 pub worker_id: usize,
183 pub rank: usize,
185 pub address: String,
187 pub status: WorkerStatus,
189 pub num_gpus: usize,
191 pub memory_gb: f32,
193 pub last_heartbeat: DateTime<Utc>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
199pub enum WorkerStatus {
200 Idle,
202 Training,
204 Synchronizing,
206 Failed,
208 Recovering,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct TrainingCheckpoint {
215 pub checkpoint_id: String,
217 pub epoch: usize,
219 pub global_step: usize,
221 pub model_state: Vec<u8>,
223 pub optimizer_state: Vec<u8>,
225 pub loss: f64,
227 pub timestamp: DateTime<Utc>,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct DistributedTrainingStats {
234 pub total_epochs: usize,
236 pub total_steps: usize,
238 pub final_loss: f64,
240 pub training_time: f64,
242 pub num_workers: usize,
244 pub throughput: f64,
246 pub communication_time: f64,
248 pub computation_time: f64,
250 pub num_checkpoints: usize,
252 pub num_failures: usize,
254 pub loss_history: Vec<f64>,
256}
257
258pub struct DistributedTrainingCoordinator {
260 config: DistributedTrainingConfig,
261 workers: Arc<RwLock<HashMap<usize, WorkerInfo>>>,
262 checkpoints: Arc<Mutex<Vec<TrainingCheckpoint>>>,
263 cluster_manager: Arc<ClusterManager>,
264 stats: Arc<Mutex<DistributedTrainingStats>>,
265}
266
267impl DistributedTrainingCoordinator {
268 pub async fn new(config: DistributedTrainingConfig) -> Result<Self> {
270 info!("Initializing distributed training coordinator");
271
272 let cluster_config = ClusterConfiguration::default();
274 let cluster_manager = Arc::new(
275 ClusterManager::new(cluster_config)
276 .map_err(|e| anyhow::anyhow!("Failed to create cluster manager: {}", e))?,
277 );
278
279 Ok(Self {
280 config,
281 workers: Arc::new(RwLock::new(HashMap::new())),
282 checkpoints: Arc::new(Mutex::new(Vec::new())),
283 cluster_manager,
284 stats: Arc::new(Mutex::new(DistributedTrainingStats {
285 total_epochs: 0,
286 total_steps: 0,
287 final_loss: 0.0,
288 training_time: 0.0,
289 num_workers: 0,
290 throughput: 0.0,
291 communication_time: 0.0,
292 computation_time: 0.0,
293 num_checkpoints: 0,
294 num_failures: 0,
295 loss_history: Vec::new(),
296 })),
297 })
298 }
299
300 pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
302 info!(
303 "Registering worker {}: {}",
304 worker_info.worker_id, worker_info.address
305 );
306
307 let mut workers = self.workers.write().await;
308 workers.insert(worker_info.worker_id, worker_info);
309
310 let mut stats = self.stats.lock().await;
311 stats.num_workers = workers.len();
312
313 Ok(())
314 }
315
316 pub async fn deregister_worker(&self, worker_id: usize) -> Result<()> {
318 warn!("Deregistering worker {}", worker_id);
319
320 let mut workers = self.workers.write().await;
321 workers.remove(&worker_id);
322
323 let mut stats = self.stats.lock().await;
324 stats.num_workers = workers.len();
325 stats.num_failures += 1;
326
327 Ok(())
328 }
329
330 pub async fn update_worker_status(&self, worker_id: usize, status: WorkerStatus) -> Result<()> {
332 let mut workers = self.workers.write().await;
333 if let Some(worker) = workers.get_mut(&worker_id) {
334 worker.status = status;
335 worker.last_heartbeat = Utc::now();
336 }
337 Ok(())
338 }
339
340 pub async fn train<M: EmbeddingModel>(
342 &mut self,
343 model: &mut M,
344 epochs: usize,
345 ) -> Result<DistributedTrainingStats> {
346 info!("Starting distributed training for {} epochs", epochs);
347
348 let start_time = std::time::Instant::now();
349 let mut total_comm_time = 0.0;
350 let mut total_comp_time = 0.0;
351
352 self.initialize_optimizer().await?;
354
355 for epoch in 0..epochs {
356 debug!("Epoch {}/{}", epoch + 1, epochs);
357
358 let comp_start = std::time::Instant::now();
360 let batch_results = self.distribute_training_batch(model, epoch).await?;
361 let comp_time = comp_start.elapsed().as_secs_f64();
362 total_comp_time += comp_time;
363
364 let comm_start = std::time::Instant::now();
366 let avg_loss = self.aggregate_gradients(&batch_results).await?;
367 let comm_time = comm_start.elapsed().as_secs_f64();
368 total_comm_time += comm_time;
369
370 {
372 let mut stats = self.stats.lock().await;
373 stats.total_epochs = epoch + 1;
374 stats.loss_history.push(avg_loss);
375 stats.final_loss = avg_loss;
376 }
377
378 if self.config.fault_tolerance.enable_checkpointing
380 && (epoch + 1) % self.config.fault_tolerance.checkpoint_frequency == 0
381 {
382 self.save_checkpoint(model, epoch, avg_loss).await?;
383 }
384
385 info!(
386 "Epoch {}: loss={:.6}, comp_time={:.2}s, comm_time={:.2}s",
387 epoch + 1,
388 avg_loss,
389 comp_time,
390 comm_time
391 );
392 }
393
394 let elapsed = start_time.elapsed().as_secs_f64();
395
396 let stats = {
398 let mut stats = self.stats.lock().await;
399 stats.training_time = elapsed;
400 stats.communication_time = total_comm_time;
401 stats.computation_time = total_comp_time;
402 stats.throughput = (epochs as f64) / elapsed;
403 stats.clone()
404 };
405
406 info!("Distributed training completed in {:.2}s", elapsed);
407 info!("Final loss: {:.6}", stats.final_loss);
408 info!("Throughput: {:.2} epochs/sec", stats.throughput);
409
410 Ok(stats)
411 }
412
413 async fn initialize_optimizer(&mut self) -> Result<()> {
415 debug!("Initializing distributed optimizer");
416
417 Ok(())
421 }
422
423 async fn distribute_training_batch<M: EmbeddingModel>(
425 &self,
426 _model: &M,
427 epoch: usize,
428 ) -> Result<Vec<WorkerResult>> {
429 let workers = self.workers.read().await;
430 let num_workers = workers.len();
431
432 if num_workers == 0 {
433 return Err(anyhow::anyhow!("No workers available"));
434 }
435
436 let mut results = Vec::new();
439 for (worker_id, _) in workers.iter() {
440 results.push(WorkerResult {
441 worker_id: *worker_id,
442 epoch,
443 loss: 0.1 * (1.0 - epoch as f64 / 100.0).max(0.01),
444 num_samples: 1000,
445 gradients: HashMap::new(),
446 });
447 }
448
449 Ok(results)
450 }
451
452 async fn aggregate_gradients(&self, results: &[WorkerResult]) -> Result<f64> {
454 if results.is_empty() {
455 return Err(anyhow::anyhow!("No results to aggregate"));
456 }
457
458 let avg_loss = results.iter().map(|r| r.loss).sum::<f64>() / results.len() as f64;
460
461 match &self.config.aggregation {
464 AggregationMethod::AllReduce => {
465 debug!("Using AllReduce for gradient aggregation");
466 }
469 AggregationMethod::RingAllReduce => {
470 debug!("Using Ring-AllReduce for gradient aggregation");
471 }
473 AggregationMethod::ParameterServer { num_servers } => {
474 debug!("Using Parameter Server with {} servers", num_servers);
475 }
477 AggregationMethod::Hierarchical { branching_factor } => {
478 debug!(
479 "Using Hierarchical aggregation with branching factor {}",
480 branching_factor
481 );
482 }
484 }
485
486 Ok(avg_loss)
487 }
488
489 async fn save_checkpoint<M: EmbeddingModel>(
491 &self,
492 _model: &M,
493 epoch: usize,
494 loss: f64,
495 ) -> Result<()> {
496 info!("Saving checkpoint at epoch {}", epoch);
497
498 let checkpoint = TrainingCheckpoint {
499 checkpoint_id: format!("checkpoint_epoch_{}", epoch),
500 epoch,
501 global_step: epoch * 1000, model_state: Vec::new(), optimizer_state: Vec::new(), loss,
505 timestamp: Utc::now(),
506 };
507
508 let mut checkpoints = self.checkpoints.lock().await;
509 checkpoints.push(checkpoint);
510
511 let mut stats = self.stats.lock().await;
512 stats.num_checkpoints += 1;
513
514 Ok(())
515 }
516
517 pub async fn load_checkpoint(&self, checkpoint_id: &str) -> Result<TrainingCheckpoint> {
519 let checkpoints = self.checkpoints.lock().await;
520 checkpoints
521 .iter()
522 .find(|c| c.checkpoint_id == checkpoint_id)
523 .cloned()
524 .ok_or_else(|| anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id))
525 }
526
527 pub async fn get_worker_stats(&self) -> HashMap<usize, WorkerInfo> {
529 self.workers.read().await.clone()
530 }
531
532 pub async fn get_stats(&self) -> DistributedTrainingStats {
534 self.stats.lock().await.clone()
535 }
536
537 pub async fn monitor_workers(&self) -> Result<()> {
539 let timeout_duration =
540 std::time::Duration::from_secs(self.config.fault_tolerance.worker_timeout);
541
542 let workers = self.workers.read().await;
543 let now = Utc::now();
544
545 for (worker_id, worker) in workers.iter() {
546 let elapsed = now.signed_duration_since(worker.last_heartbeat);
547 if elapsed.num_seconds() as u64 > timeout_duration.as_secs() {
548 warn!(
549 "Worker {} timed out (last heartbeat: {:?})",
550 worker_id, worker.last_heartbeat
551 );
552 }
554 }
555
556 Ok(())
557 }
558}
559
560#[derive(Debug, Clone)]
562struct WorkerResult {
563 worker_id: usize,
564 epoch: usize,
565 loss: f64,
566 num_samples: usize,
567 gradients: HashMap<String, Array1<f32>>,
568}
569
570#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
576pub enum AllReduceStrategy {
577 RingAllReduce,
579 TreeAllReduce,
581 ParameterServer,
583}
584
585#[derive(Debug, Clone, Default)]
587pub struct GradientAggregator;
588
589impl GradientAggregator {
590 pub fn new() -> Self {
592 Self
593 }
594
595 pub fn aggregate_gradients(
602 &self,
603 local_grad: &[f64],
604 strategy: &AllReduceStrategy,
605 ) -> Vec<f64> {
606 match strategy {
607 AllReduceStrategy::RingAllReduce => {
608 self.ring_all_reduce(vec![local_grad.to_vec()])
610 }
611 AllReduceStrategy::TreeAllReduce => self.tree_all_reduce(vec![local_grad.to_vec()]),
612 AllReduceStrategy::ParameterServer => {
613 local_grad.to_vec()
615 }
616 }
617 }
618
619 pub fn ring_all_reduce(&self, gradients: Vec<Vec<f64>>) -> Vec<f64> {
634 let n = gradients.len();
635 if n == 0 {
636 return Vec::new();
637 }
638 if n == 1 {
639 return gradients.into_iter().next().unwrap_or_default();
640 }
641
642 let len = gradients[0].len();
643 if len == 0 {
644 return Vec::new();
645 }
646
647 let base = len / n;
650 let remainder = len % n;
651 let chunk_sizes: Vec<usize> = (0..n)
652 .map(|i| base + if i < remainder { 1 } else { 0 })
653 .collect();
654 let mut chunk_start = vec![0usize; n];
655 for i in 1..n {
656 chunk_start[i] = chunk_start[i - 1] + chunk_sizes[i - 1];
657 }
658
659 let mut partial: Vec<Vec<Vec<f64>>> = gradients
662 .iter()
663 .map(|g| {
664 chunk_sizes
665 .iter()
666 .zip(chunk_start.iter())
667 .map(|(&sz, &s)| g[s..s + sz].to_vec())
668 .collect()
669 })
670 .collect();
671
672 #[allow(clippy::needless_range_loop)]
676 for step in 0..(n - 1) {
677 let prev = partial.clone();
678 for w in 0..n {
679 let left = (w + n - 1) % n;
680 let c = (w + n - 1 - step) % n;
681 let sz = chunk_sizes[c];
682 for i in 0..sz {
683 partial[w][c][i] += prev[left][c][i];
684 }
685 }
686 }
687
688 let mut result = vec![0.0_f64; len];
693 #[allow(clippy::needless_range_loop)]
694 for w in 0..n {
695 let c = (w + 1) % n;
696 let s = chunk_start[c];
697 let sz = chunk_sizes[c];
698 for i in 0..sz {
699 result[s + i] = partial[w][c][i] / n as f64;
700 }
701 }
702
703 result
704 }
705
706 fn tree_all_reduce(&self, gradients: Vec<Vec<f64>>) -> Vec<f64> {
708 let n = gradients.len();
709 if n == 0 {
710 return Vec::new();
711 }
712 if n == 1 {
713 return gradients.into_iter().next().unwrap_or_default();
714 }
715
716 let len = gradients[0].len();
717 let mut sums = vec![0.0_f64; len];
718 for grad in &gradients {
719 for (i, v) in grad.iter().enumerate() {
720 if i < len {
721 sums[i] += v;
722 }
723 }
724 }
725 sums.iter_mut().for_each(|v| *v /= n as f64);
726 sums
727 }
728}
729
730#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
732pub struct SparseGradient {
733 pub indices: Vec<usize>,
735 pub values: Vec<f64>,
737 pub original_len: usize,
739}
740
741#[derive(Debug, Clone, Default)]
743pub struct GradientCompressor;
744
745impl GradientCompressor {
746 pub fn new() -> Self {
748 Self
749 }
750
751 pub fn compress(&self, grad: &[f64], sparsity: f64) -> SparseGradient {
756 let sparsity = sparsity.clamp(0.0, 0.9999);
757 let n = grad.len();
758 if n == 0 {
759 return SparseGradient {
760 indices: Vec::new(),
761 values: Vec::new(),
762 original_len: 0,
763 };
764 }
765
766 let keep = ((1.0 - sparsity) * n as f64).ceil() as usize;
767 let keep = keep.max(1).min(n);
768
769 let mut indexed: Vec<(usize, f64)> = grad
771 .iter()
772 .enumerate()
773 .map(|(i, &v)| (i, v.abs()))
774 .collect();
775 indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
776
777 let mut indices: Vec<usize> = indexed[..keep].iter().map(|(i, _)| *i).collect();
778 indices.sort_unstable();
779
780 let values: Vec<f64> = indices.iter().map(|&i| grad[i]).collect();
781
782 SparseGradient {
783 indices,
784 values,
785 original_len: n,
786 }
787 }
788
789 pub fn decompress(&self, sparse: &SparseGradient) -> Vec<f64> {
791 let mut dense = vec![0.0_f64; sparse.original_len];
792 for (&idx, &val) in sparse.indices.iter().zip(sparse.values.iter()) {
793 if idx < sparse.original_len {
794 dense[idx] = val;
795 }
796 }
797 dense
798 }
799}
800
801#[derive(Debug, Clone, Serialize, Deserialize)]
809pub struct DistributedTrainingSample {
810 pub features: Vec<f64>,
812 pub label: f64,
814 pub weight: Option<f64>,
816}
817
818impl DistributedTrainingSample {
819 pub fn new(features: Vec<f64>, label: f64) -> Self {
821 Self {
822 features,
823 label,
824 weight: None,
825 }
826 }
827}
828
829#[derive(Debug, Clone, Serialize, Deserialize)]
831pub struct WorkerUpdate {
832 pub worker_id: u32,
834 pub gradients: Vec<f64>,
836 pub loss: f64,
838 pub samples_processed: u32,
840}
841
842#[derive(Debug, Clone, Serialize, Deserialize)]
844pub struct ModelUpdate {
845 pub averaged_gradients: Vec<f64>,
847 pub mean_loss: f64,
849 pub total_samples: u32,
851}
852
853#[derive(Debug, Clone, Default)]
855pub struct DataParallelTrainer;
856
857impl DataParallelTrainer {
858 pub fn new() -> Self {
860 Self
861 }
862
863 pub fn split_batch(
869 &self,
870 data: &[DistributedTrainingSample],
871 n_workers: u32,
872 ) -> Vec<Vec<DistributedTrainingSample>> {
873 let n = n_workers as usize;
874 if n == 0 || data.is_empty() {
875 return Vec::new();
876 }
877
878 let mut buckets: Vec<Vec<DistributedTrainingSample>> = (0..n).map(|_| Vec::new()).collect();
879 for (i, sample) in data.iter().enumerate() {
880 buckets[i % n].push(sample.clone());
881 }
882 buckets
883 }
884
885 pub fn merge_worker_updates(&self, updates: Vec<WorkerUpdate>) -> ModelUpdate {
890 if updates.is_empty() {
891 return ModelUpdate {
892 averaged_gradients: Vec::new(),
893 mean_loss: 0.0,
894 total_samples: 0,
895 };
896 }
897
898 let total_samples: u32 = updates.iter().map(|u| u.samples_processed).sum();
899 if total_samples == 0 {
900 return ModelUpdate {
901 averaged_gradients: Vec::new(),
902 mean_loss: 0.0,
903 total_samples: 0,
904 };
905 }
906
907 let grad_len = updates.iter().map(|u| u.gradients.len()).max().unwrap_or(0);
909
910 let mut averaged_gradients = vec![0.0_f64; grad_len];
911 let mut weighted_loss = 0.0_f64;
912
913 for update in &updates {
914 let weight = update.samples_processed as f64 / total_samples as f64;
915 for (i, &g) in update.gradients.iter().enumerate() {
916 if i < grad_len {
917 averaged_gradients[i] += g * weight;
918 }
919 }
920 weighted_loss += update.loss * weight;
921 }
922
923 ModelUpdate {
924 averaged_gradients,
925 mean_loss: weighted_loss,
926 total_samples,
927 }
928 }
929}
930
931pub struct DistributedEmbeddingTrainer<M: EmbeddingModel> {
933 model: M,
934 coordinator: DistributedTrainingCoordinator,
935}
936
937impl<M: EmbeddingModel> DistributedEmbeddingTrainer<M> {
938 pub async fn new(model: M, config: DistributedTrainingConfig) -> Result<Self> {
940 let coordinator = DistributedTrainingCoordinator::new(config).await?;
941
942 Ok(Self { model, coordinator })
943 }
944
945 pub async fn train(&mut self, epochs: usize) -> Result<DistributedTrainingStats> {
947 self.coordinator.train(&mut self.model, epochs).await
948 }
949
950 pub fn model(&self) -> &M {
952 &self.model
953 }
954
955 pub fn model_mut(&mut self) -> &mut M {
957 &mut self.model
958 }
959
960 pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
962 self.coordinator.register_worker(worker_info).await
963 }
964
965 pub async fn get_stats(&self) -> DistributedTrainingStats {
967 self.coordinator.get_stats().await
968 }
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974 use crate::{ModelConfig, TransE};
975
976 #[test]
979 fn test_all_reduce_strategy_variants() {
980 let strategies = [
981 AllReduceStrategy::RingAllReduce,
982 AllReduceStrategy::TreeAllReduce,
983 AllReduceStrategy::ParameterServer,
984 ];
985 for s in &strategies {
986 let agg = GradientAggregator::new();
987 let grad = vec![1.0, 2.0, 3.0];
988 let result = agg.aggregate_gradients(&grad, s);
989 assert_eq!(result.len(), 3);
990 }
991 }
992
993 #[test]
994 fn test_ring_all_reduce_single_worker() {
995 let agg = GradientAggregator::new();
996 let grads = vec![vec![1.0, 2.0, 3.0]];
997 let result = agg.ring_all_reduce(grads);
998 assert_eq!(result, vec![1.0, 2.0, 3.0]);
999 }
1000
1001 #[test]
1002 fn test_ring_all_reduce_two_workers() {
1003 let agg = GradientAggregator::new();
1004 let grads = vec![vec![2.0, 4.0, 6.0], vec![2.0, 4.0, 6.0]];
1005 let result = agg.ring_all_reduce(grads);
1006 assert_eq!(result.len(), 3);
1007 for (r, expected) in result.iter().zip([2.0, 4.0, 6.0].iter()) {
1009 assert!((r - expected).abs() < 1e-9, "expected {expected}, got {r}");
1010 }
1011 }
1012
1013 #[test]
1014 fn test_ring_all_reduce_four_workers_mean() {
1015 let agg = GradientAggregator::new();
1016 let grads = vec![
1017 vec![4.0, 8.0],
1018 vec![2.0, 4.0],
1019 vec![0.0, 0.0],
1020 vec![6.0, 12.0],
1021 ];
1022 let result = agg.ring_all_reduce(grads);
1023 assert_eq!(result.len(), 2);
1024 assert!((result[0] - 3.0).abs() < 1e-6);
1026 assert!((result[1] - 6.0).abs() < 1e-6);
1027 }
1028
1029 #[test]
1030 fn test_ring_all_reduce_empty_input() {
1031 let agg = GradientAggregator::new();
1032 let result = agg.ring_all_reduce(vec![]);
1033 assert!(result.is_empty());
1034 }
1035
1036 #[test]
1037 fn test_ring_all_reduce_empty_gradient_vectors() {
1038 let agg = GradientAggregator::new();
1039 let result = agg.ring_all_reduce(vec![vec![], vec![]]);
1040 assert!(result.is_empty());
1041 }
1042
1043 #[test]
1044 fn test_aggregate_gradients_ring() {
1045 let agg = GradientAggregator::new();
1046 let grad = vec![1.0, 2.0, 3.0, 4.0];
1047 let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::RingAllReduce);
1048 assert_eq!(result.len(), 4);
1049 }
1050
1051 #[test]
1052 fn test_aggregate_gradients_tree() {
1053 let agg = GradientAggregator::new();
1054 let grad = vec![5.0, 10.0];
1055 let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::TreeAllReduce);
1056 assert_eq!(result, vec![5.0, 10.0]);
1057 }
1058
1059 #[test]
1060 fn test_aggregate_gradients_parameter_server() {
1061 let agg = GradientAggregator::new();
1062 let grad = vec![3.0, 1.0, 4.0];
1063 let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::ParameterServer);
1064 assert_eq!(result, grad);
1065 }
1066
1067 #[test]
1070 fn test_compress_empty_gradient() {
1071 let comp = GradientCompressor::new();
1072 let sparse = comp.compress(&[], 0.9);
1073 assert!(sparse.indices.is_empty());
1074 assert_eq!(sparse.original_len, 0);
1075 }
1076
1077 #[test]
1078 fn test_compress_keep_all() {
1079 let comp = GradientCompressor::new();
1080 let grad = vec![1.0, -2.0, 3.0, -4.0];
1081 let sparse = comp.compress(&grad, 0.0);
1082 assert_eq!(sparse.indices.len(), 4);
1084 assert_eq!(sparse.original_len, 4);
1085 }
1086
1087 #[test]
1088 fn test_compress_top_k_selects_largest() {
1089 let comp = GradientCompressor::new();
1090 let grad = vec![0.1, 5.0, 0.2, 9.0, 0.3];
1091 let sparse = comp.compress(&grad, 0.6);
1093 assert_eq!(sparse.indices.len(), 2);
1094 assert!(sparse.indices.contains(&3)); assert!(sparse.indices.contains(&1)); }
1097
1098 #[test]
1099 fn test_decompress_roundtrip() {
1100 let comp = GradientCompressor::new();
1101 let grad = vec![0.0, 1.0, 0.0, -3.0, 0.0];
1102 let sparse = comp.compress(&grad, 0.6);
1103 let dense = comp.decompress(&sparse);
1104 assert_eq!(dense.len(), 5);
1105 assert!((dense[3] - (-3.0)).abs() < 1e-12);
1107 assert!((dense[1] - 1.0).abs() < 1e-12);
1108 }
1109
1110 #[test]
1111 fn test_decompress_empty_sparse() {
1112 let comp = GradientCompressor::new();
1113 let sparse = SparseGradient {
1114 indices: Vec::new(),
1115 values: Vec::new(),
1116 original_len: 5,
1117 };
1118 let dense = comp.decompress(&sparse);
1119 assert_eq!(dense, vec![0.0; 5]);
1120 }
1121
1122 #[test]
1123 fn test_sparse_gradient_serialization() {
1124 let sg = SparseGradient {
1125 indices: vec![0, 2],
1126 values: vec![1.5, -2.5],
1127 original_len: 4,
1128 };
1129 let json = serde_json::to_string(&sg).expect("serialize");
1130 let sg2: SparseGradient = serde_json::from_str(&json).expect("deserialize");
1131 assert_eq!(sg, sg2);
1132 }
1133
1134 #[test]
1137 fn test_split_batch_even() {
1138 let trainer = DataParallelTrainer::new();
1139 let samples: Vec<DistributedTrainingSample> = (0..8)
1140 .map(|i| DistributedTrainingSample::new(vec![i as f64], i as f64))
1141 .collect();
1142 let batches = trainer.split_batch(&samples, 4);
1143 assert_eq!(batches.len(), 4);
1144 for b in &batches {
1145 assert_eq!(b.len(), 2);
1146 }
1147 }
1148
1149 #[test]
1150 fn test_split_batch_uneven() {
1151 let trainer = DataParallelTrainer::new();
1152 let samples: Vec<DistributedTrainingSample> = (0..10)
1153 .map(|i| DistributedTrainingSample::new(vec![i as f64], i as f64))
1154 .collect();
1155 let batches = trainer.split_batch(&samples, 3);
1156 assert_eq!(batches.len(), 3);
1157 let total: usize = batches.iter().map(|b| b.len()).sum();
1158 assert_eq!(total, 10);
1159 }
1160
1161 #[test]
1162 fn test_split_batch_zero_workers() {
1163 let trainer = DataParallelTrainer::new();
1164 let samples = vec![DistributedTrainingSample::new(vec![1.0], 0.0)];
1165 let batches = trainer.split_batch(&samples, 0);
1166 assert!(batches.is_empty());
1167 }
1168
1169 #[test]
1170 fn test_split_batch_empty_data() {
1171 let trainer = DataParallelTrainer::new();
1172 let batches = trainer.split_batch(&[], 4);
1173 assert!(batches.is_empty());
1174 }
1175
1176 #[test]
1177 fn test_merge_worker_updates_basic() {
1178 let trainer = DataParallelTrainer::new();
1179 let updates = vec![
1180 WorkerUpdate {
1181 worker_id: 0,
1182 gradients: vec![2.0, 4.0],
1183 loss: 1.0,
1184 samples_processed: 10,
1185 },
1186 WorkerUpdate {
1187 worker_id: 1,
1188 gradients: vec![2.0, 4.0],
1189 loss: 1.0,
1190 samples_processed: 10,
1191 },
1192 ];
1193 let merged = trainer.merge_worker_updates(updates);
1194 assert_eq!(merged.total_samples, 20);
1195 assert!((merged.mean_loss - 1.0).abs() < 1e-9);
1196 assert!((merged.averaged_gradients[0] - 2.0).abs() < 1e-9);
1197 assert!((merged.averaged_gradients[1] - 4.0).abs() < 1e-9);
1198 }
1199
1200 #[test]
1201 fn test_merge_worker_updates_weighted() {
1202 let trainer = DataParallelTrainer::new();
1203 let updates = vec![
1205 WorkerUpdate {
1206 worker_id: 0,
1207 gradients: vec![4.0],
1208 loss: 2.0,
1209 samples_processed: 1,
1210 },
1211 WorkerUpdate {
1212 worker_id: 1,
1213 gradients: vec![0.0],
1214 loss: 0.0,
1215 samples_processed: 3,
1216 },
1217 ];
1218 let merged = trainer.merge_worker_updates(updates);
1219 assert_eq!(merged.total_samples, 4);
1220 assert!((merged.averaged_gradients[0] - 1.0).abs() < 1e-9);
1222 assert!((merged.mean_loss - 0.5).abs() < 1e-9);
1224 }
1225
1226 #[test]
1227 fn test_merge_worker_updates_empty() {
1228 let trainer = DataParallelTrainer::new();
1229 let merged = trainer.merge_worker_updates(vec![]);
1230 assert_eq!(merged.total_samples, 0);
1231 assert!(merged.averaged_gradients.is_empty());
1232 }
1233
1234 #[test]
1235 fn test_worker_update_serialization() {
1236 let update = WorkerUpdate {
1237 worker_id: 7,
1238 gradients: vec![0.1, -0.2],
1239 loss: 0.42,
1240 samples_processed: 32,
1241 };
1242 let json = serde_json::to_string(&update).expect("serialize");
1243 let update2: WorkerUpdate = serde_json::from_str(&json).expect("deserialize");
1244 assert_eq!(update.worker_id, update2.worker_id);
1245 assert_eq!(update.samples_processed, update2.samples_processed);
1246 }
1247
1248 #[test]
1249 fn test_model_update_fields() {
1250 let mu = ModelUpdate {
1251 averaged_gradients: vec![1.0, 2.0],
1252 mean_loss: 0.5,
1253 total_samples: 100,
1254 };
1255 assert_eq!(mu.total_samples, 100);
1256 assert!((mu.mean_loss - 0.5).abs() < 1e-12);
1257 }
1258
1259 #[tokio::test]
1260 async fn test_distributed_coordinator_creation() {
1261 let config = DistributedTrainingConfig::default();
1262 let coordinator = DistributedTrainingCoordinator::new(config).await;
1263 assert!(coordinator.is_ok());
1264 }
1265
1266 #[tokio::test]
1267 async fn test_worker_registration() {
1268 let config = DistributedTrainingConfig::default();
1269 let coordinator = DistributedTrainingCoordinator::new(config).await.unwrap();
1270
1271 let worker = WorkerInfo {
1272 worker_id: 0,
1273 rank: 0,
1274 address: "127.0.0.1:8080".to_string(),
1275 status: WorkerStatus::Idle,
1276 num_gpus: 1,
1277 memory_gb: 16.0,
1278 last_heartbeat: Utc::now(),
1279 };
1280
1281 coordinator.register_worker(worker).await.unwrap();
1282 let stats = coordinator.get_worker_stats().await;
1283 assert_eq!(stats.len(), 1);
1284 }
1285
1286 #[tokio::test]
1287 async fn test_distributed_training() {
1288 let config = DistributedTrainingConfig {
1289 strategy: DistributedStrategy::DataParallel {
1290 num_workers: 2,
1291 batch_size: 128,
1292 },
1293 ..Default::default()
1294 };
1295
1296 let model_config = ModelConfig::default().with_dimensions(64);
1297 let model = TransE::new(model_config);
1298
1299 let mut trainer = DistributedEmbeddingTrainer::new(model, config)
1300 .await
1301 .unwrap();
1302
1303 for i in 0..2 {
1305 let worker = WorkerInfo {
1306 worker_id: i,
1307 rank: i,
1308 address: format!("127.0.0.1:808{}", i),
1309 status: WorkerStatus::Idle,
1310 num_gpus: 1,
1311 memory_gb: 16.0,
1312 last_heartbeat: Utc::now(),
1313 };
1314 trainer.register_worker(worker).await.unwrap();
1315 }
1316
1317 let stats = trainer.train(5).await.unwrap();
1319
1320 assert_eq!(stats.total_epochs, 5);
1321 assert!(stats.final_loss >= 0.0);
1322 assert_eq!(stats.num_workers, 2);
1323 }
1324
1325 #[tokio::test]
1326 async fn test_checkpoint_save_load() {
1327 let config = DistributedTrainingConfig::default();
1328 let coordinator = DistributedTrainingCoordinator::new(config).await.unwrap();
1329
1330 let model_config = ModelConfig::default();
1331 let model = TransE::new(model_config);
1332
1333 let worker = WorkerInfo {
1335 worker_id: 0,
1336 rank: 0,
1337 address: "127.0.0.1:8080".to_string(),
1338 status: WorkerStatus::Idle,
1339 num_gpus: 1,
1340 memory_gb: 16.0,
1341 last_heartbeat: Utc::now(),
1342 };
1343 coordinator.register_worker(worker).await.unwrap();
1344
1345 coordinator.save_checkpoint(&model, 10, 0.5).await.unwrap();
1347
1348 let checkpoint = coordinator
1350 .load_checkpoint("checkpoint_epoch_10")
1351 .await
1352 .unwrap();
1353 assert_eq!(checkpoint.epoch, 10);
1354 assert_eq!(checkpoint.loss, 0.5);
1355 }
1356}