pmetal_distributed/
activation_codec.rs1use half::f16;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10pub enum ActivationCodec {
11 None,
13 #[default]
15 Float16,
16 SparseColumns {
18 keep_ratio: u32, },
21}
22
23pub fn compress_f32_to_f16(data: &[f32]) -> Vec<u8> {
28 let mut out = Vec::with_capacity(data.len() * 2);
29 for &val in data {
30 let h = f16::from_f32(val);
31 out.extend_from_slice(&h.to_le_bytes());
32 }
33 out
34}
35
36pub fn decompress_f16_to_f32(data: &[u8]) -> Vec<f32> {
41 assert!(
42 data.len().is_multiple_of(2),
43 "fp16 data must be even length"
44 );
45 let mut out = Vec::with_capacity(data.len() / 2);
46 for chunk in data.chunks_exact(2) {
47 let h = f16::from_le_bytes([chunk[0], chunk[1]]);
48 out.push(h.to_f32());
49 }
50 out
51}
52
53pub fn compress_activation(data: &[u8], src_is_f32: bool, codec: ActivationCodec) -> Vec<u8> {
57 match codec {
58 ActivationCodec::None => data.to_vec(),
59 ActivationCodec::Float16 => {
60 if src_is_f32 {
61 let f32_data: Vec<f32> = data
63 .chunks_exact(4)
64 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
65 .collect();
66 compress_f32_to_f16(&f32_data)
67 } else {
68 data.to_vec()
69 }
70 }
71 ActivationCodec::SparseColumns { .. } => {
72 data.to_vec()
74 }
75 }
76}
77
78pub fn decompress_activation(data: &[u8], codec: ActivationCodec, target_is_f32: bool) -> Vec<u8> {
80 match codec {
81 ActivationCodec::None => data.to_vec(),
82 ActivationCodec::Float16 => {
83 if target_is_f32 {
84 let f32_vals = decompress_f16_to_f32(data);
85 let mut bytes = Vec::with_capacity(f32_vals.len() * 4);
86 for val in &f32_vals {
87 bytes.extend_from_slice(&val.to_le_bytes());
88 }
89 bytes
90 } else {
91 data.to_vec()
92 }
93 }
94 ActivationCodec::SparseColumns { .. } => data.to_vec(),
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn f16_roundtrip() {
104 let original = vec![1.0f32, 2.0, 3.5, -0.5, 0.0, 100.0];
105 let compressed = compress_f32_to_f16(&original);
106 assert_eq!(compressed.len(), original.len() * 2);
107
108 let decompressed = decompress_f16_to_f32(&compressed);
109 assert_eq!(decompressed.len(), original.len());
110
111 for (orig, decomp) in original.iter().zip(decompressed.iter()) {
112 let diff = (orig - decomp).abs();
113 assert!(diff < 0.1, "f16 roundtrip drift: {orig} -> {decomp}");
115 }
116 }
117}