1use half::{bf16, f16};
15use serde::{Deserialize, Serialize};
16use std::collections::BinaryHeap;
17use tracing::debug;
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub enum CompressionStrategy {
22 #[default]
24 None,
25 TopK { ratio: f32 },
27 Random { probability: f32 },
29 Quantize(QuantizationType),
31 PowerSGD { rank: usize },
33}
34
35#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
37pub enum QuantizationType {
38 FP16,
40 BF16,
42 INT8,
44 OneBit,
46}
47
48#[derive(Debug, Clone)]
50pub struct CompressedGradient {
51 pub original_size: usize,
53 pub strategy: CompressionStrategy,
55 pub data: CompressedData,
57}
58
59#[derive(Debug, Clone)]
61pub enum CompressedData {
62 Full(Vec<f32>),
64 Sparse { indices: Vec<u32>, values: Vec<f32> },
66 FP16(Vec<u16>),
68 BF16(Vec<u16>),
70 INT8 { data: Vec<i8>, scale: f32 },
72 OneBit { signs: Vec<u8>, scale: f32 },
74}
75
76impl CompressedGradient {
77 pub fn compression_ratio(&self) -> f32 {
79 let original_bytes = self.original_size * 4;
80 let compressed_bytes = self.compressed_bytes();
81 original_bytes as f32 / compressed_bytes as f32
82 }
83
84 pub fn compressed_bytes(&self) -> usize {
86 match &self.data {
87 CompressedData::Full(v) => v.len() * 4,
88 CompressedData::Sparse { indices, values } => indices.len() * 4 + values.len() * 4,
89 CompressedData::FP16(v) => v.len() * 2,
90 CompressedData::BF16(v) => v.len() * 2,
91 CompressedData::INT8 { data, .. } => data.len() + 4,
92 CompressedData::OneBit { signs, .. } => signs.len() + 4,
93 }
94 }
95}
96
97pub struct GradientCompressor {
99 strategy: CompressionStrategy,
101 error_feedback: Option<Vec<f32>>,
103 use_error_feedback: bool,
105 rng_seed: u64,
107}
108
109impl GradientCompressor {
110 pub fn new(strategy: CompressionStrategy, use_error_feedback: bool) -> Self {
112 Self {
113 strategy,
114 error_feedback: None,
115 use_error_feedback,
116 rng_seed: 42,
117 }
118 }
119
120 pub fn compress(&mut self, gradients: &[f32]) -> CompressedGradient {
122 let original_size = gradients.len();
123
124 let working_grads = if self.use_error_feedback {
126 if let Some(ref error) = self.error_feedback {
127 gradients
128 .iter()
129 .zip(error.iter())
130 .map(|(g, e)| g + e)
131 .collect()
132 } else {
133 gradients.to_vec()
134 }
135 } else {
136 gradients.to_vec()
137 };
138
139 let (data, residual) = match &self.strategy {
140 CompressionStrategy::None => (CompressedData::Full(working_grads.clone()), None),
141 CompressionStrategy::TopK { ratio } => self.compress_topk(&working_grads, *ratio),
142 CompressionStrategy::Random { probability } => {
143 self.compress_random(&working_grads, *probability)
144 }
145 CompressionStrategy::Quantize(qtype) => (self.quantize(&working_grads, *qtype), None),
146 CompressionStrategy::PowerSGD { rank: _ } => {
147 (CompressedData::Full(working_grads.clone()), None)
149 }
150 };
151
152 if self.use_error_feedback {
154 self.error_feedback = residual;
155 }
156
157 let result = CompressedGradient {
158 original_size,
159 strategy: self.strategy.clone(),
160 data,
161 };
162
163 debug!(
164 "Compressed {} floats, ratio={:.2}x",
165 original_size,
166 result.compression_ratio()
167 );
168
169 result
170 }
171
172 pub fn decompress(&self, compressed: &CompressedGradient) -> Vec<f32> {
174 match &compressed.data {
175 CompressedData::Full(v) => v.clone(),
176 CompressedData::Sparse { indices, values } => {
177 let mut result = vec![0.0f32; compressed.original_size];
178 for (&idx, &val) in indices.iter().zip(values.iter()) {
179 result[idx as usize] = val;
180 }
181 result
182 }
183 CompressedData::FP16(v) => v.iter().map(|&x| f16::from_bits(x).to_f32()).collect(),
184 CompressedData::BF16(v) => v.iter().map(|&x| bf16::from_bits(x).to_f32()).collect(),
185 CompressedData::INT8 { data, scale } => {
186 data.iter().map(|&x| x as f32 * scale).collect()
187 }
188 CompressedData::OneBit { signs, scale } => {
189 let mut result = Vec::with_capacity(compressed.original_size);
190 for byte in signs {
191 for bit in 0..8 {
192 if result.len() >= compressed.original_size {
193 break;
194 }
195 let sign = if (byte >> bit) & 1 == 1 { 1.0 } else { -1.0 };
196 result.push(sign * scale);
197 }
198 }
199 result
200 }
201 }
202 }
203
204 fn compress_topk(&self, gradients: &[f32], ratio: f32) -> (CompressedData, Option<Vec<f32>>) {
206 let k = ((gradients.len() as f32 * ratio) as usize).max(1);
207
208 let mut heap: BinaryHeap<std::cmp::Reverse<(ordered_float::OrderedFloat<f32>, u32)>> =
210 BinaryHeap::with_capacity(k + 1);
211
212 for (i, &val) in gradients.iter().enumerate() {
213 let abs_val = ordered_float::OrderedFloat(val.abs());
214 heap.push(std::cmp::Reverse((abs_val, i as u32)));
215 if heap.len() > k {
216 heap.pop();
217 }
218 }
219
220 let mut indices: Vec<u32> = heap.iter().map(|x| x.0.1).collect();
222 indices.sort_unstable();
223
224 let values: Vec<f32> = indices.iter().map(|&i| gradients[i as usize]).collect();
225
226 let mut residual = gradients.to_vec();
228 for &idx in &indices {
229 residual[idx as usize] = 0.0;
230 }
231
232 (CompressedData::Sparse { indices, values }, Some(residual))
233 }
234
235 fn compress_random(
237 &mut self,
238 gradients: &[f32],
239 probability: f32,
240 ) -> (CompressedData, Option<Vec<f32>>) {
241 let mut indices = Vec::new();
242 let mut values = Vec::new();
243 let mut residual = gradients.to_vec();
244
245 let mut rng = self.rng_seed;
247
248 for (i, &val) in gradients.iter().enumerate() {
249 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
251 let rand_val = (rng >> 33) as f32 / (u32::MAX >> 1) as f32;
252
253 if rand_val < probability {
254 indices.push(i as u32);
255 values.push(val / probability); residual[i] = 0.0;
257 }
258 }
259
260 self.rng_seed = rng;
261 (CompressedData::Sparse { indices, values }, Some(residual))
262 }
263
264 fn quantize(&self, gradients: &[f32], qtype: QuantizationType) -> CompressedData {
266 match qtype {
267 QuantizationType::FP16 => {
268 let data: Vec<u16> = gradients
269 .iter()
270 .map(|&x| f16::from_f32(x).to_bits())
271 .collect();
272 CompressedData::FP16(data)
273 }
274 QuantizationType::BF16 => {
275 let data: Vec<u16> = gradients
276 .iter()
277 .map(|&x| bf16::from_f32(x).to_bits())
278 .collect();
279 CompressedData::BF16(data)
280 }
281 QuantizationType::INT8 => {
282 let max_abs = gradients
283 .iter()
284 .map(|x| x.abs())
285 .fold(0.0f32, |a, b| a.max(b));
286 let scale = max_abs / 127.0;
287
288 let data: Vec<i8> = gradients
289 .iter()
290 .map(|&x| (x / scale).clamp(-127.0, 127.0) as i8)
291 .collect();
292
293 CompressedData::INT8 { data, scale }
294 }
295 QuantizationType::OneBit => {
296 let mean_abs =
297 gradients.iter().map(|x| x.abs()).sum::<f32>() / gradients.len() as f32;
298
299 let num_bytes = gradients.len().div_ceil(8);
300 let mut signs = vec![0u8; num_bytes];
301
302 for (i, &val) in gradients.iter().enumerate() {
303 if val > 0.0 {
304 signs[i / 8] |= 1 << (i % 8);
305 }
306 }
307
308 CompressedData::OneBit {
309 signs,
310 scale: mean_abs,
311 }
312 }
313 }
314 }
315
316 pub fn reset_error_feedback(&mut self) {
318 self.error_feedback = None;
319 }
320}
321
322pub fn serialize_compressed(compressed: &CompressedGradient) -> Vec<u8> {
324 let mut result = Vec::new();
325
326 result.extend_from_slice(&(compressed.original_size as u32).to_le_bytes());
328
329 match &compressed.data {
330 CompressedData::Full(v) => {
331 result.push(0u8);
332 for f in v {
333 result.extend_from_slice(&f.to_le_bytes());
334 }
335 }
336 CompressedData::Sparse { indices, values } => {
337 result.push(1u8);
338 result.extend_from_slice(&(indices.len() as u32).to_le_bytes());
339 for &idx in indices {
340 result.extend_from_slice(&idx.to_le_bytes());
341 }
342 for &val in values {
343 result.extend_from_slice(&val.to_le_bytes());
344 }
345 }
346 CompressedData::FP16(v) => {
347 result.push(2u8);
348 for &x in v {
349 result.extend_from_slice(&x.to_le_bytes());
350 }
351 }
352 CompressedData::BF16(v) => {
353 result.push(3u8);
354 for &x in v {
355 result.extend_from_slice(&x.to_le_bytes());
356 }
357 }
358 CompressedData::INT8 { data, scale } => {
359 result.push(4u8);
360 result.extend_from_slice(&scale.to_le_bytes());
361 result.extend_from_slice(data.iter().map(|&x| x as u8).collect::<Vec<_>>().as_slice());
362 }
363 CompressedData::OneBit { signs, scale } => {
364 result.push(5u8);
365 result.extend_from_slice(&scale.to_le_bytes());
366 result.extend_from_slice(signs);
367 }
368 }
369
370 result
371}
372
373pub fn deserialize_compressed(bytes: &[u8]) -> Option<CompressedGradient> {
375 if bytes.len() < 5 {
376 return None;
377 }
378
379 let original_size = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
380 let strategy_id = bytes[4];
381
382 let data = match strategy_id {
383 0 => {
384 let floats: Vec<f32> = bytes[5..]
386 .chunks_exact(4)
387 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
388 .collect();
389 CompressedData::Full(floats)
390 }
391 1 => {
392 let num_indices = u32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]) as usize;
394 let indices_end = 9 + num_indices * 4;
395 let indices: Vec<u32> = bytes[9..indices_end]
396 .chunks_exact(4)
397 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
398 .collect();
399 let values: Vec<f32> = bytes[indices_end..]
400 .chunks_exact(4)
401 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
402 .collect();
403 CompressedData::Sparse { indices, values }
404 }
405 2 => {
406 let data: Vec<u16> = bytes[5..]
408 .chunks_exact(2)
409 .map(|c| u16::from_le_bytes([c[0], c[1]]))
410 .collect();
411 CompressedData::FP16(data)
412 }
413 3 => {
414 let data: Vec<u16> = bytes[5..]
416 .chunks_exact(2)
417 .map(|c| u16::from_le_bytes([c[0], c[1]]))
418 .collect();
419 CompressedData::BF16(data)
420 }
421 4 => {
422 let scale = f32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
424 let data: Vec<i8> = bytes[9..].iter().map(|&x| x as i8).collect();
425 CompressedData::INT8 { data, scale }
426 }
427 5 => {
428 let scale = f32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
430 let signs = bytes[9..].to_vec();
431 CompressedData::OneBit { signs, scale }
432 }
433 _ => return None,
434 };
435
436 Some(CompressedGradient {
437 original_size,
438 strategy: CompressionStrategy::None, data,
440 })
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_no_compression() {
449 let mut compressor = GradientCompressor::new(CompressionStrategy::None, false);
450 let grads = vec![1.0, 2.0, 3.0, 4.0];
451
452 let compressed = compressor.compress(&grads);
453 let decompressed = compressor.decompress(&compressed);
454
455 assert_eq!(grads, decompressed);
456 assert!((compressed.compression_ratio() - 1.0).abs() < 0.01);
457 }
458
459 #[test]
460 fn test_topk() {
461 let mut compressor =
462 GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, false);
463 let grads = vec![1.0, 4.0, 2.0, 3.0];
464
465 let compressed = compressor.compress(&grads);
466 let decompressed = compressor.decompress(&compressed);
467
468 assert!(decompressed[1] == 4.0);
470 assert!(decompressed[3] == 3.0);
471 assert!(decompressed[0] == 0.0);
472 assert!(decompressed[2] == 0.0);
473
474 assert!(compressed.compression_ratio() >= 1.0);
477 }
478
479 #[test]
480 fn test_fp16_quantization() {
481 let mut compressor =
482 GradientCompressor::new(CompressionStrategy::Quantize(QuantizationType::FP16), false);
483 let grads = vec![1.0, 2.5, 3.125, 4.0];
484
485 let compressed = compressor.compress(&grads);
486 let decompressed = compressor.decompress(&compressed);
487
488 for (orig, decomp) in grads.iter().zip(decompressed.iter()) {
490 assert!((orig - decomp).abs() < 0.01);
491 }
492
493 assert!((compressed.compression_ratio() - 2.0).abs() < 0.1);
495 }
496
497 #[test]
498 fn test_int8_quantization() {
499 let mut compressor =
500 GradientCompressor::new(CompressionStrategy::Quantize(QuantizationType::INT8), false);
501 let grads = vec![1.0, 2.0, 3.0, 4.0];
502
503 let compressed = compressor.compress(&grads);
504 let decompressed = compressor.decompress(&compressed);
505
506 for (orig, decomp) in grads.iter().zip(decompressed.iter()) {
508 assert!((orig - decomp).abs() < 0.1);
509 }
510
511 assert!(compressed.compression_ratio() >= 2.0);
513 }
514
515 #[test]
516 fn test_serialization_roundtrip() {
517 let mut compressor =
518 GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, false);
519 let grads = vec![1.0, 4.0, 2.0, 3.0];
520
521 let compressed = compressor.compress(&grads);
522 let bytes = serialize_compressed(&compressed);
523 let restored = deserialize_compressed(&bytes).unwrap();
524
525 let decompressed = compressor.decompress(&restored);
526
527 assert!(decompressed[1] == 4.0);
529 assert!(decompressed[3] == 3.0);
530 }
531
532 #[test]
533 fn test_error_feedback() {
534 let mut compressor =
535 GradientCompressor::new(CompressionStrategy::TopK { ratio: 0.5 }, true);
536
537 let grads1 = vec![1.0, 4.0, 2.0, 3.0];
539 let _compressed1 = compressor.compress(&grads1);
540
541 let grads2 = vec![0.1, 0.1, 0.1, 0.1];
543 let compressed2 = compressor.compress(&grads2);
544 let decompressed2 = compressor.decompress(&compressed2);
545
546 assert!(decompressed2.iter().any(|&x| x > 1.0));
549 }
550}