Skip to main content

trueno/backends/q4k/gemv/
mod.rs

1//! Row-major Q4_K matrix-vector multiplication.
2//!
3//! This module implements row-major GEMV where weights are stored row-first.
4//! Includes scalar, AVX2-optimized, and parallel dispatch implementations.
5
6mod scalar;
7
8#[cfg(target_arch = "x86_64")]
9mod avx2;
10
11#[cfg(target_arch = "x86_64")]
12mod avx512;
13
14use super::{SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
15
16// Re-export public API (preserves exact public surface)
17pub use scalar::{matmul_q4k_f32, matmul_q4k_f32_scalar};
18
19// Re-export crate-internal API (used by sibling test modules)
20#[allow(unused_imports)]
21pub(crate) use scalar::compute_chunk_q4k_scalar;
22
23/// Runtime dispatch for Q4K matmul - uses AVX2 if available, otherwise scalar
24///
25/// # Contract (GH-279)
26///
27/// Preconditions validated via `debug_assert!` (zero-cost in release):
28/// - `q4k_data.len() >= contracts::Q4_K.expected_bytes(out_dim, in_dim)`
29/// - `input.len() == in_dim`
30///
31/// These guarantee that inner-loop `expect()` calls on super-block sub-slices
32/// are unreachable: each super-block is sliced to exactly `SUPER_BLOCK_BYTES`
33/// (144), and all sub-accesses (`get(4..16)`, `get(16..144)`) fit within that.
34#[inline]
35pub fn matmul_q4k_f32_dispatch(
36    q4k_data: &[u8],
37    input: &[f32],
38    out_dim: usize,
39    in_dim: usize,
40) -> Vec<f32> {
41    // GH-279: Contract validation at dispatch boundary.
42    // Inner expect() calls are defense-in-depth — provably unreachable when
43    // this precondition holds, because every sb_data slice is SUPER_BLOCK_BYTES.
44    debug_assert_eq!(input.len(), in_dim, "Q4K dispatch: input length mismatch");
45    debug_assert!(
46        q4k_data.len() >= crate::contracts::Q4_K.expected_bytes(out_dim, in_dim),
47        "Q4K dispatch: buffer too small: {} bytes for [{}, {}] (need {})",
48        q4k_data.len(),
49        out_dim,
50        in_dim,
51        crate::contracts::Q4_K.expected_bytes(out_dim, in_dim),
52    );
53
54    #[cfg(target_arch = "x86_64")]
55    {
56        // For large Q4K matmuls (total work >= ~8M elements), use parallel execution.
57        // This catches FFN layers (8960×1536 = 13.7M) and lm_head (151936×1536).
58        // Threshold tested at 2M (2026-04-05) but REGRESSED: 1536×1536 went from
59        // 17→14 GFLOPS because parallel overhead (~40µs) dominates at 277µs total.
60        // Contract: cgp-q4k-parallel-threshold-v1.yaml documents negative result.
61        let total_work = out_dim * in_dim;
62        if total_work >= 8_000_000 {
63            return matmul_q4k_f32_parallel(q4k_data, input, out_dim, in_dim);
64        }
65
66        // AVX-512: 16-wide dequant+FMA (2× throughput vs AVX2)
67        // Contract: avx512-q4k-v1.yaml (C-AVX512-Q4K-001, C-AVX512-Q4K-002)
68        if is_x86_feature_detected!("avx512f")
69            && is_x86_feature_detected!("avx512bw")
70            && is_x86_feature_detected!("fma")
71        {
72            return unsafe { avx512::matmul_q4k_f32_avx512(q4k_data, input, out_dim, in_dim) };
73        }
74
75        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
76            // SAFETY: We just verified AVX2 + FMA are available
77            return unsafe { avx2::matmul_q4k_f32_avx2(q4k_data, input, out_dim, in_dim) };
78        }
79    }
80
81    // Fallback to scalar with 4-way unroll
82    scalar::matmul_q4k_f32(q4k_data, input, out_dim, in_dim)
83}
84
85/// Fused Q4_K matrix-vector multiply for GGML column-major layout
86///
87/// Computes: output = input @ Q4K_weight (GGML convention: y = x @ W)
88/// where weight is stored in Q4_K format with GGML column-major super-block organization.
89///
90/// # GGML Column-Major Layout (PMAT-103)
91///
92/// For a weight tensor with shape [ne0, ne1] in GGML notation:
93/// - ne0 is the output dimension (rows)
94/// - ne1 is the input/reduction dimension (columns)
95/// - Elements are stored column-major: W[i,j] at offset i + j*ne0
96/// - Each column j (length ne0) contains weights from input[j] to all outputs
97/// - Super-blocks are organized by columns: column j uses super-blocks [j*blocks_per_col, (j+1)*blocks_per_col)
98///
99/// This matches GGUF tensor storage and enables fused kernel execution without transposition.
100///
101/// # Arguments
102/// * `q4k_data` - Raw Q4K bytes in GGML column-major layout [ne0, ne1]
103/// * `input` - F32 input vector [ne1] (input/reduction dimension)
104/// * `ne0` - Size of output dimension (rows in GGML, output size)
105/// * `ne1` - Size of input/reduction dimension (columns in GGML, input size)
106///
107/// # Returns
108/// F32 output vector [ne0]
109///
110/// # Example
111/// ```rust,ignore
112/// // GGUF ffn_gate: shape [intermediate_dim, hidden_dim] = [8960, 1536]
113/// // Computes: intermediate = hidden @ ffn_gate
114/// let output = matmul_q4k_f32_colmajor(&q4k_bytes, &hidden, 8960, 1536);
115/// // output has 8960 elements
116/// ```
117
118// ============================================================================
119// Parallel Execution Helpers
120// ============================================================================
121
122#[cfg(target_arch = "x86_64")]
123fn matmul_q4k_f32_parallel(
124    q4k_data: &[u8],
125    input: &[f32],
126    out_dim: usize,
127    in_dim: usize,
128) -> Vec<f32> {
129    use std::thread;
130
131    // Use fewer threads with larger chunks for better cache efficiency
132    let num_threads = thread::available_parallelism().map(|p| p.get()).unwrap_or(4).min(12);
133
134    let chunk_size = (out_dim + num_threads - 1) / num_threads;
135    let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
136    let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
137
138    // Uninit: compute_chunk_* writes *out_val = hsum(acc) for every element.
139    let mut output: Vec<f32> = Vec::with_capacity(out_dim);
140    // SAFETY: Each thread's compute_chunk writes every element in its chunk (SET).
141    unsafe {
142        output.set_len(out_dim);
143    }
144    let has_avx512 = is_x86_feature_detected!("avx512f")
145        && is_x86_feature_detected!("avx512bw")
146        && is_x86_feature_detected!("fma");
147    let has_avx2 = is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma");
148
149    thread::scope(|s| {
150        let input_ref = input;
151        let q4k_ref = q4k_data;
152        // CGP-DBUF: iterate directly instead of collecting into Vec.
153        for (chunk_idx, chunk) in output.chunks_mut(chunk_size).enumerate() {
154            let start_row = chunk_idx * chunk_size;
155
156            s.spawn(move || {
157                if has_avx512 {
158                    // Contract: avx512-q4k-v1.yaml (C-AVX512-Q4K-001)
159                    unsafe {
160                        avx512::compute_chunk_q4k_avx512(
161                            q4k_ref,
162                            input_ref,
163                            chunk,
164                            start_row,
165                            out_dim,
166                            in_dim,
167                            num_blocks_per_row,
168                            row_bytes,
169                        );
170                    }
171                } else if has_avx2 {
172                    unsafe {
173                        avx2::compute_chunk_q4k_avx2(
174                            q4k_ref,
175                            input_ref,
176                            chunk,
177                            start_row,
178                            out_dim,
179                            in_dim,
180                            num_blocks_per_row,
181                            row_bytes,
182                        );
183                    }
184                } else {
185                    scalar::compute_chunk_q4k_scalar(
186                        q4k_ref,
187                        input_ref,
188                        chunk,
189                        start_row,
190                        out_dim,
191                        in_dim,
192                        num_blocks_per_row,
193                        row_bytes,
194                    );
195                }
196            });
197        }
198    });
199
200    output
201}
202
203/// Fallback for non-x86_64
204#[cfg(not(target_arch = "x86_64"))]
205fn matmul_q4k_f32_parallel(
206    q4k_data: &[u8],
207    input: &[f32],
208    out_dim: usize,
209    in_dim: usize,
210) -> Vec<f32> {
211    scalar::matmul_q4k_f32(q4k_data, input, out_dim, in_dim)
212}