1use crate::accumulator::BinnedAccumulatorF64;
15
16#[derive(Debug, Clone, Copy)]
25pub struct QuantParamsI8 {
26 pub scale: f64,
28 pub zero_point: i8,
30}
31
32impl QuantParamsI8 {
33 pub fn new(scale: f64, zero_point: i8) -> Self {
35 QuantParamsI8 { scale, zero_point }
36 }
37
38 #[inline]
40 pub fn dequantize(&self, v: i8) -> f64 {
41 self.scale * (v as i64 - self.zero_point as i64) as f64
42 }
43
44 pub fn dequantize_slice(&self, src: &[i8]) -> Vec<f64> {
46 src.iter().map(|&v| self.dequantize(v)).collect()
47 }
48}
49
50#[derive(Debug, Clone, Copy)]
55pub struct QuantParamsI4 {
56 pub scale: f64,
58 pub zero_point: i8,
60}
61
62impl QuantParamsI4 {
63 pub fn new(scale: f64, zero_point: i8) -> Self {
64 assert!(zero_point >= -8 && zero_point <= 7, "i4 zero_point must be in [-8, 7]");
65 QuantParamsI4 { scale, zero_point }
66 }
67
68 #[inline]
70 pub fn unpack_byte(byte: u8) -> (i8, i8) {
71 let hi = (((byte >> 4) & 0x0F) as i8) << 4 >> 4;
73 let lo = ((byte & 0x0F) as i8) << 4 >> 4;
74 (hi, lo)
75 }
76
77 #[inline]
79 pub fn dequantize(&self, v: i8) -> f64 {
80 self.scale * (v as i64 - self.zero_point as i64) as f64
81 }
82}
83
84#[inline]
90pub fn saturating_mul_i8(a: i8, b: i8) -> i32 {
91 (a as i32) * (b as i32)
92 }
94
95#[inline]
99pub fn saturating_dot_i8(a: &[i8], b: &[i8]) -> i32 {
100 debug_assert_eq!(a.len(), b.len());
101 let mut sum: i32 = 0;
102 for i in 0..a.len() {
103 let prod = (a[i] as i32) * (b[i] as i32);
104 sum = sum.saturating_add(prod);
105 }
106 sum
107}
108
109pub fn quantized_matmul_i8(
128 a: &[i8], b: &[i8], out: &mut [f64],
129 m: usize, k: usize, n: usize,
130 params_a: &QuantParamsI8, params_b: &QuantParamsI8,
131) {
132 debug_assert_eq!(a.len(), m * k);
133 debug_assert_eq!(b.len(), k * n);
134 debug_assert_eq!(out.len(), m * n);
135
136 let combined_scale = params_a.scale * params_b.scale;
138
139 for i in 0..m {
140 for j in 0..n {
141 let mut acc = BinnedAccumulatorF64::new();
142 for p in 0..k {
143 let int_prod = (a[i * k + p] as i64 - params_a.zero_point as i64)
145 * (b[p * n + j] as i64 - params_b.zero_point as i64);
146 acc.add(combined_scale * int_prod as f64);
148 }
149 out[i * n + j] = acc.finalize();
150 }
151 }
152}
153
154pub fn quantized_dot_i8(
158 a: &[i8], b: &[i8],
159 params_a: &QuantParamsI8, params_b: &QuantParamsI8,
160) -> f64 {
161 debug_assert_eq!(a.len(), b.len());
162 let combined_scale = params_a.scale * params_b.scale;
163 let mut acc = BinnedAccumulatorF64::new();
164 for i in 0..a.len() {
165 let int_prod = (a[i] as i64 - params_a.zero_point as i64)
166 * (b[i] as i64 - params_b.zero_point as i64);
167 acc.add(combined_scale * int_prod as f64);
168 }
169 acc.finalize()
170}
171
172pub fn quantized_sum_i8(values: &[i8], params: &QuantParamsI8) -> f64 {
174 let mut acc = BinnedAccumulatorF64::new();
175 for &v in values {
176 acc.add(params.dequantize(v));
177 }
178 acc.finalize()
179}
180
181pub fn quantized_sum_i4(packed: &[u8], count: usize, params: &QuantParamsI4) -> f64 {
186 let mut acc = BinnedAccumulatorF64::new();
187 let mut remaining = count;
188 for &byte in packed {
189 if remaining == 0 { break; }
190 let (hi, lo) = QuantParamsI4::unpack_byte(byte);
191 acc.add(params.dequantize(hi));
192 remaining -= 1;
193 if remaining == 0 { break; }
194 acc.add(params.dequantize(lo));
195 remaining -= 1;
196 }
197 acc.finalize()
198}
199
200#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_dequantize_i8_basic() {
210 let params = QuantParamsI8::new(0.1, 0);
211 assert_eq!(params.dequantize(10), 1.0);
212 assert_eq!(params.dequantize(-10), -1.0);
213 assert_eq!(params.dequantize(0), 0.0);
214 }
215
216 #[test]
217 fn test_dequantize_i8_with_zero_point() {
218 let params = QuantParamsI8::new(0.5, 10);
219 assert_eq!(params.dequantize(20), 5.0);
221 assert_eq!(params.dequantize(10), 0.0);
223 }
224
225 #[test]
226 fn test_saturating_dot_i8() {
227 let a = vec![1i8, 2, 3, 4];
228 let b = vec![5i8, 6, 7, 8];
229 assert_eq!(saturating_dot_i8(&a, &b), 70); }
231
232 #[test]
233 fn test_saturating_dot_overflow() {
234 let a = vec![127i8; 1000];
236 let b = vec![127i8; 1000];
237 let result = saturating_dot_i8(&a, &b);
238 assert_eq!(result, 16_129_000);
240 }
241
242 #[test]
243 fn test_quantized_matmul_identity() {
244 let params = QuantParamsI8::new(1.0, 0);
246 let a = vec![1i8, 0, 0, 1]; let b = vec![3i8, 4, 5, 6];
248 let mut out = vec![0.0f64; 4];
249 quantized_matmul_i8(&a, &b, &mut out, 2, 2, 2, ¶ms, ¶ms);
250 assert_eq!(out, vec![3.0, 4.0, 5.0, 6.0]);
251 }
252
253 #[test]
254 fn test_quantized_matmul_scaling() {
255 let params_a = QuantParamsI8::new(0.5, 0);
256 let params_b = QuantParamsI8::new(2.0, 0);
257 let a = vec![2i8, 3];
259 let b = vec![4i8, 5];
260 let mut out = vec![0.0f64; 1];
261 quantized_matmul_i8(&a, &b, &mut out, 1, 2, 1, ¶ms_a, ¶ms_b);
262 assert_eq!(out[0], 23.0);
265 }
266
267 #[test]
268 fn test_quantized_dot_deterministic() {
269 let params = QuantParamsI8::new(0.001, 0);
270 let a: Vec<i8> = (0..100).map(|i| (i % 127) as i8).collect();
271 let b: Vec<i8> = (0..100).map(|i| ((100 - i) % 127) as i8).collect();
272
273 let r1 = quantized_dot_i8(&a, &b, ¶ms, ¶ms);
274 let r2 = quantized_dot_i8(&a, &b, ¶ms, ¶ms);
275 assert_eq!(r1.to_bits(), r2.to_bits());
276 }
277
278 #[test]
279 fn test_i4_unpack() {
280 let (hi, lo) = QuantParamsI4::unpack_byte(0x3E);
285 assert_eq!(hi, 3);
286 assert_eq!(lo, -2);
287 }
288
289 #[test]
290 fn test_i4_unpack_negatives() {
291 let (hi, lo) = QuantParamsI4::unpack_byte(0xF8);
293 assert_eq!(hi, -1);
294 assert_eq!(lo, -8);
295 }
296
297 #[test]
298 fn test_quantized_sum_i4() {
299 let params = QuantParamsI4::new(1.0, 0);
300 let packed = vec![0x23u8, 0x45];
302 let result = quantized_sum_i4(&packed, 4, ¶ms);
303 assert_eq!(result, 14.0); }
305
306 #[test]
307 fn test_quantized_sum_i8_near_order_invariant() {
308 let params = QuantParamsI8::new(0.001, 0);
309 let values: Vec<i8> = (0..200).map(|i| ((i as i16 - 100) % 128) as i8).collect();
310
311 let r1 = quantized_sum_i8(&values, ¶ms);
312
313 let mut rev = values.clone();
315 rev.reverse();
316 let r2 = quantized_sum_i8(&rev, ¶ms);
317
318 let ulps = (r1.to_bits() as i64 - r2.to_bits() as i64).unsigned_abs();
322 assert!(ulps < 10,
323 "Quantized sum should be near-order-invariant: {r1} vs {r2} ({ulps} ULPs)");
324 }
325
326 #[test]
327 fn test_quantized_sum_i8_merge_order_invariant() {
328 let params = QuantParamsI8::new(0.001, 0);
330 let values: Vec<i8> = (0..200).map(|i| ((i as i16 - 100) % 128) as i8).collect();
331
332 let mut fwd = BinnedAccumulatorF64::new();
334 for chunk in values.chunks(20) {
335 let mut c = BinnedAccumulatorF64::new();
336 for &v in chunk {
337 c.add(params.dequantize(v));
338 }
339 fwd.merge(&c);
340 }
341
342 let chunks: Vec<Vec<i8>> = values.chunks(20).map(|c| c.to_vec()).collect();
344 let mut rev = BinnedAccumulatorF64::new();
345 for chunk in chunks.iter().rev() {
346 let mut c = BinnedAccumulatorF64::new();
347 for &v in chunk.iter() {
348 c.add(params.dequantize(v));
349 }
350 rev.merge(&c);
351 }
352
353 assert_eq!(fwd.finalize().to_bits(), rev.finalize().to_bits(),
354 "Merge-based quantized sum must be order-invariant");
355 }
356}