Skip to main content

ailake_vec/
quantize.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use half::f16;
3
4#[derive(Debug, Clone, Copy)]
5pub struct ScalingParams {
6    pub scale: f32,
7    pub zero_point: f32,
8}
9
10pub struct Quantizer;
11
12impl Quantizer {
13    pub fn f32_to_f16_bytes(v: &[f32]) -> Vec<u8> {
14        let mut out = Vec::with_capacity(v.len() * 2);
15        for &x in v {
16            out.extend_from_slice(&f16::from_f32(x).to_le_bytes());
17        }
18        out
19    }
20
21    pub fn f16_bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
22        bytes
23            .chunks_exact(2)
24            .map(|b| f16::from_le_bytes([b[0], b[1]]).to_f32())
25            .collect()
26    }
27
28    pub fn f32_to_i8(v: &[f32]) -> (Vec<i8>, ScalingParams) {
29        let min = v.iter().cloned().fold(f32::INFINITY, f32::min);
30        let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
31        let range = max - min;
32        let scale = if range == 0.0 { 1.0 } else { range / 254.0 };
33        let zero_point = -128.0 - min / scale;
34        let quant = v
35            .iter()
36            .map(|&x| ((x / scale + zero_point).round().clamp(-128.0, 127.0)) as i8)
37            .collect();
38        (quant, ScalingParams { scale, zero_point })
39    }
40
41    pub fn i8_to_f32(v: &[i8], params: &ScalingParams) -> Vec<f32> {
42        v.iter()
43            .map(|&x| (x as f32 - params.zero_point) * params.scale)
44            .collect()
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51
52    #[test]
53    fn f16_roundtrip() {
54        let original: Vec<f32> = vec![0.1, -0.5, 1.0, 0.0, 100.0];
55        let bytes = Quantizer::f32_to_f16_bytes(&original);
56        let decoded = Quantizer::f16_bytes_to_f32(&bytes);
57        for (a, b) in original.iter().zip(decoded.iter()) {
58            assert!((a - b).abs() < 0.01, "f16 roundtrip error: {a} vs {b}");
59        }
60    }
61
62    #[test]
63    fn i8_roundtrip() {
64        let original: Vec<f32> = vec![0.0, 0.25, 0.5, 0.75, 1.0];
65        let (quant, params) = Quantizer::f32_to_i8(&original);
66        let decoded = Quantizer::i8_to_f32(&quant, &params);
67        for (a, b) in original.iter().zip(decoded.iter()) {
68            assert!((a - b).abs() < 0.02, "i8 roundtrip error: {a} vs {b}");
69        }
70    }
71}