axonml_quant/bitnet.rs
1//! BitNet b1.58 I2_S Ternary Quantization — Dequant + Fused Add-Only Matmul
2//!
3//! Implements Microsoft's `I2_S` quant type (GGUF dtype 36), used by the
4//! `bitnet.cpp` reference kernels and by every official BitNet b1.58 GGUF
5//! release (including `microsoft/bitnet-b1.58-2B-4T-gguf`).
6//!
7//! Contents:
8//! - Constants `I2S_BLOCK_SIZE`, `I2S_BYTES_PER_BLOCK`, `I2S_GROUP_SIZE`
9//! mirroring Microsoft's `QK_I2_S` / group-strided layout.
10//! - Trit ↔ 2-bit code converters `decode_trit` / `encode_trit`.
11//! - `I2sBlock` struct with `pack` / `unpack` / `to_bytes` / `from_bytes`
12//! implementing the group-strided byte layout.
13//! - `dequantize_i2s_block` and rayon-parallel `dequantize_i2s` for
14//! recovering f32 weights from packed bytes.
15//! - `matmul_i2s` — fused add-only ternary matmul (f32 activations).
16//! - `matmul_i2s_i8` — int8-activation fused path with runtime AVX-VNNI
17//! dispatch, scalar fallback `matmul_i2s_i8_scalar`, and an in-progress
18//! `matmul_i2s_i8_avxvnni` unsafe target-feature stub.
19//! - `quantize_row_to_int8` per-row absmax int8 quantization for
20//! activations entering the int8 fast path.
21//! - `bytes_for_elements` size helper and a test module covering
22//! trit encode/decode, block roundtrip, layout correctness, reference
23//! matmul agreement, int8 vs f32 agreement, and misaligned-`k` rejection.
24//!
25//! # Format (verified against `microsoft/BitNet` reference, 2026-04-14)
26//!
27//! - **Block size: 128 weights** (`QK_I2_S = 128` on x86_64).
28//! - **Block stride: 32 bytes** (128 × 2 bits packed).
29//! - **Encoding per 2-bit code:** `0 → -1`, `1 → 0`, `2 → +1`, `3 → unused`.
30//! (From `quantize_i2_s` in `ggml-bitnet-mad.cpp`:
31//! `"q8 -> 0, 1, 2 | | | -1, 0, 1"`.)
32//! - **Intra-block layout is NOT `4 consecutive weights / byte`** — it's a
33//! SIMD-friendly group-strided layout. Each 32-byte block stores 128
34//! weights as 4 groups of 32, multiplexed into the 2-bit positions of
35//! the 32 bytes:
36//!
37//! - byte `k` bits **6..7** → weight `k` (group 0, shift 6)
38//! - byte `k` bits **4..5** → weight `k + 32` (group 1, shift 4)
39//! - byte `k` bits **2..3** → weight `k + 64` (group 2, shift 2)
40//! - byte `k` bits **0..1** → weight `k + 96` (group 3, shift 0)
41//!
42//! (Encoder in the reference: `temp = q8 << (6 - 2*group_idx)`; decoder
43//! shifts the byte right by `6 - 2*group_idx` and ANDs with `0x03`.) The
44//! layout lets AVX2 load 32 bytes once, shift by 0/2/4/6 with a mask of
45//! `0x03`, and extract 128 trit codes in four 32-wide vector registers.
46//!
47//! - **One tensor-wide f32 scale** follows the packed data at offset
48//! `m * k / 4` bytes, padded to the next 32-byte boundary. The
49//! dequantized weight is `scale × trit[i]`. This module requires the
50//! caller to pass the scale separately — [`matmul_i2s`] and
51//! [`dequantize_i2s`] both accept `scale: f32`.
52//!
53//! # Why this is fast on CPU
54//!
55//! Matmul between an f32 activation matrix and ternary weights becomes
56//! branchless accumulate-and-subtract: for each weight, either add the
57//! activation (+1), subtract it (-1), or skip (0). The tensor-wide scale
58//! applies once at the end of each output element. That's BitNet's
59//! "add-only matmul" performance story. A SIMD fast path mirroring the
60//! reference AVX2 kernel is a natural follow-up once the scalar path is
61//! verified against Microsoft's released weights.
62//!
63//! # References
64//! - Microsoft BitNet paper ("The Era of 1-bit LLMs"): <https://arxiv.org/abs/2402.17764>
65//! - `microsoft/BitNet` on GitHub (bitnet.cpp reference kernels)
66//!
67//! # File
68//! `crates/axonml-quant/src/bitnet.rs`
69//!
70//! # Author
71//! Andrew Jewell Sr. — AutomataNexus LLC
72//! ORCID: 0009-0005-2158-7060
73//!
74//! # Updated
75//! April 16, 2026 11:15 PM EST
76//!
77//! # Disclaimer
78//! Use at own risk. This software is provided "as is", without warranty of any
79//! kind, express or implied. The author and AutomataNexus shall not be held
80//! liable for any damages arising from the use of this software.
81
82use rayon::prelude::*;
83
84// =============================================================================
85// Constants
86// =============================================================================
87
88/// Weights per I2_S block (Microsoft `QK_I2_S` on x86_64).
89pub const I2S_BLOCK_SIZE: usize = 128;
90
91/// Bytes per I2_S block (128 × 2 bits).
92pub const I2S_BYTES_PER_BLOCK: usize = 32;
93
94/// Group size within a block — 4 groups of 32 weights share the same 32
95/// bytes but live at different bit-positions.
96pub const I2S_GROUP_SIZE: usize = 32;
97
98// =============================================================================
99// Trit <-> 2-bit encoding (Microsoft bitnet.cpp convention)
100// =============================================================================
101
102/// Decode a 2-bit code to a trit in `{-1, 0, +1}`.
103///
104/// Encoding (Microsoft bitnet.cpp): `0 → -1`, `1 → 0`, `2 → +1`, `3 → 0`
105/// (defensive; code 3 is unused by the reference encoder).
106#[inline(always)]
107pub const fn decode_trit(code: u8) -> i8 {
108 match code & 0b11 {
109 0 => -1,
110 1 => 0,
111 2 => 1,
112 _ => 0,
113 }
114}
115
116/// Encode a trit to a 2-bit code (inverse of [`decode_trit`]).
117#[inline(always)]
118const fn encode_trit(v: i8) -> u8 {
119 if v > 0 {
120 2
121 } else if v < 0 {
122 0
123 } else {
124 1
125 }
126}
127
128// =============================================================================
129// Block pack / unpack (primarily for tests; production loads raw bytes)
130// =============================================================================
131
132/// A single I2_S block: 128 ternary weights in the group-strided layout.
133///
134/// The tensor-wide scale is not stored on the block — see module docs.
135#[derive(Debug, Clone)]
136pub struct I2sBlock {
137 /// 32 bytes holding 128 × 2-bit trits in group-strided form.
138 pub data: [u8; I2S_BYTES_PER_BLOCK],
139}
140
141impl I2sBlock {
142 /// Pack 128 trit values `{-1, 0, +1}` using the Microsoft layout.
143 pub fn pack(values: &[i8; I2S_BLOCK_SIZE]) -> Self {
144 let mut data = [0u8; I2S_BYTES_PER_BLOCK];
145 // For each (group_idx, group_pos), OR the code at bits [6-2*g .. 7-2*g].
146 for group_idx in 0..4 {
147 let shift = 6 - 2 * group_idx;
148 for group_pos in 0..I2S_GROUP_SIZE {
149 let code = encode_trit(values[group_idx * I2S_GROUP_SIZE + group_pos]);
150 data[group_pos] |= code << shift;
151 }
152 }
153 Self { data }
154 }
155
156 /// 32-byte raw view.
157 pub fn to_bytes(&self) -> [u8; I2S_BYTES_PER_BLOCK] {
158 self.data
159 }
160
161 /// Parse a block from a 32-byte slice.
162 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
163 if bytes.len() < I2S_BYTES_PER_BLOCK {
164 return None;
165 }
166 let mut data = [0u8; I2S_BYTES_PER_BLOCK];
167 data.copy_from_slice(&bytes[..I2S_BYTES_PER_BLOCK]);
168 Some(Self { data })
169 }
170
171 /// Unpack all 128 trits in linear order `0..128`.
172 pub fn unpack(&self) -> [i8; I2S_BLOCK_SIZE] {
173 let mut out = [0i8; I2S_BLOCK_SIZE];
174 for group_idx in 0..4 {
175 let shift = 6 - 2 * group_idx;
176 for group_pos in 0..I2S_GROUP_SIZE {
177 let code = (self.data[group_pos] >> shift) & 0b11;
178 out[group_idx * I2S_GROUP_SIZE + group_pos] = decode_trit(code);
179 }
180 }
181 out
182 }
183}
184
185// =============================================================================
186// Dequantization
187// =============================================================================
188
189/// Dequantize a single I2_S block to 128 `f32` values.
190///
191/// Applies the tensor-wide `scale` so `out[i] = scale × trit[i]`.
192///
193/// # Panics
194/// Debug-only: panics if `bytes.len() < 32` or `out.len() < 128`.
195pub fn dequantize_i2s_block(bytes: &[u8], scale: f32, out: &mut [f32]) {
196 debug_assert!(bytes.len() >= I2S_BYTES_PER_BLOCK);
197 debug_assert!(out.len() >= I2S_BLOCK_SIZE);
198 let b = &bytes[..I2S_BYTES_PER_BLOCK];
199 for (group_pos, &byte) in b.iter().enumerate() {
200 let c0 = decode_trit((byte >> 6) & 0b11) as f32; // group 0
201 let c1 = decode_trit((byte >> 4) & 0b11) as f32; // group 1
202 let c2 = decode_trit((byte >> 2) & 0b11) as f32; // group 2
203 let c3 = decode_trit(byte & 0b11) as f32; // group 3
204 out[group_pos] = c0 * scale;
205 out[group_pos + I2S_GROUP_SIZE] = c1 * scale;
206 out[group_pos + 2 * I2S_GROUP_SIZE] = c2 * scale;
207 out[group_pos + 3 * I2S_GROUP_SIZE] = c3 * scale;
208 }
209}
210
211/// Dequantize a full I2_S weight buffer to f32.
212///
213/// `out.len()` must be a multiple of [`I2S_BLOCK_SIZE`]. Rayon-parallelized
214/// over blocks.
215pub fn dequantize_i2s(bytes: &[u8], scale: f32, out: &mut [f32]) {
216 let n_blocks = out.len() / I2S_BLOCK_SIZE;
217 out.par_chunks_mut(I2S_BLOCK_SIZE)
218 .take(n_blocks)
219 .zip(bytes.par_chunks(I2S_BYTES_PER_BLOCK).take(n_blocks))
220 .for_each(|(out_block, in_block)| {
221 if in_block.len() >= I2S_BYTES_PER_BLOCK {
222 dequantize_i2s_block(in_block, scale, out_block);
223 }
224 });
225}
226
227// =============================================================================
228// Fused ternary matmul
229// =============================================================================
230
231/// Fused add-only ternary matmul: `output = scale × (activations @ weights^T)`.
232///
233/// # Shapes
234/// - `activations`: `[m, k]` row-major f32
235/// - `weight_bytes`: `[n, k]` — each output row is a contiguous run of
236/// `k / 128` I2_S blocks (32 bytes each). `k` **must** be a multiple
237/// of [`I2S_BLOCK_SIZE`] (128).
238/// - `scale`: tensor-wide f32 scale read from the tail of the GGUF tensor
239/// (see module docs)
240/// - `output`: `[m, n]` row-major f32
241///
242/// # Panics
243/// Panics if `k % 128 != 0`, or if shapes don't line up.
244///
245/// # Parallelism
246/// Parallelizes over output columns `n`. Right grain for decode (`m == 1`)
247/// — one independent dot product per column.
248pub fn matmul_i2s(
249 activations: &[f32],
250 m: usize,
251 k: usize,
252 weight_bytes: &[u8],
253 n: usize,
254 scale: f32,
255 output: &mut [f32],
256) {
257 assert!(
258 k % I2S_BLOCK_SIZE == 0,
259 "matmul_i2s: k ({k}) must be a multiple of {I2S_BLOCK_SIZE}",
260 );
261 assert_eq!(activations.len(), m * k, "activations shape mismatch");
262 assert_eq!(output.len(), m * n, "output shape mismatch");
263 let blocks_per_row = k / I2S_BLOCK_SIZE;
264 let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
265 assert_eq!(
266 weight_bytes.len(),
267 n * bytes_per_row,
268 "weight_bytes shape mismatch",
269 );
270
271 for i in 0..m {
272 let act_row = &activations[i * k..(i + 1) * k];
273 output[i * n..(i + 1) * n]
274 .par_iter_mut()
275 .enumerate()
276 .for_each(|(j, out_slot)| {
277 let wrow = &weight_bytes[j * bytes_per_row..(j + 1) * bytes_per_row];
278 *out_slot = dot_row_ternary(act_row, wrow, blocks_per_row) * scale;
279 });
280 }
281}
282
283/// Inner dot product for one activation row × one ternary weight row, in
284/// the group-strided layout. Returns the UNSCALED sum — caller multiplies
285/// by the tensor-wide scale.
286#[inline(always)]
287fn dot_row_ternary(act_row: &[f32], wrow: &[u8], blocks_per_row: usize) -> f32 {
288 let mut acc = 0.0f32;
289 for block_idx in 0..blocks_per_row {
290 let block_off = block_idx * I2S_BYTES_PER_BLOCK;
291 let block = &wrow[block_off..block_off + I2S_BYTES_PER_BLOCK];
292 let k_base = block_idx * I2S_BLOCK_SIZE;
293 // Autovec-friendly: pull the four group activations into locals,
294 // decode once per byte, accumulate into four separate accumulators
295 // (lets LLVM issue independent FMAs on wide targets).
296 let mut a0 = 0.0f32;
297 let mut a1 = 0.0f32;
298 let mut a2 = 0.0f32;
299 let mut a3 = 0.0f32;
300 for (group_pos, &byte) in block.iter().enumerate() {
301 let t0 = decode_trit((byte >> 6) & 0b11) as f32;
302 let t1 = decode_trit((byte >> 4) & 0b11) as f32;
303 let t2 = decode_trit((byte >> 2) & 0b11) as f32;
304 let t3 = decode_trit(byte & 0b11) as f32;
305 let base = k_base + group_pos;
306 a0 += act_row[base] * t0;
307 a1 += act_row[base + I2S_GROUP_SIZE] * t1;
308 a2 += act_row[base + 2 * I2S_GROUP_SIZE] * t2;
309 a3 += act_row[base + 3 * I2S_GROUP_SIZE] * t3;
310 }
311 acc += a0 + a1 + a2 + a3;
312 }
313 acc
314}
315
316// =============================================================================
317// Int8-activation fused ternary matmul (AVX-VNNI path)
318// =============================================================================
319//
320// This is the "30-50% over llama.cpp" lever. The scalar `matmul_i2s` above
321// materializes activations as f32 and walks through trit codes one byte at
322// a time — memory-BW-bound and one instruction per trit on CPU.
323//
324// The fused int8 path does three things differently:
325//
326// 1. Activations are quantized to int8 with a per-row absmax scale before
327// the matmul. Bandwidth drops 4× (f32 → i8) on the activation side.
328// 2. Trit codes stay in their 2-bit packed form and never decode to f32.
329// Bandwidth drops 16× on the weight side (f32 → 2 bits).
330// 3. The dot product uses the VNNI `dpbusd` instruction (`_mm256_dpbusd_epi32`
331// on AVX-VNNI, `_mm512_dpbusd_epi32` on AVX-512 VNNI). dpbusd does 32
332// unsigned-byte × signed-byte multiplies and sums groups of 4 into int32
333// lanes, so one instruction handles 32 weight-activation pairs.
334//
335// Arithmetic trick: the trit codes `{0, 1, 2}` map to `{-1, 0, +1}` via
336// `trit = code - 1`. dpbusd computes `sum(code × act)`, not `sum(trit × act)`.
337// We recover the true dot product with a single per-row correction:
338//
339// true_dot_j = code_dot_j - act_sum
340//
341// where `act_sum = sum_k(act_i[k])` (a single scalar per input row, computed
342// once). The correction is two cheap int32 ops per output element.
343//
344// Microsoft's bitnet.cpp does this same accounting in its AVX2 kernels — we
345// match their approach and then tune for Arrow Lake's AVX-VNNI-but-no-AVX-512
346// ISA profile. Reference in `ggml-bitnet-mad.cpp::ggml_vec_dot_i2_i8_s_1x1`.
347
348/// Quantize an f32 row to int8 with a per-row absmax scale such that
349/// `f32[i] ≈ int8[i] * scale`. The scale is chosen so the largest-
350/// magnitude element maps to ±127.
351///
352/// Returns the scale. Intended for **activation quantization** at the
353/// beginning of an I2_S × int8 matmul — the matmul output multiplies by
354/// this scale (along with the weight's tensor-wide scale) to recover f32
355/// logits.
356///
357/// # Edge cases
358/// - All-zero input → scale = 0, output is all zeros.
359/// - Signals with a single enormous outlier — absmax is sensitive; callers
360/// with outlier activations may want to clamp or use percentile-absmax
361/// instead. Not a concern for typical post-norm transformer activations.
362pub fn quantize_row_to_int8(input: &[f32], output: &mut [i8]) -> f32 {
363 debug_assert_eq!(input.len(), output.len());
364 let absmax = input.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
365 if absmax == 0.0 {
366 for o in output.iter_mut() {
367 *o = 0;
368 }
369 return 0.0;
370 }
371 let scale = absmax / 127.0;
372 let inv_scale = 1.0 / scale;
373 for (o, &v) in output.iter_mut().zip(input.iter()) {
374 let q = (v * inv_scale).round();
375 // Clamp to i8 range defensively — rounding can push ±127.5 → ±128.
376 *o = q.clamp(-127.0, 127.0) as i8;
377 }
378 scale
379}
380
381/// Fused I2_S × int8 matmul: `out = weight_scale × (acts_int8 @ weights^T) × act_scales_per_row`.
382///
383/// # Shapes
384/// - `acts_int8`: `[m, k]` row-major int8 (quantized via [`quantize_row_to_int8`])
385/// - `act_scales`: `[m]` — per-row f32 scales from the int8 quantization
386/// - `weight_bytes`: `[n, k]` as I2_S blocks (same layout as [`matmul_i2s`])
387/// - `weight_scale`: tensor-wide f32 scale (from GGUF tail)
388/// - `output`: `[m, n]` row-major f32
389///
390/// # Math
391/// For each `(i, j)`:
392///
393/// ```text
394/// out[i, j] = act_scales[i] * weight_scale * sum_k(trit[j, k] * act_i8[i, k])
395/// = act_scales[i] * weight_scale * (sum_k(code[j, k] * act_i8[i, k]) - act_sum[i])
396/// ```
397///
398/// where `code[j, k] ∈ {0, 1, 2}` is the raw 2-bit code and `act_sum[i] = sum_k(act_i8[i, k])`.
399///
400/// # Dispatch
401/// At runtime we check for AVX-VNNI via `is_x86_feature_detected!("avxvnni")`
402/// and take the SIMD path when available; otherwise fall back to a scalar
403/// reference that matches the SIMD path bit-for-bit (lets tests run on any
404/// host).
405pub fn matmul_i2s_i8(
406 acts_int8: &[i8],
407 act_scales: &[f32],
408 m: usize,
409 k: usize,
410 weight_bytes: &[u8],
411 n: usize,
412 weight_scale: f32,
413 output: &mut [f32],
414) {
415 assert!(
416 k % I2S_BLOCK_SIZE == 0,
417 "matmul_i2s_i8: k ({k}) must be a multiple of {I2S_BLOCK_SIZE}",
418 );
419 assert_eq!(acts_int8.len(), m * k, "acts_int8 shape mismatch");
420 assert_eq!(act_scales.len(), m, "act_scales length mismatch");
421 assert_eq!(output.len(), m * n, "output shape mismatch");
422 let blocks_per_row = k / I2S_BLOCK_SIZE;
423 let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
424 assert_eq!(
425 weight_bytes.len(),
426 n * bytes_per_row,
427 "weight_bytes shape mismatch",
428 );
429
430 // Runtime dispatch — AVX-VNNI variant fills in on a follow-up commit;
431 // scalar path below is the correctness baseline.
432 #[cfg(target_arch = "x86_64")]
433 {
434 if std::is_x86_feature_detected!("avxvnni") && std::is_x86_feature_detected!("avx2") {
435 // SAFETY: feature-detected above.
436 unsafe {
437 matmul_i2s_i8_avxvnni(
438 acts_int8,
439 act_scales,
440 m,
441 k,
442 weight_bytes,
443 n,
444 weight_scale,
445 output,
446 );
447 }
448 return;
449 }
450 }
451
452 matmul_i2s_i8_scalar(
453 acts_int8,
454 act_scales,
455 m,
456 k,
457 weight_bytes,
458 n,
459 weight_scale,
460 output,
461 );
462}
463
464/// Scalar reference for [`matmul_i2s_i8`]. Used as a correctness baseline
465/// for the AVX-VNNI fast path and as a fallback on non-x86_64 or
466/// pre-AVX-VNNI CPUs.
467fn matmul_i2s_i8_scalar(
468 acts_int8: &[i8],
469 act_scales: &[f32],
470 m: usize,
471 k: usize,
472 weight_bytes: &[u8],
473 n: usize,
474 weight_scale: f32,
475 output: &mut [f32],
476) {
477 let blocks_per_row = k / I2S_BLOCK_SIZE;
478 let bytes_per_row = blocks_per_row * I2S_BYTES_PER_BLOCK;
479
480 for i in 0..m {
481 let act_row = &acts_int8[i * k..(i + 1) * k];
482 let act_scale = act_scales[i];
483 let act_sum: i32 = act_row.iter().map(|&x| x as i32).sum();
484 let combined_scale = weight_scale * act_scale;
485
486 output[i * n..(i + 1) * n]
487 .par_iter_mut()
488 .enumerate()
489 .for_each(|(j, out_slot)| {
490 let wrow = &weight_bytes[j * bytes_per_row..(j + 1) * bytes_per_row];
491 // sum_k(code[j,k] * act[k]) with code in {0,1,2}.
492 let mut code_dot: i32 = 0;
493 for block_idx in 0..blocks_per_row {
494 let block_off = block_idx * I2S_BYTES_PER_BLOCK;
495 let block = &wrow[block_off..block_off + I2S_BYTES_PER_BLOCK];
496 let k_base = block_idx * I2S_BLOCK_SIZE;
497 for (group_pos, &byte) in block.iter().enumerate() {
498 // Bit layout: bits 6-7 → pos k_base+group_pos
499 // bits 4-5 → pos k_base+32+group_pos
500 // bits 2-3 → pos k_base+64+group_pos
501 // bits 0-1 → pos k_base+96+group_pos
502 let c0 = ((byte >> 6) & 0b11) as i32;
503 let c1 = ((byte >> 4) & 0b11) as i32;
504 let c2 = ((byte >> 2) & 0b11) as i32;
505 let c3 = (byte & 0b11) as i32;
506 let base = k_base + group_pos;
507 code_dot += c0 * act_row[base] as i32;
508 code_dot += c1 * act_row[base + I2S_GROUP_SIZE] as i32;
509 code_dot += c2 * act_row[base + 2 * I2S_GROUP_SIZE] as i32;
510 code_dot += c3 * act_row[base + 3 * I2S_GROUP_SIZE] as i32;
511 }
512 }
513 // trit = code - 1, so sum(trit*act) = sum(code*act) - sum(act).
514 let trit_dot = code_dot - act_sum;
515 *out_slot = (trit_dot as f32) * combined_scale;
516 });
517 }
518}
519
520/// AVX-VNNI fast path — **unimplemented**. Drop-in replacement for
521/// [`matmul_i2s_i8_scalar`] once filled in.
522///
523/// # Planned inner loop (per block of 128 weights × 128 activations):
524///
525/// ```ignore
526/// // Load 32 weight bytes (128 trits packed) and 128 int8 activations.
527/// let bytes_v = _mm256_loadu_si256(block_ptr as *const __m256i);
528/// let acts_g0 = _mm256_loadu_si256(act_ptr.add(k_base) as *const __m256i); // pos k_base..+32
529/// let acts_g1 = _mm256_loadu_si256(act_ptr.add(k_base + 32) as *const __m256i);
530/// let acts_g2 = _mm256_loadu_si256(act_ptr.add(k_base + 64) as *const __m256i);
531/// let acts_g3 = _mm256_loadu_si256(act_ptr.add(k_base + 96) as *const __m256i);
532///
533/// // Extract 2-bit codes for each of 4 groups.
534/// let mask = _mm256_set1_epi8(0x03);
535/// let codes_g0 = _mm256_and_si256(_mm256_srli_epi16(bytes_v, 6), mask); // bits 6-7
536/// let codes_g1 = _mm256_and_si256(_mm256_srli_epi16(bytes_v, 4), mask); // bits 4-5
537/// let codes_g2 = _mm256_and_si256(_mm256_srli_epi16(bytes_v, 2), mask); // bits 2-3
538/// let codes_g3 = _mm256_and_si256(bytes_v, mask); // bits 0-1
539///
540/// // 32 × (u8 × i8) → 8 × i32, accumulated.
541/// acc = _mm256_dpbusd_epi32(acc, codes_g0, acts_g0);
542/// acc = _mm256_dpbusd_epi32(acc, codes_g1, acts_g1);
543/// acc = _mm256_dpbusd_epi32(acc, codes_g2, acts_g2);
544/// acc = _mm256_dpbusd_epi32(acc, codes_g3, acts_g3);
545/// ```
546///
547/// Four VNNI ops per 128-weight block. The outer loop iterates
548/// `blocks_per_row` blocks per output column, then horizontally sums the
549/// int32 lanes to a scalar, applies the `- act_sum` correction, and scales
550/// by `combined_scale` to f32.
551///
552/// Rayon fan-out over output columns `n` (same as the scalar path). On
553/// Arrow Lake (AVX-VNNI but no AVX-512), expect ~8-12× speedup over
554/// scalar on the kernel alone; end-to-end wins compound because activation
555/// bandwidth drops 4× and weight bandwidth stays 2-bit.
556#[cfg(target_arch = "x86_64")]
557#[target_feature(enable = "avx2,avxvnni")]
558unsafe fn matmul_i2s_i8_avxvnni(
559 acts_int8: &[i8],
560 act_scales: &[f32],
561 m: usize,
562 k: usize,
563 weight_bytes: &[u8],
564 n: usize,
565 weight_scale: f32,
566 output: &mut [f32],
567) {
568 // TODO: fill in. For now delegate to the scalar path so the public
569 // API works on every machine — this is the scaffolding for the
570 // follow-up perf commit.
571 matmul_i2s_i8_scalar(
572 acts_int8,
573 act_scales,
574 m,
575 k,
576 weight_bytes,
577 n,
578 weight_scale,
579 output,
580 );
581}
582
583// =============================================================================
584// Size helpers
585// =============================================================================
586
587/// Bytes needed to store `n_elements` I2_S weights (excluding the 4-byte
588/// tensor scale and any alignment padding — those live outside the packed
589/// stream).
590pub fn bytes_for_elements(n_elements: usize) -> usize {
591 (n_elements / I2S_BLOCK_SIZE) * I2S_BYTES_PER_BLOCK
592}
593
594// =============================================================================
595// Tests
596// =============================================================================
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 fn make_trits(pattern: &[i8]) -> [i8; I2S_BLOCK_SIZE] {
603 let mut out = [0i8; I2S_BLOCK_SIZE];
604 for (i, v) in pattern.iter().cycle().take(I2S_BLOCK_SIZE).enumerate() {
605 out[i] = *v;
606 }
607 out
608 }
609
610 #[test]
611 fn trit_encode_decode_roundtrip() {
612 assert_eq!(decode_trit(0b00), -1);
613 assert_eq!(decode_trit(0b01), 0);
614 assert_eq!(decode_trit(0b10), 1);
615 assert_eq!(decode_trit(0b11), 0);
616 assert_eq!(encode_trit(-1), 0);
617 assert_eq!(encode_trit(0), 1);
618 assert_eq!(encode_trit(1), 2);
619 // Clamp.
620 assert_eq!(encode_trit(42), 2);
621 assert_eq!(encode_trit(-42), 0);
622 }
623
624 #[test]
625 fn block_pack_unpack_roundtrip() {
626 let values = make_trits(&[-1, 0, 1, 0, 1, -1, 0, 0]);
627 let block = I2sBlock::pack(&values);
628 let decoded = block.unpack();
629 assert_eq!(&values[..], &decoded[..]);
630 }
631
632 #[test]
633 fn block_bytes_roundtrip() {
634 let values = make_trits(&[1, -1, 0]);
635 let block = I2sBlock::pack(&values);
636 let bytes = block.to_bytes();
637 let parsed = I2sBlock::from_bytes(&bytes).unwrap();
638 assert_eq!(parsed.data, block.data);
639 assert_eq!(parsed.unpack(), values);
640 }
641
642 #[test]
643 fn dequantize_single_block() {
644 // Alternating +1 / -1 with scale = 2.5 → +2.5 / -2.5.
645 let values = make_trits(&[1, -1]);
646 let block = I2sBlock::pack(&values);
647 let bytes = block.to_bytes();
648 let mut out = [0.0f32; I2S_BLOCK_SIZE];
649 dequantize_i2s_block(&bytes, 2.5, &mut out);
650 for (i, v) in out.iter().enumerate() {
651 let expected = if i % 2 == 0 { 2.5 } else { -2.5 };
652 assert!(
653 (v - expected).abs() < 1e-6,
654 "idx {i}: got {v}, expected {expected}",
655 );
656 }
657 }
658
659 #[test]
660 fn group_strided_layout_is_correct() {
661 // Weight at logical position `group_idx * 32 + group_pos` must be
662 // stored in byte `group_pos`'s bit-slice `(6 - 2*group_idx)..(8 - 2*group_idx)`.
663 // Build a block where only position 0 is +1, rest are 0, and verify the
664 // byte layout matches the Microsoft encoder formula directly.
665 let mut values = [0i8; I2S_BLOCK_SIZE];
666 values[0] = 1; // group 0, pos 0 → byte 0 bits 6..7 = code 2 = 0b10
667 let block = I2sBlock::pack(&values);
668 assert_eq!(
669 block.data[0] & 0b1100_0000,
670 0b1000_0000,
671 "expected code=2 (+1) in byte 0 bits 6-7",
672 );
673 // All other bytes: zero bits 6-7 = code 0 = -1 oh wait — we want zero.
674 // Since we packed 0s everywhere else, their code is 1 (0b01). Byte 0
675 // bits 4-5 should encode group-1-pos-0 = 0 = code 1:
676 assert_eq!(
677 (block.data[0] >> 4) & 0b11,
678 1,
679 "expected code=1 (0) in byte 0 bits 4-5",
680 );
681 }
682
683 #[test]
684 fn dequantize_multi_block_tensor() {
685 let n_blocks = 3;
686 let n_elem = n_blocks * I2S_BLOCK_SIZE;
687 let mut bytes = Vec::with_capacity(n_blocks * I2S_BYTES_PER_BLOCK);
688 let patterns: &[&[i8]] = &[&[1, 0, -1], &[-1, 1, 0], &[0, 0, 1, -1]];
689 for b in 0..n_blocks {
690 let block = I2sBlock::pack(&make_trits(patterns[b]));
691 bytes.extend_from_slice(&block.to_bytes());
692 }
693 let mut out = vec![0.0f32; n_elem];
694 dequantize_i2s(&bytes, 1.0, &mut out);
695
696 // Spot-check first trit of each block.
697 assert_eq!(out[0], 1.0);
698 assert_eq!(out[I2S_BLOCK_SIZE], -1.0);
699 assert_eq!(out[2 * I2S_BLOCK_SIZE], 0.0);
700 }
701
702 fn reference_matmul(
703 activations: &[f32],
704 m: usize,
705 k: usize,
706 weight_bytes: &[u8],
707 n: usize,
708 scale: f32,
709 output: &mut [f32],
710 ) {
711 let mut w = vec![0.0f32; n * k];
712 dequantize_i2s(weight_bytes, scale, &mut w);
713 for i in 0..m {
714 for j in 0..n {
715 let mut s = 0.0f32;
716 for kk in 0..k {
717 s += activations[i * k + kk] * w[j * k + kk];
718 }
719 output[i * n + j] = s;
720 }
721 }
722 }
723
724 #[test]
725 fn matmul_matches_reference_small() {
726 let m = 2;
727 let k = I2S_BLOCK_SIZE;
728 let n = 4;
729 let scale = 0.125f32;
730
731 let mut weight_bytes = Vec::new();
732 let patterns: &[&[i8]] = &[&[1, 0, -1], &[-1, 1, 0], &[0, -1, 1], &[1, 1, -1, -1]];
733 for j in 0..n {
734 let vals = make_trits(patterns[j]);
735 let block = I2sBlock::pack(&vals);
736 weight_bytes.extend_from_slice(&block.to_bytes());
737 }
738
739 let mut activations = vec![0.0f32; m * k];
740 for i in 0..m {
741 for kk in 0..k {
742 activations[i * k + kk] = (i as f32 + 1.0) * (kk as f32 / k as f32);
743 }
744 }
745
746 let mut fused_out = vec![0.0f32; m * n];
747 let mut ref_out = vec![0.0f32; m * n];
748 matmul_i2s(&activations, m, k, &weight_bytes, n, scale, &mut fused_out);
749 reference_matmul(&activations, m, k, &weight_bytes, n, scale, &mut ref_out);
750
751 for (i, (f, r)) in fused_out.iter().zip(ref_out.iter()).enumerate() {
752 assert!((f - r).abs() < 1e-5, "mismatch at {i}: fused={f}, ref={r}",);
753 }
754 }
755
756 #[test]
757 fn matmul_matches_reference_multi_block() {
758 let m = 3;
759 let k = 3 * I2S_BLOCK_SIZE;
760 let n = 5;
761 let scale = 0.25f32;
762
763 let mut weight_bytes = Vec::new();
764 for j in 0..n {
765 for b in 0..(k / I2S_BLOCK_SIZE) {
766 let pattern = if (j + b) % 2 == 0 {
767 &[1, 0, -1, 1, -1][..]
768 } else {
769 &[-1, -1, 1, 0, 1][..]
770 };
771 let block = I2sBlock::pack(&make_trits(pattern));
772 weight_bytes.extend_from_slice(&block.to_bytes());
773 }
774 }
775
776 let mut activations = vec![0.0f32; m * k];
777 for i in 0..m {
778 for kk in 0..k {
779 activations[i * k + kk] = ((i + 1) as f32) * ((kk as f32).sin());
780 }
781 }
782
783 let mut fused_out = vec![0.0f32; m * n];
784 let mut ref_out = vec![0.0f32; m * n];
785 matmul_i2s(&activations, m, k, &weight_bytes, n, scale, &mut fused_out);
786 reference_matmul(&activations, m, k, &weight_bytes, n, scale, &mut ref_out);
787
788 for (i, (f, r)) in fused_out.iter().zip(ref_out.iter()).enumerate() {
789 assert!((f - r).abs() < 1e-4, "mismatch at {i}: fused={f}, ref={r}",);
790 }
791 }
792
793 #[test]
794 fn bytes_for_elements_calculation() {
795 assert_eq!(bytes_for_elements(128), 32);
796 assert_eq!(bytes_for_elements(256), 64);
797 assert_eq!(bytes_for_elements(1024), 256);
798 assert_eq!(bytes_for_elements(0), 0);
799 }
800
801 #[test]
802 fn int8_matmul_matches_f32_within_quant_error() {
803 // Quantize activations + run both paths. They should agree within
804 // the per-row absmax/127 quantization noise.
805 let m = 2;
806 let k = 2 * I2S_BLOCK_SIZE;
807 let n = 6;
808 let weight_scale = 0.1f32;
809
810 let mut weight_bytes = Vec::new();
811 for j in 0..n {
812 for b in 0..(k / I2S_BLOCK_SIZE) {
813 let pattern: &[i8] = if (j + b) % 2 == 0 {
814 &[1, 0, -1, 1]
815 } else {
816 &[-1, 1, 0, -1]
817 };
818 let block = I2sBlock::pack(&make_trits(pattern));
819 weight_bytes.extend_from_slice(&block.to_bytes());
820 }
821 }
822
823 // f32 activations, deterministic-ish.
824 let mut activations = vec![0.0f32; m * k];
825 for i in 0..m {
826 for kk in 0..k {
827 activations[i * k + kk] = ((kk as f32) * 0.13 - 2.0).sin() * (1.0 + i as f32 * 0.1);
828 }
829 }
830
831 let mut ref_out = vec![0.0f32; m * n];
832 matmul_i2s(
833 &activations,
834 m,
835 k,
836 &weight_bytes,
837 n,
838 weight_scale,
839 &mut ref_out,
840 );
841
842 // Quantize activations per row.
843 let mut acts_i8 = vec![0i8; m * k];
844 let mut act_scales = vec![0.0f32; m];
845 for i in 0..m {
846 act_scales[i] = quantize_row_to_int8(
847 &activations[i * k..(i + 1) * k],
848 &mut acts_i8[i * k..(i + 1) * k],
849 );
850 }
851
852 let mut i8_out = vec![0.0f32; m * n];
853 matmul_i2s_i8(
854 &acts_i8,
855 &act_scales,
856 m,
857 k,
858 &weight_bytes,
859 n,
860 weight_scale,
861 &mut i8_out,
862 );
863
864 // Expected error: per-activation quantization error is ≤ scale/2,
865 // and the dot product sums k=256 terms, so worst-case error scales
866 // with sqrt(k) × max_weight_scale × act_scale/2. For our tiny test,
867 // a relative tolerance of ~3% against the f32 reference is generous
868 // enough to catch logic errors without flaking.
869 for (i, (&r, &q)) in ref_out.iter().zip(i8_out.iter()).enumerate() {
870 let abs_err = (r - q).abs();
871 let rel_err = abs_err / r.abs().max(1e-6);
872 assert!(
873 rel_err < 0.05 || abs_err < 1e-3,
874 "idx {i}: f32 ref = {r}, int8 quantized = {q}, rel_err = {rel_err}",
875 );
876 }
877 }
878
879 #[test]
880 fn quantize_row_to_int8_roundtrip() {
881 let input = [1.0f32, -2.0, 0.5, -0.5, 0.0, 2.0, -1.5];
882 let mut output = [0i8; 7];
883 let scale = quantize_row_to_int8(&input, &mut output);
884 assert!(scale > 0.0);
885 // Largest magnitude is 2.0; should map to ±127.
886 let max_idx = input
887 .iter()
888 .enumerate()
889 .max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap())
890 .unwrap()
891 .0;
892 assert_eq!(output[max_idx].unsigned_abs(), 127);
893 // Dequantized values should be close to the originals.
894 for (i, &v) in input.iter().enumerate() {
895 let recovered = output[i] as f32 * scale;
896 assert!(
897 (recovered - v).abs() < scale,
898 "idx {i}: {v} → {} (scale={scale})",
899 recovered
900 );
901 }
902 }
903
904 #[test]
905 fn quantize_row_to_int8_zero_input() {
906 let input = [0.0f32; 8];
907 let mut output = [0i8; 8];
908 let scale = quantize_row_to_int8(&input, &mut output);
909 assert_eq!(scale, 0.0);
910 assert!(output.iter().all(|&x| x == 0));
911 }
912
913 #[test]
914 #[should_panic(expected = "k")]
915 fn matmul_rejects_misaligned_k() {
916 let m = 1;
917 let k = 100;
918 let n = 1;
919 let acts = vec![0.0; m * k];
920 let weight_bytes = vec![0u8; I2S_BYTES_PER_BLOCK];
921 let mut out = vec![0.0; m * n];
922 matmul_i2s(&acts, m, k, &weight_bytes, n, 1.0, &mut out);
923 }
924}