Skip to main content

cjc_runtime/
quantized.rs

1//! Quantized BLAS — i8/i4 dequantization into BinnedAccumulator.
2//!
3//! # Design
4//!
5//! Quantized integer products (i8×i8 → i32) are dequantized and fed directly
6//! into the BinnedAccumulator, bypassing intermediate f32 rounding entirely.
7//! This eliminates a major source of non-determinism in quantized inference.
8//!
9//! # Saturation
10//!
11//! Integer overflow is handled via saturation arithmetic — values clamp to
12//! `i32::MAX` / `i32::MIN` rather than wrapping silently.
13
14use crate::accumulator::BinnedAccumulatorF64;
15
16// ---------------------------------------------------------------------------
17// i8 Quantized Operations
18// ---------------------------------------------------------------------------
19
20/// Quantization parameters for i8 tensors.
21///
22/// Maps integer range [zero_point - 128, zero_point + 127] to floating-point
23/// via: `float_value = scale * (int_value - zero_point)`
24#[derive(Debug, Clone, Copy)]
25pub struct QuantParamsI8 {
26    /// Scale factor: the step size between consecutive quantized values.
27    pub scale: f64,
28    /// Zero point: the integer value that maps to 0.0 in float.
29    pub zero_point: i8,
30}
31
32impl QuantParamsI8 {
33    /// Create new quantization parameters.
34    pub fn new(scale: f64, zero_point: i8) -> Self {
35        QuantParamsI8 { scale, zero_point }
36    }
37
38    /// Dequantize a single i8 value to f64.
39    #[inline]
40    pub fn dequantize(&self, v: i8) -> f64 {
41        self.scale * (v as i64 - self.zero_point as i64) as f64
42    }
43
44    /// Dequantize a slice of i8 values to f64.
45    pub fn dequantize_slice(&self, src: &[i8]) -> Vec<f64> {
46        src.iter().map(|&v| self.dequantize(v)).collect()
47    }
48}
49
50/// Quantization parameters for i4 (nibble-packed) tensors.
51///
52/// i4 values range from -8 to +7 (signed) or 0 to 15 (unsigned).
53/// Stored packed: two i4 values per byte (high nibble, low nibble).
54#[derive(Debug, Clone, Copy)]
55pub struct QuantParamsI4 {
56    /// Scale factor.
57    pub scale: f64,
58    /// Zero point in i4 range [-8, 7].
59    pub zero_point: i8,
60}
61
62impl QuantParamsI4 {
63    pub fn new(scale: f64, zero_point: i8) -> Self {
64        assert!(zero_point >= -8 && zero_point <= 7, "i4 zero_point must be in [-8, 7]");
65        QuantParamsI4 { scale, zero_point }
66    }
67
68    /// Unpack a byte into two signed i4 values: (high_nibble, low_nibble).
69    #[inline]
70    pub fn unpack_byte(byte: u8) -> (i8, i8) {
71        // Sign-extension for 4-bit signed values via shift trick.
72        let hi = (((byte >> 4) & 0x0F) as i8) << 4 >> 4;
73        let lo = ((byte & 0x0F) as i8) << 4 >> 4;
74        (hi, lo)
75    }
76
77    /// Dequantize a single i4 value to f64.
78    #[inline]
79    pub fn dequantize(&self, v: i8) -> f64 {
80        self.scale * (v as i64 - self.zero_point as i64) as f64
81    }
82}
83
84// ---------------------------------------------------------------------------
85// Saturating i32 Arithmetic
86// ---------------------------------------------------------------------------
87
88/// Saturating multiply of two i8 values, producing i32 without overflow.
89#[inline]
90pub fn saturating_mul_i8(a: i8, b: i8) -> i32 {
91    (a as i32) * (b as i32)
92    // i8 * i8 fits in i32 without overflow (max: 127*127 = 16129)
93}
94
95/// Saturating dot product of two i8 slices, accumulating into i32.
96///
97/// Uses saturating addition to prevent silent wrap-around.
98#[inline]
99pub fn saturating_dot_i8(a: &[i8], b: &[i8]) -> i32 {
100    debug_assert_eq!(a.len(), b.len());
101    let mut sum: i32 = 0;
102    for i in 0..a.len() {
103        let prod = (a[i] as i32) * (b[i] as i32);
104        sum = sum.saturating_add(prod);
105    }
106    sum
107}
108
109// ---------------------------------------------------------------------------
110// Quantized GEMM via BinnedAccumulator
111// ---------------------------------------------------------------------------
112
113/// Quantized matrix multiply: C[m,n] = dequant(A[m,k]) × dequant(B[k,n])
114///
115/// The i8×i8 products are computed in i32, then dequantized to f64 and
116/// accumulated via BinnedAccumulator for deterministic summation.
117///
118/// This avoids intermediate f32 rounding: integer products go directly
119/// to f64 dequantization, then into binned accumulation.
120///
121/// # Arguments
122/// * `a` - Row-major i8 matrix [m, k]
123/// * `b` - Row-major i8 matrix [k, n]
124/// * `params_a` - Quantization parameters for A
125/// * `params_b` - Quantization parameters for B
126/// * `out` - Output buffer [m, n] (pre-allocated)
127pub fn quantized_matmul_i8(
128    a: &[i8], b: &[i8], out: &mut [f64],
129    m: usize, k: usize, n: usize,
130    params_a: &QuantParamsI8, params_b: &QuantParamsI8,
131) {
132    debug_assert_eq!(a.len(), m * k);
133    debug_assert_eq!(b.len(), k * n);
134    debug_assert_eq!(out.len(), m * n);
135
136    // Combined scale factor.
137    let combined_scale = params_a.scale * params_b.scale;
138
139    for i in 0..m {
140        for j in 0..n {
141            let mut acc = BinnedAccumulatorF64::new();
142            for p in 0..k {
143                // Integer product: no rounding.
144                let int_prod = (a[i * k + p] as i64 - params_a.zero_point as i64)
145                    * (b[p * n + j] as i64 - params_b.zero_point as i64);
146                // Dequantize directly to f64: combined_scale * int_prod.
147                acc.add(combined_scale * int_prod as f64);
148            }
149            out[i * n + j] = acc.finalize();
150        }
151    }
152}
153
154/// Quantized dot product of two i8 vectors, returning f64.
155///
156/// Dequantizes products into BinnedAccumulator for determinism.
157pub fn quantized_dot_i8(
158    a: &[i8], b: &[i8],
159    params_a: &QuantParamsI8, params_b: &QuantParamsI8,
160) -> f64 {
161    debug_assert_eq!(a.len(), b.len());
162    let combined_scale = params_a.scale * params_b.scale;
163    let mut acc = BinnedAccumulatorF64::new();
164    for i in 0..a.len() {
165        let int_prod = (a[i] as i64 - params_a.zero_point as i64)
166            * (b[i] as i64 - params_b.zero_point as i64);
167        acc.add(combined_scale * int_prod as f64);
168    }
169    acc.finalize()
170}
171
172/// Sum dequantized i8 values using BinnedAccumulator.
173pub fn quantized_sum_i8(values: &[i8], params: &QuantParamsI8) -> f64 {
174    let mut acc = BinnedAccumulatorF64::new();
175    for &v in values {
176        acc.add(params.dequantize(v));
177    }
178    acc.finalize()
179}
180
181/// Sum dequantized i4 (packed) values using BinnedAccumulator.
182///
183/// `packed` contains pairs of i4 values packed into bytes.
184/// `count` is the total number of i4 elements (may be odd).
185pub fn quantized_sum_i4(packed: &[u8], count: usize, params: &QuantParamsI4) -> f64 {
186    let mut acc = BinnedAccumulatorF64::new();
187    let mut remaining = count;
188    for &byte in packed {
189        if remaining == 0 { break; }
190        let (hi, lo) = QuantParamsI4::unpack_byte(byte);
191        acc.add(params.dequantize(hi));
192        remaining -= 1;
193        if remaining == 0 { break; }
194        acc.add(params.dequantize(lo));
195        remaining -= 1;
196    }
197    acc.finalize()
198}
199
200// ---------------------------------------------------------------------------
201// Inline tests
202// ---------------------------------------------------------------------------
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_dequantize_i8_basic() {
210        let params = QuantParamsI8::new(0.1, 0);
211        assert_eq!(params.dequantize(10), 1.0);
212        assert_eq!(params.dequantize(-10), -1.0);
213        assert_eq!(params.dequantize(0), 0.0);
214    }
215
216    #[test]
217    fn test_dequantize_i8_with_zero_point() {
218        let params = QuantParamsI8::new(0.5, 10);
219        // float = 0.5 * (20 - 10) = 5.0
220        assert_eq!(params.dequantize(20), 5.0);
221        // float = 0.5 * (10 - 10) = 0.0
222        assert_eq!(params.dequantize(10), 0.0);
223    }
224
225    #[test]
226    fn test_saturating_dot_i8() {
227        let a = vec![1i8, 2, 3, 4];
228        let b = vec![5i8, 6, 7, 8];
229        assert_eq!(saturating_dot_i8(&a, &b), 70); // 5+12+21+32
230    }
231
232    #[test]
233    fn test_saturating_dot_overflow() {
234        // Test that saturation prevents wrap-around.
235        let a = vec![127i8; 1000];
236        let b = vec![127i8; 1000];
237        let result = saturating_dot_i8(&a, &b);
238        // 127*127 = 16129; 16129 * 1000 = 16_129_000 — fits in i32.
239        assert_eq!(result, 16_129_000);
240    }
241
242    #[test]
243    fn test_quantized_matmul_identity() {
244        // 2x2 identity via i8 with scale=1.0, zp=0
245        let params = QuantParamsI8::new(1.0, 0);
246        let a = vec![1i8, 0, 0, 1]; // identity
247        let b = vec![3i8, 4, 5, 6];
248        let mut out = vec![0.0f64; 4];
249        quantized_matmul_i8(&a, &b, &mut out, 2, 2, 2, &params, &params);
250        assert_eq!(out, vec![3.0, 4.0, 5.0, 6.0]);
251    }
252
253    #[test]
254    fn test_quantized_matmul_scaling() {
255        let params_a = QuantParamsI8::new(0.5, 0);
256        let params_b = QuantParamsI8::new(2.0, 0);
257        // combined_scale = 0.5 * 2.0 = 1.0
258        let a = vec![2i8, 3];
259        let b = vec![4i8, 5];
260        let mut out = vec![0.0f64; 1];
261        quantized_matmul_i8(&a, &b, &mut out, 1, 2, 1, &params_a, &params_b);
262        // dot = (0.5*2)*(2.0*4) + (0.5*3)*(2.0*5) = 8 + 15 = 23
263        // via combined: 1.0 * (2*4 + 3*5) = 1.0 * 23 = 23.0
264        assert_eq!(out[0], 23.0);
265    }
266
267    #[test]
268    fn test_quantized_dot_deterministic() {
269        let params = QuantParamsI8::new(0.001, 0);
270        let a: Vec<i8> = (0..100).map(|i| (i % 127) as i8).collect();
271        let b: Vec<i8> = (0..100).map(|i| ((100 - i) % 127) as i8).collect();
272
273        let r1 = quantized_dot_i8(&a, &b, &params, &params);
274        let r2 = quantized_dot_i8(&a, &b, &params, &params);
275        assert_eq!(r1.to_bits(), r2.to_bits());
276    }
277
278    #[test]
279    fn test_i4_unpack() {
280        // Pack: high=3, low=-2 → 0x3E
281        // 3 in 4-bit signed: 0011
282        // -2 in 4-bit signed: 1110
283        // byte: 0011_1110 = 0x3E
284        let (hi, lo) = QuantParamsI4::unpack_byte(0x3E);
285        assert_eq!(hi, 3);
286        assert_eq!(lo, -2);
287    }
288
289    #[test]
290    fn test_i4_unpack_negatives() {
291        // high=-1 (1111), low=-8 (1000) → 0xF8
292        let (hi, lo) = QuantParamsI4::unpack_byte(0xF8);
293        assert_eq!(hi, -1);
294        assert_eq!(lo, -8);
295    }
296
297    #[test]
298    fn test_quantized_sum_i4() {
299        let params = QuantParamsI4::new(1.0, 0);
300        // Pack: (2, 3), (4, 5) = 0x23, 0x45
301        let packed = vec![0x23u8, 0x45];
302        let result = quantized_sum_i4(&packed, 4, &params);
303        assert_eq!(result, 14.0); // 2 + 3 + 4 + 5
304    }
305
306    #[test]
307    fn test_quantized_sum_i8_near_order_invariant() {
308        let params = QuantParamsI8::new(0.001, 0);
309        let values: Vec<i8> = (0..200).map(|i| ((i as i16 - 100) % 128) as i8).collect();
310
311        let r1 = quantized_sum_i8(&values, &params);
312
313        // Reverse order.
314        let mut rev = values.clone();
315        rev.reverse();
316        let r2 = quantized_sum_i8(&rev, &params);
317
318        // Within-bin accumulation order may cause a few ULPs of difference
319        // due to IEEE-754 non-associativity. The BinnedAccumulator minimizes
320        // this by binning values with similar exponents, but doesn't eliminate it.
321        let ulps = (r1.to_bits() as i64 - r2.to_bits() as i64).unsigned_abs();
322        assert!(ulps < 10,
323            "Quantized sum should be near-order-invariant: {r1} vs {r2} ({ulps} ULPs)");
324    }
325
326    #[test]
327    fn test_quantized_sum_i8_merge_order_invariant() {
328        // Merge-based accumulation IS fully order-invariant (Knuth 2Sum merge).
329        let params = QuantParamsI8::new(0.001, 0);
330        let values: Vec<i8> = (0..200).map(|i| ((i as i16 - 100) % 128) as i8).collect();
331
332        // Chunk into 20s, merge forward.
333        let mut fwd = BinnedAccumulatorF64::new();
334        for chunk in values.chunks(20) {
335            let mut c = BinnedAccumulatorF64::new();
336            for &v in chunk {
337                c.add(params.dequantize(v));
338            }
339            fwd.merge(&c);
340        }
341
342        // Chunk into 20s, merge reverse.
343        let chunks: Vec<Vec<i8>> = values.chunks(20).map(|c| c.to_vec()).collect();
344        let mut rev = BinnedAccumulatorF64::new();
345        for chunk in chunks.iter().rev() {
346            let mut c = BinnedAccumulatorF64::new();
347            for &v in chunk.iter() {
348                c.add(params.dequantize(v));
349            }
350            rev.merge(&c);
351        }
352
353        assert_eq!(fwd.finalize().to_bits(), rev.finalize().to_bits(),
354            "Merge-based quantized sum must be order-invariant");
355    }
356}