use crate::backend::ReduceOp;
use crate::collectives::{all_gather, all_reduce};
use crate::{ProcessGroup, Rank, TorshDistributedError, TorshResult};
use dashmap::DashMap;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use torsh_core::{device::DeviceType, error::Result, DType, Shape};
use torsh_nn::{Module, Parameter};
use torsh_tensor::Tensor;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct FsdpConfig {
pub min_num_params: usize,
pub auto_wrap_policy: AutoWrapPolicy,
pub sharding_strategy: ShardingStrategy,
pub mixed_precision: Option<MixedPrecisionConfig>,
pub cpu_offload: bool,
pub memory_config: MemoryConfig,
pub backward_prefetch: BackwardPrefetch,
}
impl Default for FsdpConfig {
fn default() -> Self {
Self {
min_num_params: 1000,
auto_wrap_policy: AutoWrapPolicy::SizeBasedAutoWrap {
min_num_params: 1000,
},
sharding_strategy: ShardingStrategy::FullShard,
mixed_precision: None,
cpu_offload: false,
memory_config: MemoryConfig::default(),
backward_prefetch: BackwardPrefetch::BackwardPre,
}
}
}
#[derive(Debug, Clone)]
pub enum AutoWrapPolicy {
SizeBasedAutoWrap { min_num_params: usize },
ModuleTypeBasedAutoWrap { module_types: Vec<String> },
CustomAutoWrap,
NoAutoWrap,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ShardingStrategy {
FullShard,
ShardGradOp,
NoShard,
HybridShard,
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfig {
pub param_dtype: DType,
pub reduce_dtype: DType,
pub buffer_dtype: DType,
pub keep_low_precision_grads: bool,
}
#[derive(Debug, Clone)]
pub struct MemoryConfig {
pub limit_all_gathers: bool,
pub use_orig_params: bool,
pub offload_to_cpu: bool,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
limit_all_gathers: true,
use_orig_params: false,
offload_to_cpu: false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum BackwardPrefetch {
BackwardPre,
BackwardPost,
None,
}
#[derive(Debug, Clone)]
pub struct ShardInfo {
pub rank: Rank,
pub start_idx: usize,
pub shard_size: usize,
pub original_shape: Shape,
pub is_local: bool,
}
#[derive(Debug)]
enum ParameterState {
Sharded {
#[allow(dead_code)]
shard_info: ShardInfo,
},
Gathered {
#[allow(dead_code)]
full_tensor: Tensor,
},
#[allow(dead_code)]
Gathering,
#[allow(dead_code)]
Sharding,
}
pub struct FullyShardedDataParallel {
module: Arc<RwLock<dyn Module>>,
process_group: Arc<ProcessGroup>,
config: FsdpConfig,
param_states: Arc<DashMap<String, ParameterState>>,
sharded_params: Arc<DashMap<String, Tensor>>,
#[allow(dead_code)]
gathered_params: Arc<DashMap<String, Tensor>>,
#[allow(dead_code)]
grad_buffers: Arc<DashMap<String, Tensor>>,
training: Arc<Mutex<bool>>,
#[allow(dead_code)]
compute_stream: Arc<Mutex<Option<String>>>,
memory_stats: Arc<Mutex<MemoryStats>>,
}
#[derive(Debug, Default)]
pub struct MemoryStats {
pub peak_memory_mb: f64,
pub current_memory_mb: f64,
pub memory_saved_mb: f64,
pub num_all_gathers: u64,
pub num_reduce_scatters: u64,
}
impl FullyShardedDataParallel {
pub fn new(
module: Arc<RwLock<dyn Module>>,
process_group: Arc<ProcessGroup>,
config: FsdpConfig,
) -> TorshResult<Self> {
let fsdp = Self {
module,
process_group,
config,
param_states: Arc::new(DashMap::new()),
sharded_params: Arc::new(DashMap::new()),
gathered_params: Arc::new(DashMap::new()),
grad_buffers: Arc::new(DashMap::new()),
training: Arc::new(Mutex::new(true)),
compute_stream: Arc::new(Mutex::new(None)),
memory_stats: Arc::new(Mutex::new(MemoryStats::default())),
};
fsdp.shard_parameters()?;
info!(
"FSDP initialized with strategy {:?} for {} workers",
fsdp.config.sharding_strategy,
fsdp.process_group.world_size()
);
Ok(fsdp)
}
fn shard_parameters(&self) -> TorshResult<()> {
let module_guard = self.module.read();
let parameters = module_guard.parameters();
drop(module_guard);
let world_size = self.process_group.world_size() as usize;
let rank = self.process_group.rank() as usize;
for (name, param) in parameters {
let tensor_arc = param.tensor();
let tensor_guard = tensor_arc.read();
if tensor_guard.numel() < self.config.min_num_params {
self.param_states.insert(
name.clone(),
ParameterState::Gathered {
full_tensor: tensor_guard.clone(),
},
);
continue;
}
let flat_param = tensor_guard.flatten()?;
let total_elements = flat_param.numel();
let base_shard_size = total_elements / world_size;
let remainder = total_elements % world_size;
let mut start_idx = 0;
for worker_rank in 0..world_size {
let shard_size = base_shard_size + if worker_rank < remainder { 1 } else { 0 };
if worker_rank == rank {
let shard = flat_param
.slice(0, start_idx, start_idx + shard_size)?
.to_tensor()?;
self.sharded_params.insert(name.clone(), shard);
let shard_info = ShardInfo {
rank: worker_rank as Rank,
start_idx,
shard_size,
original_shape: tensor_guard.shape().clone(),
is_local: true,
};
self.param_states
.insert(name.clone(), ParameterState::Sharded { shard_info });
}
start_idx += shard_size;
}
debug!(
"Sharded parameter '{}' with {} elements across {} workers",
name, total_elements, world_size
);
drop(tensor_guard);
}
Ok(())
}
#[allow(dead_code)]
async fn gather_parameters(&self, param_names: &[String]) -> TorshResult<()> {
for param_name in param_names {
if let Some(mut state_ref) = self.param_states.get_mut(param_name) {
if let ParameterState::Sharded { shard_info } = &*state_ref {
let original_shape = shard_info.original_shape.clone();
*state_ref = ParameterState::Gathering;
drop(state_ref);
let shard = self.sharded_params.get(param_name).ok_or_else(|| {
TorshDistributedError::backend_error(
"fsdp",
format!("Shard not found for parameter '{}'", param_name),
)
})?;
let mut gathered_tensors = Vec::new();
all_gather(&mut gathered_tensors, &*shard, &self.process_group).await?;
let gathered_tensor = if gathered_tensors.len() == 1 {
gathered_tensors
.into_iter()
.next()
.expect("gathered_tensors should not be empty")
} else {
gathered_tensors
.into_iter()
.next()
.expect("gathered_tensors should not be empty")
};
let shape_dims: Vec<i32> =
original_shape.dims().iter().map(|&x| x as i32).collect();
let reshaped = gathered_tensor.reshape(&shape_dims)?;
self.gathered_params
.insert(param_name.clone(), reshaped.clone());
self.param_states.insert(
param_name.clone(),
ParameterState::Gathered {
full_tensor: reshaped,
},
);
let mut stats = self
.memory_stats
.lock()
.expect("lock should not be poisoned");
stats.num_all_gathers += 1;
}
}
}
Ok(())
}
#[allow(dead_code)]
async fn reduce_scatter_gradients(&self, param_names: &[String]) -> TorshResult<()> {
for param_name in param_names {
if let Some(grad_buffer) = self.grad_buffers.get(param_name) {
let mut reduced_grad = grad_buffer.clone();
all_reduce(&mut reduced_grad, ReduceOp::Sum, &self.process_group).await?;
if let Some(state_ref) = self.param_states.get(param_name) {
if let ParameterState::Sharded { shard_info } = &*state_ref {
let grad_shard = reduced_grad.slice(
0,
shard_info.start_idx,
shard_info.start_idx + shard_info.shard_size,
)?;
if let Some(mut param_shard) = self.sharded_params.get_mut(param_name) {
let grad_tensor = grad_shard.to_tensor()?;
*param_shard = param_shard.sub(&grad_tensor)?;
}
}
}
let mut stats = self
.memory_stats
.lock()
.expect("lock should not be poisoned");
stats.num_reduce_scatters += 1;
}
self.param_states.insert(
param_name.clone(),
ParameterState::Sharded {
shard_info: self.get_shard_info(param_name)?,
},
);
self.gathered_params.remove(param_name);
}
Ok(())
}
#[allow(dead_code)]
fn get_shard_info(&self, param_name: &str) -> TorshResult<ShardInfo> {
if let Some(state_ref) = self.param_states.get(param_name) {
match &*state_ref {
ParameterState::Sharded { shard_info } => Ok(shard_info.clone()),
_ => Err(TorshDistributedError::backend_error(
"fsdp",
format!("Parameter '{}' is not in sharded state", param_name),
)),
}
} else {
Err(TorshDistributedError::backend_error(
"fsdp",
format!("Parameter '{}' not found", param_name),
))
}
}
pub fn train(&self, mode: bool) {
*self.training.lock().expect("lock should not be poisoned") = mode;
let mut module_guard = self.module.write();
if mode {
module_guard.train();
} else {
module_guard.eval();
}
}
pub fn is_training(&self) -> bool {
*self.training.lock().expect("lock should not be poisoned")
}
pub fn memory_stats(&self) -> MemoryStats {
let stats = self
.memory_stats
.lock()
.expect("lock should not be poisoned");
MemoryStats {
peak_memory_mb: stats.peak_memory_mb,
current_memory_mb: stats.current_memory_mb,
memory_saved_mb: stats.memory_saved_mb,
num_all_gathers: stats.num_all_gathers,
num_reduce_scatters: stats.num_reduce_scatters,
}
}
pub fn num_parameters(&self) -> usize {
let module_guard = self.module.read();
let parameters = module_guard.parameters();
parameters.values().map(|p| p.tensor().read().numel()).sum()
}
pub fn local_sharding_ratio(&self) -> f64 {
let total_params = self.num_parameters();
let local_params: usize = self
.sharded_params
.iter()
.map(|entry| entry.value().numel())
.sum();
if total_params > 0 {
local_params as f64 / total_params as f64
} else {
0.0
}
}
}
impl Module for FullyShardedDataParallel {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let _param_names: Vec<String> = self
.param_states
.iter()
.filter_map(|entry| match entry.value() {
ParameterState::Sharded { .. } => Some(entry.key().clone()),
_ => None,
})
.collect();
let module_guard = self.module.read();
let output = module_guard.forward(input)?;
drop(module_guard);
if self.is_training() {
debug!("Forward pass completed, gradients will be reduce-scattered in backward");
} else {
}
Ok(output)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for entry in self.sharded_params.iter() {
let name = entry.key().clone();
let tensor = entry.value().clone();
params.insert(name, Parameter::new(tensor));
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
*self.training.lock().expect("lock should not be poisoned")
}
fn train(&mut self) {
*self.training.lock().expect("lock should not be poisoned") = true;
}
fn eval(&mut self) {
*self.training.lock().expect("lock should not be poisoned") = false;
}
fn to_device(&mut self, _device: DeviceType) -> torsh_core::Result<()> {
Ok(())
}
}
pub fn fsdp_wrap<M: Module + 'static>(
module: M,
process_group: Arc<ProcessGroup>,
config: Option<FsdpConfig>,
) -> TorshResult<FullyShardedDataParallel> {
let config = config.unwrap_or_default();
let module_arc = Arc::new(RwLock::new(module));
FullyShardedDataParallel::new(module_arc, process_group, config)
}
pub fn auto_wrap_modules<M: Module + 'static>(
module: M,
process_group: Arc<ProcessGroup>,
auto_wrap_policy: AutoWrapPolicy,
) -> TorshResult<FullyShardedDataParallel> {
let config = FsdpConfig {
auto_wrap_policy,
..Default::default()
};
fsdp_wrap(module, process_group, Some(config))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{init_process_group, BackendType};
use torsh_nn::{prelude::Linear, Module};
#[tokio::test]
async fn test_fsdp_initialization() -> TorshResult<()> {
let process_group =
Arc::new(init_process_group(BackendType::Gloo, 0, 2, "127.0.0.1", 12345).await?);
let linear = Linear::new(128, 64, true);
let config = FsdpConfig::default();
let fsdp =
FullyShardedDataParallel::new(Arc::new(RwLock::new(linear)), process_group, config)?;
assert!(fsdp.local_sharding_ratio() > 0.0);
assert!(fsdp.local_sharding_ratio() <= 1.0);
Ok(())
}
#[tokio::test]
async fn test_fsdp_forward_pass() -> TorshResult<()> {
let process_group =
Arc::new(init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 12346).await?);
let linear = Linear::new(64, 32, true);
let fsdp = fsdp_wrap(linear, process_group, None)?;
let input = torsh_tensor::creation::randn(&[8, 64])?;
let output = fsdp.forward(&input)?;
assert_eq!(output.shape().dims(), &[8, 32]);
Ok(())
}
#[test]
fn test_fsdp_config() {
let config = FsdpConfig::default();
assert_eq!(config.min_num_params, 1000);
assert_eq!(config.sharding_strategy, ShardingStrategy::FullShard);
assert_eq!(config.backward_prefetch, BackwardPrefetch::BackwardPre);
let custom_config = FsdpConfig {
min_num_params: 500,
sharding_strategy: ShardingStrategy::ShardGradOp,
cpu_offload: true,
..Default::default()
};
assert_eq!(custom_config.min_num_params, 500);
assert_eq!(
custom_config.sharding_strategy,
ShardingStrategy::ShardGradOp
);
assert!(custom_config.cpu_offload);
}
#[test]
fn test_shard_info() {
let shard_info = ShardInfo {
rank: 0,
start_idx: 0,
shard_size: 1000,
original_shape: Shape::new(vec![10, 100]),
is_local: true,
};
assert_eq!(shard_info.rank, 0);
assert_eq!(shard_info.shard_size, 1000);
assert!(shard_info.is_local);
}
#[test]
fn test_memory_stats() {
let stats = MemoryStats::default();
assert_eq!(stats.peak_memory_mb, 0.0);
assert_eq!(stats.num_all_gathers, 0);
assert_eq!(stats.num_reduce_scatters, 0);
}
#[tokio::test]
async fn test_auto_wrap() -> TorshResult<()> {
let process_group =
Arc::new(init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 12347).await?);
let linear = Linear::new(100, 50, true);
let policy = AutoWrapPolicy::SizeBasedAutoWrap {
min_num_params: 1000,
};
let fsdp = auto_wrap_modules(linear, process_group, policy)?;
assert!(fsdp.local_sharding_ratio() >= 0.9);
Ok(())
}
}