1use anyhow::Result;
33use chrono::{DateTime, Utc};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::sync::Arc;
37use tokio::sync::{Mutex, RwLock};
38use tracing::{debug, info, warn};
39
40use scirs2_core::distributed::{ClusterConfiguration, ClusterManager};
42use scirs2_core::ndarray_ext::Array1;
43
44use crate::EmbeddingModel;
45
46pub mod parameter_server;
48pub mod shard_manager;
49pub mod worker;
50
51pub use parameter_server::{
52 ParameterServer, ParameterServerConfig, ParameterServerStats, ShardSnapshot, UpdateMode,
53};
54pub use shard_manager::{ModelShardManager, ShardAssignment, ShardingStrategy};
55pub use worker::{TripleSample, Worker, WorkerConfig, WorkerLoss};
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum DistributedStrategy {
60 DataParallel {
62 num_workers: usize,
64 batch_size: usize,
66 },
67 ModelParallel {
69 num_shards: usize,
71 pipeline_stages: usize,
73 },
74 Hybrid {
76 data_parallel_size: usize,
78 model_parallel_size: usize,
80 },
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub enum AggregationMethod {
86 AllReduce,
88 RingAllReduce,
90 ParameterServer {
92 num_servers: usize,
94 },
95 Hierarchical {
97 branching_factor: usize,
99 },
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum CommunicationBackend {
105 Tcp,
107 Nccl,
109 Gloo,
111 Mpi,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct FaultToleranceConfig {
118 pub enable_checkpointing: bool,
120 pub checkpoint_frequency: usize,
122 pub max_retries: usize,
124 pub elastic_scaling: bool,
126 pub heartbeat_interval: u64,
128 pub worker_timeout: u64,
130}
131
132impl Default for FaultToleranceConfig {
133 fn default() -> Self {
134 Self {
135 enable_checkpointing: true,
136 checkpoint_frequency: 10,
137 max_retries: 3,
138 elastic_scaling: false,
139 heartbeat_interval: 30,
140 worker_timeout: 300,
141 }
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DistributedTrainingConfig {
148 pub strategy: DistributedStrategy,
150 pub aggregation: AggregationMethod,
152 pub backend: CommunicationBackend,
154 pub fault_tolerance: FaultToleranceConfig,
156 pub gradient_compression: bool,
158 pub compression_ratio: f32,
160 pub mixed_precision: bool,
162 pub gradient_clip: Option<f32>,
164 pub warmup_epochs: usize,
166 pub pipeline_parallelism: bool,
168 pub num_microbatches: usize,
170}
171
172impl Default for DistributedTrainingConfig {
173 fn default() -> Self {
174 Self {
175 strategy: DistributedStrategy::DataParallel {
176 num_workers: 4,
177 batch_size: 256,
178 },
179 aggregation: AggregationMethod::AllReduce,
180 backend: CommunicationBackend::Tcp,
181 fault_tolerance: FaultToleranceConfig::default(),
182 gradient_compression: false,
183 compression_ratio: 0.5,
184 mixed_precision: false,
185 gradient_clip: Some(1.0),
186 warmup_epochs: 5,
187 pipeline_parallelism: false,
188 num_microbatches: 4,
189 }
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct WorkerInfo {
196 pub worker_id: usize,
198 pub rank: usize,
200 pub address: String,
202 pub status: WorkerStatus,
204 pub num_gpus: usize,
206 pub memory_gb: f32,
208 pub last_heartbeat: DateTime<Utc>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
214pub enum WorkerStatus {
215 Idle,
217 Training,
219 Synchronizing,
221 Failed,
223 Recovering,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct TrainingCheckpoint {
230 pub checkpoint_id: String,
232 pub epoch: usize,
234 pub global_step: usize,
236 pub model_state: Vec<u8>,
238 pub optimizer_state: Vec<u8>,
240 pub loss: f64,
242 pub timestamp: DateTime<Utc>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct DistributedTrainingStats {
249 pub total_epochs: usize,
251 pub total_steps: usize,
253 pub final_loss: f64,
255 pub training_time: f64,
257 pub num_workers: usize,
259 pub throughput: f64,
261 pub communication_time: f64,
263 pub computation_time: f64,
265 pub num_checkpoints: usize,
267 pub num_failures: usize,
269 pub loss_history: Vec<f64>,
271}
272
273pub struct DistributedTrainingCoordinator {
275 config: DistributedTrainingConfig,
276 workers: Arc<RwLock<HashMap<usize, WorkerInfo>>>,
277 checkpoints: Arc<Mutex<Vec<TrainingCheckpoint>>>,
278 cluster_manager: Arc<ClusterManager>,
279 stats: Arc<Mutex<DistributedTrainingStats>>,
280}
281
282impl DistributedTrainingCoordinator {
283 pub async fn new(config: DistributedTrainingConfig) -> Result<Self> {
285 info!("Initializing distributed training coordinator");
286
287 let cluster_config = ClusterConfiguration::default();
289 let cluster_manager = Arc::new(
290 ClusterManager::new(cluster_config)
291 .map_err(|e| anyhow::anyhow!("Failed to create cluster manager: {}", e))?,
292 );
293
294 Ok(Self {
295 config,
296 workers: Arc::new(RwLock::new(HashMap::new())),
297 checkpoints: Arc::new(Mutex::new(Vec::new())),
298 cluster_manager,
299 stats: Arc::new(Mutex::new(DistributedTrainingStats {
300 total_epochs: 0,
301 total_steps: 0,
302 final_loss: 0.0,
303 training_time: 0.0,
304 num_workers: 0,
305 throughput: 0.0,
306 communication_time: 0.0,
307 computation_time: 0.0,
308 num_checkpoints: 0,
309 num_failures: 0,
310 loss_history: Vec::new(),
311 })),
312 })
313 }
314
315 pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
317 info!(
318 "Registering worker {}: {}",
319 worker_info.worker_id, worker_info.address
320 );
321
322 let mut workers = self.workers.write().await;
323 workers.insert(worker_info.worker_id, worker_info);
324
325 let mut stats = self.stats.lock().await;
326 stats.num_workers = workers.len();
327
328 Ok(())
329 }
330
331 pub async fn deregister_worker(&self, worker_id: usize) -> Result<()> {
333 warn!("Deregistering worker {}", worker_id);
334
335 let mut workers = self.workers.write().await;
336 workers.remove(&worker_id);
337
338 let mut stats = self.stats.lock().await;
339 stats.num_workers = workers.len();
340 stats.num_failures += 1;
341
342 Ok(())
343 }
344
345 pub async fn update_worker_status(&self, worker_id: usize, status: WorkerStatus) -> Result<()> {
347 let mut workers = self.workers.write().await;
348 if let Some(worker) = workers.get_mut(&worker_id) {
349 worker.status = status;
350 worker.last_heartbeat = Utc::now();
351 }
352 Ok(())
353 }
354
355 pub async fn train<M: EmbeddingModel>(
357 &mut self,
358 model: &mut M,
359 epochs: usize,
360 ) -> Result<DistributedTrainingStats> {
361 info!("Starting distributed training for {} epochs", epochs);
362
363 let start_time = std::time::Instant::now();
364 let mut total_comm_time = 0.0;
365 let mut total_comp_time = 0.0;
366
367 self.initialize_optimizer().await?;
369
370 for epoch in 0..epochs {
371 debug!("Epoch {}/{}", epoch + 1, epochs);
372
373 let comp_start = std::time::Instant::now();
375 let batch_results = self.distribute_training_batch(model, epoch).await?;
376 let comp_time = comp_start.elapsed().as_secs_f64();
377 total_comp_time += comp_time;
378
379 let comm_start = std::time::Instant::now();
381 let avg_loss = self.aggregate_gradients(&batch_results).await?;
382 let comm_time = comm_start.elapsed().as_secs_f64();
383 total_comm_time += comm_time;
384
385 {
387 let mut stats = self.stats.lock().await;
388 stats.total_epochs = epoch + 1;
389 stats.loss_history.push(avg_loss);
390 stats.final_loss = avg_loss;
391 }
392
393 if self.config.fault_tolerance.enable_checkpointing
395 && (epoch + 1) % self.config.fault_tolerance.checkpoint_frequency == 0
396 {
397 self.save_checkpoint(model, epoch, avg_loss).await?;
398 }
399
400 info!(
401 "Epoch {}: loss={:.6}, comp_time={:.2}s, comm_time={:.2}s",
402 epoch + 1,
403 avg_loss,
404 comp_time,
405 comm_time
406 );
407 }
408
409 let elapsed = start_time.elapsed().as_secs_f64();
410
411 let stats = {
413 let mut stats = self.stats.lock().await;
414 stats.training_time = elapsed;
415 stats.communication_time = total_comm_time;
416 stats.computation_time = total_comp_time;
417 stats.throughput = (epochs as f64) / elapsed;
418 stats.clone()
419 };
420
421 info!("Distributed training completed in {:.2}s", elapsed);
422 info!("Final loss: {:.6}", stats.final_loss);
423 info!("Throughput: {:.2} epochs/sec", stats.throughput);
424
425 Ok(stats)
426 }
427
428 async fn initialize_optimizer(&mut self) -> Result<()> {
430 debug!("Initializing distributed optimizer");
431
432 Ok(())
436 }
437
438 async fn distribute_training_batch<M: EmbeddingModel>(
440 &self,
441 _model: &M,
442 epoch: usize,
443 ) -> Result<Vec<WorkerResult>> {
444 let workers = self.workers.read().await;
445 let num_workers = workers.len();
446
447 if num_workers == 0 {
448 return Err(anyhow::anyhow!("No workers available"));
449 }
450
451 let mut results = Vec::new();
454 for (worker_id, _) in workers.iter() {
455 results.push(WorkerResult {
456 worker_id: *worker_id,
457 epoch,
458 loss: 0.1 * (1.0 - epoch as f64 / 100.0).max(0.01),
459 num_samples: 1000,
460 gradients: HashMap::new(),
461 });
462 }
463
464 Ok(results)
465 }
466
467 async fn aggregate_gradients(&self, results: &[WorkerResult]) -> Result<f64> {
469 if results.is_empty() {
470 return Err(anyhow::anyhow!("No results to aggregate"));
471 }
472
473 let avg_loss = results.iter().map(|r| r.loss).sum::<f64>() / results.len() as f64;
475
476 match &self.config.aggregation {
479 AggregationMethod::AllReduce => {
480 debug!("Using AllReduce for gradient aggregation");
481 }
484 AggregationMethod::RingAllReduce => {
485 debug!("Using Ring-AllReduce for gradient aggregation");
486 }
488 AggregationMethod::ParameterServer { num_servers } => {
489 debug!("Using Parameter Server with {} servers", num_servers);
490 }
492 AggregationMethod::Hierarchical { branching_factor } => {
493 debug!(
494 "Using Hierarchical aggregation with branching factor {}",
495 branching_factor
496 );
497 }
499 }
500
501 Ok(avg_loss)
502 }
503
504 async fn save_checkpoint<M: EmbeddingModel>(
506 &self,
507 _model: &M,
508 epoch: usize,
509 loss: f64,
510 ) -> Result<()> {
511 info!("Saving checkpoint at epoch {}", epoch);
512
513 let checkpoint = TrainingCheckpoint {
514 checkpoint_id: format!("checkpoint_epoch_{}", epoch),
515 epoch,
516 global_step: epoch * 1000, model_state: Vec::new(), optimizer_state: Vec::new(), loss,
520 timestamp: Utc::now(),
521 };
522
523 let mut checkpoints = self.checkpoints.lock().await;
524 checkpoints.push(checkpoint);
525
526 let mut stats = self.stats.lock().await;
527 stats.num_checkpoints += 1;
528
529 Ok(())
530 }
531
532 pub async fn load_checkpoint(&self, checkpoint_id: &str) -> Result<TrainingCheckpoint> {
534 let checkpoints = self.checkpoints.lock().await;
535 checkpoints
536 .iter()
537 .find(|c| c.checkpoint_id == checkpoint_id)
538 .cloned()
539 .ok_or_else(|| anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id))
540 }
541
542 pub async fn get_worker_stats(&self) -> HashMap<usize, WorkerInfo> {
544 self.workers.read().await.clone()
545 }
546
547 pub async fn get_stats(&self) -> DistributedTrainingStats {
549 self.stats.lock().await.clone()
550 }
551
552 pub async fn monitor_workers(&self) -> Result<()> {
554 let timeout_duration =
555 std::time::Duration::from_secs(self.config.fault_tolerance.worker_timeout);
556
557 let workers = self.workers.read().await;
558 let now = Utc::now();
559
560 for (worker_id, worker) in workers.iter() {
561 let elapsed = now.signed_duration_since(worker.last_heartbeat);
562 if elapsed.num_seconds() as u64 > timeout_duration.as_secs() {
563 warn!(
564 "Worker {} timed out (last heartbeat: {:?})",
565 worker_id, worker.last_heartbeat
566 );
567 }
569 }
570
571 Ok(())
572 }
573}
574
575#[derive(Debug, Clone)]
577struct WorkerResult {
578 worker_id: usize,
579 epoch: usize,
580 loss: f64,
581 num_samples: usize,
582 gradients: HashMap<String, Array1<f32>>,
583}
584
585#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
591pub enum AllReduceStrategy {
592 RingAllReduce,
594 TreeAllReduce,
596 ParameterServer,
598}
599
600#[derive(Debug, Clone, Default)]
602pub struct GradientAggregator;
603
604impl GradientAggregator {
605 pub fn new() -> Self {
607 Self
608 }
609
610 pub fn aggregate_gradients(
617 &self,
618 local_grad: &[f64],
619 strategy: &AllReduceStrategy,
620 ) -> Vec<f64> {
621 match strategy {
622 AllReduceStrategy::RingAllReduce => {
623 self.ring_all_reduce(vec![local_grad.to_vec()])
625 }
626 AllReduceStrategy::TreeAllReduce => self.tree_all_reduce(vec![local_grad.to_vec()]),
627 AllReduceStrategy::ParameterServer => {
628 local_grad.to_vec()
630 }
631 }
632 }
633
634 pub fn ring_all_reduce(&self, gradients: Vec<Vec<f64>>) -> Vec<f64> {
649 let n = gradients.len();
650 if n == 0 {
651 return Vec::new();
652 }
653 if n == 1 {
654 return gradients.into_iter().next().unwrap_or_default();
655 }
656
657 let len = gradients[0].len();
658 if len == 0 {
659 return Vec::new();
660 }
661
662 let base = len / n;
665 let remainder = len % n;
666 let chunk_sizes: Vec<usize> = (0..n)
667 .map(|i| base + if i < remainder { 1 } else { 0 })
668 .collect();
669 let mut chunk_start = vec![0usize; n];
670 for i in 1..n {
671 chunk_start[i] = chunk_start[i - 1] + chunk_sizes[i - 1];
672 }
673
674 let mut partial: Vec<Vec<Vec<f64>>> = gradients
677 .iter()
678 .map(|g| {
679 chunk_sizes
680 .iter()
681 .zip(chunk_start.iter())
682 .map(|(&sz, &s)| g[s..s + sz].to_vec())
683 .collect()
684 })
685 .collect();
686
687 #[allow(clippy::needless_range_loop)]
691 for step in 0..(n - 1) {
692 let prev = partial.clone();
693 for w in 0..n {
694 let left = (w + n - 1) % n;
695 let c = (w + n - 1 - step) % n;
696 let sz = chunk_sizes[c];
697 for i in 0..sz {
698 partial[w][c][i] += prev[left][c][i];
699 }
700 }
701 }
702
703 let mut result = vec![0.0_f64; len];
708 #[allow(clippy::needless_range_loop)]
709 for w in 0..n {
710 let c = (w + 1) % n;
711 let s = chunk_start[c];
712 let sz = chunk_sizes[c];
713 for i in 0..sz {
714 result[s + i] = partial[w][c][i] / n as f64;
715 }
716 }
717
718 result
719 }
720
721 fn tree_all_reduce(&self, gradients: Vec<Vec<f64>>) -> Vec<f64> {
723 let n = gradients.len();
724 if n == 0 {
725 return Vec::new();
726 }
727 if n == 1 {
728 return gradients.into_iter().next().unwrap_or_default();
729 }
730
731 let len = gradients[0].len();
732 let mut sums = vec![0.0_f64; len];
733 for grad in &gradients {
734 for (i, v) in grad.iter().enumerate() {
735 if i < len {
736 sums[i] += v;
737 }
738 }
739 }
740 sums.iter_mut().for_each(|v| *v /= n as f64);
741 sums
742 }
743}
744
745#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
747pub struct SparseGradient {
748 pub indices: Vec<usize>,
750 pub values: Vec<f64>,
752 pub original_len: usize,
754}
755
756#[derive(Debug, Clone, Default)]
758pub struct GradientCompressor;
759
760impl GradientCompressor {
761 pub fn new() -> Self {
763 Self
764 }
765
766 pub fn compress(&self, grad: &[f64], sparsity: f64) -> SparseGradient {
771 let sparsity = sparsity.clamp(0.0, 0.9999);
772 let n = grad.len();
773 if n == 0 {
774 return SparseGradient {
775 indices: Vec::new(),
776 values: Vec::new(),
777 original_len: 0,
778 };
779 }
780
781 let keep = ((1.0 - sparsity) * n as f64).ceil() as usize;
782 let keep = keep.max(1).min(n);
783
784 let mut indexed: Vec<(usize, f64)> = grad
786 .iter()
787 .enumerate()
788 .map(|(i, &v)| (i, v.abs()))
789 .collect();
790 indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
791
792 let mut indices: Vec<usize> = indexed[..keep].iter().map(|(i, _)| *i).collect();
793 indices.sort_unstable();
794
795 let values: Vec<f64> = indices.iter().map(|&i| grad[i]).collect();
796
797 SparseGradient {
798 indices,
799 values,
800 original_len: n,
801 }
802 }
803
804 pub fn decompress(&self, sparse: &SparseGradient) -> Vec<f64> {
806 let mut dense = vec![0.0_f64; sparse.original_len];
807 for (&idx, &val) in sparse.indices.iter().zip(sparse.values.iter()) {
808 if idx < sparse.original_len {
809 dense[idx] = val;
810 }
811 }
812 dense
813 }
814}
815
816#[derive(Debug, Clone, Serialize, Deserialize)]
824pub struct DistributedTrainingSample {
825 pub features: Vec<f64>,
827 pub label: f64,
829 pub weight: Option<f64>,
831}
832
833impl DistributedTrainingSample {
834 pub fn new(features: Vec<f64>, label: f64) -> Self {
836 Self {
837 features,
838 label,
839 weight: None,
840 }
841 }
842}
843
844#[derive(Debug, Clone, Serialize, Deserialize)]
846pub struct WorkerUpdate {
847 pub worker_id: u32,
849 pub gradients: Vec<f64>,
851 pub loss: f64,
853 pub samples_processed: u32,
855}
856
857#[derive(Debug, Clone, Serialize, Deserialize)]
859pub struct ModelUpdate {
860 pub averaged_gradients: Vec<f64>,
862 pub mean_loss: f64,
864 pub total_samples: u32,
866}
867
868#[derive(Debug, Clone, Default)]
870pub struct DataParallelTrainer;
871
872impl DataParallelTrainer {
873 pub fn new() -> Self {
875 Self
876 }
877
878 pub fn split_batch(
884 &self,
885 data: &[DistributedTrainingSample],
886 n_workers: u32,
887 ) -> Vec<Vec<DistributedTrainingSample>> {
888 let n = n_workers as usize;
889 if n == 0 || data.is_empty() {
890 return Vec::new();
891 }
892
893 let mut buckets: Vec<Vec<DistributedTrainingSample>> = (0..n).map(|_| Vec::new()).collect();
894 for (i, sample) in data.iter().enumerate() {
895 buckets[i % n].push(sample.clone());
896 }
897 buckets
898 }
899
900 pub fn merge_worker_updates(&self, updates: Vec<WorkerUpdate>) -> ModelUpdate {
905 if updates.is_empty() {
906 return ModelUpdate {
907 averaged_gradients: Vec::new(),
908 mean_loss: 0.0,
909 total_samples: 0,
910 };
911 }
912
913 let total_samples: u32 = updates.iter().map(|u| u.samples_processed).sum();
914 if total_samples == 0 {
915 return ModelUpdate {
916 averaged_gradients: Vec::new(),
917 mean_loss: 0.0,
918 total_samples: 0,
919 };
920 }
921
922 let grad_len = updates.iter().map(|u| u.gradients.len()).max().unwrap_or(0);
924
925 let mut averaged_gradients = vec![0.0_f64; grad_len];
926 let mut weighted_loss = 0.0_f64;
927
928 for update in &updates {
929 let weight = update.samples_processed as f64 / total_samples as f64;
930 for (i, &g) in update.gradients.iter().enumerate() {
931 if i < grad_len {
932 averaged_gradients[i] += g * weight;
933 }
934 }
935 weighted_loss += update.loss * weight;
936 }
937
938 ModelUpdate {
939 averaged_gradients,
940 mean_loss: weighted_loss,
941 total_samples,
942 }
943 }
944}
945
946pub struct DistributedEmbeddingTrainer<M: EmbeddingModel> {
948 model: M,
949 coordinator: DistributedTrainingCoordinator,
950}
951
952impl<M: EmbeddingModel> DistributedEmbeddingTrainer<M> {
953 pub async fn new(model: M, config: DistributedTrainingConfig) -> Result<Self> {
955 let coordinator = DistributedTrainingCoordinator::new(config).await?;
956
957 Ok(Self { model, coordinator })
958 }
959
960 pub async fn train(&mut self, epochs: usize) -> Result<DistributedTrainingStats> {
962 self.coordinator.train(&mut self.model, epochs).await
963 }
964
965 pub fn model(&self) -> &M {
967 &self.model
968 }
969
970 pub fn model_mut(&mut self) -> &mut M {
972 &mut self.model
973 }
974
975 pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
977 self.coordinator.register_worker(worker_info).await
978 }
979
980 pub async fn get_stats(&self) -> DistributedTrainingStats {
982 self.coordinator.get_stats().await
983 }
984}
985
986#[cfg(test)]
987mod tests {
988 use super::*;
989 use crate::{ModelConfig, TransE};
990
991 #[test]
994 fn test_all_reduce_strategy_variants() {
995 let strategies = [
996 AllReduceStrategy::RingAllReduce,
997 AllReduceStrategy::TreeAllReduce,
998 AllReduceStrategy::ParameterServer,
999 ];
1000 for s in &strategies {
1001 let agg = GradientAggregator::new();
1002 let grad = vec![1.0, 2.0, 3.0];
1003 let result = agg.aggregate_gradients(&grad, s);
1004 assert_eq!(result.len(), 3);
1005 }
1006 }
1007
1008 #[test]
1009 fn test_ring_all_reduce_single_worker() {
1010 let agg = GradientAggregator::new();
1011 let grads = vec![vec![1.0, 2.0, 3.0]];
1012 let result = agg.ring_all_reduce(grads);
1013 assert_eq!(result, vec![1.0, 2.0, 3.0]);
1014 }
1015
1016 #[test]
1017 fn test_ring_all_reduce_two_workers() {
1018 let agg = GradientAggregator::new();
1019 let grads = vec![vec![2.0, 4.0, 6.0], vec![2.0, 4.0, 6.0]];
1020 let result = agg.ring_all_reduce(grads);
1021 assert_eq!(result.len(), 3);
1022 for (r, expected) in result.iter().zip([2.0, 4.0, 6.0].iter()) {
1024 assert!((r - expected).abs() < 1e-9, "expected {expected}, got {r}");
1025 }
1026 }
1027
1028 #[test]
1029 fn test_ring_all_reduce_four_workers_mean() {
1030 let agg = GradientAggregator::new();
1031 let grads = vec![
1032 vec![4.0, 8.0],
1033 vec![2.0, 4.0],
1034 vec![0.0, 0.0],
1035 vec![6.0, 12.0],
1036 ];
1037 let result = agg.ring_all_reduce(grads);
1038 assert_eq!(result.len(), 2);
1039 assert!((result[0] - 3.0).abs() < 1e-6);
1041 assert!((result[1] - 6.0).abs() < 1e-6);
1042 }
1043
1044 #[test]
1045 fn test_ring_all_reduce_empty_input() {
1046 let agg = GradientAggregator::new();
1047 let result = agg.ring_all_reduce(vec![]);
1048 assert!(result.is_empty());
1049 }
1050
1051 #[test]
1052 fn test_ring_all_reduce_empty_gradient_vectors() {
1053 let agg = GradientAggregator::new();
1054 let result = agg.ring_all_reduce(vec![vec![], vec![]]);
1055 assert!(result.is_empty());
1056 }
1057
1058 #[test]
1059 fn test_aggregate_gradients_ring() {
1060 let agg = GradientAggregator::new();
1061 let grad = vec![1.0, 2.0, 3.0, 4.0];
1062 let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::RingAllReduce);
1063 assert_eq!(result.len(), 4);
1064 }
1065
1066 #[test]
1067 fn test_aggregate_gradients_tree() {
1068 let agg = GradientAggregator::new();
1069 let grad = vec![5.0, 10.0];
1070 let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::TreeAllReduce);
1071 assert_eq!(result, vec![5.0, 10.0]);
1072 }
1073
1074 #[test]
1075 fn test_aggregate_gradients_parameter_server() {
1076 let agg = GradientAggregator::new();
1077 let grad = vec![3.0, 1.0, 4.0];
1078 let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::ParameterServer);
1079 assert_eq!(result, grad);
1080 }
1081
1082 #[test]
1085 fn test_compress_empty_gradient() {
1086 let comp = GradientCompressor::new();
1087 let sparse = comp.compress(&[], 0.9);
1088 assert!(sparse.indices.is_empty());
1089 assert_eq!(sparse.original_len, 0);
1090 }
1091
1092 #[test]
1093 fn test_compress_keep_all() {
1094 let comp = GradientCompressor::new();
1095 let grad = vec![1.0, -2.0, 3.0, -4.0];
1096 let sparse = comp.compress(&grad, 0.0);
1097 assert_eq!(sparse.indices.len(), 4);
1099 assert_eq!(sparse.original_len, 4);
1100 }
1101
1102 #[test]
1103 fn test_compress_top_k_selects_largest() {
1104 let comp = GradientCompressor::new();
1105 let grad = vec![0.1, 5.0, 0.2, 9.0, 0.3];
1106 let sparse = comp.compress(&grad, 0.6);
1108 assert_eq!(sparse.indices.len(), 2);
1109 assert!(sparse.indices.contains(&3)); assert!(sparse.indices.contains(&1)); }
1112
1113 #[test]
1114 fn test_decompress_roundtrip() {
1115 let comp = GradientCompressor::new();
1116 let grad = vec![0.0, 1.0, 0.0, -3.0, 0.0];
1117 let sparse = comp.compress(&grad, 0.6);
1118 let dense = comp.decompress(&sparse);
1119 assert_eq!(dense.len(), 5);
1120 assert!((dense[3] - (-3.0)).abs() < 1e-12);
1122 assert!((dense[1] - 1.0).abs() < 1e-12);
1123 }
1124
1125 #[test]
1126 fn test_decompress_empty_sparse() {
1127 let comp = GradientCompressor::new();
1128 let sparse = SparseGradient {
1129 indices: Vec::new(),
1130 values: Vec::new(),
1131 original_len: 5,
1132 };
1133 let dense = comp.decompress(&sparse);
1134 assert_eq!(dense, vec![0.0; 5]);
1135 }
1136
1137 #[test]
1138 fn test_sparse_gradient_serialization() {
1139 let sg = SparseGradient {
1140 indices: vec![0, 2],
1141 values: vec![1.5, -2.5],
1142 original_len: 4,
1143 };
1144 let json = serde_json::to_string(&sg).expect("serialize");
1145 let sg2: SparseGradient = serde_json::from_str(&json).expect("deserialize");
1146 assert_eq!(sg, sg2);
1147 }
1148
1149 #[test]
1152 fn test_split_batch_even() {
1153 let trainer = DataParallelTrainer::new();
1154 let samples: Vec<DistributedTrainingSample> = (0..8)
1155 .map(|i| DistributedTrainingSample::new(vec![i as f64], i as f64))
1156 .collect();
1157 let batches = trainer.split_batch(&samples, 4);
1158 assert_eq!(batches.len(), 4);
1159 for b in &batches {
1160 assert_eq!(b.len(), 2);
1161 }
1162 }
1163
1164 #[test]
1165 fn test_split_batch_uneven() {
1166 let trainer = DataParallelTrainer::new();
1167 let samples: Vec<DistributedTrainingSample> = (0..10)
1168 .map(|i| DistributedTrainingSample::new(vec![i as f64], i as f64))
1169 .collect();
1170 let batches = trainer.split_batch(&samples, 3);
1171 assert_eq!(batches.len(), 3);
1172 let total: usize = batches.iter().map(|b| b.len()).sum();
1173 assert_eq!(total, 10);
1174 }
1175
1176 #[test]
1177 fn test_split_batch_zero_workers() {
1178 let trainer = DataParallelTrainer::new();
1179 let samples = vec![DistributedTrainingSample::new(vec![1.0], 0.0)];
1180 let batches = trainer.split_batch(&samples, 0);
1181 assert!(batches.is_empty());
1182 }
1183
1184 #[test]
1185 fn test_split_batch_empty_data() {
1186 let trainer = DataParallelTrainer::new();
1187 let batches = trainer.split_batch(&[], 4);
1188 assert!(batches.is_empty());
1189 }
1190
1191 #[test]
1192 fn test_merge_worker_updates_basic() {
1193 let trainer = DataParallelTrainer::new();
1194 let updates = vec![
1195 WorkerUpdate {
1196 worker_id: 0,
1197 gradients: vec![2.0, 4.0],
1198 loss: 1.0,
1199 samples_processed: 10,
1200 },
1201 WorkerUpdate {
1202 worker_id: 1,
1203 gradients: vec![2.0, 4.0],
1204 loss: 1.0,
1205 samples_processed: 10,
1206 },
1207 ];
1208 let merged = trainer.merge_worker_updates(updates);
1209 assert_eq!(merged.total_samples, 20);
1210 assert!((merged.mean_loss - 1.0).abs() < 1e-9);
1211 assert!((merged.averaged_gradients[0] - 2.0).abs() < 1e-9);
1212 assert!((merged.averaged_gradients[1] - 4.0).abs() < 1e-9);
1213 }
1214
1215 #[test]
1216 fn test_merge_worker_updates_weighted() {
1217 let trainer = DataParallelTrainer::new();
1218 let updates = vec![
1220 WorkerUpdate {
1221 worker_id: 0,
1222 gradients: vec![4.0],
1223 loss: 2.0,
1224 samples_processed: 1,
1225 },
1226 WorkerUpdate {
1227 worker_id: 1,
1228 gradients: vec![0.0],
1229 loss: 0.0,
1230 samples_processed: 3,
1231 },
1232 ];
1233 let merged = trainer.merge_worker_updates(updates);
1234 assert_eq!(merged.total_samples, 4);
1235 assert!((merged.averaged_gradients[0] - 1.0).abs() < 1e-9);
1237 assert!((merged.mean_loss - 0.5).abs() < 1e-9);
1239 }
1240
1241 #[test]
1242 fn test_merge_worker_updates_empty() {
1243 let trainer = DataParallelTrainer::new();
1244 let merged = trainer.merge_worker_updates(vec![]);
1245 assert_eq!(merged.total_samples, 0);
1246 assert!(merged.averaged_gradients.is_empty());
1247 }
1248
1249 #[test]
1250 fn test_worker_update_serialization() {
1251 let update = WorkerUpdate {
1252 worker_id: 7,
1253 gradients: vec![0.1, -0.2],
1254 loss: 0.42,
1255 samples_processed: 32,
1256 };
1257 let json = serde_json::to_string(&update).expect("serialize");
1258 let update2: WorkerUpdate = serde_json::from_str(&json).expect("deserialize");
1259 assert_eq!(update.worker_id, update2.worker_id);
1260 assert_eq!(update.samples_processed, update2.samples_processed);
1261 }
1262
1263 #[test]
1264 fn test_model_update_fields() {
1265 let mu = ModelUpdate {
1266 averaged_gradients: vec![1.0, 2.0],
1267 mean_loss: 0.5,
1268 total_samples: 100,
1269 };
1270 assert_eq!(mu.total_samples, 100);
1271 assert!((mu.mean_loss - 0.5).abs() < 1e-12);
1272 }
1273
1274 #[tokio::test]
1275 async fn test_distributed_coordinator_creation() {
1276 let config = DistributedTrainingConfig::default();
1277 let coordinator = DistributedTrainingCoordinator::new(config).await;
1278 assert!(coordinator.is_ok());
1279 }
1280
1281 #[tokio::test]
1282 async fn test_worker_registration() {
1283 let config = DistributedTrainingConfig::default();
1284 let coordinator = DistributedTrainingCoordinator::new(config)
1285 .await
1286 .expect("should succeed");
1287
1288 let worker = WorkerInfo {
1289 worker_id: 0,
1290 rank: 0,
1291 address: "127.0.0.1:8080".to_string(),
1292 status: WorkerStatus::Idle,
1293 num_gpus: 1,
1294 memory_gb: 16.0,
1295 last_heartbeat: Utc::now(),
1296 };
1297
1298 coordinator
1299 .register_worker(worker)
1300 .await
1301 .expect("should succeed");
1302 let stats = coordinator.get_worker_stats().await;
1303 assert_eq!(stats.len(), 1);
1304 }
1305
1306 #[tokio::test]
1307 async fn test_distributed_training() {
1308 let config = DistributedTrainingConfig {
1309 strategy: DistributedStrategy::DataParallel {
1310 num_workers: 2,
1311 batch_size: 128,
1312 },
1313 ..Default::default()
1314 };
1315
1316 let model_config = ModelConfig::default().with_dimensions(64);
1317 let model = TransE::new(model_config);
1318
1319 let mut trainer = DistributedEmbeddingTrainer::new(model, config)
1320 .await
1321 .expect("should succeed");
1322
1323 for i in 0..2 {
1325 let worker = WorkerInfo {
1326 worker_id: i,
1327 rank: i,
1328 address: format!("127.0.0.1:808{}", i),
1329 status: WorkerStatus::Idle,
1330 num_gpus: 1,
1331 memory_gb: 16.0,
1332 last_heartbeat: Utc::now(),
1333 };
1334 trainer
1335 .register_worker(worker)
1336 .await
1337 .expect("should succeed");
1338 }
1339
1340 let stats = trainer.train(5).await.expect("should succeed");
1342
1343 assert_eq!(stats.total_epochs, 5);
1344 assert!(stats.final_loss >= 0.0);
1345 assert_eq!(stats.num_workers, 2);
1346 }
1347
1348 #[tokio::test]
1349 async fn test_checkpoint_save_load() {
1350 let config = DistributedTrainingConfig::default();
1351 let coordinator = DistributedTrainingCoordinator::new(config)
1352 .await
1353 .expect("should succeed");
1354
1355 let model_config = ModelConfig::default();
1356 let model = TransE::new(model_config);
1357
1358 let worker = WorkerInfo {
1360 worker_id: 0,
1361 rank: 0,
1362 address: "127.0.0.1:8080".to_string(),
1363 status: WorkerStatus::Idle,
1364 num_gpus: 1,
1365 memory_gb: 16.0,
1366 last_heartbeat: Utc::now(),
1367 };
1368 coordinator
1369 .register_worker(worker)
1370 .await
1371 .expect("should succeed");
1372
1373 coordinator
1375 .save_checkpoint(&model, 10, 0.5)
1376 .await
1377 .expect("should succeed");
1378
1379 let checkpoint = coordinator
1381 .load_checkpoint("checkpoint_epoch_10")
1382 .await
1383 .expect("should succeed");
1384 assert_eq!(checkpoint.epoch, 10);
1385 assert_eq!(checkpoint.loss, 0.5);
1386 }
1387}