trueno/brick/quant_ops/
mod.rs1use super::{Backend, ComputeOp};
21use crate::error::TruenoError;
22
23#[derive(Debug, Clone)]
39pub struct BlockQ5K {
40 pub d: f32,
42 pub dmin: f32,
44 pub scales: [u8; 12],
46 pub qh: [u8; 32],
48 pub qs: [u8; 128],
50}
51
52impl BlockQ5K {
53 pub const BLOCK_SIZE: usize = 256;
55
56 pub fn dequantize(&self, output: &mut [f32]) {
62 debug_assert!(output.len() >= Self::BLOCK_SIZE);
63
64 let mut scales = [0i8; 8];
66 for i in 0..8 {
67 let low = (self.scales[i] & 0x3F) as i8;
68 scales[i] = low - 32;
69 }
70
71 for block_idx in 0..8 {
73 let scale = scales[block_idx] as f32;
74 let base_idx = block_idx * 32;
75
76 for i in 0..32 {
77 let out_idx = base_idx + i;
78 let byte_idx = base_idx / 2 + i / 2;
79
80 let q4 = if i % 2 == 0 { self.qs[byte_idx] & 0x0F } else { self.qs[byte_idx] >> 4 };
82
83 let qh_bit = ((self.qh[i] >> block_idx) & 1) as u8;
85 let q5 = q4 | (qh_bit << 4);
86
87 output[out_idx] = self.d * scale * (q5 as f32 - 16.0) + self.dmin;
89 }
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
110pub struct BlockQ6K {
111 pub ql: [u8; 128],
113 pub qh: [u8; 64],
115 pub scales: [i8; 16],
117 pub d: f32,
119}
120
121impl BlockQ6K {
122 pub const BLOCK_SIZE: usize = 256;
124
125 pub fn dequantize(&self, output: &mut [f32]) {
131 debug_assert!(output.len() >= Self::BLOCK_SIZE);
132
133 for block_idx in 0..16 {
135 let scale = self.scales[block_idx] as f32;
136 let base_idx = block_idx * 16;
137
138 for i in 0..16 {
139 let out_idx = base_idx + i;
140 let ql_idx = base_idx / 2 + i / 2;
141 let qh_idx = base_idx / 4 + i / 4;
142
143 let ql_val = if i % 2 == 0 { self.ql[ql_idx] & 0x0F } else { self.ql[ql_idx] >> 4 };
145
146 let qh_shift = (i % 4) * 2;
148 let qh_val = ((self.qh[qh_idx] >> qh_shift) & 0x03) as u8;
149
150 let q6 = ql_val | (qh_val << 4);
152
153 output[out_idx] = self.d * scale * (q6 as f32 - 32.0);
155 }
156 }
157 }
158}
159
160#[derive(Debug, Clone)]
168pub struct DotQ5KOp {
169 pub n_blocks: usize,
171}
172
173impl DotQ5KOp {
174 #[must_use]
176 pub fn new(n_elements: usize) -> Self {
177 Self { n_blocks: n_elements / BlockQ5K::BLOCK_SIZE }
178 }
179
180 #[cfg(target_arch = "x86_64")]
182 #[target_feature(enable = "avx2", enable = "fma")]
183 unsafe fn avx2_dot_block(block: &BlockQ5K, x: &[f32]) -> f32 {
185 unsafe {
186 use std::arch::x86_64::*;
187
188 let mut acc = _mm256_setzero_ps();
189 let mut dequant = [0.0f32; BlockQ5K::BLOCK_SIZE];
190 block.dequantize(&mut dequant);
191
192 let mut i = 0;
193 while i + 8 <= BlockQ5K::BLOCK_SIZE {
194 let vd = _mm256_loadu_ps(dequant.as_ptr().add(i));
195 let vx = _mm256_loadu_ps(x.as_ptr().add(i));
196 acc = _mm256_fmadd_ps(vd, vx, acc);
197 i += 8;
198 }
199
200 let high = _mm256_extractf128_ps(acc, 1);
202 let low = _mm256_castps256_ps128(acc);
203 let sum128 = _mm_add_ps(high, low);
204 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
205 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
206 _mm_cvtss_f32(sum32)
207 }
208 }
209}
210
211impl ComputeOp for DotQ5KOp {
212 type Input = (Vec<BlockQ5K>, Vec<f32>);
213 type Output = f32;
214
215 fn name(&self) -> &'static str {
216 "dot_q5k"
217 }
218
219 fn execute(&self, input: Self::Input, backend: Backend) -> Result<Self::Output, TruenoError> {
220 let (blocks, x) = input;
221
222 if blocks.is_empty() || x.is_empty() {
223 return Ok(0.0);
224 }
225
226 let mut sum = 0.0f32;
227
228 #[cfg(target_arch = "x86_64")]
229 {
230 if matches!(backend, Backend::Avx2 | Backend::Auto) && is_x86_feature_detected!("avx2")
231 {
232 for (i, block) in blocks.iter().enumerate() {
233 let x_slice = &x[i * BlockQ5K::BLOCK_SIZE..];
234 sum += unsafe { Self::avx2_dot_block(block, x_slice) };
236 }
237 return Ok(sum);
238 }
239 }
240
241 let mut dequant = [0.0f32; BlockQ5K::BLOCK_SIZE];
243 for (i, block) in blocks.iter().enumerate() {
244 block.dequantize(&mut dequant);
245 let x_slice = &x[i * BlockQ5K::BLOCK_SIZE..];
246 for j in 0..BlockQ5K::BLOCK_SIZE {
247 sum += dequant[j] * x_slice[j];
248 }
249 }
250
251 Ok(sum)
252 }
253
254 fn tokens(&self, _input: &Self::Input) -> usize {
255 self.n_blocks * BlockQ5K::BLOCK_SIZE
256 }
257}
258
259#[derive(Debug, Clone)]
267pub struct DotQ6KOp {
268 pub n_blocks: usize,
270}
271
272impl DotQ6KOp {
273 #[must_use]
275 pub fn new(n_elements: usize) -> Self {
276 Self { n_blocks: n_elements / BlockQ6K::BLOCK_SIZE }
277 }
278
279 #[cfg(target_arch = "x86_64")]
281 #[target_feature(enable = "avx2", enable = "fma")]
282 unsafe fn avx2_dot_block(block: &BlockQ6K, x: &[f32]) -> f32 {
284 unsafe {
285 use std::arch::x86_64::*;
286
287 let mut acc = _mm256_setzero_ps();
288 let mut dequant = [0.0f32; BlockQ6K::BLOCK_SIZE];
289 block.dequantize(&mut dequant);
290
291 let mut i = 0;
292 while i + 8 <= BlockQ6K::BLOCK_SIZE {
293 let vd = _mm256_loadu_ps(dequant.as_ptr().add(i));
294 let vx = _mm256_loadu_ps(x.as_ptr().add(i));
295 acc = _mm256_fmadd_ps(vd, vx, acc);
296 i += 8;
297 }
298
299 let high = _mm256_extractf128_ps(acc, 1);
301 let low = _mm256_castps256_ps128(acc);
302 let sum128 = _mm_add_ps(high, low);
303 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
304 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
305 _mm_cvtss_f32(sum32)
306 }
307 }
308}
309
310impl ComputeOp for DotQ6KOp {
311 type Input = (Vec<BlockQ6K>, Vec<f32>);
312 type Output = f32;
313
314 fn name(&self) -> &'static str {
315 "dot_q6k"
316 }
317
318 fn execute(&self, input: Self::Input, backend: Backend) -> Result<Self::Output, TruenoError> {
319 let (blocks, x) = input;
320
321 if blocks.is_empty() || x.is_empty() {
322 return Ok(0.0);
323 }
324
325 let mut sum = 0.0f32;
326
327 #[cfg(target_arch = "x86_64")]
328 {
329 if matches!(backend, Backend::Avx2 | Backend::Auto) && is_x86_feature_detected!("avx2")
330 {
331 for (i, block) in blocks.iter().enumerate() {
332 let x_slice = &x[i * BlockQ6K::BLOCK_SIZE..];
333 sum += unsafe { Self::avx2_dot_block(block, x_slice) };
335 }
336 return Ok(sum);
337 }
338 }
339
340 let mut dequant = [0.0f32; BlockQ6K::BLOCK_SIZE];
342 for (i, block) in blocks.iter().enumerate() {
343 block.dequantize(&mut dequant);
344 let x_slice = &x[i * BlockQ6K::BLOCK_SIZE..];
345 for j in 0..BlockQ6K::BLOCK_SIZE {
346 sum += dequant[j] * x_slice[j];
347 }
348 }
349
350 Ok(sum)
351 }
352
353 fn tokens(&self, _input: &Self::Input) -> usize {
354 self.n_blocks * BlockQ6K::BLOCK_SIZE
355 }
356}
357
358#[cfg(test)]
359pub mod nf4;
360mod tests;