use crate::error::{CoreError, CoreResult};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct TopKConfig {
pub k_fraction: f64,
pub use_error_feedback: bool,
}
impl Default for TopKConfig {
fn default() -> Self {
TopKConfig {
k_fraction: 0.01,
use_error_feedback: true,
}
}
}
pub struct TopKCompressor {
config: TopKConfig,
error_feedback: Vec<f64>,
}
impl TopKCompressor {
pub fn new(n_params: usize, config: TopKConfig) -> Self {
TopKCompressor {
config,
error_feedback: vec![0.0; n_params],
}
}
pub fn compress(&mut self, gradient: &[f64]) -> CoreResult<(Vec<usize>, Vec<f64>)> {
if gradient.is_empty() {
return Ok((vec![], vec![]));
}
let n = gradient.len();
if n != self.error_feedback.len() {
return Err(CoreError::ShapeError(crate::error::ErrorContext::new(
format!(
"TopKCompressor: gradient len {} != initialised len {}",
n,
self.error_feedback.len()
),
)));
}
let mut g: Vec<f64> = if self.config.use_error_feedback {
gradient
.iter()
.zip(self.error_feedback.iter())
.map(|(a, b)| a + b)
.collect()
} else {
gradient.to_vec()
};
let k = ((n as f64 * self.config.k_fraction).ceil() as usize)
.max(1)
.min(n);
let mut order: Vec<usize> = (0..n).collect();
order.sort_unstable_by(|&a, &b| {
g[b].abs()
.partial_cmp(&g[a].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let top_k: Vec<usize> = order[..k].to_vec();
let mut indices: Vec<usize> = top_k.clone();
indices.sort_unstable();
let values: Vec<f64> = indices.iter().map(|&i| g[i]).collect();
if self.config.use_error_feedback {
for &i in &indices {
g[i] = 0.0;
}
self.error_feedback = g; }
Ok((indices, values))
}
pub fn decompress(indices: &[usize], values: &[f64], n_total: usize) -> CoreResult<Vec<f64>> {
if indices.len() != values.len() {
return Err(CoreError::InvalidArgument(crate::error::ErrorContext::new(
"decompress: indices and values length mismatch",
)));
}
let mut out = vec![0.0f64; n_total];
for (&i, &v) in indices.iter().zip(values.iter()) {
if i >= n_total {
return Err(CoreError::InvalidArgument(crate::error::ErrorContext::new(
format!(
"decompress: index {} out of bounds for n_total {}",
i, n_total
),
)));
}
out[i] = v;
}
Ok(out)
}
pub fn compression_ratio(&self, n_total: usize) -> f64 {
let k = ((n_total as f64 * self.config.k_fraction).ceil() as usize)
.max(1)
.min(n_total);
n_total as f64 / k as f64
}
}
pub struct RandomKCompressor {
k_fraction: f64,
}
impl RandomKCompressor {
pub fn new(k_fraction: f64) -> Self {
RandomKCompressor { k_fraction }
}
pub fn compress(&self, gradient: &[f64], seed: u64) -> CoreResult<(Vec<usize>, Vec<f64>)> {
if gradient.is_empty() {
return Ok((vec![], vec![]));
}
let n = gradient.len();
let k = ((n as f64 * self.k_fraction).ceil() as usize).max(1).min(n);
let mut indices: Vec<usize> = (0..n).collect();
let mut rng_state = seed.wrapping_add(1);
let lcg_a: u64 = 6364136223846793005;
let lcg_c: u64 = 1442695040888963407;
for i in 0..k {
rng_state = rng_state.wrapping_mul(lcg_a).wrapping_add(lcg_c);
let j = (rng_state >> 33) as usize % (n - i) + i;
indices.swap(i, j);
}
let mut selected: Vec<usize> = indices[..k].to_vec();
selected.sort_unstable();
let values: Vec<f64> = selected.iter().map(|&i| gradient[i]).collect();
Ok((selected, values))
}
pub fn decompress(indices: &[usize], values: &[f64], n_total: usize) -> CoreResult<Vec<f64>> {
if indices.len() != values.len() {
return Err(CoreError::InvalidArgument(crate::error::ErrorContext::new(
"decompress: indices and values length mismatch",
)));
}
let mut out = vec![0.0f64; n_total];
for (&i, &v) in indices.iter().zip(values.iter()) {
if i >= n_total {
return Err(CoreError::InvalidArgument(crate::error::ErrorContext::new(
format!(
"decompress: index {} out of bounds for n_total {}",
i, n_total
),
)));
}
out[i] = v;
}
Ok(out)
}
}
pub struct OneBitQuantizer;
impl OneBitQuantizer {
pub fn quantize(gradient: &[f64]) -> CoreResult<(Vec<u64>, f64)> {
if gradient.is_empty() {
return Ok((vec![], 0.0));
}
let scale = gradient.iter().map(|x| x.abs()).sum::<f64>() / gradient.len() as f64;
let n_words = gradient.len().div_ceil(64);
let mut bits = vec![0u64; n_words];
for (i, &v) in gradient.iter().enumerate() {
if v >= 0.0 {
bits[i / 64] |= 1u64 << (i % 64);
}
}
Ok((bits, scale))
}
pub fn dequantize(bits: &[u64], scale: f64, n: usize) -> Vec<f64> {
(0..n)
.map(|i| {
if (bits[i / 64] >> (i % 64)) & 1 == 1 {
scale
} else {
-scale
}
})
.collect()
}
pub fn quantization_error(original: &[f64], quantized: &[f64]) -> f64 {
if original.is_empty() {
return 0.0;
}
let len = original.len().min(quantized.len());
let total: f64 = original[..len]
.iter()
.zip(quantized[..len].iter())
.map(|(a, b)| (a - b).abs())
.sum();
total / len as f64
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct PowerSgdConfig {
pub rank: usize,
pub n_power_iter: usize,
pub reuse_momentum: bool,
}
impl Default for PowerSgdConfig {
fn default() -> Self {
PowerSgdConfig {
rank: 4,
n_power_iter: 1,
reuse_momentum: true,
}
}
}
pub fn low_rank_compress(
gradient_matrix: &[Vec<f64>],
config: &PowerSgdConfig,
) -> CoreResult<(Vec<Vec<f64>>, Vec<Vec<f64>>)> {
let m = gradient_matrix.len();
if m == 0 {
return Ok((vec![], vec![]));
}
let n = gradient_matrix[0].len();
if n == 0 {
return Ok((vec![vec![]; m], vec![]));
}
let r = config.rank.min(m.min(n));
if r == 0 {
return Err(CoreError::InvalidArgument(crate::error::ErrorContext::new(
"low_rank_compress: rank must be >= 1",
)));
}
let mut q = vec![vec![0.0f64; r]; n];
let mut rng: u64 = 0xDEAD_BEEF_1234_5678;
let lcg_a: u64 = 6364136223846793005;
let lcg_c: u64 = 1442695040888963407;
for row in q.iter_mut() {
for x in row.iter_mut() {
rng = rng.wrapping_mul(lcg_a).wrapping_add(lcg_c);
*x = (rng as i64 as f64) / (i64::MAX as f64);
}
}
orthonormalize_cols(&mut q)?;
let mut p = vec![vec![0.0f64; r]; m];
for _ in 0..config.n_power_iter.max(1) {
for i in 0..m {
for j in 0..r {
p[i][j] = gradient_matrix[i]
.iter()
.enumerate()
.map(|(k, &g)| g * q[k][j])
.sum();
}
}
orthonormalize_cols(&mut p)?;
for k in 0..n {
for j in 0..r {
q[k][j] = gradient_matrix
.iter()
.enumerate()
.map(|(i, row)| row[k] * p[i][j])
.sum();
}
}
orthonormalize_cols(&mut q)?;
}
for i in 0..m {
for j in 0..r {
p[i][j] = gradient_matrix[i]
.iter()
.enumerate()
.map(|(k, &g)| g * q[k][j])
.sum();
}
}
Ok((p, q))
}
pub fn low_rank_decompress(p: &[Vec<f64>], q: &[Vec<f64>]) -> Vec<Vec<f64>> {
let m = p.len();
if m == 0 {
return vec![];
}
let r = p[0].len();
let n = q.len();
let mut out = vec![vec![0.0f64; n]; m];
for i in 0..m {
for k in 0..n {
let dot: f64 = (0..r).map(|j| p[i][j] * q[k][j]).sum();
out[i][k] = dot;
}
}
out
}
fn orthonormalize_cols(mat: &mut Vec<Vec<f64>>) -> CoreResult<()> {
let rows = mat.len();
if rows == 0 {
return Ok(());
}
let cols = mat[0].len();
if cols == 0 {
return Ok(());
}
for j in 0..cols {
for k in 0..j {
let dot: f64 = (0..rows).map(|i| mat[i][j] * mat[i][k]).sum();
for i in 0..rows {
let prev = mat[i][k];
mat[i][j] -= dot * prev;
}
}
let norm: f64 = (0..rows).map(|i| mat[i][j] * mat[i][j]).sum::<f64>().sqrt();
if norm < 1e-10 {
let mut replaced = false;
'outer: for candidate in 0..rows {
for k in 0..j {
if mat[candidate][k].abs() > 0.9 {
continue 'outer;
}
}
for i in 0..rows {
mat[i][j] = if i == candidate { 1.0 } else { 0.0 };
}
replaced = true;
break;
}
if !replaced {
for i in 0..rows {
mat[i][j] = if i == j % rows { 1.0 } else { 0.0 };
}
}
} else {
for i in 0..rows {
mat[i][j] /= norm;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topk_keeps_exactly_k_elements() {
let cfg = TopKConfig {
k_fraction: 0.25,
use_error_feedback: false,
};
let mut comp = TopKCompressor::new(8, cfg);
let grad = vec![0.1, 0.5, 0.3, 0.9, 0.2, 0.8, 0.4, 0.6];
let (indices, values) = comp.compress(&grad).expect("compress failed");
assert_eq!(indices.len(), 2);
assert_eq!(values.len(), 2);
assert!(indices.contains(&3));
assert!(indices.contains(&5));
}
#[test]
fn test_topk_error_feedback_reduces_over_rounds() {
let cfg = TopKConfig {
k_fraction: 0.5,
use_error_feedback: true,
};
let mut comp = TopKCompressor::new(4, cfg);
let grad = vec![1.0, 0.1, 0.1, 0.1];
let (_, v1) = comp.compress(&grad).expect("compress round 1 failed");
let (_, v2) = comp.compress(&grad).expect("compress round 2 failed");
let sum1: f64 = v1.iter().map(|x| x.abs()).sum();
let sum2: f64 = v2.iter().map(|x| x.abs()).sum();
assert!(sum1 > 0.0);
assert!(sum2 > 0.0);
}
#[test]
fn test_randomk_correct_size() {
let comp = RandomKCompressor::new(0.1);
let grad: Vec<f64> = (0..100).map(|i| i as f64).collect();
let (indices, values) = comp.compress(&grad, 42).expect("compress failed");
assert_eq!(indices.len(), 10);
assert_eq!(values.len(), 10);
let mut sorted = indices.clone();
sorted.dedup();
assert_eq!(sorted.len(), indices.len());
}
#[test]
fn test_1bit_quantize_dequantize_preserves_sign() {
let gradient = vec![-3.0, 1.5, -0.5, 2.0, -0.1, 0.8];
let (bits, scale) = OneBitQuantizer::quantize(&gradient).expect("quantize failed");
let dequantized = OneBitQuantizer::dequantize(&bits, scale, gradient.len());
for (orig, deq) in gradient.iter().zip(dequantized.iter()) {
let same_sign = (orig >= &0.0 && deq >= &0.0) || (orig < &0.0 && deq < &0.0);
assert!(same_sign, "sign mismatch: orig={} deq={}", orig, deq);
}
}
#[test]
fn test_low_rank_compress_decompress_close_for_full_rank() {
let m = 4;
let n = 4;
let u1 = [1.0, 2.0, 3.0, 4.0];
let v1 = [5.0, -1.0, 2.0, 0.5];
let u2 = [0.5, -1.0, 1.5, -2.0];
let v2 = [1.0, 3.0, -2.0, 4.0];
let grad: Vec<Vec<f64>> = (0..m)
.map(|i| (0..n).map(|j| u1[i] * v1[j] + u2[i] * v2[j]).collect())
.collect();
let cfg = PowerSgdConfig {
rank: 2, n_power_iter: 10, reuse_momentum: false,
};
let (p, q) = low_rank_compress(&grad, &cfg).expect("compress failed");
let approx = low_rank_decompress(&p, &q);
let mut max_err = 0.0f64;
for i in 0..m {
for j in 0..n {
let err = (grad[i][j] - approx[i][j]).abs();
max_err = max_err.max(err);
}
}
assert!(max_err < 1e-6, "max reconstruction error = {}", max_err);
}
#[test]
fn test_powersgd_config_defaults() {
let cfg = PowerSgdConfig::default();
assert_eq!(cfg.rank, 4);
assert_eq!(cfg.n_power_iter, 1);
assert!(cfg.reuse_momentum);
}
#[test]
fn test_compression_ratio_computation() {
let cfg = TopKConfig {
k_fraction: 0.01,
use_error_feedback: true,
};
let comp = TopKCompressor::new(1000, cfg);
let ratio = comp.compression_ratio(1000);
assert!((ratio - 100.0).abs() < 1e-9);
}
#[test]
fn test_empty_gradient_handling() {
let cfg = TopKConfig::default();
let mut comp = TopKCompressor::new(0, cfg);
let (idx, val) = comp.compress(&[]).expect("compress empty failed");
assert!(idx.is_empty());
assert!(val.is_empty());
let comp2 = RandomKCompressor::new(0.1);
let (idx2, val2) = comp2.compress(&[], 0).expect("compress empty failed");
assert!(idx2.is_empty());
assert!(val2.is_empty());
let (bits, scale) = OneBitQuantizer::quantize(&[]).expect("quantize empty failed");
assert!(bits.is_empty());
assert_eq!(scale, 0.0);
}
}