use std::sync::Arc;
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::{FerrotorchResult, Float, Tensor};
use ferrotorch_nn::{Module, Parameter};
use crate::async_collective::{PendingCollective, async_all_gather};
use crate::backend::{Backend, SubBackend};
use crate::collective::{ReduceOp, all_gather, allreduce, reduce_scatter};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ShardingStrategy {
#[default]
FullShard,
ShardGradOp,
NoShard,
HybridShard { intra_node_size: usize },
}
pub struct FSDP<M: Module<T>, T: Float> {
module: M,
backend: Arc<dyn Backend>,
strategy: ShardingStrategy,
original_shapes: Vec<Vec<usize>>,
full_params: Vec<Tensor<T>>,
pending_prefetch: Option<Vec<PendingCollective<T>>>,
intra_node_group: Option<Arc<SubBackend>>,
inter_node_group: Option<Arc<SubBackend>>,
_marker: std::marker::PhantomData<T>,
}
impl<M: Module<T>, T: Float> FSDP<M, T> {
pub fn new(module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self> {
Self::new_with_strategy(module, backend, ShardingStrategy::FullShard)
}
pub fn new_with_strategy(
mut module: M,
backend: Arc<dyn Backend>,
strategy: ShardingStrategy,
) -> FerrotorchResult<Self> {
let rank = backend.rank();
let world_size = backend.world_size();
let mut original_shapes = Vec::new();
let (intra_node_group, inter_node_group) = match strategy {
ShardingStrategy::HybridShard { intra_node_size } => {
if intra_node_size == 0 || world_size % intra_node_size != 0 {
return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!(
"HybridShard: world_size={world_size} must be a positive multiple of intra_node_size={intra_node_size}"
),
});
}
let node_idx = rank / intra_node_size;
let local_idx = rank % intra_node_size;
let intra_members: Vec<usize> = (node_idx * intra_node_size
..node_idx * intra_node_size + intra_node_size)
.collect();
let intra = Arc::new(SubBackend::new(Arc::clone(&backend), intra_members)?);
let inter_members: Vec<usize> = (0..world_size / intra_node_size)
.map(|n| n * intra_node_size + local_idx)
.collect();
let inter = Arc::new(SubBackend::new(Arc::clone(&backend), inter_members)?);
(Some(intra), Some(inter))
}
_ => (None, None),
};
{
let params = module.parameters_mut();
for param in params {
let tensor = param.tensor();
let shape = tensor.shape().to_vec();
original_shapes.push(shape);
match strategy {
ShardingStrategy::FullShard => {
let numel = tensor.numel();
assert!(
numel % world_size == 0,
"FSDP: parameter with {numel} elements is not evenly divisible by world_size {world_size}"
);
let data = tensor.data_vec()?;
let chunk_size = numel / world_size;
let start = rank * chunk_size;
let end = start + chunk_size;
let shard_data = data[start..end].to_vec();
let shard_tensor = Tensor::from_storage(
TensorStorage::cpu(shard_data),
vec![chunk_size],
true,
)?;
*param = Parameter::new(shard_tensor);
}
ShardingStrategy::HybridShard { .. } => {
let intra = intra_node_group.as_ref().expect(
"FSDP::new_with_strategy: intra_node_group is Some \
for HybridShard (set ~50 lines above on the same \
match arm; never reassigned)",
);
let intra_size = intra.world_size();
let intra_rank = intra.rank();
let numel = tensor.numel();
assert!(
numel % intra_size == 0,
"FSDP HybridShard: parameter with {numel} elements is not evenly divisible by intra_node_size {intra_size}"
);
let data = tensor.data_vec()?;
let chunk_size = numel / intra_size;
let start = intra_rank * chunk_size;
let end = start + chunk_size;
let shard_data = data[start..end].to_vec();
let shard_tensor = Tensor::from_storage(
TensorStorage::cpu(shard_data),
vec![chunk_size],
true,
)?;
*param = Parameter::new(shard_tensor);
}
ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
}
}
}
}
Ok(Self {
module,
backend,
strategy,
original_shapes,
full_params: Vec::new(),
pending_prefetch: None,
intra_node_group,
inter_node_group,
_marker: std::marker::PhantomData,
})
}
pub fn strategy(&self) -> ShardingStrategy {
self.strategy
}
pub fn prefetch_forward_params(&mut self) -> FerrotorchResult<()> {
if self.strategy != ShardingStrategy::FullShard {
return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!(
"FSDP::prefetch_forward_params: prefetch is only valid for FullShard, got {:?}",
self.strategy
),
});
}
if self.pending_prefetch.is_some() {
return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
message: "FSDP::prefetch_forward_params called twice without intervening forward()"
.into(),
});
}
let mut handles = Vec::new();
{
let params = self.module.parameters();
for param in params {
let shard = param.tensor().clone();
let h = async_all_gather(shard, Arc::clone(&self.backend));
handles.push(h);
}
}
self.pending_prefetch = Some(handles);
Ok(())
}
pub fn has_pending_prefetch(&self) -> bool {
self.pending_prefetch.is_some()
}
pub fn module(&self) -> &M {
&self.module
}
pub fn module_mut(&mut self) -> &mut M {
&mut self.module
}
pub fn into_inner(self) -> M {
self.module
}
pub fn backend(&self) -> &Arc<dyn Backend> {
&self.backend
}
pub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let world_size = self.backend.world_size();
self.full_params.clear();
let mut pending = self.pending_prefetch.take();
match self.strategy {
ShardingStrategy::FullShard => {
if let Some(ref p) = pending {
if p.len() != self.module.parameters().len() {
return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!(
"FSDP prefetch: have {} pending handles but module has {} parameters",
p.len(),
self.module.parameters().len(),
),
});
}
}
if let Some(ref mut p) = pending {
p.reverse();
}
let params = self.module.parameters_mut();
for (i, param) in params.into_iter().enumerate() {
let orig_shape = &self.original_shapes[i];
let full = if let Some(ref mut handles) = pending {
let handle = handles.pop().ok_or_else(|| {
ferrotorch_core::FerrotorchError::InvalidArgument {
message: "FSDP prefetch: exhausted pending handles".into(),
}
})?;
handle.wait()?
} else {
let shard = param.tensor().clone();
if world_size == 1 {
shard
} else {
all_gather(&shard, self.backend.as_ref())?
}
};
let full = Tensor::from_storage(
TensorStorage::cpu(full.data_vec()?),
orig_shape.clone(),
true,
)?;
self.full_params.push(full.clone());
*param = Parameter::new(full);
}
}
ShardingStrategy::HybridShard { .. } => {
if pending.is_some() {
return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
message:
"FSDP prefetch_forward_params is not yet implemented for HybridShard"
.into(),
});
}
let intra = self.intra_node_group.as_ref().expect(
"FSDP::forward (HybridShard): intra_node_group is Some for \
HybridShard strategy (paired with strategy in \
new_with_strategy; never reassigned)",
);
let intra_ref: &dyn Backend = &**intra;
let intra_size = intra.world_size();
let params = self.module.parameters_mut();
for (i, param) in params.into_iter().enumerate() {
let orig_shape = &self.original_shapes[i];
let shard = param.tensor().clone();
let full = if intra_size == 1 {
shard
} else {
all_gather(&shard, intra_ref)?
};
let full = Tensor::from_storage(
TensorStorage::cpu(full.data_vec()?),
orig_shape.clone(),
true,
)?;
self.full_params.push(full.clone());
*param = Parameter::new(full);
}
}
ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
if pending.is_some() {
return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
message: "FSDP prefetch_forward_params is only valid for FullShard; \
use ShardingStrategy::FullShard or don't prefetch"
.into(),
});
}
let params = self.module.parameters_mut();
for param in params.into_iter() {
let t = param.tensor().clone();
let data = t.data_vec()?;
let shape = t.shape().to_vec();
let full = Tensor::from_storage(TensorStorage::cpu(data), shape, true)?;
self.full_params.push(full.clone());
*param = Parameter::new(full);
}
}
}
let output = self.module.forward(input)?;
match self.strategy {
ShardingStrategy::FullShard => self.restore_shards()?,
ShardingStrategy::HybridShard { .. } => self.restore_hybrid_shards()?,
ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
}
}
Ok(output)
}
fn restore_hybrid_shards(&mut self) -> FerrotorchResult<()> {
let intra = self
.intra_node_group
.as_ref()
.expect(
"FSDP::restore_hybrid_shards: intra_node_group is Some for \
HybridShard (only callsite is forward() under HybridShard arm)",
)
.clone();
let intra_size = intra.world_size();
let intra_rank = intra.rank();
let params = self.module.parameters_mut();
for (i, param) in params.into_iter().enumerate() {
let tensor = param.tensor();
let data = tensor.data_vec()?;
let numel = data.len();
let chunk_size = numel / intra_size;
let start = intra_rank * chunk_size;
let end = start + chunk_size;
let shard_data = data[start..end].to_vec();
let shard_tensor =
Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
*param = Parameter::new(shard_tensor);
let _ = &self.original_shapes[i];
}
Ok(())
}
fn restore_shards(&mut self) -> FerrotorchResult<()> {
let rank = self.backend.rank();
let world_size = self.backend.world_size();
let params = self.module.parameters_mut();
for (i, param) in params.into_iter().enumerate() {
let tensor = param.tensor();
let data = tensor.data_vec()?;
let numel = data.len();
let chunk_size = numel / world_size;
let start = rank * chunk_size;
let end = start + chunk_size;
let shard_data = data[start..end].to_vec();
let shard_tensor =
Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
*param = Parameter::new(shard_tensor);
let _ = &self.original_shapes[i];
}
Ok(())
}
pub fn sync_gradients(&mut self) -> FerrotorchResult<()> {
let rank = self.backend.rank();
let world_size = self.backend.world_size();
let params = self.module.parameters_mut();
if self.full_params.len() != params.len() {
return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!(
"FSDP sync_gradients: expected {} full_params but have {}. \
Was forward() called before backward()?",
params.len(),
self.full_params.len(),
),
});
}
for (i, param) in params.into_iter().enumerate() {
let full_param = &self.full_params[i];
let grad = full_param.grad()?;
let full_grad = match grad {
Some(g) => g,
None => {
let numel = full_param.numel();
Tensor::from_storage(
TensorStorage::cpu(vec![<T as num_traits::Zero>::zero(); numel]),
full_param.shape().to_vec(),
false,
)?
}
};
let grad_data = full_grad.data_vec()?;
let flat_grad = Tensor::from_storage(
TensorStorage::cpu(grad_data),
vec![full_grad.numel()],
false,
)?;
match self.strategy {
ShardingStrategy::FullShard => {
let shard_grad = if world_size == 1 {
flat_grad
} else {
reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
};
param.tensor().set_grad(Some(shard_grad))?;
}
ShardingStrategy::ShardGradOp => {
let numel = flat_grad.numel();
assert!(
numel % world_size == 0,
"FSDP ShardGradOp: parameter with {numel} elements is not evenly \
divisible by world_size {world_size}"
);
let shard_grad_flat = if world_size == 1 {
flat_grad
} else {
reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
};
let chunk_size = numel / world_size;
let shard_data = shard_grad_flat.data_vec()?;
let mut padded = vec![<T as num_traits::Zero>::zero(); numel];
let start = rank * chunk_size;
padded[start..start + chunk_size].copy_from_slice(&shard_data);
let padded_grad = Tensor::from_storage(
TensorStorage::cpu(padded),
full_param.shape().to_vec(),
false,
)?;
param.tensor().set_grad(Some(padded_grad))?;
}
ShardingStrategy::HybridShard { .. } => {
let intra = self.intra_node_group.as_ref().expect(
"FSDP::sync_gradients (HybridShard): intra_node_group \
is Some for HybridShard (paired with strategy in \
new_with_strategy; never reassigned)",
);
let inter = self.inter_node_group.as_ref().expect(
"FSDP::sync_gradients (HybridShard): inter_node_group \
is Some for HybridShard (paired with strategy in \
new_with_strategy; never reassigned)",
);
let intra_ref: &dyn Backend = &**intra;
let inter_ref: &dyn Backend = &**inter;
let intra_size = intra.world_size();
let inter_size = inter.world_size();
let intra_shard = if intra_size == 1 {
flat_grad
} else {
reduce_scatter(&flat_grad, intra_ref, ReduceOp::Mean)?
};
let replicated = if inter_size == 1 {
intra_shard
} else {
allreduce(&intra_shard, inter_ref, ReduceOp::Mean)?
};
param.tensor().set_grad(Some(replicated))?;
}
ShardingStrategy::NoShard => {
let reduced = if world_size == 1 {
flat_grad
} else {
allreduce(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
};
let reduced_full = Tensor::from_storage(
TensorStorage::cpu(reduced.data_vec()?),
full_param.shape().to_vec(),
false,
)?;
param.tensor().set_grad(Some(reduced_full))?;
}
}
}
self.full_params.clear();
Ok(())
}
pub fn broadcast_updated_params(&mut self) -> FerrotorchResult<()> {
if self.strategy != ShardingStrategy::ShardGradOp {
return Ok(());
}
let rank = self.backend.rank();
let world_size = self.backend.world_size();
if world_size == 1 {
return Ok(());
}
let params = self.module.parameters_mut();
for param in params {
let full = param.tensor();
let full_data = full.data_vec()?;
let numel = full_data.len();
assert!(
numel % world_size == 0,
"FSDP broadcast_updated_params: parameter with {numel} elements is not evenly \
divisible by world_size {world_size}"
);
let chunk_size = numel / world_size;
let start = rank * chunk_size;
let end = start + chunk_size;
let shard = full_data[start..end].to_vec();
let shard_tensor =
Tensor::from_storage(TensorStorage::cpu(shard), vec![chunk_size], false)?;
let gathered = all_gather(&shard_tensor, self.backend.as_ref())?;
let full_shape = full.shape().to_vec();
let new_full =
Tensor::from_storage(TensorStorage::cpu(gathered.data_vec()?), full_shape, true)?;
*param = Parameter::new(new_full);
}
Ok(())
}
pub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()> {
let params = self.module.parameters_mut();
let total_shard_numel: usize = params.iter().map(|p| p.tensor().numel()).sum();
assert!(
flat_data.len() == total_shard_numel,
"FSDP update_shards: expected {} elements but got {}",
total_shard_numel,
flat_data.len(),
);
let mut offset = 0;
for param in params {
let numel = param.tensor().numel();
let shard_data = flat_data[offset..offset + numel].to_vec();
let shard_tensor = Tensor::from_storage(
TensorStorage::cpu(shard_data),
param.tensor().shape().to_vec(),
true,
)?;
*param = Parameter::new(shard_tensor);
offset += numel;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::SimulatedBackend;
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::{FerrotorchResult, Tensor};
use ferrotorch_nn::Parameter;
use std::thread;
struct TestModule<T: Float> {
weight: Parameter<T>,
training: bool,
}
impl<T: Float> TestModule<T> {
fn new(data: &[T]) -> FerrotorchResult<Self> {
Ok(Self {
weight: Parameter::from_slice(data, &[data.len()])?,
training: true,
})
}
}
impl<T: Float> Module<T> for TestModule<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let w_data = self.weight.tensor().data_vec()?;
let w_sum: T = w_data
.iter()
.copied()
.fold(<T as num_traits::Zero>::zero(), |a, b| a + b);
let i_data = input.data_vec()?;
let out: Vec<T> = i_data.iter().map(|&x| x * w_sum).collect();
Tensor::from_storage(TensorStorage::cpu(out), input.shape().to_vec(), false)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![&self.weight]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![&mut self.weight]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![("weight".into(), &self.weight)]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[test]
fn test_fsdp_sharding() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
let fsdp = FSDP::new(model, b).unwrap();
let shard = fsdp.module().weight.tensor().data_vec().unwrap();
(rank, shard)
})
})
.collect();
for h in handles {
let (rank, shard) = h.join().unwrap();
if rank == 0 {
assert_eq!(shard, &[10.0, 20.0]);
} else {
assert_eq!(shard, &[30.0, 40.0]);
}
}
}
#[test]
fn test_fsdp_shard_requires_grad() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let fsdp = FSDP::new(model, b).unwrap();
fsdp.module().weight.tensor().requires_grad()
})
})
.collect();
for h in handles {
assert!(h.join().unwrap(), "shard must have requires_grad=true");
}
}
#[test]
fn test_fsdp_forward_restores_shards() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let shard = fsdp.module().weight.tensor();
assert_eq!(shard.numel(), 2);
assert!(shard.requires_grad());
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_fsdp_forward_produces_correct_output() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
let output = fsdp.forward(&input).unwrap();
let data = output.data_vec().unwrap();
assert!(
(data[0] - 20.0).abs() < 1e-6,
"expected 20.0, got {}",
data[0]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_fsdp_update_shards() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
fsdp.update_shards(&[10.0, 20.0, 30.0, 40.0]).unwrap();
let data = fsdp.module().weight.tensor().data_vec().unwrap();
assert_eq!(data, &[10.0, 20.0, 30.0, 40.0]);
}
#[test]
#[should_panic(expected = "expected 4 elements but got 2")]
fn test_fsdp_update_shards_size_validation() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
fsdp.update_shards(&[10.0, 20.0]).unwrap();
}
#[test]
fn test_fsdp_sync_gradients_single_rank() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(vec![0.1f32, 0.2, 0.3, 0.4]),
vec![4],
false,
)
.unwrap();
fsdp.full_params[0].set_grad(Some(grad)).unwrap();
fsdp.sync_gradients().unwrap();
let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
let data = shard_grad.data_vec().unwrap();
assert_eq!(data, &[0.1, 0.2, 0.3, 0.4]);
}
#[test]
fn test_fsdp_shard_grad_op_keeps_full_params() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let fsdp =
FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
fsdp.module().weight.tensor().data_vec().unwrap()
})
})
.collect();
for h in handles {
let data = h.join().unwrap();
assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
}
}
#[test]
#[allow(clippy::float_cmp)]
fn test_fsdp_shard_grad_op_sync_gradients_multi_rank() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp =
FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
vec![4],
false,
)
.unwrap();
fsdp.full_params[0].set_grad(Some(grad)).unwrap();
fsdp.sync_gradients().unwrap();
let w = fsdp.module().weight.tensor();
assert_eq!(w.numel(), 4, "ShardGradOp keeps params full");
let g = w.grad().unwrap().unwrap();
let gd = g.data_vec().unwrap();
assert_eq!(gd.len(), 4, "grad should be full-shape");
(rank, gd)
})
})
.collect();
for h in handles {
let (rank, gd) = h.join().unwrap();
if rank == 0 {
assert!((gd[0] - 1.0).abs() < 1e-6, "rank 0 [0]: {}", gd[0]);
assert!((gd[1] - 2.0).abs() < 1e-6, "rank 0 [1]: {}", gd[1]);
assert_eq!(gd[2], 0.0, "rank 0 [2] should be zero");
assert_eq!(gd[3], 0.0, "rank 0 [3] should be zero");
} else {
assert_eq!(gd[0], 0.0, "rank 1 [0] should be zero");
assert_eq!(gd[1], 0.0, "rank 1 [1] should be zero");
assert!((gd[2] - 3.0).abs() < 1e-6, "rank 1 [2]: {}", gd[2]);
assert!((gd[3] - 4.0).abs() < 1e-6, "rank 1 [3]: {}", gd[3]);
}
}
}
#[test]
fn test_fsdp_shard_grad_op_broadcast_updated_params() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp =
FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
let mut local = fsdp.module().weight.tensor().data_vec().unwrap();
if rank == 0 {
local[0] += 10.0;
local[1] += 10.0;
} else {
local[2] += 20.0;
local[3] += 20.0;
}
let new_param =
Tensor::from_storage(TensorStorage::cpu(local), vec![4], true).unwrap();
*fsdp.module.parameters_mut()[0] = Parameter::new(new_param);
fsdp.broadcast_updated_params().unwrap();
fsdp.module().weight.tensor().data_vec().unwrap()
})
})
.collect();
for h in handles {
let data = h.join().unwrap();
assert_eq!(data, &[11.0, 12.0, 23.0, 24.0]);
}
}
#[test]
fn test_fsdp_no_shard_is_ddp_equivalent() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp =
FSDP::new_with_strategy(model, b, ShardingStrategy::NoShard).unwrap();
assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
assert_eq!(fsdp.module().weight.tensor().numel(), 4);
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
vec![4],
false,
)
.unwrap();
fsdp.full_params[0].set_grad(Some(grad)).unwrap();
fsdp.sync_gradients().unwrap();
let w = fsdp.module().weight.tensor();
assert_eq!(w.numel(), 4, "NoShard keeps params full");
w.grad().unwrap().unwrap().data_vec().unwrap()
})
})
.collect();
for h in handles {
let gd = h.join().unwrap();
assert_eq!(gd.len(), 4);
for (i, expected) in [1.0f32, 2.0, 3.0, 4.0].iter().enumerate() {
assert!(
(gd[i] - expected).abs() < 1e-6,
"NoShard allreduce: got {} at {}, expected {}",
gd[i],
i,
expected
);
}
}
}
#[test]
fn test_fsdp_prefetched_forward_matches_sync_forward() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
assert!(!fsdp.has_pending_prefetch());
fsdp.prefetch_forward_params().unwrap();
assert!(fsdp.has_pending_prefetch());
let _scratch: f32 = (0..100).map(|i| i as f32).sum();
let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
let output = fsdp.forward(&input).unwrap();
assert!(!fsdp.has_pending_prefetch());
output.data_vec().unwrap()[0]
})
})
.collect();
for h in handles {
let v = h.join().unwrap();
assert!((v - 20.0).abs() < 1e-6, "expected 20.0, got {v}");
}
}
#[test]
fn test_fsdp_forward_without_prefetch_still_works() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
assert!(!fsdp.has_pending_prefetch());
let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
let output = fsdp.forward(&input).unwrap();
output.data_vec().unwrap()[0]
})
})
.collect();
for h in handles {
assert!((h.join().unwrap() - 20.0).abs() < 1e-6);
}
}
#[test]
fn test_fsdp_prefetch_rejects_double_call() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
fsdp.prefetch_forward_params().unwrap();
let r = fsdp.prefetch_forward_params();
assert!(r.is_err());
let err = format!("{}", r.unwrap_err());
assert!(err.contains("called twice"), "err = {err}");
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _ = fsdp.forward(&input).unwrap();
}
#[test]
fn test_fsdp_prefetch_rejects_non_fullshard() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
let r = fsdp.prefetch_forward_params();
assert!(r.is_err());
let err = format!("{}", r.unwrap_err());
assert!(err.contains("FullShard"), "err = {err}");
}
#[test]
fn test_fsdp_no_shard_broadcast_is_noop() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new_with_strategy(model, b, ShardingStrategy::NoShard).unwrap();
fsdp.broadcast_updated_params().unwrap();
assert_eq!(
fsdp.module().weight.tensor().data_vec().unwrap(),
&[1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn test_fsdp_sync_gradients_multi_rank() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut fsdp = FSDP::new(model, b).unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
vec![4],
false,
)
.unwrap();
fsdp.full_params[0].set_grad(Some(grad)).unwrap();
fsdp.sync_gradients().unwrap();
let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
let data = shard_grad.data_vec().unwrap();
(rank, data)
})
})
.collect();
for h in handles {
let (rank, data) = h.join().unwrap();
if rank == 0 {
assert_eq!(data.len(), 2);
assert!(
(data[0] - 1.0).abs() < 1e-6,
"rank 0: expected 1.0, got {}",
data[0]
);
assert!(
(data[1] - 2.0).abs() < 1e-6,
"rank 0: expected 2.0, got {}",
data[1]
);
} else {
assert_eq!(data.len(), 2);
assert!(
(data[0] - 3.0).abs() < 1e-6,
"rank 1: expected 3.0, got {}",
data[0]
);
assert!(
(data[1] - 4.0).abs() < 1e-6,
"rank 1: expected 4.0, got {}",
data[1]
);
}
}
}
#[test]
fn test_fsdp_hybrid_shard_rejects_uneven_world_size() {
let group = SimulatedBackend::create_group(4).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let b = arcs[0].clone();
let model = TestModule::<f32>::new(&[1.0f32; 8]).unwrap();
let result = FSDP::new_with_strategy(
model,
b,
ShardingStrategy::HybridShard { intra_node_size: 3 },
);
assert!(result.is_err(), "expected uneven intra_node_size to fail");
}
#[test]
fn test_fsdp_hybrid_shard_intra_node_sharding() {
let group = SimulatedBackend::create_group(4).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
let fsdp = FSDP::new_with_strategy(
model,
b,
ShardingStrategy::HybridShard { intra_node_size: 2 },
)
.unwrap();
let shard = fsdp.module().weight.tensor().data_vec().unwrap();
(rank, shard)
})
})
.collect();
for h in handles {
let (rank, shard) = h.join().unwrap();
assert_eq!(shard.len(), 2, "each rank gets 4/2 = 2 elements");
let expected: &[f32] = if rank % 2 == 0 {
&[10.0, 20.0]
} else {
&[30.0, 40.0]
};
assert_eq!(shard, expected, "rank {} shard mismatch", rank);
}
}
#[test]
fn test_fsdp_hybrid_shard_sync_gradients() {
let group = SimulatedBackend::create_group(4).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
let mut fsdp = FSDP::new_with_strategy(
model,
b,
ShardingStrategy::HybridShard { intra_node_size: 2 },
)
.unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _output = fsdp.forward(&input).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
vec![4],
false,
)
.unwrap();
fsdp.full_params[0].set_grad(Some(grad)).unwrap();
fsdp.sync_gradients().unwrap();
let w = fsdp.module().weight.tensor();
assert_eq!(w.numel(), 2, "HybridShard keeps params as intra-shards");
let g = w.grad().unwrap().unwrap();
let gd = g.data_vec().unwrap();
(rank, gd)
})
})
.collect();
for h in handles {
let (rank, gd) = h.join().unwrap();
assert_eq!(gd.len(), 2, "shard grad should have 2 elements");
let expected: &[f32] = if rank % 2 == 0 {
&[1.0, 2.0]
} else {
&[3.0, 4.0]
};
for (i, e) in expected.iter().enumerate() {
assert!(
(gd[i] - e).abs() < 1e-6,
"rank {} [{}]: expected {}, got {}",
rank,
i,
e,
gd[i]
);
}
}
}
#[test]
fn test_fsdp_hybrid_shard_inter_node_averaging() {
let group = SimulatedBackend::create_group(4).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[0.0f32; 4]).unwrap();
let mut fsdp = FSDP::new_with_strategy(
model,
b,
ShardingStrategy::HybridShard { intra_node_size: 2 },
)
.unwrap();
let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
let _ = fsdp.forward(&input).unwrap();
let grad_vec: Vec<f32> = if rank < 2 {
vec![2.0, 4.0, 6.0, 8.0]
} else {
vec![10.0, 20.0, 30.0, 40.0]
};
let grad =
Tensor::from_storage(TensorStorage::cpu(grad_vec), vec![4], false).unwrap();
fsdp.full_params[0].set_grad(Some(grad)).unwrap();
fsdp.sync_gradients().unwrap();
let gd = fsdp
.module()
.weight
.tensor()
.grad()
.unwrap()
.unwrap()
.data_vec()
.unwrap();
(rank, gd)
})
})
.collect();
for h in handles {
let (rank, gd) = h.join().unwrap();
assert_eq!(gd.len(), 2);
let expected: &[f32] = if rank % 2 == 0 {
&[6.0, 12.0]
} else {
&[18.0, 24.0]
};
for (i, e) in expected.iter().enumerate() {
assert!(
(gd[i] - e).abs() < 1e-4,
"hybrid inter-node mean rank {} [{}]: expected {}, got {}",
rank,
i,
e,
gd[i]
);
}
}
}
}