trueno/tiling/q4k_matvec.rs
1#![allow(missing_docs)]
2//! Q4_K quantized matrix-vector tiling implementation.
3
4use super::config::TilingConfig;
5
6/// Q4_K superblock constants (per GGML specification)
7pub const Q4K_SUPERBLOCK_SIZE: usize = 256;
8pub const Q4K_SUPERBLOCK_BYTES: usize = 144;
9
10/// Tiled Q4_K MatVec executor
11///
12/// Implements TCB-01 pattern: Cache-blocked matvec with 4×1 micro-kernel.
13///
14/// # Memory Layout
15///
16/// Weights are stored in Q4_K superblock format (144 bytes per 256 elements):
17/// - d: f16 (2 bytes) - block scale
18/// - dmin: f16 (2 bytes) - block minimum
19/// - scales: 12 bytes - 8 sub-block scales (6-bit packed)
20/// - qs: 128 bytes - 256 quantized values (4-bit packed)
21///
22/// # Performance Characteristics
23///
24/// - L2-resident: Process midi_tile.m rows at a time
25/// - Vectorized: 4×1 micro-kernel processes 4 output rows simultaneously
26/// - Aligned: K dimension aligned to Q4_K superblock (256)
27#[derive(Debug, Clone)]
28pub struct TiledQ4KMatvec {
29 /// Tiling configuration
30 pub config: TilingConfig,
31 /// Number of rows (M dimension)
32 pub m: usize,
33 /// Number of columns (K dimension)
34 pub k: usize,
35}
36
37impl TiledQ4KMatvec {
38 /// Create a new tiled Q4K matvec executor
39 ///
40 /// # Panics
41 /// Panics if K is not aligned to Q4_K superblock size (256).
42 #[must_use]
43 pub fn new(m: usize, k: usize) -> Self {
44 assert!(
45 k % Q4K_SUPERBLOCK_SIZE == 0,
46 "K dimension ({}) must be aligned to Q4_K superblock size ({})",
47 k,
48 Q4K_SUPERBLOCK_SIZE
49 );
50
51 Self { config: TilingConfig::cpu_avx2_q4k_matvec(), m, k }
52 }
53
54 /// Get number of superblocks per row
55 #[must_use]
56 pub fn superblocks_per_row(&self) -> usize {
57 self.k / Q4K_SUPERBLOCK_SIZE
58 }
59
60 /// Get total number of superblocks
61 #[must_use]
62 pub fn total_superblocks(&self) -> usize {
63 self.m * self.superblocks_per_row()
64 }
65
66 /// Get weight bytes offset for a given row
67 #[must_use]
68 #[inline]
69 pub fn weight_row_offset(&self, row: usize) -> usize {
70 row * self.superblocks_per_row() * Q4K_SUPERBLOCK_BYTES
71 }
72
73 /// Calculate optimal number of parallel rows based on L2 cache
74 ///
75 /// Goal: Keep working set in L2 (256KB typical)
76 /// Working set = midi_tile.m rows × K × sizeof(Q4K) + K × sizeof(f32)
77 #[must_use]
78 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation, clippy::cast_sign_loss)]
79 // SAFETY: k ≤ 2^24 for practical matrix dims so usize→f32 is lossless;
80 // result of f32 multiply is non-negative and fits in usize.
81 pub fn optimal_parallel_rows(&self, l2_bytes: usize) -> usize {
82 // Q4K: 144 bytes per 256 elements = 0.5625 bytes/element
83 let row_bytes = (self.k as f32 * 0.5625) as usize;
84 // Input vector: K × 4 bytes
85 let input_bytes = self.k * 4;
86 // Available for rows
87 let available = l2_bytes.saturating_sub(input_bytes);
88 // Rows that fit (minimum 4 for micro-kernel)
89 (available / row_bytes).max(4)
90 }
91
92 /// Execute tiled matvec (reference scalar implementation)
93 ///
94 /// This is the reference implementation for correctness testing.
95 /// Actual SIMD implementation would be in the backends.
96 ///
97 /// For parallel execution, use [`execute_parallel`] when the `parallel` feature is enabled.
98 pub fn execute_scalar(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
99 assert_eq!(weights.len(), self.total_superblocks() * Q4K_SUPERBLOCK_BYTES);
100 assert_eq!(input.len(), self.k);
101 assert_eq!(output.len(), self.m);
102
103 let superblocks_per_row = self.superblocks_per_row();
104
105 for row in 0..self.m {
106 let mut sum = 0.0f32;
107 let row_offset = row * superblocks_per_row * Q4K_SUPERBLOCK_BYTES;
108
109 for sb in 0..superblocks_per_row {
110 let sb_offset = row_offset + sb * Q4K_SUPERBLOCK_BYTES;
111 let sb_data = &weights[sb_offset..sb_offset + Q4K_SUPERBLOCK_BYTES];
112
113 // Dequantize and dot product for this superblock
114 let input_offset = sb * Q4K_SUPERBLOCK_SIZE;
115 sum += self.scalar_superblock_dot(
116 sb_data,
117 &input[input_offset..input_offset + Q4K_SUPERBLOCK_SIZE],
118 );
119 }
120
121 output[row] = sum;
122 }
123 }
124
125 /// Execute tiled matvec with parallel row processing
126 ///
127 /// Uses Rayon to parallelize across rows for multi-core speedup.
128 /// Falls back to scalar execution if the `parallel` feature is not enabled.
129 ///
130 /// # Performance
131 ///
132 /// Achieves near-linear speedup with core count for large matrices.
133 /// For small matrices (< 256 rows), scalar may be faster due to overhead.
134 #[cfg(feature = "parallel")]
135 pub fn execute_parallel(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
136 use rayon::prelude::*;
137
138 assert_eq!(weights.len(), self.total_superblocks() * Q4K_SUPERBLOCK_BYTES);
139 assert_eq!(input.len(), self.k);
140 assert_eq!(output.len(), self.m);
141
142 let superblocks_per_row = self.superblocks_per_row();
143 let row_stride = superblocks_per_row * Q4K_SUPERBLOCK_BYTES;
144
145 output.par_iter_mut().enumerate().for_each(|(row, out)| {
146 let mut sum = 0.0f32;
147 let row_offset = row * row_stride;
148
149 for sb in 0..superblocks_per_row {
150 let sb_offset = row_offset + sb * Q4K_SUPERBLOCK_BYTES;
151 let sb_data = &weights[sb_offset..sb_offset + Q4K_SUPERBLOCK_BYTES];
152
153 let input_offset = sb * Q4K_SUPERBLOCK_SIZE;
154 sum += self.scalar_superblock_dot(
155 sb_data,
156 &input[input_offset..input_offset + Q4K_SUPERBLOCK_SIZE],
157 );
158 }
159
160 *out = sum;
161 });
162 }
163
164 /// Execute tiled matvec with parallel row processing (fallback)
165 ///
166 /// When `parallel` feature is not enabled, this is equivalent to `execute_scalar`.
167 #[cfg(not(feature = "parallel"))]
168 pub fn execute_parallel(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
169 self.execute_scalar(weights, input, output);
170 }
171
172 /// Scalar dot product for a single Q4_K superblock
173 ///
174 /// # Performance
175 ///
176 /// Optimized version with:
177 /// - Precomputed scale/min pairs
178 /// - Loop unrolling hints
179 /// - Minimized branching in inner loop
180 #[inline]
181 fn scalar_superblock_dot(&self, sb_data: &[u8], input: &[f32]) -> f32 {
182 // Read header (hot path optimized)
183 let d = f16_to_f32(sb_data.get(0..2).expect("Q4_K: need ≥2 bytes for d"));
184 let dmin = f16_to_f32(sb_data.get(2..4).expect("Q4_K: need ≥4 bytes for dmin"));
185 let scales = sb_data.get(4..16).expect("Q4_K: need ≥16 bytes for scales");
186 let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
187
188 // Precompute all scale/min pairs upfront
189 let scale_mins = precompute_scales_mins(scales);
190
191 let mut sum = 0.0f32;
192
193 // Process 256 values in 4 pairs of sub-blocks (GGML Q4_K layout).
194 // Each pair shares 32 qs bytes: even SB uses low nibbles, odd SB uses high nibbles.
195 // pair 0: SB 0 (lo) + SB 1 (hi) → qs[0..31] → K offsets 0..63
196 // pair 1: SB 2 (lo) + SB 3 (hi) → qs[32..63] → K offsets 64..127
197 // pair 2: SB 4 (lo) + SB 5 (hi) → qs[64..95] → K offsets 128..191
198 // pair 3: SB 6 (lo) + SB 7 (hi) → qs[96..127] → K offsets 192..255
199 for pair in 0..4 {
200 let sb_lo = pair * 2;
201 let sb_hi = pair * 2 + 1;
202
203 let (sc_lo, mn_lo) = scale_mins[sb_lo];
204 let (sc_hi, mn_hi) = scale_mins[sb_hi];
205
206 let d_scale_lo = d * sc_lo;
207 let dm_lo = dmin * mn_lo;
208 let d_scale_hi = d * sc_hi;
209 let dm_hi = dmin * mn_hi;
210
211 let q_offset = pair * 32; // 32 qs bytes per pair
212 let input_lo = pair * 64; // low nibbles: first 32 K values
213 let input_hi = pair * 64 + 32; // high nibbles: next 32 K values
214
215 let mut pair_sum = 0.0f32;
216
217 // 32 qs bytes: low nibble → SB_lo values, high nibble → SB_hi values
218 for i in 0..32 {
219 let byte = qs[q_offset + i];
220
221 let q_lo = (byte & 0x0F) as f32;
222 let q_hi = (byte >> 4) as f32;
223
224 let val_lo = d_scale_lo * q_lo - dm_lo;
225 let val_hi = d_scale_hi * q_hi - dm_hi;
226
227 pair_sum += val_lo * input[input_lo + i];
228 pair_sum += val_hi * input[input_hi + i];
229 }
230
231 sum += pair_sum;
232 }
233
234 sum
235 }
236
237 /// Get tiling statistics for profiling
238 #[must_use]
239 #[allow(clippy::cast_precision_loss)] // Matrix dimensions ≤ 2^24; precision loss acceptable for profiling stats
240 pub fn stats(&self) -> TilingStats {
241 let bytes_per_row = self.superblocks_per_row() * Q4K_SUPERBLOCK_BYTES;
242 let total_weight_bytes = self.m * bytes_per_row;
243 let input_bytes = self.k * 4;
244 let output_bytes = self.m * 4;
245
246 TilingStats {
247 total_weight_bytes,
248 input_bytes,
249 output_bytes,
250 superblocks: self.total_superblocks(),
251 arithmetic_ops: self.m * self.k * 2, // 2 ops per element (mul + add)
252 arithmetic_intensity: (self.m * self.k * 2) as f32
253 / (total_weight_bytes + input_bytes) as f32,
254 }
255 }
256}
257
258/// Statistics for a tiled operation
259#[derive(Debug, Clone)]
260pub struct TilingStats {
261 /// Total weight bytes
262 pub total_weight_bytes: usize,
263 /// Input vector bytes
264 pub input_bytes: usize,
265 /// Output vector bytes
266 pub output_bytes: usize,
267 /// Number of superblocks
268 pub superblocks: usize,
269 /// Total arithmetic operations
270 pub arithmetic_ops: usize,
271 /// Arithmetic intensity (FLOPS/byte)
272 pub arithmetic_intensity: f32,
273}
274
275/// Convert 2 bytes (f16 IEEE 754) to f32
276///
277/// Manual implementation to avoid half crate dependency.
278/// Format: 1 sign bit, 5 exponent bits, 10 mantissa bits.
279///
280/// # Performance
281///
282/// Optimized for the common case (normal numbers). Special cases (zero,
283/// subnormal, inf, nan) use branches but are rare in practice for model weights.
284#[inline]
285pub fn f16_to_f32(bytes: &[u8]) -> f32 {
286 let bits = u16::from_le_bytes([bytes[0], bytes[1]]);
287 f16_bits_to_f32(bits)
288}
289
290/// Fast path f16 to f32 conversion from raw bits
291///
292/// Optimized version that handles the common case (normal numbers) with
293/// minimal branching. Uses branchless bit manipulation for the hot path.
294#[inline(always)]
295fn f16_bits_to_f32(bits: u16) -> f32 {
296 let sign = (bits >> 15) & 0x1;
297 let exponent = (bits >> 10) & 0x1F;
298 let mantissa = bits & 0x3FF;
299
300 // Fast path: normal numbers (exponent != 0 && exponent != 31)
301 // This is the common case for model weights
302 if exponent != 0 && exponent != 31 {
303 // Branchless conversion for normal numbers
304 // f16 bias = 15, f32 bias = 127
305 let f32_exp = (exponent as u32 + 112) as u32; // 127 - 15 = 112
306 let f32_mant = (mantissa as u32) << 13; // 10 bits -> 23 bits
307 let f32_bits = ((sign as u32) << 31) | (f32_exp << 23) | f32_mant;
308 return f32::from_bits(f32_bits);
309 }
310
311 // Cold path: special cases (zero, subnormal, inf, nan)
312 f16_special_to_f32(sign, exponent, mantissa)
313}
314
315/// Handle f16 special cases (zero, subnormal, inf, nan)
316///
317/// Cold path - marked to help branch prediction
318#[cold]
319#[inline(never)]
320fn f16_special_to_f32(sign: u16, exponent: u16, mantissa: u16) -> f32 {
321 if exponent == 0 {
322 if mantissa == 0 {
323 // Zero (positive or negative)
324 return if sign == 1 { -0.0 } else { 0.0 };
325 }
326 // Subnormal f16 -> normalized f32
327 // 2^-14 as constant to avoid powi() call
328 const TWO_POW_NEG_14: f32 = 6.103_515_625e-5; // 2^-14
329 let m = mantissa as f32 * (1.0 / 1024.0);
330 let result = m * TWO_POW_NEG_14;
331 return if sign == 1 { -result } else { result };
332 }
333
334 // exponent == 31: Inf or NaN
335 if mantissa == 0 {
336 if sign == 1 {
337 f32::NEG_INFINITY
338 } else {
339 f32::INFINITY
340 }
341 } else {
342 f32::NAN
343 }
344}
345
346/// Extract 6-bit scale and min values from packed scales array
347///
348/// Extract Q4K scale and min for a given sub-block index (0-7).
349///
350/// GGML Q4_K_M split storage format (12 bytes encode 8 scale/min pairs):
351/// bytes[0..3]: bits[5:0] = scale SB 0-3, bits[7:6] = high 2 bits of scale SB 4-7
352/// bytes[4..7]: bits[5:0] = min SB 0-3, bits[7:6] = high 2 bits of min SB 4-7
353/// bytes[8..11]: bits[3:0] = low 4 bits of scale SB 4-7,
354/// bits[7:4] = low 4 bits of min SB 4-7
355///
356/// Reference: ggml `block_q4_K.scales`, llama.cpp `vec_dot_q4_K_q8_1`,
357/// trueno `backends::q4k::parse_q4k_header`.
358#[inline(always)]
359#[allow(clippy::cast_precision_loss)]
360// SAFETY: scale and min are 6-bit values (0-63) stored in u32; lossless in f32.
361pub fn extract_scale_min_6bit(scales: &[u8], idx: usize) -> (f32, f32) {
362 debug_assert!(scales.len() >= 12, "scales array must be at least 12 bytes");
363 debug_assert!(idx < 8, "idx must be < 8");
364
365 if idx < 4 {
366 // Sub-blocks 0-3: 6-bit values from bytes 0-3 (scale) and 4-7 (min)
367 let scale = (scales[idx] & 0x3F) as u32;
368 let min = (scales[4 + idx] & 0x3F) as u32;
369 (scale as f32, min as f32)
370 } else {
371 // Sub-blocks 4-7: low 4 bits from combo byte, high 2 from upper bits
372 let i = idx - 4;
373 let combo = scales[8 + i];
374
375 // Scale: low 4 bits from combo[3:0], high 2 from scales[i][7:6]
376 let sc_low4 = (combo & 0x0F) as u32;
377 let sc_high2 = ((scales[i] >> 6) & 0x03) as u32;
378 let scale = sc_low4 | (sc_high2 << 4);
379
380 // Min: low 4 bits from combo[7:4], high 2 from scales[4+i][7:6]
381 let mn_low4 = ((combo >> 4) & 0x0F) as u32;
382 let mn_high2 = ((scales[4 + i] >> 6) & 0x03) as u32;
383 let min = mn_low4 | (mn_high2 << 4);
384
385 (scale as f32, min as f32)
386 }
387}
388
389/// Precompute all 8 scale/min pairs for a Q4_K superblock
390///
391/// More efficient than calling extract_scale_min_6bit 8 times when
392/// we need all values (which is the common case).
393#[inline]
394fn precompute_scales_mins(scales: &[u8]) -> [(f32, f32); 8] {
395 debug_assert!(scales.len() >= 12);
396
397 // Unroll the extraction for all 8 chunks
398 [
399 extract_scale_min_6bit(scales, 0),
400 extract_scale_min_6bit(scales, 1),
401 extract_scale_min_6bit(scales, 2),
402 extract_scale_min_6bit(scales, 3),
403 extract_scale_min_6bit(scales, 4),
404 extract_scale_min_6bit(scales, 5),
405 extract_scale_min_6bit(scales, 6),
406 extract_scale_min_6bit(scales, 7),
407 ]
408}