trueno/backends/q4k/mod.rs
1//! Fused Q4_K Matrix-Vector Multiply (F-GPU-130)
2//!
3//! This module implements fused quantized matrix-vector multiplication that operates
4//! directly on Q4_K compressed weights without full dequantization.
5//!
6//! # Q4_K Format (llama.cpp compatible)
7//!
8//! Super-block layout (144 bytes per 256 elements):
9//! - `d`: 2 bytes (f16 global scale)
10//! - `dmin`: 2 bytes (f16 global min scale)
11//! - `scales`: 12 bytes (packed 6-bit scales and mins for 8 sub-blocks)
12//! - `qs`: 128 bytes (4-bit quantized values, interleaved low/high nibbles)
13//!
14//! # Golden Test Invariant (Section 12.4 of spec)
15//!
16//! For all Q4K weight W and input x:
17//! ```text
18//! matmul_q4k_f32(W, x) ≈ matmul(dequant_q4k_to_f32(W), x) within ε = 1e-3
19//! ```
20//!
21//! # Performance Targets
22//!
23//! - Baseline (dequant+matmul): 0.27 tok/s
24//! - Target (fused): >5 tok/s CPU, >100 tok/s GPU
25//!
26//! # Example
27//!
28//! ```rust,ignore
29//! use trueno::backends::q4k::matmul_q4k_f32;
30//!
31//! let q4k_weights = load_q4k_tensor("gate_proj.weight");
32//! let input = vec![1.0f32; 896];
33//! let output = matmul_q4k_f32(&q4k_weights, &input, 4864, 896);
34//! ```
35
36// Sub-modules
37mod colmajor;
38mod dequant;
39mod gemv;
40
41// Re-exports
42#[allow(deprecated)]
43pub use colmajor::{matmul_q4k_f32_colmajor, matmul_q4k_f32_colmajor_dispatch};
44pub use dequant::dequantize_q4k_to_f32;
45pub use gemv::{matmul_q4k_f32, matmul_q4k_f32_dispatch, matmul_q4k_f32_scalar};
46
47// Constants (pub(crate) for submodule access)
48pub(crate) const SUPER_BLOCK_SIZE: usize = 256;
49pub(crate) const SUPER_BLOCK_BYTES: usize = 144;
50pub(crate) const _SUB_BLOCK_SIZE: usize = 32; // Reserved for future sub-block optimizations
51
52/// Convert f16 bits to f32.
53///
54/// NOTE: F16C hardware instruction (`_mm_cvtph_ps`) was tested (2026-04-05)
55/// but the per-call `is_x86_feature_detected` overhead negated the gain.
56/// The scalar path is already well-optimized by LLVM for typical Q4K scales
57/// (normal f16 values that hit the fast path without subnormal branching).
58/// The Q4K bottleneck is the FMA dependency chain, not header parsing.
59#[inline(always)]
60fn f16_to_f32(bits: u16) -> f32 {
61 let sign = ((bits & 0x8000) as u32) << 16;
62 let exp = (bits >> 10) & 0x1F;
63 let mantissa = (bits & 0x3FF) as u32;
64
65 if exp == 0 {
66 if mantissa == 0 {
67 f32::from_bits(sign)
68 } else {
69 // Subnormal
70 let mut m = mantissa;
71 let mut e = 0i32;
72 while (m & 0x400) == 0 {
73 m <<= 1;
74 e -= 1;
75 }
76 let new_exp = ((127 - 15 + 1 + e) as u32) << 23;
77 let new_mantissa = (m & 0x3FF) << 13;
78 f32::from_bits(sign | new_exp | new_mantissa)
79 }
80 } else if exp == 31 {
81 f32::from_bits(sign | (0xFF << 23) | (mantissa << 13))
82 } else {
83 let new_exp = ((exp as i32 - 15 + 127) as u32) << 23;
84 f32::from_bits(sign | new_exp | (mantissa << 13))
85 }
86}
87
88/// Parse Q4_K super-block header and scales
89///
90/// Returns (d, dmin, scales[8], mins[8])
91#[inline(always)]
92pub(crate) fn parse_q4k_header(block: &[u8]) -> (f32, f32, [u8; 8], [u8; 8]) {
93 debug_assert!(block.len() >= 16);
94
95 // Read d and dmin (f16)
96 let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
97 let dmin = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
98
99 // Unpack scales and mins (llama.cpp format)
100 let scales_bytes = block.get(4..16).expect("Q4_K: need ≥16 bytes for header");
101 let mut scales = [0u8; 8];
102 let mut mins = [0u8; 8];
103
104 for i in 0..4 {
105 // Blocks 0-3: lower 6 bits of bytes 0-3 and 4-7
106 scales[i] = scales_bytes[i] & 0x3F;
107 mins[i] = scales_bytes[i + 4] & 0x3F;
108 // Blocks 4-7: lower 4 bits from bytes 8-11, upper 2 bits from bytes 0-3/4-7
109 scales[i + 4] = (scales_bytes[i + 8] & 0x0F) | ((scales_bytes[i] >> 6) << 4);
110 mins[i + 4] = (scales_bytes[i + 8] >> 4) | ((scales_bytes[i + 4] >> 6) << 4);
111 }
112
113 (d, dmin, scales, mins)
114}
115
116#[cfg(test)]
117mod tests_core;
118#[cfg(test)]
119mod tests_coverage;
120#[cfg(test)]
121mod tests_golden;