scirs2_autograd/distributed/
mod.rs1use crate::{error::AutogradError, tensor::Tensor, Float, NdArray, Result};
7use std::sync::{Arc, Mutex};
8
9pub mod communication;
10pub mod data_parallel;
11pub mod model_parallel;
12
13pub struct DistributedGradient<T: Float> {
15 local_gradients: Arc<Mutex<Vec<NdArray<T>>>>,
17 accumulated: Arc<Mutex<Option<Vec<NdArray<T>>>>>,
19 num_workers: usize,
21 rank: usize,
23}
24
25impl<T: Float + scirs2_core::ndarray::ScalarOperand> DistributedGradient<T> {
26 pub fn new(num_workers: usize, rank: usize) -> Self {
28 Self {
29 local_gradients: Arc::new(Mutex::new(Vec::new())),
30 accumulated: Arc::new(Mutex::new(None)),
31 num_workers,
32 rank,
33 }
34 }
35
36 pub fn add_local(&self, gradient: NdArray<T>) -> Result<()> {
38 let mut local = self
39 .local_gradients
40 .lock()
41 .map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
42 local.push(gradient);
43 Ok(())
44 }
45
46 pub fn allreduce(&self) -> Result<Vec<NdArray<T>>> {
48 let local = self
52 .local_gradients
53 .lock()
54 .map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
55
56 let num_grads = local.len();
58 if num_grads == 0 {
59 return Ok(Vec::new());
60 }
61
62 let mut result = Vec::with_capacity(num_grads);
63 for grad in local.iter() {
64 let averaged = grad
65 / T::from(self.num_workers).ok_or_else(|| {
66 AutogradError::compute_error("Failed to convert num_workers".to_string())
67 })?;
68 result.push(averaged);
69 }
70
71 let mut accumulated = self
73 .accumulated
74 .lock()
75 .map_err(|_| AutogradError::internal_error("Failed to lock accumulated gradients"))?;
76 *accumulated = Some(result.clone());
77
78 Ok(result)
79 }
80
81 pub fn rank(&self) -> usize {
83 self.rank
84 }
85
86 pub fn num_workers(&self) -> usize {
88 self.num_workers
89 }
90
91 pub fn clear(&self) -> Result<()> {
93 let mut local = self
94 .local_gradients
95 .lock()
96 .map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
97 local.clear();
98
99 let mut accumulated = self
100 .accumulated
101 .lock()
102 .map_err(|_| AutogradError::internal_error("Failed to lock accumulated gradients"))?;
103 *accumulated = None;
104
105 Ok(())
106 }
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum ParallelStrategy {
112 DataParallel,
114 ModelParallel,
116 PipelineParallel,
118 Hybrid,
120}
121
122pub struct DistributedConfig {
124 pub strategy: ParallelStrategy,
126 pub num_workers: usize,
128 pub rank: usize,
130 pub grad_accumulation_steps: usize,
132 pub compress_gradients: bool,
134}
135
136impl Default for DistributedConfig {
137 fn default() -> Self {
138 Self {
139 strategy: ParallelStrategy::DataParallel,
140 num_workers: 1,
141 rank: 0,
142 grad_accumulation_steps: 1,
143 compress_gradients: false,
144 }
145 }
146}
147
148pub trait SyncBackend<T: Float>: Send + Sync {
150 fn allreduce(&self, gradients: &[NdArray<T>]) -> Result<Vec<NdArray<T>>>;
152
153 fn broadcast(&self, parameters: &[NdArray<T>]) -> Result<Vec<NdArray<T>>>;
155
156 fn gather(&self, gradient: &NdArray<T>) -> Result<Vec<NdArray<T>>>;
158
159 fn scatter(&self, data: &[NdArray<T>]) -> Result<NdArray<T>>;
161}
162
163pub struct LocalSyncBackend<T: Float> {
165 num_workers: usize,
166 _phantom: std::marker::PhantomData<T>,
167}
168
169impl<T: Float> LocalSyncBackend<T> {
170 pub fn new(num_workers: usize) -> Self {
172 Self {
173 num_workers,
174 _phantom: std::marker::PhantomData,
175 }
176 }
177}
178
179impl<T: Float + scirs2_core::ndarray::ScalarOperand> SyncBackend<T> for LocalSyncBackend<T> {
180 fn allreduce(&self, gradients: &[NdArray<T>]) -> Result<Vec<NdArray<T>>> {
181 let divisor = T::from(self.num_workers).ok_or_else(|| {
183 AutogradError::compute_error("Failed to convert num_workers".to_string())
184 })?;
185
186 Ok(gradients.iter().map(|g| g / divisor).collect())
187 }
188
189 fn broadcast(&self, parameters: &[NdArray<T>]) -> Result<Vec<NdArray<T>>> {
190 Ok(parameters.to_vec())
192 }
193
194 fn gather(&self, gradient: &NdArray<T>) -> Result<Vec<NdArray<T>>> {
195 Ok(vec![gradient.clone(); self.num_workers])
197 }
198
199 fn scatter(&self, data: &[NdArray<T>]) -> Result<NdArray<T>> {
200 data.first()
202 .cloned()
203 .ok_or_else(|| AutogradError::invalid_argument("Empty data for scatter".to_string()))
204 }
205}
206
207pub struct DistributedOptimizer<T: Float> {
209 backend: Arc<dyn SyncBackend<T>>,
211 config: DistributedConfig,
213 grad_buffer: Arc<Mutex<Vec<Vec<NdArray<T>>>>>,
215}
216
217impl<T: Float + scirs2_core::ndarray::ScalarOperand> DistributedOptimizer<T> {
218 pub fn new(backend: Arc<dyn SyncBackend<T>>, config: DistributedConfig) -> Self {
220 Self {
221 backend,
222 config,
223 grad_buffer: Arc::new(Mutex::new(Vec::new())),
224 }
225 }
226
227 pub fn accumulate_gradient(&self, gradients: Vec<NdArray<T>>) -> Result<()> {
229 let mut buffer = self
230 .grad_buffer
231 .lock()
232 .map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
233 buffer.push(gradients);
234 Ok(())
235 }
236
237 pub fn should_sync(&self) -> Result<bool> {
239 let buffer = self
240 .grad_buffer
241 .lock()
242 .map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
243 Ok(buffer.len() >= self.config.grad_accumulation_steps)
244 }
245
246 pub fn sync_gradients(&self) -> Result<Vec<NdArray<T>>> {
248 let mut buffer = self
249 .grad_buffer
250 .lock()
251 .map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
252
253 if buffer.is_empty() {
254 return Ok(Vec::new());
255 }
256
257 let num_grads = buffer[0].len();
259 let num_steps = buffer.len();
260 let mut averaged = Vec::with_capacity(num_grads);
261
262 for i in 0..num_grads {
263 let mut sum = buffer[0][i].clone();
264 for step in buffer.iter().skip(1) {
265 sum += &step[i];
266 }
267 let avg = sum
268 / T::from(num_steps).ok_or_else(|| {
269 AutogradError::compute_error("Failed to convert num_steps".to_string())
270 })?;
271 averaged.push(avg);
272 }
273
274 let synced = self.backend.allreduce(&averaged)?;
276
277 buffer.clear();
279
280 Ok(synced)
281 }
282
283 pub fn config(&self) -> &DistributedConfig {
285 &self.config
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use scirs2_core::ndarray::Array1;
293
294 #[test]
295 fn test_distributed_gradient() {
296 let grad_acc: DistributedGradient<f32> = DistributedGradient::new(4, 0);
297
298 let grad1: Array1<f32> = Array1::from_vec(vec![1.0, 2.0, 3.0]);
299 grad_acc.add_local(grad1.into_dyn()).expect("Should add");
300
301 let result = grad_acc.allreduce().expect("Should allreduce");
302 assert_eq!(result.len(), 1);
303
304 let result_vals = result[0].as_slice().expect("Should get slice");
306 assert!((result_vals[0] - 0.25).abs() < 1e-6);
307 }
308
309 #[test]
310 fn test_parallel_strategy() {
311 assert_eq!(
312 ParallelStrategy::DataParallel,
313 ParallelStrategy::DataParallel
314 );
315 assert_ne!(
316 ParallelStrategy::DataParallel,
317 ParallelStrategy::ModelParallel
318 );
319 }
320
321 #[test]
322 fn test_local_sync_backend() {
323 let backend: LocalSyncBackend<f64> = LocalSyncBackend::new(2);
324
325 let grad: Array1<f64> = Array1::from_vec(vec![4.0, 6.0]);
326 let result = backend
327 .allreduce(&[grad.into_dyn()])
328 .expect("Should allreduce");
329
330 let result_vals = result[0].as_slice().expect("Should get slice");
332 assert_eq!(result_vals[0], 2.0);
333 assert_eq!(result_vals[1], 3.0);
334 }
335
336 #[test]
337 fn test_distributed_optimizer() {
338 let backend = Arc::new(LocalSyncBackend::<f32>::new(1));
339 let config = DistributedConfig {
340 grad_accumulation_steps: 2,
341 ..Default::default()
342 };
343
344 let optimizer = DistributedOptimizer::new(backend, config);
345
346 let grad1: Array1<f32> = Array1::from_vec(vec![1.0]);
348 optimizer
349 .accumulate_gradient(vec![grad1.into_dyn()])
350 .expect("Should accumulate");
351 assert!(!optimizer.should_sync().expect("Should check"));
352
353 let grad2: Array1<f32> = Array1::from_vec(vec![3.0]);
355 optimizer
356 .accumulate_gradient(vec![grad2.into_dyn()])
357 .expect("Should accumulate");
358 assert!(optimizer.should_sync().expect("Should check"));
359
360 let synced = optimizer.sync_gradients().expect("Should sync");
362 let synced_val = synced[0].as_slice().expect("Should get slice")[0];
363 assert_eq!(synced_val, 2.0); }
365}