use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum WrappingPolicy {
WholeProgramWrap,
TransformerLayerWrap {
min_params: usize,
},
SizeBasedWrap {
min_num_params: usize,
},
}
#[derive(Debug, Clone)]
pub struct FsdpConfig {
pub world_size: usize,
pub local_rank: usize,
pub wrapping_policy: WrappingPolicy,
pub cpu_offload: bool,
pub mixed_precision: bool,
pub backward_prefetch: bool,
pub forward_prefetch: bool,
}
impl Default for FsdpConfig {
fn default() -> Self {
Self {
world_size: 1,
local_rank: 0,
wrapping_policy: WrappingPolicy::WholeProgramWrap,
cpu_offload: false,
mixed_precision: false,
backward_prefetch: true,
forward_prefetch: false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ShardingStrategy {
FullShard,
ShardGradOp,
NoShard,
HybridShard { num_model_replicas: usize },
}
pub struct FsdpMemoryAnalyzer;
impl FsdpMemoryAnalyzer {
pub fn peak_memory(config: &FsdpConfig, total_params: usize) -> usize {
let param_bytes = total_params * 4; let grad_bytes = total_params * 4;
let opt_bytes = total_params * 8; let ws = config.world_size.max(1);
if ws == 1 {
param_bytes + grad_bytes + opt_bytes
} else {
param_bytes / ws + grad_bytes / ws + opt_bytes / ws
}
}
pub fn memory_vs_ddp_ratio(world_size: usize) -> f32 {
if world_size <= 1 {
1.0
} else {
1.0 / world_size as f32
}
}
}
#[derive(Debug, Clone)]
pub struct FsdpUnit {
pub unit_id: usize,
pub param_names: Vec<String>,
pub total_params: usize,
pub shard_size: usize,
pub is_offloaded: bool,
}
impl FsdpUnit {
pub fn new(
unit_id: usize,
param_names: Vec<String>,
total_params: usize,
world_size: usize,
) -> Self {
let shard_size =
if world_size == 0 { total_params } else { total_params.div_ceil(world_size) };
Self {
unit_id,
param_names,
total_params,
shard_size,
is_offloaded: false,
}
}
pub fn memory_bytes_per_rank(&self) -> usize {
self.shard_size * std::mem::size_of::<f64>()
}
pub fn local_params(&self) -> usize {
self.shard_size
}
pub fn all_gather_buffer_size(&self) -> usize {
self.total_params
}
pub fn reduce_scatter_buffer_size(&self) -> usize {
self.total_params
}
}
pub struct FsdpState {
config: FsdpConfig,
units: Vec<FsdpUnit>,
param_registry: HashMap<String, usize>,
gather_cache: HashMap<usize, Vec<f64>>,
shard_storage: HashMap<usize, Vec<f64>>,
}
impl FsdpState {
pub fn new(config: FsdpConfig) -> Self {
Self {
config,
units: Vec::new(),
param_registry: HashMap::new(),
gather_cache: HashMap::new(),
shard_storage: HashMap::new(),
}
}
pub fn wrap_unit(
&mut self,
param_names: Vec<String>,
param_values: HashMap<String, Vec<f64>>,
) -> Result<usize, FsdpError> {
if param_names.is_empty() {
return Err(FsdpError::EmptyUnit);
}
for name in ¶m_names {
if self.param_registry.contains_key(name) {
return Err(FsdpError::AlreadyRegistered(name.clone()));
}
}
let unit_id = self.units.len();
let total_params: usize =
param_names.iter().filter_map(|n| param_values.get(n)).map(|v| v.len()).sum();
let unit = FsdpUnit::new(
unit_id,
param_names.clone(),
total_params,
self.config.world_size,
);
let mut flat_params: Vec<f64> = Vec::with_capacity(total_params);
for name in ¶m_names {
if let Some(vals) = param_values.get(name) {
flat_params.extend_from_slice(vals);
}
}
let shard_start = self.config.local_rank * unit.shard_size;
let shard_end = (shard_start + unit.shard_size).min(flat_params.len());
let shard: Vec<f64> = if shard_start < flat_params.len() {
flat_params[shard_start..shard_end].to_vec()
} else {
Vec::new()
};
self.shard_storage.insert(unit_id, shard);
for name in ¶m_names {
self.param_registry.insert(name.clone(), unit_id);
}
self.units.push(unit);
Ok(unit_id)
}
pub fn allgather_unit(&mut self, unit_id: usize) -> Result<Vec<f64>, FsdpError> {
let unit = self.units.get(unit_id).ok_or(FsdpError::UnitNotFound(unit_id))?;
let total_params = unit.total_params;
let world_size = self.config.world_size;
let shard = self.shard_storage.get(&unit_id).cloned().unwrap_or_default();
let mut gathered: Vec<f64> = Vec::with_capacity(total_params);
if world_size == 0 || shard.is_empty() {
} else {
for _ in 0..world_size {
gathered.extend_from_slice(&shard);
if gathered.len() >= total_params {
break;
}
}
gathered.truncate(total_params);
}
self.gather_cache.insert(unit_id, gathered.clone());
Ok(gathered)
}
pub fn discard_unit_params(&mut self, unit_id: usize) -> Result<(), FsdpError> {
if unit_id >= self.units.len() {
return Err(FsdpError::UnitNotFound(unit_id));
}
self.gather_cache.remove(&unit_id);
Ok(())
}
pub fn reduce_scatter_grads(
&mut self,
unit_id: usize,
grads: Vec<f64>,
) -> Result<Vec<f64>, FsdpError> {
let unit = self.units.get(unit_id).ok_or(FsdpError::UnitNotFound(unit_id))?;
let shard_size = unit.shard_size;
let world_size = self.config.world_size.max(1) as f64;
let averaged: Vec<f64> = grads.iter().map(|&g| g / world_size).collect();
let rank = self.config.local_rank;
let start = rank * shard_size;
let end = (start + shard_size).min(averaged.len());
let local_grads: Vec<f64> =
if start < averaged.len() { averaged[start..end].to_vec() } else { Vec::new() };
Ok(local_grads)
}
pub fn memory_saving_ratio(&self) -> f64 {
self.config.world_size as f64
}
pub fn per_rank_memory_bytes(&self) -> usize {
self.units.iter().map(|u| u.memory_bytes_per_rank()).sum()
}
pub fn unit_count(&self) -> usize {
self.units.len()
}
pub fn total_params(&self) -> usize {
self.units.iter().map(|u| u.total_params).sum()
}
}
#[derive(Debug, thiserror::Error)]
pub enum FsdpError {
#[error("Unit not found: {0}")]
UnitNotFound(usize),
#[error("Param already registered: {0}")]
AlreadyRegistered(String),
#[error("Empty unit")]
EmptyUnit,
#[error("Not gathered: unit {0}")]
NotGathered(usize),
}
#[cfg(test)]
mod tests {
use super::*;
fn make_param_values(names: &[&str], size: usize) -> HashMap<String, Vec<f64>> {
names
.iter()
.enumerate()
.map(|(i, &name)| {
let vals: Vec<f64> = (0..size).map(|j| (i * size + j) as f64).collect();
(name.to_string(), vals)
})
.collect()
}
#[test]
fn test_wrap_unit_basic() {
let config = FsdpConfig {
world_size: 4,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let names = vec!["w1".to_string(), "b1".to_string()];
let values = make_param_values(&["w1", "b1"], 16);
let unit_id = state.wrap_unit(names, values).expect("wrap failed");
assert_eq!(unit_id, 0);
assert_eq!(state.unit_count(), 1);
}
#[test]
fn test_unit_shard_size() {
let world_size = 4;
let config = FsdpConfig {
world_size,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let names = vec!["weight".to_string()];
let total = 100usize;
let values: HashMap<String, Vec<f64>> = {
let mut m = HashMap::new();
m.insert("weight".to_string(), vec![1.0f64; total]);
m
};
let unit_id = state.wrap_unit(names, values).expect("wrap failed");
let unit = &state.units[unit_id];
let expected_shard = total.div_ceil(world_size); assert_eq!(unit.shard_size, expected_shard);
assert_eq!(unit.total_params, total);
}
#[test]
fn test_allgather_unit_reconstruction() {
let config = FsdpConfig {
world_size: 2,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let names = vec!["p".to_string()];
let values: HashMap<String, Vec<f64>> = {
let mut m = HashMap::new();
m.insert("p".to_string(), vec![1.0, 2.0, 3.0, 4.0]);
m
};
let unit_id = state.wrap_unit(names, values).expect("wrap failed");
let gathered = state.allgather_unit(unit_id).expect("gather failed");
assert_eq!(gathered.len(), 4);
}
#[test]
fn test_discard_clears_cache() {
let config = FsdpConfig {
world_size: 2,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let names = vec!["q".to_string()];
let values: HashMap<String, Vec<f64>> = {
let mut m = HashMap::new();
m.insert("q".to_string(), vec![1.0, 2.0, 3.0, 4.0]);
m
};
let unit_id = state.wrap_unit(names, values).expect("wrap failed");
state.allgather_unit(unit_id).expect("gather failed");
assert!(state.gather_cache.contains_key(&unit_id));
state.discard_unit_params(unit_id).expect("discard failed");
assert!(!state.gather_cache.contains_key(&unit_id));
}
#[test]
fn test_reduce_scatter_averages() {
let world_size = 4;
let config = FsdpConfig {
world_size,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let names = vec!["r".to_string()];
let total = 8usize;
let values: HashMap<String, Vec<f64>> = {
let mut m = HashMap::new();
m.insert("r".to_string(), vec![1.0f64; total]);
m
};
let unit_id = state.wrap_unit(names, values).expect("wrap failed");
let grads = vec![4.0f64; total];
let local_grads = state.reduce_scatter_grads(unit_id, grads).expect("scatter failed");
for &g in &local_grads {
assert!((g - 1.0).abs() < 1e-9, "Expected 1.0, got {}", g);
}
}
#[test]
fn test_memory_saving_ratio() {
let world_size = 8;
let config = FsdpConfig {
world_size,
local_rank: 2,
..Default::default()
};
let state = FsdpState::new(config);
let ratio = state.memory_saving_ratio();
assert!((ratio - world_size as f64).abs() < 1e-9);
}
#[test]
fn test_per_rank_memory_bytes() {
let world_size = 4;
let config = FsdpConfig {
world_size,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let names = vec!["w".to_string()];
let total = 100usize;
let values: HashMap<String, Vec<f64>> = {
let mut m = HashMap::new();
m.insert("w".to_string(), vec![1.0f64; total]);
m
};
let unit_id = state.wrap_unit(names, values).expect("wrap failed");
let unit = &state.units[unit_id];
let expected_bytes = unit.shard_size * 8; assert_eq!(state.per_rank_memory_bytes(), expected_bytes);
}
#[test]
fn test_unit_count_and_total_params() {
let config = FsdpConfig {
world_size: 2,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
assert_eq!(state.unit_count(), 0);
assert_eq!(state.total_params(), 0);
let names = vec!["a".to_string()];
let values: HashMap<String, Vec<f64>> = {
let mut m = HashMap::new();
m.insert("a".to_string(), vec![0.0f64; 50]);
m
};
state.wrap_unit(names, values).expect("wrap failed");
assert_eq!(state.unit_count(), 1);
assert_eq!(state.total_params(), 50);
}
#[test]
fn test_transformer_layer_wrap_policy() {
let policy = WrappingPolicy::TransformerLayerWrap { min_params: 1024 };
let config = FsdpConfig {
world_size: 4,
local_rank: 1,
wrapping_policy: policy.clone(),
..Default::default()
};
let state = FsdpState::new(config);
assert_eq!(
state.config.wrapping_policy,
WrappingPolicy::TransformerLayerWrap { min_params: 1024 }
);
}
#[test]
fn test_fsdp_unit_memory_bytes_per_rank() {
let unit = FsdpUnit::new(0, vec!["w".to_string()], 1000, 4);
assert_eq!(unit.shard_size, 250);
assert_eq!(unit.memory_bytes_per_rank(), 250 * 8);
}
#[test]
fn test_double_register_error() {
let config = FsdpConfig {
world_size: 2,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let names1 = vec!["shared_param".to_string()];
let values1: HashMap<String, Vec<f64>> = {
let mut m = HashMap::new();
m.insert("shared_param".to_string(), vec![1.0f64; 10]);
m
};
state.wrap_unit(names1, values1).expect("first wrap should succeed");
let names2 = vec!["shared_param".to_string()];
let values2: HashMap<String, Vec<f64>> = {
let mut m = HashMap::new();
m.insert("shared_param".to_string(), vec![2.0f64; 10]);
m
};
let result = state.wrap_unit(names2, values2);
assert!(matches!(result, Err(FsdpError::AlreadyRegistered(_))));
}
#[test]
fn test_unit_not_found_error() {
let config = FsdpConfig {
world_size: 2,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let result = state.allgather_unit(99);
assert!(matches!(result, Err(FsdpError::UnitNotFound(99))));
}
#[test]
fn test_multi_unit_model() {
let config = FsdpConfig {
world_size: 4,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let layers = [
(vec!["l0.w".to_string(), "l0.b".to_string()], 128, 4),
(vec!["l1.w".to_string(), "l1.b".to_string()], 256, 8),
(vec!["l2.w".to_string(), "l2.b".to_string()], 512, 16),
];
let mut total_expected = 0usize;
for (names, w_size, b_size) in &layers {
let mut values: HashMap<String, Vec<f64>> = HashMap::new();
values.insert(names[0].clone(), vec![1.0f64; *w_size]);
values.insert(names[1].clone(), vec![0.0f64; *b_size]);
total_expected += w_size + b_size;
state.wrap_unit(names.clone(), values).expect("wrap failed");
}
assert_eq!(state.unit_count(), 3);
assert_eq!(state.total_params(), total_expected);
}
#[test]
fn test_cpu_offload_flag() {
let config = FsdpConfig {
world_size: 4,
local_rank: 1,
cpu_offload: true,
mixed_precision: true,
..Default::default()
};
assert!(config.cpu_offload);
assert!(config.mixed_precision);
let state = FsdpState::new(config);
assert!(state.config.cpu_offload);
assert!(state.config.mixed_precision);
let default_config = FsdpConfig::default();
assert!(!default_config.cpu_offload);
}
#[test]
fn test_sharding_strategy_full_shard() {
let s = ShardingStrategy::FullShard;
assert_eq!(s, ShardingStrategy::FullShard);
}
#[test]
fn test_sharding_strategy_no_shard() {
assert_ne!(ShardingStrategy::NoShard, ShardingStrategy::FullShard);
}
#[test]
fn test_sharding_strategy_shard_grad_op() {
let s = ShardingStrategy::ShardGradOp;
assert_eq!(s, ShardingStrategy::ShardGradOp);
assert_ne!(s, ShardingStrategy::FullShard);
}
#[test]
fn test_hybrid_shard_replicas() {
let s = ShardingStrategy::HybridShard {
num_model_replicas: 4,
};
assert_eq!(
s,
ShardingStrategy::HybridShard {
num_model_replicas: 4
}
);
assert_ne!(
s,
ShardingStrategy::HybridShard {
num_model_replicas: 2
}
);
}
#[test]
fn test_fsdp_unit_local_params() {
let unit = FsdpUnit::new(0, vec!["w".to_string()], 100, 4);
assert_eq!(unit.local_params(), unit.shard_size);
}
#[test]
fn test_fsdp_unit_all_gather_buffer() {
let unit = FsdpUnit::new(0, vec!["w".to_string()], 100, 4);
assert_eq!(unit.all_gather_buffer_size(), 100);
}
#[test]
fn test_fsdp_unit_reduce_scatter_buffer() {
let unit = FsdpUnit::new(0, vec!["w".to_string()], 64, 8);
assert_eq!(unit.reduce_scatter_buffer_size(), 64);
}
#[test]
fn test_memory_analyzer_peak_memory_world_size_1() {
let config = FsdpConfig {
world_size: 1,
local_rank: 0,
..Default::default()
};
let peak = FsdpMemoryAnalyzer::peak_memory(&config, 1000);
assert_eq!(peak, 16_000);
}
#[test]
fn test_memory_analyzer_peak_memory_world_size_4() {
let config = FsdpConfig {
world_size: 4,
local_rank: 0,
..Default::default()
};
let peak = FsdpMemoryAnalyzer::peak_memory(&config, 1000);
assert_eq!(peak, 4_000);
}
#[test]
fn test_memory_vs_ddp_ratio_world_size_1() {
let ratio = FsdpMemoryAnalyzer::memory_vs_ddp_ratio(1);
assert!((ratio - 1.0).abs() < 1e-6);
}
#[test]
fn test_memory_vs_ddp_ratio_world_size_8() {
let ratio = FsdpMemoryAnalyzer::memory_vs_ddp_ratio(8);
assert!((ratio - 0.125).abs() < 1e-6);
}
#[test]
fn test_memory_vs_ddp_ratio_world_size_0() {
let ratio = FsdpMemoryAnalyzer::memory_vs_ddp_ratio(0);
assert!((ratio - 1.0).abs() < 1e-6);
}
#[test]
fn test_fsdp_state_rank_gets_correct_shard() {
let config = FsdpConfig {
world_size: 4,
local_rank: 2,
..Default::default()
};
let mut state = FsdpState::new(config);
let names = vec!["params".to_string()];
let vals: Vec<f64> = (0..100).map(|i| i as f64).collect();
let mut values = HashMap::new();
values.insert("params".to_string(), vals.clone());
let unit_id = state.wrap_unit(names, values).expect("wrap failed");
let unit = &state.units[unit_id];
assert_eq!(unit.shard_size, 25);
let shard = state.shard_storage.get(&unit_id).expect("shard missing");
assert_eq!(shard.len(), 25);
assert!(
(shard[0] - 50.0).abs() < 1e-9,
"first element should be 50.0, got {}",
shard[0]
);
}
#[test]
fn test_forward_prefetch_flag() {
let config = FsdpConfig {
world_size: 4,
local_rank: 0,
forward_prefetch: true,
..Default::default()
};
assert!(config.forward_prefetch);
let state = FsdpState::new(config);
assert!(state.config.forward_prefetch);
assert!(!FsdpConfig::default().forward_prefetch);
}
#[test]
fn test_reduce_scatter_rank1_of_4() {
let config = FsdpConfig {
world_size: 4,
local_rank: 1,
..Default::default()
};
let mut state = FsdpState::new(config);
let names = vec!["r".to_string()];
let mut values = HashMap::new();
values.insert("r".to_string(), vec![1.0f64; 8]);
let unit_id = state.wrap_unit(names, values).expect("wrap");
let grads = vec![4.0f64; 8]; let local = state.reduce_scatter_grads(unit_id, grads).expect("scatter");
assert_eq!(local.len(), 2, "rank 1 should hold 2 elements");
for &g in &local {
assert!((g - 1.0).abs() < 1e-9, "Expected 1.0, got {g}");
}
}
#[test]
fn test_wrap_empty_unit_error() {
let config = FsdpConfig {
world_size: 2,
local_rank: 0,
..Default::default()
};
let mut state = FsdpState::new(config);
let result = state.wrap_unit(vec![], HashMap::new());
assert!(matches!(result, Err(FsdpError::EmptyUnit)));
}
}