1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
//! AVX2-accelerated dequantization fast paths for Q4_0 / Q8_0 (GH-386).
//!
//! The scalar implementations in [`crate::format::quantize`] (`Q4_0Quantizer`,
//! `Q8_0Quantizer`) iterate one element at a time; LLVM's auto-vectorizer
//! handles the trivial multiply but bottlenecks on the i8→i32→f32 sign-extend
//! cascade and the nibble unpack, hitting ~1.2 Gelem/s (≈5× below memcpy
//! ceiling).
//!
//! This module provides AVX2 implementations that:
//!
//! - **Q8_0**: load 32 i8 elements per block, sign-extend to four 256-bit i32
//! lanes via `_mm256_cvtepi8_epi32`, convert to f32, multiply by a broadcast
//! f16 scale, and store the 32-element output as four 256-bit f32 vectors.
//!
//! - **Q4_0**: load 16 packed nibble bytes, extract low and high nibbles
//! (mask + shift), interleave so that `byte_i` produces output positions
//! `2i` and `2i+1` (matching the existing `Q4_0Quantizer::quantize` pack
//! layout — NOT the GGML half-half layout used in `format::gguf::dequant`),
//! subtract 8, convert to f32, multiply by the scale, and store.
//!
//! Runtime dispatch is via [`is_x86_feature_detected`]. Targets without
//! AVX2 fall back to the scalar path unchanged. All non-x86 architectures
//! also fall back. The fast paths produce **bit-exact** output relative to
//! the scalar reference (verified by `tests::scalar_simd_parity_*` and the
//! proptest under `tests::prop_avx2_matches_scalar_*`).
//!
//! # Safety
//!
//! Each `_avx2` function is `unsafe fn` and is only reachable from
//! [`dequantize_q8_0_avx2_dispatch`] / [`dequantize_q4_0_avx2_dispatch`],
//! which check `is_x86_feature_detected!("avx2")` immediately before the
//! call. The functions are marked `#[target_feature(enable = "avx2")]`
//! so the codegen for the AVX2 intrinsics is correct; the caller is
//! responsible for the runtime feature gate.
//!
//! Bounds invariants:
//!
//! - Caller passes `blocks` with `blocks.len() == num_blocks * BLOCK_BYTES`
//! (Q8_0: 34, Q4_0: 18) and `out.len() == num_blocks * BLOCK_SIZE` (32).
//! - All loads/stores go through `_mm256_loadu_si256` / `_mm256_storeu_ps`
//! so input alignment is not required.
#![allow(unsafe_code)] // GH-386: documented AVX2 fast path; runtime-gated.
use crate::format::quantize::BLOCK_SIZE;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use half::f16;
/// Dispatch entry point for Q8_0 dequant.
///
/// Calls the AVX2 fast path when available, otherwise returns `false` so the
/// caller can use its scalar reference path. Writes exactly `num_blocks *
/// BLOCK_SIZE` elements to `out` when it returns `true`; `out.len()` must be
/// `≥ num_blocks * BLOCK_SIZE`.
#[inline]
pub(crate) fn dequantize_q8_0_avx2_dispatch(
blocks: &[u8],
num_blocks: usize,
out: &mut [f32],
) -> bool {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
// SAFETY: feature gate above guarantees AVX2 is available. Callers
// (the `Q8_0Quantizer::dequantize` path) pre-validate that
// `blocks.len() == num_blocks * Q8_0_BLOCK_BYTES (34)` and
// `out.len() == num_blocks * BLOCK_SIZE (32)` — see
// `tests::scalar_simd_parity_q8_0`.
unsafe { dequantize_q8_0_avx2(blocks, num_blocks, out) };
return true;
}
}
let _ = (blocks, num_blocks, out);
false
}
/// Dispatch entry point for Q4_0 dequant. See [`dequantize_q8_0_avx2_dispatch`].
#[inline]
pub(crate) fn dequantize_q4_0_avx2_dispatch(
blocks: &[u8],
num_blocks: usize,
out: &mut [f32],
) -> bool {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
// SAFETY: as for Q8_0 dispatch; preconditions on
// `blocks.len() == num_blocks * Q4_0_BLOCK_BYTES (18)` and
// `out.len() == num_blocks * BLOCK_SIZE (32)` enforced by callers.
unsafe { dequantize_q4_0_avx2(blocks, num_blocks, out) };
return true;
}
}
let _ = (blocks, num_blocks, out);
false
}
// ---------------------------------------------------------------------------
// AVX2 implementations
// ---------------------------------------------------------------------------
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn dequantize_q8_0_avx2(blocks: &[u8], num_blocks: usize, out: &mut [f32]) {
#[cfg(target_arch = "x86")]
use std::arch::x86::{
__m128i, _mm256_cvtepi32_ps, _mm256_cvtepi8_epi32, _mm256_mul_ps, _mm256_set1_ps,
_mm256_storeu_ps, _mm_loadu_si128, _mm_srli_si128,
};
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{
__m128i, _mm256_cvtepi32_ps, _mm256_cvtepi8_epi32, _mm256_mul_ps, _mm256_set1_ps,
_mm256_storeu_ps, _mm_loadu_si128, _mm_srli_si128,
};
const BLOCK_BYTES: usize = 34;
// SAFETY: every intrinsic call is reachable only via
// `dequantize_q8_0_avx2_dispatch`, which gates on `is_x86_feature_detected!("avx2")`.
// Bounds: `blocks.len() >= num_blocks * BLOCK_BYTES`, `out.len() >=
// num_blocks * BLOCK_SIZE` (caller invariants).
unsafe {
for block_idx in 0..num_blocks {
let b_start = block_idx * BLOCK_BYTES;
let block = &blocks[b_start..b_start + BLOCK_BYTES];
// f16 scale → f32 → broadcast across 8 lanes.
let scale = f16::from_le_bytes([block[0], block[1]]).to_f32();
let scale_v = _mm256_set1_ps(scale);
// Load 16 bytes (low half) + 16 bytes (high half) of the 32 i8 quants.
// AVX2 _mm256_cvtepi8_epi32 takes 8 i8 from a 128-bit lane and
// sign-extends to 8 i32; do it 4× (once per 8-element slice).
let lo16 = _mm_loadu_si128(block.as_ptr().add(2).cast::<__m128i>());
let hi16 = _mm_loadu_si128(block.as_ptr().add(18).cast::<__m128i>());
let q_i32_0 = _mm256_cvtepi8_epi32(lo16);
let q_i32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128::<8>(lo16));
let q_i32_2 = _mm256_cvtepi8_epi32(hi16);
let q_i32_3 = _mm256_cvtepi8_epi32(_mm_srli_si128::<8>(hi16));
let f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(q_i32_0), scale_v);
let f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(q_i32_1), scale_v);
let f2 = _mm256_mul_ps(_mm256_cvtepi32_ps(q_i32_2), scale_v);
let f3 = _mm256_mul_ps(_mm256_cvtepi32_ps(q_i32_3), scale_v);
let out_ptr = out.as_mut_ptr().add(block_idx * BLOCK_SIZE);
_mm256_storeu_ps(out_ptr, f0);
_mm256_storeu_ps(out_ptr.add(8), f1);
_mm256_storeu_ps(out_ptr.add(16), f2);
_mm256_storeu_ps(out_ptr.add(24), f3);
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn dequantize_q4_0_avx2(blocks: &[u8], num_blocks: usize, out: &mut [f32]) {
#[cfg(target_arch = "x86")]
use std::arch::x86::{
__m128i, _mm256_cvtepi32_ps, _mm256_cvtepu8_epi32, _mm256_mul_ps, _mm256_set1_epi32,
_mm256_set1_ps, _mm256_storeu_ps, _mm256_sub_epi32, _mm_and_si128, _mm_loadu_si128,
_mm_set1_epi8, _mm_srli_epi16, _mm_srli_si128, _mm_unpackhi_epi8, _mm_unpacklo_epi8,
};
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{
__m128i, _mm256_cvtepi32_ps, _mm256_cvtepu8_epi32, _mm256_mul_ps, _mm256_set1_epi32,
_mm256_set1_ps, _mm256_storeu_ps, _mm256_sub_epi32, _mm_and_si128, _mm_loadu_si128,
_mm_set1_epi8, _mm_srli_epi16, _mm_srli_si128, _mm_unpackhi_epi8, _mm_unpacklo_epi8,
};
const BLOCK_BYTES: usize = 18;
// SAFETY: every intrinsic call is reachable only via
// `dequantize_q4_0_avx2_dispatch`, which gates on
// `is_x86_feature_detected!("avx2")`. Bounds: `blocks.len() >= num_blocks
// * BLOCK_BYTES`, `out.len() >= num_blocks * BLOCK_SIZE` (caller invariants).
unsafe {
let mask_lo_nib = _mm_set1_epi8(0x0F);
let bias_i32 = _mm256_set1_epi32(8);
for block_idx in 0..num_blocks {
let b_start = block_idx * BLOCK_BYTES;
let block = &blocks[b_start..b_start + BLOCK_BYTES];
let scale = f16::from_le_bytes([block[0], block[1]]).to_f32();
let scale_v = _mm256_set1_ps(scale);
// Load 16 packed bytes (32 nibbles).
let packed = _mm_loadu_si128(block.as_ptr().add(2).cast::<__m128i>());
// Low nibble of each byte: byte_i & 0x0F → output position 2i
let nib_lo = _mm_and_si128(packed, mask_lo_nib);
// High nibble of each byte: (byte_i >> 4) & 0x0F → output position 2i+1
// (use 16-bit shift since AVX2 lacks 8-bit shift; mask afterwards).
let nib_hi = _mm_and_si128(_mm_srli_epi16::<4>(packed), mask_lo_nib);
// Interleave so positions (lo_0, hi_0, lo_1, hi_1, ...) match the
// `Q4_0Quantizer::quantize` pack layout
// (byte_i = (q_2i+1) << 4 | q_2i).
let inter_lo = _mm_unpacklo_epi8(nib_lo, nib_hi); // → out 0..16
let inter_hi = _mm_unpackhi_epi8(nib_lo, nib_hi); // → out 16..32
// Sign-extend each 8-element u8 sublane (values 0..15) to i32,
// then subtract the centering bias 8.
let q_i32_0 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(inter_lo), bias_i32);
let q_i32_1 = _mm256_sub_epi32(
_mm256_cvtepu8_epi32(_mm_srli_si128::<8>(inter_lo)),
bias_i32,
);
let q_i32_2 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(inter_hi), bias_i32);
let q_i32_3 = _mm256_sub_epi32(
_mm256_cvtepu8_epi32(_mm_srli_si128::<8>(inter_hi)),
bias_i32,
);
let f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(q_i32_0), scale_v);
let f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(q_i32_1), scale_v);
let f2 = _mm256_mul_ps(_mm256_cvtepi32_ps(q_i32_2), scale_v);
let f3 = _mm256_mul_ps(_mm256_cvtepi32_ps(q_i32_3), scale_v);
let out_ptr = out.as_mut_ptr().add(block_idx * BLOCK_SIZE);
_mm256_storeu_ps(out_ptr, f0);
_mm256_storeu_ps(out_ptr.add(8), f1);
_mm256_storeu_ps(out_ptr.add(16), f2);
_mm256_storeu_ps(out_ptr.add(24), f3);
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::format::quantize::{
quantize, QuantType, BLOCK_SIZE as BS, Q4_0_BLOCK_BYTES, Q8_0_BLOCK_BYTES,
};
/// Reference scalar Q8_0 dequant (mirrors `Q8_0Quantizer::dequantize` body).
fn ref_dequantize_q8_0(blocks: &[u8], num_blocks: usize, total_elems: usize) -> Vec<f32> {
let mut out = vec![0.0f32; num_blocks * BS];
for block_idx in 0..num_blocks {
let b_start = block_idx * Q8_0_BLOCK_BYTES;
let scale = half::f16::from_le_bytes([blocks[b_start], blocks[b_start + 1]]).to_f32();
let qs = &blocks[b_start + 2..b_start + 2 + BS];
let out_off = block_idx * BS;
for (j, &q) in qs.iter().enumerate() {
out[out_off + j] = (q as i8) as f32 * scale;
}
}
out.truncate(total_elems);
out
}
/// Reference scalar Q4_0 dequant (mirrors `Q4_0Quantizer::dequantize` body).
fn ref_dequantize_q4_0(blocks: &[u8], num_blocks: usize, total_elems: usize) -> Vec<f32> {
let mut out = vec![0.0f32; num_blocks * BS];
for block_idx in 0..num_blocks {
let b_start = block_idx * Q4_0_BLOCK_BYTES;
let scale = half::f16::from_le_bytes([blocks[b_start], blocks[b_start + 1]]).to_f32();
let packed = &blocks[b_start + 2..b_start + 2 + 16];
let out_off = block_idx * BS;
for (i, &p) in packed.iter().enumerate() {
let q0 = (p & 0x0F) as i8 - 8;
let q1 = ((p >> 4) & 0x0F) as i8 - 8;
out[out_off + i * 2] = (q0 as f32) * scale;
out[out_off + i * 2 + 1] = (q1 as f32) * scale;
}
}
out.truncate(total_elems);
out
}
fn make_payload(n: usize, seed: u32) -> Vec<f32> {
(0..n)
.map(|i| {
let x = (i as u32).wrapping_mul(2_654_435_761).wrapping_add(seed) as f32;
(x * 1.0e-9).sin()
})
.collect()
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[test]
fn scalar_simd_parity_q8_0() {
if !is_x86_feature_detected!("avx2") {
eprintln!("skipping AVX2 parity test — CPU lacks avx2");
return;
}
for n in [32, 64, 256, 1024, 32 * 71] {
let data = make_payload(n, 7);
let q = quantize(&data, &[n], QuantType::Q8_0).expect("quantize");
let num_blocks = q.num_blocks();
let ref_out = ref_dequantize_q8_0(&q.blocks, num_blocks, n);
let mut simd_out = vec![0.0f32; num_blocks * BS];
let dispatched = dequantize_q8_0_avx2_dispatch(&q.blocks, num_blocks, &mut simd_out);
assert!(dispatched, "AVX2 dispatch must run on avx2 host");
simd_out.truncate(n);
assert_eq!(ref_out.len(), simd_out.len());
for (i, (r, s)) in ref_out.iter().zip(&simd_out).enumerate() {
assert!(
r.to_bits() == s.to_bits(),
"Q8_0 mismatch at i={i} n={n}: scalar={r} simd={s}"
);
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[test]
fn scalar_simd_parity_q4_0() {
if !is_x86_feature_detected!("avx2") {
eprintln!("skipping AVX2 parity test — CPU lacks avx2");
return;
}
for n in [32, 64, 256, 1024, 32 * 71] {
let data = make_payload(n, 13);
let q = quantize(&data, &[n], QuantType::Q4_0).expect("quantize");
let num_blocks = q.num_blocks();
let ref_out = ref_dequantize_q4_0(&q.blocks, num_blocks, n);
let mut simd_out = vec![0.0f32; num_blocks * BS];
let dispatched = dequantize_q4_0_avx2_dispatch(&q.blocks, num_blocks, &mut simd_out);
assert!(dispatched, "AVX2 dispatch must run on avx2 host");
simd_out.truncate(n);
assert_eq!(ref_out.len(), simd_out.len());
for (i, (r, s)) in ref_out.iter().zip(&simd_out).enumerate() {
assert!(
r.to_bits() == s.to_bits(),
"Q4_0 mismatch at i={i} n={n}: scalar={r} simd={s}"
);
}
}
}
/// Non-x86 platforms (and CPUs without AVX2) hit the dispatcher's `false`
/// path so the scalar fallback runs. Verify the dispatch returns false on
/// non-x86, and on x86 only returns false when AVX2 is missing.
#[test]
fn dispatch_returns_false_without_avx2() {
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
{
let mut out = vec![0.0f32; 32];
assert!(!dequantize_q8_0_avx2_dispatch(&[0u8; 34], 1, &mut out));
assert!(!dequantize_q4_0_avx2_dispatch(&[0u8; 18], 1, &mut out));
}
// On x86, the dispatcher runs the SIMD path when AVX2 exists; the
// test in that case is the parity test above. No assertion here.
}
}