use crate::fsdp::{AutoWrapPolicy, MemoryConfig};
use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::path::Path;
use torsh_core::DType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairScaleConfig {
pub fsdp: Option<FairScaleFsdpConfig>,
pub oss: Option<FairScaleOssConfig>,
pub sharded_grad_scaler: Option<FairScaleGradScalerConfig>,
pub activation_checkpointing: Option<FairScaleActivationCheckpointingConfig>,
pub pipeline_parallelism: Option<FairScalePipelineConfig>,
pub memory_optimization: Option<FairScaleMemoryOptimizationConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairScaleFsdpConfig {
pub auto_wrap_policy: Option<FairScaleAutoWrapPolicy>,
pub min_num_params: Option<u64>,
pub wrapper_cls: Option<String>,
pub mixed_precision: Option<bool>,
pub flatten_parameters: Option<bool>,
pub bucket_cap_mb: Option<f32>,
pub compute_dtype: Option<String>,
pub buffer_dtype: Option<String>,
pub reshard_after_forward: Option<bool>,
pub move_grads_to_cpu: Option<bool>,
pub move_params_to_cpu: Option<bool>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FairScaleAutoWrapPolicy {
None,
SizeBased,
TransformerBased,
CustomFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairScaleOssConfig {
pub optimizer: String,
pub broadcast_buffers: Option<bool>,
pub compress_gradients: Option<bool>,
pub gradient_compression: Option<String>,
pub partition_optimizer: Option<bool>,
pub gradient_predivide_factor: Option<f32>,
pub gradient_postdivide_factor: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairScaleGradScalerConfig {
pub init_scale: Option<f32>,
pub growth_factor: Option<f32>,
pub backoff_factor: Option<f32>,
pub growth_interval: Option<u32>,
pub enabled: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairScaleActivationCheckpointingConfig {
pub strategy: FairScaleCheckpointingStrategy,
pub checkpoint_ratio: Option<f32>,
pub offload_to_cpu: Option<bool>,
pub checkpoint_every_n_layers: Option<u32>,
pub use_gradient_checkpointing: Option<bool>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FairScaleCheckpointingStrategy {
None,
Uniform,
Selective,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairScalePipelineConfig {
pub stages: u32,
pub micro_batch_size: Option<u32>,
pub balance_mode: Option<FairScaleBalanceMode>,
pub schedule: Option<FairScalePipelineSchedule>,
pub checkpoint_activation: Option<bool>,
pub distributed_backend: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FairScaleBalanceMode {
Auto,
Manual,
Parameters,
Time,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FairScalePipelineSchedule {
GPipe,
OneF1B,
Interleaved,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairScaleMemoryOptimizationConfig {
pub cpu_offloading: Option<bool>,
pub parameter_offloading: Option<bool>,
pub optimizer_offloading: Option<bool>,
pub gradient_compression: Option<bool>,
pub memory_defragmentation: Option<bool>,
pub lazy_parameter_init: Option<bool>,
}
#[derive(Debug, Clone, Default)]
pub struct FairScaleStats {
pub fsdp_ops: u64,
pub fsdp_time_sec: f64,
pub oss_ops: u64,
pub oss_time_sec: f64,
pub memory_saved_bytes: u64,
pub checkpointed_activations: u64,
pub pipeline_efficiency: f64,
pub gradient_scaling_events: u64,
pub average_shard_size: f64,
}
pub struct FairScaleIntegration {
config: FairScaleConfig,
stats: FairScaleStats,
initialized: bool,
rank: u32,
world_size: u32,
local_rank: u32,
local_size: u32,
}
impl FairScaleIntegration {
pub fn new(config: FairScaleConfig) -> Self {
Self {
config,
stats: FairScaleStats::default(),
initialized: false,
rank: 0,
world_size: 1,
local_rank: 0,
local_size: 1,
}
}
pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
let content = std::fs::read_to_string(path).map_err(|e| {
TorshDistributedError::configuration_error(format!(
"Failed to read FairScale config file: {}",
e
))
})?;
let config: FairScaleConfig = serde_json::from_str(&content).map_err(|e| {
TorshDistributedError::configuration_error(format!(
"Failed to parse FairScale config: {}",
e
))
})?;
Ok(Self::new(config))
}
pub fn initialize(
&mut self,
rank: u32,
world_size: u32,
local_rank: u32,
local_size: u32,
) -> TorshResult<()> {
if self.initialized {
return Err(TorshDistributedError::configuration_error(
"FairScale integration already initialized",
));
}
self.rank = rank;
self.world_size = world_size;
self.local_rank = local_rank;
self.local_size = local_size;
self.validate_config()?;
self.setup_fsdp()?;
self.setup_oss()?;
self.setup_grad_scaler()?;
self.setup_activation_checkpointing()?;
self.setup_pipeline_parallelism()?;
self.setup_memory_optimization()?;
self.initialized = true;
tracing::info!(
"FairScale integration initialized - rank: {}, world_size: {}, local_rank: {}",
self.rank,
self.world_size,
self.local_rank
);
Ok(())
}
fn validate_config(&self) -> TorshResult<()> {
if let Some(ref fsdp) = self.config.fsdp {
if let Some(min_params) = fsdp.min_num_params {
if min_params == 0 {
return Err(TorshDistributedError::configuration_error(
"FSDP min_num_params must be greater than 0",
));
}
}
if let Some(bucket_cap) = fsdp.bucket_cap_mb {
if bucket_cap <= 0.0 {
return Err(TorshDistributedError::configuration_error(
"FSDP bucket_cap_mb must be greater than 0",
));
}
}
}
if let Some(ref pipeline) = self.config.pipeline_parallelism {
if pipeline.stages == 0 {
return Err(TorshDistributedError::configuration_error(
"Pipeline stages must be greater than 0",
));
}
if pipeline.stages > self.world_size {
return Err(TorshDistributedError::configuration_error(
"Pipeline stages cannot exceed world size",
));
}
if let Some(micro_batch_size) = pipeline.micro_batch_size {
if micro_batch_size == 0 {
return Err(TorshDistributedError::configuration_error(
"Pipeline micro_batch_size must be greater than 0",
));
}
}
}
if let Some(ref grad_scaler) = self.config.sharded_grad_scaler {
if let Some(init_scale) = grad_scaler.init_scale {
if init_scale <= 0.0 {
return Err(TorshDistributedError::configuration_error(
"GradScaler init_scale must be greater than 0",
));
}
}
if let Some(growth_factor) = grad_scaler.growth_factor {
if growth_factor <= 1.0 {
return Err(TorshDistributedError::configuration_error(
"GradScaler growth_factor must be greater than 1",
));
}
}
if let Some(backoff_factor) = grad_scaler.backoff_factor {
if backoff_factor <= 0.0 || backoff_factor >= 1.0 {
return Err(TorshDistributedError::configuration_error(
"GradScaler backoff_factor must be between 0 and 1",
));
}
}
}
Ok(())
}
fn setup_fsdp(&self) -> TorshResult<()> {
if let Some(ref fsdp) = self.config.fsdp {
tracing::info!("Setting up FairScale FSDP");
let auto_wrap_policy = fsdp
.auto_wrap_policy
.unwrap_or(FairScaleAutoWrapPolicy::SizeBased);
tracing::debug!("FSDP auto-wrap policy: {:?}", auto_wrap_policy);
let min_params = fsdp.min_num_params.unwrap_or(100_000);
tracing::debug!("FSDP min parameters for auto-wrap: {}", min_params);
let mixed_precision = fsdp.mixed_precision.unwrap_or(false);
tracing::debug!("FSDP mixed precision: {}", mixed_precision);
let flatten_parameters = fsdp.flatten_parameters.unwrap_or(true);
tracing::debug!("FSDP flatten parameters: {}", flatten_parameters);
let bucket_cap_mb = fsdp.bucket_cap_mb.unwrap_or(25.0);
tracing::debug!("FSDP bucket capacity: {} MB", bucket_cap_mb);
let reshard_after_forward = fsdp.reshard_after_forward.unwrap_or(true);
tracing::debug!("FSDP reshard after forward: {}", reshard_after_forward);
if fsdp.move_grads_to_cpu.unwrap_or(false) {
tracing::debug!("FSDP gradient CPU offloading enabled");
}
if fsdp.move_params_to_cpu.unwrap_or(false) {
tracing::debug!("FSDP parameter CPU offloading enabled");
}
}
Ok(())
}
fn setup_oss(&self) -> TorshResult<()> {
if let Some(ref oss) = self.config.oss {
tracing::info!("Setting up FairScale OSS (Optimizer State Sharding)");
tracing::debug!("OSS optimizer: {}", oss.optimizer);
let broadcast_buffers = oss.broadcast_buffers.unwrap_or(true);
tracing::debug!("OSS broadcast buffers: {}", broadcast_buffers);
let compress_gradients = oss.compress_gradients.unwrap_or(false);
tracing::debug!("OSS compress gradients: {}", compress_gradients);
if let Some(ref compression) = oss.gradient_compression {
tracing::debug!("OSS gradient compression algorithm: {}", compression);
}
let partition_optimizer = oss.partition_optimizer.unwrap_or(true);
tracing::debug!("OSS partition optimizer: {}", partition_optimizer);
let predivide_factor = oss.gradient_predivide_factor.unwrap_or(1.0);
tracing::debug!("OSS gradient predivide factor: {}", predivide_factor);
let postdivide_factor = oss.gradient_postdivide_factor.unwrap_or(1.0);
tracing::debug!("OSS gradient postdivide factor: {}", postdivide_factor);
}
Ok(())
}
fn setup_grad_scaler(&self) -> TorshResult<()> {
if let Some(ref grad_scaler) = self.config.sharded_grad_scaler {
tracing::info!("Setting up FairScale ShardedGradScaler");
let init_scale = grad_scaler.init_scale.unwrap_or(2.0_f32.powi(16));
tracing::debug!("GradScaler initial scale: {}", init_scale);
let growth_factor = grad_scaler.growth_factor.unwrap_or(2.0);
tracing::debug!("GradScaler growth factor: {}", growth_factor);
let backoff_factor = grad_scaler.backoff_factor.unwrap_or(0.5);
tracing::debug!("GradScaler backoff factor: {}", backoff_factor);
let growth_interval = grad_scaler.growth_interval.unwrap_or(2000);
tracing::debug!("GradScaler growth interval: {}", growth_interval);
let enabled = grad_scaler.enabled.unwrap_or(true);
tracing::debug!("GradScaler enabled: {}", enabled);
}
Ok(())
}
fn setup_activation_checkpointing(&self) -> TorshResult<()> {
if let Some(ref checkpoint) = self.config.activation_checkpointing {
tracing::info!("Setting up FairScale activation checkpointing");
tracing::debug!(
"Activation checkpointing strategy: {:?}",
checkpoint.strategy
);
let checkpoint_ratio = checkpoint.checkpoint_ratio.unwrap_or(0.5);
tracing::debug!("Activation checkpointing ratio: {}", checkpoint_ratio);
let offload_to_cpu = checkpoint.offload_to_cpu.unwrap_or(false);
tracing::debug!("Activation checkpointing CPU offload: {}", offload_to_cpu);
if let Some(every_n) = checkpoint.checkpoint_every_n_layers {
tracing::debug!("Checkpoint every {} layers", every_n);
}
let use_gradient_checkpointing = checkpoint.use_gradient_checkpointing.unwrap_or(false);
tracing::debug!("Use gradient checkpointing: {}", use_gradient_checkpointing);
}
Ok(())
}
fn setup_pipeline_parallelism(&self) -> TorshResult<()> {
if let Some(ref pipeline) = self.config.pipeline_parallelism {
tracing::info!("Setting up FairScale pipeline parallelism");
tracing::debug!("Pipeline stages: {}", pipeline.stages);
let micro_batch_size = pipeline.micro_batch_size.unwrap_or(1);
tracing::debug!("Pipeline micro batch size: {}", micro_batch_size);
let balance_mode = pipeline.balance_mode.unwrap_or(FairScaleBalanceMode::Auto);
tracing::debug!("Pipeline balance mode: {:?}", balance_mode);
let schedule = pipeline
.schedule
.unwrap_or(FairScalePipelineSchedule::GPipe);
tracing::debug!("Pipeline schedule: {:?}", schedule);
let checkpoint_activation = pipeline.checkpoint_activation.unwrap_or(false);
tracing::debug!("Pipeline checkpoint activation: {}", checkpoint_activation);
if let Some(ref backend) = pipeline.distributed_backend {
tracing::debug!("Pipeline distributed backend: {}", backend);
}
}
Ok(())
}
fn setup_memory_optimization(&self) -> TorshResult<()> {
if let Some(ref memory) = self.config.memory_optimization {
tracing::info!("Setting up FairScale memory optimization");
let cpu_offloading = memory.cpu_offloading.unwrap_or(false);
tracing::debug!("Memory CPU offloading: {}", cpu_offloading);
let parameter_offloading = memory.parameter_offloading.unwrap_or(false);
tracing::debug!("Memory parameter offloading: {}", parameter_offloading);
let optimizer_offloading = memory.optimizer_offloading.unwrap_or(false);
tracing::debug!("Memory optimizer offloading: {}", optimizer_offloading);
let gradient_compression = memory.gradient_compression.unwrap_or(false);
tracing::debug!("Memory gradient compression: {}", gradient_compression);
let memory_defragmentation = memory.memory_defragmentation.unwrap_or(false);
tracing::debug!("Memory defragmentation: {}", memory_defragmentation);
let lazy_parameter_init = memory.lazy_parameter_init.unwrap_or(false);
tracing::debug!(
"Memory lazy parameter initialization: {}",
lazy_parameter_init
);
}
Ok(())
}
pub fn to_fsdp_config(&self) -> TorshResult<crate::fsdp::FsdpConfig> {
use crate::fsdp::{BackwardPrefetch, FsdpConfig, MixedPrecisionConfig, ShardingStrategy};
let sharding_strategy = if let Some(ref fsdp) = self.config.fsdp {
if fsdp.move_params_to_cpu.unwrap_or(false) {
ShardingStrategy::NoShard
} else if fsdp.reshard_after_forward.unwrap_or(true) {
ShardingStrategy::FullShard
} else {
ShardingStrategy::ShardGradOp
}
} else {
ShardingStrategy::FullShard
};
let mixed_precision = if let Some(ref fsdp) = self.config.fsdp {
if fsdp.mixed_precision.unwrap_or(false) {
Some(MixedPrecisionConfig {
param_dtype: DType::F16, reduce_dtype: DType::F16,
buffer_dtype: DType::F16,
keep_low_precision_grads: false,
})
} else {
None
}
} else {
None
};
let config = FsdpConfig {
min_num_params: 1000,
auto_wrap_policy: AutoWrapPolicy::SizeBasedAutoWrap {
min_num_params: 1000,
},
sharding_strategy,
mixed_precision,
cpu_offload: self
.config
.fsdp
.as_ref()
.map(|f| f.move_params_to_cpu.unwrap_or(false))
.unwrap_or(false),
memory_config: MemoryConfig {
limit_all_gathers: true,
use_orig_params: false,
offload_to_cpu: self
.config
.fsdp
.as_ref()
.map(|f| f.move_params_to_cpu.unwrap_or(false))
.unwrap_or(false),
},
backward_prefetch: BackwardPrefetch::BackwardPre,
};
Ok(config)
}
pub fn to_pipeline_config(&self) -> TorshResult<Option<crate::pipeline::PipelineConfig>> {
if let Some(ref pipeline) = self.config.pipeline_parallelism {
use crate::pipeline::{PipelineConfig, ScheduleType};
let schedule_type = match pipeline
.schedule
.unwrap_or(FairScalePipelineSchedule::GPipe)
{
FairScalePipelineSchedule::GPipe => ScheduleType::GPipe,
FairScalePipelineSchedule::OneF1B => ScheduleType::OneFOneBInterleaved,
FairScalePipelineSchedule::Interleaved => ScheduleType::InterleavedOneFOneB,
};
let config = PipelineConfig {
num_micro_batches: pipeline.micro_batch_size.unwrap_or(1) as usize,
schedule: schedule_type,
accumulate_gradients: false,
base_tag: 0,
async_comm: true,
comm_timeout_ms: 5000,
};
Ok(Some(config))
} else {
Ok(None)
}
}
pub fn config(&self) -> &FairScaleConfig {
&self.config
}
pub fn stats(&self) -> &FairScaleStats {
&self.stats
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn rank(&self) -> u32 {
self.rank
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn local_rank(&self) -> u32 {
self.local_rank
}
pub fn local_size(&self) -> u32 {
self.local_size
}
pub fn fsdp_operation(
&mut self,
operation_name: &str,
parameter_count: usize,
) -> TorshResult<()> {
if !self.initialized {
return Err(TorshDistributedError::BackendNotInitialized);
}
let start_time = std::time::Instant::now();
tracing::debug!(
"FairScale FSDP operation: {} ({} params)",
operation_name,
parameter_count
);
self.stats.fsdp_ops += 1;
self.stats.fsdp_time_sec += start_time.elapsed().as_secs_f64();
let shard_size = parameter_count / self.world_size as usize;
let memory_saved = parameter_count - shard_size;
self.stats.memory_saved_bytes += (memory_saved * 4) as u64; self.stats.average_shard_size = shard_size as f64;
Ok(())
}
pub fn oss_operation(
&mut self,
operation_name: &str,
optimizer_state_size: usize,
) -> TorshResult<()> {
if !self.initialized {
return Err(TorshDistributedError::BackendNotInitialized);
}
let start_time = std::time::Instant::now();
tracing::debug!(
"FairScale OSS operation: {} ({} bytes)",
operation_name,
optimizer_state_size
);
self.stats.oss_ops += 1;
self.stats.oss_time_sec += start_time.elapsed().as_secs_f64();
Ok(())
}
pub fn record_activation_checkpoint(&mut self, layer_name: &str, memory_saved: usize) {
if self.config.activation_checkpointing.is_some() {
tracing::debug!(
"Activation checkpoint: {} (saved {} bytes)",
layer_name,
memory_saved
);
self.stats.checkpointed_activations += 1;
self.stats.memory_saved_bytes += memory_saved as u64;
}
}
pub fn record_gradient_scaling_event(&mut self, scale_factor: f32) {
if self.config.sharded_grad_scaler.is_some() {
tracing::debug!("Gradient scaling event: scale factor {}", scale_factor);
self.stats.gradient_scaling_events += 1;
}
}
pub fn default_config() -> FairScaleConfig {
FairScaleConfig {
fsdp: Some(FairScaleFsdpConfig {
auto_wrap_policy: Some(FairScaleAutoWrapPolicy::SizeBased),
min_num_params: Some(100_000),
wrapper_cls: None,
mixed_precision: Some(false),
flatten_parameters: Some(true),
bucket_cap_mb: Some(25.0),
compute_dtype: Some("float32".to_string()),
buffer_dtype: Some("float32".to_string()),
reshard_after_forward: Some(true),
move_grads_to_cpu: Some(false),
move_params_to_cpu: Some(false),
}),
oss: Some(FairScaleOssConfig {
optimizer: "AdamW".to_string(),
broadcast_buffers: Some(true),
compress_gradients: Some(false),
gradient_compression: None,
partition_optimizer: Some(true),
gradient_predivide_factor: Some(1.0),
gradient_postdivide_factor: Some(1.0),
}),
sharded_grad_scaler: None,
activation_checkpointing: None,
pipeline_parallelism: None,
memory_optimization: Some(FairScaleMemoryOptimizationConfig {
cpu_offloading: Some(false),
parameter_offloading: Some(false),
optimizer_offloading: Some(false),
gradient_compression: Some(false),
memory_defragmentation: Some(false),
lazy_parameter_init: Some(false),
}),
}
}
pub fn config_with_fsdp_mixed_precision() -> FairScaleConfig {
let mut config = Self::default_config();
if let Some(ref mut fsdp) = config.fsdp {
fsdp.mixed_precision = Some(true);
fsdp.compute_dtype = Some("float16".to_string());
fsdp.buffer_dtype = Some("float16".to_string());
}
config.sharded_grad_scaler = Some(FairScaleGradScalerConfig {
init_scale: Some(2.0_f32.powi(16)),
growth_factor: Some(2.0),
backoff_factor: Some(0.5),
growth_interval: Some(2000),
enabled: Some(true),
});
config
}
pub fn config_with_pipeline_parallelism(stages: u32) -> FairScaleConfig {
let mut config = Self::default_config();
config.pipeline_parallelism = Some(FairScalePipelineConfig {
stages,
micro_batch_size: Some(1),
balance_mode: Some(FairScaleBalanceMode::Auto),
schedule: Some(FairScalePipelineSchedule::OneF1B),
checkpoint_activation: Some(true),
distributed_backend: Some("nccl".to_string()),
});
config.activation_checkpointing = Some(FairScaleActivationCheckpointingConfig {
strategy: FairScaleCheckpointingStrategy::Uniform,
checkpoint_ratio: Some(0.5),
offload_to_cpu: Some(false),
checkpoint_every_n_layers: Some(4),
use_gradient_checkpointing: Some(true),
});
config
}
}
impl Default for FairScaleConfig {
fn default() -> Self {
FairScaleIntegration::default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fairscale_config_validation() {
let config = FairScaleIntegration::default_config();
let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.is_initialized());
assert_eq!(integration.rank(), 0);
assert_eq!(integration.world_size(), 4);
assert_eq!(integration.local_rank(), 0);
}
#[test]
fn test_fairscale_fsdp_config_conversion() {
let config = FairScaleIntegration::default_config();
let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
let fsdp_config = integration.to_fsdp_config().unwrap();
assert!(matches!(
fsdp_config.sharding_strategy,
crate::fsdp::ShardingStrategy::FullShard
));
assert!(!fsdp_config.cpu_offload);
assert_eq!(fsdp_config.min_num_params, 1000);
}
#[test]
fn test_fairscale_pipeline_config_conversion() {
let config = FairScaleIntegration::config_with_pipeline_parallelism(4);
let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
let pipeline_config = integration.to_pipeline_config().unwrap();
assert!(pipeline_config.is_some());
if let Some(config) = pipeline_config {
assert_eq!(config.num_micro_batches, 1);
assert!(matches!(
config.schedule,
crate::pipeline::ScheduleType::OneFOneBInterleaved
));
}
}
#[test]
fn test_fairscale_fsdp_operations() {
let config = FairScaleIntegration::default_config();
let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.fsdp_operation("forward", 1000000).is_ok());
assert!(integration.fsdp_operation("backward", 1000000).is_ok());
let stats = integration.stats();
assert_eq!(stats.fsdp_ops, 2);
assert!(stats.fsdp_time_sec >= 0.0);
assert!(stats.memory_saved_bytes > 0);
assert_eq!(stats.average_shard_size, 250000.0); }
#[test]
fn test_fairscale_oss_operations() {
let config = FairScaleIntegration::default_config();
let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.oss_operation("step", 1024).is_ok());
assert!(integration.oss_operation("zero_grad", 1024).is_ok());
let stats = integration.stats();
assert_eq!(stats.oss_ops, 2);
assert!(stats.oss_time_sec >= 0.0);
}
#[test]
fn test_fairscale_mixed_precision_config() {
let config = FairScaleIntegration::config_with_fsdp_mixed_precision();
let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
let fsdp_config = integration.to_fsdp_config().unwrap();
assert!(fsdp_config.mixed_precision.is_some());
if let Some(mp_config) = fsdp_config.mixed_precision {
assert_eq!(mp_config.param_dtype, DType::F16);
assert_eq!(mp_config.reduce_dtype, DType::F16);
assert_eq!(mp_config.buffer_dtype, DType::F16);
}
}
#[test]
fn test_fairscale_invalid_pipeline_stages() {
let config = FairScaleIntegration::config_with_pipeline_parallelism(0); let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_err());
}
#[test]
fn test_fairscale_config_serialization() {
let config = FairScaleIntegration::config_with_fsdp_mixed_precision();
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("float16"));
assert!(json.contains("fsdp"));
assert!(json.contains("sharded_grad_scaler"));
let deserialized: FairScaleConfig = serde_json::from_str(&json).unwrap();
assert!(deserialized.fsdp.is_some());
assert!(deserialized.sharded_grad_scaler.is_some());
}
#[test]
fn test_fairscale_activation_checkpointing() {
let config = FairScaleIntegration::config_with_pipeline_parallelism(4);
let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
integration.record_activation_checkpoint("layer1", 1024);
integration.record_activation_checkpoint("layer2", 2048);
let stats = integration.stats();
assert_eq!(stats.checkpointed_activations, 2);
assert!(stats.memory_saved_bytes >= 3072); }
#[test]
fn test_fairscale_gradient_scaling() {
let config = FairScaleIntegration::config_with_fsdp_mixed_precision();
let mut integration = FairScaleIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
integration.record_gradient_scaling_event(65536.0);
integration.record_gradient_scaling_event(32768.0);
let stats = integration.stats();
assert_eq!(stats.gradient_scaling_events, 2);
}
}