#![allow(dead_code)]
use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use torsh_tensor::Tensor;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub method: CompressionMethod,
pub compression_ratio: f32,
pub error_feedback: bool,
pub error_feedback_momentum: f32,
pub memory_efficient: bool,
pub warmup_steps: usize,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
method: CompressionMethod::TopK { k: 0.1 },
compression_ratio: 0.1,
error_feedback: true,
error_feedback_momentum: 0.9,
memory_efficient: true,
warmup_steps: 100,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CompressionMethod {
TopK { k: f32 },
RandomK { k: f32 },
Threshold { threshold: f32 },
Quantization { bits: u8 },
SignSGD,
Sketching { sketch_size: usize },
PowerSGD { rank: usize },
TernaryQuant { threshold: f32 },
BimodalQuant { num_bins: usize },
NaturalCompression { compression_factor: f32 },
LayerwiseAdaptive { base_ratio: f32, sensitivity: f32 },
EF21 {
compression_ratio: f32,
momentum: f32,
},
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressedGradient {
pub method: CompressionMethod,
pub data: CompressedData,
pub original_shape: Vec<usize>,
pub metadata: CompressionMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CompressedData {
Sparse {
indices: Vec<usize>,
values: Vec<f32>,
},
Quantized {
values: Vec<u8>,
scale: f32,
zero_point: u8,
},
Signs { signs: Vec<bool>, norm: f32 },
LowRank {
left_factor: Vec<f32>,
right_factor: Vec<f32>,
rank: usize,
},
Sketch {
sketch: Vec<f32>,
hash_a: Vec<u32>,
hash_b: Vec<u32>,
},
Ternary { values: Vec<i8>, scale: f32 },
Bimodal {
bin_indices: Vec<u8>,
bin_centers: Vec<f32>,
},
Natural {
values: Vec<f32>,
frequencies: Vec<u32>,
codebook: Vec<f32>,
},
EF21 {
compressed_values: Vec<f32>,
error_feedback: Vec<f32>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionMetadata {
pub compression_ratio: f32,
pub error_norm: f32,
pub original_norm: f32,
pub timestamp: u64,
}
pub struct GradientCompressor {
config: CompressionConfig,
error_buffers: HashMap<String, Tensor>,
step_count: usize,
stats: CompressionStats,
}
#[derive(Debug, Clone, Default)]
pub struct CompressionStats {
pub total_compressions: u64,
pub avg_compression_ratio: f64,
pub total_communication_reduction: u64,
pub avg_error_norm: f64,
pub compression_time_ms: f64,
}
impl GradientCompressor {
pub fn new(config: CompressionConfig) -> Self {
info!(
"Initializing gradient compressor with method: {:?}",
config.method
);
Self {
config,
error_buffers: HashMap::new(),
step_count: 0,
stats: CompressionStats::default(),
}
}
pub fn compress(
&mut self,
gradient: &Tensor,
param_name: &str,
) -> TorshResult<CompressedGradient> {
let start_time = std::time::Instant::now();
if self.step_count < self.config.warmup_steps {
return self.no_compression(gradient, param_name);
}
let adjusted_gradient = if self.config.error_feedback {
self.apply_error_feedback(gradient, param_name)?
} else {
gradient.clone()
};
let compressed = match &self.config.method {
CompressionMethod::TopK { k } => self.compress_top_k(&adjusted_gradient, *k)?,
CompressionMethod::RandomK { k } => self.compress_random_k(&adjusted_gradient, *k)?,
CompressionMethod::Threshold { threshold } => {
self.compress_threshold(&adjusted_gradient, *threshold)?
}
CompressionMethod::Quantization { bits } => {
self.compress_quantization(&adjusted_gradient, *bits)?
}
CompressionMethod::SignSGD => self.compress_sign_sgd(&adjusted_gradient)?,
CompressionMethod::Sketching { sketch_size } => {
self.compress_sketching(&adjusted_gradient, *sketch_size)?
}
CompressionMethod::PowerSGD { rank } => {
self.compress_power_sgd(&adjusted_gradient, *rank)?
}
CompressionMethod::TernaryQuant { threshold } => {
self.compress_ternary(&adjusted_gradient, *threshold)?
}
CompressionMethod::BimodalQuant { num_bins } => {
self.compress_bimodal(&adjusted_gradient, *num_bins)?
}
CompressionMethod::NaturalCompression { compression_factor } => {
self.compress_natural(&adjusted_gradient, *compression_factor)?
}
CompressionMethod::LayerwiseAdaptive {
base_ratio,
sensitivity,
} => self.compress_layerwise_adaptive(
&adjusted_gradient,
*base_ratio,
*sensitivity,
param_name,
)?,
CompressionMethod::EF21 {
compression_ratio,
momentum,
} => self.compress_ef21(
&adjusted_gradient,
*compression_ratio,
*momentum,
param_name,
)?,
CompressionMethod::None => return self.no_compression(gradient, param_name),
};
if self.config.error_feedback {
self.update_error_feedback(&compressed, gradient, param_name)?;
}
let compression_time = start_time.elapsed().as_millis() as f64;
self.update_stats(&compressed, compression_time);
self.step_count += 1;
Ok(compressed)
}
pub fn decompress(&self, compressed: &CompressedGradient) -> TorshResult<Tensor> {
match &compressed.data {
CompressedData::Sparse { indices, values } => {
self.decompress_sparse(indices, values, &compressed.original_shape)
}
CompressedData::Quantized {
values,
scale,
zero_point,
} => self.decompress_quantized(values, *scale, *zero_point, &compressed.original_shape),
CompressedData::Signs { signs, norm } => {
self.decompress_sign_sgd(signs, *norm, &compressed.original_shape)
}
CompressedData::LowRank {
left_factor,
right_factor,
rank,
} => self.decompress_power_sgd(
left_factor,
right_factor,
*rank,
&compressed.original_shape,
),
CompressedData::Sketch {
sketch,
hash_a,
hash_b,
} => self.decompress_sketching(sketch, hash_a, hash_b, &compressed.original_shape),
CompressedData::Ternary { values, scale } => {
self.decompress_ternary(values, *scale, &compressed.original_shape)
}
CompressedData::Bimodal {
bin_indices,
bin_centers,
} => self.decompress_bimodal(bin_indices, bin_centers, &compressed.original_shape),
CompressedData::Natural {
values,
frequencies: _,
codebook,
} => self.decompress_natural(values, codebook, &compressed.original_shape),
CompressedData::EF21 {
compressed_values,
error_feedback: _,
} => self.decompress_ef21(compressed_values, &compressed.original_shape),
}
}
fn apply_error_feedback(&mut self, gradient: &Tensor, param_name: &str) -> TorshResult<Tensor> {
if let Some(error_buffer) = self.error_buffers.get(param_name) {
let scaled_error = error_buffer.mul_scalar(self.config.error_feedback_momentum)?;
Ok(gradient.add(&scaled_error)?)
} else {
Ok(gradient.clone())
}
}
fn update_error_feedback(
&mut self,
compressed: &CompressedGradient,
original: &Tensor,
param_name: &str,
) -> TorshResult<()> {
let decompressed = self.decompress(compressed)?;
let error = original.sub(&decompressed)?;
self.error_buffers.insert(param_name.to_string(), error);
Ok(())
}
fn compress_top_k(&self, gradient: &Tensor, k: f32) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let numel = flat_grad.numel();
let k_elements = ((numel as f32) * k).ceil() as usize;
let abs_grad = flat_grad.abs()?;
let grad_data = flat_grad.to_vec()?;
let abs_data = abs_grad.to_vec()?;
let mut indexed_values: Vec<(usize, f32)> =
abs_data.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut indices = Vec::new();
let mut values = Vec::new();
for &(idx, _) in indexed_values.iter().take(k_elements) {
indices.push(idx);
values.push(grad_data[idx]);
}
debug!("Top-K compression: kept {}/{} elements", k_elements, numel);
let original_norm = gradient.norm()?.item()?;
let compression_ratio = k;
Ok(CompressedGradient {
method: CompressionMethod::TopK { k },
data: CompressedData::Sparse { indices, values },
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio,
error_norm: 0.0, original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_random_k(&self, gradient: &Tensor, k: f32) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let numel = flat_grad.numel();
let k_elements = ((numel as f32) * k).ceil() as usize;
let grad_data = flat_grad.to_vec()?;
let mut indices = Vec::new();
let mut values = Vec::new();
let step = numel / k_elements.max(1);
for i in (0..numel).step_by(step).take(k_elements) {
indices.push(i);
values.push(grad_data[i]);
}
debug!(
"Random-K compression: kept {}/{} elements",
k_elements, numel
);
let original_norm = gradient.norm()?.item()?;
Ok(CompressedGradient {
method: CompressionMethod::RandomK { k },
data: CompressedData::Sparse { indices, values },
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio: k,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_threshold(
&self,
gradient: &Tensor,
threshold: f32,
) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let mut indices = Vec::new();
let mut values = Vec::new();
for (i, &value) in grad_data.iter().enumerate() {
if value.abs() >= threshold {
indices.push(i);
values.push(value);
}
}
let compression_ratio = indices.len() as f32 / grad_data.len() as f32;
debug!(
"Threshold compression: kept {}/{} elements",
indices.len(),
grad_data.len()
);
let original_norm = gradient.norm()?.item()?;
Ok(CompressedGradient {
method: CompressionMethod::Threshold { threshold },
data: CompressedData::Sparse { indices, values },
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_quantization(
&self,
gradient: &Tensor,
bits: u8,
) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let min_val = grad_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = grad_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let levels = (1 << bits) - 1;
let scale = (max_val - min_val) / levels as f32;
let zero_point = (-min_val / scale).round() as u8;
let mut quantized_values = Vec::new();
for &value in &grad_data {
let quantized = ((value / scale) + zero_point as f32)
.round()
.clamp(0.0, levels as f32) as u8;
quantized_values.push(quantized);
}
debug!("Quantization: {} bits, {} levels", bits, levels);
let original_norm = gradient.norm()?.item()?;
let compression_ratio = (bits as f32) / 32.0;
Ok(CompressedGradient {
method: CompressionMethod::Quantization { bits },
data: CompressedData::Quantized {
values: quantized_values,
scale,
zero_point,
},
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_sign_sgd(&self, gradient: &Tensor) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let norm = gradient.norm()?.item()?;
let signs: Vec<bool> = grad_data.iter().map(|&x| x >= 0.0).collect();
debug!(
"SignSGD compression: {} elements -> {} bits",
grad_data.len(),
signs.len()
);
Ok(CompressedGradient {
method: CompressionMethod::SignSGD,
data: CompressedData::Signs { signs, norm },
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio: 1.0 / 32.0, error_norm: 0.0,
original_norm: norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_sketching(
&self,
gradient: &Tensor,
sketch_size: usize,
) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let sketch: Vec<f32> = grad_data.iter().take(sketch_size).copied().collect();
let hash_a: Vec<u32> = (0..grad_data.len()).map(|i| (i * 17 + 23) as u32).collect();
let hash_b: Vec<u32> = (0..grad_data.len()).map(|i| (i * 37 + 41) as u32).collect();
let compression_ratio = sketch_size as f32 / grad_data.len() as f32;
let original_norm = gradient.norm()?.item()?;
debug!(
"Sketching compression: {} -> {} elements",
grad_data.len(),
sketch_size
);
Ok(CompressedGradient {
method: CompressionMethod::Sketching { sketch_size },
data: CompressedData::Sketch {
sketch,
hash_a,
hash_b,
},
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_power_sgd(
&self,
gradient: &Tensor,
rank: usize,
) -> TorshResult<CompressedGradient> {
let shape_obj = gradient.shape();
let shape = shape_obj.dims();
if shape.len() != 2 {
return Err(TorshDistributedError::invalid_argument(
"gradient",
format!("PowerSGD requires 2D tensors, got {}D tensor", shape.len()),
"2D tensor with shape [rows, cols]",
));
}
let rows = shape[0];
let cols = shape[1];
let left_factor_size = rows * rank;
let right_factor_size = cols * rank;
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let left_factor: Vec<f32> = grad_data.iter().take(left_factor_size).copied().collect();
let right_factor: Vec<f32> = grad_data
.iter()
.skip(left_factor_size)
.take(right_factor_size)
.copied()
.collect();
let compression_ratio =
(left_factor_size + right_factor_size) as f32 / grad_data.len() as f32;
let original_norm = gradient.norm()?.item()?;
debug!(
"PowerSGD compression: rank {}, ratio {:.3}",
rank, compression_ratio
);
Ok(CompressedGradient {
method: CompressionMethod::PowerSGD { rank },
data: CompressedData::LowRank {
left_factor,
right_factor,
rank,
},
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_ternary(
&self,
gradient: &Tensor,
threshold: f32,
) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let original_norm = gradient.norm()?.item()?;
let scale = original_norm / (grad_data.len() as f32).sqrt();
let mut ternary_values = Vec::new();
for &value in &grad_data {
let normalized = value / scale;
let ternary = if normalized > threshold {
1i8
} else if normalized < -threshold {
-1i8
} else {
0i8
};
ternary_values.push(ternary);
}
let compression_ratio = 2.0 / 32.0; debug!(
"Ternary compression: threshold {}, scale {:.6}",
threshold, scale
);
Ok(CompressedGradient {
method: CompressionMethod::TernaryQuant { threshold },
data: CompressedData::Ternary {
values: ternary_values,
scale,
},
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_bimodal(
&self,
gradient: &Tensor,
num_bins: usize,
) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let original_norm = gradient.norm()?.item()?;
let min_val = grad_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = grad_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut bin_centers = Vec::new();
for i in 0..num_bins {
let center = min_val + (max_val - min_val) * (i as f32 + 0.5) / (num_bins as f32);
bin_centers.push(center);
}
let mut bin_indices = Vec::new();
for &value in &grad_data {
let mut best_bin = 0;
let mut best_distance = f32::INFINITY;
for (bin_idx, ¢er) in bin_centers.iter().enumerate() {
let distance = (value - center).abs();
if distance < best_distance {
best_distance = distance;
best_bin = bin_idx;
}
}
bin_indices.push(best_bin as u8);
}
let bits_per_bin = (num_bins as f32).log2().ceil();
let compression_ratio = bits_per_bin / 32.0;
debug!(
"Bimodal compression: {} bins, {:.1} bits/value",
num_bins, bits_per_bin
);
Ok(CompressedGradient {
method: CompressionMethod::BimodalQuant { num_bins },
data: CompressedData::Bimodal {
bin_indices,
bin_centers,
},
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_natural(
&self,
gradient: &Tensor,
compression_factor: f32,
) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let original_norm = gradient.norm()?.item()?;
let num_unique = (grad_data.len() as f32 * compression_factor).ceil() as usize;
let mut value_counts: std::collections::HashMap<i32, u32> =
std::collections::HashMap::new();
let scale = 10000.0; for &value in &grad_data {
let quantized = (value * scale).round() as i32;
*value_counts.entry(quantized).or_insert(0) += 1;
}
let mut sorted_values: Vec<_> = value_counts.into_iter().collect();
sorted_values.sort_by(|a, b| b.1.cmp(&a.1));
sorted_values.truncate(num_unique);
let codebook: Vec<f32> = sorted_values
.iter()
.map(|(v, _)| *v as f32 / scale)
.collect();
let frequencies: Vec<u32> = sorted_values.iter().map(|(_, f)| *f).collect();
let mut compressed_values = Vec::new();
for &value in &grad_data {
let mut best_idx = 0;
let mut best_distance = f32::INFINITY;
for (idx, &codebook_val) in codebook.iter().enumerate() {
let distance = (value - codebook_val).abs();
if distance < best_distance {
best_distance = distance;
best_idx = idx;
}
}
compressed_values.push(best_idx as f32);
}
debug!(
"Natural compression: {} unique values from {} total",
num_unique,
grad_data.len()
);
Ok(CompressedGradient {
method: CompressionMethod::NaturalCompression { compression_factor },
data: CompressedData::Natural {
values: compressed_values,
frequencies,
codebook,
},
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio: compression_factor,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn compress_layerwise_adaptive(
&self,
gradient: &Tensor,
base_ratio: f32,
sensitivity: f32,
param_name: &str,
) -> TorshResult<CompressedGradient> {
let _original_norm = gradient.norm()?.item();
let layer_sensitivity = if param_name.contains("weight") {
1.0
} else {
sensitivity
};
let adapted_ratio = base_ratio * layer_sensitivity;
self.compress_top_k(gradient, adapted_ratio)
}
fn compress_ef21(
&mut self,
gradient: &Tensor,
compression_ratio: f32,
momentum: f32,
param_name: &str,
) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let original_norm = gradient.norm()?.item()?;
let error_key = format!("ef21_{}", param_name);
let error_feedback = if let Some(prev_error) = self.error_buffers.get(&error_key) {
prev_error.flatten()?.to_vec()?
} else {
vec![0.0; grad_data.len()]
};
let mut adjusted_grad = Vec::new();
for (&grad_val, &error_val) in grad_data.iter().zip(error_feedback.iter()) {
adjusted_grad.push(grad_val + momentum * error_val);
}
let k_elements = (grad_data.len() as f32 * compression_ratio).ceil() as usize;
let mut indexed_values: Vec<(usize, f32)> = adjusted_grad
.iter()
.enumerate()
.map(|(i, &v)| (i, v.abs()))
.collect();
indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut compressed_values = vec![0.0; grad_data.len()];
let mut new_error_feedback = adjusted_grad.clone();
for &(idx, _) in indexed_values.iter().take(k_elements) {
compressed_values[idx] = adjusted_grad[idx];
new_error_feedback[idx] = 0.0; }
let error_tensor = Tensor::from_vec(new_error_feedback.clone(), gradient.shape().dims())?;
self.error_buffers.insert(error_key, error_tensor);
debug!(
"EF21 compression: kept {}/{} elements with momentum {}",
k_elements,
grad_data.len(),
momentum
);
Ok(CompressedGradient {
method: CompressionMethod::EF21 {
compression_ratio,
momentum,
},
data: CompressedData::EF21 {
compressed_values,
error_feedback: new_error_feedback,
},
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn no_compression(
&self,
gradient: &Tensor,
_param_name: &str,
) -> TorshResult<CompressedGradient> {
let flat_grad = gradient.flatten()?;
let grad_data = flat_grad.to_vec()?;
let indices: Vec<usize> = (0..grad_data.len()).collect();
let original_norm = gradient.norm()?.item()?;
Ok(CompressedGradient {
method: CompressionMethod::None,
data: CompressedData::Sparse {
indices,
values: grad_data,
},
original_shape: gradient.shape().dims().to_vec(),
metadata: CompressionMetadata {
compression_ratio: 1.0,
error_norm: 0.0,
original_norm,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_secs(),
},
})
}
fn decompress_sparse(
&self,
indices: &[usize],
values: &[f32],
shape: &[usize],
) -> TorshResult<Tensor> {
let total_elements: usize = shape.iter().product();
let mut data = vec![0.0; total_elements];
for (&idx, &val) in indices.iter().zip(values.iter()) {
if idx < total_elements {
data[idx] = val;
}
}
Ok(Tensor::from_vec(data, shape)?)
}
fn decompress_quantized(
&self,
values: &[u8],
scale: f32,
zero_point: u8,
shape: &[usize],
) -> TorshResult<Tensor> {
let data: Vec<f32> = values
.iter()
.map(|&q| (q as f32 - zero_point as f32) * scale)
.collect();
Ok(Tensor::from_vec(data, shape)?)
}
fn decompress_sign_sgd(
&self,
signs: &[bool],
norm: f32,
shape: &[usize],
) -> TorshResult<Tensor> {
let total_elements: usize = shape.iter().product();
let magnitude = norm / (total_elements as f32).sqrt();
let data: Vec<f32> = signs
.iter()
.map(|&sign| if sign { magnitude } else { -magnitude })
.collect();
Ok(Tensor::from_vec(data, shape)?)
}
fn decompress_power_sgd(
&self,
left_factor: &[f32],
right_factor: &[f32],
_rank: usize,
shape: &[usize],
) -> TorshResult<Tensor> {
let total_elements: usize = shape.iter().product();
let mut data = vec![0.0; total_elements];
let left_len = left_factor.len();
let right_len = right_factor.len();
for i in 0..total_elements.min(left_len + right_len) {
if i < left_len {
data[i] = left_factor[i];
} else {
data[i] = right_factor[i - left_len];
}
}
Ok(Tensor::from_vec(data, shape)?)
}
fn decompress_sketching(
&self,
sketch: &[f32],
_hash_a: &[u32],
_hash_b: &[u32],
shape: &[usize],
) -> TorshResult<Tensor> {
let total_elements: usize = shape.iter().product();
let mut data = vec![0.0; total_elements];
for (i, &val) in sketch.iter().enumerate() {
if i < total_elements {
data[i] = val;
}
}
Ok(Tensor::from_vec(data, shape)?)
}
fn decompress_ternary(
&self,
values: &[i8],
scale: f32,
shape: &[usize],
) -> TorshResult<Tensor> {
let data: Vec<f32> = values
.iter()
.map(|&ternary| (ternary as f32) * scale)
.collect();
Ok(Tensor::from_vec(data, shape)?)
}
fn decompress_bimodal(
&self,
bin_indices: &[u8],
bin_centers: &[f32],
shape: &[usize],
) -> TorshResult<Tensor> {
let data: Vec<f32> = bin_indices
.iter()
.map(|&bin_idx| bin_centers.get(bin_idx as usize).copied().unwrap_or(0.0))
.collect();
Ok(Tensor::from_vec(data, shape)?)
}
fn decompress_natural(
&self,
values: &[f32],
codebook: &[f32],
shape: &[usize],
) -> TorshResult<Tensor> {
let data: Vec<f32> = values
.iter()
.map(|&idx| {
let idx_usize = idx as usize;
codebook.get(idx_usize).copied().unwrap_or(0.0)
})
.collect();
Ok(Tensor::from_vec(data, shape)?)
}
fn decompress_ef21(&self, compressed_values: &[f32], shape: &[usize]) -> TorshResult<Tensor> {
Ok(Tensor::from_vec(compressed_values.to_vec(), shape)?)
}
fn update_stats(&mut self, compressed: &CompressedGradient, compression_time: f64) {
self.stats.total_compressions += 1;
self.stats.avg_compression_ratio = (self.stats.avg_compression_ratio
* (self.stats.total_compressions - 1) as f64
+ compressed.metadata.compression_ratio as f64)
/ self.stats.total_compressions as f64;
self.stats.compression_time_ms += compression_time;
}
pub fn get_stats(&self) -> &CompressionStats {
&self.stats
}
pub fn reset_error_feedback(&mut self) {
self.error_buffers.clear();
}
pub fn step_count(&self) -> usize {
self.step_count
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compression_config() {
let config = CompressionConfig::default();
assert_eq!(config.compression_ratio, 0.1);
assert!(config.error_feedback);
assert_eq!(config.warmup_steps, 100);
}
#[test]
fn test_compression_methods() {
assert_ne!(
CompressionMethod::TopK { k: 0.1 },
CompressionMethod::SignSGD
);
assert_ne!(
CompressionMethod::Quantization { bits: 8 },
CompressionMethod::None
);
}
#[tokio::test]
async fn test_gradient_compressor_creation() {
let config = CompressionConfig::default();
let compressor = GradientCompressor::new(config);
assert_eq!(compressor.step_count(), 0);
assert_eq!(compressor.get_stats().total_compressions, 0);
}
#[tokio::test]
async fn test_top_k_compression() -> TorshResult<()> {
let config = CompressionConfig {
method: CompressionMethod::TopK { k: 0.5 },
warmup_steps: 0,
..Default::default()
};
let mut compressor = GradientCompressor::new(config);
let gradient = torsh_tensor::creation::randn(&[10, 10])?;
let compressed = compressor.compress(&gradient, "test_param")?;
match &compressed.data {
CompressedData::Sparse { indices, values } => {
assert_eq!(indices.len(), values.len());
assert!(indices.len() <= 50); }
_ => panic!("Expected sparse compression for TopK"),
}
let decompressed = compressor.decompress(&compressed)?;
assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
Ok(())
}
#[tokio::test]
async fn test_sign_sgd_compression() -> TorshResult<()> {
let config = CompressionConfig {
method: CompressionMethod::SignSGD,
warmup_steps: 0,
..Default::default()
};
let mut compressor = GradientCompressor::new(config);
let gradient = torsh_tensor::creation::randn(&[5, 5])?;
let compressed = compressor.compress(&gradient, "test_param")?;
match &compressed.data {
CompressedData::Signs { signs, norm } => {
assert_eq!(signs.len(), 25); assert!(*norm > 0.0);
}
_ => panic!("Expected sign compression for SignSGD"),
}
let decompressed = compressor.decompress(&compressed)?;
assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
Ok(())
}
#[tokio::test]
async fn test_quantization_compression() -> TorshResult<()> {
let config = CompressionConfig {
method: CompressionMethod::Quantization { bits: 8 },
warmup_steps: 0,
..Default::default()
};
let mut compressor = GradientCompressor::new(config);
let gradient = torsh_tensor::creation::randn(&[4, 4])?;
let compressed = compressor.compress(&gradient, "test_param")?;
match &compressed.data {
CompressedData::Quantized {
values,
scale,
zero_point: _,
} => {
assert_eq!(values.len(), 16); assert!(*scale > 0.0);
}
_ => panic!("Expected quantized compression"),
}
let decompressed = compressor.decompress(&compressed)?;
assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
Ok(())
}
#[tokio::test]
async fn test_no_compression() -> TorshResult<()> {
let config = CompressionConfig {
method: CompressionMethod::None,
warmup_steps: 0,
..Default::default()
};
let mut compressor = GradientCompressor::new(config);
let gradient = torsh_tensor::creation::randn(&[3, 3])?;
let compressed = compressor.compress(&gradient, "test_param")?;
assert_eq!(compressed.metadata.compression_ratio, 1.0);
let decompressed = compressor.decompress(&compressed)?;
assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
Ok(())
}
#[test]
fn test_compression_stats() {
let stats = CompressionStats {
total_compressions: 100,
avg_compression_ratio: 0.25,
total_communication_reduction: 1024 * 1024, avg_error_norm: 0.01,
compression_time_ms: 250.5,
};
assert_eq!(stats.total_compressions, 100);
assert_eq!(stats.avg_compression_ratio, 0.25);
assert_eq!(stats.total_communication_reduction, 1024 * 1024);
}
}