1use anyhow::Result;
29use chrono::{DateTime, Utc};
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33use tokio::sync::{Mutex, RwLock};
34use tracing::{debug, info, warn};
35
36use scirs2_core::distributed::{ClusterConfiguration, ClusterManager};
38use scirs2_core::ndarray_ext::Array1;
39
40use crate::EmbeddingModel;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum DistributedStrategy {
45 DataParallel {
47 num_workers: usize,
49 batch_size: usize,
51 },
52 ModelParallel {
54 num_shards: usize,
56 pipeline_stages: usize,
58 },
59 Hybrid {
61 data_parallel_size: usize,
63 model_parallel_size: usize,
65 },
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub enum AggregationMethod {
71 AllReduce,
73 RingAllReduce,
75 ParameterServer {
77 num_servers: usize,
79 },
80 Hierarchical {
82 branching_factor: usize,
84 },
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub enum CommunicationBackend {
90 Tcp,
92 Nccl,
94 Gloo,
96 Mpi,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct FaultToleranceConfig {
103 pub enable_checkpointing: bool,
105 pub checkpoint_frequency: usize,
107 pub max_retries: usize,
109 pub elastic_scaling: bool,
111 pub heartbeat_interval: u64,
113 pub worker_timeout: u64,
115}
116
117impl Default for FaultToleranceConfig {
118 fn default() -> Self {
119 Self {
120 enable_checkpointing: true,
121 checkpoint_frequency: 10,
122 max_retries: 3,
123 elastic_scaling: false,
124 heartbeat_interval: 30,
125 worker_timeout: 300,
126 }
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct DistributedTrainingConfig {
133 pub strategy: DistributedStrategy,
135 pub aggregation: AggregationMethod,
137 pub backend: CommunicationBackend,
139 pub fault_tolerance: FaultToleranceConfig,
141 pub gradient_compression: bool,
143 pub compression_ratio: f32,
145 pub mixed_precision: bool,
147 pub gradient_clip: Option<f32>,
149 pub warmup_epochs: usize,
151 pub pipeline_parallelism: bool,
153 pub num_microbatches: usize,
155}
156
157impl Default for DistributedTrainingConfig {
158 fn default() -> Self {
159 Self {
160 strategy: DistributedStrategy::DataParallel {
161 num_workers: 4,
162 batch_size: 256,
163 },
164 aggregation: AggregationMethod::AllReduce,
165 backend: CommunicationBackend::Tcp,
166 fault_tolerance: FaultToleranceConfig::default(),
167 gradient_compression: false,
168 compression_ratio: 0.5,
169 mixed_precision: false,
170 gradient_clip: Some(1.0),
171 warmup_epochs: 5,
172 pipeline_parallelism: false,
173 num_microbatches: 4,
174 }
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct WorkerInfo {
181 pub worker_id: usize,
183 pub rank: usize,
185 pub address: String,
187 pub status: WorkerStatus,
189 pub num_gpus: usize,
191 pub memory_gb: f32,
193 pub last_heartbeat: DateTime<Utc>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
199pub enum WorkerStatus {
200 Idle,
202 Training,
204 Synchronizing,
206 Failed,
208 Recovering,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct TrainingCheckpoint {
215 pub checkpoint_id: String,
217 pub epoch: usize,
219 pub global_step: usize,
221 pub model_state: Vec<u8>,
223 pub optimizer_state: Vec<u8>,
225 pub loss: f64,
227 pub timestamp: DateTime<Utc>,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct DistributedTrainingStats {
234 pub total_epochs: usize,
236 pub total_steps: usize,
238 pub final_loss: f64,
240 pub training_time: f64,
242 pub num_workers: usize,
244 pub throughput: f64,
246 pub communication_time: f64,
248 pub computation_time: f64,
250 pub num_checkpoints: usize,
252 pub num_failures: usize,
254 pub loss_history: Vec<f64>,
256}
257
258pub struct DistributedTrainingCoordinator {
260 config: DistributedTrainingConfig,
261 workers: Arc<RwLock<HashMap<usize, WorkerInfo>>>,
262 checkpoints: Arc<Mutex<Vec<TrainingCheckpoint>>>,
263 cluster_manager: Arc<ClusterManager>,
264 stats: Arc<Mutex<DistributedTrainingStats>>,
265}
266
267impl DistributedTrainingCoordinator {
268 pub async fn new(config: DistributedTrainingConfig) -> Result<Self> {
270 info!("Initializing distributed training coordinator");
271
272 let cluster_config = ClusterConfiguration::default();
274 let cluster_manager = Arc::new(
275 ClusterManager::new(cluster_config)
276 .map_err(|e| anyhow::anyhow!("Failed to create cluster manager: {}", e))?,
277 );
278
279 Ok(Self {
280 config,
281 workers: Arc::new(RwLock::new(HashMap::new())),
282 checkpoints: Arc::new(Mutex::new(Vec::new())),
283 cluster_manager,
284 stats: Arc::new(Mutex::new(DistributedTrainingStats {
285 total_epochs: 0,
286 total_steps: 0,
287 final_loss: 0.0,
288 training_time: 0.0,
289 num_workers: 0,
290 throughput: 0.0,
291 communication_time: 0.0,
292 computation_time: 0.0,
293 num_checkpoints: 0,
294 num_failures: 0,
295 loss_history: Vec::new(),
296 })),
297 })
298 }
299
300 pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
302 info!(
303 "Registering worker {}: {}",
304 worker_info.worker_id, worker_info.address
305 );
306
307 let mut workers = self.workers.write().await;
308 workers.insert(worker_info.worker_id, worker_info);
309
310 let mut stats = self.stats.lock().await;
311 stats.num_workers = workers.len();
312
313 Ok(())
314 }
315
316 pub async fn deregister_worker(&self, worker_id: usize) -> Result<()> {
318 warn!("Deregistering worker {}", worker_id);
319
320 let mut workers = self.workers.write().await;
321 workers.remove(&worker_id);
322
323 let mut stats = self.stats.lock().await;
324 stats.num_workers = workers.len();
325 stats.num_failures += 1;
326
327 Ok(())
328 }
329
330 pub async fn update_worker_status(&self, worker_id: usize, status: WorkerStatus) -> Result<()> {
332 let mut workers = self.workers.write().await;
333 if let Some(worker) = workers.get_mut(&worker_id) {
334 worker.status = status;
335 worker.last_heartbeat = Utc::now();
336 }
337 Ok(())
338 }
339
340 pub async fn train<M: EmbeddingModel>(
342 &mut self,
343 model: &mut M,
344 epochs: usize,
345 ) -> Result<DistributedTrainingStats> {
346 info!("Starting distributed training for {} epochs", epochs);
347
348 let start_time = std::time::Instant::now();
349 let mut total_comm_time = 0.0;
350 let mut total_comp_time = 0.0;
351
352 self.initialize_optimizer().await?;
354
355 for epoch in 0..epochs {
356 debug!("Epoch {}/{}", epoch + 1, epochs);
357
358 let comp_start = std::time::Instant::now();
360 let batch_results = self.distribute_training_batch(model, epoch).await?;
361 let comp_time = comp_start.elapsed().as_secs_f64();
362 total_comp_time += comp_time;
363
364 let comm_start = std::time::Instant::now();
366 let avg_loss = self.aggregate_gradients(&batch_results).await?;
367 let comm_time = comm_start.elapsed().as_secs_f64();
368 total_comm_time += comm_time;
369
370 {
372 let mut stats = self.stats.lock().await;
373 stats.total_epochs = epoch + 1;
374 stats.loss_history.push(avg_loss);
375 stats.final_loss = avg_loss;
376 }
377
378 if self.config.fault_tolerance.enable_checkpointing
380 && (epoch + 1) % self.config.fault_tolerance.checkpoint_frequency == 0
381 {
382 self.save_checkpoint(model, epoch, avg_loss).await?;
383 }
384
385 info!(
386 "Epoch {}: loss={:.6}, comp_time={:.2}s, comm_time={:.2}s",
387 epoch + 1,
388 avg_loss,
389 comp_time,
390 comm_time
391 );
392 }
393
394 let elapsed = start_time.elapsed().as_secs_f64();
395
396 let stats = {
398 let mut stats = self.stats.lock().await;
399 stats.training_time = elapsed;
400 stats.communication_time = total_comm_time;
401 stats.computation_time = total_comp_time;
402 stats.throughput = (epochs as f64) / elapsed;
403 stats.clone()
404 };
405
406 info!("Distributed training completed in {:.2}s", elapsed);
407 info!("Final loss: {:.6}", stats.final_loss);
408 info!("Throughput: {:.2} epochs/sec", stats.throughput);
409
410 Ok(stats)
411 }
412
413 async fn initialize_optimizer(&mut self) -> Result<()> {
415 debug!("Initializing distributed optimizer");
416
417 Ok(())
421 }
422
423 async fn distribute_training_batch<M: EmbeddingModel>(
425 &self,
426 _model: &M,
427 epoch: usize,
428 ) -> Result<Vec<WorkerResult>> {
429 let workers = self.workers.read().await;
430 let num_workers = workers.len();
431
432 if num_workers == 0 {
433 return Err(anyhow::anyhow!("No workers available"));
434 }
435
436 let mut results = Vec::new();
439 for (worker_id, _) in workers.iter() {
440 results.push(WorkerResult {
441 worker_id: *worker_id,
442 epoch,
443 loss: 0.1 * (1.0 - epoch as f64 / 100.0).max(0.01),
444 num_samples: 1000,
445 gradients: HashMap::new(),
446 });
447 }
448
449 Ok(results)
450 }
451
452 async fn aggregate_gradients(&self, results: &[WorkerResult]) -> Result<f64> {
454 if results.is_empty() {
455 return Err(anyhow::anyhow!("No results to aggregate"));
456 }
457
458 let avg_loss = results.iter().map(|r| r.loss).sum::<f64>() / results.len() as f64;
460
461 match &self.config.aggregation {
464 AggregationMethod::AllReduce => {
465 debug!("Using AllReduce for gradient aggregation");
466 }
469 AggregationMethod::RingAllReduce => {
470 debug!("Using Ring-AllReduce for gradient aggregation");
471 }
473 AggregationMethod::ParameterServer { num_servers } => {
474 debug!("Using Parameter Server with {} servers", num_servers);
475 }
477 AggregationMethod::Hierarchical { branching_factor } => {
478 debug!(
479 "Using Hierarchical aggregation with branching factor {}",
480 branching_factor
481 );
482 }
484 }
485
486 Ok(avg_loss)
487 }
488
489 async fn save_checkpoint<M: EmbeddingModel>(
491 &self,
492 _model: &M,
493 epoch: usize,
494 loss: f64,
495 ) -> Result<()> {
496 info!("Saving checkpoint at epoch {}", epoch);
497
498 let checkpoint = TrainingCheckpoint {
499 checkpoint_id: format!("checkpoint_epoch_{}", epoch),
500 epoch,
501 global_step: epoch * 1000, model_state: Vec::new(), optimizer_state: Vec::new(), loss,
505 timestamp: Utc::now(),
506 };
507
508 let mut checkpoints = self.checkpoints.lock().await;
509 checkpoints.push(checkpoint);
510
511 let mut stats = self.stats.lock().await;
512 stats.num_checkpoints += 1;
513
514 Ok(())
515 }
516
517 pub async fn load_checkpoint(&self, checkpoint_id: &str) -> Result<TrainingCheckpoint> {
519 let checkpoints = self.checkpoints.lock().await;
520 checkpoints
521 .iter()
522 .find(|c| c.checkpoint_id == checkpoint_id)
523 .cloned()
524 .ok_or_else(|| anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id))
525 }
526
527 pub async fn get_worker_stats(&self) -> HashMap<usize, WorkerInfo> {
529 self.workers.read().await.clone()
530 }
531
532 pub async fn get_stats(&self) -> DistributedTrainingStats {
534 self.stats.lock().await.clone()
535 }
536
537 pub async fn monitor_workers(&self) -> Result<()> {
539 let timeout_duration =
540 std::time::Duration::from_secs(self.config.fault_tolerance.worker_timeout);
541
542 let workers = self.workers.read().await;
543 let now = Utc::now();
544
545 for (worker_id, worker) in workers.iter() {
546 let elapsed = now.signed_duration_since(worker.last_heartbeat);
547 if elapsed.num_seconds() as u64 > timeout_duration.as_secs() {
548 warn!(
549 "Worker {} timed out (last heartbeat: {:?})",
550 worker_id, worker.last_heartbeat
551 );
552 }
554 }
555
556 Ok(())
557 }
558}
559
560#[derive(Debug, Clone)]
562struct WorkerResult {
563 worker_id: usize,
564 epoch: usize,
565 loss: f64,
566 num_samples: usize,
567 gradients: HashMap<String, Array1<f32>>,
568}
569
570pub struct DistributedEmbeddingTrainer<M: EmbeddingModel> {
572 model: M,
573 coordinator: DistributedTrainingCoordinator,
574}
575
576impl<M: EmbeddingModel> DistributedEmbeddingTrainer<M> {
577 pub async fn new(model: M, config: DistributedTrainingConfig) -> Result<Self> {
579 let coordinator = DistributedTrainingCoordinator::new(config).await?;
580
581 Ok(Self { model, coordinator })
582 }
583
584 pub async fn train(&mut self, epochs: usize) -> Result<DistributedTrainingStats> {
586 self.coordinator.train(&mut self.model, epochs).await
587 }
588
589 pub fn model(&self) -> &M {
591 &self.model
592 }
593
594 pub fn model_mut(&mut self) -> &mut M {
596 &mut self.model
597 }
598
599 pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
601 self.coordinator.register_worker(worker_info).await
602 }
603
604 pub async fn get_stats(&self) -> DistributedTrainingStats {
606 self.coordinator.get_stats().await
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use crate::{ModelConfig, TransE};
614
615 #[tokio::test]
616 async fn test_distributed_coordinator_creation() {
617 let config = DistributedTrainingConfig::default();
618 let coordinator = DistributedTrainingCoordinator::new(config).await;
619 assert!(coordinator.is_ok());
620 }
621
622 #[tokio::test]
623 async fn test_worker_registration() {
624 let config = DistributedTrainingConfig::default();
625 let coordinator = DistributedTrainingCoordinator::new(config).await.unwrap();
626
627 let worker = WorkerInfo {
628 worker_id: 0,
629 rank: 0,
630 address: "127.0.0.1:8080".to_string(),
631 status: WorkerStatus::Idle,
632 num_gpus: 1,
633 memory_gb: 16.0,
634 last_heartbeat: Utc::now(),
635 };
636
637 coordinator.register_worker(worker).await.unwrap();
638 let stats = coordinator.get_worker_stats().await;
639 assert_eq!(stats.len(), 1);
640 }
641
642 #[tokio::test]
643 async fn test_distributed_training() {
644 let config = DistributedTrainingConfig {
645 strategy: DistributedStrategy::DataParallel {
646 num_workers: 2,
647 batch_size: 128,
648 },
649 ..Default::default()
650 };
651
652 let model_config = ModelConfig::default().with_dimensions(64);
653 let model = TransE::new(model_config);
654
655 let mut trainer = DistributedEmbeddingTrainer::new(model, config)
656 .await
657 .unwrap();
658
659 for i in 0..2 {
661 let worker = WorkerInfo {
662 worker_id: i,
663 rank: i,
664 address: format!("127.0.0.1:808{}", i),
665 status: WorkerStatus::Idle,
666 num_gpus: 1,
667 memory_gb: 16.0,
668 last_heartbeat: Utc::now(),
669 };
670 trainer.register_worker(worker).await.unwrap();
671 }
672
673 let stats = trainer.train(5).await.unwrap();
675
676 assert_eq!(stats.total_epochs, 5);
677 assert!(stats.final_loss >= 0.0);
678 assert_eq!(stats.num_workers, 2);
679 }
680
681 #[tokio::test]
682 async fn test_checkpoint_save_load() {
683 let config = DistributedTrainingConfig::default();
684 let coordinator = DistributedTrainingCoordinator::new(config).await.unwrap();
685
686 let model_config = ModelConfig::default();
687 let model = TransE::new(model_config);
688
689 let worker = WorkerInfo {
691 worker_id: 0,
692 rank: 0,
693 address: "127.0.0.1:8080".to_string(),
694 status: WorkerStatus::Idle,
695 num_gpus: 1,
696 memory_gb: 16.0,
697 last_heartbeat: Utc::now(),
698 };
699 coordinator.register_worker(worker).await.unwrap();
700
701 coordinator.save_checkpoint(&model, 10, 0.5).await.unwrap();
703
704 let checkpoint = coordinator
706 .load_checkpoint("checkpoint_epoch_10")
707 .await
708 .unwrap();
709 assert_eq!(checkpoint.epoch, 10);
710 assert_eq!(checkpoint.loss, 0.5);
711 }
712}