trueno/backends/q4k/gemv/mod.rs
1//! Row-major Q4_K matrix-vector multiplication.
2//!
3//! This module implements row-major GEMV where weights are stored row-first.
4//! Includes scalar, AVX2-optimized, and parallel dispatch implementations.
5
6mod scalar;
7
8#[cfg(target_arch = "x86_64")]
9mod avx2;
10
11#[cfg(target_arch = "x86_64")]
12mod avx512;
13
14use super::{SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
15
16// Re-export public API (preserves exact public surface)
17pub use scalar::{matmul_q4k_f32, matmul_q4k_f32_scalar};
18
19// Re-export crate-internal API (used by sibling test modules)
20#[allow(unused_imports)]
21pub(crate) use scalar::compute_chunk_q4k_scalar;
22
23/// Runtime dispatch for Q4K matmul - uses AVX2 if available, otherwise scalar
24///
25/// # Contract (GH-279)
26///
27/// Preconditions validated via `debug_assert!` (zero-cost in release):
28/// - `q4k_data.len() >= contracts::Q4_K.expected_bytes(out_dim, in_dim)`
29/// - `input.len() == in_dim`
30///
31/// These guarantee that inner-loop `expect()` calls on super-block sub-slices
32/// are unreachable: each super-block is sliced to exactly `SUPER_BLOCK_BYTES`
33/// (144), and all sub-accesses (`get(4..16)`, `get(16..144)`) fit within that.
34#[inline]
35pub fn matmul_q4k_f32_dispatch(
36 q4k_data: &[u8],
37 input: &[f32],
38 out_dim: usize,
39 in_dim: usize,
40) -> Vec<f32> {
41 // GH-279: Contract validation at dispatch boundary.
42 // Inner expect() calls are defense-in-depth — provably unreachable when
43 // this precondition holds, because every sb_data slice is SUPER_BLOCK_BYTES.
44 debug_assert_eq!(input.len(), in_dim, "Q4K dispatch: input length mismatch");
45 debug_assert!(
46 q4k_data.len() >= crate::contracts::Q4_K.expected_bytes(out_dim, in_dim),
47 "Q4K dispatch: buffer too small: {} bytes for [{}, {}] (need {})",
48 q4k_data.len(),
49 out_dim,
50 in_dim,
51 crate::contracts::Q4_K.expected_bytes(out_dim, in_dim),
52 );
53
54 #[cfg(target_arch = "x86_64")]
55 {
56 // For large Q4K matmuls (total work >= ~8M elements), use parallel execution.
57 // This catches FFN layers (8960×1536 = 13.7M) and lm_head (151936×1536).
58 // Threshold tested at 2M (2026-04-05) but REGRESSED: 1536×1536 went from
59 // 17→14 GFLOPS because parallel overhead (~40µs) dominates at 277µs total.
60 // Contract: cgp-q4k-parallel-threshold-v1.yaml documents negative result.
61 let total_work = out_dim * in_dim;
62 if total_work >= 8_000_000 {
63 return matmul_q4k_f32_parallel(q4k_data, input, out_dim, in_dim);
64 }
65
66 // AVX-512: 16-wide dequant+FMA (2× throughput vs AVX2)
67 // Contract: avx512-q4k-v1.yaml (C-AVX512-Q4K-001, C-AVX512-Q4K-002)
68 if is_x86_feature_detected!("avx512f")
69 && is_x86_feature_detected!("avx512bw")
70 && is_x86_feature_detected!("fma")
71 {
72 return unsafe { avx512::matmul_q4k_f32_avx512(q4k_data, input, out_dim, in_dim) };
73 }
74
75 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
76 // SAFETY: We just verified AVX2 + FMA are available
77 return unsafe { avx2::matmul_q4k_f32_avx2(q4k_data, input, out_dim, in_dim) };
78 }
79 }
80
81 // Fallback to scalar with 4-way unroll
82 scalar::matmul_q4k_f32(q4k_data, input, out_dim, in_dim)
83}
84
85/// Fused Q4_K matrix-vector multiply for GGML column-major layout
86///
87/// Computes: output = input @ Q4K_weight (GGML convention: y = x @ W)
88/// where weight is stored in Q4_K format with GGML column-major super-block organization.
89///
90/// # GGML Column-Major Layout (PMAT-103)
91///
92/// For a weight tensor with shape [ne0, ne1] in GGML notation:
93/// - ne0 is the output dimension (rows)
94/// - ne1 is the input/reduction dimension (columns)
95/// - Elements are stored column-major: W[i,j] at offset i + j*ne0
96/// - Each column j (length ne0) contains weights from input[j] to all outputs
97/// - Super-blocks are organized by columns: column j uses super-blocks [j*blocks_per_col, (j+1)*blocks_per_col)
98///
99/// This matches GGUF tensor storage and enables fused kernel execution without transposition.
100///
101/// # Arguments
102/// * `q4k_data` - Raw Q4K bytes in GGML column-major layout [ne0, ne1]
103/// * `input` - F32 input vector [ne1] (input/reduction dimension)
104/// * `ne0` - Size of output dimension (rows in GGML, output size)
105/// * `ne1` - Size of input/reduction dimension (columns in GGML, input size)
106///
107/// # Returns
108/// F32 output vector [ne0]
109///
110/// # Example
111/// ```rust,ignore
112/// // GGUF ffn_gate: shape [intermediate_dim, hidden_dim] = [8960, 1536]
113/// // Computes: intermediate = hidden @ ffn_gate
114/// let output = matmul_q4k_f32_colmajor(&q4k_bytes, &hidden, 8960, 1536);
115/// // output has 8960 elements
116/// ```
117
118// ============================================================================
119// Parallel Execution Helpers
120// ============================================================================
121
122#[cfg(target_arch = "x86_64")]
123fn matmul_q4k_f32_parallel(
124 q4k_data: &[u8],
125 input: &[f32],
126 out_dim: usize,
127 in_dim: usize,
128) -> Vec<f32> {
129 use std::thread;
130
131 // Use fewer threads with larger chunks for better cache efficiency
132 let num_threads = thread::available_parallelism().map(|p| p.get()).unwrap_or(4).min(12);
133
134 let chunk_size = (out_dim + num_threads - 1) / num_threads;
135 let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
136 let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
137
138 // Uninit: compute_chunk_* writes *out_val = hsum(acc) for every element.
139 let mut output: Vec<f32> = Vec::with_capacity(out_dim);
140 // SAFETY: Each thread's compute_chunk writes every element in its chunk (SET).
141 unsafe {
142 output.set_len(out_dim);
143 }
144 let has_avx512 = is_x86_feature_detected!("avx512f")
145 && is_x86_feature_detected!("avx512bw")
146 && is_x86_feature_detected!("fma");
147 let has_avx2 = is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma");
148
149 thread::scope(|s| {
150 let input_ref = input;
151 let q4k_ref = q4k_data;
152 // CGP-DBUF: iterate directly instead of collecting into Vec.
153 for (chunk_idx, chunk) in output.chunks_mut(chunk_size).enumerate() {
154 let start_row = chunk_idx * chunk_size;
155
156 s.spawn(move || {
157 if has_avx512 {
158 // Contract: avx512-q4k-v1.yaml (C-AVX512-Q4K-001)
159 unsafe {
160 avx512::compute_chunk_q4k_avx512(
161 q4k_ref,
162 input_ref,
163 chunk,
164 start_row,
165 out_dim,
166 in_dim,
167 num_blocks_per_row,
168 row_bytes,
169 );
170 }
171 } else if has_avx2 {
172 unsafe {
173 avx2::compute_chunk_q4k_avx2(
174 q4k_ref,
175 input_ref,
176 chunk,
177 start_row,
178 out_dim,
179 in_dim,
180 num_blocks_per_row,
181 row_bytes,
182 );
183 }
184 } else {
185 scalar::compute_chunk_q4k_scalar(
186 q4k_ref,
187 input_ref,
188 chunk,
189 start_row,
190 out_dim,
191 in_dim,
192 num_blocks_per_row,
193 row_bytes,
194 );
195 }
196 });
197 }
198 });
199
200 output
201}
202
203/// Fallback for non-x86_64
204#[cfg(not(target_arch = "x86_64"))]
205fn matmul_q4k_f32_parallel(
206 q4k_data: &[u8],
207 input: &[f32],
208 out_dim: usize,
209 in_dim: usize,
210) -> Vec<f32> {
211 scalar::matmul_q4k_f32(q4k_data, input, out_dim, in_dim)
212}