1use half::{bf16, f16};
15use serde::{Deserialize, Serialize};
16use std::collections::BinaryHeap;
17use tracing::debug;
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub enum CompressionStrategy {
22 #[default]
24 None,
25 TopK { ratio: f32 },
27 Random { probability: f32 },
29 Quantize(QuantizationType),
31 PowerSGD { rank: usize },
33}
34
35#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
37pub enum QuantizationType {
38 FP16,
40 BF16,
42 INT8,
44 OneBit,
46}
47
48#[derive(Debug, Clone)]
50pub struct CompressedGradient {
51 pub original_size: usize,
53 pub strategy: CompressionStrategy,
55 pub data: CompressedData,
57}
58
59#[derive(Debug, Clone)]
61pub enum CompressedData {
62 Full(Vec<f32>),
64 Sparse { indices: Vec<u32>, values: Vec<f32> },
66 FP16(Vec<u16>),
68 BF16(Vec<u16>),
70 INT8 { data: Vec<i8>, scale: f32 },
72 OneBit { signs: Vec<u8>, scale: f32 },
74}
75
76impl CompressedGradient {
77 pub fn compression_ratio(&self) -> f32 {
79 let original_bytes = self.original_size * 4;
80 let compressed_bytes = self.compressed_bytes();
81 original_bytes as f32 / compressed_bytes as f32
82 }
83
84 pub fn compressed_bytes(&self) -> usize {
86 match &self.data {
87 CompressedData::Full(v) => v.len() * 4,
88 CompressedData::Sparse { indices, values } => indices.len() * 4 + values.len() * 4,
89 CompressedData::FP16(v) => v.len() * 2,
90 CompressedData::BF16(v) => v.len() * 2,
91 CompressedData::INT8 { data, .. } => data.len() + 4,
92 CompressedData::OneBit { signs, .. } => signs.len() + 4,
93 }
94 }
95}
96
97pub struct GradientCompressor {
99 strategy: CompressionStrategy,
101 error_feedback: Option<Vec<f32>>,
103 use_error_feedback: bool,
105 rng_seed: u64,
107}
108
109impl GradientCompressor {
110 pub fn new(strategy: CompressionStrategy, use_error_feedback: bool) -> Self {
112 Self {
113 strategy,
114 error_feedback: None,
115 use_error_feedback,
116 rng_seed: 42,
117 }
118 }
119
120 pub fn compress(&mut self, gradients: &[f32]) -> CompressedGradient {
122 let original_size = gradients.len();
123
124 let working_grads = if self.use_error_feedback {
126 if let Some(ref error) = self.error_feedback {
127 gradients
128 .iter()
129 .zip(error.iter())
130 .map(|(g, e)| g + e)
131 .collect()
132 } else {
133 gradients.to_vec()
134 }
135 } else {
136 gradients.to_vec()
137 };
138
139 let (data, residual) = match &self.strategy {
140 CompressionStrategy::None => (CompressedData::Full(working_grads.clone()), None),
141 CompressionStrategy::TopK { ratio } => self.compress_topk(&working_grads, *ratio),
142 CompressionStrategy::Random { probability } => {
143 self.compress_random(&working_grads, *probability)
144 }
145 CompressionStrategy::Quantize(qtype) => (self.quantize(&working_grads, *qtype), None),
146 CompressionStrategy::PowerSGD { rank } => {
147 tracing::warn!(
150 rank,
151 "PowerSGD compression not yet implemented; sending uncompressed gradients"
152 );
153 (CompressedData::Full(working_grads.clone()), None)
154 }
155 };
156
157 if self.use_error_feedback {
159 self.error_feedback = residual;
160 }
161
162 let result = CompressedGradient {
163 original_size,
164 strategy: self.strategy.clone(),
165 data,
166 };
167
168 debug!(
169 "Compressed {} floats, ratio={:.2}x",
170 original_size,
171 result.compression_ratio()
172 );
173
174 result
175 }
176
177 pub fn decompress(&self, compressed: &CompressedGradient) -> Vec<f32> {
179 match &compressed.data {
180 CompressedData::Full(v) => v.clone(),
181 CompressedData::Sparse { indices, values } => {
182 let mut result = vec![0.0f32; compressed.original_size];
183 for (&idx, &val) in indices.iter().zip(values.iter()) {
184 result[idx as usize] = val;
185 }
186 result
187 }
188 CompressedData::FP16(v) => v.iter().map(|&x| f16::from_bits(x).to_f32()).collect(),
189 CompressedData::BF16(v) => v.iter().map(|&x| bf16::from_bits(x).to_f32()).collect(),
190 CompressedData::INT8 { data, scale } => {
191 data.iter().map(|&x| x as f32 * scale).collect()
192 }
193 CompressedData::OneBit { signs, scale } => {
194 let mut result = Vec::with_capacity(compressed.original_size);
195 for byte in signs {
196 for bit in 0..8 {
197 if result.len() >= compressed.original_size {
198 break;
199 }
200 let sign = if (byte >> bit) & 1 == 1 { 1.0 } else { -1.0 };
201 result.push(sign * scale);
202 }
203 }
204 result
205 }
206 }
207 }
208
209 fn compress_topk(&self, gradients: &[f32], ratio: f32) -> (CompressedData, Option<Vec<f32>>) {
211 let k = ((gradients.len() as f32 * ratio) as usize).max(1);
212
213 let mut heap: BinaryHeap<std::cmp::Reverse<(ordered_float::OrderedFloat<f32>, u32)>> =
215 BinaryHeap::with_capacity(k + 1);
216
217 for (i, &val) in gradients.iter().enumerate() {
218 let abs_val = ordered_float::OrderedFloat(val.abs());
219 heap.push(std::cmp::Reverse((abs_val, i as u32)));
220 if heap.len() > k {
221 heap.pop();
222 }
223 }
224
225 let mut indices: Vec<u32> = heap.iter().map(|x| x.0.1).collect();
227 indices.sort_unstable();
228
229 let values: Vec<f32> = indices.iter().map(|&i| gradients[i as usize]).collect();
230
231 let mut residual = gradients.to_vec();
233 for &idx in &indices {
234 residual[idx as usize] = 0.0;
235 }
236
237 (CompressedData::Sparse { indices, values }, Some(residual))
238 }
239
240 fn compress_random(
242 &mut self,
243 gradients: &[f32],
244 probability: f32,
245 ) -> (CompressedData, Option<Vec<f32>>) {
246 let mut indices = Vec::new();
247 let mut values = Vec::new();
248 let mut residual = gradients.to_vec();
249
250 let mut rng = self.rng_seed;
252
253 for (i, &val) in gradients.iter().enumerate() {
254 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
256 let rand_val = (rng >> 33) as f32 / (u32::MAX >> 1) as f32;
257
258 if rand_val < probability {
259 indices.push(i as u32);
260 values.push(val / probability); residual[i] = 0.0;
262 }
263 }
264
265 self.rng_seed = rng;
266 (CompressedData::Sparse { indices, values }, Some(residual))
267 }
268
269 fn quantize(&self, gradients: &[f32], qtype: QuantizationType) -> CompressedData {
271 match qtype {
272 QuantizationType::FP16 => {
273 let data: Vec<u16> = gradients
274 .iter()
275 .map(|&x| f16::from_f32(x).to_bits())
276 .collect();
277 CompressedData::FP16(data)
278 }
279 QuantizationType::BF16 => {
280 let data: Vec<u16> = gradients
281 .iter()
282 .map(|&x| bf16::from_f32(x).to_bits())
283 .collect();
284 CompressedData::BF16(data)
285 }
286 QuantizationType::INT8 => {
287 let max_abs = gradients
288 .iter()
289 .map(|x| x.abs())
290 .fold(0.0f32, |a, b| a.max(b));
291 let scale = if max_abs == 0.0 { 1.0 } else { max_abs / 127.0 };
292
293 let data: Vec<i8> = gradients
294 .iter()
295 .map(|&x| (x / scale).clamp(-127.0, 127.0) as i8)
296 .collect();
297
298 CompressedData::INT8 { data, scale }
299 }
300 QuantizationType::OneBit => {
301 let mean_abs =
302 gradients.iter().map(|x| x.abs()).sum::<f32>() / gradients.len() as f32;
303
304 let num_bytes = gradients.len().div_ceil(8);
305 let mut signs = vec![0u8; num_bytes];
306
307 for (i, &val) in gradients.iter().enumerate() {
308 if val > 0.0 {
309 signs[i / 8] |= 1 << (i % 8);
310 }
311 }
312
313 CompressedData::OneBit {
314 signs,
315 scale: mean_abs,
316 }
317 }
318 }
319 }
320
321 pub fn reset_error_feedback(&mut self) {
323 self.error_feedback = None;
324 }
325}
326
327pub fn serialize_compressed(compressed: &CompressedGradient) -> Vec<u8> {
329 let mut result = Vec::new();
330
331 result.extend_from_slice(&(compressed.original_size as u32).to_le_bytes());
333
334 match &compressed.data {
335 CompressedData::Full(v) => {
336 result.push(0u8);
337 for f in v {
338 result.extend_from_slice(&f.to_le_bytes());
339 }
340 }
341 CompressedData::Sparse { indices, values } => {
342 result.push(1u8);
343 result.extend_from_slice(&(indices.len() as u32).to_le_bytes());
344 for &idx in indices {
345 result.extend_from_slice(&idx.to_le_bytes());
346 }
347 for &val in values {
348 result.extend_from_slice(&val.to_le_bytes());
349 }
350 }
351 CompressedData::FP16(v) => {
352 result.push(2u8);
353 for &x in v {
354 result.extend_from_slice(&x.to_le_bytes());
355 }
356 }
357 CompressedData::BF16(v) => {
358 result.push(3u8);
359 for &x in v {
360 result.extend_from_slice(&x.to_le_bytes());
361 }
362 }
363 CompressedData::INT8 { data, scale } => {
364 result.push(4u8);
365 result.extend_from_slice(&scale.to_le_bytes());
366 result.extend_from_slice(data.iter().map(|&x| x as u8).collect::<Vec<_>>().as_slice());
367 }
368 CompressedData::OneBit { signs, scale } => {
369 result.push(5u8);
370 result.extend_from_slice(&scale.to_le_bytes());
371 result.extend_from_slice(signs);
372 }
373 }
374
375 result
376}
377
378pub fn deserialize_compressed(bytes: &[u8]) -> Option<CompressedGradient> {
380 if bytes.len() < 5 {
381 return None;
382 }
383
384 let original_size = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
385 let strategy_id = bytes[4];
386
387 let data = match strategy_id {
388 0 => {
389 let floats: Vec<f32> = bytes[5..]
391 .chunks_exact(4)
392 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
393 .collect();
394 CompressedData::Full(floats)
395 }
396 1 => {
397 let num_indices = u32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]) as usize;
399 let indices_end = 9 + num_indices * 4;
400 let indices: Vec<u32> = bytes[9..indices_end]
401 .chunks_exact(4)
402 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
403 .collect();
404 let values: Vec<f32> = bytes[indices_end..]
405 .chunks_exact(4)
406 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
407 .collect();
408 CompressedData::Sparse { indices, values }
409 }
410 2 => {
411 let data: Vec<u16> = bytes[5..]
413 .chunks_exact(2)
414 .map(|c| u16::from_le_bytes([c[0], c[1]]))
415 .collect();
416 CompressedData::FP16(data)
417 }
418 3 => {
419 let data: Vec<u16> = bytes[5..]
421 .chunks_exact(2)
422 .map(|c| u16::from_le_bytes([c[0], c[1]]))
423 .collect();
424 CompressedData::BF16(data)
425 }
426 4 => {
427 let scale = f32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
429 let data: Vec<i8> = bytes[9..].iter().map(|&x| x as i8).collect();
430 CompressedData::INT8 { data, scale }
431 }
432 5 => {
433 let scale = f32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
435 let signs = bytes[9..].to_vec();
436 CompressedData::OneBit { signs, scale }
437 }
438 _ => return None,
439 };
440
441 Some(CompressedGradient {
442 original_size,
443 strategy: CompressionStrategy::None, data,
445 })
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_no_compression() {
454 let mut compressor = GradientCompressor::new(CompressionStrategy::None, false);
455 let grads = vec![1.0, 2.0, 3.0, 4.0];
456
457 let compressed = compressor.compress(&grads);
458 let decompressed = compressor.decompress(&compressed);
459
460 assert_eq!(grads, decompressed);
461 assert!((compressed.compression_ratio() - 1.0).abs() < 0.01);
462 }
463
464 #[test]
465 fn test_topk() {
466 let mut compressor =
467 GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, false);
468 let grads = vec![1.0, 4.0, 2.0, 3.0];
469
470 let compressed = compressor.compress(&grads);
471 let decompressed = compressor.decompress(&compressed);
472
473 assert!(decompressed[1] == 4.0);
475 assert!(decompressed[3] == 3.0);
476 assert!(decompressed[0] == 0.0);
477 assert!(decompressed[2] == 0.0);
478
479 assert!(compressed.compression_ratio() >= 1.0);
482 }
483
484 #[test]
485 fn test_fp16_quantization() {
486 let mut compressor =
487 GradientCompressor::new(CompressionStrategy::Quantize(QuantizationType::FP16), false);
488 let grads = vec![1.0, 2.5, 3.125, 4.0];
489
490 let compressed = compressor.compress(&grads);
491 let decompressed = compressor.decompress(&compressed);
492
493 for (orig, decomp) in grads.iter().zip(decompressed.iter()) {
495 assert!((orig - decomp).abs() < 0.01);
496 }
497
498 assert!((compressed.compression_ratio() - 2.0).abs() < 0.1);
500 }
501
502 #[test]
503 fn test_int8_quantization() {
504 let mut compressor =
505 GradientCompressor::new(CompressionStrategy::Quantize(QuantizationType::INT8), false);
506 let grads = vec![1.0, 2.0, 3.0, 4.0];
507
508 let compressed = compressor.compress(&grads);
509 let decompressed = compressor.decompress(&compressed);
510
511 for (orig, decomp) in grads.iter().zip(decompressed.iter()) {
513 assert!((orig - decomp).abs() < 0.1);
514 }
515
516 assert!(compressed.compression_ratio() >= 2.0);
518 }
519
520 #[test]
521 fn test_serialization_roundtrip() {
522 let mut compressor =
523 GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, false);
524 let grads = vec![1.0, 4.0, 2.0, 3.0];
525
526 let compressed = compressor.compress(&grads);
527 let bytes = serialize_compressed(&compressed);
528 let restored = deserialize_compressed(&bytes).unwrap();
529
530 let decompressed = compressor.decompress(&restored);
531
532 assert!(decompressed[1] == 4.0);
534 assert!(decompressed[3] == 3.0);
535 }
536
537 #[test]
538 fn test_error_feedback() {
539 let mut compressor =
540 GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, true);
541
542 let grads1 = vec![1.0, 4.0, 2.0, 3.0];
544 let _compressed1 = compressor.compress(&grads1);
545
546 let grads2 = vec![0.1, 0.1, 0.1, 0.1];
548 let compressed2 = compressor.compress(&grads2);
549 let decompressed2 = compressor.decompress(&compressed2);
550
551 assert!(decompressed2.iter().any(|&x| x > 1.0));
554 }
555}