Skip to main content

trueno/backends/q6k/
gemv.rs

1#![allow(missing_docs)]
2//! Row-major Q6_K matrix-vector multiplication.
3//!
4//! This module implements row-major GEMV for Q6_K format.
5//! Includes scalar, AVX2-optimized, and parallel dispatch implementations.
6
7use super::{f16_to_f32, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
8
9/// Fused Q6_K matrix-vector multiply (scalar reference)
10/// Extract a single Q6K quantized value from packed ql/qh arrays.
11#[inline(always)]
12fn extract_q6k_scalar(ql: &[u8], qh: &[u8], idx: usize) -> i8 {
13    let ql_byte = ql[idx / 2];
14    let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
15    let qh_byte = qh[idx / 4];
16    let high2 = (qh_byte >> ((idx % 4) * 2)) & 0x03;
17    (low4 | (high2 << 4)) as i8 - 32
18}
19
20/// Scalar dot product for one Q6K super-block row.
21#[inline(always)]
22fn process_q6k_superblock_scalar(
23    sb_data: &[u8],
24    input: &[f32],
25    input_offset: usize,
26    in_dim: usize,
27) -> f32 {
28    let ql = sb_data.get(0..128).expect("Q6_K: need ≥128 bytes for ql");
29    let qh = sb_data.get(128..192).expect("Q6_K: need ≥192 bytes for qh");
30    let scales = sb_data.get(192..208).expect("Q6_K: need ≥208 bytes for scales");
31    let d = f16_to_f32(u16::from_le_bytes([sb_data[208], sb_data[209]]));
32    let mut sum = 0.0f32;
33
34    for group in 0..16 {
35        let scale = (scales[group] as i8) as f32;
36        let group_offset = group * 16;
37
38        for j in 0..16 {
39            let idx = group_offset + j;
40            let input_idx = input_offset + idx;
41            if input_idx >= in_dim {
42                continue;
43            }
44            let q6 = extract_q6k_scalar(ql, qh, idx);
45            sum += d * scale * q6 as f32 * input[input_idx];
46        }
47    }
48    sum
49}
50
51pub fn matmul_q6k_f32_scalar(
52    q6k_data: &[u8],
53    input: &[f32],
54    out_dim: usize,
55    in_dim: usize,
56) -> Vec<f32> {
57    assert_eq!(input.len(), in_dim, "Input length mismatch");
58
59    let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
60    let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
61
62    // Uninit: output[out_idx] = sum (SET) for every out_idx.
63    let mut output: Vec<f32> = Vec::with_capacity(out_dim);
64    // SAFETY: Each output[out_idx] is SET from local accumulator sum.
65    unsafe {
66        output.set_len(out_dim);
67    }
68
69    for out_idx in 0..out_dim {
70        let row_start = out_idx * row_bytes;
71        let mut sum = 0.0f32;
72
73        for sb_idx in 0..num_blocks_per_row {
74            let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
75            if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
76                break;
77            }
78            let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
79            let input_offset = sb_idx * SUPER_BLOCK_SIZE;
80            sum += process_q6k_superblock_scalar(sb_data, input, input_offset, in_dim);
81        }
82
83        output[out_idx] = sum;
84    }
85
86    output
87}
88
89/// Extract 8 Q6K quantized values from packed ql/qh arrays.
90#[inline(always)]
91fn extract_q6k_values(ql: &[u8], qh: &[u8], idx_base: usize) -> [i32; 8] {
92    let mut q6_vals = [0i32; 8];
93    for i in 0..8 {
94        let idx = idx_base + i;
95        let ql_byte = ql[idx / 2];
96        let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
97        let qh_byte = qh[idx / 4];
98        let qh_shift = (idx % 4) * 2;
99        let high2 = (qh_byte >> qh_shift) & 0x03;
100        q6_vals[i] = ((low4 | (high2 << 4)) as i32) - 32;
101    }
102    q6_vals
103}
104
105/// AVX2 horizontal sum of 8 f32 lanes to a single f32.
106#[cfg(target_arch = "x86_64")]
107#[target_feature(enable = "avx2")]
108// SAFETY: caller verifies AVX2 support, input slices meet alignment/length requirements
109unsafe fn hsum_q6k_avx2(acc: std::arch::x86_64::__m256) -> f32 {
110    use std::arch::x86_64::*;
111    let hi128 = _mm256_extractf128_ps(acc, 1);
112    let lo128 = _mm256_castps256_ps128(acc);
113    let sum128 = _mm_add_ps(lo128, hi128);
114    let hi64 = _mm_movehl_ps(sum128, sum128);
115    let sum64 = _mm_add_ps(sum128, hi64);
116    let hi32 = _mm_shuffle_ps(sum64, sum64, 1);
117    let sum32 = _mm_add_ss(sum64, hi32);
118    _mm_cvtss_f32(sum32)
119}
120
121/// Process one Q6K super-block with AVX2, accumulating into `acc`.
122#[cfg(target_arch = "x86_64")]
123#[target_feature(enable = "avx2", enable = "fma")]
124// SAFETY: Caller ensures AVX2+FMA are available and sb_data is a valid Q6_K super-block
125unsafe fn process_q6k_superblock_avx2(
126    sb_data: &[u8],
127    input: &[f32],
128    input_offset: usize,
129    in_dim: usize,
130    acc: &mut std::arch::x86_64::__m256,
131) {
132    unsafe {
133        use std::arch::x86_64::*;
134
135        let ql = sb_data.get(0..128).expect("Q6_K: need ≥128 bytes for ql");
136        let qh = sb_data.get(128..192).expect("Q6_K: need ≥192 bytes for qh");
137        let scales = sb_data.get(192..208).expect("Q6_K: need ≥208 bytes for scales");
138        let d = f16_to_f32(u16::from_le_bytes([sb_data[208], sb_data[209]]));
139        let d_vec = _mm256_set1_ps(d);
140
141        for group in 0..16 {
142            let scale = (scales[group] as i8) as f32;
143            let ds_vec = _mm256_mul_ps(d_vec, _mm256_set1_ps(scale));
144            let group_offset = group * 16;
145            let input_group = input_offset + group_offset;
146
147            for half in 0..2 {
148                let half_offset = half * 8;
149                let input_base = input_group + half_offset;
150                if input_base + 8 > in_dim {
151                    continue;
152                }
153
154                let q6_vals = extract_q6k_values(ql, qh, group_offset + half_offset);
155                let q6_i32 = _mm256_loadu_si256(q6_vals.as_ptr() as *const __m256i);
156                let q6_f32 = _mm256_cvtepi32_ps(q6_i32);
157                let x = _mm256_loadu_ps(input.as_ptr().add(input_base));
158                let dequant = _mm256_mul_ps(ds_vec, q6_f32);
159                *acc = _mm256_fmadd_ps(dequant, x, *acc);
160            }
161        }
162    }
163}
164
165/// Fused Q6_K matrix-vector multiply with AVX2 SIMD
166///
167/// Optimized to process groups of 8 values at a time, computing
168/// dequant and dot product in one pass without intermediate buffer.
169#[cfg(target_arch = "x86_64")]
170#[target_feature(enable = "avx2", enable = "fma")]
171// SAFETY: Caller ensures AVX2+FMA are available and q6k_data is valid Q6_K layout
172unsafe fn matmul_q6k_f32_avx2(
173    q6k_data: &[u8],
174    input: &[f32],
175    out_dim: usize,
176    in_dim: usize,
177) -> Vec<f32> {
178    unsafe {
179        use std::arch::x86_64::*;
180
181        let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
182        let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
183
184        // Uninit: output[out_idx] = hsum_q6k_avx2(acc) (SET) for every out_idx.
185        let mut output: Vec<f32> = Vec::with_capacity(out_dim);
186        // SAFETY: Each output[out_idx] is SET from local SIMD accumulator.
187        output.set_len(out_dim);
188
189        for out_idx in 0..out_dim {
190            let row_start = out_idx * row_bytes;
191            let mut acc = _mm256_setzero_ps();
192
193            for sb_idx in 0..num_blocks_per_row {
194                let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
195                if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
196                    break;
197                }
198                let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
199                let input_offset = sb_idx * SUPER_BLOCK_SIZE;
200                process_q6k_superblock_avx2(sb_data, input, input_offset, in_dim, &mut acc);
201            }
202
203            output[out_idx] = hsum_q6k_avx2(acc);
204        }
205
206        output
207    }
208}
209
210/// Runtime dispatch for Q6K matmul - uses AVX2 if available
211///
212/// # Contract (GH-279)
213///
214/// Preconditions validated via `debug_assert!` (zero-cost in release):
215/// - `q6k_data.len() >= contracts::Q6_K.expected_bytes(out_dim, in_dim)`
216/// - `input.len() == in_dim`
217///
218/// These guarantee that inner-loop `expect()` calls on super-block sub-slices
219/// are unreachable: each super-block is sliced to exactly `SUPER_BLOCK_BYTES`
220/// (210), and all sub-accesses (`get(0..128)`, `get(128..192)`, `get(192..208)`)
221/// fit within that.
222#[inline]
223pub fn matmul_q6k_f32_dispatch(
224    q6k_data: &[u8],
225    input: &[f32],
226    out_dim: usize,
227    in_dim: usize,
228) -> Vec<f32> {
229    // GH-279: Contract validation at dispatch boundary.
230    // Inner expect() calls are defense-in-depth — provably unreachable when
231    // this precondition holds, because every sb_data slice is SUPER_BLOCK_BYTES.
232    debug_assert_eq!(input.len(), in_dim, "Q6K dispatch: input length mismatch");
233    debug_assert!(
234        q6k_data.len() >= crate::contracts::Q6_K.expected_bytes(out_dim, in_dim),
235        "Q6K dispatch: buffer too small: {} bytes for [{}, {}] (need {})",
236        q6k_data.len(),
237        out_dim,
238        in_dim,
239        crate::contracts::Q6_K.expected_bytes(out_dim, in_dim),
240    );
241
242    // For large matmuls (total work >= ~8M ops), use parallel execution
243    // This catches FFN layers (8960x1536) and lm_head (151936x1536)
244    // Also catches ffn_down (1536x8960) where out_dim is small but in_dim is large
245    let total_work = out_dim * in_dim;
246    if total_work >= 8_000_000 {
247        return matmul_q6k_f32_parallel(q6k_data, input, out_dim, in_dim);
248    }
249
250    #[cfg(target_arch = "x86_64")]
251    {
252        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
253            // SAFETY: preconditions verified by caller
254            return unsafe { matmul_q6k_f32_avx2(q6k_data, input, out_dim, in_dim) };
255        }
256    }
257    matmul_q6k_f32_scalar(q6k_data, input, out_dim, in_dim)
258}
259
260/// Parallel Q6K matmul using multiple threads with AVX2
261#[cfg(target_arch = "x86_64")]
262fn matmul_q6k_f32_parallel(
263    q6k_data: &[u8],
264    input: &[f32],
265    out_dim: usize,
266    in_dim: usize,
267) -> Vec<f32> {
268    use std::thread;
269
270    // Use fewer threads with larger chunks for better cache efficiency
271    let num_threads = thread::available_parallelism().map(|p| p.get()).unwrap_or(4).min(12); // Use 12 threads max for better cache behavior
272
273    let chunk_size = (out_dim + num_threads - 1) / num_threads;
274    let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
275    let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
276
277    // Uninit: compute_chunk writes *out_val = sum/hsum(acc) (SET) for every element.
278    let mut output: Vec<f32> = Vec::with_capacity(out_dim);
279    // SAFETY: Each thread's compute_chunk writes every element in its chunk (SET).
280    unsafe {
281        output.set_len(out_dim);
282    }
283    let has_avx2 = is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma");
284
285    thread::scope(|s| {
286        let input_ref = input;
287        let q6k_ref = q6k_data;
288        // CGP-DBUF: iterate directly instead of collecting into Vec.
289        for (chunk_idx, chunk) in output.chunks_mut(chunk_size).enumerate() {
290            let start_row = chunk_idx * chunk_size;
291
292            s.spawn(move || {
293                if has_avx2 {
294                    // SAFETY: AVX2+FMA availability verified via is_x86_feature_detected!()
295                    // before thread::scope entry; has_avx2 captures that result.
296                    unsafe {
297                        compute_chunk_avx2(
298                            q6k_ref,
299                            input_ref,
300                            chunk,
301                            start_row,
302                            out_dim,
303                            in_dim,
304                            num_blocks_per_row,
305                            row_bytes,
306                        );
307                    }
308                } else {
309                    compute_chunk_scalar(
310                        q6k_ref,
311                        input_ref,
312                        chunk,
313                        start_row,
314                        out_dim,
315                        in_dim,
316                        num_blocks_per_row,
317                        row_bytes,
318                    );
319                }
320            });
321        }
322    });
323
324    output
325}
326
327/// Fallback for non-x86_64
328#[cfg(not(target_arch = "x86_64"))]
329fn matmul_q6k_f32_parallel(
330    q6k_data: &[u8],
331    input: &[f32],
332    out_dim: usize,
333    in_dim: usize,
334) -> Vec<f32> {
335    matmul_q6k_f32_scalar(q6k_data, input, out_dim, in_dim)
336}
337
338#[cfg(target_arch = "x86_64")]
339#[target_feature(enable = "avx2", enable = "fma")]
340// SAFETY: Caller ensures AVX2+FMA are available and chunk bounds are valid
341unsafe fn compute_chunk_avx2(
342    q6k_data: &[u8],
343    input: &[f32],
344    chunk: &mut [f32],
345    start_row: usize,
346    out_dim: usize,
347    in_dim: usize,
348    num_blocks_per_row: usize,
349    row_bytes: usize,
350) {
351    unsafe {
352        use std::arch::x86_64::*;
353
354        for (local_idx, out_val) in chunk.iter_mut().enumerate() {
355            let out_idx = start_row + local_idx;
356            if out_idx >= out_dim {
357                break;
358            }
359
360            let row_start = out_idx * row_bytes;
361            let mut acc = _mm256_setzero_ps();
362
363            for sb_idx in 0..num_blocks_per_row {
364                let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
365                if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
366                    break;
367                }
368                let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
369                let input_offset = sb_idx * SUPER_BLOCK_SIZE;
370                process_q6k_superblock_avx2(sb_data, input, input_offset, in_dim, &mut acc);
371            }
372
373            *out_val = hsum_q6k_avx2(acc);
374        }
375    }
376}
377
378pub(crate) fn compute_chunk_scalar(
379    q6k_data: &[u8],
380    input: &[f32],
381    chunk: &mut [f32],
382    start_row: usize,
383    out_dim: usize,
384    in_dim: usize,
385    num_blocks_per_row: usize,
386    row_bytes: usize,
387) {
388    for (local_idx, out_val) in chunk.iter_mut().enumerate() {
389        let out_idx = start_row + local_idx;
390        if out_idx >= out_dim {
391            break;
392        }
393
394        let row_start = out_idx * row_bytes;
395        let mut sum = 0.0f32;
396
397        for sb_idx in 0..num_blocks_per_row {
398            let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
399            if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
400                break;
401            }
402            let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
403            let input_offset = sb_idx * SUPER_BLOCK_SIZE;
404            sum += process_q6k_superblock_scalar(sb_data, input, input_offset, in_dim);
405        }
406
407        *out_val = sum;
408    }
409}
410
411/// Public alias for the optimized Q6K matmul
412pub fn matmul_q6k_f32(q6k_data: &[u8], input: &[f32], out_dim: usize, in_dim: usize) -> Vec<f32> {
413    matmul_q6k_f32_dispatch(q6k_data, input, out_dim, in_dim)
414}