Skip to main content

scirs2_autograd/distributed/
mod.rs

1//! Distributed automatic differentiation
2//!
3//! This module provides support for distributed gradient computation across multiple
4//! processes or machines, enabling data-parallel and model-parallel training.
5
6use crate::{error::AutogradError, tensor::Tensor, Float, NdArray, Result};
7use std::sync::{Arc, Mutex};
8
9pub mod communication;
10pub mod data_parallel;
11pub mod model_parallel;
12
13/// Distributed gradient accumulator
14pub struct DistributedGradient<T: Float> {
15    /// Local gradients
16    local_gradients: Arc<Mutex<Vec<NdArray<T>>>>,
17    /// Accumulated gradients across all workers
18    accumulated: Arc<Mutex<Option<Vec<NdArray<T>>>>>,
19    /// Number of workers
20    num_workers: usize,
21    /// Current worker rank
22    rank: usize,
23}
24
25impl<T: Float + scirs2_core::ndarray::ScalarOperand> DistributedGradient<T> {
26    /// Create a new distributed gradient accumulator
27    pub fn new(num_workers: usize, rank: usize) -> Self {
28        Self {
29            local_gradients: Arc::new(Mutex::new(Vec::new())),
30            accumulated: Arc::new(Mutex::new(None)),
31            num_workers,
32            rank,
33        }
34    }
35
36    /// Add local gradient
37    pub fn add_local(&self, gradient: NdArray<T>) -> Result<()> {
38        let mut local = self
39            .local_gradients
40            .lock()
41            .map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
42        local.push(gradient);
43        Ok(())
44    }
45
46    /// Synchronize gradients across all workers (allreduce)
47    pub fn allreduce(&self) -> Result<Vec<NdArray<T>>> {
48        // In a real implementation, this would communicate with other workers
49        // For now, we simulate by averaging local gradients
50
51        let local = self
52            .local_gradients
53            .lock()
54            .map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
55
56        // Average gradients (simulating allreduce)
57        let num_grads = local.len();
58        if num_grads == 0 {
59            return Ok(Vec::new());
60        }
61
62        let mut result = Vec::with_capacity(num_grads);
63        for grad in local.iter() {
64            let averaged = grad
65                / T::from(self.num_workers).ok_or_else(|| {
66                    AutogradError::compute_error("Failed to convert num_workers".to_string())
67                })?;
68            result.push(averaged);
69        }
70
71        // Store accumulated result
72        let mut accumulated = self
73            .accumulated
74            .lock()
75            .map_err(|_| AutogradError::internal_error("Failed to lock accumulated gradients"))?;
76        *accumulated = Some(result.clone());
77
78        Ok(result)
79    }
80
81    /// Get current worker rank
82    pub fn rank(&self) -> usize {
83        self.rank
84    }
85
86    /// Get total number of workers
87    pub fn num_workers(&self) -> usize {
88        self.num_workers
89    }
90
91    /// Clear accumulated gradients
92    pub fn clear(&self) -> Result<()> {
93        let mut local = self
94            .local_gradients
95            .lock()
96            .map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
97        local.clear();
98
99        let mut accumulated = self
100            .accumulated
101            .lock()
102            .map_err(|_| AutogradError::internal_error("Failed to lock accumulated gradients"))?;
103        *accumulated = None;
104
105        Ok(())
106    }
107}
108
109/// Data-parallel training strategy
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum ParallelStrategy {
112    /// Data parallel - replicate model, split data
113    DataParallel,
114    /// Model parallel - split model across devices
115    ModelParallel,
116    /// Pipeline parallel - split model into stages
117    PipelineParallel,
118    /// Hybrid - combination of strategies
119    Hybrid,
120}
121
122/// Distributed training configuration
123pub struct DistributedConfig {
124    /// Parallel strategy
125    pub strategy: ParallelStrategy,
126    /// Number of workers
127    pub num_workers: usize,
128    /// Current worker rank
129    pub rank: usize,
130    /// Gradient accumulation steps
131    pub grad_accumulation_steps: usize,
132    /// Use gradient compression
133    pub compress_gradients: bool,
134}
135
136impl Default for DistributedConfig {
137    fn default() -> Self {
138        Self {
139            strategy: ParallelStrategy::DataParallel,
140            num_workers: 1,
141            rank: 0,
142            grad_accumulation_steps: 1,
143            compress_gradients: false,
144        }
145    }
146}
147
148/// Gradient synchronization backend
149pub trait SyncBackend<T: Float>: Send + Sync {
150    /// Perform allreduce operation on gradients
151    fn allreduce(&self, gradients: &[NdArray<T>]) -> Result<Vec<NdArray<T>>>;
152
153    /// Broadcast parameters from rank 0 to all workers
154    fn broadcast(&self, parameters: &[NdArray<T>]) -> Result<Vec<NdArray<T>>>;
155
156    /// Gather gradients from all workers to rank 0
157    fn gather(&self, gradient: &NdArray<T>) -> Result<Vec<NdArray<T>>>;
158
159    /// Scatter data from rank 0 to all workers
160    fn scatter(&self, data: &[NdArray<T>]) -> Result<NdArray<T>>;
161}
162
163/// Local (single-process) sync backend for testing
164pub struct LocalSyncBackend<T: Float> {
165    num_workers: usize,
166    _phantom: std::marker::PhantomData<T>,
167}
168
169impl<T: Float> LocalSyncBackend<T> {
170    /// Create a new local sync backend
171    pub fn new(num_workers: usize) -> Self {
172        Self {
173            num_workers,
174            _phantom: std::marker::PhantomData,
175        }
176    }
177}
178
179impl<T: Float + scirs2_core::ndarray::ScalarOperand> SyncBackend<T> for LocalSyncBackend<T> {
180    fn allreduce(&self, gradients: &[NdArray<T>]) -> Result<Vec<NdArray<T>>> {
181        // Simulate allreduce by averaging (for testing)
182        let divisor = T::from(self.num_workers).ok_or_else(|| {
183            AutogradError::compute_error("Failed to convert num_workers".to_string())
184        })?;
185
186        Ok(gradients.iter().map(|g| g / divisor).collect())
187    }
188
189    fn broadcast(&self, parameters: &[NdArray<T>]) -> Result<Vec<NdArray<T>>> {
190        // In local mode, just return copies
191        Ok(parameters.to_vec())
192    }
193
194    fn gather(&self, gradient: &NdArray<T>) -> Result<Vec<NdArray<T>>> {
195        // Simulate gather by replicating
196        Ok(vec![gradient.clone(); self.num_workers])
197    }
198
199    fn scatter(&self, data: &[NdArray<T>]) -> Result<NdArray<T>> {
200        // Return first element (simulating scatter to rank 0)
201        data.first()
202            .cloned()
203            .ok_or_else(|| AutogradError::invalid_argument("Empty data for scatter".to_string()))
204    }
205}
206
207/// Distributed optimizer wrapper
208pub struct DistributedOptimizer<T: Float> {
209    /// Synchronization backend
210    backend: Arc<dyn SyncBackend<T>>,
211    /// Configuration
212    config: DistributedConfig,
213    /// Gradient accumulation buffer
214    grad_buffer: Arc<Mutex<Vec<Vec<NdArray<T>>>>>,
215}
216
217impl<T: Float + scirs2_core::ndarray::ScalarOperand> DistributedOptimizer<T> {
218    /// Create a new distributed optimizer
219    pub fn new(backend: Arc<dyn SyncBackend<T>>, config: DistributedConfig) -> Self {
220        Self {
221            backend,
222            config,
223            grad_buffer: Arc::new(Mutex::new(Vec::new())),
224        }
225    }
226
227    /// Accumulate gradient
228    pub fn accumulate_gradient(&self, gradients: Vec<NdArray<T>>) -> Result<()> {
229        let mut buffer = self
230            .grad_buffer
231            .lock()
232            .map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
233        buffer.push(gradients);
234        Ok(())
235    }
236
237    /// Check if ready to synchronize
238    pub fn should_sync(&self) -> Result<bool> {
239        let buffer = self
240            .grad_buffer
241            .lock()
242            .map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
243        Ok(buffer.len() >= self.config.grad_accumulation_steps)
244    }
245
246    /// Synchronize accumulated gradients
247    pub fn sync_gradients(&self) -> Result<Vec<NdArray<T>>> {
248        let mut buffer = self
249            .grad_buffer
250            .lock()
251            .map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
252
253        if buffer.is_empty() {
254            return Ok(Vec::new());
255        }
256
257        // Average accumulated gradients
258        let num_grads = buffer[0].len();
259        let num_steps = buffer.len();
260        let mut averaged = Vec::with_capacity(num_grads);
261
262        for i in 0..num_grads {
263            let mut sum = buffer[0][i].clone();
264            for step in buffer.iter().skip(1) {
265                sum += &step[i];
266            }
267            let avg = sum
268                / T::from(num_steps).ok_or_else(|| {
269                    AutogradError::compute_error("Failed to convert num_steps".to_string())
270                })?;
271            averaged.push(avg);
272        }
273
274        // Synchronize across workers
275        let synced = self.backend.allreduce(&averaged)?;
276
277        // Clear buffer
278        buffer.clear();
279
280        Ok(synced)
281    }
282
283    /// Get configuration
284    pub fn config(&self) -> &DistributedConfig {
285        &self.config
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use scirs2_core::ndarray::Array1;
293
294    #[test]
295    fn test_distributed_gradient() {
296        let grad_acc: DistributedGradient<f32> = DistributedGradient::new(4, 0);
297
298        let grad1: Array1<f32> = Array1::from_vec(vec![1.0, 2.0, 3.0]);
299        grad_acc.add_local(grad1.into_dyn()).expect("Should add");
300
301        let result = grad_acc.allreduce().expect("Should allreduce");
302        assert_eq!(result.len(), 1);
303
304        // Should be averaged by num_workers
305        let result_vals = result[0].as_slice().expect("Should get slice");
306        assert!((result_vals[0] - 0.25).abs() < 1e-6);
307    }
308
309    #[test]
310    fn test_parallel_strategy() {
311        assert_eq!(
312            ParallelStrategy::DataParallel,
313            ParallelStrategy::DataParallel
314        );
315        assert_ne!(
316            ParallelStrategy::DataParallel,
317            ParallelStrategy::ModelParallel
318        );
319    }
320
321    #[test]
322    fn test_local_sync_backend() {
323        let backend: LocalSyncBackend<f64> = LocalSyncBackend::new(2);
324
325        let grad: Array1<f64> = Array1::from_vec(vec![4.0, 6.0]);
326        let result = backend
327            .allreduce(&[grad.into_dyn()])
328            .expect("Should allreduce");
329
330        // Should be divided by 2
331        let result_vals = result[0].as_slice().expect("Should get slice");
332        assert_eq!(result_vals[0], 2.0);
333        assert_eq!(result_vals[1], 3.0);
334    }
335
336    #[test]
337    fn test_distributed_optimizer() {
338        let backend = Arc::new(LocalSyncBackend::<f32>::new(1));
339        let config = DistributedConfig {
340            grad_accumulation_steps: 2,
341            ..Default::default()
342        };
343
344        let optimizer = DistributedOptimizer::new(backend, config);
345
346        // Not ready after 1 accumulation
347        let grad1: Array1<f32> = Array1::from_vec(vec![1.0]);
348        optimizer
349            .accumulate_gradient(vec![grad1.into_dyn()])
350            .expect("Should accumulate");
351        assert!(!optimizer.should_sync().expect("Should check"));
352
353        // Ready after 2 accumulations
354        let grad2: Array1<f32> = Array1::from_vec(vec![3.0]);
355        optimizer
356            .accumulate_gradient(vec![grad2.into_dyn()])
357            .expect("Should accumulate");
358        assert!(optimizer.should_sync().expect("Should check"));
359
360        // Sync and check average
361        let synced = optimizer.sync_gradients().expect("Should sync");
362        let synced_val = synced[0].as_slice().expect("Should get slice")[0];
363        assert_eq!(synced_val, 2.0); // (1 + 3) / 2
364    }
365}