Skip to main content

oxirs_embed/distributed_training/
mod.rs

1//! Distributed Training Module for Knowledge Graph Embeddings
2//!
3//! This module provides distributed training capabilities for knowledge graph embeddings
4//! across multiple nodes/GPUs using data parallelism and model parallelism strategies.
5//!
6//! ## Features
7//!
8//! - **Data Parallelism**: Distribute training data across multiple workers
9//! - **Model Parallelism**: Split large models across multiple devices
10//! - **Gradient Aggregation**: AllReduce, Parameter Server, Ring-AllReduce
11//! - **Fault Tolerance**: Checkpointing, recovery, and elastic scaling
12//! - **Communication**: Efficient gradient synchronization with compression
13//! - **Load Balancing**: Dynamic workload distribution
14//! - **Monitoring**: Real-time training metrics and performance tracking
15//! - **Parameter-Server Prototype** ([`parameter_server`], [`worker`], [`shard_manager`]):
16//!   in-process toy parameter server with sharded embeddings, sync/async update modes,
17//!   and `ModelShardManager` partitioning entity tables by entity-ID hash. Bounded to
18//!   4-8 workers; not a full DistBelief/Horovod replacement.
19//!
20//! ## Architecture
21//!
22//! ```text
23//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
24//! │  Worker 1   │────▶│  Coordinator│◀────│  Worker 2   │
25//! │ (GPU/CPU)   │     │   (Master)  │     │ (GPU/CPU)   │
26//! └─────────────┘     └─────────────┘     └─────────────┘
27//!       │                    │                    │
28//!       └────────────────────┴────────────────────┘
29//!                   Gradient Sync
30//! ```
31
32use anyhow::Result;
33use chrono::{DateTime, Utc};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::sync::Arc;
37use tokio::sync::{Mutex, RwLock};
38use tracing::{debug, info, warn};
39
40// Use SciRS2 for distributed computing
41use scirs2_core::distributed::{ClusterConfiguration, ClusterManager};
42use scirs2_core::ndarray_ext::Array1;
43
44use crate::EmbeddingModel;
45
46// ── Parameter-server-style distributed training prototype ────────────────────
47pub mod parameter_server;
48pub mod shard_manager;
49pub mod worker;
50
51pub use parameter_server::{
52    ParameterServer, ParameterServerConfig, ParameterServerStats, ShardSnapshot, UpdateMode,
53};
54pub use shard_manager::{ModelShardManager, ShardAssignment, ShardingStrategy};
55pub use worker::{TripleSample, Worker, WorkerConfig, WorkerLoss};
56
57/// Distributed training strategy
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum DistributedStrategy {
60    /// Data parallelism - split data across workers
61    DataParallel {
62        /// Number of workers
63        num_workers: usize,
64        /// Batch size per worker
65        batch_size: usize,
66    },
67    /// Model parallelism - split model across workers
68    ModelParallel {
69        /// Number of model shards
70        num_shards: usize,
71        /// Pipeline stages
72        pipeline_stages: usize,
73    },
74    /// Hybrid parallelism - combine data and model parallelism
75    Hybrid {
76        /// Data parallel degree
77        data_parallel_size: usize,
78        /// Model parallel degree
79        model_parallel_size: usize,
80    },
81}
82
83/// Gradient aggregation method
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub enum AggregationMethod {
86    /// AllReduce - all workers exchange gradients
87    AllReduce,
88    /// Ring-AllReduce - efficient ring-based gradient exchange
89    RingAllReduce,
90    /// Parameter Server - centralized gradient aggregation
91    ParameterServer {
92        /// Number of parameter servers
93        num_servers: usize,
94    },
95    /// Hierarchical - tree-based aggregation
96    Hierarchical {
97        /// Tree branching factor
98        branching_factor: usize,
99    },
100}
101
102/// Communication backend for distributed training
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum CommunicationBackend {
105    /// Native TCP/IP
106    Tcp,
107    /// NCCL (NVIDIA Collective Communications Library)
108    Nccl,
109    /// Gloo (Facebook's collective communications)
110    Gloo,
111    /// MPI (Message Passing Interface)
112    Mpi,
113}
114
115/// Fault tolerance configuration
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct FaultToleranceConfig {
118    /// Enable checkpointing
119    pub enable_checkpointing: bool,
120    /// Checkpoint frequency (in epochs)
121    pub checkpoint_frequency: usize,
122    /// Maximum retry attempts
123    pub max_retries: usize,
124    /// Enable elastic scaling
125    pub elastic_scaling: bool,
126    /// Heartbeat interval (seconds)
127    pub heartbeat_interval: u64,
128    /// Worker timeout (seconds)
129    pub worker_timeout: u64,
130}
131
132impl Default for FaultToleranceConfig {
133    fn default() -> Self {
134        Self {
135            enable_checkpointing: true,
136            checkpoint_frequency: 10,
137            max_retries: 3,
138            elastic_scaling: false,
139            heartbeat_interval: 30,
140            worker_timeout: 300,
141        }
142    }
143}
144
145/// Distributed training configuration
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DistributedTrainingConfig {
148    /// Distributed strategy
149    pub strategy: DistributedStrategy,
150    /// Gradient aggregation method
151    pub aggregation: AggregationMethod,
152    /// Communication backend
153    pub backend: CommunicationBackend,
154    /// Fault tolerance configuration
155    pub fault_tolerance: FaultToleranceConfig,
156    /// Enable gradient compression
157    pub gradient_compression: bool,
158    /// Compression ratio (0.0-1.0)
159    pub compression_ratio: f32,
160    /// Enable mixed precision training
161    pub mixed_precision: bool,
162    /// Gradient clipping threshold
163    pub gradient_clip: Option<f32>,
164    /// Warmup epochs before full distribution
165    pub warmup_epochs: usize,
166    /// Enable pipeline parallelism
167    pub pipeline_parallelism: bool,
168    /// Number of microbatches for pipeline
169    pub num_microbatches: usize,
170}
171
172impl Default for DistributedTrainingConfig {
173    fn default() -> Self {
174        Self {
175            strategy: DistributedStrategy::DataParallel {
176                num_workers: 4,
177                batch_size: 256,
178            },
179            aggregation: AggregationMethod::AllReduce,
180            backend: CommunicationBackend::Tcp,
181            fault_tolerance: FaultToleranceConfig::default(),
182            gradient_compression: false,
183            compression_ratio: 0.5,
184            mixed_precision: false,
185            gradient_clip: Some(1.0),
186            warmup_epochs: 5,
187            pipeline_parallelism: false,
188            num_microbatches: 4,
189        }
190    }
191}
192
193/// Worker information
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct WorkerInfo {
196    /// Worker ID
197    pub worker_id: usize,
198    /// Worker rank (global)
199    pub rank: usize,
200    /// Worker address
201    pub address: String,
202    /// Worker status
203    pub status: WorkerStatus,
204    /// Number of GPUs available
205    pub num_gpus: usize,
206    /// Memory capacity (GB)
207    pub memory_gb: f32,
208    /// Last heartbeat timestamp
209    pub last_heartbeat: DateTime<Utc>,
210}
211
212/// Worker status
213#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
214pub enum WorkerStatus {
215    /// Worker is idle
216    Idle,
217    /// Worker is training
218    Training,
219    /// Worker is synchronizing
220    Synchronizing,
221    /// Worker has failed
222    Failed,
223    /// Worker is recovering
224    Recovering,
225}
226
227/// Training checkpoint
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct TrainingCheckpoint {
230    /// Checkpoint ID
231    pub checkpoint_id: String,
232    /// Epoch number
233    pub epoch: usize,
234    /// Global step
235    pub global_step: usize,
236    /// Model state (serialized)
237    pub model_state: Vec<u8>,
238    /// Optimizer state (serialized)
239    pub optimizer_state: Vec<u8>,
240    /// Training loss
241    pub loss: f64,
242    /// Timestamp
243    pub timestamp: DateTime<Utc>,
244}
245
246/// Distributed training statistics
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct DistributedTrainingStats {
249    /// Total epochs
250    pub total_epochs: usize,
251    /// Total steps
252    pub total_steps: usize,
253    /// Final loss
254    pub final_loss: f64,
255    /// Training time (seconds)
256    pub training_time: f64,
257    /// Number of workers
258    pub num_workers: usize,
259    /// Average throughput (samples/sec)
260    pub throughput: f64,
261    /// Communication time (seconds)
262    pub communication_time: f64,
263    /// Computation time (seconds)
264    pub computation_time: f64,
265    /// Number of checkpoints saved
266    pub num_checkpoints: usize,
267    /// Number of worker failures
268    pub num_failures: usize,
269    /// Loss history per epoch
270    pub loss_history: Vec<f64>,
271}
272
273/// Distributed training coordinator
274pub struct DistributedTrainingCoordinator {
275    config: DistributedTrainingConfig,
276    workers: Arc<RwLock<HashMap<usize, WorkerInfo>>>,
277    checkpoints: Arc<Mutex<Vec<TrainingCheckpoint>>>,
278    cluster_manager: Arc<ClusterManager>,
279    stats: Arc<Mutex<DistributedTrainingStats>>,
280}
281
282impl DistributedTrainingCoordinator {
283    /// Create a new distributed training coordinator
284    pub async fn new(config: DistributedTrainingConfig) -> Result<Self> {
285        info!("Initializing distributed training coordinator");
286
287        // Create cluster configuration
288        let cluster_config = ClusterConfiguration::default();
289        let cluster_manager = Arc::new(
290            ClusterManager::new(cluster_config)
291                .map_err(|e| anyhow::anyhow!("Failed to create cluster manager: {}", e))?,
292        );
293
294        Ok(Self {
295            config,
296            workers: Arc::new(RwLock::new(HashMap::new())),
297            checkpoints: Arc::new(Mutex::new(Vec::new())),
298            cluster_manager,
299            stats: Arc::new(Mutex::new(DistributedTrainingStats {
300                total_epochs: 0,
301                total_steps: 0,
302                final_loss: 0.0,
303                training_time: 0.0,
304                num_workers: 0,
305                throughput: 0.0,
306                communication_time: 0.0,
307                computation_time: 0.0,
308                num_checkpoints: 0,
309                num_failures: 0,
310                loss_history: Vec::new(),
311            })),
312        })
313    }
314
315    /// Register a worker
316    pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
317        info!(
318            "Registering worker {}: {}",
319            worker_info.worker_id, worker_info.address
320        );
321
322        let mut workers = self.workers.write().await;
323        workers.insert(worker_info.worker_id, worker_info);
324
325        let mut stats = self.stats.lock().await;
326        stats.num_workers = workers.len();
327
328        Ok(())
329    }
330
331    /// Deregister a worker
332    pub async fn deregister_worker(&self, worker_id: usize) -> Result<()> {
333        warn!("Deregistering worker {}", worker_id);
334
335        let mut workers = self.workers.write().await;
336        workers.remove(&worker_id);
337
338        let mut stats = self.stats.lock().await;
339        stats.num_workers = workers.len();
340        stats.num_failures += 1;
341
342        Ok(())
343    }
344
345    /// Update worker status
346    pub async fn update_worker_status(&self, worker_id: usize, status: WorkerStatus) -> Result<()> {
347        let mut workers = self.workers.write().await;
348        if let Some(worker) = workers.get_mut(&worker_id) {
349            worker.status = status;
350            worker.last_heartbeat = Utc::now();
351        }
352        Ok(())
353    }
354
355    /// Coordinate distributed training
356    pub async fn train<M: EmbeddingModel>(
357        &mut self,
358        model: &mut M,
359        epochs: usize,
360    ) -> Result<DistributedTrainingStats> {
361        info!("Starting distributed training for {} epochs", epochs);
362
363        let start_time = std::time::Instant::now();
364        let mut total_comm_time = 0.0;
365        let mut total_comp_time = 0.0;
366
367        // Initialize distributed optimizer
368        self.initialize_optimizer().await?;
369
370        for epoch in 0..epochs {
371            debug!("Epoch {}/{}", epoch + 1, epochs);
372
373            // Distribute work to workers
374            let comp_start = std::time::Instant::now();
375            let batch_results = self.distribute_training_batch(model, epoch).await?;
376            let comp_time = comp_start.elapsed().as_secs_f64();
377            total_comp_time += comp_time;
378
379            // Aggregate gradients
380            let comm_start = std::time::Instant::now();
381            let avg_loss = self.aggregate_gradients(&batch_results).await?;
382            let comm_time = comm_start.elapsed().as_secs_f64();
383            total_comm_time += comm_time;
384
385            // Update statistics
386            {
387                let mut stats = self.stats.lock().await;
388                stats.total_epochs = epoch + 1;
389                stats.loss_history.push(avg_loss);
390                stats.final_loss = avg_loss;
391            }
392
393            // Save checkpoint if needed
394            if self.config.fault_tolerance.enable_checkpointing
395                && (epoch + 1) % self.config.fault_tolerance.checkpoint_frequency == 0
396            {
397                self.save_checkpoint(model, epoch, avg_loss).await?;
398            }
399
400            info!(
401                "Epoch {}: loss={:.6}, comp_time={:.2}s, comm_time={:.2}s",
402                epoch + 1,
403                avg_loss,
404                comp_time,
405                comm_time
406            );
407        }
408
409        let elapsed = start_time.elapsed().as_secs_f64();
410
411        // Finalize statistics
412        let stats = {
413            let mut stats = self.stats.lock().await;
414            stats.training_time = elapsed;
415            stats.communication_time = total_comm_time;
416            stats.computation_time = total_comp_time;
417            stats.throughput = (epochs as f64) / elapsed;
418            stats.clone()
419        };
420
421        info!("Distributed training completed in {:.2}s", elapsed);
422        info!("Final loss: {:.6}", stats.final_loss);
423        info!("Throughput: {:.2} epochs/sec", stats.throughput);
424
425        Ok(stats)
426    }
427
428    /// Initialize distributed optimizer
429    async fn initialize_optimizer(&mut self) -> Result<()> {
430        debug!("Initializing distributed optimizer");
431
432        // In a real implementation, this would initialize optimizer state
433        // For now, this is a placeholder
434
435        Ok(())
436    }
437
438    /// Distribute training batch to workers
439    async fn distribute_training_batch<M: EmbeddingModel>(
440        &self,
441        _model: &M,
442        epoch: usize,
443    ) -> Result<Vec<WorkerResult>> {
444        let workers = self.workers.read().await;
445        let num_workers = workers.len();
446
447        if num_workers == 0 {
448            return Err(anyhow::anyhow!("No workers available"));
449        }
450
451        // Simulate distributed training (in a real implementation, this would
452        // send batches to workers via network communication)
453        let mut results = Vec::new();
454        for (worker_id, _) in workers.iter() {
455            results.push(WorkerResult {
456                worker_id: *worker_id,
457                epoch,
458                loss: 0.1 * (1.0 - epoch as f64 / 100.0).max(0.01),
459                num_samples: 1000,
460                gradients: HashMap::new(),
461            });
462        }
463
464        Ok(results)
465    }
466
467    /// Aggregate gradients from workers
468    async fn aggregate_gradients(&self, results: &[WorkerResult]) -> Result<f64> {
469        if results.is_empty() {
470            return Err(anyhow::anyhow!("No results to aggregate"));
471        }
472
473        // Calculate average loss
474        let avg_loss = results.iter().map(|r| r.loss).sum::<f64>() / results.len() as f64;
475
476        // In a real implementation, this would aggregate gradients using
477        // the configured aggregation method (AllReduce, Parameter Server, etc.)
478        match &self.config.aggregation {
479            AggregationMethod::AllReduce => {
480                debug!("Using AllReduce for gradient aggregation");
481                // Use distributed aggregation
482                // In production, implement actual AllReduce algorithm
483            }
484            AggregationMethod::RingAllReduce => {
485                debug!("Using Ring-AllReduce for gradient aggregation");
486                // Implement ring-based gradient exchange
487            }
488            AggregationMethod::ParameterServer { num_servers } => {
489                debug!("Using Parameter Server with {} servers", num_servers);
490                // Implement parameter server aggregation
491            }
492            AggregationMethod::Hierarchical { branching_factor } => {
493                debug!(
494                    "Using Hierarchical aggregation with branching factor {}",
495                    branching_factor
496                );
497                // Implement tree-based aggregation
498            }
499        }
500
501        Ok(avg_loss)
502    }
503
504    /// Save training checkpoint
505    async fn save_checkpoint<M: EmbeddingModel>(
506        &self,
507        _model: &M,
508        epoch: usize,
509        loss: f64,
510    ) -> Result<()> {
511        info!("Saving checkpoint at epoch {}", epoch);
512
513        let checkpoint = TrainingCheckpoint {
514            checkpoint_id: format!("checkpoint_epoch_{}", epoch),
515            epoch,
516            global_step: epoch * 1000,   // Simplified
517            model_state: Vec::new(),     // In real impl, serialize model state
518            optimizer_state: Vec::new(), // In real impl, serialize optimizer state
519            loss,
520            timestamp: Utc::now(),
521        };
522
523        let mut checkpoints = self.checkpoints.lock().await;
524        checkpoints.push(checkpoint);
525
526        let mut stats = self.stats.lock().await;
527        stats.num_checkpoints += 1;
528
529        Ok(())
530    }
531
532    /// Load training checkpoint
533    pub async fn load_checkpoint(&self, checkpoint_id: &str) -> Result<TrainingCheckpoint> {
534        let checkpoints = self.checkpoints.lock().await;
535        checkpoints
536            .iter()
537            .find(|c| c.checkpoint_id == checkpoint_id)
538            .cloned()
539            .ok_or_else(|| anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id))
540    }
541
542    /// Get worker statistics
543    pub async fn get_worker_stats(&self) -> HashMap<usize, WorkerInfo> {
544        self.workers.read().await.clone()
545    }
546
547    /// Get training statistics
548    pub async fn get_stats(&self) -> DistributedTrainingStats {
549        self.stats.lock().await.clone()
550    }
551
552    /// Monitor worker health (heartbeat check)
553    pub async fn monitor_workers(&self) -> Result<()> {
554        let timeout_duration =
555            std::time::Duration::from_secs(self.config.fault_tolerance.worker_timeout);
556
557        let workers = self.workers.read().await;
558        let now = Utc::now();
559
560        for (worker_id, worker) in workers.iter() {
561            let elapsed = now.signed_duration_since(worker.last_heartbeat);
562            if elapsed.num_seconds() as u64 > timeout_duration.as_secs() {
563                warn!(
564                    "Worker {} timed out (last heartbeat: {:?})",
565                    worker_id, worker.last_heartbeat
566                );
567                // In a real implementation, trigger worker recovery or replacement
568            }
569        }
570
571        Ok(())
572    }
573}
574
575/// Worker training result
576#[derive(Debug, Clone)]
577struct WorkerResult {
578    worker_id: usize,
579    epoch: usize,
580    loss: f64,
581    num_samples: usize,
582    gradients: HashMap<String, Array1<f32>>,
583}
584
585// ─────────────────────────────────────────────────────────────
586// A. Gradient Aggregation & Compression
587// ─────────────────────────────────────────────────────────────
588
589/// Strategy for all-reduce gradient aggregation across distributed workers.
590#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
591pub enum AllReduceStrategy {
592    /// Ring-based all-reduce: workers arranged in a ring pass partial sums around.
593    RingAllReduce,
594    /// Tree-based all-reduce: hierarchical reduction over a binary tree topology.
595    TreeAllReduce,
596    /// Parameter server: a central server accumulates and broadcasts gradients.
597    ParameterServer,
598}
599
600/// Aggregates gradients from distributed workers.
601#[derive(Debug, Clone, Default)]
602pub struct GradientAggregator;
603
604impl GradientAggregator {
605    /// Create a new `GradientAggregator`.
606    pub fn new() -> Self {
607        Self
608    }
609
610    /// Aggregate `local_grad` from this worker together with gradients that have
611    /// already been reduced on other workers, using the given `strategy`.
612    ///
613    /// For the single-worker case the function simply returns a normalised copy of
614    /// `local_grad`.  In a real multi-node scenario the caller would pass the
615    /// collected per-worker slices through `ring_all_reduce` or the tree variant.
616    pub fn aggregate_gradients(
617        &self,
618        local_grad: &[f64],
619        strategy: &AllReduceStrategy,
620    ) -> Vec<f64> {
621        match strategy {
622            AllReduceStrategy::RingAllReduce => {
623                // Treat the single local gradient as the only worker contribution.
624                self.ring_all_reduce(vec![local_grad.to_vec()])
625            }
626            AllReduceStrategy::TreeAllReduce => self.tree_all_reduce(vec![local_grad.to_vec()]),
627            AllReduceStrategy::ParameterServer => {
628                // Parameter-server: accept local grad and average (single worker path).
629                local_grad.to_vec()
630            }
631        }
632    }
633
634    /// Simulate ring all-reduce over a set of per-worker gradient vectors.
635    ///
636    /// Ring all-reduce arranges `n` workers in a ring.  It runs in two phases:
637    ///
638    /// 1. **Scatter-reduce** (`n−1` steps): at step `s`, each worker `w` passes
639    ///    the accumulated data for chunk `(w − s) mod n` to its right neighbour
640    ///    `(w + 1) mod n`, which adds it to its own copy.  After `n−1` steps,
641    ///    worker `w` holds the fully-reduced sum for chunk `(w + 1) mod n`.
642    ///
643    /// 2. **All-gather**: collect the fully-reduced chunk from each owning worker
644    ///    and divide by `n` to obtain the mean.
645    ///
646    /// The mathematical result equals the element-wise mean of all input vectors.
647    /// This simulation runs synchronously on the calling thread with no I/O.
648    pub fn ring_all_reduce(&self, gradients: Vec<Vec<f64>>) -> Vec<f64> {
649        let n = gradients.len();
650        if n == 0 {
651            return Vec::new();
652        }
653        if n == 1 {
654            return gradients.into_iter().next().unwrap_or_default();
655        }
656
657        let len = gradients[0].len();
658        if len == 0 {
659            return Vec::new();
660        }
661
662        // Divide the gradient vector into `n` chunks.  Chunks are sized as
663        // evenly as possible; the first `remainder` chunks get one extra element.
664        let base = len / n;
665        let remainder = len % n;
666        let chunk_sizes: Vec<usize> = (0..n)
667            .map(|i| base + if i < remainder { 1 } else { 0 })
668            .collect();
669        let mut chunk_start = vec![0usize; n];
670        for i in 1..n {
671            chunk_start[i] = chunk_start[i - 1] + chunk_sizes[i - 1];
672        }
673
674        // `partial[w][c]` = partial sums of chunk `c` accumulated on worker `w`.
675        // Initially each worker contributes its own slice of the gradient.
676        let mut partial: Vec<Vec<Vec<f64>>> = gradients
677            .iter()
678            .map(|g| {
679                chunk_sizes
680                    .iter()
681                    .zip(chunk_start.iter())
682                    .map(|(&sz, &s)| g[s..s + sz].to_vec())
683                    .collect()
684            })
685            .collect();
686
687        // ── scatter-reduce phase ──────────────────────────────────────────────
688        // At each step, worker w receives from its left neighbour (w−1) the
689        // partial accumulation for chunk `(w − 1 − step) mod n`.
690        #[allow(clippy::needless_range_loop)]
691        for step in 0..(n - 1) {
692            let prev = partial.clone();
693            for w in 0..n {
694                let left = (w + n - 1) % n;
695                let c = (w + n - 1 - step) % n;
696                let sz = chunk_sizes[c];
697                for i in 0..sz {
698                    partial[w][c][i] += prev[left][c][i];
699                }
700            }
701        }
702
703        // After `n−1` scatter-reduce steps, worker `w` holds the fully-reduced
704        // sum in slot `(w + 1) mod n`.
705
706        // ── collect result (all-gather) ───────────────────────────────────────
707        let mut result = vec![0.0_f64; len];
708        #[allow(clippy::needless_range_loop)]
709        for w in 0..n {
710            let c = (w + 1) % n;
711            let s = chunk_start[c];
712            let sz = chunk_sizes[c];
713            for i in 0..sz {
714                result[s + i] = partial[w][c][i] / n as f64;
715            }
716        }
717
718        result
719    }
720
721    /// Simulate tree (binary) all-reduce: recursive halving/doubling.
722    fn tree_all_reduce(&self, gradients: Vec<Vec<f64>>) -> Vec<f64> {
723        let n = gradients.len();
724        if n == 0 {
725            return Vec::new();
726        }
727        if n == 1 {
728            return gradients.into_iter().next().unwrap_or_default();
729        }
730
731        let len = gradients[0].len();
732        let mut sums = vec![0.0_f64; len];
733        for grad in &gradients {
734            for (i, v) in grad.iter().enumerate() {
735                if i < len {
736                    sums[i] += v;
737                }
738            }
739        }
740        sums.iter_mut().for_each(|v| *v /= n as f64);
741        sums
742    }
743}
744
745/// Sparse gradient representation after top-k sparsification.
746#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
747pub struct SparseGradient {
748    /// Indices of the non-zero (kept) elements in the original gradient vector.
749    pub indices: Vec<usize>,
750    /// Values at the kept indices.
751    pub values: Vec<f64>,
752    /// Length of the original (dense) gradient vector.
753    pub original_len: usize,
754}
755
756/// Compresses gradient vectors via top-k sparsification to reduce communication overhead.
757#[derive(Debug, Clone, Default)]
758pub struct GradientCompressor;
759
760impl GradientCompressor {
761    /// Create a new `GradientCompressor`.
762    pub fn new() -> Self {
763        Self
764    }
765
766    /// Compress `grad` by retaining only the top-k largest-magnitude entries.
767    ///
768    /// * `sparsity` — fraction of entries to **zero out** (e.g. `0.9` keeps the top 10%).
769    ///   Clamped to `[0.0, 1.0)`.
770    pub fn compress(&self, grad: &[f64], sparsity: f64) -> SparseGradient {
771        let sparsity = sparsity.clamp(0.0, 0.9999);
772        let n = grad.len();
773        if n == 0 {
774            return SparseGradient {
775                indices: Vec::new(),
776                values: Vec::new(),
777                original_len: 0,
778            };
779        }
780
781        let keep = ((1.0 - sparsity) * n as f64).ceil() as usize;
782        let keep = keep.max(1).min(n);
783
784        // Collect (index, |value|) pairs and sort descending by magnitude.
785        let mut indexed: Vec<(usize, f64)> = grad
786            .iter()
787            .enumerate()
788            .map(|(i, &v)| (i, v.abs()))
789            .collect();
790        indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
791
792        let mut indices: Vec<usize> = indexed[..keep].iter().map(|(i, _)| *i).collect();
793        indices.sort_unstable();
794
795        let values: Vec<f64> = indices.iter().map(|&i| grad[i]).collect();
796
797        SparseGradient {
798            indices,
799            values,
800            original_len: n,
801        }
802    }
803
804    /// Decompress a `SparseGradient` back into a dense gradient vector (zero-filled elsewhere).
805    pub fn decompress(&self, sparse: &SparseGradient) -> Vec<f64> {
806        let mut dense = vec![0.0_f64; sparse.original_len];
807        for (&idx, &val) in sparse.indices.iter().zip(sparse.values.iter()) {
808            if idx < sparse.original_len {
809                dense[idx] = val;
810            }
811        }
812        dense
813    }
814}
815
816// ─────────────────────────────────────────────────────────────
817// B. Data-Parallel Training Coordinator
818// ─────────────────────────────────────────────────────────────
819
820/// A training sample for data-parallel distribution.
821///
822/// Deliberately kept generic so it can represent any supervised/self-supervised sample.
823#[derive(Debug, Clone, Serialize, Deserialize)]
824pub struct DistributedTrainingSample {
825    /// Numeric feature vector for this sample.
826    pub features: Vec<f64>,
827    /// Scalar label or target value.
828    pub label: f64,
829    /// Optional sample weight (defaults to `1.0` if `None`).
830    pub weight: Option<f64>,
831}
832
833impl DistributedTrainingSample {
834    /// Create a new sample with equal weight.
835    pub fn new(features: Vec<f64>, label: f64) -> Self {
836        Self {
837            features,
838            label,
839            weight: None,
840        }
841    }
842}
843
844/// Per-worker gradient update produced after a local forward-backward pass.
845#[derive(Debug, Clone, Serialize, Deserialize)]
846pub struct WorkerUpdate {
847    /// Identifier of the worker that computed this update.
848    pub worker_id: u32,
849    /// Flattened gradient vector from this worker.
850    pub gradients: Vec<f64>,
851    /// Training loss on the local mini-batch.
852    pub loss: f64,
853    /// Number of samples processed in this update.
854    pub samples_processed: u32,
855}
856
857/// Merged model update produced after aggregating all worker updates.
858#[derive(Debug, Clone, Serialize, Deserialize)]
859pub struct ModelUpdate {
860    /// Averaged gradient vector across all workers (weighted by sample count).
861    pub averaged_gradients: Vec<f64>,
862    /// Weighted-mean loss across all workers.
863    pub mean_loss: f64,
864    /// Total number of samples processed.
865    pub total_samples: u32,
866}
867
868/// Coordinates data-parallel training by splitting batches and merging worker gradients.
869#[derive(Debug, Clone, Default)]
870pub struct DataParallelTrainer;
871
872impl DataParallelTrainer {
873    /// Create a new `DataParallelTrainer`.
874    pub fn new() -> Self {
875        Self
876    }
877
878    /// Evenly split `data` across `n_workers` workers.
879    ///
880    /// Returns a `Vec` of sub-batches, one per worker.  If `data.len()` is not
881    /// evenly divisible some workers receive one extra sample (round-robin
882    /// assignment).
883    pub fn split_batch(
884        &self,
885        data: &[DistributedTrainingSample],
886        n_workers: u32,
887    ) -> Vec<Vec<DistributedTrainingSample>> {
888        let n = n_workers as usize;
889        if n == 0 || data.is_empty() {
890            return Vec::new();
891        }
892
893        let mut buckets: Vec<Vec<DistributedTrainingSample>> = (0..n).map(|_| Vec::new()).collect();
894        for (i, sample) in data.iter().enumerate() {
895            buckets[i % n].push(sample.clone());
896        }
897        buckets
898    }
899
900    /// Merge gradient updates from all workers into a single `ModelUpdate`.
901    ///
902    /// Gradients are averaged weighted by `samples_processed` so that workers
903    /// with larger mini-batches contribute proportionally more.
904    pub fn merge_worker_updates(&self, updates: Vec<WorkerUpdate>) -> ModelUpdate {
905        if updates.is_empty() {
906            return ModelUpdate {
907                averaged_gradients: Vec::new(),
908                mean_loss: 0.0,
909                total_samples: 0,
910            };
911        }
912
913        let total_samples: u32 = updates.iter().map(|u| u.samples_processed).sum();
914        if total_samples == 0 {
915            return ModelUpdate {
916                averaged_gradients: Vec::new(),
917                mean_loss: 0.0,
918                total_samples: 0,
919            };
920        }
921
922        // Determine gradient length from the first update with non-empty gradients.
923        let grad_len = updates.iter().map(|u| u.gradients.len()).max().unwrap_or(0);
924
925        let mut averaged_gradients = vec![0.0_f64; grad_len];
926        let mut weighted_loss = 0.0_f64;
927
928        for update in &updates {
929            let weight = update.samples_processed as f64 / total_samples as f64;
930            for (i, &g) in update.gradients.iter().enumerate() {
931                if i < grad_len {
932                    averaged_gradients[i] += g * weight;
933                }
934            }
935            weighted_loss += update.loss * weight;
936        }
937
938        ModelUpdate {
939            averaged_gradients,
940            mean_loss: weighted_loss,
941            total_samples,
942        }
943    }
944}
945
946/// Distributed embedding model trainer
947pub struct DistributedEmbeddingTrainer<M: EmbeddingModel> {
948    model: M,
949    coordinator: DistributedTrainingCoordinator,
950}
951
952impl<M: EmbeddingModel> DistributedEmbeddingTrainer<M> {
953    /// Create a new distributed trainer
954    pub async fn new(model: M, config: DistributedTrainingConfig) -> Result<Self> {
955        let coordinator = DistributedTrainingCoordinator::new(config).await?;
956
957        Ok(Self { model, coordinator })
958    }
959
960    /// Train the model in a distributed manner
961    pub async fn train(&mut self, epochs: usize) -> Result<DistributedTrainingStats> {
962        self.coordinator.train(&mut self.model, epochs).await
963    }
964
965    /// Get the trained model
966    pub fn model(&self) -> &M {
967        &self.model
968    }
969
970    /// Get mutable reference to the model
971    pub fn model_mut(&mut self) -> &mut M {
972        &mut self.model
973    }
974
975    /// Register a worker
976    pub async fn register_worker(&self, worker_info: WorkerInfo) -> Result<()> {
977        self.coordinator.register_worker(worker_info).await
978    }
979
980    /// Get training statistics
981    pub async fn get_stats(&self) -> DistributedTrainingStats {
982        self.coordinator.get_stats().await
983    }
984}
985
986#[cfg(test)]
987mod tests {
988    use super::*;
989    use crate::{ModelConfig, TransE};
990
991    // ── AllReduceStrategy & GradientAggregator ────────────────────────────────
992
993    #[test]
994    fn test_all_reduce_strategy_variants() {
995        let strategies = [
996            AllReduceStrategy::RingAllReduce,
997            AllReduceStrategy::TreeAllReduce,
998            AllReduceStrategy::ParameterServer,
999        ];
1000        for s in &strategies {
1001            let agg = GradientAggregator::new();
1002            let grad = vec![1.0, 2.0, 3.0];
1003            let result = agg.aggregate_gradients(&grad, s);
1004            assert_eq!(result.len(), 3);
1005        }
1006    }
1007
1008    #[test]
1009    fn test_ring_all_reduce_single_worker() {
1010        let agg = GradientAggregator::new();
1011        let grads = vec![vec![1.0, 2.0, 3.0]];
1012        let result = agg.ring_all_reduce(grads);
1013        assert_eq!(result, vec![1.0, 2.0, 3.0]);
1014    }
1015
1016    #[test]
1017    fn test_ring_all_reduce_two_workers() {
1018        let agg = GradientAggregator::new();
1019        let grads = vec![vec![2.0, 4.0, 6.0], vec![2.0, 4.0, 6.0]];
1020        let result = agg.ring_all_reduce(grads);
1021        assert_eq!(result.len(), 3);
1022        // Mean of equal vectors should be the vector itself.
1023        for (r, expected) in result.iter().zip([2.0, 4.0, 6.0].iter()) {
1024            assert!((r - expected).abs() < 1e-9, "expected {expected}, got {r}");
1025        }
1026    }
1027
1028    #[test]
1029    fn test_ring_all_reduce_four_workers_mean() {
1030        let agg = GradientAggregator::new();
1031        let grads = vec![
1032            vec![4.0, 8.0],
1033            vec![2.0, 4.0],
1034            vec![0.0, 0.0],
1035            vec![6.0, 12.0],
1036        ];
1037        let result = agg.ring_all_reduce(grads);
1038        assert_eq!(result.len(), 2);
1039        // Mean: (4+2+0+6)/4 = 3, (8+4+0+12)/4 = 6
1040        assert!((result[0] - 3.0).abs() < 1e-6);
1041        assert!((result[1] - 6.0).abs() < 1e-6);
1042    }
1043
1044    #[test]
1045    fn test_ring_all_reduce_empty_input() {
1046        let agg = GradientAggregator::new();
1047        let result = agg.ring_all_reduce(vec![]);
1048        assert!(result.is_empty());
1049    }
1050
1051    #[test]
1052    fn test_ring_all_reduce_empty_gradient_vectors() {
1053        let agg = GradientAggregator::new();
1054        let result = agg.ring_all_reduce(vec![vec![], vec![]]);
1055        assert!(result.is_empty());
1056    }
1057
1058    #[test]
1059    fn test_aggregate_gradients_ring() {
1060        let agg = GradientAggregator::new();
1061        let grad = vec![1.0, 2.0, 3.0, 4.0];
1062        let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::RingAllReduce);
1063        assert_eq!(result.len(), 4);
1064    }
1065
1066    #[test]
1067    fn test_aggregate_gradients_tree() {
1068        let agg = GradientAggregator::new();
1069        let grad = vec![5.0, 10.0];
1070        let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::TreeAllReduce);
1071        assert_eq!(result, vec![5.0, 10.0]);
1072    }
1073
1074    #[test]
1075    fn test_aggregate_gradients_parameter_server() {
1076        let agg = GradientAggregator::new();
1077        let grad = vec![3.0, 1.0, 4.0];
1078        let result = agg.aggregate_gradients(&grad, &AllReduceStrategy::ParameterServer);
1079        assert_eq!(result, grad);
1080    }
1081
1082    // ── GradientCompressor ────────────────────────────────────────────────────
1083
1084    #[test]
1085    fn test_compress_empty_gradient() {
1086        let comp = GradientCompressor::new();
1087        let sparse = comp.compress(&[], 0.9);
1088        assert!(sparse.indices.is_empty());
1089        assert_eq!(sparse.original_len, 0);
1090    }
1091
1092    #[test]
1093    fn test_compress_keep_all() {
1094        let comp = GradientCompressor::new();
1095        let grad = vec![1.0, -2.0, 3.0, -4.0];
1096        let sparse = comp.compress(&grad, 0.0);
1097        // sparsity=0 → keep all
1098        assert_eq!(sparse.indices.len(), 4);
1099        assert_eq!(sparse.original_len, 4);
1100    }
1101
1102    #[test]
1103    fn test_compress_top_k_selects_largest() {
1104        let comp = GradientCompressor::new();
1105        let grad = vec![0.1, 5.0, 0.2, 9.0, 0.3];
1106        // sparsity=0.6 → keep 40% = 2 entries → indices 1 (5.0) and 3 (9.0)
1107        let sparse = comp.compress(&grad, 0.6);
1108        assert_eq!(sparse.indices.len(), 2);
1109        assert!(sparse.indices.contains(&3)); // 9.0
1110        assert!(sparse.indices.contains(&1)); // 5.0
1111    }
1112
1113    #[test]
1114    fn test_decompress_roundtrip() {
1115        let comp = GradientCompressor::new();
1116        let grad = vec![0.0, 1.0, 0.0, -3.0, 0.0];
1117        let sparse = comp.compress(&grad, 0.6);
1118        let dense = comp.decompress(&sparse);
1119        assert_eq!(dense.len(), 5);
1120        // The two largest-magnitude values must be preserved.
1121        assert!((dense[3] - (-3.0)).abs() < 1e-12);
1122        assert!((dense[1] - 1.0).abs() < 1e-12);
1123    }
1124
1125    #[test]
1126    fn test_decompress_empty_sparse() {
1127        let comp = GradientCompressor::new();
1128        let sparse = SparseGradient {
1129            indices: Vec::new(),
1130            values: Vec::new(),
1131            original_len: 5,
1132        };
1133        let dense = comp.decompress(&sparse);
1134        assert_eq!(dense, vec![0.0; 5]);
1135    }
1136
1137    #[test]
1138    fn test_sparse_gradient_serialization() {
1139        let sg = SparseGradient {
1140            indices: vec![0, 2],
1141            values: vec![1.5, -2.5],
1142            original_len: 4,
1143        };
1144        let json = serde_json::to_string(&sg).expect("serialize");
1145        let sg2: SparseGradient = serde_json::from_str(&json).expect("deserialize");
1146        assert_eq!(sg, sg2);
1147    }
1148
1149    // ── DataParallelTrainer ───────────────────────────────────────────────────
1150
1151    #[test]
1152    fn test_split_batch_even() {
1153        let trainer = DataParallelTrainer::new();
1154        let samples: Vec<DistributedTrainingSample> = (0..8)
1155            .map(|i| DistributedTrainingSample::new(vec![i as f64], i as f64))
1156            .collect();
1157        let batches = trainer.split_batch(&samples, 4);
1158        assert_eq!(batches.len(), 4);
1159        for b in &batches {
1160            assert_eq!(b.len(), 2);
1161        }
1162    }
1163
1164    #[test]
1165    fn test_split_batch_uneven() {
1166        let trainer = DataParallelTrainer::new();
1167        let samples: Vec<DistributedTrainingSample> = (0..10)
1168            .map(|i| DistributedTrainingSample::new(vec![i as f64], i as f64))
1169            .collect();
1170        let batches = trainer.split_batch(&samples, 3);
1171        assert_eq!(batches.len(), 3);
1172        let total: usize = batches.iter().map(|b| b.len()).sum();
1173        assert_eq!(total, 10);
1174    }
1175
1176    #[test]
1177    fn test_split_batch_zero_workers() {
1178        let trainer = DataParallelTrainer::new();
1179        let samples = vec![DistributedTrainingSample::new(vec![1.0], 0.0)];
1180        let batches = trainer.split_batch(&samples, 0);
1181        assert!(batches.is_empty());
1182    }
1183
1184    #[test]
1185    fn test_split_batch_empty_data() {
1186        let trainer = DataParallelTrainer::new();
1187        let batches = trainer.split_batch(&[], 4);
1188        assert!(batches.is_empty());
1189    }
1190
1191    #[test]
1192    fn test_merge_worker_updates_basic() {
1193        let trainer = DataParallelTrainer::new();
1194        let updates = vec![
1195            WorkerUpdate {
1196                worker_id: 0,
1197                gradients: vec![2.0, 4.0],
1198                loss: 1.0,
1199                samples_processed: 10,
1200            },
1201            WorkerUpdate {
1202                worker_id: 1,
1203                gradients: vec![2.0, 4.0],
1204                loss: 1.0,
1205                samples_processed: 10,
1206            },
1207        ];
1208        let merged = trainer.merge_worker_updates(updates);
1209        assert_eq!(merged.total_samples, 20);
1210        assert!((merged.mean_loss - 1.0).abs() < 1e-9);
1211        assert!((merged.averaged_gradients[0] - 2.0).abs() < 1e-9);
1212        assert!((merged.averaged_gradients[1] - 4.0).abs() < 1e-9);
1213    }
1214
1215    #[test]
1216    fn test_merge_worker_updates_weighted() {
1217        let trainer = DataParallelTrainer::new();
1218        // Worker 0 has 1 sample, worker 1 has 3 samples.
1219        let updates = vec![
1220            WorkerUpdate {
1221                worker_id: 0,
1222                gradients: vec![4.0],
1223                loss: 2.0,
1224                samples_processed: 1,
1225            },
1226            WorkerUpdate {
1227                worker_id: 1,
1228                gradients: vec![0.0],
1229                loss: 0.0,
1230                samples_processed: 3,
1231            },
1232        ];
1233        let merged = trainer.merge_worker_updates(updates);
1234        assert_eq!(merged.total_samples, 4);
1235        // Weighted mean gradient: 4*0.25 + 0*0.75 = 1.0
1236        assert!((merged.averaged_gradients[0] - 1.0).abs() < 1e-9);
1237        // Weighted mean loss: 2*0.25 + 0*0.75 = 0.5
1238        assert!((merged.mean_loss - 0.5).abs() < 1e-9);
1239    }
1240
1241    #[test]
1242    fn test_merge_worker_updates_empty() {
1243        let trainer = DataParallelTrainer::new();
1244        let merged = trainer.merge_worker_updates(vec![]);
1245        assert_eq!(merged.total_samples, 0);
1246        assert!(merged.averaged_gradients.is_empty());
1247    }
1248
1249    #[test]
1250    fn test_worker_update_serialization() {
1251        let update = WorkerUpdate {
1252            worker_id: 7,
1253            gradients: vec![0.1, -0.2],
1254            loss: 0.42,
1255            samples_processed: 32,
1256        };
1257        let json = serde_json::to_string(&update).expect("serialize");
1258        let update2: WorkerUpdate = serde_json::from_str(&json).expect("deserialize");
1259        assert_eq!(update.worker_id, update2.worker_id);
1260        assert_eq!(update.samples_processed, update2.samples_processed);
1261    }
1262
1263    #[test]
1264    fn test_model_update_fields() {
1265        let mu = ModelUpdate {
1266            averaged_gradients: vec![1.0, 2.0],
1267            mean_loss: 0.5,
1268            total_samples: 100,
1269        };
1270        assert_eq!(mu.total_samples, 100);
1271        assert!((mu.mean_loss - 0.5).abs() < 1e-12);
1272    }
1273
1274    #[tokio::test]
1275    async fn test_distributed_coordinator_creation() {
1276        let config = DistributedTrainingConfig::default();
1277        let coordinator = DistributedTrainingCoordinator::new(config).await;
1278        assert!(coordinator.is_ok());
1279    }
1280
1281    #[tokio::test]
1282    async fn test_worker_registration() {
1283        let config = DistributedTrainingConfig::default();
1284        let coordinator = DistributedTrainingCoordinator::new(config)
1285            .await
1286            .expect("should succeed");
1287
1288        let worker = WorkerInfo {
1289            worker_id: 0,
1290            rank: 0,
1291            address: "127.0.0.1:8080".to_string(),
1292            status: WorkerStatus::Idle,
1293            num_gpus: 1,
1294            memory_gb: 16.0,
1295            last_heartbeat: Utc::now(),
1296        };
1297
1298        coordinator
1299            .register_worker(worker)
1300            .await
1301            .expect("should succeed");
1302        let stats = coordinator.get_worker_stats().await;
1303        assert_eq!(stats.len(), 1);
1304    }
1305
1306    #[tokio::test]
1307    async fn test_distributed_training() {
1308        let config = DistributedTrainingConfig {
1309            strategy: DistributedStrategy::DataParallel {
1310                num_workers: 2,
1311                batch_size: 128,
1312            },
1313            ..Default::default()
1314        };
1315
1316        let model_config = ModelConfig::default().with_dimensions(64);
1317        let model = TransE::new(model_config);
1318
1319        let mut trainer = DistributedEmbeddingTrainer::new(model, config)
1320            .await
1321            .expect("should succeed");
1322
1323        // Register workers
1324        for i in 0..2 {
1325            let worker = WorkerInfo {
1326                worker_id: i,
1327                rank: i,
1328                address: format!("127.0.0.1:808{}", i),
1329                status: WorkerStatus::Idle,
1330                num_gpus: 1,
1331                memory_gb: 16.0,
1332                last_heartbeat: Utc::now(),
1333            };
1334            trainer
1335                .register_worker(worker)
1336                .await
1337                .expect("should succeed");
1338        }
1339
1340        // Train for a few epochs
1341        let stats = trainer.train(5).await.expect("should succeed");
1342
1343        assert_eq!(stats.total_epochs, 5);
1344        assert!(stats.final_loss >= 0.0);
1345        assert_eq!(stats.num_workers, 2);
1346    }
1347
1348    #[tokio::test]
1349    async fn test_checkpoint_save_load() {
1350        let config = DistributedTrainingConfig::default();
1351        let coordinator = DistributedTrainingCoordinator::new(config)
1352            .await
1353            .expect("should succeed");
1354
1355        let model_config = ModelConfig::default();
1356        let model = TransE::new(model_config);
1357
1358        // Register a worker
1359        let worker = WorkerInfo {
1360            worker_id: 0,
1361            rank: 0,
1362            address: "127.0.0.1:8080".to_string(),
1363            status: WorkerStatus::Idle,
1364            num_gpus: 1,
1365            memory_gb: 16.0,
1366            last_heartbeat: Utc::now(),
1367        };
1368        coordinator
1369            .register_worker(worker)
1370            .await
1371            .expect("should succeed");
1372
1373        // Save checkpoint
1374        coordinator
1375            .save_checkpoint(&model, 10, 0.5)
1376            .await
1377            .expect("should succeed");
1378
1379        // Load checkpoint
1380        let checkpoint = coordinator
1381            .load_checkpoint("checkpoint_epoch_10")
1382            .await
1383            .expect("should succeed");
1384        assert_eq!(checkpoint.epoch, 10);
1385        assert_eq!(checkpoint.loss, 0.5);
1386    }
1387}