Skip to main content

trueno/tiling/
q4k_matvec.rs

1#![allow(missing_docs)]
2//! Q4_K quantized matrix-vector tiling implementation.
3
4use super::config::TilingConfig;
5
6/// Q4_K superblock constants (per GGML specification)
7pub const Q4K_SUPERBLOCK_SIZE: usize = 256;
8pub const Q4K_SUPERBLOCK_BYTES: usize = 144;
9
10/// Tiled Q4_K MatVec executor
11///
12/// Implements TCB-01 pattern: Cache-blocked matvec with 4×1 micro-kernel.
13///
14/// # Memory Layout
15///
16/// Weights are stored in Q4_K superblock format (144 bytes per 256 elements):
17/// - d: f16 (2 bytes) - block scale
18/// - dmin: f16 (2 bytes) - block minimum
19/// - scales: 12 bytes - 8 sub-block scales (6-bit packed)
20/// - qs: 128 bytes - 256 quantized values (4-bit packed)
21///
22/// # Performance Characteristics
23///
24/// - L2-resident: Process midi_tile.m rows at a time
25/// - Vectorized: 4×1 micro-kernel processes 4 output rows simultaneously
26/// - Aligned: K dimension aligned to Q4_K superblock (256)
27#[derive(Debug, Clone)]
28pub struct TiledQ4KMatvec {
29    /// Tiling configuration
30    pub config: TilingConfig,
31    /// Number of rows (M dimension)
32    pub m: usize,
33    /// Number of columns (K dimension)
34    pub k: usize,
35}
36
37impl TiledQ4KMatvec {
38    /// Create a new tiled Q4K matvec executor
39    ///
40    /// # Panics
41    /// Panics if K is not aligned to Q4_K superblock size (256).
42    #[must_use]
43    pub fn new(m: usize, k: usize) -> Self {
44        assert!(
45            k % Q4K_SUPERBLOCK_SIZE == 0,
46            "K dimension ({}) must be aligned to Q4_K superblock size ({})",
47            k,
48            Q4K_SUPERBLOCK_SIZE
49        );
50
51        Self { config: TilingConfig::cpu_avx2_q4k_matvec(), m, k }
52    }
53
54    /// Get number of superblocks per row
55    #[must_use]
56    pub fn superblocks_per_row(&self) -> usize {
57        self.k / Q4K_SUPERBLOCK_SIZE
58    }
59
60    /// Get total number of superblocks
61    #[must_use]
62    pub fn total_superblocks(&self) -> usize {
63        self.m * self.superblocks_per_row()
64    }
65
66    /// Get weight bytes offset for a given row
67    #[must_use]
68    #[inline]
69    pub fn weight_row_offset(&self, row: usize) -> usize {
70        row * self.superblocks_per_row() * Q4K_SUPERBLOCK_BYTES
71    }
72
73    /// Calculate optimal number of parallel rows based on L2 cache
74    ///
75    /// Goal: Keep working set in L2 (256KB typical)
76    /// Working set = midi_tile.m rows × K × sizeof(Q4K) + K × sizeof(f32)
77    #[must_use]
78    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation, clippy::cast_sign_loss)]
79    // SAFETY: k ≤ 2^24 for practical matrix dims so usize→f32 is lossless;
80    //         result of f32 multiply is non-negative and fits in usize.
81    pub fn optimal_parallel_rows(&self, l2_bytes: usize) -> usize {
82        // Q4K: 144 bytes per 256 elements = 0.5625 bytes/element
83        let row_bytes = (self.k as f32 * 0.5625) as usize;
84        // Input vector: K × 4 bytes
85        let input_bytes = self.k * 4;
86        // Available for rows
87        let available = l2_bytes.saturating_sub(input_bytes);
88        // Rows that fit (minimum 4 for micro-kernel)
89        (available / row_bytes).max(4)
90    }
91
92    /// Execute tiled matvec (reference scalar implementation)
93    ///
94    /// This is the reference implementation for correctness testing.
95    /// Actual SIMD implementation would be in the backends.
96    ///
97    /// For parallel execution, use [`execute_parallel`] when the `parallel` feature is enabled.
98    pub fn execute_scalar(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
99        assert_eq!(weights.len(), self.total_superblocks() * Q4K_SUPERBLOCK_BYTES);
100        assert_eq!(input.len(), self.k);
101        assert_eq!(output.len(), self.m);
102
103        let superblocks_per_row = self.superblocks_per_row();
104
105        for row in 0..self.m {
106            let mut sum = 0.0f32;
107            let row_offset = row * superblocks_per_row * Q4K_SUPERBLOCK_BYTES;
108
109            for sb in 0..superblocks_per_row {
110                let sb_offset = row_offset + sb * Q4K_SUPERBLOCK_BYTES;
111                let sb_data = &weights[sb_offset..sb_offset + Q4K_SUPERBLOCK_BYTES];
112
113                // Dequantize and dot product for this superblock
114                let input_offset = sb * Q4K_SUPERBLOCK_SIZE;
115                sum += self.scalar_superblock_dot(
116                    sb_data,
117                    &input[input_offset..input_offset + Q4K_SUPERBLOCK_SIZE],
118                );
119            }
120
121            output[row] = sum;
122        }
123    }
124
125    /// Execute tiled matvec with parallel row processing
126    ///
127    /// Uses Rayon to parallelize across rows for multi-core speedup.
128    /// Falls back to scalar execution if the `parallel` feature is not enabled.
129    ///
130    /// # Performance
131    ///
132    /// Achieves near-linear speedup with core count for large matrices.
133    /// For small matrices (< 256 rows), scalar may be faster due to overhead.
134    #[cfg(feature = "parallel")]
135    pub fn execute_parallel(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
136        use rayon::prelude::*;
137
138        assert_eq!(weights.len(), self.total_superblocks() * Q4K_SUPERBLOCK_BYTES);
139        assert_eq!(input.len(), self.k);
140        assert_eq!(output.len(), self.m);
141
142        let superblocks_per_row = self.superblocks_per_row();
143        let row_stride = superblocks_per_row * Q4K_SUPERBLOCK_BYTES;
144
145        output.par_iter_mut().enumerate().for_each(|(row, out)| {
146            let mut sum = 0.0f32;
147            let row_offset = row * row_stride;
148
149            for sb in 0..superblocks_per_row {
150                let sb_offset = row_offset + sb * Q4K_SUPERBLOCK_BYTES;
151                let sb_data = &weights[sb_offset..sb_offset + Q4K_SUPERBLOCK_BYTES];
152
153                let input_offset = sb * Q4K_SUPERBLOCK_SIZE;
154                sum += self.scalar_superblock_dot(
155                    sb_data,
156                    &input[input_offset..input_offset + Q4K_SUPERBLOCK_SIZE],
157                );
158            }
159
160            *out = sum;
161        });
162    }
163
164    /// Execute tiled matvec with parallel row processing (fallback)
165    ///
166    /// When `parallel` feature is not enabled, this is equivalent to `execute_scalar`.
167    #[cfg(not(feature = "parallel"))]
168    pub fn execute_parallel(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
169        self.execute_scalar(weights, input, output);
170    }
171
172    /// Scalar dot product for a single Q4_K superblock
173    ///
174    /// # Performance
175    ///
176    /// Optimized version with:
177    /// - Precomputed scale/min pairs
178    /// - Loop unrolling hints
179    /// - Minimized branching in inner loop
180    #[inline]
181    fn scalar_superblock_dot(&self, sb_data: &[u8], input: &[f32]) -> f32 {
182        // Read header (hot path optimized)
183        let d = f16_to_f32(sb_data.get(0..2).expect("Q4_K: need ≥2 bytes for d"));
184        let dmin = f16_to_f32(sb_data.get(2..4).expect("Q4_K: need ≥4 bytes for dmin"));
185        let scales = sb_data.get(4..16).expect("Q4_K: need ≥16 bytes for scales");
186        let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
187
188        // Precompute all scale/min pairs upfront
189        let scale_mins = precompute_scales_mins(scales);
190
191        let mut sum = 0.0f32;
192
193        // Process 256 values in 4 pairs of sub-blocks (GGML Q4_K layout).
194        // Each pair shares 32 qs bytes: even SB uses low nibbles, odd SB uses high nibbles.
195        // pair 0: SB 0 (lo) + SB 1 (hi) → qs[0..31]   → K offsets 0..63
196        // pair 1: SB 2 (lo) + SB 3 (hi) → qs[32..63]  → K offsets 64..127
197        // pair 2: SB 4 (lo) + SB 5 (hi) → qs[64..95]  → K offsets 128..191
198        // pair 3: SB 6 (lo) + SB 7 (hi) → qs[96..127] → K offsets 192..255
199        for pair in 0..4 {
200            let sb_lo = pair * 2;
201            let sb_hi = pair * 2 + 1;
202
203            let (sc_lo, mn_lo) = scale_mins[sb_lo];
204            let (sc_hi, mn_hi) = scale_mins[sb_hi];
205
206            let d_scale_lo = d * sc_lo;
207            let dm_lo = dmin * mn_lo;
208            let d_scale_hi = d * sc_hi;
209            let dm_hi = dmin * mn_hi;
210
211            let q_offset = pair * 32; // 32 qs bytes per pair
212            let input_lo = pair * 64; // low nibbles: first 32 K values
213            let input_hi = pair * 64 + 32; // high nibbles: next 32 K values
214
215            let mut pair_sum = 0.0f32;
216
217            // 32 qs bytes: low nibble → SB_lo values, high nibble → SB_hi values
218            for i in 0..32 {
219                let byte = qs[q_offset + i];
220
221                let q_lo = (byte & 0x0F) as f32;
222                let q_hi = (byte >> 4) as f32;
223
224                let val_lo = d_scale_lo * q_lo - dm_lo;
225                let val_hi = d_scale_hi * q_hi - dm_hi;
226
227                pair_sum += val_lo * input[input_lo + i];
228                pair_sum += val_hi * input[input_hi + i];
229            }
230
231            sum += pair_sum;
232        }
233
234        sum
235    }
236
237    /// Get tiling statistics for profiling
238    #[must_use]
239    #[allow(clippy::cast_precision_loss)] // Matrix dimensions ≤ 2^24; precision loss acceptable for profiling stats
240    pub fn stats(&self) -> TilingStats {
241        let bytes_per_row = self.superblocks_per_row() * Q4K_SUPERBLOCK_BYTES;
242        let total_weight_bytes = self.m * bytes_per_row;
243        let input_bytes = self.k * 4;
244        let output_bytes = self.m * 4;
245
246        TilingStats {
247            total_weight_bytes,
248            input_bytes,
249            output_bytes,
250            superblocks: self.total_superblocks(),
251            arithmetic_ops: self.m * self.k * 2, // 2 ops per element (mul + add)
252            arithmetic_intensity: (self.m * self.k * 2) as f32
253                / (total_weight_bytes + input_bytes) as f32,
254        }
255    }
256}
257
258/// Statistics for a tiled operation
259#[derive(Debug, Clone)]
260pub struct TilingStats {
261    /// Total weight bytes
262    pub total_weight_bytes: usize,
263    /// Input vector bytes
264    pub input_bytes: usize,
265    /// Output vector bytes
266    pub output_bytes: usize,
267    /// Number of superblocks
268    pub superblocks: usize,
269    /// Total arithmetic operations
270    pub arithmetic_ops: usize,
271    /// Arithmetic intensity (FLOPS/byte)
272    pub arithmetic_intensity: f32,
273}
274
275/// Convert 2 bytes (f16 IEEE 754) to f32
276///
277/// Manual implementation to avoid half crate dependency.
278/// Format: 1 sign bit, 5 exponent bits, 10 mantissa bits.
279///
280/// # Performance
281///
282/// Optimized for the common case (normal numbers). Special cases (zero,
283/// subnormal, inf, nan) use branches but are rare in practice for model weights.
284#[inline]
285pub fn f16_to_f32(bytes: &[u8]) -> f32 {
286    let bits = u16::from_le_bytes([bytes[0], bytes[1]]);
287    f16_bits_to_f32(bits)
288}
289
290/// Fast path f16 to f32 conversion from raw bits
291///
292/// Optimized version that handles the common case (normal numbers) with
293/// minimal branching. Uses branchless bit manipulation for the hot path.
294#[inline(always)]
295fn f16_bits_to_f32(bits: u16) -> f32 {
296    let sign = (bits >> 15) & 0x1;
297    let exponent = (bits >> 10) & 0x1F;
298    let mantissa = bits & 0x3FF;
299
300    // Fast path: normal numbers (exponent != 0 && exponent != 31)
301    // This is the common case for model weights
302    if exponent != 0 && exponent != 31 {
303        // Branchless conversion for normal numbers
304        // f16 bias = 15, f32 bias = 127
305        let f32_exp = (exponent as u32 + 112) as u32; // 127 - 15 = 112
306        let f32_mant = (mantissa as u32) << 13; // 10 bits -> 23 bits
307        let f32_bits = ((sign as u32) << 31) | (f32_exp << 23) | f32_mant;
308        return f32::from_bits(f32_bits);
309    }
310
311    // Cold path: special cases (zero, subnormal, inf, nan)
312    f16_special_to_f32(sign, exponent, mantissa)
313}
314
315/// Handle f16 special cases (zero, subnormal, inf, nan)
316///
317/// Cold path - marked to help branch prediction
318#[cold]
319#[inline(never)]
320fn f16_special_to_f32(sign: u16, exponent: u16, mantissa: u16) -> f32 {
321    if exponent == 0 {
322        if mantissa == 0 {
323            // Zero (positive or negative)
324            return if sign == 1 { -0.0 } else { 0.0 };
325        }
326        // Subnormal f16 -> normalized f32
327        // 2^-14 as constant to avoid powi() call
328        const TWO_POW_NEG_14: f32 = 6.103_515_625e-5; // 2^-14
329        let m = mantissa as f32 * (1.0 / 1024.0);
330        let result = m * TWO_POW_NEG_14;
331        return if sign == 1 { -result } else { result };
332    }
333
334    // exponent == 31: Inf or NaN
335    if mantissa == 0 {
336        if sign == 1 {
337            f32::NEG_INFINITY
338        } else {
339            f32::INFINITY
340        }
341    } else {
342        f32::NAN
343    }
344}
345
346/// Extract 6-bit scale and min values from packed scales array
347///
348/// Extract Q4K scale and min for a given sub-block index (0-7).
349///
350/// GGML Q4_K_M split storage format (12 bytes encode 8 scale/min pairs):
351///   bytes[0..3]:  bits[5:0] = scale SB 0-3,  bits[7:6] = high 2 bits of scale SB 4-7
352///   bytes[4..7]:  bits[5:0] = min SB 0-3,    bits[7:6] = high 2 bits of min SB 4-7
353///   bytes[8..11]: bits[3:0] = low 4 bits of scale SB 4-7,
354///                 bits[7:4] = low 4 bits of min SB 4-7
355///
356/// Reference: ggml `block_q4_K.scales`, llama.cpp `vec_dot_q4_K_q8_1`,
357///            trueno `backends::q4k::parse_q4k_header`.
358#[inline(always)]
359#[allow(clippy::cast_precision_loss)]
360// SAFETY: scale and min are 6-bit values (0-63) stored in u32; lossless in f32.
361pub fn extract_scale_min_6bit(scales: &[u8], idx: usize) -> (f32, f32) {
362    debug_assert!(scales.len() >= 12, "scales array must be at least 12 bytes");
363    debug_assert!(idx < 8, "idx must be < 8");
364
365    if idx < 4 {
366        // Sub-blocks 0-3: 6-bit values from bytes 0-3 (scale) and 4-7 (min)
367        let scale = (scales[idx] & 0x3F) as u32;
368        let min = (scales[4 + idx] & 0x3F) as u32;
369        (scale as f32, min as f32)
370    } else {
371        // Sub-blocks 4-7: low 4 bits from combo byte, high 2 from upper bits
372        let i = idx - 4;
373        let combo = scales[8 + i];
374
375        // Scale: low 4 bits from combo[3:0], high 2 from scales[i][7:6]
376        let sc_low4 = (combo & 0x0F) as u32;
377        let sc_high2 = ((scales[i] >> 6) & 0x03) as u32;
378        let scale = sc_low4 | (sc_high2 << 4);
379
380        // Min: low 4 bits from combo[7:4], high 2 from scales[4+i][7:6]
381        let mn_low4 = ((combo >> 4) & 0x0F) as u32;
382        let mn_high2 = ((scales[4 + i] >> 6) & 0x03) as u32;
383        let min = mn_low4 | (mn_high2 << 4);
384
385        (scale as f32, min as f32)
386    }
387}
388
389/// Precompute all 8 scale/min pairs for a Q4_K superblock
390///
391/// More efficient than calling extract_scale_min_6bit 8 times when
392/// we need all values (which is the common case).
393#[inline]
394fn precompute_scales_mins(scales: &[u8]) -> [(f32, f32); 8] {
395    debug_assert!(scales.len() >= 12);
396
397    // Unroll the extraction for all 8 chunks
398    [
399        extract_scale_min_6bit(scales, 0),
400        extract_scale_min_6bit(scales, 1),
401        extract_scale_min_6bit(scales, 2),
402        extract_scale_min_6bit(scales, 3),
403        extract_scale_min_6bit(scales, 4),
404        extract_scale_min_6bit(scales, 5),
405        extract_scale_min_6bit(scales, 6),
406        extract_scale_min_6bit(scales, 7),
407    ]
408}