1use std::sync::Arc;
12
13use ferrotorch_core::storage::TensorStorage;
14use ferrotorch_core::{FerrotorchResult, Float, Tensor};
15use ferrotorch_nn::{Module, Parameter};
16
17use crate::backend::Backend;
18use crate::collective::{ReduceOp, all_gather, reduce_scatter};
19
20pub struct FSDP<M: Module<T>, T: Float> {
55 module: M,
56 backend: Arc<dyn Backend>,
57 original_shapes: Vec<Vec<usize>>,
59 full_params: Vec<Tensor<T>>,
62 _marker: std::marker::PhantomData<T>,
63}
64
65impl<M: Module<T>, T: Float> FSDP<M, T> {
66 pub fn new(mut module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self> {
77 let rank = backend.rank();
78 let world_size = backend.world_size();
79 let mut original_shapes = Vec::new();
80
81 {
82 let params = module.parameters_mut();
83 for param in params {
84 let tensor = param.tensor();
85 let shape = tensor.shape().to_vec();
86 let numel = tensor.numel();
87
88 assert!(
89 numel % world_size == 0,
90 "FSDP: parameter with {} elements is not evenly divisible by world_size {}",
91 numel,
92 world_size,
93 );
94
95 original_shapes.push(shape);
96
97 let data = tensor.data_vec()?;
98 let chunk_size = numel / world_size;
99 let start = rank * chunk_size;
100 let end = start + chunk_size;
101 let shard_data = data[start..end].to_vec();
102
103 let shard_tensor =
104 Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
105 *param = Parameter::new(shard_tensor);
108 }
109 }
110
111 Ok(Self {
112 module,
113 backend,
114 original_shapes,
115 full_params: Vec::new(),
116 _marker: std::marker::PhantomData,
117 })
118 }
119
120 pub fn module(&self) -> &M {
122 &self.module
123 }
124
125 pub fn module_mut(&mut self) -> &mut M {
127 &mut self.module
128 }
129
130 pub fn into_inner(self) -> M {
132 self.module
133 }
134
135 pub fn backend(&self) -> &Arc<dyn Backend> {
137 &self.backend
138 }
139
140 pub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
146 let world_size = self.backend.world_size();
147 self.full_params.clear();
148
149 {
150 let params = self.module.parameters_mut();
151 for (i, param) in params.into_iter().enumerate() {
152 let shard = param.tensor().clone();
153 let orig_shape = &self.original_shapes[i];
154
155 let full = if world_size == 1 {
157 shard
158 } else {
159 all_gather(&shard, self.backend.as_ref())?
160 };
161
162 let full = Tensor::from_storage(
164 TensorStorage::cpu(full.data_vec()?),
165 orig_shape.clone(),
166 true,
167 )?;
168
169 self.full_params.push(full.clone());
170
171 *param = Parameter::new(full);
173 }
174 }
175
176 let output = self.module.forward(input)?;
177
178 self.restore_shards()?;
181
182 Ok(output)
183 }
184
185 fn restore_shards(&mut self) -> FerrotorchResult<()> {
187 let rank = self.backend.rank();
188 let world_size = self.backend.world_size();
189
190 let params = self.module.parameters_mut();
191 for (i, param) in params.into_iter().enumerate() {
192 let tensor = param.tensor();
193 let data = tensor.data_vec()?;
194 let numel = data.len();
195 let chunk_size = numel / world_size;
196 let start = rank * chunk_size;
197 let end = start + chunk_size;
198 let shard_data = data[start..end].to_vec();
199
200 let shard_tensor =
201 Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
202 *param = Parameter::new(shard_tensor);
203
204 let _ = &self.original_shapes[i];
206 }
207
208 Ok(())
209 }
210
211 pub fn sync_gradients(&mut self) -> FerrotorchResult<()> {
228 let world_size = self.backend.world_size();
229 let params = self.module.parameters_mut();
230
231 if self.full_params.len() != params.len() {
232 return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
233 message: format!(
234 "FSDP sync_gradients: expected {} full_params but have {}. \
235 Was forward() called before backward()?",
236 params.len(),
237 self.full_params.len(),
238 ),
239 });
240 }
241
242 for (i, param) in params.into_iter().enumerate() {
243 let full_param = &self.full_params[i];
244
245 let grad = full_param.grad()?;
249 let full_grad = match grad {
250 Some(g) => g,
251 None => {
252 let numel = full_param.numel();
253 Tensor::from_storage(
254 TensorStorage::cpu(vec![<T as num_traits::Zero>::zero(); numel]),
255 full_param.shape().to_vec(),
256 false,
257 )?
258 }
259 };
260
261 let grad_data = full_grad.data_vec()?;
263 let flat_grad = Tensor::from_storage(
264 TensorStorage::cpu(grad_data),
265 vec![full_grad.numel()],
266 false,
267 )?;
268
269 let shard_grad = if world_size == 1 {
271 flat_grad
272 } else {
273 reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
274 };
275
276 param.tensor().set_grad(Some(shard_grad))?;
279 }
280
281 self.full_params.clear();
283
284 Ok(())
285 }
286
287 pub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()> {
293 let params = self.module.parameters_mut();
294 let total_shard_numel: usize = params.iter().map(|p| p.tensor().numel()).sum();
295
296 assert!(
297 flat_data.len() == total_shard_numel,
298 "FSDP update_shards: expected {} elements but got {}",
299 total_shard_numel,
300 flat_data.len(),
301 );
302
303 let mut offset = 0;
304 for param in params {
305 let numel = param.tensor().numel();
306 let shard_data = flat_data[offset..offset + numel].to_vec();
307 let shard_tensor = Tensor::from_storage(
308 TensorStorage::cpu(shard_data),
309 param.tensor().shape().to_vec(),
310 true,
311 )?;
312 *param = Parameter::new(shard_tensor);
313 offset += numel;
314 }
315
316 Ok(())
317 }
318}
319
320#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::backend::SimulatedBackend;
327 use ferrotorch_core::storage::TensorStorage;
328 use ferrotorch_core::{FerrotorchResult, Tensor};
329 use ferrotorch_nn::Parameter;
330 use std::thread;
331
332 struct TestModule<T: Float> {
334 weight: Parameter<T>,
335 training: bool,
336 }
337
338 impl<T: Float> TestModule<T> {
339 fn new(data: &[T]) -> FerrotorchResult<Self> {
340 Ok(Self {
341 weight: Parameter::from_slice(data, &[data.len()])?,
342 training: true,
343 })
344 }
345 }
346
347 impl<T: Float> Module<T> for TestModule<T> {
348 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
349 let w_data = self.weight.tensor().data_vec()?;
352 let w_sum: T = w_data
353 .iter()
354 .copied()
355 .fold(<T as num_traits::Zero>::zero(), |a, b| a + b);
356 let i_data = input.data_vec()?;
357 let out: Vec<T> = i_data.iter().map(|&x| x * w_sum).collect();
358 Tensor::from_storage(TensorStorage::cpu(out), input.shape().to_vec(), false)
359 }
360
361 fn parameters(&self) -> Vec<&Parameter<T>> {
362 vec![&self.weight]
363 }
364
365 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
366 vec![&mut self.weight]
367 }
368
369 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
370 vec![("weight".into(), &self.weight)]
371 }
372
373 fn train(&mut self) {
374 self.training = true;
375 }
376
377 fn eval(&mut self) {
378 self.training = false;
379 }
380
381 fn is_training(&self) -> bool {
382 self.training
383 }
384 }
385
386 #[test]
387 fn test_fsdp_sharding() {
388 let group = SimulatedBackend::create_group(2).unwrap();
391 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
392
393 let handles: Vec<_> = arcs
394 .iter()
395 .cloned()
396 .map(|b| {
397 thread::spawn(move || {
398 let rank = b.rank();
399 let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
400 let fsdp = FSDP::new(model, b).unwrap();
401
402 let shard = fsdp.module().weight.tensor().data_vec().unwrap();
403 (rank, shard)
404 })
405 })
406 .collect();
407
408 for h in handles {
409 let (rank, shard) = h.join().unwrap();
410 if rank == 0 {
411 assert_eq!(shard, &[10.0, 20.0]);
412 } else {
413 assert_eq!(shard, &[30.0, 40.0]);
414 }
415 }
416 }
417
418 #[test]
419 fn test_fsdp_shard_requires_grad() {
420 let group = SimulatedBackend::create_group(2).unwrap();
422 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
423
424 let handles: Vec<_> = arcs
425 .iter()
426 .cloned()
427 .map(|b| {
428 thread::spawn(move || {
429 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
430 let fsdp = FSDP::new(model, b).unwrap();
431 fsdp.module().weight.tensor().requires_grad()
432 })
433 })
434 .collect();
435
436 for h in handles {
437 assert!(h.join().unwrap(), "shard must have requires_grad=true");
438 }
439 }
440
441 #[test]
442 fn test_fsdp_forward_restores_shards() {
443 let group = SimulatedBackend::create_group(2).unwrap();
445 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
446
447 let handles: Vec<_> = arcs
448 .iter()
449 .cloned()
450 .map(|b| {
451 thread::spawn(move || {
452 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
453 let mut fsdp = FSDP::new(model, b).unwrap();
454
455 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
456 let _output = fsdp.forward(&input).unwrap();
457
458 let shard = fsdp.module().weight.tensor();
460 assert_eq!(shard.numel(), 2);
461 assert!(shard.requires_grad());
462 })
463 })
464 .collect();
465
466 for h in handles {
467 h.join().unwrap();
468 }
469 }
470
471 #[test]
472 fn test_fsdp_forward_produces_correct_output() {
473 let group = SimulatedBackend::create_group(2).unwrap();
476 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
477
478 let handles: Vec<_> = arcs
479 .iter()
480 .cloned()
481 .map(|b| {
482 thread::spawn(move || {
483 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
484 let mut fsdp = FSDP::new(model, b).unwrap();
485
486 let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
487 let output = fsdp.forward(&input).unwrap();
488 let data = output.data_vec().unwrap();
489 assert!(
490 (data[0] - 20.0).abs() < 1e-6,
491 "expected 20.0, got {}",
492 data[0]
493 );
494 })
495 })
496 .collect();
497
498 for h in handles {
499 h.join().unwrap();
500 }
501 }
502
503 #[test]
504 fn test_fsdp_update_shards() {
505 let group = SimulatedBackend::create_group(1).unwrap();
506 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
507 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
508 let mut fsdp = FSDP::new(model, b).unwrap();
509
510 fsdp.update_shards(&[10.0, 20.0, 30.0, 40.0]).unwrap();
511 let data = fsdp.module().weight.tensor().data_vec().unwrap();
512 assert_eq!(data, &[10.0, 20.0, 30.0, 40.0]);
513 }
514
515 #[test]
516 #[should_panic(expected = "expected 4 elements but got 2")]
517 fn test_fsdp_update_shards_size_validation() {
518 let group = SimulatedBackend::create_group(1).unwrap();
519 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
520 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
521 let mut fsdp = FSDP::new(model, b).unwrap();
522
523 fsdp.update_shards(&[10.0, 20.0]).unwrap();
525 }
526
527 #[test]
528 fn test_fsdp_sync_gradients_single_rank() {
529 let group = SimulatedBackend::create_group(1).unwrap();
532 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
533 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
534 let mut fsdp = FSDP::new(model, b).unwrap();
535
536 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
538 let _output = fsdp.forward(&input).unwrap();
539
540 let grad = Tensor::from_storage(
542 TensorStorage::cpu(vec![0.1f32, 0.2, 0.3, 0.4]),
543 vec![4],
544 false,
545 )
546 .unwrap();
547 fsdp.full_params[0].set_grad(Some(grad)).unwrap();
548
549 fsdp.sync_gradients().unwrap();
550
551 let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
553 let data = shard_grad.data_vec().unwrap();
554 assert_eq!(data, &[0.1, 0.2, 0.3, 0.4]);
555 }
556
557 #[test]
558 fn test_fsdp_sync_gradients_multi_rank() {
559 let group = SimulatedBackend::create_group(2).unwrap();
563 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
564
565 let handles: Vec<_> = arcs
566 .iter()
567 .cloned()
568 .map(|b| {
569 thread::spawn(move || {
570 let rank = b.rank();
571 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
572 let mut fsdp = FSDP::new(model, b).unwrap();
573
574 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
576 let _output = fsdp.forward(&input).unwrap();
577
578 let grad = Tensor::from_storage(
580 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
581 vec![4],
582 false,
583 )
584 .unwrap();
585 fsdp.full_params[0].set_grad(Some(grad)).unwrap();
586
587 fsdp.sync_gradients().unwrap();
588
589 let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
590 let data = shard_grad.data_vec().unwrap();
591 (rank, data)
592 })
593 })
594 .collect();
595
596 for h in handles {
597 let (rank, data) = h.join().unwrap();
598 if rank == 0 {
599 assert_eq!(data.len(), 2);
601 assert!(
602 (data[0] - 1.0).abs() < 1e-6,
603 "rank 0: expected 1.0, got {}",
604 data[0]
605 );
606 assert!(
607 (data[1] - 2.0).abs() < 1e-6,
608 "rank 0: expected 2.0, got {}",
609 data[1]
610 );
611 } else {
612 assert_eq!(data.len(), 2);
614 assert!(
615 (data[0] - 3.0).abs() < 1e-6,
616 "rank 1: expected 3.0, got {}",
617 data[0]
618 );
619 assert!(
620 (data[1] - 4.0).abs() < 1e-6,
621 "rank 1: expected 4.0, got {}",
622 data[1]
623 );
624 }
625 }
626 }
627}