1use std::mem::MaybeUninit;
2
3use rten_simd::ops::{FloatOps, NarrowSaturate, NumOps};
4use rten_simd::{Isa, SimdOp, SliceWriter};
5
6pub struct Quantize<'s, 'd, To> {
15 src: &'s [f32],
16 dest: &'d mut [MaybeUninit<To>],
17 inv_scale: f32,
18 zero_point: To,
19}
20
21impl<'s, 'd, To> Quantize<'s, 'd, To> {
22 pub fn new(
23 src: &'s [f32],
24 dest: &'d mut [MaybeUninit<To>],
25 inv_scale: f32,
26 zero_point: To,
27 ) -> Self {
28 assert_eq!(src.len(), dest.len());
29 Quantize {
30 src,
31 dest,
32 inv_scale,
33 zero_point,
34 }
35 }
36}
37
38impl<'d> SimdOp for Quantize<'_, 'd, u8> {
39 type Output = &'d mut [u8];
40
41 #[inline(always)]
42 fn eval<I: Isa>(self, isa: I) -> Self::Output {
43 let src_ops = isa.f32();
44 let i32_ops = isa.i32();
45
46 let zp_vec = i32_ops.splat(self.zero_point as i32);
47 let scale_vec = src_ops.splat(self.inv_scale);
48 let f32_v_len = src_ops.len();
49
50 let mut src_chunks = self.src.chunks_exact(f32_v_len * 4);
53 let mut dest_writer = SliceWriter::new(self.dest);
54
55 for src_chunk in src_chunks.by_ref() {
56 let src = src_ops.load_many::<4>(src_chunk);
57 let quant_i32 = src.map(|x| {
58 let y = src_ops.mul(x, scale_vec);
59 let y = src_ops.to_int_round(y);
60 i32_ops.add(y, zp_vec)
61 });
62 let quant_i16_low = i32_ops.narrow_saturate(quant_i32[0], quant_i32[1]);
63 let quant_i16_high = i32_ops.narrow_saturate(quant_i32[2], quant_i32[3]);
64 let quant_u8 = isa.i16().narrow_saturate(quant_i16_low, quant_i16_high);
65 dest_writer.write_vec(isa.u8(), quant_u8);
66 }
67
68 for src in src_chunks.remainder() {
70 let y = (src * self.inv_scale).round_ties_even() as i32;
71 let y = (y + self.zero_point as i32).clamp(0, u8::MAX as i32);
72 dest_writer.write_scalar(y as u8);
73 }
74
75 dest_writer.into_mut_slice()
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use rten_simd::ops::NumOps;
82 use rten_simd::{Isa, SimdOp};
83
84 use super::Quantize;
85
86 fn reference_quantize(src: &[f32], inv_scale: f32, zero_point: u8) -> Vec<u8> {
87 src.iter()
88 .map(|x| {
89 let tmp = (x * inv_scale).round_ties_even() + zero_point as f32;
90 tmp as u8 })
92 .collect()
93 }
94
95 fn u8_vec_len() -> usize {
97 struct U8VecLen {}
98 impl SimdOp for U8VecLen {
99 type Output = usize;
100 fn eval<I: Isa>(self, isa: I) -> usize {
101 isa.u8().len()
102 }
103 }
104 U8VecLen {}.dispatch()
105 }
106
107 #[test]
108 fn test_quantize() {
109 let mut rng = fastrand::Rng::with_seed(1234);
110
111 let len = u8_vec_len() + 1;
114 let src: Vec<f32> = std::iter::from_fn(|| Some(rng.f32())).take(len).collect();
115 let inv_scale = 5.2;
116 let zero_point = 10;
117 let expected = reference_quantize(&src, inv_scale, zero_point);
118
119 let mut buf = Vec::with_capacity(src.len());
120 let actual = &mut buf.spare_capacity_mut();
121 let actual = Quantize::new(&src, actual, inv_scale, zero_point).dispatch();
122
123 assert_eq!(actual, expected);
124 }
125}