Skip to main content

ailake_vec/
quantize.rs

1use half::f16;
2
3#[derive(Debug, Clone, Copy)]
4pub struct ScalingParams {
5    pub scale: f32,
6    pub zero_point: f32,
7}
8
9pub struct Quantizer;
10
11impl Quantizer {
12    pub fn f32_to_f16_bytes(v: &[f32]) -> Vec<u8> {
13        let mut out = Vec::with_capacity(v.len() * 2);
14        for &x in v {
15            out.extend_from_slice(&f16::from_f32(x).to_le_bytes());
16        }
17        out
18    }
19
20    pub fn f16_bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
21        bytes
22            .chunks_exact(2)
23            .map(|b| f16::from_le_bytes([b[0], b[1]]).to_f32())
24            .collect()
25    }
26
27    pub fn f32_to_i8(v: &[f32]) -> (Vec<i8>, ScalingParams) {
28        let min = v.iter().cloned().fold(f32::INFINITY, f32::min);
29        let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
30        let range = max - min;
31        let scale = if range == 0.0 { 1.0 } else { range / 254.0 };
32        let zero_point = -128.0 - min / scale;
33        let quant = v
34            .iter()
35            .map(|&x| ((x / scale + zero_point).round().clamp(-128.0, 127.0)) as i8)
36            .collect();
37        (quant, ScalingParams { scale, zero_point })
38    }
39
40    pub fn i8_to_f32(v: &[i8], params: &ScalingParams) -> Vec<f32> {
41        v.iter()
42            .map(|&x| (x as f32 - params.zero_point) * params.scale)
43            .collect()
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    #[test]
52    fn f16_roundtrip() {
53        let original: Vec<f32> = vec![0.1, -0.5, 1.0, 0.0, 100.0];
54        let bytes = Quantizer::f32_to_f16_bytes(&original);
55        let decoded = Quantizer::f16_bytes_to_f32(&bytes);
56        for (a, b) in original.iter().zip(decoded.iter()) {
57            assert!((a - b).abs() < 0.01, "f16 roundtrip error: {a} vs {b}");
58        }
59    }
60
61    #[test]
62    fn i8_roundtrip() {
63        let original: Vec<f32> = vec![0.0, 0.25, 0.5, 0.75, 1.0];
64        let (quant, params) = Quantizer::f32_to_i8(&original);
65        let decoded = Quantizer::i8_to_f32(&quant, &params);
66        for (a, b) in original.iter().zip(decoded.iter()) {
67            assert!((a - b).abs() < 0.02, "i8 roundtrip error: {a} vs {b}");
68        }
69    }
70}