Skip to main content

axonml_quant/
bitnet.rs

1//! BitNet b1.58 I2_S Ternary Quantization — Dequant + Fused Add-Only Matmul
2//!
3//! Implements Microsoft's `I2_S` quant type (GGUF dtype 36), used by the
4//! `bitnet.cpp` reference kernels and by every official BitNet b1.58 GGUF
5//! release (including `microsoft/bitnet-b1.58-2B-4T-gguf`).
6//!
7//! Contents:
8//! - Constants `I2S_BLOCK_SIZE`, `I2S_BYTES_PER_BLOCK`, `I2S_GROUP_SIZE`
9//!   mirroring Microsoft's `QK_I2_S` / group-strided layout.
10//! - Trit ↔ 2-bit code converters `decode_trit` / `encode_trit`.
11//! - `I2sBlock` struct with `pack` / `unpack` / `to_bytes` / `from_bytes`
12//!   implementing the group-strided byte layout.
13//! - `dequantize_i2s_block` and rayon-parallel `dequantize_i2s` for
14//!   recovering f32 weights from packed bytes.
15//! - `matmul_i2s` — fused add-only ternary matmul (f32 activations).
16//! - `matmul_i2s_i8` — int8-activation fused path with runtime AVX-VNNI
17//!   dispatch, scalar fallback `matmul_i2s_i8_scalar`, and an in-progress
18//!   `matmul_i2s_i8_avxvnni` unsafe target-feature stub.
19//! - `quantize_row_to_int8` per-row absmax int8 quantization for
20//!   activations entering the int8 fast path.
21//! - `bytes_for_elements` size helper and a test module covering
22//!   trit encode/decode, block roundtrip, layout correctness, reference
23//!   matmul agreement, int8 vs f32 agreement, and misaligned-`k` rejection.
24//!
25//! # Format (verified against `microsoft/BitNet` reference, 2026-04-14)
26//!
27//! - **Block size: 128 weights** (`QK_I2_S = 128` on x86_64).
28//! - **Block stride: 32 bytes** (128 × 2 bits packed).
29//! - **Encoding per 2-bit code:** `0 → -1`, `1 → 0`, `2 → +1`, `3 → unused`.
30//!   (From `quantize_i2_s` in `ggml-bitnet-mad.cpp`:
31//!   `"q8 -> 0, 1, 2 | | | -1, 0, 1"`.)
32//! - **Intra-block layout is NOT `4 consecutive weights / byte`** — it's a
33//!   SIMD-friendly group-strided layout. Each 32-byte block stores 128
34//!   weights as 4 groups of 32, multiplexed into the 2-bit positions of
35//!   the 32 bytes:
36//!
37//!     - byte `k` bits **6..7** → weight `k`       (group 0, shift 6)
38//!     - byte `k` bits **4..5** → weight `k + 32`  (group 1, shift 4)
39//!     - byte `k` bits **2..3** → weight `k + 64`  (group 2, shift 2)
40//!     - byte `k` bits **0..1** → weight `k + 96`  (group 3, shift 0)
41//!
42//!   (Encoder in the reference: `temp = q8 << (6 - 2*group_idx)`; decoder
43//!   shifts the byte right by `6 - 2*group_idx` and ANDs with `0x03`.) The
44//!   layout lets AVX2 load 32 bytes once, shift by 0/2/4/6 with a mask of
45//!   `0x03`, and extract 128 trit codes in four 32-wide vector registers.
46//!
47//! - **One tensor-wide f32 scale** follows the packed data at offset
48//!   `m * k / 4` bytes, padded to the next 32-byte boundary. The
49//!   dequantized weight is `scale × trit[i]`. This module requires the
50//!   caller to pass the scale separately — [`matmul_i2s`] and
51//!   [`dequantize_i2s`] both accept `scale: f32`.
52//!
53//! # Why this is fast on CPU
54//!
55//! Matmul between an f32 activation matrix and ternary weights becomes
56//! branchless accumulate-and-subtract: for each weight, either add the
57//! activation (+1), subtract it (-1), or skip (0). The tensor-wide scale
58//! applies once at the end of each output element. That's BitNet's
59//! "add-only matmul" performance story. A SIMD fast path mirroring the
60//! reference AVX2 kernel is a natural follow-up once the scalar path is
61//! verified against Microsoft's released weights.
62//!
63//! # References
64//! - Microsoft BitNet paper ("The Era of 1-bit LLMs"): <https://arxiv.org/abs/2402.17764>
65//! - `microsoft/BitNet` on GitHub (bitnet.cpp reference kernels)
66//!
67//! # File
68//! `crates/axonml-quant/src/bitnet.rs`
69//!
70//! # Author
71//! Andrew Jewell Sr. — AutomataNexus LLC
72//! ORCID: 0009-0005-2158-7060
73//!
74//! # Updated
75//! April 16, 2026 11:15 PM EST
76//!
77//! # Disclaimer
78//! Use at own risk. This software is provided "as is", without warranty of any
79//! kind, express or implied. The author and AutomataNexus shall not be held
80//! liable for any damages arising from the use of this software.
81
82use rayon::prelude::*;
83
84// =============================================================================
85// Constants
86// =============================================================================
87
88/// Weights per I2_S block (Microsoft `QK_I2_S` on x86_64).
89pub const I2S_BLOCK_SIZE: usize = 128;
90
91/// Bytes per I2_S block (128 × 2 bits).
92pub const I2S_BYTES_PER_BLOCK: usize = 32;
93
94/// Group size within a block — 4 groups of 32 weights share the same 32
95/// bytes but live at different bit-positions.
96pub const I2S_GROUP_SIZE: usize = 32;
97
98// =============================================================================
99// Trit <-> 2-bit encoding (Microsoft bitnet.cpp convention)
100// =============================================================================
101
102/// Decode a 2-bit code to a trit in `{-1, 0, +1}`.
103///
104/// Encoding (Microsoft bitnet.cpp): `0 → -1`, `1 → 0`, `2 → +1`, `3 → 0`
105/// (defensive; code 3 is unused by the reference encoder).
106#[inline(always)]
107pub const fn decode_trit(code: u8) -> i8 {
108    match code & 0b11 {
109        0 => -1,
110        1 => 0,
111        2 => 1,
112        _ => 0,
113    }
114}
115
116/// Encode a trit to a 2-bit code (inverse of [`decode_trit`]).
117#[inline(always)]
118const fn encode_trit(v: i8) -> u8 {
119    if v > 0 {
120        2
121    } else if v < 0 {
122        0
123    } else {
124        1
125    }
126}
127
128// =============================================================================
129// Block pack / unpack (primarily for tests; production loads raw bytes)
130// =============================================================================
131
132/// A single I2_S block: 128 ternary weights in the group-strided layout.
133///
134/// The tensor-wide scale is not stored on the block — see module docs.
135#[derive(Debug, Clone)]
136pub struct I2sBlock {
137    /// 32 bytes holding 128 × 2-bit trits in group-strided form.
138    pub data: [u8; I2S_BYTES_PER_BLOCK],
139}
140
141impl I2sBlock {
142    /// Pack 128 trit values `{-1, 0, +1}` using the Microsoft layout.
143    pub fn pack(values: &[i8; I2S_BLOCK_SIZE]) -> Self {
144        let mut data = [0u8; I2S_BYTES_PER_BLOCK];
145        // For each (group_idx, group_pos), OR the code at bits [6-2*g .. 7-2*g].
146        for group_idx in 0..4 {
147            let shift = 6 - 2 * group_idx;
148            for group_pos in 0..I2S_GROUP_SIZE {
149                let code = encode_trit(values[group_idx * I2S_GROUP_SIZE + group_pos]);
150                data[group_pos] |= code << shift;
151            }
152        }
153        Self { data }
154    }
155
156    /// 32-byte raw view.
157    pub fn to_bytes(&self) -> [u8; I2S_BYTES_PER_BLOCK] {
158        self.data
159    }
160
161    /// Parse a block from a 32-byte slice.
162    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
163        if bytes.len() < I2S_BYTES_PER_BLOCK {
164            return None;
165        }
166        let mut data = [0u8; I2S_BYTES_PER_BLOCK];
167        data.copy_from_slice(&bytes[..I2S_BYTES_PER_BLOCK]);
168        Some(Self { data })
169    }
170
171    /// Unpack all 128 trits in linear order `0..128`.
172    pub fn unpack(&self) -> [i8; I2S_BLOCK_SIZE] {
173        let mut out = [0i8; I2S_BLOCK_SIZE];
174        for group_idx in 0..4 {
175            let shift = 6 - 2 * group_idx;
176            for group_pos in 0..I2S_GROUP_SIZE {
177                let code = (self.data[group_pos] >> shift) & 0b11;
178                out[group_idx * I2S_GROUP_SIZE + group_pos] = decode_trit(code);
179            }
180        }
181        out
182    }
183}
184
185// =============================================================================
186// Dequantization
187// =============================================================================
188
189/// Dequantize a single I2_S block to 128 `f32` values.
190///
191/// Applies the tensor-wide `scale` so `out[i] = scale × trit[i]`.
192///
193/// # Panics
194/// Debug-only: panics if `bytes.len() < 32` or `out.len() < 128`.
195pub fn dequantize_i2s_block(bytes: &[u8], scale: f32, out: &mut [f32]) {
196    debug_assert!(bytes.len() >= I2S_BYTES_PER_BLOCK);
197    debug_assert!(out.len() >= I2S_BLOCK_SIZE);
198    let b = &bytes[..I2S_BYTES_PER_BLOCK];
199    for (group_pos, &byte) in b.iter().enumerate() {
200        let c0 = decode_trit((byte >> 6) & 0b11) as f32; // group 0
201        let c1 = decode_trit((byte >> 4) & 0b11) as f32; // group 1
202        let c2 = decode_trit((byte >> 2) & 0b11) as f32; // group 2
203        let c3 = decode_trit(byte & 0b11) as f32; // group 3
204        out[group_pos] = c0 * scale;
205        out[group_pos + I2S_GROUP_SIZE] = c1 * scale;
206        out[group_pos + 2 * I2S_GROUP_SIZE] = c2 * scale;
207        out[group_pos + 3 * I2S_GROUP_SIZE] = c3 * scale;
208    }
209}
210
211/// Dequantize a full I2_S weight buffer to f32.
212///
213/// `out.len()` must be a multiple of [`I2S_BLOCK_SIZE`]. Rayon-parallelized
214/// over blocks.
215pub fn dequantize_i2s(bytes: &[u8], scale: f32, out: &mut [f32]) {
216    let n_blocks = out.len() / I2S_BLOCK_SIZE;
217    out.par_chunks_mut(I2S_BLOCK_SIZE)
218        .take(n_blocks)
219        .zip(bytes.par_chunks(I2S_BYTES_PER_BLOCK).take(n_blocks))
220        .for_each(|(out_block, in_block)| {
221            if in_block.len() >= I2S_BYTES_PER_BLOCK {
222                dequantize_i2s_block(in_block, scale, out_block);
223            }
224        });
225}
226
227// =============================================================================
228// Fused ternary matmul
229// =============================================================================
230
231/// Fused add-only ternary matmul: `output = scale × (activations @ weights^T)`.
232///
233/// # Shapes
234/// - `activations`: `[m, k]` row-major f32
235/// - `weight_bytes`: `[n, k]` — each output row is a contiguous run of
236///   `k / 128` I2_S blocks (32 bytes each). `k` **must** be a multiple
237///   of [`I2S_BLOCK_SIZE`] (128).
238/// - `scale`: tensor-wide f32 scale read from the tail of the GGUF tensor
239///   (see module docs)
240/// - `output`: `[m, n]` row-major f32
241///
242/// # Panics
243/// Panics if `k % 128 != 0`, or if shapes don't line up.
244///
245/// # Parallelism
246/// Parallelizes over output columns `n`. Right grain for decode (`m == 1`)
247/// — one independent dot product per column.
248pub fn matmul_i2s(
249    activations: &[f32],
250    m: usize,
251    k: usize,
252    weight_bytes: &[u8],
253    n: usize,
254    scale: f32,
255    output: &mut [f32],
256) {
257    assert!(
258        k % I2S_BLOCK_SIZE == 0,
259        "matmul_i2s: k ({k}) must be a multiple of {I2S_BLOCK_SIZE}",
260    );
261    assert_eq!(activations.len(), m * k, "activations shape mismatch");
262    assert_eq!(output.len(), m * n, "output shape mismatch");
263    let blocks_per_row = k / I2S_BLOCK_SIZE;
264    let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
265    assert_eq!(
266        weight_bytes.len(),
267        n * bytes_per_row,
268        "weight_bytes shape mismatch",
269    );
270
271    for i in 0..m {
272        let act_row = &activations[i * k..(i + 1) * k];
273        output[i * n..(i + 1) * n]
274            .par_iter_mut()
275            .enumerate()
276            .for_each(|(j, out_slot)| {
277                let wrow = &weight_bytes[j * bytes_per_row..(j + 1) * bytes_per_row];
278                *out_slot = dot_row_ternary(act_row, wrow, blocks_per_row) * scale;
279            });
280    }
281}
282
283/// Inner dot product for one activation row × one ternary weight row, in
284/// the group-strided layout. Returns the UNSCALED sum — caller multiplies
285/// by the tensor-wide scale.
286#[inline(always)]
287fn dot_row_ternary(act_row: &[f32], wrow: &[u8], blocks_per_row: usize) -> f32 {
288    let mut acc = 0.0f32;
289    for block_idx in 0..blocks_per_row {
290        let block_off = block_idx * I2S_BYTES_PER_BLOCK;
291        let block = &wrow[block_off..block_off + I2S_BYTES_PER_BLOCK];
292        let k_base = block_idx * I2S_BLOCK_SIZE;
293        // Autovec-friendly: pull the four group activations into locals,
294        // decode once per byte, accumulate into four separate accumulators
295        // (lets LLVM issue independent FMAs on wide targets).
296        let mut a0 = 0.0f32;
297        let mut a1 = 0.0f32;
298        let mut a2 = 0.0f32;
299        let mut a3 = 0.0f32;
300        for (group_pos, &byte) in block.iter().enumerate() {
301            let t0 = decode_trit((byte >> 6) & 0b11) as f32;
302            let t1 = decode_trit((byte >> 4) & 0b11) as f32;
303            let t2 = decode_trit((byte >> 2) & 0b11) as f32;
304            let t3 = decode_trit(byte & 0b11) as f32;
305            let base = k_base + group_pos;
306            a0 += act_row[base] * t0;
307            a1 += act_row[base + I2S_GROUP_SIZE] * t1;
308            a2 += act_row[base + 2 * I2S_GROUP_SIZE] * t2;
309            a3 += act_row[base + 3 * I2S_GROUP_SIZE] * t3;
310        }
311        acc += a0 + a1 + a2 + a3;
312    }
313    acc
314}
315
316// =============================================================================
317// Int8-activation fused ternary matmul (AVX-VNNI path)
318// =============================================================================
319//
320// This is the "30-50% over llama.cpp" lever. The scalar `matmul_i2s` above
321// materializes activations as f32 and walks through trit codes one byte at
322// a time — memory-BW-bound and one instruction per trit on CPU.
323//
324// The fused int8 path does three things differently:
325//
326// 1. Activations are quantized to int8 with a per-row absmax scale before
327//    the matmul. Bandwidth drops 4× (f32 → i8) on the activation side.
328// 2. Trit codes stay in their 2-bit packed form and never decode to f32.
329//    Bandwidth drops 16× on the weight side (f32 → 2 bits).
330// 3. The dot product uses the VNNI `dpbusd` instruction (`_mm256_dpbusd_epi32`
331//    on AVX-VNNI, `_mm512_dpbusd_epi32` on AVX-512 VNNI). dpbusd does 32
332//    unsigned-byte × signed-byte multiplies and sums groups of 4 into int32
333//    lanes, so one instruction handles 32 weight-activation pairs.
334//
335// Arithmetic trick: the trit codes `{0, 1, 2}` map to `{-1, 0, +1}` via
336// `trit = code - 1`. dpbusd computes `sum(code × act)`, not `sum(trit × act)`.
337// We recover the true dot product with a single per-row correction:
338//
339//     true_dot_j = code_dot_j - act_sum
340//
341// where `act_sum = sum_k(act_i[k])` (a single scalar per input row, computed
342// once). The correction is two cheap int32 ops per output element.
343//
344// Microsoft's bitnet.cpp does this same accounting in its AVX2 kernels — we
345// match their approach and then tune for Arrow Lake's AVX-VNNI-but-no-AVX-512
346// ISA profile. Reference in `ggml-bitnet-mad.cpp::ggml_vec_dot_i2_i8_s_1x1`.
347
348/// Quantize an f32 row to int8 with a per-row absmax scale such that
349/// `f32[i] ≈ int8[i] * scale`. The scale is chosen so the largest-
350/// magnitude element maps to ±127.
351///
352/// Returns the scale. Intended for **activation quantization** at the
353/// beginning of an I2_S × int8 matmul — the matmul output multiplies by
354/// this scale (along with the weight's tensor-wide scale) to recover f32
355/// logits.
356///
357/// # Edge cases
358/// - All-zero input → scale = 0, output is all zeros.
359/// - Signals with a single enormous outlier — absmax is sensitive; callers
360///   with outlier activations may want to clamp or use percentile-absmax
361///   instead. Not a concern for typical post-norm transformer activations.
362pub fn quantize_row_to_int8(input: &[f32], output: &mut [i8]) -> f32 {
363    debug_assert_eq!(input.len(), output.len());
364    let absmax = input.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
365    if absmax == 0.0 {
366        for o in output.iter_mut() {
367            *o = 0;
368        }
369        return 0.0;
370    }
371    let scale = absmax / 127.0;
372    let inv_scale = 1.0 / scale;
373    for (o, &v) in output.iter_mut().zip(input.iter()) {
374        let q = (v * inv_scale).round();
375        // Clamp to i8 range defensively — rounding can push ±127.5 → ±128.
376        *o = q.clamp(-127.0, 127.0) as i8;
377    }
378    scale
379}
380
381/// Fused I2_S × int8 matmul: `out = weight_scale × (acts_int8 @ weights^T) × act_scales_per_row`.
382///
383/// # Shapes
384/// - `acts_int8`: `[m, k]` row-major int8 (quantized via [`quantize_row_to_int8`])
385/// - `act_scales`: `[m]` — per-row f32 scales from the int8 quantization
386/// - `weight_bytes`: `[n, k]` as I2_S blocks (same layout as [`matmul_i2s`])
387/// - `weight_scale`: tensor-wide f32 scale (from GGUF tail)
388/// - `output`: `[m, n]` row-major f32
389///
390/// # Math
391/// For each `(i, j)`:
392///
393/// ```text
394/// out[i, j] = act_scales[i] * weight_scale * sum_k(trit[j, k] * act_i8[i, k])
395///           = act_scales[i] * weight_scale * (sum_k(code[j, k] * act_i8[i, k]) - act_sum[i])
396/// ```
397///
398/// where `code[j, k] ∈ {0, 1, 2}` is the raw 2-bit code and `act_sum[i] = sum_k(act_i8[i, k])`.
399///
400/// # Dispatch
401/// At runtime we check for AVX-VNNI via `is_x86_feature_detected!("avxvnni")`
402/// and take the SIMD path when available; otherwise fall back to a scalar
403/// reference that matches the SIMD path bit-for-bit (lets tests run on any
404/// host).
405pub fn matmul_i2s_i8(
406    acts_int8: &[i8],
407    act_scales: &[f32],
408    m: usize,
409    k: usize,
410    weight_bytes: &[u8],
411    n: usize,
412    weight_scale: f32,
413    output: &mut [f32],
414) {
415    assert!(
416        k % I2S_BLOCK_SIZE == 0,
417        "matmul_i2s_i8: k ({k}) must be a multiple of {I2S_BLOCK_SIZE}",
418    );
419    assert_eq!(acts_int8.len(), m * k, "acts_int8 shape mismatch");
420    assert_eq!(act_scales.len(), m, "act_scales length mismatch");
421    assert_eq!(output.len(), m * n, "output shape mismatch");
422    let blocks_per_row = k / I2S_BLOCK_SIZE;
423    let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
424    assert_eq!(
425        weight_bytes.len(),
426        n * bytes_per_row,
427        "weight_bytes shape mismatch",
428    );
429
430    // Runtime dispatch — AVX-VNNI variant fills in on a follow-up commit;
431    // scalar path below is the correctness baseline.
432    #[cfg(target_arch = "x86_64")]
433    {
434        if std::is_x86_feature_detected!("avxvnni") && std::is_x86_feature_detected!("avx2") {
435            // SAFETY: feature-detected above.
436            unsafe {
437                matmul_i2s_i8_avxvnni(
438                    acts_int8,
439                    act_scales,
440                    m,
441                    k,
442                    weight_bytes,
443                    n,
444                    weight_scale,
445                    output,
446                );
447            }
448            return;
449        }
450    }
451
452    matmul_i2s_i8_scalar(
453        acts_int8,
454        act_scales,
455        m,
456        k,
457        weight_bytes,
458        n,
459        weight_scale,
460        output,
461    );
462}
463
464/// Scalar reference for [`matmul_i2s_i8`]. Used as a correctness baseline
465/// for the AVX-VNNI fast path and as a fallback on non-x86_64 or
466/// pre-AVX-VNNI CPUs.
467fn matmul_i2s_i8_scalar(
468    acts_int8: &[i8],
469    act_scales: &[f32],
470    m: usize,
471    k: usize,
472    weight_bytes: &[u8],
473    n: usize,
474    weight_scale: f32,
475    output: &mut [f32],
476) {
477    let blocks_per_row = k / I2S_BLOCK_SIZE;
478    let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
479
480    for i in 0..m {
481        let act_row = &acts_int8[i * k..(i + 1) * k];
482        let act_scale = act_scales[i];
483        let act_sum: i32 = act_row.iter().map(|&x| x as i32).sum();
484        let combined_scale = weight_scale * act_scale;
485
486        output[i * n..(i + 1) * n]
487            .par_iter_mut()
488            .enumerate()
489            .for_each(|(j, out_slot)| {
490                let wrow = &weight_bytes[j * bytes_per_row..(j + 1) * bytes_per_row];
491                // sum_k(code[j,k] * act[k]) with code in {0,1,2}.
492                let mut code_dot: i32 = 0;
493                for block_idx in 0..blocks_per_row {
494                    let block_off = block_idx * I2S_BYTES_PER_BLOCK;
495                    let block = &wrow[block_off..block_off + I2S_BYTES_PER_BLOCK];
496                    let k_base = block_idx * I2S_BLOCK_SIZE;
497                    for (group_pos, &byte) in block.iter().enumerate() {
498                        // Bit layout: bits 6-7 → pos k_base+group_pos
499                        //             bits 4-5 → pos k_base+32+group_pos
500                        //             bits 2-3 → pos k_base+64+group_pos
501                        //             bits 0-1 → pos k_base+96+group_pos
502                        let c0 = ((byte >> 6) & 0b11) as i32;
503                        let c1 = ((byte >> 4) & 0b11) as i32;
504                        let c2 = ((byte >> 2) & 0b11) as i32;
505                        let c3 = (byte & 0b11) as i32;
506                        let base = k_base + group_pos;
507                        code_dot += c0 * act_row[base] as i32;
508                        code_dot += c1 * act_row[base + I2S_GROUP_SIZE] as i32;
509                        code_dot += c2 * act_row[base + 2 * I2S_GROUP_SIZE] as i32;
510                        code_dot += c3 * act_row[base + 3 * I2S_GROUP_SIZE] as i32;
511                    }
512                }
513                // trit = code - 1, so sum(trit*act) = sum(code*act) - sum(act).
514                let trit_dot = code_dot - act_sum;
515                *out_slot = (trit_dot as f32) * combined_scale;
516            });
517    }
518}
519
520/// AVX-VNNI fast path — **unimplemented**. Drop-in replacement for
521/// [`matmul_i2s_i8_scalar`] once filled in.
522///
523/// # Planned inner loop (per block of 128 weights × 128 activations):
524///
525/// ```ignore
526/// // Load 32 weight bytes (128 trits packed) and 128 int8 activations.
527/// let bytes_v = _mm256_loadu_si256(block_ptr as *const __m256i);
528/// let acts_g0 = _mm256_loadu_si256(act_ptr.add(k_base)         as *const __m256i); // pos k_base..+32
529/// let acts_g1 = _mm256_loadu_si256(act_ptr.add(k_base + 32)    as *const __m256i);
530/// let acts_g2 = _mm256_loadu_si256(act_ptr.add(k_base + 64)    as *const __m256i);
531/// let acts_g3 = _mm256_loadu_si256(act_ptr.add(k_base + 96)    as *const __m256i);
532///
533/// // Extract 2-bit codes for each of 4 groups.
534/// let mask = _mm256_set1_epi8(0x03);
535/// let codes_g0 = _mm256_and_si256(_mm256_srli_epi16(bytes_v, 6), mask); // bits 6-7
536/// let codes_g1 = _mm256_and_si256(_mm256_srli_epi16(bytes_v, 4), mask); // bits 4-5
537/// let codes_g2 = _mm256_and_si256(_mm256_srli_epi16(bytes_v, 2), mask); // bits 2-3
538/// let codes_g3 = _mm256_and_si256(bytes_v, mask);                      // bits 0-1
539///
540/// // 32 × (u8 × i8) → 8 × i32, accumulated.
541/// acc = _mm256_dpbusd_epi32(acc, codes_g0, acts_g0);
542/// acc = _mm256_dpbusd_epi32(acc, codes_g1, acts_g1);
543/// acc = _mm256_dpbusd_epi32(acc, codes_g2, acts_g2);
544/// acc = _mm256_dpbusd_epi32(acc, codes_g3, acts_g3);
545/// ```
546///
547/// Four VNNI ops per 128-weight block. The outer loop iterates
548/// `blocks_per_row` blocks per output column, then horizontally sums the
549/// int32 lanes to a scalar, applies the `- act_sum` correction, and scales
550/// by `combined_scale` to f32.
551///
552/// Rayon fan-out over output columns `n` (same as the scalar path). On
553/// Arrow Lake (AVX-VNNI but no AVX-512), expect ~8-12× speedup over
554/// scalar on the kernel alone; end-to-end wins compound because activation
555/// bandwidth drops 4× and weight bandwidth stays 2-bit.
556#[cfg(target_arch = "x86_64")]
557#[target_feature(enable = "avx2,avxvnni")]
558unsafe fn matmul_i2s_i8_avxvnni(
559    acts_int8: &[i8],
560    act_scales: &[f32],
561    m: usize,
562    k: usize,
563    weight_bytes: &[u8],
564    n: usize,
565    weight_scale: f32,
566    output: &mut [f32],
567) {
568    // TODO: fill in. For now delegate to the scalar path so the public
569    // API works on every machine — this is the scaffolding for the
570    // follow-up perf commit.
571    matmul_i2s_i8_scalar(
572        acts_int8,
573        act_scales,
574        m,
575        k,
576        weight_bytes,
577        n,
578        weight_scale,
579        output,
580    );
581}
582
583// =============================================================================
584// Size helpers
585// =============================================================================
586
587/// Bytes needed to store `n_elements` I2_S weights (excluding the 4-byte
588/// tensor scale and any alignment padding — those live outside the packed
589/// stream).
590pub fn bytes_for_elements(n_elements: usize) -> usize {
591    (n_elements / I2S_BLOCK_SIZE) * I2S_BYTES_PER_BLOCK
592}
593
594// =============================================================================
595// Tests
596// =============================================================================
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    fn make_trits(pattern: &[i8]) -> [i8; I2S_BLOCK_SIZE] {
603        let mut out = [0i8; I2S_BLOCK_SIZE];
604        for (i, v) in pattern.iter().cycle().take(I2S_BLOCK_SIZE).enumerate() {
605            out[i] = *v;
606        }
607        out
608    }
609
610    #[test]
611    fn trit_encode_decode_roundtrip() {
612        assert_eq!(decode_trit(0b00), -1);
613        assert_eq!(decode_trit(0b01), 0);
614        assert_eq!(decode_trit(0b10), 1);
615        assert_eq!(decode_trit(0b11), 0);
616        assert_eq!(encode_trit(-1), 0);
617        assert_eq!(encode_trit(0), 1);
618        assert_eq!(encode_trit(1), 2);
619        // Clamp.
620        assert_eq!(encode_trit(42), 2);
621        assert_eq!(encode_trit(-42), 0);
622    }
623
624    #[test]
625    fn block_pack_unpack_roundtrip() {
626        let values = make_trits(&[-1, 0, 1, 0, 1, -1, 0, 0]);
627        let block = I2sBlock::pack(&values);
628        let decoded = block.unpack();
629        assert_eq!(&values[..], &decoded[..]);
630    }
631
632    #[test]
633    fn block_bytes_roundtrip() {
634        let values = make_trits(&[1, -1, 0]);
635        let block = I2sBlock::pack(&values);
636        let bytes = block.to_bytes();
637        let parsed = I2sBlock::from_bytes(&bytes).unwrap();
638        assert_eq!(parsed.data, block.data);
639        assert_eq!(parsed.unpack(), values);
640    }
641
642    #[test]
643    fn dequantize_single_block() {
644        // Alternating +1 / -1 with scale = 2.5 → +2.5 / -2.5.
645        let values = make_trits(&[1, -1]);
646        let block = I2sBlock::pack(&values);
647        let bytes = block.to_bytes();
648        let mut out = [0.0f32; I2S_BLOCK_SIZE];
649        dequantize_i2s_block(&bytes, 2.5, &mut out);
650        for (i, v) in out.iter().enumerate() {
651            let expected = if i % 2 == 0 { 2.5 } else { -2.5 };
652            assert!(
653                (v - expected).abs() < 1e-6,
654                "idx {i}: got {v}, expected {expected}",
655            );
656        }
657    }
658
659    #[test]
660    fn group_strided_layout_is_correct() {
661        // Weight at logical position `group_idx * 32 + group_pos` must be
662        // stored in byte `group_pos`'s bit-slice `(6 - 2*group_idx)..(8 - 2*group_idx)`.
663        // Build a block where only position 0 is +1, rest are 0, and verify the
664        // byte layout matches the Microsoft encoder formula directly.
665        let mut values = [0i8; I2S_BLOCK_SIZE];
666        values[0] = 1; // group 0, pos 0 → byte 0 bits 6..7 = code 2 = 0b10
667        let block = I2sBlock::pack(&values);
668        assert_eq!(
669            block.data[0] & 0b1100_0000,
670            0b1000_0000,
671            "expected code=2 (+1) in byte 0 bits 6-7",
672        );
673        // All other bytes: zero bits 6-7 = code 0 = -1 oh wait — we want zero.
674        // Since we packed 0s everywhere else, their code is 1 (0b01). Byte 0
675        // bits 4-5 should encode group-1-pos-0 = 0 = code 1:
676        assert_eq!(
677            (block.data[0] >> 4) & 0b11,
678            1,
679            "expected code=1 (0) in byte 0 bits 4-5",
680        );
681    }
682
683    #[test]
684    fn dequantize_multi_block_tensor() {
685        let n_blocks = 3;
686        let n_elem = n_blocks * I2S_BLOCK_SIZE;
687        let mut bytes = Vec::with_capacity(n_blocks * I2S_BYTES_PER_BLOCK);
688        let patterns: &[&[i8]] = &[&[1, 0, -1], &[-1, 1, 0], &[0, 0, 1, -1]];
689        for b in 0..n_blocks {
690            let block = I2sBlock::pack(&make_trits(patterns[b]));
691            bytes.extend_from_slice(&block.to_bytes());
692        }
693        let mut out = vec![0.0f32; n_elem];
694        dequantize_i2s(&bytes, 1.0, &mut out);
695
696        // Spot-check first trit of each block.
697        assert_eq!(out[0], 1.0);
698        assert_eq!(out[I2S_BLOCK_SIZE], -1.0);
699        assert_eq!(out[2 * I2S_BLOCK_SIZE], 0.0);
700    }
701
702    fn reference_matmul(
703        activations: &[f32],
704        m: usize,
705        k: usize,
706        weight_bytes: &[u8],
707        n: usize,
708        scale: f32,
709        output: &mut [f32],
710    ) {
711        let mut w = vec![0.0f32; n * k];
712        dequantize_i2s(weight_bytes, scale, &mut w);
713        for i in 0..m {
714            for j in 0..n {
715                let mut s = 0.0f32;
716                for kk in 0..k {
717                    s += activations[i * k + kk] * w[j * k + kk];
718                }
719                output[i * n + j] = s;
720            }
721        }
722    }
723
724    #[test]
725    fn matmul_matches_reference_small() {
726        let m = 2;
727        let k = I2S_BLOCK_SIZE;
728        let n = 4;
729        let scale = 0.125f32;
730
731        let mut weight_bytes = Vec::new();
732        let patterns: &[&[i8]] = &[&[1, 0, -1], &[-1, 1, 0], &[0, -1, 1], &[1, 1, -1, -1]];
733        for j in 0..n {
734            let vals = make_trits(patterns[j]);
735            let block = I2sBlock::pack(&vals);
736            weight_bytes.extend_from_slice(&block.to_bytes());
737        }
738
739        let mut activations = vec![0.0f32; m * k];
740        for i in 0..m {
741            for kk in 0..k {
742                activations[i * k + kk] = (i as f32 + 1.0) * (kk as f32 / k as f32);
743            }
744        }
745
746        let mut fused_out = vec![0.0f32; m * n];
747        let mut ref_out = vec![0.0f32; m * n];
748        matmul_i2s(&activations, m, k, &weight_bytes, n, scale, &mut fused_out);
749        reference_matmul(&activations, m, k, &weight_bytes, n, scale, &mut ref_out);
750
751        for (i, (f, r)) in fused_out.iter().zip(ref_out.iter()).enumerate() {
752            assert!((f - r).abs() < 1e-5, "mismatch at {i}: fused={f}, ref={r}",);
753        }
754    }
755
756    #[test]
757    fn matmul_matches_reference_multi_block() {
758        let m = 3;
759        let k = 3 * I2S_BLOCK_SIZE;
760        let n = 5;
761        let scale = 0.25f32;
762
763        let mut weight_bytes = Vec::new();
764        for j in 0..n {
765            for b in 0..(k / I2S_BLOCK_SIZE) {
766                let pattern = if (j + b) % 2 == 0 {
767                    &[1, 0, -1, 1, -1][..]
768                } else {
769                    &[-1, -1, 1, 0, 1][..]
770                };
771                let block = I2sBlock::pack(&make_trits(pattern));
772                weight_bytes.extend_from_slice(&block.to_bytes());
773            }
774        }
775
776        let mut activations = vec![0.0f32; m * k];
777        for i in 0..m {
778            for kk in 0..k {
779                activations[i * k + kk] = ((i + 1) as f32) * ((kk as f32).sin());
780            }
781        }
782
783        let mut fused_out = vec![0.0f32; m * n];
784        let mut ref_out = vec![0.0f32; m * n];
785        matmul_i2s(&activations, m, k, &weight_bytes, n, scale, &mut fused_out);
786        reference_matmul(&activations, m, k, &weight_bytes, n, scale, &mut ref_out);
787
788        for (i, (f, r)) in fused_out.iter().zip(ref_out.iter()).enumerate() {
789            assert!((f - r).abs() < 1e-4, "mismatch at {i}: fused={f}, ref={r}",);
790        }
791    }
792
793    #[test]
794    fn bytes_for_elements_calculation() {
795        assert_eq!(bytes_for_elements(128), 32);
796        assert_eq!(bytes_for_elements(256), 64);
797        assert_eq!(bytes_for_elements(1024), 256);
798        assert_eq!(bytes_for_elements(0), 0);
799    }
800
801    #[test]
802    fn int8_matmul_matches_f32_within_quant_error() {
803        // Quantize activations + run both paths. They should agree within
804        // the per-row absmax/127 quantization noise.
805        let m = 2;
806        let k = 2 * I2S_BLOCK_SIZE;
807        let n = 6;
808        let weight_scale = 0.1f32;
809
810        let mut weight_bytes = Vec::new();
811        for j in 0..n {
812            for b in 0..(k / I2S_BLOCK_SIZE) {
813                let pattern: &[i8] = if (j + b) % 2 == 0 {
814                    &[1, 0, -1, 1]
815                } else {
816                    &[-1, 1, 0, -1]
817                };
818                let block = I2sBlock::pack(&make_trits(pattern));
819                weight_bytes.extend_from_slice(&block.to_bytes());
820            }
821        }
822
823        // f32 activations, deterministic-ish.
824        let mut activations = vec![0.0f32; m * k];
825        for i in 0..m {
826            for kk in 0..k {
827                activations[i * k + kk] = ((kk as f32) * 0.13 - 2.0).sin() * (1.0 + i as f32 * 0.1);
828            }
829        }
830
831        let mut ref_out = vec![0.0f32; m * n];
832        matmul_i2s(
833            &activations,
834            m,
835            k,
836            &weight_bytes,
837            n,
838            weight_scale,
839            &mut ref_out,
840        );
841
842        // Quantize activations per row.
843        let mut acts_i8 = vec![0i8; m * k];
844        let mut act_scales = vec![0.0f32; m];
845        for i in 0..m {
846            act_scales[i] = quantize_row_to_int8(
847                &activations[i * k..(i + 1) * k],
848                &mut acts_i8[i * k..(i + 1) * k],
849            );
850        }
851
852        let mut i8_out = vec![0.0f32; m * n];
853        matmul_i2s_i8(
854            &acts_i8,
855            &act_scales,
856            m,
857            k,
858            &weight_bytes,
859            n,
860            weight_scale,
861            &mut i8_out,
862        );
863
864        // Expected error: per-activation quantization error is ≤ scale/2,
865        // and the dot product sums k=256 terms, so worst-case error scales
866        // with sqrt(k) × max_weight_scale × act_scale/2. For our tiny test,
867        // a relative tolerance of ~3% against the f32 reference is generous
868        // enough to catch logic errors without flaking.
869        for (i, (&r, &q)) in ref_out.iter().zip(i8_out.iter()).enumerate() {
870            let abs_err = (r - q).abs();
871            let rel_err = abs_err / r.abs().max(1e-6);
872            assert!(
873                rel_err < 0.05 || abs_err < 1e-3,
874                "idx {i}: f32 ref = {r}, int8 quantized = {q}, rel_err = {rel_err}",
875            );
876        }
877    }
878
879    #[test]
880    fn quantize_row_to_int8_roundtrip() {
881        let input = [1.0f32, -2.0, 0.5, -0.5, 0.0, 2.0, -1.5];
882        let mut output = [0i8; 7];
883        let scale = quantize_row_to_int8(&input, &mut output);
884        assert!(scale > 0.0);
885        // Largest magnitude is 2.0; should map to ±127.
886        let max_idx = input
887            .iter()
888            .enumerate()
889            .max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap())
890            .unwrap()
891            .0;
892        assert_eq!(output[max_idx].unsigned_abs(), 127);
893        // Dequantized values should be close to the originals.
894        for (i, &v) in input.iter().enumerate() {
895            let recovered = output[i] as f32 * scale;
896            assert!(
897                (recovered - v).abs() < scale,
898                "idx {i}: {v} → {} (scale={scale})",
899                recovered
900            );
901        }
902    }
903
904    #[test]
905    fn quantize_row_to_int8_zero_input() {
906        let input = [0.0f32; 8];
907        let mut output = [0i8; 8];
908        let scale = quantize_row_to_int8(&input, &mut output);
909        assert_eq!(scale, 0.0);
910        assert!(output.iter().all(|&x| x == 0));
911    }
912
913    #[test]
914    #[should_panic(expected = "k")]
915    fn matmul_rejects_misaligned_k() {
916        let m = 1;
917        let k = 100;
918        let n = 1;
919        let acts = vec![0.0; m * k];
920        let weight_bytes = vec![0u8; I2S_BYTES_PER_BLOCK];
921        let mut out = vec![0.0; m * n];
922        matmul_i2s(&acts, m, k, &weight_bytes, n, 1.0, &mut out);
923    }
924}