#![allow(dead_code)]
use crate::{TorshDistributedError, TorshResult};
use log::info;
use std::collections::HashMap;
use std::sync::{Mutex, RwLock};
use torsh_tensor::Tensor;
use super::config::{Zero3CpuOffloadConfig, Zero3RankMapping};
pub struct GradientPartitioner {
config: Zero3CpuOffloadConfig,
rank_mapping: Zero3RankMapping,
partition_metadata: Mutex<HashMap<String, GradientPartitionMetadata>>,
}
impl GradientPartitioner {
pub fn new(
config: &Zero3CpuOffloadConfig,
rank_mapping: &Zero3RankMapping,
) -> TorshResult<Self> {
info!(
" Gradient Partitioner initialized for rank {}/{}",
rank_mapping.rank(),
rank_mapping.world_size()
);
Ok(Self {
config: config.clone(),
rank_mapping: rank_mapping.clone(),
partition_metadata: Mutex::new(HashMap::new()),
})
}
pub fn partition_gradients(
&self,
layer_name: &str,
grads: &ParameterGradients,
) -> TorshResult<Vec<GradientPartition>> {
let mut partitions = Vec::new();
let weight_grad = &grads.weight_grad;
let grad_data = weight_grad.to_vec()?;
let grad_shape_binding = weight_grad.shape();
let _grad_shape = grad_shape_binding.dims();
let total_elements = grad_data.len();
let elements_per_partition = total_elements.div_ceil(self.rank_mapping.world_size());
info!(
" Partitioning gradients for '{}': {} elements across {} ranks",
layer_name,
total_elements,
self.rank_mapping.world_size()
);
for rank in 0..self.rank_mapping.world_size() {
let start_idx = rank * elements_per_partition;
let end_idx = ((rank + 1) * elements_per_partition).min(total_elements);
if start_idx < total_elements {
let partition_data = grad_data[start_idx..end_idx].to_vec();
let partition_size = end_idx - start_idx;
let partition_shape = vec![partition_size];
let partition_tensor = Tensor::from_vec(partition_data, &partition_shape)?;
partitions.push(GradientPartition {
layer_name: layer_name.to_string(),
rank,
partition_idx: rank,
start_idx,
end_idx,
size_elements: partition_size,
weight_gradient: partition_tensor,
bias_gradient: None, });
} else {
let empty_tensor = Tensor::from_vec(vec![], &[0])?;
partitions.push(GradientPartition {
layer_name: layer_name.to_string(),
rank,
partition_idx: rank,
start_idx: total_elements,
end_idx: total_elements,
size_elements: 0,
weight_gradient: empty_tensor,
bias_gradient: None,
});
}
}
if let Some(ref bias_grad) = grads.bias_grad {
let bias_data = bias_grad.to_vec()?;
let bias_elements = bias_data.len();
let bias_elements_per_partition =
bias_elements.div_ceil(self.rank_mapping.world_size());
for (rank, partition) in partitions.iter_mut().enumerate() {
let bias_start = rank * bias_elements_per_partition;
let bias_end = ((rank + 1) * bias_elements_per_partition).min(bias_elements);
if bias_start < bias_elements {
let bias_partition_data = bias_data[bias_start..bias_end].to_vec();
let bias_partition_shape = vec![bias_partition_data.len()];
let bias_partition_tensor =
Tensor::from_vec(bias_partition_data, &bias_partition_shape)?;
partition.bias_gradient = Some(bias_partition_tensor);
}
}
}
let metadata = GradientPartitionMetadata {
layer_name: layer_name.to_string(),
total_weight_elements: total_elements,
total_bias_elements: grads.bias_grad.as_ref().map(|b| b.numel()).unwrap_or(0),
elements_per_partition,
world_size: self.rank_mapping.world_size(),
};
{
let mut meta = self
.partition_metadata
.lock()
.expect("lock should not be poisoned");
meta.insert(layer_name.to_string(), metadata);
}
Ok(partitions)
}
pub fn get_owned_partition<'a>(
&self,
layer_name: &str,
partitions: &'a [GradientPartition],
) -> Option<&'a GradientPartition> {
partitions
.iter()
.find(|p| p.rank == self.rank_mapping.rank() && p.layer_name == layer_name)
}
pub fn get_layer_metadata(&self, layer_name: &str) -> Option<GradientPartitionMetadata> {
let meta = self
.partition_metadata
.lock()
.expect("lock should not be poisoned");
meta.get(layer_name).cloned()
}
pub fn calculate_owned_gradient_memory(
&self,
layer_gradients: &HashMap<String, Vec<GradientPartition>>,
) -> usize {
let mut total_memory = 0;
for partitions in layer_gradients.values() {
if let Some(owned_partition) = partitions
.iter()
.find(|p| p.rank == self.rank_mapping.rank())
{
total_memory += owned_partition.memory_size();
}
}
total_memory
}
pub fn get_statistics(&self) -> GradientPartitionerStats {
let meta = self
.partition_metadata
.lock()
.expect("lock should not be poisoned");
let total_layers = meta.len();
let total_elements: usize = meta
.values()
.map(|m| m.total_weight_elements + m.total_bias_elements)
.sum();
GradientPartitionerStats {
total_layers,
total_elements,
rank: self.rank_mapping.rank(),
world_size: self.rank_mapping.world_size(),
}
}
}
#[derive(Debug, Clone)]
pub struct GradientPartitionMetadata {
pub layer_name: String,
pub total_weight_elements: usize,
pub total_bias_elements: usize,
pub elements_per_partition: usize,
pub world_size: usize,
}
#[derive(Debug, Clone)]
pub struct GradientPartition {
pub layer_name: String,
pub rank: usize,
pub partition_idx: usize,
pub start_idx: usize,
pub end_idx: usize,
pub size_elements: usize,
pub weight_gradient: Tensor<f32>,
pub bias_gradient: Option<Tensor<f32>>,
}
impl GradientPartition {
pub fn memory_size(&self) -> usize {
let weight_size = self.weight_gradient.numel() * std::mem::size_of::<f32>();
let bias_size = self
.bias_gradient
.as_ref()
.map(|b| b.numel() * std::mem::size_of::<f32>())
.unwrap_or(0);
weight_size + bias_size
}
pub fn is_empty(&self) -> bool {
self.size_elements == 0
}
pub fn total_elements(&self) -> usize {
let weight_elements = self.weight_gradient.numel();
let bias_elements = self.bias_gradient.as_ref().map(|b| b.numel()).unwrap_or(0);
weight_elements + bias_elements
}
pub fn has_bias(&self) -> bool {
self.bias_gradient.is_some()
}
}
#[derive(Debug, Clone)]
pub struct GradientPartitionerStats {
pub total_layers: usize,
pub total_elements: usize,
pub rank: usize,
pub world_size: usize,
}
#[derive(Debug, Clone)]
pub struct ParameterGradients {
pub weight_grad: Tensor<f32>,
pub bias_grad: Option<Tensor<f32>>,
}
impl ParameterGradients {
pub fn new(weight_grad: Tensor<f32>, bias_grad: Option<Tensor<f32>>) -> Self {
Self {
weight_grad,
bias_grad,
}
}
pub fn memory_size(&self) -> usize {
let weight_size = self.weight_grad.numel() * std::mem::size_of::<f32>();
let bias_size = self
.bias_grad
.as_ref()
.map(|b| b.numel() * std::mem::size_of::<f32>())
.unwrap_or(0);
weight_size + bias_size
}
pub fn total_elements(&self) -> usize {
let weight_elements = self.weight_grad.numel();
let bias_elements = self.bias_grad.as_ref().map(|b| b.numel()).unwrap_or(0);
weight_elements + bias_elements
}
pub fn has_bias(&self) -> bool {
self.bias_grad.is_some()
}
}
pub struct CpuGradientStore {
config: Zero3CpuOffloadConfig,
stored_gradients: RwLock<HashMap<String, Tensor<f32>>>,
memory_used: std::sync::atomic::AtomicUsize,
gradient_metadata: Mutex<HashMap<String, GradientStoreMetadata>>,
}
impl CpuGradientStore {
pub fn new(config: &Zero3CpuOffloadConfig) -> TorshResult<Self> {
info!(
" CPU Gradient Store initialized with {} MB budget",
config.cpu_memory_budget / (1024 * 1024)
);
Ok(Self {
config: config.clone(),
stored_gradients: RwLock::new(HashMap::new()),
memory_used: std::sync::atomic::AtomicUsize::new(0),
gradient_metadata: Mutex::new(HashMap::new()),
})
}
pub async fn store(
&self,
layer_name: &str,
partition_idx: usize,
gradient: &Tensor<f32>,
) -> TorshResult<()> {
let key = format!("{}_{}", layer_name, partition_idx);
let grad_size = gradient.numel() * std::mem::size_of::<f32>();
let new_memory_usage = self.memory_used() + grad_size;
if new_memory_usage > self.config.cpu_memory_budget {
return Err(TorshDistributedError::memory_allocation_failed(
new_memory_usage,
"CPU memory budget exceeded for gradient storage",
));
}
{
let mut grads = self
.stored_gradients
.write()
.expect("lock should not be poisoned");
grads.insert(key.clone(), gradient.clone());
}
{
let mut metadata = self
.gradient_metadata
.lock()
.expect("lock should not be poisoned");
metadata.insert(
key.clone(),
GradientStoreMetadata {
layer_name: layer_name.to_string(),
partition_idx,
size_bytes: grad_size,
elements: gradient.numel(),
},
);
}
self.memory_used
.fetch_add(grad_size, std::sync::atomic::Ordering::SeqCst);
info!(
" Stored gradient partition '{}_{}' in CPU ({} bytes)",
layer_name, partition_idx, grad_size
);
Ok(())
}
pub async fn get_gradient(
&self,
layer_name: &str,
partition_idx: usize,
) -> TorshResult<Option<Tensor<f32>>> {
let key = format!("{}_{}", layer_name, partition_idx);
let grads = self
.stored_gradients
.read()
.expect("lock should not be poisoned");
Ok(grads.get(&key).cloned())
}
pub async fn get_all_gradients(&self) -> TorshResult<HashMap<String, Tensor<f32>>> {
let grads = self
.stored_gradients
.read()
.expect("lock should not be poisoned");
Ok(grads.clone())
}
pub async fn get_owned_gradients(
&self,
rank: usize,
world_size: usize,
) -> TorshResult<HashMap<String, Tensor<f32>>> {
let grads = self
.stored_gradients
.read()
.expect("lock should not be poisoned");
let owned_grads: HashMap<String, Tensor<f32>> = grads
.iter()
.filter(|(key, _)| {
if let Some((_layer, partition_str)) = key.split_once('_') {
if let Ok(partition_idx) = partition_str.parse::<usize>() {
return partition_idx % world_size == rank;
}
}
false
})
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Ok(owned_grads)
}
pub async fn store_reduced_gradient(
&self,
key: &str,
gradient: &Tensor<f32>,
) -> TorshResult<()> {
let grad_size = gradient.numel() * std::mem::size_of::<f32>();
{
let mut grads = self
.stored_gradients
.write()
.expect("lock should not be poisoned");
if let Some(old_gradient) = grads.insert(key.to_string(), gradient.clone()) {
let old_size = old_gradient.numel() * std::mem::size_of::<f32>();
self.memory_used
.fetch_sub(old_size, std::sync::atomic::Ordering::SeqCst);
} else {
self.memory_used
.fetch_add(grad_size, std::sync::atomic::Ordering::SeqCst);
}
}
info!(
" Stored reduced gradient '{}' in CPU ({} bytes)",
key, grad_size
);
Ok(())
}
pub async fn remove_gradient(
&self,
layer_name: &str,
partition_idx: usize,
) -> TorshResult<Option<Tensor<f32>>> {
let key = format!("{}_{}", layer_name, partition_idx);
let removed_gradient = {
let mut grads = self
.stored_gradients
.write()
.expect("lock should not be poisoned");
grads.remove(&key)
};
if let Some(ref gradient) = removed_gradient {
let grad_size = gradient.numel() * std::mem::size_of::<f32>();
self.memory_used
.fetch_sub(grad_size, std::sync::atomic::Ordering::SeqCst);
let mut metadata = self
.gradient_metadata
.lock()
.expect("lock should not be poisoned");
metadata.remove(&key);
}
Ok(removed_gradient)
}
pub async fn clear(&self) -> TorshResult<()> {
{
let mut grads = self
.stored_gradients
.write()
.expect("lock should not be poisoned");
grads.clear();
}
{
let mut metadata = self
.gradient_metadata
.lock()
.expect("lock should not be poisoned");
metadata.clear();
}
self.memory_used
.store(0, std::sync::atomic::Ordering::SeqCst);
info!(" 🗑️ Cleared all gradients from CPU store");
Ok(())
}
pub fn memory_used(&self) -> usize {
self.memory_used.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn gradient_count(&self) -> usize {
self.stored_gradients
.read()
.expect("lock should not be poisoned")
.len()
}
pub fn memory_utilization(&self) -> f32 {
(self.memory_used() as f32) / (self.config.cpu_memory_budget as f32)
}
pub fn get_statistics(&self) -> CpuGradientStoreStats {
let grads = self
.stored_gradients
.read()
.expect("lock should not be poisoned");
let metadata = self
.gradient_metadata
.lock()
.expect("lock should not be poisoned");
let total_elements: usize = metadata.values().map(|m| m.elements).sum();
CpuGradientStoreStats {
gradient_count: grads.len(),
memory_used_bytes: self.memory_used(),
memory_budget_bytes: self.config.cpu_memory_budget,
memory_utilization: self.memory_utilization(),
total_elements,
}
}
}
#[derive(Debug, Clone)]
struct GradientStoreMetadata {
layer_name: String,
partition_idx: usize,
size_bytes: usize,
elements: usize,
}
#[derive(Debug, Clone)]
pub struct CpuGradientStoreStats {
pub gradient_count: usize,
pub memory_used_bytes: usize,
pub memory_budget_bytes: usize,
pub memory_utilization: f32,
pub total_elements: usize,
}
pub struct GpuGradientBuffer {
config: Zero3CpuOffloadConfig,
stored_gradients: RwLock<HashMap<String, Tensor<f32>>>,
memory_used: std::sync::atomic::AtomicUsize,
buffer_metadata: Mutex<HashMap<String, GradientBufferMetadata>>,
}
impl GpuGradientBuffer {
pub fn new(config: &Zero3CpuOffloadConfig) -> TorshResult<Self> {
info!(
" GPU Gradient Buffer initialized with {} MB budget",
config.gpu_param_memory_budget / (1024 * 1024)
);
Ok(Self {
config: config.clone(),
stored_gradients: RwLock::new(HashMap::new()),
memory_used: std::sync::atomic::AtomicUsize::new(0),
buffer_metadata: Mutex::new(HashMap::new()),
})
}
pub async fn store(
&self,
layer_name: &str,
partition_idx: usize,
gradient: &Tensor<f32>,
) -> TorshResult<()> {
let key = format!("{}_{}", layer_name, partition_idx);
let grad_size = gradient.numel() * std::mem::size_of::<f32>();
let new_memory_usage = self.memory_used() + grad_size;
if new_memory_usage > self.config.gpu_param_memory_budget {
return Err(TorshDistributedError::memory_allocation_failed(
new_memory_usage,
"GPU memory budget exceeded for gradient buffer",
));
}
{
let mut grads = self
.stored_gradients
.write()
.expect("lock should not be poisoned");
grads.insert(key.clone(), gradient.clone());
}
{
let mut metadata = self
.buffer_metadata
.lock()
.expect("lock should not be poisoned");
metadata.insert(
key.clone(),
GradientBufferMetadata {
layer_name: layer_name.to_string(),
partition_idx,
size_bytes: grad_size,
elements: gradient.numel(),
},
);
}
self.memory_used
.fetch_add(grad_size, std::sync::atomic::Ordering::SeqCst);
info!(
" Buffered gradient partition '{}_{}' in GPU ({} bytes)",
layer_name, partition_idx, grad_size
);
Ok(())
}
pub async fn get_gradient(
&self,
layer_name: &str,
partition_idx: usize,
) -> TorshResult<Option<Tensor<f32>>> {
let key = format!("{}_{}", layer_name, partition_idx);
let grads = self
.stored_gradients
.read()
.expect("lock should not be poisoned");
Ok(grads.get(&key).cloned())
}
pub async fn get_all_gradients(&self) -> TorshResult<HashMap<String, Tensor<f32>>> {
let grads = self
.stored_gradients
.read()
.expect("lock should not be poisoned");
Ok(grads.clone())
}
pub async fn remove_gradient(
&self,
layer_name: &str,
partition_idx: usize,
) -> TorshResult<Option<Tensor<f32>>> {
let key = format!("{}_{}", layer_name, partition_idx);
let removed_gradient = {
let mut grads = self
.stored_gradients
.write()
.expect("lock should not be poisoned");
grads.remove(&key)
};
if let Some(ref gradient) = removed_gradient {
let grad_size = gradient.numel() * std::mem::size_of::<f32>();
self.memory_used
.fetch_sub(grad_size, std::sync::atomic::Ordering::SeqCst);
let mut metadata = self
.buffer_metadata
.lock()
.expect("lock should not be poisoned");
metadata.remove(&key);
}
Ok(removed_gradient)
}
pub fn memory_used(&self) -> usize {
self.memory_used.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn gradient_count(&self) -> usize {
self.stored_gradients
.read()
.expect("lock should not be poisoned")
.len()
}
pub fn clear(&self) -> TorshResult<()> {
{
let mut grads = self
.stored_gradients
.write()
.expect("lock should not be poisoned");
grads.clear();
}
{
let mut metadata = self
.buffer_metadata
.lock()
.expect("lock should not be poisoned");
metadata.clear();
}
self.memory_used
.store(0, std::sync::atomic::Ordering::SeqCst);
info!(" 🗑️ Cleared all gradients from GPU buffer");
Ok(())
}
pub fn memory_utilization(&self) -> f32 {
(self.memory_used() as f32) / (self.config.gpu_param_memory_budget as f32)
}
pub fn get_statistics(&self) -> GpuGradientBufferStats {
let grads = self
.stored_gradients
.read()
.expect("lock should not be poisoned");
let metadata = self
.buffer_metadata
.lock()
.expect("lock should not be poisoned");
let total_elements: usize = metadata.values().map(|m| m.elements).sum();
GpuGradientBufferStats {
gradient_count: grads.len(),
memory_used_bytes: self.memory_used(),
memory_budget_bytes: self.config.gpu_param_memory_budget,
memory_utilization: self.memory_utilization(),
total_elements,
}
}
}
#[derive(Debug, Clone)]
struct GradientBufferMetadata {
layer_name: String,
partition_idx: usize,
size_bytes: usize,
elements: usize,
}
#[derive(Debug, Clone)]
pub struct GpuGradientBufferStats {
pub gradient_count: usize,
pub memory_used_bytes: usize,
pub memory_budget_bytes: usize,
pub memory_utilization: f32,
pub total_elements: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parameter_gradients() -> Result<(), Box<dyn std::error::Error>> {
use torsh_tensor::Tensor;
let weight_grad = Tensor::zeros(&[10, 5], torsh_core::DeviceType::Cpu)?;
let bias_grad = Some(Tensor::zeros(&[5], torsh_core::DeviceType::Cpu)?);
let param_grads = ParameterGradients::new(weight_grad, bias_grad);
assert!(param_grads.has_bias());
assert_eq!(param_grads.total_elements(), 55); Ok(())
}
#[test]
fn test_gradient_partition() -> Result<(), Box<dyn std::error::Error>> {
use torsh_tensor::Tensor;
let weight_grad = Tensor::zeros(&[20], torsh_core::DeviceType::Cpu)?;
let bias_grad = Some(Tensor::zeros(&[5], torsh_core::DeviceType::Cpu)?);
let partition = GradientPartition {
layer_name: "layer1".to_string(),
rank: 0,
partition_idx: 0,
start_idx: 0,
end_idx: 20,
size_elements: 20,
weight_gradient: weight_grad,
bias_gradient: bias_grad,
};
assert!(partition.has_bias());
assert_eq!(partition.total_elements(), 25); assert!(!partition.is_empty());
Ok(())
}
#[test]
fn test_gradient_partitioner() -> Result<(), Box<dyn std::error::Error>> {
use torsh_tensor::Tensor;
let config = Zero3CpuOffloadConfig::default();
let rank_mapping = Zero3RankMapping::new(0, 4);
let partitioner = GradientPartitioner::new(&config, &rank_mapping)
.expect("Gradient Partitioner should succeed");
let weight_grad = Tensor::ones(&[100], torsh_core::DeviceType::Cpu)?;
let param_grads = ParameterGradients::new(weight_grad, None);
let partitions = partitioner
.partition_gradients("layer1", ¶m_grads)
.expect("operation should succeed");
assert_eq!(partitions.len(), 4); assert_eq!(partitions[0].rank, 0);
assert_eq!(partitions[0].size_elements, 25);
let owned = partitioner.get_owned_partition("layer1", &partitions);
assert!(owned.is_some());
assert_eq!(owned.expect("operation should succeed").rank, 0);
Ok(())
}
#[tokio::test]
async fn test_cpu_gradient_store() -> Result<(), Box<dyn std::error::Error>> {
use torsh_tensor::Tensor;
let config = Zero3CpuOffloadConfig::default();
let store = CpuGradientStore::new(&config).expect("Cpu Gradient Store should succeed");
let gradient = Tensor::ones(&[100], torsh_core::DeviceType::Cpu)?;
store
.store("layer1", 0, &gradient)
.await
.expect("operation should succeed");
assert_eq!(store.gradient_count(), 1);
let retrieved = store
.get_gradient("layer1", 0)
.await
.expect("operation should succeed");
assert!(retrieved.is_some());
let removed = store
.remove_gradient("layer1", 0)
.await
.expect("operation should succeed");
assert!(removed.is_some());
assert_eq!(store.gradient_count(), 0);
Ok(())
}
#[tokio::test]
async fn test_gpu_gradient_buffer() -> Result<(), Box<dyn std::error::Error>> {
use torsh_tensor::Tensor;
let config = Zero3CpuOffloadConfig::default();
let buffer = GpuGradientBuffer::new(&config).expect("Gpu Gradient Buffer should succeed");
let gradient = Tensor::ones(&[50], torsh_core::DeviceType::Cpu)?;
buffer
.store("layer1", 0, &gradient)
.await
.expect("operation should succeed");
assert_eq!(buffer.gradient_count(), 1);
let retrieved = buffer
.get_gradient("layer1", 0)
.await
.expect("operation should succeed");
assert!(retrieved.is_some());
buffer.clear().expect("clear should succeed");
assert_eq!(buffer.gradient_count(), 0);
Ok(())
}
#[tokio::test]
async fn test_gradient_store_owned_gradients() -> Result<(), Box<dyn std::error::Error>> {
use torsh_tensor::Tensor;
let config = Zero3CpuOffloadConfig::default();
let store = CpuGradientStore::new(&config).expect("Cpu Gradient Store should succeed");
for i in 0..8 {
let gradient = Tensor::ones(&[10], torsh_core::DeviceType::Cpu)?;
store
.store("layer1", i, &gradient)
.await
.expect("operation should succeed");
}
let owned = store
.get_owned_gradients(0, 4)
.await
.expect("operation should succeed");
assert_eq!(owned.len(), 2);
assert!(owned.contains_key("layer1_0"));
assert!(owned.contains_key("layer1_4"));
Ok(())
}
}