Skip to main content

trueno/brick/quant_ops/
mod.rs

1//! Q5_K and Q6_K Quantization Operations (llama.cpp compatible)
2//!
3//! This module provides quantization formats and compute operations
4//! for llama.cpp-compatible k-quant formats.
5//!
6//! # Formats
7//!
8//! - `BlockQ5K`: 5-bit quantization with super-blocks (256 values)
9//! - `BlockQ6K`: 6-bit quantization with super-blocks (256 values)
10//!
11//! # Operations
12//!
13//! - `DotQ5KOp`: Dot product with Q5_K quantized weights
14//! - `DotQ6KOp`: Dot product with Q6_K quantized weights
15//!
16//! # SIMD Optimization
17//!
18//! Both operations use AVX2/FMA when available for ~4x speedup.
19
20use super::{Backend, ComputeOp};
21use crate::error::TruenoError;
22
23// ============================================================================
24// Q5_K Block Format
25// ============================================================================
26
27/// Q5_K block format (5-bit with super-blocks).
28///
29/// Matches llama.cpp's block_q5_K format:
30/// - Super-block of 256 values
31/// - 5-bit quantization with k-quant scales
32/// - Higher precision than Q4_K, lower than Q6_K
33///
34/// Memory layout:
35/// ```text
36/// | d (fp16) | dmin (fp16) | scales[12] | qh[32] | qs[128] |
37/// ```
38#[derive(Debug, Clone)]
39pub struct BlockQ5K {
40    /// Scale factor (half precision)
41    pub d: f32,
42    /// Minimum value scale (half precision)
43    pub dmin: f32,
44    /// Scales for each 32-value block (12 bytes packed)
45    pub scales: [u8; 12],
46    /// High bits for quantized values (32 bytes)
47    pub qh: [u8; 32],
48    /// Quantized values (128 bytes, 2 values per byte)
49    pub qs: [u8; 128],
50}
51
52impl BlockQ5K {
53    /// Block size in elements
54    pub const BLOCK_SIZE: usize = 256;
55
56    /// Dequantize a Q5_K block to f32.
57    ///
58    /// # Safety
59    ///
60    /// Output buffer must have at least BLOCK_SIZE elements.
61    pub fn dequantize(&self, output: &mut [f32]) {
62        debug_assert!(output.len() >= Self::BLOCK_SIZE);
63
64        // Decode scales from packed format
65        let mut scales = [0i8; 8];
66        for i in 0..8 {
67            let low = (self.scales[i] & 0x3F) as i8;
68            scales[i] = low - 32;
69        }
70
71        // Dequantize each sub-block
72        for block_idx in 0..8 {
73            let scale = scales[block_idx] as f32;
74            let base_idx = block_idx * 32;
75
76            for i in 0..32 {
77                let out_idx = base_idx + i;
78                let byte_idx = base_idx / 2 + i / 2;
79
80                // Extract 4-bit low value
81                let q4 = if i % 2 == 0 { self.qs[byte_idx] & 0x0F } else { self.qs[byte_idx] >> 4 };
82
83                // Extract 5th bit from qh
84                let qh_bit = ((self.qh[i] >> block_idx) & 1) as u8;
85                let q5 = q4 | (qh_bit << 4);
86
87                // Dequantize: value = d * scale * (q5 - 16) + dmin
88                output[out_idx] = self.d * scale * (q5 as f32 - 16.0) + self.dmin;
89            }
90        }
91    }
92}
93
94// ============================================================================
95// Q6_K Block Format
96// ============================================================================
97
98/// Q6_K block format (6-bit with super-blocks).
99///
100/// Matches llama.cpp's block_q6_K format:
101/// - Super-block of 256 values
102/// - 6-bit quantization with k-quant scales
103/// - Highest precision k-quant format
104///
105/// Memory layout:
106/// ```text
107/// | ql[128] | qh[64] | scales[16] | d (fp16) |
108/// ```
109#[derive(Debug, Clone)]
110pub struct BlockQ6K {
111    /// Low 4 bits of quantized values (128 bytes)
112    pub ql: [u8; 128],
113    /// High 2 bits of quantized values (64 bytes)
114    pub qh: [u8; 64],
115    /// Scales for each 16-value block (16 bytes)
116    pub scales: [i8; 16],
117    /// Scale factor (half precision)
118    pub d: f32,
119}
120
121impl BlockQ6K {
122    /// Block size in elements
123    pub const BLOCK_SIZE: usize = 256;
124
125    /// Dequantize a Q6_K block to f32.
126    ///
127    /// # Safety
128    ///
129    /// Output buffer must have at least BLOCK_SIZE elements.
130    pub fn dequantize(&self, output: &mut [f32]) {
131        debug_assert!(output.len() >= Self::BLOCK_SIZE);
132
133        // Dequantize each sub-block of 16 values
134        for block_idx in 0..16 {
135            let scale = self.scales[block_idx] as f32;
136            let base_idx = block_idx * 16;
137
138            for i in 0..16 {
139                let out_idx = base_idx + i;
140                let ql_idx = base_idx / 2 + i / 2;
141                let qh_idx = base_idx / 4 + i / 4;
142
143                // Extract 4-bit low value
144                let ql_val = if i % 2 == 0 { self.ql[ql_idx] & 0x0F } else { self.ql[ql_idx] >> 4 };
145
146                // Extract 2-bit high value
147                let qh_shift = (i % 4) * 2;
148                let qh_val = ((self.qh[qh_idx] >> qh_shift) & 0x03) as u8;
149
150                // Combine to 6-bit value
151                let q6 = ql_val | (qh_val << 4);
152
153                // Dequantize: value = d * scale * (q6 - 32)
154                output[out_idx] = self.d * scale * (q6 as f32 - 32.0);
155            }
156        }
157    }
158}
159
160// ============================================================================
161// Q5_K Dot Product Operation
162// ============================================================================
163
164/// Q5_K dot product operation.
165///
166/// Computes dot product between Q5_K quantized weights and f32 activations.
167#[derive(Debug, Clone)]
168pub struct DotQ5KOp {
169    /// Number of blocks
170    pub n_blocks: usize,
171}
172
173impl DotQ5KOp {
174    /// Create a new Q5_K dot product operation.
175    #[must_use]
176    pub fn new(n_elements: usize) -> Self {
177        Self { n_blocks: n_elements / BlockQ5K::BLOCK_SIZE }
178    }
179
180    /// Compute dot product with SIMD acceleration.
181    #[cfg(target_arch = "x86_64")]
182    #[target_feature(enable = "avx2", enable = "fma")]
183    // SAFETY: caller verifies AVX2 support, input slices meet alignment/length requirements
184    unsafe fn avx2_dot_block(block: &BlockQ5K, x: &[f32]) -> f32 {
185        unsafe {
186            use std::arch::x86_64::*;
187
188            let mut acc = _mm256_setzero_ps();
189            let mut dequant = [0.0f32; BlockQ5K::BLOCK_SIZE];
190            block.dequantize(&mut dequant);
191
192            let mut i = 0;
193            while i + 8 <= BlockQ5K::BLOCK_SIZE {
194                let vd = _mm256_loadu_ps(dequant.as_ptr().add(i));
195                let vx = _mm256_loadu_ps(x.as_ptr().add(i));
196                acc = _mm256_fmadd_ps(vd, vx, acc);
197                i += 8;
198            }
199
200            // Horizontal sum
201            let high = _mm256_extractf128_ps(acc, 1);
202            let low = _mm256_castps256_ps128(acc);
203            let sum128 = _mm_add_ps(high, low);
204            let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
205            let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
206            _mm_cvtss_f32(sum32)
207        }
208    }
209}
210
211impl ComputeOp for DotQ5KOp {
212    type Input = (Vec<BlockQ5K>, Vec<f32>);
213    type Output = f32;
214
215    fn name(&self) -> &'static str {
216        "dot_q5k"
217    }
218
219    fn execute(&self, input: Self::Input, backend: Backend) -> Result<Self::Output, TruenoError> {
220        let (blocks, x) = input;
221
222        if blocks.is_empty() || x.is_empty() {
223            return Ok(0.0);
224        }
225
226        let mut sum = 0.0f32;
227
228        #[cfg(target_arch = "x86_64")]
229        {
230            if matches!(backend, Backend::Avx2 | Backend::Auto) && is_x86_feature_detected!("avx2")
231            {
232                for (i, block) in blocks.iter().enumerate() {
233                    let x_slice = &x[i * BlockQ5K::BLOCK_SIZE..];
234                    // SAFETY: preconditions verified by caller
235                    sum += unsafe { Self::avx2_dot_block(block, x_slice) };
236                }
237                return Ok(sum);
238            }
239        }
240
241        // Scalar fallback
242        let mut dequant = [0.0f32; BlockQ5K::BLOCK_SIZE];
243        for (i, block) in blocks.iter().enumerate() {
244            block.dequantize(&mut dequant);
245            let x_slice = &x[i * BlockQ5K::BLOCK_SIZE..];
246            for j in 0..BlockQ5K::BLOCK_SIZE {
247                sum += dequant[j] * x_slice[j];
248            }
249        }
250
251        Ok(sum)
252    }
253
254    fn tokens(&self, _input: &Self::Input) -> usize {
255        self.n_blocks * BlockQ5K::BLOCK_SIZE
256    }
257}
258
259// ============================================================================
260// Q6_K Dot Product Operation
261// ============================================================================
262
263/// Q6_K dot product operation.
264///
265/// Computes dot product between Q6_K quantized weights and f32 activations.
266#[derive(Debug, Clone)]
267pub struct DotQ6KOp {
268    /// Number of blocks
269    pub n_blocks: usize,
270}
271
272impl DotQ6KOp {
273    /// Create a new Q6_K dot product operation.
274    #[must_use]
275    pub fn new(n_elements: usize) -> Self {
276        Self { n_blocks: n_elements / BlockQ6K::BLOCK_SIZE }
277    }
278
279    /// Compute dot product with SIMD acceleration.
280    #[cfg(target_arch = "x86_64")]
281    #[target_feature(enable = "avx2", enable = "fma")]
282    // SAFETY: caller verifies AVX2 support, input slices meet alignment/length requirements
283    unsafe fn avx2_dot_block(block: &BlockQ6K, x: &[f32]) -> f32 {
284        unsafe {
285            use std::arch::x86_64::*;
286
287            let mut acc = _mm256_setzero_ps();
288            let mut dequant = [0.0f32; BlockQ6K::BLOCK_SIZE];
289            block.dequantize(&mut dequant);
290
291            let mut i = 0;
292            while i + 8 <= BlockQ6K::BLOCK_SIZE {
293                let vd = _mm256_loadu_ps(dequant.as_ptr().add(i));
294                let vx = _mm256_loadu_ps(x.as_ptr().add(i));
295                acc = _mm256_fmadd_ps(vd, vx, acc);
296                i += 8;
297            }
298
299            // Horizontal sum
300            let high = _mm256_extractf128_ps(acc, 1);
301            let low = _mm256_castps256_ps128(acc);
302            let sum128 = _mm_add_ps(high, low);
303            let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
304            let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
305            _mm_cvtss_f32(sum32)
306        }
307    }
308}
309
310impl ComputeOp for DotQ6KOp {
311    type Input = (Vec<BlockQ6K>, Vec<f32>);
312    type Output = f32;
313
314    fn name(&self) -> &'static str {
315        "dot_q6k"
316    }
317
318    fn execute(&self, input: Self::Input, backend: Backend) -> Result<Self::Output, TruenoError> {
319        let (blocks, x) = input;
320
321        if blocks.is_empty() || x.is_empty() {
322            return Ok(0.0);
323        }
324
325        let mut sum = 0.0f32;
326
327        #[cfg(target_arch = "x86_64")]
328        {
329            if matches!(backend, Backend::Avx2 | Backend::Auto) && is_x86_feature_detected!("avx2")
330            {
331                for (i, block) in blocks.iter().enumerate() {
332                    let x_slice = &x[i * BlockQ6K::BLOCK_SIZE..];
333                    // SAFETY: preconditions verified by caller
334                    sum += unsafe { Self::avx2_dot_block(block, x_slice) };
335                }
336                return Ok(sum);
337            }
338        }
339
340        // Scalar fallback
341        let mut dequant = [0.0f32; BlockQ6K::BLOCK_SIZE];
342        for (i, block) in blocks.iter().enumerate() {
343            block.dequantize(&mut dequant);
344            let x_slice = &x[i * BlockQ6K::BLOCK_SIZE..];
345            for j in 0..BlockQ6K::BLOCK_SIZE {
346                sum += dequant[j] * x_slice[j];
347            }
348        }
349
350        Ok(sum)
351    }
352
353    fn tokens(&self, _input: &Self::Input) -> usize {
354        self.n_blocks * BlockQ6K::BLOCK_SIZE
355    }
356}
357
358#[cfg(test)]
359pub mod nf4;
360mod tests;