Skip to main content

pmetal_distributed/
compression.rs

1//! Gradient compression for bandwidth optimization.
2//!
3//! Provides several compression strategies:
4//! - TopK: Keep only the k largest gradients
5//! - Random sparsification: Randomly sample gradients
6//! - Quantization: Reduce precision (FP16, BF16, INT8)
7//! - Error feedback: Accumulate compression errors for future updates
8//!
9//! References:
10//! - Deep Gradient Compression (Lin et al., 2018)
11//! - 1-Bit SGD (Seide et al., 2014)
12//! - PowerSGD (Vogels et al., 2019)
13
14use half::{bf16, f16};
15use serde::{Deserialize, Serialize};
16use std::collections::BinaryHeap;
17use tracing::debug;
18
19/// Compression strategy.
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub enum CompressionStrategy {
22    /// No compression.
23    #[default]
24    None,
25    /// Keep only top-k% gradients by magnitude.
26    TopK { ratio: f32 },
27    /// Random sparsification with given probability.
28    Random { probability: f32 },
29    /// Quantize to lower precision.
30    Quantize(QuantizationType),
31    /// PowerSGD low-rank approximation.
32    PowerSGD { rank: usize },
33}
34
35/// Quantization type.
36#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
37pub enum QuantizationType {
38    /// FP16 (IEEE half precision).
39    FP16,
40    /// BF16 (Brain float).
41    BF16,
42    /// 8-bit integer with scale.
43    INT8,
44    /// 1-bit sign only (with scale).
45    OneBit,
46}
47
48/// Compressed gradient representation.
49#[derive(Debug, Clone)]
50pub struct CompressedGradient {
51    /// Original size.
52    pub original_size: usize,
53    /// Compression strategy used.
54    pub strategy: CompressionStrategy,
55    /// Compressed data.
56    pub data: CompressedData,
57}
58
59/// Compressed data variants.
60#[derive(Debug, Clone)]
61pub enum CompressedData {
62    /// Full precision (no compression).
63    Full(Vec<f32>),
64    /// Sparse representation (indices + values).
65    Sparse { indices: Vec<u32>, values: Vec<f32> },
66    /// FP16 quantized.
67    FP16(Vec<u16>),
68    /// BF16 quantized.
69    BF16(Vec<u16>),
70    /// INT8 quantized with scale.
71    INT8 { data: Vec<i8>, scale: f32 },
72    /// 1-bit with scale.
73    OneBit { signs: Vec<u8>, scale: f32 },
74}
75
76impl CompressedGradient {
77    /// Get the compression ratio.
78    pub fn compression_ratio(&self) -> f32 {
79        let original_bytes = self.original_size * 4;
80        let compressed_bytes = self.compressed_bytes();
81        original_bytes as f32 / compressed_bytes as f32
82    }
83
84    /// Get compressed size in bytes.
85    pub fn compressed_bytes(&self) -> usize {
86        match &self.data {
87            CompressedData::Full(v) => v.len() * 4,
88            CompressedData::Sparse { indices, values } => indices.len() * 4 + values.len() * 4,
89            CompressedData::FP16(v) => v.len() * 2,
90            CompressedData::BF16(v) => v.len() * 2,
91            CompressedData::INT8 { data, .. } => data.len() + 4,
92            CompressedData::OneBit { signs, .. } => signs.len() + 4,
93        }
94    }
95}
96
97/// Gradient compressor with error feedback.
98pub struct GradientCompressor {
99    /// Compression strategy.
100    strategy: CompressionStrategy,
101    /// Error feedback buffer (accumulated residuals).
102    error_feedback: Option<Vec<f32>>,
103    /// Whether to use error feedback.
104    use_error_feedback: bool,
105    /// Random seed for reproducibility.
106    rng_seed: u64,
107}
108
109impl GradientCompressor {
110    /// Create a new compressor.
111    pub fn new(strategy: CompressionStrategy, use_error_feedback: bool) -> Self {
112        Self {
113            strategy,
114            error_feedback: None,
115            use_error_feedback,
116            rng_seed: 42,
117        }
118    }
119
120    /// Compress gradients.
121    pub fn compress(&mut self, gradients: &[f32]) -> CompressedGradient {
122        let original_size = gradients.len();
123
124        // Apply error feedback if enabled
125        let working_grads = if self.use_error_feedback {
126            if let Some(ref error) = self.error_feedback {
127                gradients
128                    .iter()
129                    .zip(error.iter())
130                    .map(|(g, e)| g + e)
131                    .collect()
132            } else {
133                gradients.to_vec()
134            }
135        } else {
136            gradients.to_vec()
137        };
138
139        let (data, residual) = match &self.strategy {
140            CompressionStrategy::None => (CompressedData::Full(working_grads.clone()), None),
141            CompressionStrategy::TopK { ratio } => self.compress_topk(&working_grads, *ratio),
142            CompressionStrategy::Random { probability } => {
143                self.compress_random(&working_grads, *probability)
144            }
145            CompressionStrategy::Quantize(qtype) => (self.quantize(&working_grads, *qtype), None),
146            CompressionStrategy::PowerSGD { rank: _ } => {
147                // PowerSGD requires state across iterations, simplified here
148                (CompressedData::Full(working_grads.clone()), None)
149            }
150        };
151
152        // Store residual for error feedback
153        if self.use_error_feedback {
154            self.error_feedback = residual;
155        }
156
157        let result = CompressedGradient {
158            original_size,
159            strategy: self.strategy.clone(),
160            data,
161        };
162
163        debug!(
164            "Compressed {} floats, ratio={:.2}x",
165            original_size,
166            result.compression_ratio()
167        );
168
169        result
170    }
171
172    /// Decompress gradients.
173    pub fn decompress(&self, compressed: &CompressedGradient) -> Vec<f32> {
174        match &compressed.data {
175            CompressedData::Full(v) => v.clone(),
176            CompressedData::Sparse { indices, values } => {
177                let mut result = vec![0.0f32; compressed.original_size];
178                for (&idx, &val) in indices.iter().zip(values.iter()) {
179                    result[idx as usize] = val;
180                }
181                result
182            }
183            CompressedData::FP16(v) => v.iter().map(|&x| f16::from_bits(x).to_f32()).collect(),
184            CompressedData::BF16(v) => v.iter().map(|&x| bf16::from_bits(x).to_f32()).collect(),
185            CompressedData::INT8 { data, scale } => {
186                data.iter().map(|&x| x as f32 * scale).collect()
187            }
188            CompressedData::OneBit { signs, scale } => {
189                let mut result = Vec::with_capacity(compressed.original_size);
190                for byte in signs {
191                    for bit in 0..8 {
192                        if result.len() >= compressed.original_size {
193                            break;
194                        }
195                        let sign = if (byte >> bit) & 1 == 1 { 1.0 } else { -1.0 };
196                        result.push(sign * scale);
197                    }
198                }
199                result
200            }
201        }
202    }
203
204    /// Top-K sparsification.
205    fn compress_topk(&self, gradients: &[f32], ratio: f32) -> (CompressedData, Option<Vec<f32>>) {
206        let k = ((gradients.len() as f32 * ratio) as usize).max(1);
207
208        // Find top-k by magnitude using a min-heap
209        let mut heap: BinaryHeap<std::cmp::Reverse<(ordered_float::OrderedFloat<f32>, u32)>> =
210            BinaryHeap::with_capacity(k + 1);
211
212        for (i, &val) in gradients.iter().enumerate() {
213            let abs_val = ordered_float::OrderedFloat(val.abs());
214            heap.push(std::cmp::Reverse((abs_val, i as u32)));
215            if heap.len() > k {
216                heap.pop();
217            }
218        }
219
220        // Extract indices and values
221        let mut indices: Vec<u32> = heap.iter().map(|x| x.0.1).collect();
222        indices.sort_unstable();
223
224        let values: Vec<f32> = indices.iter().map(|&i| gradients[i as usize]).collect();
225
226        // Compute residual (unselected values)
227        let mut residual = gradients.to_vec();
228        for &idx in &indices {
229            residual[idx as usize] = 0.0;
230        }
231
232        (CompressedData::Sparse { indices, values }, Some(residual))
233    }
234
235    /// Random sparsification.
236    fn compress_random(
237        &mut self,
238        gradients: &[f32],
239        probability: f32,
240    ) -> (CompressedData, Option<Vec<f32>>) {
241        let mut indices = Vec::new();
242        let mut values = Vec::new();
243        let mut residual = gradients.to_vec();
244
245        // Simple PRNG
246        let mut rng = self.rng_seed;
247
248        for (i, &val) in gradients.iter().enumerate() {
249            // LCG random number generator
250            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
251            let rand_val = (rng >> 33) as f32 / (u32::MAX >> 1) as f32;
252
253            if rand_val < probability {
254                indices.push(i as u32);
255                values.push(val / probability); // Scale to maintain expectation
256                residual[i] = 0.0;
257            }
258        }
259
260        self.rng_seed = rng;
261        (CompressedData::Sparse { indices, values }, Some(residual))
262    }
263
264    /// Quantize gradients.
265    fn quantize(&self, gradients: &[f32], qtype: QuantizationType) -> CompressedData {
266        match qtype {
267            QuantizationType::FP16 => {
268                let data: Vec<u16> = gradients
269                    .iter()
270                    .map(|&x| f16::from_f32(x).to_bits())
271                    .collect();
272                CompressedData::FP16(data)
273            }
274            QuantizationType::BF16 => {
275                let data: Vec<u16> = gradients
276                    .iter()
277                    .map(|&x| bf16::from_f32(x).to_bits())
278                    .collect();
279                CompressedData::BF16(data)
280            }
281            QuantizationType::INT8 => {
282                let max_abs = gradients
283                    .iter()
284                    .map(|x| x.abs())
285                    .fold(0.0f32, |a, b| a.max(b));
286                let scale = max_abs / 127.0;
287
288                let data: Vec<i8> = gradients
289                    .iter()
290                    .map(|&x| (x / scale).clamp(-127.0, 127.0) as i8)
291                    .collect();
292
293                CompressedData::INT8 { data, scale }
294            }
295            QuantizationType::OneBit => {
296                let mean_abs =
297                    gradients.iter().map(|x| x.abs()).sum::<f32>() / gradients.len() as f32;
298
299                let num_bytes = gradients.len().div_ceil(8);
300                let mut signs = vec![0u8; num_bytes];
301
302                for (i, &val) in gradients.iter().enumerate() {
303                    if val > 0.0 {
304                        signs[i / 8] |= 1 << (i % 8);
305                    }
306                }
307
308                CompressedData::OneBit {
309                    signs,
310                    scale: mean_abs,
311                }
312            }
313        }
314    }
315
316    /// Reset error feedback.
317    pub fn reset_error_feedback(&mut self) {
318        self.error_feedback = None;
319    }
320}
321
322/// Serialize compressed gradient to bytes.
323pub fn serialize_compressed(compressed: &CompressedGradient) -> Vec<u8> {
324    let mut result = Vec::new();
325
326    // Header: original_size (4 bytes) + strategy_id (1 byte)
327    result.extend_from_slice(&(compressed.original_size as u32).to_le_bytes());
328
329    match &compressed.data {
330        CompressedData::Full(v) => {
331            result.push(0u8);
332            for f in v {
333                result.extend_from_slice(&f.to_le_bytes());
334            }
335        }
336        CompressedData::Sparse { indices, values } => {
337            result.push(1u8);
338            result.extend_from_slice(&(indices.len() as u32).to_le_bytes());
339            for &idx in indices {
340                result.extend_from_slice(&idx.to_le_bytes());
341            }
342            for &val in values {
343                result.extend_from_slice(&val.to_le_bytes());
344            }
345        }
346        CompressedData::FP16(v) => {
347            result.push(2u8);
348            for &x in v {
349                result.extend_from_slice(&x.to_le_bytes());
350            }
351        }
352        CompressedData::BF16(v) => {
353            result.push(3u8);
354            for &x in v {
355                result.extend_from_slice(&x.to_le_bytes());
356            }
357        }
358        CompressedData::INT8 { data, scale } => {
359            result.push(4u8);
360            result.extend_from_slice(&scale.to_le_bytes());
361            result.extend_from_slice(data.iter().map(|&x| x as u8).collect::<Vec<_>>().as_slice());
362        }
363        CompressedData::OneBit { signs, scale } => {
364            result.push(5u8);
365            result.extend_from_slice(&scale.to_le_bytes());
366            result.extend_from_slice(signs);
367        }
368    }
369
370    result
371}
372
373/// Deserialize compressed gradient from bytes.
374pub fn deserialize_compressed(bytes: &[u8]) -> Option<CompressedGradient> {
375    if bytes.len() < 5 {
376        return None;
377    }
378
379    let original_size = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
380    let strategy_id = bytes[4];
381
382    let data = match strategy_id {
383        0 => {
384            // Full
385            let floats: Vec<f32> = bytes[5..]
386                .chunks_exact(4)
387                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
388                .collect();
389            CompressedData::Full(floats)
390        }
391        1 => {
392            // Sparse
393            let num_indices = u32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]) as usize;
394            let indices_end = 9 + num_indices * 4;
395            let indices: Vec<u32> = bytes[9..indices_end]
396                .chunks_exact(4)
397                .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
398                .collect();
399            let values: Vec<f32> = bytes[indices_end..]
400                .chunks_exact(4)
401                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
402                .collect();
403            CompressedData::Sparse { indices, values }
404        }
405        2 => {
406            // FP16
407            let data: Vec<u16> = bytes[5..]
408                .chunks_exact(2)
409                .map(|c| u16::from_le_bytes([c[0], c[1]]))
410                .collect();
411            CompressedData::FP16(data)
412        }
413        3 => {
414            // BF16
415            let data: Vec<u16> = bytes[5..]
416                .chunks_exact(2)
417                .map(|c| u16::from_le_bytes([c[0], c[1]]))
418                .collect();
419            CompressedData::BF16(data)
420        }
421        4 => {
422            // INT8
423            let scale = f32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
424            let data: Vec<i8> = bytes[9..].iter().map(|&x| x as i8).collect();
425            CompressedData::INT8 { data, scale }
426        }
427        5 => {
428            // OneBit
429            let scale = f32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
430            let signs = bytes[9..].to_vec();
431            CompressedData::OneBit { signs, scale }
432        }
433        _ => return None,
434    };
435
436    Some(CompressedGradient {
437        original_size,
438        strategy: CompressionStrategy::None, // Not tracked in serialization
439        data,
440    })
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_no_compression() {
449        let mut compressor = GradientCompressor::new(CompressionStrategy::None, false);
450        let grads = vec![1.0, 2.0, 3.0, 4.0];
451
452        let compressed = compressor.compress(&grads);
453        let decompressed = compressor.decompress(&compressed);
454
455        assert_eq!(grads, decompressed);
456        assert!((compressed.compression_ratio() - 1.0).abs() < 0.01);
457    }
458
459    #[test]
460    fn test_topk() {
461        let mut compressor =
462            GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, false);
463        let grads = vec![1.0, 4.0, 2.0, 3.0];
464
465        let compressed = compressor.compress(&grads);
466        let decompressed = compressor.decompress(&compressed);
467
468        // Top 50% should be 4.0 and 3.0
469        assert!(decompressed[1] == 4.0);
470        assert!(decompressed[3] == 3.0);
471        assert!(decompressed[0] == 0.0);
472        assert!(decompressed[2] == 0.0);
473
474        // With 50% sparsity on 4 elements: 2 indices + 2 values = same as original
475        // Compression becomes effective with larger tensors
476        assert!(compressed.compression_ratio() >= 1.0);
477    }
478
479    #[test]
480    fn test_fp16_quantization() {
481        let mut compressor =
482            GradientCompressor::new(CompressionStrategy::Quantize(QuantizationType::FP16), false);
483        let grads = vec![1.0, 2.5, 3.125, 4.0];
484
485        let compressed = compressor.compress(&grads);
486        let decompressed = compressor.decompress(&compressed);
487
488        // FP16 should be approximately equal
489        for (orig, decomp) in grads.iter().zip(decompressed.iter()) {
490            assert!((orig - decomp).abs() < 0.01);
491        }
492
493        // 2x compression ratio
494        assert!((compressed.compression_ratio() - 2.0).abs() < 0.1);
495    }
496
497    #[test]
498    fn test_int8_quantization() {
499        let mut compressor =
500            GradientCompressor::new(CompressionStrategy::Quantize(QuantizationType::INT8), false);
501        let grads = vec![1.0, 2.0, 3.0, 4.0];
502
503        let compressed = compressor.compress(&grads);
504        let decompressed = compressor.decompress(&compressed);
505
506        // INT8 should be approximately equal
507        for (orig, decomp) in grads.iter().zip(decompressed.iter()) {
508            assert!((orig - decomp).abs() < 0.1);
509        }
510
511        // INT8: 4 bytes data + 4 bytes scale = 8 bytes vs 16 bytes original = 2x ratio
512        assert!(compressed.compression_ratio() >= 2.0);
513    }
514
515    #[test]
516    fn test_serialization_roundtrip() {
517        let mut compressor =
518            GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, false);
519        let grads = vec![1.0, 4.0, 2.0, 3.0];
520
521        let compressed = compressor.compress(&grads);
522        let bytes = serialize_compressed(&compressed);
523        let restored = deserialize_compressed(&bytes).unwrap();
524
525        let decompressed = compressor.decompress(&restored);
526
527        // Should match original sparse decompression
528        assert!(decompressed[1] == 4.0);
529        assert!(decompressed[3] == 3.0);
530    }
531
532    #[test]
533    fn test_error_feedback() {
534        let mut compressor =
535            GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, true);
536
537        // First compression - will accumulate residuals
538        let grads1 = vec![1.0, 4.0, 2.0, 3.0];
539        let _compressed1 = compressor.compress(&grads1);
540
541        // Second compression - should include accumulated error
542        let grads2 = vec![0.1, 0.1, 0.1, 0.1];
543        let compressed2 = compressor.compress(&grads2);
544        let decompressed2 = compressor.decompress(&compressed2);
545
546        // Residual from first (1.0 and 2.0) should be added to second
547        // Top-k should now pick the accumulated values
548        assert!(decompressed2.iter().any(|&x| x > 1.0));
549    }
550}