use std::sync::Arc;
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::{FerrotorchResult, Float, Tensor};
use ferrotorch_nn::{Module, Parameter};
use crate::backend::Backend;
use crate::collective::{ReduceOp, all_gather, reduce_scatter};
pub struct FSDP<M: Module<T>, T: Float> {
module: M,
backend: Arc<dyn Backend>,
original_shapes: Vec<Vec<usize>>,
full_params: Vec<Tensor<T>>,
_marker: std::marker::PhantomData<T>,
}
impl<M: Module<T>, T: Float> FSDP<M, T> {
pub fn new(mut module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self> {
let rank = backend.rank();
let world_size = backend.world_size();
let mut original_shapes = Vec::new();
{
let params = module.parameters_mut();
for param in params {
let tensor = param.tensor();
let shape = tensor.shape().to_vec();
let numel = tensor.numel();
assert!(
numel % world_size == 0,
"FSDP: parameter with {} elements is not evenly divisible by world_size {}",
numel,
world_size,
);
original_shapes.push(shape);
let data = tensor.data_vec()?;
let chunk_size = numel / world_size;
let start = rank * chunk_size;
let end = start + chunk_size;
let shard_data = data[start..end].to_vec();
let shard_tensor =
Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
*param = Parameter::new(shard_tensor);
}
}
Ok(Self {
module,
backend,
original_shapes,
full_params: Vec::new(),
_marker: std::marker::PhantomData,
})
}
pub fn module(&self) -> &M {
&self.module
}
pub fn module_mut(&mut self) -> &mut M {
&mut self.module
}
pub fn into_inner(self) -> M {
self.module
}
pub fn backend(&self) -> &Arc<dyn Backend> {
&self.backend
}
pub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let world_size = self.backend.world_size();
self.full_params.clear();
{
let params = self.module.parameters_mut();
for (i, param) in params.into_iter().enumerate() {
let shard = param.tensor().clone();
let orig_shape = &self.original_shapes[i];
let full = if world_size == 1 {
shard
} else {
all_gather(&shard, self.backend.as_ref())?
};
let full = Tensor::from_storage(
TensorStorage::cpu(full.data_vec()?),
orig_shape.clone(),
true,
)?;
self.full_params.push(full.clone());
*param = Parameter::new(full);
}
}
let output = self.module.forward(input)?;
self.restore_shards()?;
Ok(output)
}
fn restore_shards(&mut self) -> FerrotorchResult<()> {
let rank = self.backend.rank();
let world_size = self.backend.world_size();
let params = self.module.parameters_mut();
for (i, param) in params.into_iter().enumerate() {
let tensor = param.tensor();
let data = tensor.data_vec()?;
let numel = data.len();
let chunk_size = numel / world_size;
let start = rank * chunk_size;
let end = start + chunk_size;
let shard_data = data[start..end].to_vec();
let shard_tensor =
Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
*param = Parameter::new(shard_tensor);
let _ = &self.original_shapes[i];
}
Ok(())
}
pub fn sync_gradients(&mut self) -> FerrotorchResult<()> {
let world_size = self.backend.world_size();
let params = self.module.parameters_mut();
if self.full_params.len() != params.len() {
return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!(
"FSDP sync_gradients: expected {} full_params but have {}. \
Was forward() called before backward()?",
params.len(),
self.full_params.len(),
),
});
}
for (i, param) in params.into_iter().enumerate() {
let full_param = &self.full_params[i];
let grad = full_param.grad()?;
let full_grad = match grad {
Some(g) => g,
None => {
let numel = full_param.numel();
Tensor::from_storage(
TensorStorage::cpu(vec![<T as num_traits::Zero>::zero(); numel]),
full_param.shape().to_vec(),
false,
)?
}
};
let grad_data = full_grad.data_vec()?;
let flat_grad = Tensor::from_storage(
TensorStorage::cpu(grad_data),
vec![full_grad.numel()],
false,
)?;
let shard_grad = if world_size == 1 {
flat_grad
} else {
reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
};
param.tensor().set_grad(Some(shard_grad))?;
}
self.full_params.clear();
Ok(())
}
pub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()> {
let params = self.module.parameters_mut();
let total_shard_numel: usize = params.iter().map(|p| p.tensor().numel()).sum();
assert!(
flat_data.len() == total_shard_numel,
"FSDP update_shards: expected {} elements but got {}",
total_shard_numel,
flat_data.len(),
);
let mut offset = 0;
for param in params {
let numel = param.tensor().numel();
let shard_data = flat_data[offset..offset + numel].to_vec();
let shard_tensor = Tensor::from_storage(
TensorStorage::cpu(shard_data),
param.tensor().shape().to_vec(),
true,
)?;
*param = Parameter::new(shard_tensor);
offset += numel;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::SimulatedBackend;
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::{FerrotorchResult, Tensor};
use ferrotorch_nn::Parameter;
use std::thread;
struct TestModule<T: Float> {
weight: Parameter<T>,
training: bool,
}
impl<T: Float> TestModule<T> {
fn new(data: &[T]) -> FerrotorchResult<Self> {
Ok(Self {
weight: Parameter::from_slice(data, &[data.len()])?,
training: true,
})
}
}
impl<T: Float> Module<T> for TestModule<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let w_data = self.weight.tensor().data_vec()?;
let w_sum: T = w_data
.iter()
.copied()
.fold(<T as num_traits::Zero>::zero(), |a, b| a + b);
let i_data = input.data_vec()?;
let out: Vec<T> = i_data.iter().map(|&x| x * w_sum).collect();
Tensor::from_storage(TensorStorage::cpu(out), input.shape().to_vec(), false)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![&self.weight]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![&mut self.weight]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![("weight".into(), &self.weight)]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[test]
fn test_fsdp_sharding() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
let fsdp = FSDP::new(model, b).unwrap();
let shard = fsdp.module().weight.tensor().data_vec().unwrap();
(rank, shard)
})
})
.collect();
for h in handles {
let (rank, shard) = h.join().unwrap();
if rank == 0 {
assert_eq!(shard, &[10.0, 20.0]);
} else {
assert_eq!(shard, &[30.0, 40.0]);
}
}
}
#[test]
fn test_fsdp_shard_requires_grad() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let fsdp = FSDP::new(model, b).unwrap();
fsdp.module().weight.tensor().requires_grad()
})
})
.collect();
for h in handles {
assert!(h.join().unwrap(), "shard must have requires_grad=true");
}
}
#[test]
fn test_fsdp_forward_restores_shards() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let shard = fsdp.module().weight.tensor();
assert_eq!(shard.numel(), 2);
assert!(shard.requires_grad());
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_fsdp_forward_produces_correct_output() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
let output = fsdp.forward(&input).unwrap();
let data = output.data_vec().unwrap();
assert!(
(data[0] - 20.0).abs() < 1e-6,
"expected 20.0, got {}",
data[0]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_fsdp_update_shards() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
fsdp.update_shards(&[10.0, 20.0, 30.0, 40.0]).unwrap();
let data = fsdp.module().weight.tensor().data_vec().unwrap();
assert_eq!(data, &[10.0, 20.0, 30.0, 40.0]);
}
#[test]
#[should_panic(expected = "expected 4 elements but got 2")]
fn test_fsdp_update_shards_size_validation() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
fsdp.update_shards(&[10.0, 20.0]).unwrap();
}
#[test]
fn test_fsdp_sync_gradients_single_rank() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(vec![0.1f32, 0.2, 0.3, 0.4]),
vec![4],
false,
)
.unwrap();
fsdp.full_params[0].set_grad(Some(grad)).unwrap();
fsdp.sync_gradients().unwrap();
let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
let data = shard_grad.data_vec().unwrap();
assert_eq!(data, &[0.1, 0.2, 0.3, 0.4]);
}
#[test]
fn test_fsdp_sync_gradients_multi_rank() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
vec![4],
false,
)
.unwrap();
fsdp.full_params[0].set_grad(Some(grad)).unwrap();
fsdp.sync_gradients().unwrap();
let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
let data = shard_grad.data_vec().unwrap();
(rank, data)
})
})
.collect();
for h in handles {
let (rank, data) = h.join().unwrap();
if rank == 0 {
assert_eq!(data.len(), 2);
assert!(
(data[0] - 1.0).abs() < 1e-6,
"rank 0: expected 1.0, got {}",
data[0]
);
assert!(
(data[1] - 2.0).abs() < 1e-6,
"rank 0: expected 2.0, got {}",
data[1]
);
} else {
assert_eq!(data.len(), 2);
assert!(
(data[0] - 3.0).abs() < 1e-6,
"rank 1: expected 3.0, got {}",
data[0]
);
assert!(
(data[1] - 4.0).abs() < 1e-6,
"rank 1: expected 4.0, got {}",
data[1]
);
}
}
}
}