use crate::backend::ReduceOp;
use crate::process_group::ProcessGroup;
use axonml_autograd::Variable;
use axonml_nn::{Module, Parameter};
use axonml_tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ShardingStrategy {
#[default]
FullShard,
ShardGradOp,
NoShard,
HybridShard,
}
#[derive(Debug)]
#[allow(dead_code)]
struct ShardedParam {
local_shard: Tensor<f32>,
original_shape: Vec<usize>,
numel: usize,
padding: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CPUOffload {
#[default]
None,
Params,
Full,
}
pub struct FullyShardedDataParallel<M: Module> {
module: M,
process_group: ProcessGroup,
sharding_strategy: ShardingStrategy,
cpu_offload: CPUOffload,
sharded_params: Vec<ShardedParam>,
is_gathered: bool,
mixed_precision: bool,
}
impl<M: Module> FullyShardedDataParallel<M> {
pub fn new(module: M, process_group: ProcessGroup) -> Self {
let mut fsdp = Self {
module,
process_group,
sharding_strategy: ShardingStrategy::default(),
cpu_offload: CPUOffload::default(),
sharded_params: Vec::new(),
is_gathered: true,
mixed_precision: false,
};
fsdp.shard_parameters();
fsdp
}
pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
self.sharding_strategy = strategy;
self.shard_parameters();
self
}
pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
self.cpu_offload = offload;
self
}
pub fn mixed_precision(mut self, enabled: bool) -> Self {
self.mixed_precision = enabled;
self
}
pub fn module(&self) -> &M {
&self.module
}
pub fn module_mut(&mut self) -> &mut M {
&mut self.module
}
pub fn process_group(&self) -> &ProcessGroup {
&self.process_group
}
pub fn strategy(&self) -> ShardingStrategy {
self.sharding_strategy
}
fn shard_parameters(&mut self) {
if self.sharding_strategy == ShardingStrategy::NoShard {
return;
}
let world_size = self.process_group.world_size();
let rank = self.process_group.rank();
self.sharded_params.clear();
for param in self.module.parameters() {
let data = param.data();
let shape = data.shape().to_vec();
let numel = data.numel();
let shard_size = numel.div_ceil(world_size);
let padding = shard_size * world_size - numel;
let flat_data = data.to_vec();
let start = rank * shard_size;
let end = ((rank + 1) * shard_size).min(flat_data.len());
let mut shard_data: Vec<f32> = if start < flat_data.len() {
flat_data[start..end].to_vec()
} else {
vec![0.0; shard_size]
};
while shard_data.len() < shard_size {
shard_data.push(0.0);
}
self.sharded_params.push(ShardedParam {
local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
original_shape: shape,
numel,
padding,
});
}
self.is_gathered = false;
}
pub fn gather_parameters(&mut self) {
if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
return;
}
let _world_size = self.process_group.world_size();
let params = self.module.parameters();
for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
param.update_data(restored);
}
self.is_gathered = true;
}
pub fn reshard_parameters(&mut self) {
if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
return;
}
self.shard_parameters();
}
pub fn sync_gradients(&self) {
match self.sharding_strategy {
ShardingStrategy::NoShard => {
for param in self.module.parameters() {
if let Some(grad) = param.grad() {
let mut grad_tensor = grad.clone();
self.process_group
.all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
}
}
}
ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
for param in self.module.parameters() {
if let Some(grad) = param.grad() {
let _reduced = self
.process_group
.reduce_scatter_tensor(&grad, ReduceOp::Average);
}
}
}
ShardingStrategy::HybridShard => {
for param in self.module.parameters() {
if let Some(grad) = param.grad() {
let mut grad_tensor = grad.clone();
self.process_group
.all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
}
}
}
}
}
pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
let mut total_norm_sq = 0.0f32;
for param in self.module.parameters() {
if let Some(grad) = param.grad() {
let grad_vec = grad.to_vec();
let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
total_norm_sq += norm_sq;
}
}
let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
self.process_group
.all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
let global_norm = norm_tensor.to_vec()[0].sqrt();
if global_norm > max_norm {
let clip_coef = max_norm / (global_norm + 1e-6);
for param in self.module.parameters() {
if let Some(grad) = param.grad() {
let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
param.variable().set_grad(clipped_tensor);
}
}
}
global_norm
}
pub fn memory_estimate(&self) -> FSDPMemoryStats {
let params = self.module.parameters();
let total_params: usize = params.iter().map(|p| p.numel()).sum();
let world_size = self.process_group.world_size();
let bytes_per_param = 4; let param_memory = total_params * bytes_per_param;
let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
ShardingStrategy::ShardGradOp => (
param_memory,
param_memory / world_size,
param_memory * 2 / world_size,
),
ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
param_memory / world_size,
param_memory / world_size,
param_memory * 2 / world_size,
),
};
FSDPMemoryStats {
total_params,
param_memory_bytes: sharded_params,
grad_memory_bytes: sharded_grads,
optim_memory_bytes: sharded_optim,
world_size,
}
}
}
impl<M: Module> Module for FullyShardedDataParallel<M> {
fn forward(&self, input: &Variable) -> Variable {
self.module.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
self.module.parameters()
}
fn train(&mut self) {
self.module.train();
}
fn eval(&mut self) {
self.module.eval();
}
fn is_training(&self) -> bool {
self.module.is_training()
}
}
#[derive(Debug, Clone)]
pub struct FSDPMemoryStats {
pub total_params: usize,
pub param_memory_bytes: usize,
pub grad_memory_bytes: usize,
pub optim_memory_bytes: usize,
pub world_size: usize,
}
impl FSDPMemoryStats {
pub fn total_memory_mb(&self) -> f32 {
(self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
/ (1024.0 * 1024.0)
}
pub fn memory_savings(&self) -> f32 {
if self.world_size > 1 {
1.0 - (1.0 / self.world_size as f32)
} else {
0.0
}
}
}
#[allow(dead_code)]
pub struct ColumnParallelLinear {
weight: Parameter,
bias: Option<Parameter>,
process_group: ProcessGroup,
in_features: usize,
out_features: usize,
gather_output: bool,
}
impl ColumnParallelLinear {
pub fn new(
in_features: usize,
out_features: usize,
bias: bool,
process_group: ProcessGroup,
gather_output: bool,
) -> Self {
let world_size = process_group.world_size();
let local_out_features = out_features / world_size;
let weight_data = Tensor::randn(&[local_out_features, in_features]);
let weight = Parameter::new(weight_data, true);
let bias = if bias {
let bias_data = Tensor::zeros(&[local_out_features]);
Some(Parameter::new(bias_data, true))
} else {
None
};
Self {
weight,
bias,
process_group,
in_features,
out_features,
gather_output,
}
}
}
impl Module for ColumnParallelLinear {
fn forward(&self, input: &Variable) -> Variable {
let weight_var = Variable::new(self.weight.data(), false);
let output = input.matmul(&weight_var.transpose(0, 1));
let output = if let Some(ref bias) = self.bias {
let bias_var = Variable::new(bias.data(), false);
output.add(&bias_var)
} else {
output
};
if self.gather_output {
let gathered = self.process_group.all_gather_tensor(&output.data());
Variable::new(gathered, output.requires_grad())
} else {
output
}
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = vec![self.weight.clone()];
if let Some(ref bias) = self.bias {
params.push(bias.clone());
}
params
}
}
#[allow(dead_code)]
pub struct RowParallelLinear {
weight: Parameter,
bias: Option<Parameter>,
process_group: ProcessGroup,
in_features: usize,
out_features: usize,
input_is_parallel: bool,
}
impl RowParallelLinear {
pub fn new(
in_features: usize,
out_features: usize,
bias: bool,
process_group: ProcessGroup,
input_is_parallel: bool,
) -> Self {
let world_size = process_group.world_size();
let rank = process_group.rank();
let local_in_features = in_features / world_size;
let weight_data = Tensor::randn(&[out_features, local_in_features]);
let weight = Parameter::new(weight_data, true);
let bias = if bias && rank == 0 {
let bias_data = Tensor::zeros(&[out_features]);
Some(Parameter::new(bias_data, true))
} else {
None
};
Self {
weight,
bias,
process_group,
in_features,
out_features,
input_is_parallel,
}
}
}
impl Module for RowParallelLinear {
fn forward(&self, input: &Variable) -> Variable {
let local_input = if self.input_is_parallel {
input.clone()
} else {
let world_size = self.process_group.world_size();
let rank = self.process_group.rank();
let data = input.data();
let shape = data.shape();
let feature_dim = shape[shape.len() - 1];
let local_features = feature_dim / world_size;
let start = rank * local_features;
let end = start + local_features;
let sliced = if shape.len() == 2 {
data.slice(&[0..shape[0], start..end])
} else {
data.clone() };
Variable::new(sliced, input.requires_grad())
};
let weight_var = Variable::new(self.weight.data(), false);
let local_output = local_input.matmul(&weight_var.transpose(0, 1));
let mut output_data = local_output.data().clone();
self.process_group
.all_reduce_tensor(&mut output_data, ReduceOp::Sum);
let output = Variable::new(output_data, local_output.requires_grad());
if let Some(ref bias) = self.bias {
let bias_var = Variable::new(bias.data(), false);
output.add(&bias_var)
} else {
output
}
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = vec![self.weight.clone()];
if let Some(ref bias) = self.bias {
params.push(bias.clone());
}
params
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_nn::Linear;
#[test]
fn test_sharding_strategy_default() {
assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
}
#[test]
fn test_fsdp_creation() {
let model = Linear::new(10, 5);
let pg = ProcessGroup::mock();
let fsdp = FullyShardedDataParallel::new(model, pg);
assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
}
#[test]
fn test_fsdp_forward() {
let model = Linear::new(4, 2);
let pg = ProcessGroup::mock();
let mut fsdp = FullyShardedDataParallel::new(model, pg);
fsdp.gather_parameters();
let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
let output = fsdp.forward(&input);
assert_eq!(output.data().shape(), &[1, 2]);
}
#[test]
fn test_fsdp_builder() {
let model = Linear::new(10, 5);
let pg = ProcessGroup::mock();
let fsdp = FullyShardedDataParallel::new(model, pg)
.sharding_strategy(ShardingStrategy::ShardGradOp)
.cpu_offload(CPUOffload::Params)
.mixed_precision(true);
assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
}
#[test]
fn test_fsdp_memory_stats() {
let model = Linear::new(100, 50);
let pg = ProcessGroup::mock();
let fsdp = FullyShardedDataParallel::new(model, pg);
let stats = fsdp.memory_estimate();
assert!(stats.total_params > 0);
assert!(stats.total_memory_mb() > 0.0);
}
#[test]
fn test_fsdp_no_shard() {
let model = Linear::new(10, 5);
let pg = ProcessGroup::mock();
let fsdp =
FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
}
#[test]
fn test_column_parallel_linear() {
let pg = ProcessGroup::mock();
let layer = ColumnParallelLinear::new(8, 4, true, pg, false);
let input = Variable::new(Tensor::randn(&[2, 8]), false);
let output = layer.forward(&input);
assert_eq!(output.data().shape(), &[2, 4]);
}
#[test]
fn test_row_parallel_linear() {
let pg = ProcessGroup::mock();
let layer = RowParallelLinear::new(8, 4, true, pg, false);
let input = Variable::new(Tensor::randn(&[2, 8]), false);
let output = layer.forward(&input);
assert_eq!(output.data().shape(), &[2, 4]);
}
}