rten_vecmath/
quantize.rs

1use std::mem::MaybeUninit;
2
3use rten_simd::ops::{FloatOps, NarrowSaturate, NumOps};
4use rten_simd::{Isa, SimdOp, SliceWriter};
5
6/// Quantize a slice of `f32` elements to 8-bit integers using the formula:
7///
8/// ```text
9/// y = saturate(round(x * inv_scale) + zero_point)
10/// ```
11///
12/// Where `round` rounds to the nearest `i32` value with ties to even and
13/// `saturate` converts `i32` to the small integer type `To` with saturation.
14pub struct Quantize<'s, 'd, To> {
15    src: &'s [f32],
16    dest: &'d mut [MaybeUninit<To>],
17    inv_scale: f32,
18    zero_point: To,
19}
20
21impl<'s, 'd, To> Quantize<'s, 'd, To> {
22    pub fn new(
23        src: &'s [f32],
24        dest: &'d mut [MaybeUninit<To>],
25        inv_scale: f32,
26        zero_point: To,
27    ) -> Self {
28        assert_eq!(src.len(), dest.len());
29        Quantize {
30            src,
31            dest,
32            inv_scale,
33            zero_point,
34        }
35    }
36}
37
38impl<'d> SimdOp for Quantize<'_, 'd, u8> {
39    type Output = &'d mut [u8];
40
41    #[inline(always)]
42    fn eval<I: Isa>(self, isa: I) -> Self::Output {
43        let src_ops = isa.f32();
44        let i32_ops = isa.i32();
45
46        let zp_vec = i32_ops.splat(self.zero_point as i32);
47        let scale_vec = src_ops.splat(self.inv_scale);
48        let f32_v_len = src_ops.len();
49
50        // Generate one vector of u8 elements in each iteration by quantizing
51        // 4 vectors of f32 elements.
52        let mut src_chunks = self.src.chunks_exact(f32_v_len * 4);
53        let mut dest_writer = SliceWriter::new(self.dest);
54
55        for src_chunk in src_chunks.by_ref() {
56            let src = src_ops.load_many::<4>(src_chunk);
57            let quant_i32 = src.map(|x| {
58                let y = src_ops.mul(x, scale_vec);
59                let y = src_ops.to_int_round(y);
60                i32_ops.add(y, zp_vec)
61            });
62            let quant_i16_low = i32_ops.narrow_saturate(quant_i32[0], quant_i32[1]);
63            let quant_i16_high = i32_ops.narrow_saturate(quant_i32[2], quant_i32[3]);
64            let quant_u8 = isa.i16().narrow_saturate(quant_i16_low, quant_i16_high);
65            dest_writer.write_vec(isa.u8(), quant_u8);
66        }
67
68        // Quantize tail elements.
69        for src in src_chunks.remainder() {
70            let y = (src * self.inv_scale).round_ties_even() as i32;
71            let y = (y + self.zero_point as i32).clamp(0, u8::MAX as i32);
72            dest_writer.write_scalar(y as u8);
73        }
74
75        dest_writer.into_mut_slice()
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use rten_simd::ops::NumOps;
82    use rten_simd::{Isa, SimdOp};
83
84    use super::Quantize;
85
86    fn reference_quantize(src: &[f32], inv_scale: f32, zero_point: u8) -> Vec<u8> {
87        src.iter()
88            .map(|x| {
89                let tmp = (x * inv_scale).round_ties_even() + zero_point as f32;
90                tmp as u8 // Saturating cast
91            })
92            .collect()
93    }
94
95    /// Return number of u8 lanes supported in a SIMD vector.
96    fn u8_vec_len() -> usize {
97        struct U8VecLen {}
98        impl SimdOp for U8VecLen {
99            type Output = usize;
100            fn eval<I: Isa>(self, isa: I) -> usize {
101                isa.u8().len()
102            }
103        }
104        U8VecLen {}.dispatch()
105    }
106
107    #[test]
108    fn test_quantize() {
109        let mut rng = fastrand::Rng::with_seed(1234);
110
111        // Larger than max u8 SIMD vector length, and not an exact multiple, so
112        // we have a tail.
113        let len = u8_vec_len() + 1;
114        let src: Vec<f32> = std::iter::from_fn(|| Some(rng.f32())).take(len).collect();
115        let inv_scale = 5.2;
116        let zero_point = 10;
117        let expected = reference_quantize(&src, inv_scale, zero_point);
118
119        let mut buf = Vec::with_capacity(src.len());
120        let actual = &mut buf.spare_capacity_mut();
121        let actual = Quantize::new(&src, actual, inv_scale, zero_point).dispatch();
122
123        assert_eq!(actual, expected);
124    }
125}