use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Zero3CpuOffloadConfig {
pub offload_params: bool,
pub offload_grads: bool,
pub offload_optimizer_states: bool,
pub cpu_memory_budget: usize,
pub gpu_param_memory_budget: usize,
pub max_gpu_memory_mb: usize,
pub max_cpu_memory_mb: usize,
pub prefetch_buffer_size: usize,
pub async_prefetch: bool,
pub overlap_computation: bool,
pub pin_cpu_memory: bool,
pub cpu_compression: CpuCompressionMethod,
pub auto_memory_management: AutoMemoryStrategy,
}
impl Default for Zero3CpuOffloadConfig {
fn default() -> Self {
Self {
offload_params: true,
offload_grads: true,
offload_optimizer_states: true,
cpu_memory_budget: 32 * 1024 * 1024 * 1024, gpu_param_memory_budget: 2 * 1024 * 1024 * 1024, max_gpu_memory_mb: 8 * 1024, max_cpu_memory_mb: 64 * 1024, prefetch_buffer_size: 16,
async_prefetch: true,
overlap_computation: true,
pin_cpu_memory: true,
cpu_compression: CpuCompressionMethod::None,
auto_memory_management: AutoMemoryStrategy::Aggressive,
}
}
}
impl Zero3CpuOffloadConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_offload_params(mut self, offload: bool) -> Self {
self.offload_params = offload;
self
}
pub fn with_offload_grads(mut self, offload: bool) -> Self {
self.offload_grads = offload;
self
}
pub fn with_offload_optimizer_states(mut self, offload: bool) -> Self {
self.offload_optimizer_states = offload;
self
}
pub fn with_cpu_memory_budget(mut self, budget: usize) -> Self {
self.cpu_memory_budget = budget;
self
}
pub fn with_gpu_param_memory_budget(mut self, budget: usize) -> Self {
self.gpu_param_memory_budget = budget;
self
}
pub fn with_prefetch_buffer_size(mut self, size: usize) -> Self {
self.prefetch_buffer_size = size;
self
}
pub fn with_compression(mut self, compression: CpuCompressionMethod) -> Self {
self.cpu_compression = compression;
self
}
pub fn with_memory_strategy(mut self, strategy: AutoMemoryStrategy) -> Self {
self.auto_memory_management = strategy;
self
}
pub fn with_async_prefetch(mut self, async_prefetch: bool) -> Self {
self.async_prefetch = async_prefetch;
self
}
pub fn with_overlap_computation(mut self, overlap: bool) -> Self {
self.overlap_computation = overlap;
self
}
pub fn with_pin_cpu_memory(mut self, pin: bool) -> Self {
self.pin_cpu_memory = pin;
self
}
pub fn validate(&self) -> Result<(), String> {
if self.cpu_memory_budget == 0 {
return Err("CPU memory budget cannot be zero".to_string());
}
if self.gpu_param_memory_budget == 0 {
return Err("GPU parameter memory budget cannot be zero".to_string());
}
if self.prefetch_buffer_size == 0 {
return Err("Prefetch buffer size cannot be zero".to_string());
}
if self.max_gpu_memory_mb == 0 {
return Err("Maximum GPU memory cannot be zero".to_string());
}
if self.max_cpu_memory_mb == 0 {
return Err("Maximum CPU memory cannot be zero".to_string());
}
let gpu_budget_mb = self.gpu_param_memory_budget / (1024 * 1024);
if gpu_budget_mb > self.max_gpu_memory_mb {
return Err(format!(
"GPU parameter memory budget ({} MB) exceeds maximum GPU memory ({} MB)",
gpu_budget_mb, self.max_gpu_memory_mb
));
}
let cpu_budget_mb = self.cpu_memory_budget / (1024 * 1024);
if cpu_budget_mb > self.max_cpu_memory_mb {
return Err(format!(
"CPU memory budget ({} MB) exceeds maximum CPU memory ({} MB)",
cpu_budget_mb, self.max_cpu_memory_mb
));
}
Ok(())
}
pub fn compression_ratio(&self) -> f32 {
match self.cpu_compression {
CpuCompressionMethod::None => 1.0,
CpuCompressionMethod::FP16 => 0.5,
CpuCompressionMethod::BF16 => 0.5,
CpuCompressionMethod::INT8 => 0.25,
CpuCompressionMethod::Quantization => 0.25,
CpuCompressionMethod::LosslessCompression => 0.7, }
}
pub fn effective_cpu_memory_budget(&self) -> usize {
(self.cpu_memory_budget as f32 / self.compression_ratio()) as usize
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CpuCompressionMethod {
None,
FP16,
BF16,
INT8,
Quantization,
LosslessCompression,
}
impl CpuCompressionMethod {
pub fn ratio(&self) -> f32 {
match self {
CpuCompressionMethod::None => 1.0,
CpuCompressionMethod::FP16 => 0.5,
CpuCompressionMethod::BF16 => 0.5,
CpuCompressionMethod::INT8 => 0.25,
CpuCompressionMethod::Quantization => 0.25,
CpuCompressionMethod::LosslessCompression => 0.7,
}
}
pub fn is_lossy(&self) -> bool {
matches!(
self,
CpuCompressionMethod::FP16
| CpuCompressionMethod::BF16
| CpuCompressionMethod::INT8
| CpuCompressionMethod::Quantization
)
}
pub fn description(&self) -> &'static str {
match self {
CpuCompressionMethod::None => "No compression",
CpuCompressionMethod::FP16 => "16-bit floating point",
CpuCompressionMethod::BF16 => "BFloat16",
CpuCompressionMethod::INT8 => "8-bit integer quantization",
CpuCompressionMethod::Quantization => "Advanced quantization",
CpuCompressionMethod::LosslessCompression => "Lossless compression",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AutoMemoryStrategy {
Conservative,
Balanced,
Aggressive,
Extreme,
}
impl AutoMemoryStrategy {
pub fn pressure_threshold(&self) -> f32 {
match self {
AutoMemoryStrategy::Conservative => 0.9, AutoMemoryStrategy::Balanced => 0.75, AutoMemoryStrategy::Aggressive => 0.6, AutoMemoryStrategy::Extreme => 0.4, }
}
pub fn aggressiveness(&self) -> f32 {
match self {
AutoMemoryStrategy::Conservative => 0.2,
AutoMemoryStrategy::Balanced => 0.5,
AutoMemoryStrategy::Aggressive => 0.8,
AutoMemoryStrategy::Extreme => 1.0,
}
}
pub fn description(&self) -> &'static str {
match self {
AutoMemoryStrategy::Conservative => "Conservative - minimal offloading",
AutoMemoryStrategy::Balanced => "Balanced - moderate offloading",
AutoMemoryStrategy::Aggressive => "Aggressive - maximize CPU utilization",
AutoMemoryStrategy::Extreme => "Extreme - offload everything possible",
}
}
}
#[derive(Debug, Clone)]
pub struct Zero3RankMapping {
rank: usize,
world_size: usize,
}
impl Zero3RankMapping {
pub fn new(rank: usize, world_size: usize) -> Self {
assert!(rank < world_size, "Rank must be less than world size");
Self { rank, world_size }
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn world_size(&self) -> usize {
self.world_size
}
pub fn owns_partition(&self, partition_idx: usize) -> bool {
partition_idx % self.world_size == self.rank
}
pub fn get_parameter_owner(&self, param_idx: usize) -> usize {
param_idx % self.world_size
}
pub fn owned_partitions(&self, total_partitions: usize) -> Vec<usize> {
(0..total_partitions)
.filter(|&i| self.owns_partition(i))
.collect()
}
pub fn owned_partition_count(&self, total_partitions: usize) -> usize {
let base_count = total_partitions / self.world_size;
let remainder = total_partitions % self.world_size;
if self.rank < remainder {
base_count + 1
} else {
base_count
}
}
pub fn global_to_local_partition(&self, global_idx: usize) -> Option<usize> {
if self.owns_partition(global_idx) {
Some(global_idx / self.world_size)
} else {
None
}
}
pub fn local_to_global_partition(&self, local_idx: usize) -> usize {
local_idx * self.world_size + self.rank
}
pub fn communication_group(&self, param_indices: &[usize]) -> Vec<usize> {
let mut ranks = std::collections::HashSet::new();
for ¶m_idx in param_indices {
ranks.insert(self.get_parameter_owner(param_idx));
}
let mut result: Vec<usize> = ranks.into_iter().collect();
result.sort();
result
}
}
#[derive(Debug)]
pub struct ModelParameters {
pub parameter_count: usize,
pub parameter_names: Vec<String>,
pub parameter_shapes: HashMap<String, Vec<usize>>,
pub total_memory_bytes: usize,
}
impl ModelParameters {
pub fn new() -> Self {
Self {
parameter_count: 0,
parameter_names: Vec::new(),
parameter_shapes: HashMap::new(),
total_memory_bytes: 0,
}
}
pub fn add_parameter(&mut self, name: String, shape: Vec<usize>) {
let param_size = shape.iter().product::<usize>();
self.parameter_count += param_size;
self.total_memory_bytes += param_size * std::mem::size_of::<f32>();
self.parameter_shapes.insert(name.clone(), shape);
self.parameter_names.push(name);
}
pub fn has_parameter(&self, name: &str) -> bool {
self.parameter_shapes.contains_key(name)
}
pub fn add_parameter_with_size(
&mut self,
name: String,
shape: Vec<usize>,
element_size: usize,
) {
let param_size = shape.iter().product::<usize>();
self.parameter_count += param_size;
self.total_memory_bytes += param_size * element_size;
self.parameter_shapes.insert(name.clone(), shape);
self.parameter_names.push(name);
}
pub fn get_parameter_shape(&self, name: &str) -> Option<&Vec<usize>> {
self.parameter_shapes.get(name)
}
pub fn get_parameter_size(&self, name: &str) -> Option<usize> {
self.parameter_shapes
.get(name)
.map(|shape| shape.iter().product::<usize>())
}
pub fn total_parameters(&self) -> usize {
self.parameter_names.len()
}
pub fn memory_usage_mb(&self) -> f64 {
self.total_memory_bytes as f64 / (1024.0 * 1024.0)
}
pub fn get_statistics(&self) -> ModelParameterStats {
if self.parameter_names.is_empty() {
return ModelParameterStats::default();
}
let mut sizes: Vec<usize> = self
.parameter_shapes
.values()
.map(|shape| shape.iter().product::<usize>())
.collect();
sizes.sort();
let total_elements = sizes.iter().sum::<usize>();
let mean_size = total_elements as f64 / sizes.len() as f64;
let median_size = if sizes.len() % 2 == 0 {
(sizes[sizes.len() / 2 - 1] + sizes[sizes.len() / 2]) as f64 / 2.0
} else {
sizes[sizes.len() / 2] as f64
};
ModelParameterStats {
total_parameters: self.parameter_names.len(),
total_elements,
mean_parameter_size: mean_size,
median_parameter_size: median_size,
min_parameter_size: *sizes.first().unwrap_or(&0),
max_parameter_size: *sizes.last().unwrap_or(&0),
total_memory_bytes: self.total_memory_bytes,
}
}
}
impl Default for ModelParameters {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ModelParameterStats {
pub total_parameters: usize,
pub total_elements: usize,
pub mean_parameter_size: f64,
pub median_parameter_size: f64,
pub min_parameter_size: usize,
pub max_parameter_size: usize,
pub total_memory_bytes: usize,
}
impl Default for ModelParameterStats {
fn default() -> Self {
Self {
total_parameters: 0,
total_elements: 0,
mean_parameter_size: 0.0,
median_parameter_size: 0.0,
min_parameter_size: 0,
max_parameter_size: 0,
total_memory_bytes: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zero3_config_default() {
let config = Zero3CpuOffloadConfig::default();
assert!(config.offload_params);
assert!(config.offload_grads);
assert!(config.offload_optimizer_states);
assert!(config.async_prefetch);
assert_eq!(config.cpu_compression, CpuCompressionMethod::None);
assert_eq!(
config.auto_memory_management,
AutoMemoryStrategy::Aggressive
);
}
#[test]
fn test_zero3_config_builder() {
let config = Zero3CpuOffloadConfig::new()
.with_offload_params(false)
.with_compression(CpuCompressionMethod::FP16)
.with_memory_strategy(AutoMemoryStrategy::Conservative)
.with_prefetch_buffer_size(32);
assert!(!config.offload_params);
assert_eq!(config.cpu_compression, CpuCompressionMethod::FP16);
assert_eq!(
config.auto_memory_management,
AutoMemoryStrategy::Conservative
);
assert_eq!(config.prefetch_buffer_size, 32);
}
#[test]
fn test_zero3_config_validation() {
let config = Zero3CpuOffloadConfig::default();
assert!(config.validate().is_ok());
let mut invalid_config = config.clone();
invalid_config.cpu_memory_budget = 0;
assert!(invalid_config.validate().is_err());
let mut invalid_config = config.clone();
invalid_config.gpu_param_memory_budget = 0;
assert!(invalid_config.validate().is_err());
}
#[test]
fn test_compression_methods() {
assert_eq!(CpuCompressionMethod::None.ratio(), 1.0);
assert_eq!(CpuCompressionMethod::FP16.ratio(), 0.5);
assert_eq!(CpuCompressionMethod::INT8.ratio(), 0.25);
assert!(!CpuCompressionMethod::None.is_lossy());
assert!(CpuCompressionMethod::FP16.is_lossy());
assert!(!CpuCompressionMethod::LosslessCompression.is_lossy());
}
#[test]
fn test_memory_strategies() {
assert_eq!(AutoMemoryStrategy::Conservative.pressure_threshold(), 0.9);
assert_eq!(AutoMemoryStrategy::Aggressive.pressure_threshold(), 0.6);
assert_eq!(AutoMemoryStrategy::Conservative.aggressiveness(), 0.2);
assert_eq!(AutoMemoryStrategy::Extreme.aggressiveness(), 1.0);
}
#[test]
fn test_rank_mapping() {
let mapping = Zero3RankMapping::new(1, 4);
assert_eq!(mapping.rank(), 1);
assert_eq!(mapping.world_size(), 4);
assert!(mapping.owns_partition(1)); assert!(mapping.owns_partition(5)); assert!(!mapping.owns_partition(0)); assert!(!mapping.owns_partition(2));
assert_eq!(mapping.get_parameter_owner(5), 1);
assert_eq!(mapping.get_parameter_owner(8), 0);
let owned = mapping.owned_partitions(10);
assert_eq!(owned, vec![1, 5, 9]);
assert_eq!(mapping.owned_partition_count(10), 3); assert_eq!(mapping.owned_partition_count(8), 2); }
#[test]
fn test_model_parameters() {
let mut params = ModelParameters::new();
params.add_parameter("layer1.weight".to_string(), vec![100, 50]);
params.add_parameter("layer1.bias".to_string(), vec![50]);
assert_eq!(params.total_parameters(), 2);
assert_eq!(params.parameter_count, 5050); assert_eq!(params.get_parameter_size("layer1.weight"), Some(5000));
assert_eq!(params.get_parameter_size("layer1.bias"), Some(50));
let stats = params.get_statistics();
assert_eq!(stats.total_parameters, 2);
assert_eq!(stats.total_elements, 5050);
assert_eq!(stats.min_parameter_size, 50);
assert_eq!(stats.max_parameter_size, 5000);
}
#[test]
fn test_rank_mapping_communication_group() {
let mapping = Zero3RankMapping::new(1, 4);
let param_indices = vec![0, 1, 4, 5, 8, 9];
let comm_group = mapping.communication_group(¶m_indices);
assert_eq!(comm_group, vec![0, 1]);
}
#[test]
fn test_effective_cpu_memory_budget() {
let config = Zero3CpuOffloadConfig::new()
.with_cpu_memory_budget(1000)
.with_compression(CpuCompressionMethod::FP16);
assert_eq!(config.effective_cpu_memory_budget(), 2000);
let config_no_compression = Zero3CpuOffloadConfig::new()
.with_cpu_memory_budget(1000)
.with_compression(CpuCompressionMethod::None);
assert_eq!(config_no_compression.effective_cpu_memory_budget(), 1000);
}
}