Skip to main content

oxirs_embed/
distributed_training.rs

1//! Distributed Training Module for Knowledge Graph Embeddings
2//!
3//! This module provides distributed training capabilities for knowledge graph embeddings
4//! across multiple nodes/GPUs using data parallelism and model parallelism strategies.
5//!
6//! ## Features
7//!
8//! - **Data Parallelism**: Distribute training data across multiple workers
9//! - **Model Parallelism**: Split large models across multiple devices
10//! - **Gradient Aggregation**: AllReduce, Parameter Server, Ring-AllReduce
11//! - **Fault Tolerance**: Checkpointing, recovery, and elastic scaling
12//! - **Communication**: Efficient gradient synchronization with compression
13//! - **Load Balancing**: Dynamic workload distribution
14//! - **Monitoring**: Real-time training metrics and performance tracking
15//!
16//! ## Architecture
17//!
18//! ```text
19//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
20//! │  Worker 1   │────▶│  Coordinator│◀────│  Worker 2   │
21//! │ (GPU/CPU)   │     │   (Master)  │     │ (GPU/CPU)   │
22//! └─────────────┘     └─────────────┘     └─────────────┘
23//!       │                    │                    │
24//!       └────────────────────┴────────────────────┘
25//!                   Gradient Sync
26//! ```
27
28use anyhow::Result;
29use chrono::{DateTime, Utc};
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33use tokio::sync::{Mutex, RwLock};
34use tracing::{debug, info, warn};
35
36// Use SciRS2 for distributed computing
37use scirs2_core::distributed::{ClusterConfiguration, ClusterManager};
38use scirs2_core::ndarray_ext::Array1;
39
40use crate::EmbeddingModel;
41
42/// Distributed training strategy
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum DistributedStrategy {
45    /// Data parallelism - split data across workers
46    DataParallel {
47        /// Number of workers
48        num_workers: usize,
49        /// Batch size per worker
50        batch_size: usize,
51    },
52    /// Model parallelism - split model across workers
53    ModelParallel {
54        /// Number of model shards
55        num_shards: usize,
56        /// Pipeline stages
57        pipeline_stages: usize,
58    },
59    /// Hybrid parallelism - combine data and model parallelism
60    Hybrid {
61        /// Data parallel degree
62        data_parallel_size: usize,
63        /// Model parallel degree
64        model_parallel_size: usize,
65    },
66}
67
68/// Gradient aggregation method
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub enum AggregationMethod {
71    /// AllReduce - all workers exchange gradients
72    AllReduce,
73    /// Ring-AllReduce - efficient ring-based gradient exchange
74    RingAllReduce,
75    /// Parameter Server - centralized gradient aggregation
76    ParameterServer {
77        /// Number of parameter servers
78        num_servers: usize,
79    },
80    /// Hierarchical - tree-based aggregation
81    Hierarchical {
82        /// Tree branching factor
83        branching_factor: usize,
84    },
85}
86
87/// Communication backend for distributed training
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub enum CommunicationBackend {
90    /// Native TCP/IP
91    Tcp,
92    /// NCCL (NVIDIA Collective Communications Library)
93    Nccl,
94    /// Gloo (Facebook's collective communications)
95    Gloo,
96    /// MPI (Message Passing Interface)
97    Mpi,
98}
99
100/// Fault tolerance configuration
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct FaultToleranceConfig {
103    /// Enable checkpointing
104    pub enable_checkpointing: bool,
105    /// Checkpoint frequency (in epochs)
106    pub checkpoint_frequency: usize,
107    /// Maximum retry attempts
108    pub max_retries: usize,
109    /// Enable elastic scaling
110    pub elastic_scaling: bool,
111    /// Heartbeat interval (seconds)
112    pub heartbeat_interval: u64,
113    /// Worker timeout (seconds)
114    pub worker_timeout: u64,
115}
116
117impl Default for FaultToleranceConfig {
118    fn default() -> Self {
119        Self {
120            enable_checkpointing: true,
121            checkpoint_frequency: 10,
122            max_retries: 3,
123            elastic_scaling: false,
124            heartbeat_interval: 30,
125            worker_timeout: 300,
126        }
127    }
128}
129
130/// Distributed training configuration
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct DistributedTrainingConfig {
133    /// Distributed strategy
134    pub strategy: DistributedStrategy,
135    /// Gradient aggregation method
136    pub aggregation: AggregationMethod,
137    /// Communication backend
138    pub backend: CommunicationBackend,
139    /// Fault tolerance configuration
140    pub fault_tolerance: FaultToleranceConfig,
141    /// Enable gradient compression
142    pub gradient_compression: bool,
143    /// Compression ratio (0.0-1.0)
144    pub compression_ratio: f32,
145    /// Enable mixed precision training
146    pub mixed_precision: bool,
147    /// Gradient clipping threshold
148    pub gradient_clip: Option<f32>,
149    /// Warmup epochs before full distribution
150    pub warmup_epochs: usize,
151    /// Enable pipeline parallelism
152    pub pipeline_parallelism: bool,
153    /// Number of microbatches for pipeline
154    pub num_microbatches: usize,
155}
156
157impl Default for DistributedTrainingConfig {
158    fn default() -> Self {
159        Self {
160            strategy: DistributedStrategy::DataParallel {
161                num_workers: 4,
162                batch_size: 256,
163            },
164            aggregation: AggregationMethod::AllReduce,
165            backend: CommunicationBackend::Tcp,
166            fault_tolerance: FaultToleranceConfig::default(),
167            gradient_compression: false,
168            compression_ratio: 0.5,
169            mixed_precision: false,
170            gradient_clip: Some(1.0),
171            warmup_epochs: 5,
172            pipeline_parallelism: false,
173            num_microbatches: 4,
174        }
175    }
176}
177
178/// Worker information
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct WorkerInfo {
181    /// Worker ID
182    pub worker_id: usize,
183    /// Worker rank (global)
184    pub rank: usize,
185    /// Worker address
186    pub address: String,
187    /// Worker status
188    pub status: WorkerStatus,
189    /// Number of GPUs available
190    pub num_gpus: usize,
191    /// Memory capacity (GB)
192    pub memory_gb: f32,
193    /// Last heartbeat timestamp
194    pub last_heartbeat: DateTime<Utc>,
195}
196
197/// Worker status
198#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
199pub enum WorkerStatus {
200    /// Worker is idle
201    Idle,
202    /// Worker is training
203    Training,
204    /// Worker is synchronizing
205    Synchronizing,
206    /// Worker has failed
207    Failed,
208    /// Worker is recovering
209    Recovering,
210}
211
212/// Training checkpoint
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct TrainingCheckpoint {
215    /// Checkpoint ID
216    pub checkpoint_id: String,
217    /// Epoch number
218    pub epoch: usize,
219    /// Global step
220    pub global_step: usize,
221    /// Model state (serialized)
222    pub model_state: Vec<u8>,
223    /// Optimizer state (serialized)
224    pub optimizer_state: Vec<u8>,
225    /// Training loss
226    pub loss: f64,
227    /// Timestamp
228    pub timestamp: DateTime<Utc>,
229}
230
231/// Distributed training statistics
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct DistributedTrainingStats {
234    /// Total epochs
235    pub total_epochs: usize,
236    /// Total steps
237    pub total_steps: usize,
238    /// Final loss
239    pub final_loss: f64,
240    /// Training time (seconds)
241    pub training_time: f64,
242    /// Number of workers
243    pub num_workers: usize,
244    /// Average throughput (samples/sec)
245    pub throughput: f64,
246    /// Communication time (seconds)
247    pub communication_time: f64,
248    /// Computation time (seconds)
249    pub computation_time: f64,
250    /// Number of checkpoints saved
251    pub num_checkpoints: usize,
252    /// Number of worker failures
253    pub num_failures: usize,
254    /// Loss history per epoch
255    pub loss_history: Vec<f64>,
256}
257
258/// Distributed training coordinator
259pub struct DistributedTrainingCoordinator {
260    config: DistributedTrainingConfig,
261    workers: Arc<RwLock<HashMap<usize, WorkerInfo>>>,
262    checkpoints: Arc<Mutex<Vec<TrainingCheckpoint>>>,
263    cluster_manager: Arc<ClusterManager>,
264    stats: Arc<Mutex<DistributedTrainingStats>>,
265}
266
267impl DistributedTrainingCoordinator {
268    /// Create a new distributed training coordinator
269    pub async fn new(config: DistributedTrainingConfig) -> Result<Self> {
270        info!("Initializing distributed training coordinator");
271
272        // Create cluster configuration
273        let cluster_config = ClusterConfiguration::default();
274        let cluster_manager = Arc::new(
275            ClusterManager::new(cluster_config)
276                .map_err(|e| anyhow::anyhow!("Failed to create cluster manager: {}", e))?,
277        );
278
279        Ok(Self {
280            config,
281            workers: Arc::new(RwLock::new(HashMap::new())),
282            checkpoints: Arc::new(Mutex::new(Vec::new())),
283            cluster_manager,
284            stats: Arc::new(Mutex::new(DistributedTrainingStats {
285                total_epochs: 0,
286                total_steps: 0,
287                final_loss: 0.0,
288                training_time: 0.0,
289                num_workers: 0,
290                throughput: 0.0,
291                communication_time: 0.0,
292                computation_time: 0.0,
293                num_checkpoints: 0,
294                num_failures: 0,
295                loss_history: Vec::new(),
296            })),
297        })
298    }
299
300    /// Register a worker
301    pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
302        info!(
303            "Registering worker {}: {}",
304            worker_info.worker_id, worker_info.address
305        );
306
307        let mut workers = self.workers.write().await;
308        workers.insert(worker_info.worker_id, worker_info);
309
310        let mut stats = self.stats.lock().await;
311        stats.num_workers = workers.len();
312
313        Ok(())
314    }
315
316    /// Deregister a worker
317    pub async fn deregister_worker(&self, worker_id: usize) -> Result<()> {
318        warn!("Deregistering worker {}", worker_id);
319
320        let mut workers = self.workers.write().await;
321        workers.remove(&worker_id);
322
323        let mut stats = self.stats.lock().await;
324        stats.num_workers = workers.len();
325        stats.num_failures += 1;
326
327        Ok(())
328    }
329
330    /// Update worker status
331    pub async fn update_worker_status(&self, worker_id: usize, status: WorkerStatus) -> Result<()> {
332        let mut workers = self.workers.write().await;
333        if let Some(worker) = workers.get_mut(&worker_id) {
334            worker.status = status;
335            worker.last_heartbeat = Utc::now();
336        }
337        Ok(())
338    }
339
340    /// Coordinate distributed training
341    pub async fn train<M: EmbeddingModel>(
342        &mut self,
343        model: &mut M,
344        epochs: usize,
345    ) -> Result<DistributedTrainingStats> {
346        info!("Starting distributed training for {} epochs", epochs);
347
348        let start_time = std::time::Instant::now();
349        let mut total_comm_time = 0.0;
350        let mut total_comp_time = 0.0;
351
352        // Initialize distributed optimizer
353        self.initialize_optimizer().await?;
354
355        for epoch in 0..epochs {
356            debug!("Epoch {}/{}", epoch + 1, epochs);
357
358            // Distribute work to workers
359            let comp_start = std::time::Instant::now();
360            let batch_results = self.distribute_training_batch(model, epoch).await?;
361            let comp_time = comp_start.elapsed().as_secs_f64();
362            total_comp_time += comp_time;
363
364            // Aggregate gradients
365            let comm_start = std::time::Instant::now();
366            let avg_loss = self.aggregate_gradients(&batch_results).await?;
367            let comm_time = comm_start.elapsed().as_secs_f64();
368            total_comm_time += comm_time;
369
370            // Update statistics
371            {
372                let mut stats = self.stats.lock().await;
373                stats.total_epochs = epoch + 1;
374                stats.loss_history.push(avg_loss);
375                stats.final_loss = avg_loss;
376            }
377
378            // Save checkpoint if needed
379            if self.config.fault_tolerance.enable_checkpointing
380                && (epoch + 1) % self.config.fault_tolerance.checkpoint_frequency == 0
381            {
382                self.save_checkpoint(model, epoch, avg_loss).await?;
383            }
384
385            info!(
386                "Epoch {}: loss={:.6}, comp_time={:.2}s, comm_time={:.2}s",
387                epoch + 1,
388                avg_loss,
389                comp_time,
390                comm_time
391            );
392        }
393
394        let elapsed = start_time.elapsed().as_secs_f64();
395
396        // Finalize statistics
397        let stats = {
398            let mut stats = self.stats.lock().await;
399            stats.training_time = elapsed;
400            stats.communication_time = total_comm_time;
401            stats.computation_time = total_comp_time;
402            stats.throughput = (epochs as f64) / elapsed;
403            stats.clone()
404        };
405
406        info!("Distributed training completed in {:.2}s", elapsed);
407        info!("Final loss: {:.6}", stats.final_loss);
408        info!("Throughput: {:.2} epochs/sec", stats.throughput);
409
410        Ok(stats)
411    }
412
413    /// Initialize distributed optimizer
414    async fn initialize_optimizer(&mut self) -> Result<()> {
415        debug!("Initializing distributed optimizer");
416
417        // In a real implementation, this would initialize optimizer state
418        // For now, this is a placeholder
419
420        Ok(())
421    }
422
423    /// Distribute training batch to workers
424    async fn distribute_training_batch<M: EmbeddingModel>(
425        &self,
426        _model: &M,
427        epoch: usize,
428    ) -> Result<Vec<WorkerResult>> {
429        let workers = self.workers.read().await;
430        let num_workers = workers.len();
431
432        if num_workers == 0 {
433            return Err(anyhow::anyhow!("No workers available"));
434        }
435
436        // Simulate distributed training (in a real implementation, this would
437        // send batches to workers via network communication)
438        let mut results = Vec::new();
439        for (worker_id, _) in workers.iter() {
440            results.push(WorkerResult {
441                worker_id: *worker_id,
442                epoch,
443                loss: 0.1 * (1.0 - epoch as f64 / 100.0).max(0.01),
444                num_samples: 1000,
445                gradients: HashMap::new(),
446            });
447        }
448
449        Ok(results)
450    }
451
452    /// Aggregate gradients from workers
453    async fn aggregate_gradients(&self, results: &[WorkerResult]) -> Result<f64> {
454        if results.is_empty() {
455            return Err(anyhow::anyhow!("No results to aggregate"));
456        }
457
458        // Calculate average loss
459        let avg_loss = results.iter().map(|r| r.loss).sum::<f64>() / results.len() as f64;
460
461        // In a real implementation, this would aggregate gradients using
462        // the configured aggregation method (AllReduce, Parameter Server, etc.)
463        match &self.config.aggregation {
464            AggregationMethod::AllReduce => {
465                debug!("Using AllReduce for gradient aggregation");
466                // Use distributed aggregation
467                // In production, implement actual AllReduce algorithm
468            }
469            AggregationMethod::RingAllReduce => {
470                debug!("Using Ring-AllReduce for gradient aggregation");
471                // Implement ring-based gradient exchange
472            }
473            AggregationMethod::ParameterServer { num_servers } => {
474                debug!("Using Parameter Server with {} servers", num_servers);
475                // Implement parameter server aggregation
476            }
477            AggregationMethod::Hierarchical { branching_factor } => {
478                debug!(
479                    "Using Hierarchical aggregation with branching factor {}",
480                    branching_factor
481                );
482                // Implement tree-based aggregation
483            }
484        }
485
486        Ok(avg_loss)
487    }
488
489    /// Save training checkpoint
490    async fn save_checkpoint<M: EmbeddingModel>(
491        &self,
492        _model: &M,
493        epoch: usize,
494        loss: f64,
495    ) -> Result<()> {
496        info!("Saving checkpoint at epoch {}", epoch);
497
498        let checkpoint = TrainingCheckpoint {
499            checkpoint_id: format!("checkpoint_epoch_{}", epoch),
500            epoch,
501            global_step: epoch * 1000,   // Simplified
502            model_state: Vec::new(),     // In real impl, serialize model state
503            optimizer_state: Vec::new(), // In real impl, serialize optimizer state
504            loss,
505            timestamp: Utc::now(),
506        };
507
508        let mut checkpoints = self.checkpoints.lock().await;
509        checkpoints.push(checkpoint);
510
511        let mut stats = self.stats.lock().await;
512        stats.num_checkpoints += 1;
513
514        Ok(())
515    }
516
517    /// Load training checkpoint
518    pub async fn load_checkpoint(&self, checkpoint_id: &str) -> Result<TrainingCheckpoint> {
519        let checkpoints = self.checkpoints.lock().await;
520        checkpoints
521            .iter()
522            .find(|c| c.checkpoint_id == checkpoint_id)
523            .cloned()
524            .ok_or_else(|| anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id))
525    }
526
527    /// Get worker statistics
528    pub async fn get_worker_stats(&self) -> HashMap<usize, WorkerInfo> {
529        self.workers.read().await.clone()
530    }
531
532    /// Get training statistics
533    pub async fn get_stats(&self) -> DistributedTrainingStats {
534        self.stats.lock().await.clone()
535    }
536
537    /// Monitor worker health (heartbeat check)
538    pub async fn monitor_workers(&self) -> Result<()> {
539        let timeout_duration =
540            std::time::Duration::from_secs(self.config.fault_tolerance.worker_timeout);
541
542        let workers = self.workers.read().await;
543        let now = Utc::now();
544
545        for (worker_id, worker) in workers.iter() {
546            let elapsed = now.signed_duration_since(worker.last_heartbeat);
547            if elapsed.num_seconds() as u64 > timeout_duration.as_secs() {
548                warn!(
549                    "Worker {} timed out (last heartbeat: {:?})",
550                    worker_id, worker.last_heartbeat
551                );
552                // In a real implementation, trigger worker recovery or replacement
553            }
554        }
555
556        Ok(())
557    }
558}
559
560/// Worker training result
561#[derive(Debug, Clone)]
562struct WorkerResult {
563    worker_id: usize,
564    epoch: usize,
565    loss: f64,
566    num_samples: usize,
567    gradients: HashMap<String, Array1<f32>>,
568}
569
570// ─────────────────────────────────────────────────────────────
571// A. Gradient Aggregation & Compression
572// ─────────────────────────────────────────────────────────────
573
574/// Strategy for all-reduce gradient aggregation across distributed workers.
575#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
576pub enum AllReduceStrategy {
577    /// Ring-based all-reduce: workers arranged in a ring pass partial sums around.
578    RingAllReduce,
579    /// Tree-based all-reduce: hierarchical reduction over a binary tree topology.
580    TreeAllReduce,
581    /// Parameter server: a central server accumulates and broadcasts gradients.
582    ParameterServer,
583}
584
585/// Aggregates gradients from distributed workers.
586#[derive(Debug, Clone, Default)]
587pub struct GradientAggregator;
588
589impl GradientAggregator {
590    /// Create a new `GradientAggregator`.
591    pub fn new() -> Self {
592        Self
593    }
594
595    /// Aggregate `local_grad` from this worker together with gradients that have
596    /// already been reduced on other workers, using the given `strategy`.
597    ///
598    /// For the single-worker case the function simply returns a normalised copy of
599    /// `local_grad`.  In a real multi-node scenario the caller would pass the
600    /// collected per-worker slices through `ring_all_reduce` or the tree variant.
601    pub fn aggregate_gradients(
602        &self,
603        local_grad: &[f64],
604        strategy: &AllReduceStrategy,
605    ) -> Vec<f64> {
606        match strategy {
607            AllReduceStrategy::RingAllReduce => {
608                // Treat the single local gradient as the only worker contribution.
609                self.ring_all_reduce(vec![local_grad.to_vec()])
610            }
611            AllReduceStrategy::TreeAllReduce => self.tree_all_reduce(vec![local_grad.to_vec()]),
612            AllReduceStrategy::ParameterServer => {
613                // Parameter-server: accept local grad and average (single worker path).
614                local_grad.to_vec()
615            }
616        }
617    }
618
619    /// Simulate ring all-reduce over a set of per-worker gradient vectors.
620    ///
621    /// Ring all-reduce arranges `n` workers in a ring.  It runs in two phases:
622    ///
623    /// 1. **Scatter-reduce** (`n−1` steps): at step `s`, each worker `w` passes
624    ///    the accumulated data for chunk `(w − s) mod n` to its right neighbour
625    ///    `(w + 1) mod n`, which adds it to its own copy.  After `n−1` steps,
626    ///    worker `w` holds the fully-reduced sum for chunk `(w + 1) mod n`.
627    ///
628    /// 2. **All-gather**: collect the fully-reduced chunk from each owning worker
629    ///    and divide by `n` to obtain the mean.
630    ///
631    /// The mathematical result equals the element-wise mean of all input vectors.
632    /// This simulation runs synchronously on the calling thread with no I/O.
633    pub fn ring_all_reduce(&self, gradients: Vec<Vec<f64>>) -> Vec<f64> {
634        let n = gradients.len();
635        if n == 0 {
636            return Vec::new();
637        }
638        if n == 1 {
639            return gradients.into_iter().next().unwrap_or_default();
640        }
641
642        let len = gradients[0].len();
643        if len == 0 {
644            return Vec::new();
645        }
646
647        // Divide the gradient vector into `n` chunks.  Chunks are sized as
648        // evenly as possible; the first `remainder` chunks get one extra element.
649        let base = len / n;
650        let remainder = len % n;
651        let chunk_sizes: Vec<usize> = (0..n)
652            .map(|i| base + if i < remainder { 1 } else { 0 })
653            .collect();
654        let mut chunk_start = vec![0usize; n];
655        for i in 1..n {
656            chunk_start[i] = chunk_start[i - 1] + chunk_sizes[i - 1];
657        }
658
659        // `partial[w][c]` = partial sums of chunk `c` accumulated on worker `w`.
660        // Initially each worker contributes its own slice of the gradient.
661        let mut partial: Vec<Vec<Vec<f64>>> = gradients
662            .iter()
663            .map(|g| {
664                chunk_sizes
665                    .iter()
666                    .zip(chunk_start.iter())
667                    .map(|(&sz, &s)| g[s..s + sz].to_vec())
668                    .collect()
669            })
670            .collect();
671
672        // ── scatter-reduce phase ──────────────────────────────────────────────
673        // At each step, worker w receives from its left neighbour (w−1) the
674        // partial accumulation for chunk `(w − 1 − step) mod n`.
675        #[allow(clippy::needless_range_loop)]
676        for step in 0..(n - 1) {
677            let prev = partial.clone();
678            for w in 0..n {
679                let left = (w + n - 1) % n;
680                let c = (w + n - 1 - step) % n;
681                let sz = chunk_sizes[c];
682                for i in 0..sz {
683                    partial[w][c][i] += prev[left][c][i];
684                }
685            }
686        }
687
688        // After `n−1` scatter-reduce steps, worker `w` holds the fully-reduced
689        // sum in slot `(w + 1) mod n`.
690
691        // ── collect result (all-gather) ───────────────────────────────────────
692        let mut result = vec![0.0_f64; len];
693        #[allow(clippy::needless_range_loop)]
694        for w in 0..n {
695            let c = (w + 1) % n;
696            let s = chunk_start[c];
697            let sz = chunk_sizes[c];
698            for i in 0..sz {
699                result[s + i] = partial[w][c][i] / n as f64;
700            }
701        }
702
703        result
704    }
705
706    /// Simulate tree (binary) all-reduce: recursive halving/doubling.
707    fn tree_all_reduce(&self, gradients: Vec<Vec<f64>>) -> Vec<f64> {
708        let n = gradients.len();
709        if n == 0 {
710            return Vec::new();
711        }
712        if n == 1 {
713            return gradients.into_iter().next().unwrap_or_default();
714        }
715
716        let len = gradients[0].len();
717        let mut sums = vec![0.0_f64; len];
718        for grad in &gradients {
719            for (i, v) in grad.iter().enumerate() {
720                if i < len {
721                    sums[i] += v;
722                }
723            }
724        }
725        sums.iter_mut().for_each(|v| *v /= n as f64);
726        sums
727    }
728}
729
730/// Sparse gradient representation after top-k sparsification.
731#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
732pub struct SparseGradient {
733    /// Indices of the non-zero (kept) elements in the original gradient vector.
734    pub indices: Vec<usize>,
735    /// Values at the kept indices.
736    pub values: Vec<f64>,
737    /// Length of the original (dense) gradient vector.
738    pub original_len: usize,
739}
740
741/// Compresses gradient vectors via top-k sparsification to reduce communication overhead.
742#[derive(Debug, Clone, Default)]
743pub struct GradientCompressor;
744
745impl GradientCompressor {
746    /// Create a new `GradientCompressor`.
747    pub fn new() -> Self {
748        Self
749    }
750
751    /// Compress `grad` by retaining only the top-k largest-magnitude entries.
752    ///
753    /// * `sparsity` — fraction of entries to **zero out** (e.g. `0.9` keeps the top 10%).
754    ///   Clamped to `[0.0, 1.0)`.
755    pub fn compress(&self, grad: &[f64], sparsity: f64) -> SparseGradient {
756        let sparsity = sparsity.clamp(0.0, 0.9999);
757        let n = grad.len();
758        if n == 0 {
759            return SparseGradient {
760                indices: Vec::new(),
761                values: Vec::new(),
762                original_len: 0,
763            };
764        }
765
766        let keep = ((1.0 - sparsity) * n as f64).ceil() as usize;
767        let keep = keep.max(1).min(n);
768
769        // Collect (index, |value|) pairs and sort descending by magnitude.
770        let mut indexed: Vec<(usize, f64)> = grad
771            .iter()
772            .enumerate()
773            .map(|(i, &v)| (i, v.abs()))
774            .collect();
775        indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
776
777        let mut indices: Vec<usize> = indexed[..keep].iter().map(|(i, _)| *i).collect();
778        indices.sort_unstable();
779
780        let values: Vec<f64> = indices.iter().map(|&i| grad[i]).collect();
781
782        SparseGradient {
783            indices,
784            values,
785            original_len: n,
786        }
787    }
788
789    /// Decompress a `SparseGradient` back into a dense gradient vector (zero-filled elsewhere).
790    pub fn decompress(&self, sparse: &SparseGradient) -> Vec<f64> {
791        let mut dense = vec![0.0_f64; sparse.original_len];
792        for (&idx, &val) in sparse.indices.iter().zip(sparse.values.iter()) {
793            if idx < sparse.original_len {
794                dense[idx] = val;
795            }
796        }
797        dense
798    }
799}
800
801// ─────────────────────────────────────────────────────────────
802// B. Data-Parallel Training Coordinator
803// ─────────────────────────────────────────────────────────────
804
805/// A training sample for data-parallel distribution.
806///
807/// Deliberately kept generic so it can represent any supervised/self-supervised sample.
808#[derive(Debug, Clone, Serialize, Deserialize)]
809pub struct DistributedTrainingSample {
810    /// Numeric feature vector for this sample.
811    pub features: Vec<f64>,
812    /// Scalar label or target value.
813    pub label: f64,
814    /// Optional sample weight (defaults to `1.0` if `None`).
815    pub weight: Option<f64>,
816}
817
818impl DistributedTrainingSample {
819    /// Create a new sample with equal weight.
820    pub fn new(features: Vec<f64>, label: f64) -> Self {
821        Self {
822            features,
823            label,
824            weight: None,
825        }
826    }
827}
828
829/// Per-worker gradient update produced after a local forward-backward pass.
830#[derive(Debug, Clone, Serialize, Deserialize)]
831pub struct WorkerUpdate {
832    /// Identifier of the worker that computed this update.
833    pub worker_id: u32,
834    /// Flattened gradient vector from this worker.
835    pub gradients: Vec<f64>,
836    /// Training loss on the local mini-batch.
837    pub loss: f64,
838    /// Number of samples processed in this update.
839    pub samples_processed: u32,
840}
841
842/// Merged model update produced after aggregating all worker updates.
843#[derive(Debug, Clone, Serialize, Deserialize)]
844pub struct ModelUpdate {
845    /// Averaged gradient vector across all workers (weighted by sample count).
846    pub averaged_gradients: Vec<f64>,
847    /// Weighted-mean loss across all workers.
848    pub mean_loss: f64,
849    /// Total number of samples processed.
850    pub total_samples: u32,
851}
852
853/// Coordinates data-parallel training by splitting batches and merging worker gradients.
854#[derive(Debug, Clone, Default)]
855pub struct DataParallelTrainer;
856
857impl DataParallelTrainer {
858    /// Create a new `DataParallelTrainer`.
859    pub fn new() -> Self {
860        Self
861    }
862
863    /// Evenly split `data` across `n_workers` workers.
864    ///
865    /// Returns a `Vec` of sub-batches, one per worker.  If `data.len()` is not
866    /// evenly divisible some workers receive one extra sample (round-robin
867    /// assignment).
868    pub fn split_batch(
869        &self,
870        data: &[DistributedTrainingSample],
871        n_workers: u32,
872    ) -> Vec<Vec<DistributedTrainingSample>> {
873        let n = n_workers as usize;
874        if n == 0 || data.is_empty() {
875            return Vec::new();
876        }
877
878        let mut buckets: Vec<Vec<DistributedTrainingSample>> = (0..n).map(|_| Vec::new()).collect();
879        for (i, sample) in data.iter().enumerate() {
880            buckets[i % n].push(sample.clone());
881        }
882        buckets
883    }
884
885    /// Merge gradient updates from all workers into a single `ModelUpdate`.
886    ///
887    /// Gradients are averaged weighted by `samples_processed` so that workers
888    /// with larger mini-batches contribute proportionally more.
889    pub fn merge_worker_updates(&self, updates: Vec<WorkerUpdate>) -> ModelUpdate {
890        if updates.is_empty() {
891            return ModelUpdate {
892                averaged_gradients: Vec::new(),
893                mean_loss: 0.0,
894                total_samples: 0,
895            };
896        }
897
898        let total_samples: u32 = updates.iter().map(|u| u.samples_processed).sum();
899        if total_samples == 0 {
900            return ModelUpdate {
901                averaged_gradients: Vec::new(),
902                mean_loss: 0.0,
903                total_samples: 0,
904            };
905        }
906
907        // Determine gradient length from the first update with non-empty gradients.
908        let grad_len = updates.iter().map(|u| u.gradients.len()).max().unwrap_or(0);
909
910        let mut averaged_gradients = vec![0.0_f64; grad_len];
911        let mut weighted_loss = 0.0_f64;
912
913        for update in &updates {
914            let weight = update.samples_processed as f64 / total_samples as f64;
915            for (i, &g) in update.gradients.iter().enumerate() {
916                if i < grad_len {
917                    averaged_gradients[i] += g * weight;
918                }
919            }
920            weighted_loss += update.loss * weight;
921        }
922
923        ModelUpdate {
924            averaged_gradients,
925            mean_loss: weighted_loss,
926            total_samples,
927        }
928    }
929}
930
931/// Distributed embedding model trainer
932pub struct DistributedEmbeddingTrainer<M: EmbeddingModel> {
933    model: M,
934    coordinator: DistributedTrainingCoordinator,
935}
936
937impl<M: EmbeddingModel> DistributedEmbeddingTrainer<M> {
938    /// Create a new distributed trainer
939    pub async fn new(model: M, config: DistributedTrainingConfig) -> Result<Self> {
940        let coordinator = DistributedTrainingCoordinator::new(config).await?;
941
942        Ok(Self { model, coordinator })
943    }
944
945    /// Train the model in a distributed manner
946    pub async fn train(&mut self, epochs: usize) -> Result<DistributedTrainingStats> {
947        self.coordinator.train(&mut self.model, epochs).await
948    }
949
950    /// Get the trained model
951    pub fn model(&self) -> &M {
952        &self.model
953    }
954
955    /// Get mutable reference to the model
956    pub fn model_mut(&mut self) -> &mut M {
957        &mut self.model
958    }
959
960    /// Register a worker
961    pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
962        self.coordinator.register_worker(worker_info).await
963    }
964
965    /// Get training statistics
966    pub async fn get_stats(&self) -> DistributedTrainingStats {
967        self.coordinator.get_stats().await
968    }
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974    use crate::{ModelConfig, TransE};
975
976    // ── AllReduceStrategy & GradientAggregator ────────────────────────────────
977
978    #[test]
979    fn test_all_reduce_strategy_variants() {
980        let strategies = [
981            AllReduceStrategy::RingAllReduce,
982            AllReduceStrategy::TreeAllReduce,
983            AllReduceStrategy::ParameterServer,
984        ];
985        for s in &strategies {
986            let agg = GradientAggregator::new();
987            let grad = vec![1.0, 2.0, 3.0];
988            let result = agg.aggregate_gradients(&grad, s);
989            assert_eq!(result.len(), 3);
990        }
991    }
992
993    #[test]
994    fn test_ring_all_reduce_single_worker() {
995        let agg = GradientAggregator::new();
996        let grads = vec![vec![1.0, 2.0, 3.0]];
997        let result = agg.ring_all_reduce(grads);
998        assert_eq!(result, vec![1.0, 2.0, 3.0]);
999    }
1000
1001    #[test]
1002    fn test_ring_all_reduce_two_workers() {
1003        let agg = GradientAggregator::new();
1004        let grads = vec![vec![2.0, 4.0, 6.0], vec![2.0, 4.0, 6.0]];
1005        let result = agg.ring_all_reduce(grads);
1006        assert_eq!(result.len(), 3);
1007        // Mean of equal vectors should be the vector itself.
1008        for (r, expected) in result.iter().zip([2.0, 4.0, 6.0].iter()) {
1009            assert!((r - expected).abs() < 1e-9, "expected {expected}, got {r}");
1010        }
1011    }
1012
1013    #[test]
1014    fn test_ring_all_reduce_four_workers_mean() {
1015        let agg = GradientAggregator::new();
1016        let grads = vec![
1017            vec![4.0, 8.0],
1018            vec![2.0, 4.0],
1019            vec![0.0, 0.0],
1020            vec![6.0, 12.0],
1021        ];
1022        let result = agg.ring_all_reduce(grads);
1023        assert_eq!(result.len(), 2);
1024        // Mean: (4+2+0+6)/4 = 3, (8+4+0+12)/4 = 6
1025        assert!((result[0] - 3.0).abs() < 1e-6);
1026        assert!((result[1] - 6.0).abs() < 1e-6);
1027    }
1028
1029    #[test]
1030    fn test_ring_all_reduce_empty_input() {
1031        let agg = GradientAggregator::new();
1032        let result = agg.ring_all_reduce(vec![]);
1033        assert!(result.is_empty());
1034    }
1035
1036    #[test]
1037    fn test_ring_all_reduce_empty_gradient_vectors() {
1038        let agg = GradientAggregator::new();
1039        let result = agg.ring_all_reduce(vec![vec![], vec![]]);
1040        assert!(result.is_empty());
1041    }
1042
1043    #[test]
1044    fn test_aggregate_gradients_ring() {
1045        let agg = GradientAggregator::new();
1046        let grad = vec![1.0, 2.0, 3.0, 4.0];
1047        let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::RingAllReduce);
1048        assert_eq!(result.len(), 4);
1049    }
1050
1051    #[test]
1052    fn test_aggregate_gradients_tree() {
1053        let agg = GradientAggregator::new();
1054        let grad = vec![5.0, 10.0];
1055        let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::TreeAllReduce);
1056        assert_eq!(result, vec![5.0, 10.0]);
1057    }
1058
1059    #[test]
1060    fn test_aggregate_gradients_parameter_server() {
1061        let agg = GradientAggregator::new();
1062        let grad = vec![3.0, 1.0, 4.0];
1063        let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::ParameterServer);
1064        assert_eq!(result, grad);
1065    }
1066
1067    // ── GradientCompressor ────────────────────────────────────────────────────
1068
1069    #[test]
1070    fn test_compress_empty_gradient() {
1071        let comp = GradientCompressor::new();
1072        let sparse = comp.compress(&[], 0.9);
1073        assert!(sparse.indices.is_empty());
1074        assert_eq!(sparse.original_len, 0);
1075    }
1076
1077    #[test]
1078    fn test_compress_keep_all() {
1079        let comp = GradientCompressor::new();
1080        let grad = vec![1.0, -2.0, 3.0, -4.0];
1081        let sparse = comp.compress(&grad, 0.0);
1082        // sparsity=0 → keep all
1083        assert_eq!(sparse.indices.len(), 4);
1084        assert_eq!(sparse.original_len, 4);
1085    }
1086
1087    #[test]
1088    fn test_compress_top_k_selects_largest() {
1089        let comp = GradientCompressor::new();
1090        let grad = vec![0.1, 5.0, 0.2, 9.0, 0.3];
1091        // sparsity=0.6 → keep 40% = 2 entries → indices 1 (5.0) and 3 (9.0)
1092        let sparse = comp.compress(&grad, 0.6);
1093        assert_eq!(sparse.indices.len(), 2);
1094        assert!(sparse.indices.contains(&3)); // 9.0
1095        assert!(sparse.indices.contains(&1)); // 5.0
1096    }
1097
1098    #[test]
1099    fn test_decompress_roundtrip() {
1100        let comp = GradientCompressor::new();
1101        let grad = vec![0.0, 1.0, 0.0, -3.0, 0.0];
1102        let sparse = comp.compress(&grad, 0.6);
1103        let dense = comp.decompress(&sparse);
1104        assert_eq!(dense.len(), 5);
1105        // The two largest-magnitude values must be preserved.
1106        assert!((dense[3] - (-3.0)).abs() < 1e-12);
1107        assert!((dense[1] - 1.0).abs() < 1e-12);
1108    }
1109
1110    #[test]
1111    fn test_decompress_empty_sparse() {
1112        let comp = GradientCompressor::new();
1113        let sparse = SparseGradient {
1114            indices: Vec::new(),
1115            values: Vec::new(),
1116            original_len: 5,
1117        };
1118        let dense = comp.decompress(&sparse);
1119        assert_eq!(dense, vec![0.0; 5]);
1120    }
1121
1122    #[test]
1123    fn test_sparse_gradient_serialization() {
1124        let sg = SparseGradient {
1125            indices: vec![0, 2],
1126            values: vec![1.5, -2.5],
1127            original_len: 4,
1128        };
1129        let json = serde_json::to_string(&sg).expect("serialize");
1130        let sg2: SparseGradient = serde_json::from_str(&json).expect("deserialize");
1131        assert_eq!(sg, sg2);
1132    }
1133
1134    // ── DataParallelTrainer ───────────────────────────────────────────────────
1135
1136    #[test]
1137    fn test_split_batch_even() {
1138        let trainer = DataParallelTrainer::new();
1139        let samples: Vec<DistributedTrainingSample> = (0..8)
1140            .map(|i| DistributedTrainingSample::new(vec![i as f64], i as f64))
1141            .collect();
1142        let batches = trainer.split_batch(&samples, 4);
1143        assert_eq!(batches.len(), 4);
1144        for b in &batches {
1145            assert_eq!(b.len(), 2);
1146        }
1147    }
1148
1149    #[test]
1150    fn test_split_batch_uneven() {
1151        let trainer = DataParallelTrainer::new();
1152        let samples: Vec<DistributedTrainingSample> = (0..10)
1153            .map(|i| DistributedTrainingSample::new(vec![i as f64], i as f64))
1154            .collect();
1155        let batches = trainer.split_batch(&samples, 3);
1156        assert_eq!(batches.len(), 3);
1157        let total: usize = batches.iter().map(|b| b.len()).sum();
1158        assert_eq!(total, 10);
1159    }
1160
1161    #[test]
1162    fn test_split_batch_zero_workers() {
1163        let trainer = DataParallelTrainer::new();
1164        let samples = vec![DistributedTrainingSample::new(vec![1.0], 0.0)];
1165        let batches = trainer.split_batch(&samples, 0);
1166        assert!(batches.is_empty());
1167    }
1168
1169    #[test]
1170    fn test_split_batch_empty_data() {
1171        let trainer = DataParallelTrainer::new();
1172        let batches = trainer.split_batch(&[], 4);
1173        assert!(batches.is_empty());
1174    }
1175
1176    #[test]
1177    fn test_merge_worker_updates_basic() {
1178        let trainer = DataParallelTrainer::new();
1179        let updates = vec![
1180            WorkerUpdate {
1181                worker_id: 0,
1182                gradients: vec![2.0, 4.0],
1183                loss: 1.0,
1184                samples_processed: 10,
1185            },
1186            WorkerUpdate {
1187                worker_id: 1,
1188                gradients: vec![2.0, 4.0],
1189                loss: 1.0,
1190                samples_processed: 10,
1191            },
1192        ];
1193        let merged = trainer.merge_worker_updates(updates);
1194        assert_eq!(merged.total_samples, 20);
1195        assert!((merged.mean_loss - 1.0).abs() < 1e-9);
1196        assert!((merged.averaged_gradients[0] - 2.0).abs() < 1e-9);
1197        assert!((merged.averaged_gradients[1] - 4.0).abs() < 1e-9);
1198    }
1199
1200    #[test]
1201    fn test_merge_worker_updates_weighted() {
1202        let trainer = DataParallelTrainer::new();
1203        // Worker 0 has 1 sample, worker 1 has 3 samples.
1204        let updates = vec![
1205            WorkerUpdate {
1206                worker_id: 0,
1207                gradients: vec![4.0],
1208                loss: 2.0,
1209                samples_processed: 1,
1210            },
1211            WorkerUpdate {
1212                worker_id: 1,
1213                gradients: vec![0.0],
1214                loss: 0.0,
1215                samples_processed: 3,
1216            },
1217        ];
1218        let merged = trainer.merge_worker_updates(updates);
1219        assert_eq!(merged.total_samples, 4);
1220        // Weighted mean gradient: 4*0.25 + 0*0.75 = 1.0
1221        assert!((merged.averaged_gradients[0] - 1.0).abs() < 1e-9);
1222        // Weighted mean loss: 2*0.25 + 0*0.75 = 0.5
1223        assert!((merged.mean_loss - 0.5).abs() < 1e-9);
1224    }
1225
1226    #[test]
1227    fn test_merge_worker_updates_empty() {
1228        let trainer = DataParallelTrainer::new();
1229        let merged = trainer.merge_worker_updates(vec![]);
1230        assert_eq!(merged.total_samples, 0);
1231        assert!(merged.averaged_gradients.is_empty());
1232    }
1233
1234    #[test]
1235    fn test_worker_update_serialization() {
1236        let update = WorkerUpdate {
1237            worker_id: 7,
1238            gradients: vec![0.1, -0.2],
1239            loss: 0.42,
1240            samples_processed: 32,
1241        };
1242        let json = serde_json::to_string(&update).expect("serialize");
1243        let update2: WorkerUpdate = serde_json::from_str(&json).expect("deserialize");
1244        assert_eq!(update.worker_id, update2.worker_id);
1245        assert_eq!(update.samples_processed, update2.samples_processed);
1246    }
1247
1248    #[test]
1249    fn test_model_update_fields() {
1250        let mu = ModelUpdate {
1251            averaged_gradients: vec![1.0, 2.0],
1252            mean_loss: 0.5,
1253            total_samples: 100,
1254        };
1255        assert_eq!(mu.total_samples, 100);
1256        assert!((mu.mean_loss - 0.5).abs() < 1e-12);
1257    }
1258
1259    #[tokio::test]
1260    async fn test_distributed_coordinator_creation() {
1261        let config = DistributedTrainingConfig::default();
1262        let coordinator = DistributedTrainingCoordinator::new(config).await;
1263        assert!(coordinator.is_ok());
1264    }
1265
1266    #[tokio::test]
1267    async fn test_worker_registration() {
1268        let config = DistributedTrainingConfig::default();
1269        let coordinator = DistributedTrainingCoordinator::new(config).await.unwrap();
1270
1271        let worker = WorkerInfo {
1272            worker_id: 0,
1273            rank: 0,
1274            address: "127.0.0.1:8080".to_string(),
1275            status: WorkerStatus::Idle,
1276            num_gpus: 1,
1277            memory_gb: 16.0,
1278            last_heartbeat: Utc::now(),
1279        };
1280
1281        coordinator.register_worker(worker).await.unwrap();
1282        let stats = coordinator.get_worker_stats().await;
1283        assert_eq!(stats.len(), 1);
1284    }
1285
1286    #[tokio::test]
1287    async fn test_distributed_training() {
1288        let config = DistributedTrainingConfig {
1289            strategy: DistributedStrategy::DataParallel {
1290                num_workers: 2,
1291                batch_size: 128,
1292            },
1293            ..Default::default()
1294        };
1295
1296        let model_config = ModelConfig::default().with_dimensions(64);
1297        let model = TransE::new(model_config);
1298
1299        let mut trainer = DistributedEmbeddingTrainer::new(model, config)
1300            .await
1301            .unwrap();
1302
1303        // Register workers
1304        for i in 0..2 {
1305            let worker = WorkerInfo {
1306                worker_id: i,
1307                rank: i,
1308                address: format!("127.0.0.1:808{}", i),
1309                status: WorkerStatus::Idle,
1310                num_gpus: 1,
1311                memory_gb: 16.0,
1312                last_heartbeat: Utc::now(),
1313            };
1314            trainer.register_worker(worker).await.unwrap();
1315        }
1316
1317        // Train for a few epochs
1318        let stats = trainer.train(5).await.unwrap();
1319
1320        assert_eq!(stats.total_epochs, 5);
1321        assert!(stats.final_loss >= 0.0);
1322        assert_eq!(stats.num_workers, 2);
1323    }
1324
1325    #[tokio::test]
1326    async fn test_checkpoint_save_load() {
1327        let config = DistributedTrainingConfig::default();
1328        let coordinator = DistributedTrainingCoordinator::new(config).await.unwrap();
1329
1330        let model_config = ModelConfig::default();
1331        let model = TransE::new(model_config);
1332
1333        // Register a worker
1334        let worker = WorkerInfo {
1335            worker_id: 0,
1336            rank: 0,
1337            address: "127.0.0.1:8080".to_string(),
1338            status: WorkerStatus::Idle,
1339            num_gpus: 1,
1340            memory_gb: 16.0,
1341            last_heartbeat: Utc::now(),
1342        };
1343        coordinator.register_worker(worker).await.unwrap();
1344
1345        // Save checkpoint
1346        coordinator.save_checkpoint(&model, 10, 0.5).await.unwrap();
1347
1348        // Load checkpoint
1349        let checkpoint = coordinator
1350            .load_checkpoint("checkpoint_epoch_10")
1351            .await
1352            .unwrap();
1353        assert_eq!(checkpoint.epoch, 10);
1354        assert_eq!(checkpoint.loss, 0.5);
1355    }
1356}