1use std::sync::Arc;
10
11use crate::backend::Backend;
12use crate::collective::{ReduceOp, allreduce};
13use ferrotorch_core::storage::TensorStorage;
14use ferrotorch_core::{FerrotorchResult, Float, Tensor};
15use ferrotorch_nn::{Module, Parameter};
16
17const DEFAULT_BUCKET_SIZE_BYTES: usize = 25 * 1024 * 1024;
19
20pub struct DDP<M: Module<T>, T: Float> {
39 module: M,
40 backend: Arc<dyn Backend>,
41 buckets: Vec<Vec<usize>>,
43 _marker: std::marker::PhantomData<T>,
44}
45
46impl<M: Module<T>, T: Float> DDP<M, T> {
47 pub fn new(module: M, backend: Arc<dyn Backend>) -> Self {
53 Self::with_bucket_size(module, backend, DEFAULT_BUCKET_SIZE_BYTES)
54 }
55
56 pub fn with_bucket_size(
58 module: M,
59 backend: Arc<dyn Backend>,
60 bucket_size_bytes: usize,
61 ) -> Self {
62 let params = module.parameters();
63 let buckets = compute_buckets::<T>(¶ms, bucket_size_bytes);
64 Self {
65 module,
66 backend,
67 buckets,
68 _marker: std::marker::PhantomData,
69 }
70 }
71
72 pub fn module(&self) -> &M {
74 &self.module
75 }
76
77 pub fn module_mut(&mut self) -> &mut M {
79 &mut self.module
80 }
81
82 pub fn into_inner(self) -> M {
84 self.module
85 }
86
87 pub fn backend(&self) -> &Arc<dyn Backend> {
89 &self.backend
90 }
91
92 pub fn sync_gradients(&self) -> FerrotorchResult<()> {
101 let params = self.module.parameters();
102 for bucket in &self.buckets {
103 sync_one_bucket::<T>(bucket, ¶ms, self.backend.as_ref())?;
104 }
105 Ok(())
106 }
107
108 pub fn overlapped_sync_gradients(&self) -> FerrotorchResult<()> {
119 let params = self.module.parameters();
120
121 let errors: std::sync::Mutex<Vec<ferrotorch_core::error::FerrotorchError>> =
123 std::sync::Mutex::new(Vec::new());
124
125 std::thread::scope(|s| {
126 for bucket in &self.buckets {
127 let params_ref = ¶ms;
128 let backend_ref = self.backend.as_ref();
129 let errors_ref = &errors;
130
131 s.spawn(move || {
132 let result = sync_one_bucket::<T>(bucket, params_ref, backend_ref);
133 if let Err(e) = result {
134 errors_ref.lock().unwrap().push(e);
135 }
136 });
137 }
138 });
139
140 let errs = errors.into_inner().unwrap();
141 if let Some(e) = errs.into_iter().next() {
142 return Err(e);
143 }
144
145 Ok(())
146 }
147
148 pub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()> {
160 let params_mut = self.module.parameters_mut();
161
162 for param in params_mut {
163 let tensor = param.tensor().clone();
164 let synced = crate::collective::broadcast(&tensor, self.backend.as_ref(), root)?;
165 *param = Parameter::new(synced);
166 }
167
168 Ok(())
169 }
170}
171
172impl<M: Module<T>, T: Float> Module<T> for DDP<M, T> {
175 fn forward(
176 &self,
177 input: &ferrotorch_core::Tensor<T>,
178 ) -> FerrotorchResult<ferrotorch_core::Tensor<T>> {
179 self.module.forward(input)
180 }
181
182 fn parameters(&self) -> Vec<&Parameter<T>> {
183 self.module.parameters()
184 }
185
186 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
187 self.module.parameters_mut()
188 }
189
190 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
191 self.module.named_parameters()
192 }
193
194 fn train(&mut self) {
195 self.module.train();
196 }
197
198 fn eval(&mut self) {
199 self.module.eval();
200 }
201
202 fn is_training(&self) -> bool {
203 self.module.is_training()
204 }
205}
206
207fn compute_buckets<T: Float>(
214 params: &[&Parameter<T>],
215 bucket_size_bytes: usize,
216) -> Vec<Vec<usize>> {
217 let elem_size = std::mem::size_of::<T>();
218 let mut buckets: Vec<Vec<usize>> = Vec::new();
219 let mut current_bucket: Vec<usize> = Vec::new();
220 let mut current_bytes: usize = 0;
221
222 for i in (0..params.len()).rev() {
224 let param_bytes = params[i].tensor().numel() * elem_size;
225
226 if !current_bucket.is_empty() && current_bytes + param_bytes > bucket_size_bytes {
227 buckets.push(current_bucket);
228 current_bucket = Vec::new();
229 current_bytes = 0;
230 }
231
232 current_bucket.push(i);
233 current_bytes += param_bytes;
234 }
235
236 if !current_bucket.is_empty() {
237 buckets.push(current_bucket);
238 }
239
240 buckets
241}
242
243fn sync_one_bucket<T: Float>(
249 bucket: &[usize],
250 params: &[&Parameter<T>],
251 backend: &dyn Backend,
252) -> FerrotorchResult<()> {
253 let mut flat_data: Vec<T> = Vec::new();
254 let mut param_numels: Vec<usize> = Vec::new();
255
256 for &pi in bucket {
257 let numel = params[pi].tensor().numel();
258 param_numels.push(numel);
259
260 let grad = params[pi].tensor().grad()?;
261 match grad {
262 Some(g) => flat_data.extend(g.data_vec()?),
263 None => {
264 flat_data.extend(std::iter::repeat_n(<T as num_traits::Zero>::zero(), numel));
265 }
266 }
267 }
268
269 if flat_data.is_empty() {
270 return Ok(());
271 }
272
273 let flat_tensor = Tensor::from_storage(
274 TensorStorage::cpu(flat_data),
275 vec![param_numels.iter().sum()],
276 false,
277 )?;
278 let synced = allreduce(&flat_tensor, backend, ReduceOp::Mean)?;
279 let synced_data = synced.data()?;
280
281 let mut offset = 0;
282 for (&pi, &numel) in bucket.iter().zip(param_numels.iter()) {
283 let grad_slice = &synced_data[offset..offset + numel];
284 let grad_tensor = Tensor::from_storage(
285 TensorStorage::cpu(grad_slice.to_vec()),
286 params[pi].tensor().shape().to_vec(),
287 false,
288 )?;
289 params[pi].tensor().set_grad(Some(grad_tensor))?;
290 offset += numel;
291 }
292
293 Ok(())
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::backend::SimulatedBackend;
300 use ferrotorch_core::storage::TensorStorage;
301 use ferrotorch_core::{FerrotorchResult, Tensor};
302 use ferrotorch_nn::Parameter;
303 use std::thread;
304
305 struct TestModule<T: Float> {
307 weight: Parameter<T>,
308 training: bool,
309 }
310
311 impl<T: Float> TestModule<T> {
312 fn new(data: &[T]) -> FerrotorchResult<Self> {
313 Ok(Self {
314 weight: Parameter::from_slice(data, &[data.len()])?,
315 training: true,
316 })
317 }
318 }
319
320 impl<T: Float> Module<T> for TestModule<T> {
321 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
322 Ok(input.clone())
323 }
324
325 fn parameters(&self) -> Vec<&Parameter<T>> {
326 vec![&self.weight]
327 }
328
329 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
330 vec![&mut self.weight]
331 }
332
333 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
334 vec![("weight".into(), &self.weight)]
335 }
336
337 fn train(&mut self) {
338 self.training = true;
339 }
340
341 fn eval(&mut self) {
342 self.training = false;
343 }
344
345 fn is_training(&self) -> bool {
346 self.training
347 }
348 }
349
350 #[test]
351 fn test_ddp_sync_gradients() {
352 let group = SimulatedBackend::create_group(4).unwrap();
355 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
356
357 let handles: Vec<_> = arcs
358 .iter()
359 .cloned()
360 .map(|b| {
361 thread::spawn(move || {
362 let rank = b.rank();
363 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0]).unwrap();
364 let ddp = DDP::new(model, b);
365
366 let grad_val = rank as f32;
368 let grad = Tensor::from_storage(
369 TensorStorage::cpu(vec![grad_val, grad_val, grad_val]),
370 vec![3],
371 false,
372 )
373 .unwrap();
374 ddp.module().weight.tensor().set_grad(Some(grad)).unwrap();
375
376 ddp.sync_gradients().unwrap();
378
379 let synced_grad = ddp.module().weight.tensor().grad().unwrap().unwrap();
381 let data = synced_grad.data().unwrap();
382 for &v in data {
383 assert!((v - 1.5).abs() < 1e-5, "rank {rank}: expected 1.5, got {v}");
384 }
385 })
386 })
387 .collect();
388
389 for h in handles {
390 h.join().unwrap();
391 }
392 }
393
394 #[test]
395 fn test_ddp_broadcast_parameters() {
396 let group = SimulatedBackend::create_group(3).unwrap();
399 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
400
401 let handles: Vec<_> = arcs
402 .iter()
403 .cloned()
404 .map(|b| {
405 thread::spawn(move || {
406 let rank = b.rank();
407 let data: Vec<f32> = if rank == 0 {
408 vec![10.0, 20.0, 30.0]
409 } else {
410 vec![0.0, 0.0, 0.0]
411 };
412 let model = TestModule::<f32>::new(&data).unwrap();
413 let mut ddp = DDP::new(model, b);
414
415 ddp.broadcast_parameters(0).unwrap();
416
417 let param_data = ddp.module().weight.tensor().data().unwrap();
418 assert!(
419 (param_data[0] - 10.0).abs() < 1e-5,
420 "rank {rank}: expected 10.0, got {}",
421 param_data[0]
422 );
423 assert!(
424 (param_data[1] - 20.0).abs() < 1e-5,
425 "rank {rank}: expected 20.0, got {}",
426 param_data[1]
427 );
428 assert!(
429 (param_data[2] - 30.0).abs() < 1e-5,
430 "rank {rank}: expected 30.0, got {}",
431 param_data[2]
432 );
433 })
434 })
435 .collect();
436
437 for h in handles {
438 h.join().unwrap();
439 }
440 }
441
442 #[test]
443 fn test_ddp_delegates_module_trait() {
444 let group = SimulatedBackend::create_group(1).unwrap();
445 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
446 let model = TestModule::<f32>::new(&[1.0, 2.0]).unwrap();
447 let mut ddp = DDP::new(model, b);
448
449 assert!(ddp.is_training());
451 ddp.eval();
452 assert!(!ddp.is_training());
453 ddp.train();
454 assert!(ddp.is_training());
455
456 assert_eq!(ddp.parameters().len(), 1);
457 assert_eq!(ddp.named_parameters()[0].0, "weight");
458 }
459}