1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
/// AVX2-accelerated fused Q4_K dequant+dot kernel (PARITY-003: llama.cpp-style SIMD)
///
/// # Safety
///
/// Caller must ensure:
/// 1. AVX2 and FMA CPU features are available (use `is_x86_feature_detected!`)
/// 2. Input slices are valid (handled by Rust's slice guarantees)
///
/// This function is marked unsafe due to SIMD intrinsics, but is logically
/// equivalent to the scalar `fused_q4k_dot` (within ULP tolerance).
///
/// # Optimizations (PARITY-003)
/// - SIMD loads: AVX2 256-bit unaligned load for 32-byte bulk loads (vs scalar byte loads)
/// - SIMD nibble extraction: AVX2 bitwise AND with 0x0F mask (vs scalar & 0x0F)
/// - 4 independent accumulators to hide FMA latency
/// - Software prefetching for next super-block
/// - Matches llama.cpp ggml_vec_dot_q4_K_q8_K pattern
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[allow(unsafe_op_in_unsafe_fn)]
// SAFETY: Caller must satisfy the documented preconditions
unsafe fn fused_q4k_dot_avx2(q4k_data: &[u8], activations: &[f32]) -> Result<f32> {
// Allow wildcard import for SIMD intrinsics (standard pattern for arch-specific code)
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
const SUPER_BLOCK_BYTES: usize = 144;
// Validate inputs (same as scalar)
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of super-block size {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let expected_values = num_super_blocks * QK_K;
if activations.len() != expected_values {
return Err(RealizarError::InvalidShape {
reason: format!(
"Activation length {} doesn't match Q4_K values count {}",
activations.len(),
expected_values
),
});
}
// Nibble mask for extracting 4-bit values (llama.cpp pattern)
let nibble_mask = _mm256_set1_epi8(0x0F_i8);
// PARITY-003: 4 independent accumulators to hide FMA latency
// FMA latency = 4 cycles, throughput = 2/cycle
// With 4 independent chains, we saturate the FMA throughput
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
let mut activation_idx = 0;
for sb_idx in 0..num_super_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
// Prefetch next super-block (dual: Q4_K weights + activations)
if sb_idx + 1 < num_super_blocks {
let next_sb = (sb_idx + 1) * SUPER_BLOCK_BYTES;
_mm_prefetch(q4k_data.as_ptr().add(next_sb).cast::<i8>(), _MM_HINT_T0);
_mm_prefetch(
activations.as_ptr().add((sb_idx + 1) * QK_K).cast::<i8>(),
_MM_HINT_T0,
);
}
// Read d and dmin (f16 -> f32)
let d = read_f16(&q4k_data[sb_start..sb_start + 2]);
let dmin = read_f16(&q4k_data[sb_start + 2..sb_start + 4]);
// Read scales (12 bytes)
let mut scales = [0u8; 12];
scales.copy_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
// Pointer to quantized data (128 bytes = 256 nibbles = 256 values)
let qs_ptr = q4k_data.as_ptr().add(sb_start + 16);
// PAR-001: Match dequantize_q4_k layout (llama.cpp/candle compatible)
// Process 4 chunks of 64 values each (j=0, 64, 128, 192)
// Each chunk: 32 low nibbles (sc1), then 32 high nibbles (sc2)
for j in (0..QK_K).step_by(64) {
let q_start = j / 2; // 32 bytes per 64-value chunk
// Get scales for the two 32-value halves
let is = j / 32;
let (sc1, m1) = extract_scale_min(&scales, is);
let (sc2, m2) = extract_scale_min(&scales, is + 1);
// Precompute d*scale and dmin*min for both halves
let d_scale1 = d * sc1;
let dm1 = dmin * m1;
let d_scale2 = d * sc2;
let dm2 = dmin * m2;
// SIMD OPTIMIZATION: Load 32 bytes at once (64 nibbles = 64 values)
// llama.cpp pattern: _mm256_loadu_si256 + AND/SHIFT for nibble extraction
let q_bytes = _mm256_loadu_si256(qs_ptr.add(q_start).cast::<__m256i>());
// Extract low nibbles (first 32 values) and high nibbles (second 32 values)
let q_lo = _mm256_and_si256(q_bytes, nibble_mask);
let q_hi = _mm256_and_si256(_mm256_srli_epi16(q_bytes, 4), nibble_mask);
// Process low nibbles (32 values with scale sc1/m1)
// Split into 4 groups of 8 for f32 conversion (AVX2 can convert 8 i32→f32)
let d_scale1_vec = _mm256_set1_ps(d_scale1);
let dm1_vec = _mm256_set1_ps(dm1);
// Extract bytes 0-7 (low nibbles) → convert to f32
let q_lo_128_0 = _mm256_castsi256_si128(q_lo);
let q_lo_i32_0 = _mm256_cvtepu8_epi32(q_lo_128_0);
let q_lo_f32_0 = _mm256_cvtepi32_ps(q_lo_i32_0);
let dequant0 = _mm256_fmsub_ps(d_scale1_vec, q_lo_f32_0, dm1_vec);
let act0 = _mm256_loadu_ps(activations.as_ptr().add(activation_idx));
acc0 = _mm256_fmadd_ps(dequant0, act0, acc0);
activation_idx += 8;
// Extract bytes 8-15 (low nibbles)
let q_lo_shifted = _mm_srli_si128(q_lo_128_0, 8);
let q_lo_i32_1 = _mm256_cvtepu8_epi32(q_lo_shifted);
let q_lo_f32_1 = _mm256_cvtepi32_ps(q_lo_i32_1);
let dequant1 = _mm256_fmsub_ps(d_scale1_vec, q_lo_f32_1, dm1_vec);
let act1 = _mm256_loadu_ps(activations.as_ptr().add(activation_idx));
acc1 = _mm256_fmadd_ps(dequant1, act1, acc1);
activation_idx += 8;
// Extract bytes 16-23 (low nibbles from high 128 bits)
let q_lo_128_1 = _mm256_extracti128_si256(q_lo, 1);
let q_lo_i32_2 = _mm256_cvtepu8_epi32(q_lo_128_1);
let q_lo_f32_2 = _mm256_cvtepi32_ps(q_lo_i32_2);
let dequant2 = _mm256_fmsub_ps(d_scale1_vec, q_lo_f32_2, dm1_vec);
let act2 = _mm256_loadu_ps(activations.as_ptr().add(activation_idx));
acc2 = _mm256_fmadd_ps(dequant2, act2, acc2);
activation_idx += 8;
// Extract bytes 24-31 (low nibbles)
let q_lo_shifted2 = _mm_srli_si128(q_lo_128_1, 8);
let q_lo_i32_3 = _mm256_cvtepu8_epi32(q_lo_shifted2);
let q_lo_f32_3 = _mm256_cvtepi32_ps(q_lo_i32_3);
let dequant3 = _mm256_fmsub_ps(d_scale1_vec, q_lo_f32_3, dm1_vec);
let act3 = _mm256_loadu_ps(activations.as_ptr().add(activation_idx));
acc3 = _mm256_fmadd_ps(dequant3, act3, acc3);
activation_idx += 8;
// Process high nibbles (32 values with scale sc2/m2)
let d_scale2_vec = _mm256_set1_ps(d_scale2);
let dm2_vec = _mm256_set1_ps(dm2);
// Extract bytes 0-7 (high nibbles)
let q_hi_128_0 = _mm256_castsi256_si128(q_hi);
let q_hi_i32_0 = _mm256_cvtepu8_epi32(q_hi_128_0);
let q_hi_f32_0 = _mm256_cvtepi32_ps(q_hi_i32_0);
let dequant4 = _mm256_fmsub_ps(d_scale2_vec, q_hi_f32_0, dm2_vec);
let act4 = _mm256_loadu_ps(activations.as_ptr().add(activation_idx));
acc0 = _mm256_fmadd_ps(dequant4, act4, acc0);
activation_idx += 8;
// Extract bytes 8-15 (high nibbles)
let q_hi_shifted = _mm_srli_si128(q_hi_128_0, 8);
let q_hi_i32_1 = _mm256_cvtepu8_epi32(q_hi_shifted);
let q_hi_f32_1 = _mm256_cvtepi32_ps(q_hi_i32_1);
let dequant5 = _mm256_fmsub_ps(d_scale2_vec, q_hi_f32_1, dm2_vec);
let act5 = _mm256_loadu_ps(activations.as_ptr().add(activation_idx));
acc1 = _mm256_fmadd_ps(dequant5, act5, acc1);
activation_idx += 8;
// Extract bytes 16-23 (high nibbles from high 128 bits)
let q_hi_128_1 = _mm256_extracti128_si256(q_hi, 1);
let q_hi_i32_2 = _mm256_cvtepu8_epi32(q_hi_128_1);
let q_hi_f32_2 = _mm256_cvtepi32_ps(q_hi_i32_2);
let dequant6 = _mm256_fmsub_ps(d_scale2_vec, q_hi_f32_2, dm2_vec);
let act6 = _mm256_loadu_ps(activations.as_ptr().add(activation_idx));
acc2 = _mm256_fmadd_ps(dequant6, act6, acc2);
activation_idx += 8;
// Extract bytes 24-31 (high nibbles)
let q_hi_shifted2 = _mm_srli_si128(q_hi_128_1, 8);
let q_hi_i32_3 = _mm256_cvtepu8_epi32(q_hi_shifted2);
let q_hi_f32_3 = _mm256_cvtepi32_ps(q_hi_i32_3);
let dequant7 = _mm256_fmsub_ps(d_scale2_vec, q_hi_f32_3, dm2_vec);
let act7 = _mm256_loadu_ps(activations.as_ptr().add(activation_idx));
acc3 = _mm256_fmadd_ps(dequant7, act7, acc3);
activation_idx += 8;
}
}
// Combine 4 accumulators → single accumulator
let acc_01 = _mm256_add_ps(acc0, acc1);
let acc_23 = _mm256_add_ps(acc2, acc3);
let acc = _mm256_add_ps(acc_01, acc_23);
// Horizontal sum: reduce 8 lanes to single value
let sum_halves = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
let temp = _mm_add_ps(sum_halves, _mm_movehl_ps(sum_halves, sum_halves));
let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
let result = _mm_cvtss_f32(temp);
Ok(result)
}
// ============================================================================
// Q4_K × Q8_K KERNELS (Super-block aligned integer-only arithmetic)
// ============================================================================
/// Fused Q4_K × Q8_K dot product (super-block aligned, llama.cpp-style)
///
/// Uses Q8_K format (256 values per super-block, single scale) for maximum
/// SIMD efficiency. This matches llama.cpp's `ggml_vec_dot_q4_K_q8_K`.
///
/// # Arguments
/// * `q4k_data` - Raw Q4_K quantized data (144 bytes per super-block)
/// * `q8k_scales` - Q8_K scales (one per super-block)
/// * `q8k_quants` - Q8_K quantized int8 values (256 per super-block)
///
/// # Performance
///
/// Compared to Q4_K × f32:
/// - 8x fewer memory reads for activations
/// - Integer-only inner loop (no f32 conversion until end)
/// - Single scale multiplication per super-block (vs 8 for Q8_0)
pub fn fused_q4k_q8k_dot(q4k_data: &[u8], q8k_scales: &[f32], q8k_quants: &[i8]) -> Result<f32> {
const SUPER_BLOCK_BYTES: usize = 144;
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let expected_values = num_super_blocks * QK_K;
if q8k_scales.len() < num_super_blocks {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q8_K scales count {} < expected {}",
q8k_scales.len(),
num_super_blocks
),
});
}
if q8k_quants.len() < expected_values {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q8_K quants count {} < expected {}",
q8k_quants.len(),
expected_values
),
});
}
let mut total_acc = 0.0f32;
for sb_idx in 0..num_super_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let q8_start = sb_idx * QK_K;
// Read Q4_K super-block header
let d = read_f16(&q4k_data[sb_start..sb_start + 2]);
let dmin = read_f16(&q4k_data[sb_start + 2..sb_start + 4]);
// Read scales (12 bytes for 8 blocks)
let mut scales = [0u8; 12];
scales.copy_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
// Q8_K scale for this super-block
let q8_scale = q8k_scales[sb_idx];
// Process 4 chunks of 64 values (matching dequantize_q4_k layout)
// The dequantized output order is: 32 low nibbles, then 32 high nibbles
// So activations[j..j+32] correspond to low nibbles, activations[j+32..j+64] to high
for j in (0..QK_K).step_by(64) {
let q_offset = sb_start + 16 + j / 2; // 32 bytes per 64-value chunk
let q8_offset = q8_start + j;
// Get scales for low and high nibbles
let is = j / 32;
let (sc1, m1) = extract_scale_min(&scales, is);
let (sc2, m2) = extract_scale_min(&scales, is + 1);
// Combined scale factors
let d_sc1_q8 = d * sc1 * q8_scale;
let dm1_q8 = dmin * m1 * q8_scale;
let d_sc2_q8 = d * sc2 * q8_scale;
let dm2_q8 = dmin * m2 * q8_scale;
// Accumulators for low and high nibbles
let mut sum_lo: i32 = 0; // q4_lo × q8 (for activations[j..j+32])
let mut sum_hi: i32 = 0; // q4_hi × q8 (for activations[j+32..j+64])
let mut q8_sum_lo: i32 = 0;
let mut q8_sum_hi: i32 = 0;
for b in 0..32 {
let q4_byte = q4k_data[q_offset + b];
// Low nibble × activation[j + b] (first 32 positions in dequant order)
let q4_lo = (q4_byte & 0x0F) as i32;
let q8_lo = q8k_quants[q8_offset + b] as i32;
sum_lo += q4_lo * q8_lo;
q8_sum_lo += q8_lo;
// High nibble × activation[j + 32 + b] (second 32 positions in dequant order)
let q4_hi = ((q4_byte >> 4) & 0x0F) as i32;
let q8_hi = q8k_quants[q8_offset + 32 + b] as i32;
sum_hi += q4_hi * q8_hi;
q8_sum_hi += q8_hi;
}
// Apply formula: (d * scale * sum_q4_q8 - dmin * min * sum_q8) * q8_scale
total_acc += d_sc1_q8 * (sum_lo as f32) - dm1_q8 * (q8_sum_lo as f32);
total_acc += d_sc2_q8 * (sum_hi as f32) - dm2_q8 * (q8_sum_hi as f32);
}
}
Ok(total_acc)
}
/// SIMD-accelerated Q4_K × Q8_K dot product
///
/// Uses AVX-512 VNNI (vpdpbusd) for maximum throughput, falls back to AVX2 or scalar.
/// Single scale per super-block eliminates per-block overhead.
pub fn fused_q4k_q8k_dot_simd(
q4k_data: &[u8],
q8k_scales: &[f32],
q8k_quants: &[i8],
) -> Result<f32> {
#[cfg(target_arch = "x86_64")]
{
// PAR-126: Use V2 optimized AVX-512 VNNI kernel (deferred horizontal sums)
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vnni") {
// SAFETY: Memory safety ensured by bounds checking and alignment
return unsafe { fused_q4k_q8k_dot_avx512vnni_v2(q4k_data, q8k_scales, q8k_quants) };
}
// pmat-ignore: hardware-path (AVX2 fallback never reached when AVX-512 VNNI available)
// Fallback to AVX2 (layout issue resolved)
if is_x86_feature_detected!("avx2") {
// SAFETY: Memory safety ensured by bounds checking and alignment
return unsafe { fused_q4k_q8k_dot_avx2(q4k_data, q8k_scales, q8k_quants) };
}
}
// pmat-ignore: hardware-path (scalar fallback tested directly via fused_q4k_q8k_dot)
fused_q4k_q8k_dot(q4k_data, q8k_scales, q8k_quants)
}