Skip to main content

pmetal_distributed/
activation_codec.rs

1//! Activation compression for pipeline inference.
2//!
3//! Reduces bandwidth between pipeline stages by compressing activations.
4//! Initial implementation: fp16 cast (same as dnet default wire dtype).
5
6use half::f16;
7
8/// Compression codec for activations transferred between pipeline stages.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10pub enum ActivationCodec {
11    /// No compression — transfer as-is.
12    None,
13    /// Cast to fp16 for transfer (2x compression for f32 activations).
14    #[default]
15    Float16,
16    /// Column-wise sparsity: keep top-k% columns by L2 norm.
17    SparseColumns {
18        /// Fraction of columns to keep (0.0-1.0, default 0.1 = top 10%).
19        keep_ratio: u32, // stored as ratio * 1000 to avoid float
20    },
21}
22
23/// Compress f32 activations to fp16 bytes.
24///
25/// Input: `&[f32]` of hidden states
26/// Output: `Vec<u8>` of fp16 bytes (half the size)
27pub fn compress_f32_to_f16(data: &[f32]) -> Vec<u8> {
28    let mut out = Vec::with_capacity(data.len() * 2);
29    for &val in data {
30        let h = f16::from_f32(val);
31        out.extend_from_slice(&h.to_le_bytes());
32    }
33    out
34}
35
36/// Decompress fp16 bytes back to f32.
37///
38/// Input: `&[u8]` of fp16 bytes
39/// Output: `Vec<f32>` of f32 values
40pub fn decompress_f16_to_f32(data: &[u8]) -> Vec<f32> {
41    assert!(
42        data.len().is_multiple_of(2),
43        "fp16 data must be even length"
44    );
45    let mut out = Vec::with_capacity(data.len() / 2);
46    for chunk in data.chunks_exact(2) {
47        let h = f16::from_le_bytes([chunk[0], chunk[1]]);
48        out.push(h.to_f32());
49    }
50    out
51}
52
53/// Compress activations according to the codec.
54///
55/// Returns compressed bytes and a tag indicating the codec used.
56pub fn compress_activation(data: &[u8], src_is_f32: bool, codec: ActivationCodec) -> Vec<u8> {
57    match codec {
58        ActivationCodec::None => data.to_vec(),
59        ActivationCodec::Float16 => {
60            if src_is_f32 {
61                // Reinterpret bytes as f32 via zerocopy-safe conversion
62                let f32_data: Vec<f32> = data
63                    .chunks_exact(4)
64                    .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
65                    .collect();
66                compress_f32_to_f16(&f32_data)
67            } else {
68                data.to_vec()
69            }
70        }
71        ActivationCodec::SparseColumns { .. } => {
72            // Future: column-wise sparsity
73            data.to_vec()
74        }
75    }
76}
77
78/// Decompress activations back to the original format.
79pub fn decompress_activation(data: &[u8], codec: ActivationCodec, target_is_f32: bool) -> Vec<u8> {
80    match codec {
81        ActivationCodec::None => data.to_vec(),
82        ActivationCodec::Float16 => {
83            if target_is_f32 {
84                let f32_vals = decompress_f16_to_f32(data);
85                let mut bytes = Vec::with_capacity(f32_vals.len() * 4);
86                for val in &f32_vals {
87                    bytes.extend_from_slice(&val.to_le_bytes());
88                }
89                bytes
90            } else {
91                data.to_vec()
92            }
93        }
94        ActivationCodec::SparseColumns { .. } => data.to_vec(),
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn f16_roundtrip() {
104        let original = vec![1.0f32, 2.0, 3.5, -0.5, 0.0, 100.0];
105        let compressed = compress_f32_to_f16(&original);
106        assert_eq!(compressed.len(), original.len() * 2);
107
108        let decompressed = decompress_f16_to_f32(&compressed);
109        assert_eq!(decompressed.len(), original.len());
110
111        for (orig, decomp) in original.iter().zip(decompressed.iter()) {
112            let diff = (orig - decomp).abs();
113            // fp16 has ~3 decimal digits of precision
114            assert!(diff < 0.1, "f16 roundtrip drift: {orig} -> {decomp}");
115        }
116    }
117}