oxirs_embed/
distributed_training.rs

1//! Distributed Training Module for Knowledge Graph Embeddings
2//!
3//! This module provides distributed training capabilities for knowledge graph embeddings
4//! across multiple nodes/GPUs using data parallelism and model parallelism strategies.
5//!
6//! ## Features
7//!
8//! - **Data Parallelism**: Distribute training data across multiple workers
9//! - **Model Parallelism**: Split large models across multiple devices
10//! - **Gradient Aggregation**: AllReduce, Parameter Server, Ring-AllReduce
11//! - **Fault Tolerance**: Checkpointing, recovery, and elastic scaling
12//! - **Communication**: Efficient gradient synchronization with compression
13//! - **Load Balancing**: Dynamic workload distribution
14//! - **Monitoring**: Real-time training metrics and performance tracking
15//!
16//! ## Architecture
17//!
18//! ```text
19//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
20//! │  Worker 1   │────▶│  Coordinator│◀────│  Worker 2   │
21//! │ (GPU/CPU)   │     │   (Master)  │     │ (GPU/CPU)   │
22//! └─────────────┘     └─────────────┘     └─────────────┘
23//!       │                    │                    │
24//!       └────────────────────┴────────────────────┘
25//!                   Gradient Sync
26//! ```
27
28use anyhow::Result;
29use chrono::{DateTime, Utc};
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33use tokio::sync::{Mutex, RwLock};
34use tracing::{debug, info, warn};
35
36// Use SciRS2 for distributed computing
37use scirs2_core::distributed::{ClusterConfiguration, ClusterManager};
38use scirs2_core::ndarray_ext::Array1;
39
40use crate::EmbeddingModel;
41
42/// Distributed training strategy
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum DistributedStrategy {
45    /// Data parallelism - split data across workers
46    DataParallel {
47        /// Number of workers
48        num_workers: usize,
49        /// Batch size per worker
50        batch_size: usize,
51    },
52    /// Model parallelism - split model across workers
53    ModelParallel {
54        /// Number of model shards
55        num_shards: usize,
56        /// Pipeline stages
57        pipeline_stages: usize,
58    },
59    /// Hybrid parallelism - combine data and model parallelism
60    Hybrid {
61        /// Data parallel degree
62        data_parallel_size: usize,
63        /// Model parallel degree
64        model_parallel_size: usize,
65    },
66}
67
68/// Gradient aggregation method
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub enum AggregationMethod {
71    /// AllReduce - all workers exchange gradients
72    AllReduce,
73    /// Ring-AllReduce - efficient ring-based gradient exchange
74    RingAllReduce,
75    /// Parameter Server - centralized gradient aggregation
76    ParameterServer {
77        /// Number of parameter servers
78        num_servers: usize,
79    },
80    /// Hierarchical - tree-based aggregation
81    Hierarchical {
82        /// Tree branching factor
83        branching_factor: usize,
84    },
85}
86
87/// Communication backend for distributed training
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub enum CommunicationBackend {
90    /// Native TCP/IP
91    Tcp,
92    /// NCCL (NVIDIA Collective Communications Library)
93    Nccl,
94    /// Gloo (Facebook's collective communications)
95    Gloo,
96    /// MPI (Message Passing Interface)
97    Mpi,
98}
99
100/// Fault tolerance configuration
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct FaultToleranceConfig {
103    /// Enable checkpointing
104    pub enable_checkpointing: bool,
105    /// Checkpoint frequency (in epochs)
106    pub checkpoint_frequency: usize,
107    /// Maximum retry attempts
108    pub max_retries: usize,
109    /// Enable elastic scaling
110    pub elastic_scaling: bool,
111    /// Heartbeat interval (seconds)
112    pub heartbeat_interval: u64,
113    /// Worker timeout (seconds)
114    pub worker_timeout: u64,
115}
116
117impl Default for FaultToleranceConfig {
118    fn default() -> Self {
119        Self {
120            enable_checkpointing: true,
121            checkpoint_frequency: 10,
122            max_retries: 3,
123            elastic_scaling: false,
124            heartbeat_interval: 30,
125            worker_timeout: 300,
126        }
127    }
128}
129
130/// Distributed training configuration
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct DistributedTrainingConfig {
133    /// Distributed strategy
134    pub strategy: DistributedStrategy,
135    /// Gradient aggregation method
136    pub aggregation: AggregationMethod,
137    /// Communication backend
138    pub backend: CommunicationBackend,
139    /// Fault tolerance configuration
140    pub fault_tolerance: FaultToleranceConfig,
141    /// Enable gradient compression
142    pub gradient_compression: bool,
143    /// Compression ratio (0.0-1.0)
144    pub compression_ratio: f32,
145    /// Enable mixed precision training
146    pub mixed_precision: bool,
147    /// Gradient clipping threshold
148    pub gradient_clip: Option<f32>,
149    /// Warmup epochs before full distribution
150    pub warmup_epochs: usize,
151    /// Enable pipeline parallelism
152    pub pipeline_parallelism: bool,
153    /// Number of microbatches for pipeline
154    pub num_microbatches: usize,
155}
156
157impl Default for DistributedTrainingConfig {
158    fn default() -> Self {
159        Self {
160            strategy: DistributedStrategy::DataParallel {
161                num_workers: 4,
162                batch_size: 256,
163            },
164            aggregation: AggregationMethod::AllReduce,
165            backend: CommunicationBackend::Tcp,
166            fault_tolerance: FaultToleranceConfig::default(),
167            gradient_compression: false,
168            compression_ratio: 0.5,
169            mixed_precision: false,
170            gradient_clip: Some(1.0),
171            warmup_epochs: 5,
172            pipeline_parallelism: false,
173            num_microbatches: 4,
174        }
175    }
176}
177
178/// Worker information
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct WorkerInfo {
181    /// Worker ID
182    pub worker_id: usize,
183    /// Worker rank (global)
184    pub rank: usize,
185    /// Worker address
186    pub address: String,
187    /// Worker status
188    pub status: WorkerStatus,
189    /// Number of GPUs available
190    pub num_gpus: usize,
191    /// Memory capacity (GB)
192    pub memory_gb: f32,
193    /// Last heartbeat timestamp
194    pub last_heartbeat: DateTime<Utc>,
195}
196
197/// Worker status
198#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
199pub enum WorkerStatus {
200    /// Worker is idle
201    Idle,
202    /// Worker is training
203    Training,
204    /// Worker is synchronizing
205    Synchronizing,
206    /// Worker has failed
207    Failed,
208    /// Worker is recovering
209    Recovering,
210}
211
212/// Training checkpoint
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct TrainingCheckpoint {
215    /// Checkpoint ID
216    pub checkpoint_id: String,
217    /// Epoch number
218    pub epoch: usize,
219    /// Global step
220    pub global_step: usize,
221    /// Model state (serialized)
222    pub model_state: Vec<u8>,
223    /// Optimizer state (serialized)
224    pub optimizer_state: Vec<u8>,
225    /// Training loss
226    pub loss: f64,
227    /// Timestamp
228    pub timestamp: DateTime<Utc>,
229}
230
231/// Distributed training statistics
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct DistributedTrainingStats {
234    /// Total epochs
235    pub total_epochs: usize,
236    /// Total steps
237    pub total_steps: usize,
238    /// Final loss
239    pub final_loss: f64,
240    /// Training time (seconds)
241    pub training_time: f64,
242    /// Number of workers
243    pub num_workers: usize,
244    /// Average throughput (samples/sec)
245    pub throughput: f64,
246    /// Communication time (seconds)
247    pub communication_time: f64,
248    /// Computation time (seconds)
249    pub computation_time: f64,
250    /// Number of checkpoints saved
251    pub num_checkpoints: usize,
252    /// Number of worker failures
253    pub num_failures: usize,
254    /// Loss history per epoch
255    pub loss_history: Vec<f64>,
256}
257
258/// Distributed training coordinator
259pub struct DistributedTrainingCoordinator {
260    config: DistributedTrainingConfig,
261    workers: Arc<RwLock<HashMap<usize, WorkerInfo>>>,
262    checkpoints: Arc<Mutex<Vec<TrainingCheckpoint>>>,
263    cluster_manager: Arc<ClusterManager>,
264    stats: Arc<Mutex<DistributedTrainingStats>>,
265}
266
267impl DistributedTrainingCoordinator {
268    /// Create a new distributed training coordinator
269    pub async fn new(config: DistributedTrainingConfig) -> Result<Self> {
270        info!("Initializing distributed training coordinator");
271
272        // Create cluster configuration
273        let cluster_config = ClusterConfiguration::default();
274        let cluster_manager = Arc::new(
275            ClusterManager::new(cluster_config)
276                .map_err(|e| anyhow::anyhow!("Failed to create cluster manager: {}", e))?,
277        );
278
279        Ok(Self {
280            config,
281            workers: Arc::new(RwLock::new(HashMap::new())),
282            checkpoints: Arc::new(Mutex::new(Vec::new())),
283            cluster_manager,
284            stats: Arc::new(Mutex::new(DistributedTrainingStats {
285                total_epochs: 0,
286                total_steps: 0,
287                final_loss: 0.0,
288                training_time: 0.0,
289                num_workers: 0,
290                throughput: 0.0,
291                communication_time: 0.0,
292                computation_time: 0.0,
293                num_checkpoints: 0,
294                num_failures: 0,
295                loss_history: Vec::new(),
296            })),
297        })
298    }
299
300    /// Register a worker
301    pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
302        info!(
303            "Registering worker {}: {}",
304            worker_info.worker_id, worker_info.address
305        );
306
307        let mut workers = self.workers.write().await;
308        workers.insert(worker_info.worker_id, worker_info);
309
310        let mut stats = self.stats.lock().await;
311        stats.num_workers = workers.len();
312
313        Ok(())
314    }
315
316    /// Deregister a worker
317    pub async fn deregister_worker(&self, worker_id: usize) -> Result<()> {
318        warn!("Deregistering worker {}", worker_id);
319
320        let mut workers = self.workers.write().await;
321        workers.remove(&worker_id);
322
323        let mut stats = self.stats.lock().await;
324        stats.num_workers = workers.len();
325        stats.num_failures += 1;
326
327        Ok(())
328    }
329
330    /// Update worker status
331    pub async fn update_worker_status(&self, worker_id: usize, status: WorkerStatus) -> Result<()> {
332        let mut workers = self.workers.write().await;
333        if let Some(worker) = workers.get_mut(&worker_id) {
334            worker.status = status;
335            worker.last_heartbeat = Utc::now();
336        }
337        Ok(())
338    }
339
340    /// Coordinate distributed training
341    pub async fn train<M: EmbeddingModel>(
342        &mut self,
343        model: &mut M,
344        epochs: usize,
345    ) -> Result<DistributedTrainingStats> {
346        info!("Starting distributed training for {} epochs", epochs);
347
348        let start_time = std::time::Instant::now();
349        let mut total_comm_time = 0.0;
350        let mut total_comp_time = 0.0;
351
352        // Initialize distributed optimizer
353        self.initialize_optimizer().await?;
354
355        for epoch in 0..epochs {
356            debug!("Epoch {}/{}", epoch + 1, epochs);
357
358            // Distribute work to workers
359            let comp_start = std::time::Instant::now();
360            let batch_results = self.distribute_training_batch(model, epoch).await?;
361            let comp_time = comp_start.elapsed().as_secs_f64();
362            total_comp_time += comp_time;
363
364            // Aggregate gradients
365            let comm_start = std::time::Instant::now();
366            let avg_loss = self.aggregate_gradients(&batch_results).await?;
367            let comm_time = comm_start.elapsed().as_secs_f64();
368            total_comm_time += comm_time;
369
370            // Update statistics
371            {
372                let mut stats = self.stats.lock().await;
373                stats.total_epochs = epoch + 1;
374                stats.loss_history.push(avg_loss);
375                stats.final_loss = avg_loss;
376            }
377
378            // Save checkpoint if needed
379            if self.config.fault_tolerance.enable_checkpointing
380                && (epoch + 1) % self.config.fault_tolerance.checkpoint_frequency == 0
381            {
382                self.save_checkpoint(model, epoch, avg_loss).await?;
383            }
384
385            info!(
386                "Epoch {}: loss={:.6}, comp_time={:.2}s, comm_time={:.2}s",
387                epoch + 1,
388                avg_loss,
389                comp_time,
390                comm_time
391            );
392        }
393
394        let elapsed = start_time.elapsed().as_secs_f64();
395
396        // Finalize statistics
397        let stats = {
398            let mut stats = self.stats.lock().await;
399            stats.training_time = elapsed;
400            stats.communication_time = total_comm_time;
401            stats.computation_time = total_comp_time;
402            stats.throughput = (epochs as f64) / elapsed;
403            stats.clone()
404        };
405
406        info!("Distributed training completed in {:.2}s", elapsed);
407        info!("Final loss: {:.6}", stats.final_loss);
408        info!("Throughput: {:.2} epochs/sec", stats.throughput);
409
410        Ok(stats)
411    }
412
413    /// Initialize distributed optimizer
414    async fn initialize_optimizer(&mut self) -> Result<()> {
415        debug!("Initializing distributed optimizer");
416
417        // In a real implementation, this would initialize optimizer state
418        // For now, this is a placeholder
419
420        Ok(())
421    }
422
423    /// Distribute training batch to workers
424    async fn distribute_training_batch<M: EmbeddingModel>(
425        &self,
426        _model: &M,
427        epoch: usize,
428    ) -> Result<Vec<WorkerResult>> {
429        let workers = self.workers.read().await;
430        let num_workers = workers.len();
431
432        if num_workers == 0 {
433            return Err(anyhow::anyhow!("No workers available"));
434        }
435
436        // Simulate distributed training (in a real implementation, this would
437        // send batches to workers via network communication)
438        let mut results = Vec::new();
439        for (worker_id, _) in workers.iter() {
440            results.push(WorkerResult {
441                worker_id: *worker_id,
442                epoch,
443                loss: 0.1 * (1.0 - epoch as f64 / 100.0).max(0.01),
444                num_samples: 1000,
445                gradients: HashMap::new(),
446            });
447        }
448
449        Ok(results)
450    }
451
452    /// Aggregate gradients from workers
453    async fn aggregate_gradients(&self, results: &[WorkerResult]) -> Result<f64> {
454        if results.is_empty() {
455            return Err(anyhow::anyhow!("No results to aggregate"));
456        }
457
458        // Calculate average loss
459        let avg_loss = results.iter().map(|r| r.loss).sum::<f64>() / results.len() as f64;
460
461        // In a real implementation, this would aggregate gradients using
462        // the configured aggregation method (AllReduce, Parameter Server, etc.)
463        match &self.config.aggregation {
464            AggregationMethod::AllReduce => {
465                debug!("Using AllReduce for gradient aggregation");
466                // Use distributed aggregation
467                // In production, implement actual AllReduce algorithm
468            }
469            AggregationMethod::RingAllReduce => {
470                debug!("Using Ring-AllReduce for gradient aggregation");
471                // Implement ring-based gradient exchange
472            }
473            AggregationMethod::ParameterServer { num_servers } => {
474                debug!("Using Parameter Server with {} servers", num_servers);
475                // Implement parameter server aggregation
476            }
477            AggregationMethod::Hierarchical { branching_factor } => {
478                debug!(
479                    "Using Hierarchical aggregation with branching factor {}",
480                    branching_factor
481                );
482                // Implement tree-based aggregation
483            }
484        }
485
486        Ok(avg_loss)
487    }
488
489    /// Save training checkpoint
490    async fn save_checkpoint<M: EmbeddingModel>(
491        &self,
492        _model: &M,
493        epoch: usize,
494        loss: f64,
495    ) -> Result<()> {
496        info!("Saving checkpoint at epoch {}", epoch);
497
498        let checkpoint = TrainingCheckpoint {
499            checkpoint_id: format!("checkpoint_epoch_{}", epoch),
500            epoch,
501            global_step: epoch * 1000,   // Simplified
502            model_state: Vec::new(),     // In real impl, serialize model state
503            optimizer_state: Vec::new(), // In real impl, serialize optimizer state
504            loss,
505            timestamp: Utc::now(),
506        };
507
508        let mut checkpoints = self.checkpoints.lock().await;
509        checkpoints.push(checkpoint);
510
511        let mut stats = self.stats.lock().await;
512        stats.num_checkpoints += 1;
513
514        Ok(())
515    }
516
517    /// Load training checkpoint
518    pub async fn load_checkpoint(&self, checkpoint_id: &str) -> Result<TrainingCheckpoint> {
519        let checkpoints = self.checkpoints.lock().await;
520        checkpoints
521            .iter()
522            .find(|c| c.checkpoint_id == checkpoint_id)
523            .cloned()
524            .ok_or_else(|| anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id))
525    }
526
527    /// Get worker statistics
528    pub async fn get_worker_stats(&self) -> HashMap<usize, WorkerInfo> {
529        self.workers.read().await.clone()
530    }
531
532    /// Get training statistics
533    pub async fn get_stats(&self) -> DistributedTrainingStats {
534        self.stats.lock().await.clone()
535    }
536
537    /// Monitor worker health (heartbeat check)
538    pub async fn monitor_workers(&self) -> Result<()> {
539        let timeout_duration =
540            std::time::Duration::from_secs(self.config.fault_tolerance.worker_timeout);
541
542        let workers = self.workers.read().await;
543        let now = Utc::now();
544
545        for (worker_id, worker) in workers.iter() {
546            let elapsed = now.signed_duration_since(worker.last_heartbeat);
547            if elapsed.num_seconds() as u64 > timeout_duration.as_secs() {
548                warn!(
549                    "Worker {} timed out (last heartbeat: {:?})",
550                    worker_id, worker.last_heartbeat
551                );
552                // In a real implementation, trigger worker recovery or replacement
553            }
554        }
555
556        Ok(())
557    }
558}
559
560/// Worker training result
561#[derive(Debug, Clone)]
562struct WorkerResult {
563    worker_id: usize,
564    epoch: usize,
565    loss: f64,
566    num_samples: usize,
567    gradients: HashMap<String, Array1<f32>>,
568}
569
570/// Distributed embedding model trainer
571pub struct DistributedEmbeddingTrainer<M: EmbeddingModel> {
572    model: M,
573    coordinator: DistributedTrainingCoordinator,
574}
575
576impl<M: EmbeddingModel> DistributedEmbeddingTrainer<M> {
577    /// Create a new distributed trainer
578    pub async fn new(model: M, config: DistributedTrainingConfig) -> Result<Self> {
579        let coordinator = DistributedTrainingCoordinator::new(config).await?;
580
581        Ok(Self { model, coordinator })
582    }
583
584    /// Train the model in a distributed manner
585    pub async fn train(&mut self, epochs: usize) -> Result<DistributedTrainingStats> {
586        self.coordinator.train(&mut self.model, epochs).await
587    }
588
589    /// Get the trained model
590    pub fn model(&self) -> &M {
591        &self.model
592    }
593
594    /// Get mutable reference to the model
595    pub fn model_mut(&mut self) -> &mut M {
596        &mut self.model
597    }
598
599    /// Register a worker
600    pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
601        self.coordinator.register_worker(worker_info).await
602    }
603
604    /// Get training statistics
605    pub async fn get_stats(&self) -> DistributedTrainingStats {
606        self.coordinator.get_stats().await
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613    use crate::{ModelConfig, TransE};
614
615    #[tokio::test]
616    async fn test_distributed_coordinator_creation() {
617        let config = DistributedTrainingConfig::default();
618        let coordinator = DistributedTrainingCoordinator::new(config).await;
619        assert!(coordinator.is_ok());
620    }
621
622    #[tokio::test]
623    async fn test_worker_registration() {
624        let config = DistributedTrainingConfig::default();
625        let coordinator = DistributedTrainingCoordinator::new(config).await.unwrap();
626
627        let worker = WorkerInfo {
628            worker_id: 0,
629            rank: 0,
630            address: "127.0.0.1:8080".to_string(),
631            status: WorkerStatus::Idle,
632            num_gpus: 1,
633            memory_gb: 16.0,
634            last_heartbeat: Utc::now(),
635        };
636
637        coordinator.register_worker(worker).await.unwrap();
638        let stats = coordinator.get_worker_stats().await;
639        assert_eq!(stats.len(), 1);
640    }
641
642    #[tokio::test]
643    async fn test_distributed_training() {
644        let config = DistributedTrainingConfig {
645            strategy: DistributedStrategy::DataParallel {
646                num_workers: 2,
647                batch_size: 128,
648            },
649            ..Default::default()
650        };
651
652        let model_config = ModelConfig::default().with_dimensions(64);
653        let model = TransE::new(model_config);
654
655        let mut trainer = DistributedEmbeddingTrainer::new(model, config)
656            .await
657            .unwrap();
658
659        // Register workers
660        for i in 0..2 {
661            let worker = WorkerInfo {
662                worker_id: i,
663                rank: i,
664                address: format!("127.0.0.1:808{}", i),
665                status: WorkerStatus::Idle,
666                num_gpus: 1,
667                memory_gb: 16.0,
668                last_heartbeat: Utc::now(),
669            };
670            trainer.register_worker(worker).await.unwrap();
671        }
672
673        // Train for a few epochs
674        let stats = trainer.train(5).await.unwrap();
675
676        assert_eq!(stats.total_epochs, 5);
677        assert!(stats.final_loss >= 0.0);
678        assert_eq!(stats.num_workers, 2);
679    }
680
681    #[tokio::test]
682    async fn test_checkpoint_save_load() {
683        let config = DistributedTrainingConfig::default();
684        let coordinator = DistributedTrainingCoordinator::new(config).await.unwrap();
685
686        let model_config = ModelConfig::default();
687        let model = TransE::new(model_config);
688
689        // Register a worker
690        let worker = WorkerInfo {
691            worker_id: 0,
692            rank: 0,
693            address: "127.0.0.1:8080".to_string(),
694            status: WorkerStatus::Idle,
695            num_gpus: 1,
696            memory_gb: 16.0,
697            last_heartbeat: Utc::now(),
698        };
699        coordinator.register_worker(worker).await.unwrap();
700
701        // Save checkpoint
702        coordinator.save_checkpoint(&model, 10, 0.5).await.unwrap();
703
704        // Load checkpoint
705        let checkpoint = coordinator
706            .load_checkpoint("checkpoint_epoch_10")
707            .await
708            .unwrap();
709        assert_eq!(checkpoint.epoch, 10);
710        assert_eq!(checkpoint.loss, 0.5);
711    }
712}