Skip to main content

trueno/backends/q4k/gemv/
avx512.rs

1//! AVX-512 SIMD Q4_K GEMV implementation.
2//!
3//! Contract: avx512-q4k-v1.yaml (C-AVX512-Q4K-001)
4//! Processes 16 elements per iteration using zmm registers (2× throughput vs AVX2).
5//! References: [46] GPTQ, [47] QuIP# AVX-512 dequant methodology.
6
7use super::super::{parse_q4k_header, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
8
9/// Fused Q4_K matrix-vector multiply with AVX-512 SIMD (16-wide)
10///
11/// Contract: avx512-q4k-v1.yaml (C-AVX512-Q4K-001, C-AVX512-Q4K-002)
12#[cfg(target_arch = "x86_64")]
13#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "fma")]
14pub(crate) unsafe fn matmul_q4k_f32_avx512(
15    q4k_data: &[u8],
16    input: &[f32],
17    out_dim: usize,
18    in_dim: usize,
19) -> Vec<f32> {
20    unsafe {
21        use std::arch::x86_64::*;
22
23        let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
24        let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
25        let low_mask = _mm512_set1_epi32(0x0F);
26
27        // Uninit: output[out_idx] = hsum_avx512(acc) (SET) for every out_idx.
28        let mut output: Vec<f32> = Vec::with_capacity(out_dim);
29        // SAFETY: Each output[out_idx] is SET from local AVX-512 accumulator.
30        output.set_len(out_dim);
31
32        for out_idx in 0..out_dim {
33            let row_start = out_idx * row_bytes;
34            let mut acc = _mm512_setzero_ps();
35
36            for sb_idx in 0..num_blocks_per_row {
37                let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
38                let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
39                let input_offset = sb_idx * SUPER_BLOCK_SIZE;
40                process_q4k_superblock_avx512(
41                    sb_data,
42                    input,
43                    input_offset,
44                    in_dim,
45                    low_mask,
46                    &mut acc,
47                );
48            }
49
50            output[out_idx] = hsum_avx512(acc);
51        }
52
53        contract_post_dequant!(output);
54        output
55    }
56}
57
58/// Process one Q4K super-block with AVX-512 (16-wide), fully unrolled.
59///
60/// Each super-block = 256 elements in 4 chunks of 64.
61/// Each chunk: 32 low nibbles + 32 high nibbles.
62/// AVX-512: 16 elements per iteration → 2 iterations per 32 nibbles.
63///
64/// Optimization (Phase 4, 2026-04-05):
65/// - Fully unrolled inner loops (was while loop with 2 iterations)
66/// - Bounds check hoisted out of hot loop (in_dim validated by caller)
67/// - Software prefetch of next superblock's quantized data
68///
69/// NOTE: Dual-accumulator (low→acc0, high→acc1) was tested (2026-04-05)
70/// but showed NO improvement. Zen 4's OOO engine already hides the FMA
71/// dependency chain across iterations — adding a second accumulator just
72/// adds merge overhead without helping the pipeline.
73#[cfg(target_arch = "x86_64")]
74#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "fma")]
75unsafe fn process_q4k_superblock_avx512(
76    sb_data: &[u8],
77    input: &[f32],
78    input_offset: usize,
79    in_dim: usize,
80    low_mask: std::arch::x86_64::__m512i,
81    acc: &mut std::arch::x86_64::__m512,
82) {
83    unsafe {
84        use std::arch::x86_64::*;
85
86        let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
87        let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
88        let qs_ptr = qs.as_ptr();
89        let input_ptr = input.as_ptr();
90
91        // Software prefetch: next superblock's header + first quant bytes
92        // Prefetch 2 cache lines ahead (128 bytes = most of next superblock)
93        _mm_prefetch(sb_data.as_ptr().add(SUPER_BLOCK_BYTES) as *const i8, _MM_HINT_T0);
94        _mm_prefetch(sb_data.as_ptr().add(SUPER_BLOCK_BYTES + 64) as *const i8, _MM_HINT_T0);
95
96        for chunk_i in 0..4 {
97            let chunk_start = chunk_i * 64;
98            let q_start = chunk_i * 32;
99
100            let d1 = d * f32::from(scales[chunk_i * 2]);
101            let dm1 = dmin * f32::from(mins[chunk_i * 2]);
102            let d2 = d * f32::from(scales[chunk_i * 2 + 1]);
103            let dm2 = dmin * f32::from(mins[chunk_i * 2 + 1]);
104
105            let d1_vec = _mm512_set1_ps(d1);
106            let dm1_vec = _mm512_set1_ps(dm1);
107            let d2_vec = _mm512_set1_ps(d2);
108            let dm2_vec = _mm512_set1_ps(dm2);
109
110            // Low nibbles: 2×16 = 32 elements, fully unrolled
111            let input_base_lo0 = input_offset + chunk_start;
112            if input_base_lo0 + 32 <= in_dim {
113                // First 16 low nibbles
114                let q0 = _mm_loadu_si128(qs_ptr.add(q_start) as *const __m128i);
115                let q0_i32 = _mm512_cvtepu8_epi32(q0);
116                let q0_low = _mm512_and_si512(q0_i32, low_mask);
117                let q0_f32 = _mm512_cvtepi32_ps(q0_low);
118                let x0 = _mm512_loadu_ps(input_ptr.add(input_base_lo0));
119                let dq0 = _mm512_fmsub_ps(d1_vec, q0_f32, dm1_vec);
120                *acc = _mm512_fmadd_ps(dq0, x0, *acc);
121
122                // Second 16 low nibbles
123                let q1 = _mm_loadu_si128(qs_ptr.add(q_start + 16) as *const __m128i);
124                let q1_i32 = _mm512_cvtepu8_epi32(q1);
125                let q1_low = _mm512_and_si512(q1_i32, low_mask);
126                let q1_f32 = _mm512_cvtepi32_ps(q1_low);
127                let x1 = _mm512_loadu_ps(input_ptr.add(input_base_lo0 + 16));
128                let dq1 = _mm512_fmsub_ps(d1_vec, q1_f32, dm1_vec);
129                *acc = _mm512_fmadd_ps(dq1, x1, *acc);
130
131                // High nibbles: 2×16 = 32 elements, fully unrolled
132                let input_base_hi0 = input_offset + chunk_start + 32;
133
134                // First 16 high nibbles (reuse q0 loaded above)
135                let q0_high = _mm512_srli_epi32(q0_i32, 4);
136                let q0h_f32 = _mm512_cvtepi32_ps(q0_high);
137                let xh0 = _mm512_loadu_ps(input_ptr.add(input_base_hi0));
138                let dqh0 = _mm512_fmsub_ps(d2_vec, q0h_f32, dm2_vec);
139                *acc = _mm512_fmadd_ps(dqh0, xh0, *acc);
140
141                // Second 16 high nibbles (reuse q1 loaded above)
142                let q1_high = _mm512_srli_epi32(q1_i32, 4);
143                let q1h_f32 = _mm512_cvtepi32_ps(q1_high);
144                let xh1 = _mm512_loadu_ps(input_ptr.add(input_base_hi0 + 16));
145                let dqh1 = _mm512_fmsub_ps(d2_vec, q1h_f32, dm2_vec);
146                *acc = _mm512_fmadd_ps(dqh1, xh1, *acc);
147            }
148        }
149    }
150}
151
152/// AVX-512 horizontal sum of 16 f32 lanes.
153/// Uses avx512f-only intrinsics (no avx512dq dependency).
154#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx512f")]
156unsafe fn hsum_avx512(v: std::arch::x86_64::__m512) -> f32 {
157    use std::arch::x86_64::*;
158    // Reduce 512→256 using shuffle instead of extractf32x8 (which needs avx512dq)
159    let lo256 = _mm512_castps512_ps256(v);
160    // Shift upper 256 bits down: use _mm512_shuffle_f32x4 to bring lanes 8-15 into 0-7
161    let hi_shifted = _mm512_shuffle_f32x4(v, v, 0b_01_00_11_10); // swap upper and lower 256
162    let hi256 = _mm512_castps512_ps256(hi_shifted);
163    let sum256 = _mm256_add_ps(lo256, hi256);
164    // Now reduce 256→scalar
165    let hi128 = _mm256_extractf128_ps(sum256, 1);
166    let lo128 = _mm256_castps256_ps128(sum256);
167    let sum128 = _mm_add_ps(lo128, hi128);
168    let hi64 = _mm_movehl_ps(sum128, sum128);
169    let sum64 = _mm_add_ps(sum128, hi64);
170    let hi32 = _mm_shuffle_ps(sum64, sum64, 1);
171    let sum32 = _mm_add_ss(sum64, hi32);
172    _mm_cvtss_f32(sum32)
173}
174
175/// AVX-512 chunk processor for parallel dispatch.
176/// Contract: avx512-q4k-v1.yaml
177#[cfg(target_arch = "x86_64")]
178#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "fma")]
179pub(crate) unsafe fn compute_chunk_q4k_avx512(
180    q4k_data: &[u8],
181    input: &[f32],
182    chunk: &mut [f32],
183    start_row: usize,
184    out_dim: usize,
185    in_dim: usize,
186    num_blocks_per_row: usize,
187    row_bytes: usize,
188) {
189    unsafe {
190        use std::arch::x86_64::*;
191
192        let low_mask = _mm512_set1_epi32(0x0F);
193
194        for (local_idx, out_val) in chunk.iter_mut().enumerate() {
195            let out_idx = start_row + local_idx;
196            if out_idx >= out_dim {
197                break;
198            }
199            let row_start = out_idx * row_bytes;
200            let mut acc = _mm512_setzero_ps();
201
202            for sb_idx in 0..num_blocks_per_row {
203                let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
204                let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
205                let input_offset = sb_idx * SUPER_BLOCK_SIZE;
206                process_q4k_superblock_avx512(
207                    sb_data,
208                    input,
209                    input_offset,
210                    in_dim,
211                    low_mask,
212                    &mut acc,
213                );
214            }
215
216            *out_val = hsum_avx512(acc);
217        }
218    }
219    contract_post_dequant!(chunk);
220}