Skip to main content

pmetal_distributed/
compression.rs

1//! Gradient compression for bandwidth optimization.
2//!
3//! Provides several compression strategies:
4//! - TopK: Keep only the k largest gradients
5//! - Random sparsification: Randomly sample gradients
6//! - Quantization: Reduce precision (FP16, BF16, INT8)
7//! - Error feedback: Accumulate compression errors for future updates
8//!
9//! References:
10//! - Deep Gradient Compression (Lin et al., 2018)
11//! - 1-Bit SGD (Seide et al., 2014)
12//! - PowerSGD (Vogels et al., 2019)
13
14use half::{bf16, f16};
15use serde::{Deserialize, Serialize};
16use std::collections::BinaryHeap;
17use tracing::debug;
18
19/// Compression strategy.
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub enum CompressionStrategy {
22    /// No compression.
23    #[default]
24    None,
25    /// Keep only top-k% gradients by magnitude.
26    TopK { ratio: f32 },
27    /// Random sparsification with given probability.
28    Random { probability: f32 },
29    /// Quantize to lower precision.
30    Quantize(QuantizationType),
31    /// PowerSGD low-rank approximation.
32    PowerSGD { rank: usize },
33}
34
35/// Quantization type.
36#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
37pub enum QuantizationType {
38    /// FP16 (IEEE half precision).
39    FP16,
40    /// BF16 (Brain float).
41    BF16,
42    /// 8-bit integer with scale.
43    INT8,
44    /// 1-bit sign only (with scale).
45    OneBit,
46}
47
48/// Compressed gradient representation.
49#[derive(Debug, Clone)]
50pub struct CompressedGradient {
51    /// Original size.
52    pub original_size: usize,
53    /// Compression strategy used.
54    pub strategy: CompressionStrategy,
55    /// Compressed data.
56    pub data: CompressedData,
57}
58
59/// Compressed data variants.
60#[derive(Debug, Clone)]
61pub enum CompressedData {
62    /// Full precision (no compression).
63    Full(Vec<f32>),
64    /// Sparse representation (indices + values).
65    Sparse { indices: Vec<u32>, values: Vec<f32> },
66    /// FP16 quantized.
67    FP16(Vec<u16>),
68    /// BF16 quantized.
69    BF16(Vec<u16>),
70    /// INT8 quantized with scale.
71    INT8 { data: Vec<i8>, scale: f32 },
72    /// 1-bit with scale.
73    OneBit { signs: Vec<u8>, scale: f32 },
74}
75
76impl CompressedGradient {
77    /// Get the compression ratio.
78    pub fn compression_ratio(&self) -> f32 {
79        let original_bytes = self.original_size * 4;
80        let compressed_bytes = self.compressed_bytes();
81        original_bytes as f32 / compressed_bytes as f32
82    }
83
84    /// Get compressed size in bytes.
85    pub fn compressed_bytes(&self) -> usize {
86        match &self.data {
87            CompressedData::Full(v) => v.len() * 4,
88            CompressedData::Sparse { indices, values } => indices.len() * 4 + values.len() * 4,
89            CompressedData::FP16(v) => v.len() * 2,
90            CompressedData::BF16(v) => v.len() * 2,
91            CompressedData::INT8 { data, .. } => data.len() + 4,
92            CompressedData::OneBit { signs, .. } => signs.len() + 4,
93        }
94    }
95}
96
97/// Gradient compressor with error feedback.
98pub struct GradientCompressor {
99    /// Compression strategy.
100    strategy: CompressionStrategy,
101    /// Error feedback buffer (accumulated residuals).
102    error_feedback: Option<Vec<f32>>,
103    /// Whether to use error feedback.
104    use_error_feedback: bool,
105    /// Random seed for reproducibility.
106    rng_seed: u64,
107}
108
109impl GradientCompressor {
110    /// Create a new compressor.
111    pub fn new(strategy: CompressionStrategy, use_error_feedback: bool) -> Self {
112        Self {
113            strategy,
114            error_feedback: None,
115            use_error_feedback,
116            rng_seed: 42,
117        }
118    }
119
120    /// Compress gradients.
121    pub fn compress(&mut self, gradients: &[f32]) -> CompressedGradient {
122        let original_size = gradients.len();
123
124        // Apply error feedback if enabled
125        let working_grads = if self.use_error_feedback {
126            if let Some(ref error) = self.error_feedback {
127                gradients
128                    .iter()
129                    .zip(error.iter())
130                    .map(|(g, e)| g + e)
131                    .collect()
132            } else {
133                gradients.to_vec()
134            }
135        } else {
136            gradients.to_vec()
137        };
138
139        let (data, residual) = match &self.strategy {
140            CompressionStrategy::None => (CompressedData::Full(working_grads.clone()), None),
141            CompressionStrategy::TopK { ratio } => self.compress_topk(&working_grads, *ratio),
142            CompressionStrategy::Random { probability } => {
143                self.compress_random(&working_grads, *probability)
144            }
145            CompressionStrategy::Quantize(qtype) => (self.quantize(&working_grads, *qtype), None),
146            CompressionStrategy::PowerSGD { rank } => {
147                // PowerSGD is not yet implemented — fall back to uncompressed
148                // with a warning so the caller can see it in logs.
149                tracing::warn!(
150                    rank,
151                    "PowerSGD compression not yet implemented; sending uncompressed gradients"
152                );
153                (CompressedData::Full(working_grads.clone()), None)
154            }
155        };
156
157        // Store residual for error feedback
158        if self.use_error_feedback {
159            self.error_feedback = residual;
160        }
161
162        let result = CompressedGradient {
163            original_size,
164            strategy: self.strategy.clone(),
165            data,
166        };
167
168        debug!(
169            "Compressed {} floats, ratio={:.2}x",
170            original_size,
171            result.compression_ratio()
172        );
173
174        result
175    }
176
177    /// Decompress gradients.
178    pub fn decompress(&self, compressed: &CompressedGradient) -> Vec<f32> {
179        match &compressed.data {
180            CompressedData::Full(v) => v.clone(),
181            CompressedData::Sparse { indices, values } => {
182                let mut result = vec![0.0f32; compressed.original_size];
183                for (&idx, &val) in indices.iter().zip(values.iter()) {
184                    result[idx as usize] = val;
185                }
186                result
187            }
188            CompressedData::FP16(v) => v.iter().map(|&x| f16::from_bits(x).to_f32()).collect(),
189            CompressedData::BF16(v) => v.iter().map(|&x| bf16::from_bits(x).to_f32()).collect(),
190            CompressedData::INT8 { data, scale } => {
191                data.iter().map(|&x| x as f32 * scale).collect()
192            }
193            CompressedData::OneBit { signs, scale } => {
194                let mut result = Vec::with_capacity(compressed.original_size);
195                for byte in signs {
196                    for bit in 0..8 {
197                        if result.len() >= compressed.original_size {
198                            break;
199                        }
200                        let sign = if (byte >> bit) & 1 == 1 { 1.0 } else { -1.0 };
201                        result.push(sign * scale);
202                    }
203                }
204                result
205            }
206        }
207    }
208
209    /// Top-K sparsification.
210    fn compress_topk(&self, gradients: &[f32], ratio: f32) -> (CompressedData, Option<Vec<f32>>) {
211        let k = ((gradients.len() as f32 * ratio) as usize).max(1);
212
213        // Find top-k by magnitude using a min-heap
214        let mut heap: BinaryHeap<std::cmp::Reverse<(ordered_float::OrderedFloat<f32>, u32)>> =
215            BinaryHeap::with_capacity(k + 1);
216
217        for (i, &val) in gradients.iter().enumerate() {
218            let abs_val = ordered_float::OrderedFloat(val.abs());
219            heap.push(std::cmp::Reverse((abs_val, i as u32)));
220            if heap.len() > k {
221                heap.pop();
222            }
223        }
224
225        // Extract indices and values
226        let mut indices: Vec<u32> = heap.iter().map(|x| x.0.1).collect();
227        indices.sort_unstable();
228
229        let values: Vec<f32> = indices.iter().map(|&i| gradients[i as usize]).collect();
230
231        // Compute residual (unselected values)
232        let mut residual = gradients.to_vec();
233        for &idx in &indices {
234            residual[idx as usize] = 0.0;
235        }
236
237        (CompressedData::Sparse { indices, values }, Some(residual))
238    }
239
240    /// Random sparsification.
241    fn compress_random(
242        &mut self,
243        gradients: &[f32],
244        probability: f32,
245    ) -> (CompressedData, Option<Vec<f32>>) {
246        let mut indices = Vec::new();
247        let mut values = Vec::new();
248        let mut residual = gradients.to_vec();
249
250        // Simple PRNG
251        let mut rng = self.rng_seed;
252
253        for (i, &val) in gradients.iter().enumerate() {
254            // LCG random number generator
255            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
256            let rand_val = (rng >> 33) as f32 / (u32::MAX >> 1) as f32;
257
258            if rand_val < probability {
259                indices.push(i as u32);
260                values.push(val / probability); // Scale to maintain expectation
261                residual[i] = 0.0;
262            }
263        }
264
265        self.rng_seed = rng;
266        (CompressedData::Sparse { indices, values }, Some(residual))
267    }
268
269    /// Quantize gradients.
270    fn quantize(&self, gradients: &[f32], qtype: QuantizationType) -> CompressedData {
271        match qtype {
272            QuantizationType::FP16 => {
273                let data: Vec<u16> = gradients
274                    .iter()
275                    .map(|&x| f16::from_f32(x).to_bits())
276                    .collect();
277                CompressedData::FP16(data)
278            }
279            QuantizationType::BF16 => {
280                let data: Vec<u16> = gradients
281                    .iter()
282                    .map(|&x| bf16::from_f32(x).to_bits())
283                    .collect();
284                CompressedData::BF16(data)
285            }
286            QuantizationType::INT8 => {
287                let max_abs = gradients
288                    .iter()
289                    .map(|x| x.abs())
290                    .fold(0.0f32, |a, b| a.max(b));
291                let scale = if max_abs == 0.0 { 1.0 } else { max_abs / 127.0 };
292
293                let data: Vec<i8> = gradients
294                    .iter()
295                    .map(|&x| (x / scale).clamp(-127.0, 127.0) as i8)
296                    .collect();
297
298                CompressedData::INT8 { data, scale }
299            }
300            QuantizationType::OneBit => {
301                let mean_abs =
302                    gradients.iter().map(|x| x.abs()).sum::<f32>() / gradients.len() as f32;
303
304                let num_bytes = gradients.len().div_ceil(8);
305                let mut signs = vec![0u8; num_bytes];
306
307                for (i, &val) in gradients.iter().enumerate() {
308                    if val > 0.0 {
309                        signs[i / 8] |= 1 << (i % 8);
310                    }
311                }
312
313                CompressedData::OneBit {
314                    signs,
315                    scale: mean_abs,
316                }
317            }
318        }
319    }
320
321    /// Reset error feedback.
322    pub fn reset_error_feedback(&mut self) {
323        self.error_feedback = None;
324    }
325}
326
327/// Serialize compressed gradient to bytes.
328pub fn serialize_compressed(compressed: &CompressedGradient) -> Vec<u8> {
329    let mut result = Vec::new();
330
331    // Header: original_size (4 bytes) + strategy_id (1 byte)
332    result.extend_from_slice(&(compressed.original_size as u32).to_le_bytes());
333
334    match &compressed.data {
335        CompressedData::Full(v) => {
336            result.push(0u8);
337            for f in v {
338                result.extend_from_slice(&f.to_le_bytes());
339            }
340        }
341        CompressedData::Sparse { indices, values } => {
342            result.push(1u8);
343            result.extend_from_slice(&(indices.len() as u32).to_le_bytes());
344            for &idx in indices {
345                result.extend_from_slice(&idx.to_le_bytes());
346            }
347            for &val in values {
348                result.extend_from_slice(&val.to_le_bytes());
349            }
350        }
351        CompressedData::FP16(v) => {
352            result.push(2u8);
353            for &x in v {
354                result.extend_from_slice(&x.to_le_bytes());
355            }
356        }
357        CompressedData::BF16(v) => {
358            result.push(3u8);
359            for &x in v {
360                result.extend_from_slice(&x.to_le_bytes());
361            }
362        }
363        CompressedData::INT8 { data, scale } => {
364            result.push(4u8);
365            result.extend_from_slice(&scale.to_le_bytes());
366            result.extend_from_slice(data.iter().map(|&x| x as u8).collect::<Vec<_>>().as_slice());
367        }
368        CompressedData::OneBit { signs, scale } => {
369            result.push(5u8);
370            result.extend_from_slice(&scale.to_le_bytes());
371            result.extend_from_slice(signs);
372        }
373    }
374
375    result
376}
377
378/// Deserialize compressed gradient from bytes.
379pub fn deserialize_compressed(bytes: &[u8]) -> Option<CompressedGradient> {
380    if bytes.len() < 5 {
381        return None;
382    }
383
384    let original_size = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
385    let strategy_id = bytes[4];
386
387    let data = match strategy_id {
388        0 => {
389            // Full
390            let floats: Vec<f32> = bytes[5..]
391                .chunks_exact(4)
392                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
393                .collect();
394            CompressedData::Full(floats)
395        }
396        1 => {
397            // Sparse
398            let num_indices = u32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]) as usize;
399            let indices_end = 9 + num_indices * 4;
400            let indices: Vec<u32> = bytes[9..indices_end]
401                .chunks_exact(4)
402                .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
403                .collect();
404            let values: Vec<f32> = bytes[indices_end..]
405                .chunks_exact(4)
406                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
407                .collect();
408            CompressedData::Sparse { indices, values }
409        }
410        2 => {
411            // FP16
412            let data: Vec<u16> = bytes[5..]
413                .chunks_exact(2)
414                .map(|c| u16::from_le_bytes([c[0], c[1]]))
415                .collect();
416            CompressedData::FP16(data)
417        }
418        3 => {
419            // BF16
420            let data: Vec<u16> = bytes[5..]
421                .chunks_exact(2)
422                .map(|c| u16::from_le_bytes([c[0], c[1]]))
423                .collect();
424            CompressedData::BF16(data)
425        }
426        4 => {
427            // INT8
428            let scale = f32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
429            let data: Vec<i8> = bytes[9..].iter().map(|&x| x as i8).collect();
430            CompressedData::INT8 { data, scale }
431        }
432        5 => {
433            // OneBit
434            let scale = f32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
435            let signs = bytes[9..].to_vec();
436            CompressedData::OneBit { signs, scale }
437        }
438        _ => return None,
439    };
440
441    Some(CompressedGradient {
442        original_size,
443        strategy: CompressionStrategy::None, // Not tracked in serialization
444        data,
445    })
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_no_compression() {
454        let mut compressor = GradientCompressor::new(CompressionStrategy::None, false);
455        let grads = vec![1.0, 2.0, 3.0, 4.0];
456
457        let compressed = compressor.compress(&grads);
458        let decompressed = compressor.decompress(&compressed);
459
460        assert_eq!(grads, decompressed);
461        assert!((compressed.compression_ratio() - 1.0).abs() < 0.01);
462    }
463
464    #[test]
465    fn test_topk() {
466        let mut compressor =
467            GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, false);
468        let grads = vec![1.0, 4.0, 2.0, 3.0];
469
470        let compressed = compressor.compress(&grads);
471        let decompressed = compressor.decompress(&compressed);
472
473        // Top 50% should be 4.0 and 3.0
474        assert!(decompressed[1] == 4.0);
475        assert!(decompressed[3] == 3.0);
476        assert!(decompressed[0] == 0.0);
477        assert!(decompressed[2] == 0.0);
478
479        // With 50% sparsity on 4 elements: 2 indices + 2 values = same as original
480        // Compression becomes effective with larger tensors
481        assert!(compressed.compression_ratio() >= 1.0);
482    }
483
484    #[test]
485    fn test_fp16_quantization() {
486        let mut compressor =
487            GradientCompressor::new(CompressionStrategy::Quantize(QuantizationType::FP16), false);
488        let grads = vec![1.0, 2.5, 3.125, 4.0];
489
490        let compressed = compressor.compress(&grads);
491        let decompressed = compressor.decompress(&compressed);
492
493        // FP16 should be approximately equal
494        for (orig, decomp) in grads.iter().zip(decompressed.iter()) {
495            assert!((orig - decomp).abs() < 0.01);
496        }
497
498        // 2x compression ratio
499        assert!((compressed.compression_ratio() - 2.0).abs() < 0.1);
500    }
501
502    #[test]
503    fn test_int8_quantization() {
504        let mut compressor =
505            GradientCompressor::new(CompressionStrategy::Quantize(QuantizationType::INT8), false);
506        let grads = vec![1.0, 2.0, 3.0, 4.0];
507
508        let compressed = compressor.compress(&grads);
509        let decompressed = compressor.decompress(&compressed);
510
511        // INT8 should be approximately equal
512        for (orig, decomp) in grads.iter().zip(decompressed.iter()) {
513            assert!((orig - decomp).abs() < 0.1);
514        }
515
516        // INT8: 4 bytes data + 4 bytes scale = 8 bytes vs 16 bytes original = 2x ratio
517        assert!(compressed.compression_ratio() >= 2.0);
518    }
519
520    #[test]
521    fn test_serialization_roundtrip() {
522        let mut compressor =
523            GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, false);
524        let grads = vec![1.0, 4.0, 2.0, 3.0];
525
526        let compressed = compressor.compress(&grads);
527        let bytes = serialize_compressed(&compressed);
528        let restored = deserialize_compressed(&bytes).unwrap();
529
530        let decompressed = compressor.decompress(&restored);
531
532        // Should match original sparse decompression
533        assert!(decompressed[1] == 4.0);
534        assert!(decompressed[3] == 3.0);
535    }
536
537    #[test]
538    fn test_error_feedback() {
539        let mut compressor =
540            GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, true);
541
542        // First compression - will accumulate residuals
543        let grads1 = vec![1.0, 4.0, 2.0, 3.0];
544        let _compressed1 = compressor.compress(&grads1);
545
546        // Second compression - should include accumulated error
547        let grads2 = vec![0.1, 0.1, 0.1, 0.1];
548        let compressed2 = compressor.compress(&grads2);
549        let decompressed2 = compressor.decompress(&compressed2);
550
551        // Residual from first (1.0 and 2.0) should be added to second
552        // Top-k should now pick the accumulated values
553        assert!(decompressed2.iter().any(|&x| x > 1.0));
554    }
555}