1use std::sync::Arc;
27
28use crate::backend::Backend;
29use crate::collective::{ReduceOp, allreduce};
30use ferrotorch_core::storage::TensorStorage;
31use ferrotorch_core::{FerrotorchResult, Float, Tensor};
32use ferrotorch_nn::{Module, Parameter};
33
34const DEFAULT_BUCKET_SIZE_BYTES: usize = 25 * 1024 * 1024;
36
37pub struct DDP<M: Module<T>, T: Float> {
56 module: M,
57 backend: Arc<dyn Backend>,
58 buckets: Vec<Vec<usize>>,
60 _marker: std::marker::PhantomData<T>,
61}
62
63impl<M: Module<T>, T: Float> DDP<M, T> {
64 pub fn new(module: M, backend: Arc<dyn Backend>) -> Self {
70 Self::with_bucket_size(module, backend, DEFAULT_BUCKET_SIZE_BYTES)
71 }
72
73 pub fn with_bucket_size(
75 module: M,
76 backend: Arc<dyn Backend>,
77 bucket_size_bytes: usize,
78 ) -> Self {
79 let params = module.parameters();
80 let buckets = compute_buckets::<T>(¶ms, bucket_size_bytes);
81 Self {
82 module,
83 backend,
84 buckets,
85 _marker: std::marker::PhantomData,
86 }
87 }
88
89 pub fn module(&self) -> &M {
91 &self.module
92 }
93
94 pub fn module_mut(&mut self) -> &mut M {
96 &mut self.module
97 }
98
99 pub fn into_inner(self) -> M {
101 self.module
102 }
103
104 pub fn backend(&self) -> &Arc<dyn Backend> {
106 &self.backend
107 }
108
109 pub fn sync_gradients(&self) -> FerrotorchResult<()> {
118 let params = self.module.parameters();
119 for bucket in &self.buckets {
120 sync_one_bucket::<T>(bucket, ¶ms, self.backend.as_ref())?;
121 }
122 Ok(())
123 }
124
125 pub fn overlapped_sync_gradients(&self) -> FerrotorchResult<()> {
136 let params = self.module.parameters();
137
138 let errors: std::sync::Mutex<Vec<ferrotorch_core::error::FerrotorchError>> =
140 std::sync::Mutex::new(Vec::new());
141
142 std::thread::scope(|s| {
143 for bucket in &self.buckets {
144 let params_ref = ¶ms;
145 let backend_ref = self.backend.as_ref();
146 let errors_ref = &errors;
147
148 s.spawn(move || {
149 let result = sync_one_bucket::<T>(bucket, params_ref, backend_ref);
150 if let Err(e) = result {
151 errors_ref.lock().unwrap().push(e);
152 }
153 });
154 }
155 });
156
157 let errs = errors.into_inner().unwrap();
158 if let Some(e) = errs.into_iter().next() {
159 return Err(e);
160 }
161
162 Ok(())
163 }
164
165 pub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()> {
177 let params_mut = self.module.parameters_mut();
178
179 for param in params_mut {
180 let tensor = param.tensor().clone();
181 let synced = crate::collective::broadcast(&tensor, self.backend.as_ref(), root)?;
182 *param = Parameter::new(synced);
183 }
184
185 Ok(())
186 }
187}
188
189impl<M: Module<T>, T: Float> Module<T> for DDP<M, T> {
192 fn forward(
193 &self,
194 input: &ferrotorch_core::Tensor<T>,
195 ) -> FerrotorchResult<ferrotorch_core::Tensor<T>> {
196 self.module.forward(input)
197 }
198
199 fn parameters(&self) -> Vec<&Parameter<T>> {
200 self.module.parameters()
201 }
202
203 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
204 self.module.parameters_mut()
205 }
206
207 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
208 self.module.named_parameters()
209 }
210
211 fn train(&mut self) {
212 self.module.train();
213 }
214
215 fn eval(&mut self) {
216 self.module.eval();
217 }
218
219 fn is_training(&self) -> bool {
220 self.module.is_training()
221 }
222}
223
224fn compute_buckets<T: Float>(
231 params: &[&Parameter<T>],
232 bucket_size_bytes: usize,
233) -> Vec<Vec<usize>> {
234 let elem_size = std::mem::size_of::<T>();
235 let mut buckets: Vec<Vec<usize>> = Vec::new();
236 let mut current_bucket: Vec<usize> = Vec::new();
237 let mut current_bytes: usize = 0;
238
239 for i in (0..params.len()).rev() {
241 let param_bytes = params[i].tensor().numel() * elem_size;
242
243 if !current_bucket.is_empty() && current_bytes + param_bytes > bucket_size_bytes {
244 buckets.push(current_bucket);
245 current_bucket = Vec::new();
246 current_bytes = 0;
247 }
248
249 current_bucket.push(i);
250 current_bytes += param_bytes;
251 }
252
253 if !current_bucket.is_empty() {
254 buckets.push(current_bucket);
255 }
256
257 buckets
258}
259
260fn sync_one_bucket<T: Float>(
266 bucket: &[usize],
267 params: &[&Parameter<T>],
268 backend: &dyn Backend,
269) -> FerrotorchResult<()> {
270 let mut flat_data: Vec<T> = Vec::new();
271 let mut param_numels: Vec<usize> = Vec::new();
272
273 for &pi in bucket {
274 let numel = params[pi].tensor().numel();
275 param_numels.push(numel);
276
277 let grad = params[pi].tensor().grad()?;
278 match grad {
279 Some(g) => flat_data.extend(g.data_vec()?),
280 None => {
281 flat_data.extend(std::iter::repeat_n(<T as num_traits::Zero>::zero(), numel));
282 }
283 }
284 }
285
286 if flat_data.is_empty() {
287 return Ok(());
288 }
289
290 let flat_tensor = Tensor::from_storage(
291 TensorStorage::cpu(flat_data),
292 vec![param_numels.iter().sum()],
293 false,
294 )?;
295 let synced = allreduce(&flat_tensor, backend, ReduceOp::Mean)?;
296 let synced_data = synced.data()?;
297
298 let mut offset = 0;
299 for (&pi, &numel) in bucket.iter().zip(param_numels.iter()) {
300 let grad_slice = &synced_data[offset..offset + numel];
301 let grad_tensor = Tensor::from_storage(
302 TensorStorage::cpu(grad_slice.to_vec()),
303 params[pi].tensor().shape().to_vec(),
304 false,
305 )?;
306 params[pi].tensor().set_grad(Some(grad_tensor))?;
307 offset += numel;
308 }
309
310 Ok(())
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::backend::SimulatedBackend;
317 use ferrotorch_core::storage::TensorStorage;
318 use ferrotorch_core::{FerrotorchResult, Tensor};
319 use ferrotorch_nn::Parameter;
320 use std::thread;
321
322 struct TestModule<T: Float> {
324 weight: Parameter<T>,
325 training: bool,
326 }
327
328 impl<T: Float> TestModule<T> {
329 fn new(data: &[T]) -> FerrotorchResult<Self> {
330 Ok(Self {
331 weight: Parameter::from_slice(data, &[data.len()])?,
332 training: true,
333 })
334 }
335 }
336
337 impl<T: Float> Module<T> for TestModule<T> {
338 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
339 Ok(input.clone())
340 }
341
342 fn parameters(&self) -> Vec<&Parameter<T>> {
343 vec![&self.weight]
344 }
345
346 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
347 vec![&mut self.weight]
348 }
349
350 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
351 vec![("weight".into(), &self.weight)]
352 }
353
354 fn train(&mut self) {
355 self.training = true;
356 }
357
358 fn eval(&mut self) {
359 self.training = false;
360 }
361
362 fn is_training(&self) -> bool {
363 self.training
364 }
365 }
366
367 #[test]
368 fn test_ddp_sync_gradients() {
369 let group = SimulatedBackend::create_group(4).unwrap();
372 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
373
374 let handles: Vec<_> = arcs
375 .iter()
376 .cloned()
377 .map(|b| {
378 thread::spawn(move || {
379 let rank = b.rank();
380 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0]).unwrap();
381 let ddp = DDP::new(model, b);
382
383 let grad_val = rank as f32;
385 let grad = Tensor::from_storage(
386 TensorStorage::cpu(vec![grad_val, grad_val, grad_val]),
387 vec![3],
388 false,
389 )
390 .unwrap();
391 ddp.module().weight.tensor().set_grad(Some(grad)).unwrap();
392
393 ddp.sync_gradients().unwrap();
395
396 let synced_grad = ddp.module().weight.tensor().grad().unwrap().unwrap();
398 let data = synced_grad.data().unwrap();
399 for &v in data {
400 assert!((v - 1.5).abs() < 1e-5, "rank {rank}: expected 1.5, got {v}");
401 }
402 })
403 })
404 .collect();
405
406 for h in handles {
407 h.join().unwrap();
408 }
409 }
410
411 #[test]
412 fn test_ddp_broadcast_parameters() {
413 let group = SimulatedBackend::create_group(3).unwrap();
416 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
417
418 let handles: Vec<_> = arcs
419 .iter()
420 .cloned()
421 .map(|b| {
422 thread::spawn(move || {
423 let rank = b.rank();
424 let data: Vec<f32> = if rank == 0 {
425 vec![10.0, 20.0, 30.0]
426 } else {
427 vec![0.0, 0.0, 0.0]
428 };
429 let model = TestModule::<f32>::new(&data).unwrap();
430 let mut ddp = DDP::new(model, b);
431
432 ddp.broadcast_parameters(0).unwrap();
433
434 let param_data = ddp.module().weight.tensor().data().unwrap();
435 assert!(
436 (param_data[0] - 10.0).abs() < 1e-5,
437 "rank {rank}: expected 10.0, got {}",
438 param_data[0]
439 );
440 assert!(
441 (param_data[1] - 20.0).abs() < 1e-5,
442 "rank {rank}: expected 20.0, got {}",
443 param_data[1]
444 );
445 assert!(
446 (param_data[2] - 30.0).abs() < 1e-5,
447 "rank {rank}: expected 30.0, got {}",
448 param_data[2]
449 );
450 })
451 })
452 .collect();
453
454 for h in handles {
455 h.join().unwrap();
456 }
457 }
458
459 #[test]
460 fn test_ddp_delegates_module_trait() {
461 let group = SimulatedBackend::create_group(1).unwrap();
462 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
463 let model = TestModule::<f32>::new(&[1.0, 2.0]).unwrap();
464 let mut ddp = DDP::new(model, b);
465
466 assert!(ddp.is_training());
468 ddp.eval();
469 assert!(!ddp.is_training());
470 ddp.train();
471 assert!(ddp.is_training());
472
473 assert_eq!(ddp.parameters().len(), 1);
474 assert_eq!(ddp.named_parameters()[0].0, "weight");
475 }
476}