1use half::f16;
3
4#[derive(Debug, Clone, Copy)]
5pub struct ScalingParams {
6 pub scale: f32,
7 pub zero_point: f32,
8}
9
10pub struct Quantizer;
11
12impl Quantizer {
13 pub fn f32_to_f16_bytes(v: &[f32]) -> Vec<u8> {
14 let mut out = Vec::with_capacity(v.len() * 2);
15 for &x in v {
16 out.extend_from_slice(&f16::from_f32(x).to_le_bytes());
17 }
18 out
19 }
20
21 pub fn f16_bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
22 bytes
23 .chunks_exact(2)
24 .map(|b| f16::from_le_bytes([b[0], b[1]]).to_f32())
25 .collect()
26 }
27
28 pub fn f32_to_i8(v: &[f32]) -> (Vec<i8>, ScalingParams) {
29 let min = v.iter().cloned().fold(f32::INFINITY, f32::min);
30 let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
31 let range = max - min;
32 let scale = if range == 0.0 { 1.0 } else { range / 254.0 };
33 let zero_point = -128.0 - min / scale;
34 let quant = v
35 .iter()
36 .map(|&x| ((x / scale + zero_point).round().clamp(-128.0, 127.0)) as i8)
37 .collect();
38 (quant, ScalingParams { scale, zero_point })
39 }
40
41 pub fn i8_to_f32(v: &[i8], params: &ScalingParams) -> Vec<f32> {
42 v.iter()
43 .map(|&x| (x as f32 - params.zero_point) * params.scale)
44 .collect()
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use super::*;
51
52 #[test]
53 fn f16_roundtrip() {
54 let original: Vec<f32> = vec![0.1, -0.5, 1.0, 0.0, 100.0];
55 let bytes = Quantizer::f32_to_f16_bytes(&original);
56 let decoded = Quantizer::f16_bytes_to_f32(&bytes);
57 for (a, b) in original.iter().zip(decoded.iter()) {
58 assert!((a - b).abs() < 0.01, "f16 roundtrip error: {a} vs {b}");
59 }
60 }
61
62 #[test]
63 fn i8_roundtrip() {
64 let original: Vec<f32> = vec![0.0, 0.25, 0.5, 0.75, 1.0];
65 let (quant, params) = Quantizer::f32_to_i8(&original);
66 let decoded = Quantizer::i8_to_f32(&quant, ¶ms);
67 for (a, b) in original.iter().zip(decoded.iter()) {
68 assert!((a - b).abs() < 0.02, "i8 roundtrip error: {a} vs {b}");
69 }
70 }
71}