use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ZeroStage {
Stage0 = 0,
Stage1 = 1,
Stage2 = 2,
Stage3 = 3,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DeepSpeedConfig {
pub zero_optimization: ZeroOptimizationConfig,
pub gradient_clipping: Option<f32>,
pub gradient_accumulation_steps: Option<u32>,
pub fp16: Option<FP16Config>,
pub zero_force_ds_cpu_optimizer: Option<bool>,
pub activation_checkpointing: Option<ActivationCheckpointingConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroOptimizationConfig {
pub stage: ZeroStage,
pub allgather_bucket_size: Option<u64>,
pub reduce_bucket_size: Option<u64>,
pub overlap_comm: Option<bool>,
pub contiguous_gradients: Option<bool>,
pub sub_group_size: Option<u32>,
pub reduce_scatter: Option<bool>,
pub allgather_partitions: Option<bool>,
pub stage3_max_live_parameters: Option<u64>,
pub stage3_max_reuse_distance: Option<u64>,
pub stage3_prefetch_bucket_size: Option<u64>,
pub stage3_param_persistence_threshold: Option<u64>,
pub offload_optimizer: Option<OffloadOptimizerConfig>,
pub offload_param: Option<OffloadParamConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FP16Config {
pub enabled: bool,
pub loss_scale: Option<f32>,
pub loss_scale_window: Option<u32>,
pub hysteresis: Option<u32>,
pub min_loss_scale: Option<f32>,
pub initial_scale_power: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActivationCheckpointingConfig {
pub partition_activations: Option<bool>,
pub cpu_checkpointing: Option<bool>,
pub contiguous_memory_optimization: Option<bool>,
pub number_checkpoints: Option<u32>,
pub synchronize_checkpoint_boundary: Option<bool>,
pub profile: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OffloadOptimizerConfig {
pub device: String,
pub nvme_path: Option<String>,
pub pin_memory: Option<bool>,
pub buffer_count: Option<u32>,
pub fast_init: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OffloadParamConfig {
pub device: String,
pub nvme_path: Option<String>,
pub pin_memory: Option<bool>,
pub buffer_count: Option<u32>,
pub max_in_cpu: Option<u64>,
}
pub struct DeepSpeedIntegration {
config: DeepSpeedConfig,
initialized: bool,
}
impl DeepSpeedIntegration {
pub fn new(config: DeepSpeedConfig) -> Self {
Self {
config,
initialized: false,
}
}
pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
let content = std::fs::read_to_string(path).map_err(|e| {
TorshDistributedError::configuration_error(format!(
"Failed to read DeepSpeed config file: {}",
e
))
})?;
let config: DeepSpeedConfig = serde_json::from_str(&content).map_err(|e| {
TorshDistributedError::configuration_error(format!(
"Failed to parse DeepSpeed config: {}",
e
))
})?;
Ok(Self::new(config))
}
pub fn from_json(json: &str) -> TorshResult<Self> {
let config: DeepSpeedConfig = serde_json::from_str(json).map_err(|e| {
TorshDistributedError::configuration_error(format!(
"Failed to parse DeepSpeed config: {}",
e
))
})?;
Ok(Self::new(config))
}
pub fn initialize(&mut self) -> TorshResult<()> {
if self.initialized {
return Ok(());
}
self.validate_config()?;
match self.config.zero_optimization.stage {
ZeroStage::Stage0 => {
tracing::info!(
"DeepSpeed integration initialized with ZeRO Stage 0 (no optimization)"
);
}
ZeroStage::Stage1 => {
self.initialize_zero_stage1()?;
}
ZeroStage::Stage2 => {
self.initialize_zero_stage2()?;
}
ZeroStage::Stage3 => {
self.initialize_zero_stage3()?;
}
}
self.initialized = true;
Ok(())
}
fn validate_config(&self) -> TorshResult<()> {
if matches!(self.config.zero_optimization.stage, ZeroStage::Stage3)
&& self
.config
.zero_optimization
.stage3_max_live_parameters
.is_none()
{
return Err(TorshDistributedError::configuration_error(
"ZeRO Stage 3 requires stage3_max_live_parameters to be set",
));
}
if let Some(ref offload_config) = self.config.zero_optimization.offload_optimizer {
if offload_config.device.is_empty() {
return Err(TorshDistributedError::configuration_error(
"Offload optimizer device cannot be empty",
));
}
}
if let Some(ref offload_config) = self.config.zero_optimization.offload_param {
if offload_config.device.is_empty() {
return Err(TorshDistributedError::configuration_error(
"Offload parameter device cannot be empty",
));
}
}
Ok(())
}
fn initialize_zero_stage1(&self) -> TorshResult<()> {
tracing::info!("Initializing DeepSpeed ZeRO Stage 1 (optimizer state partitioning)");
let bucket_size = self
.config
.zero_optimization
.reduce_bucket_size
.unwrap_or(2e8 as u64);
tracing::debug!("ZeRO Stage 1 - Reduce bucket size: {}", bucket_size);
Ok(())
}
fn initialize_zero_stage2(&self) -> TorshResult<()> {
tracing::info!("Initializing DeepSpeed ZeRO Stage 2 (gradient partitioning)");
let allgather_bucket_size = self
.config
.zero_optimization
.allgather_bucket_size
.unwrap_or(2e8 as u64);
let reduce_bucket_size = self
.config
.zero_optimization
.reduce_bucket_size
.unwrap_or(2e8 as u64);
let overlap_comm = self.config.zero_optimization.overlap_comm.unwrap_or(true);
tracing::debug!(
"ZeRO Stage 2 - Allgather bucket size: {}",
allgather_bucket_size
);
tracing::debug!("ZeRO Stage 2 - Reduce bucket size: {}", reduce_bucket_size);
tracing::debug!("ZeRO Stage 2 - Overlap communication: {}", overlap_comm);
Ok(())
}
fn initialize_zero_stage3(&self) -> TorshResult<()> {
tracing::info!("Initializing DeepSpeed ZeRO Stage 3 (parameter partitioning)");
let max_live_params = self
.config
.zero_optimization
.stage3_max_live_parameters
.unwrap_or(1e9 as u64);
let max_reuse_distance = self
.config
.zero_optimization
.stage3_max_reuse_distance
.unwrap_or(1000);
let prefetch_bucket_size = self
.config
.zero_optimization
.stage3_prefetch_bucket_size
.unwrap_or(5e8 as u64);
tracing::debug!("ZeRO Stage 3 - Max live parameters: {}", max_live_params);
tracing::debug!("ZeRO Stage 3 - Max reuse distance: {}", max_reuse_distance);
tracing::debug!(
"ZeRO Stage 3 - Prefetch bucket size: {}",
prefetch_bucket_size
);
Ok(())
}
pub fn config(&self) -> &DeepSpeedConfig {
&self.config
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn to_fsdp_config(&self) -> TorshResult<crate::fsdp::FsdpConfig> {
use crate::fsdp::{FsdpConfig, MixedPrecisionConfig, ShardingStrategy};
let sharding_strategy = match self.config.zero_optimization.stage {
ZeroStage::Stage0 => ShardingStrategy::NoShard,
ZeroStage::Stage1 => ShardingStrategy::ShardGradOp,
ZeroStage::Stage2 => ShardingStrategy::ShardGradOp,
ZeroStage::Stage3 => ShardingStrategy::FullShard,
};
let mixed_precision = if let Some(ref fp16_config) = self.config.fp16 {
if fp16_config.enabled {
Some(MixedPrecisionConfig {
param_dtype: torsh_core::DType::F16,
reduce_dtype: torsh_core::DType::F16,
buffer_dtype: torsh_core::DType::F16,
keep_low_precision_grads: false,
})
} else {
None
}
} else {
None
};
Ok(FsdpConfig {
min_num_params: 1000,
auto_wrap_policy: crate::fsdp::AutoWrapPolicy::SizeBasedAutoWrap {
min_num_params: 1000,
},
sharding_strategy,
mixed_precision,
cpu_offload: self.config.zero_optimization.offload_optimizer.is_some()
|| self.config.zero_optimization.offload_param.is_some(),
memory_config: crate::fsdp::MemoryConfig::default(),
backward_prefetch: crate::fsdp::BackwardPrefetch::BackwardPre,
})
}
pub fn to_gradient_compression_config(
&self,
) -> Option<crate::gradient_compression::CompressionConfig> {
if self
.config
.zero_optimization
.reduce_scatter
.unwrap_or(false)
{
Some(crate::gradient_compression::CompressionConfig {
method: crate::gradient_compression::CompressionMethod::TopK { k: 0.1 },
compression_ratio: 0.1,
error_feedback: true,
error_feedback_momentum: 0.9,
memory_efficient: true,
warmup_steps: 5,
})
} else {
None
}
}
pub fn get_stats(&self) -> DeepSpeedStats {
DeepSpeedStats {
zero_stage: self.config.zero_optimization.stage,
initialized: self.initialized,
fp16_enabled: self
.config
.fp16
.as_ref()
.map(|c| c.enabled)
.unwrap_or(false),
cpu_offload_enabled: self.config.zero_force_ds_cpu_optimizer.unwrap_or(false),
activation_checkpointing_enabled: self
.config
.activation_checkpointing
.as_ref()
.map(|c| c.partition_activations.unwrap_or(false))
.unwrap_or(false),
}
}
}
impl Default for DeepSpeedIntegration {
fn default() -> Self {
Self::new(DeepSpeedConfig::default())
}
}
#[derive(Debug, Clone)]
pub struct DeepSpeedStats {
pub zero_stage: ZeroStage,
pub initialized: bool,
pub fp16_enabled: bool,
pub cpu_offload_enabled: bool,
pub activation_checkpointing_enabled: bool,
}
impl Default for ZeroOptimizationConfig {
fn default() -> Self {
Self {
stage: ZeroStage::Stage0,
allgather_bucket_size: None,
reduce_bucket_size: None,
overlap_comm: None,
contiguous_gradients: None,
sub_group_size: None,
reduce_scatter: None,
allgather_partitions: None,
stage3_max_live_parameters: None,
stage3_max_reuse_distance: None,
stage3_prefetch_bucket_size: None,
stage3_param_persistence_threshold: None,
offload_optimizer: None,
offload_param: None,
}
}
}
pub mod utils {
use super::*;
pub fn create_zero_stage1_config() -> DeepSpeedConfig {
DeepSpeedConfig {
zero_optimization: ZeroOptimizationConfig {
stage: ZeroStage::Stage1,
overlap_comm: Some(true),
contiguous_gradients: Some(true),
reduce_bucket_size: Some(2e8 as u64),
..Default::default()
},
..Default::default()
}
}
pub fn create_zero_stage2_config() -> DeepSpeedConfig {
DeepSpeedConfig {
zero_optimization: ZeroOptimizationConfig {
stage: ZeroStage::Stage2,
overlap_comm: Some(true),
contiguous_gradients: Some(true),
reduce_bucket_size: Some(2e8 as u64),
allgather_bucket_size: Some(2e8 as u64),
..Default::default()
},
..Default::default()
}
}
pub fn create_zero_stage3_config() -> DeepSpeedConfig {
DeepSpeedConfig {
zero_optimization: ZeroOptimizationConfig {
stage: ZeroStage::Stage3,
overlap_comm: Some(true),
contiguous_gradients: Some(true),
reduce_bucket_size: Some(2e8 as u64),
allgather_bucket_size: Some(2e8 as u64),
stage3_max_live_parameters: Some(1e9 as u64),
stage3_max_reuse_distance: Some(1000),
stage3_prefetch_bucket_size: Some(5e8 as u64),
stage3_param_persistence_threshold: Some(1e6 as u64),
..Default::default()
},
..Default::default()
}
}
pub fn create_fp16_config() -> DeepSpeedConfig {
DeepSpeedConfig {
zero_optimization: ZeroOptimizationConfig {
stage: ZeroStage::Stage2,
overlap_comm: Some(true),
contiguous_gradients: Some(true),
reduce_bucket_size: Some(2e8 as u64),
allgather_bucket_size: Some(2e8 as u64),
..Default::default()
},
fp16: Some(FP16Config {
enabled: true,
loss_scale: None,
loss_scale_window: Some(1000),
hysteresis: Some(2),
min_loss_scale: Some(1.0),
initial_scale_power: Some(16),
}),
..Default::default()
}
}
pub fn create_cpu_offload_config() -> DeepSpeedConfig {
DeepSpeedConfig {
zero_optimization: ZeroOptimizationConfig {
stage: ZeroStage::Stage3,
overlap_comm: Some(true),
contiguous_gradients: Some(true),
reduce_bucket_size: Some(2e8 as u64),
allgather_bucket_size: Some(2e8 as u64),
stage3_max_live_parameters: Some(1e9 as u64),
stage3_max_reuse_distance: Some(1000),
stage3_prefetch_bucket_size: Some(5e8 as u64),
stage3_param_persistence_threshold: Some(1e6 as u64),
offload_optimizer: Some(OffloadOptimizerConfig {
device: "cpu".to_string(),
nvme_path: None,
pin_memory: Some(false),
buffer_count: Some(4),
fast_init: Some(false),
}),
offload_param: Some(OffloadParamConfig {
device: "cpu".to_string(),
nvme_path: None,
pin_memory: Some(false),
buffer_count: Some(4),
max_in_cpu: Some(1e9 as u64),
}),
..Default::default()
},
zero_force_ds_cpu_optimizer: Some(true),
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deepspeed_config_serialization() {
let config = utils::create_zero_stage2_config();
let json = serde_json::to_string(&config).unwrap();
let deserialized: DeepSpeedConfig = serde_json::from_str(&json).unwrap();
assert_eq!(
config.zero_optimization.stage,
deserialized.zero_optimization.stage
);
assert_eq!(
config.zero_optimization.overlap_comm,
deserialized.zero_optimization.overlap_comm
);
}
#[test]
fn test_deepspeed_integration_initialization() {
let config = utils::create_zero_stage1_config();
let mut integration = DeepSpeedIntegration::new(config);
assert!(!integration.is_initialized());
integration.initialize().unwrap();
assert!(integration.is_initialized());
}
#[test]
fn test_deepspeed_to_fsdp_config() {
let config = utils::create_zero_stage3_config();
let integration = DeepSpeedIntegration::new(config);
let fsdp_config = integration.to_fsdp_config().unwrap();
assert_eq!(
fsdp_config.sharding_strategy,
crate::fsdp::ShardingStrategy::FullShard
);
}
#[test]
fn test_deepspeed_stats() {
let config = utils::create_fp16_config();
let integration = DeepSpeedIntegration::new(config);
let stats = integration.get_stats();
assert_eq!(stats.zero_stage, ZeroStage::Stage2);
assert!(!stats.initialized);
assert!(stats.fp16_enabled); }
#[test]
fn test_deepspeed_config_validation() {
let mut config = utils::create_zero_stage3_config();
config.zero_optimization.stage3_max_live_parameters = None;
let mut integration = DeepSpeedIntegration::new(config);
assert!(integration.initialize().is_err());
}
}