Skip to main content

oxibonsai_core/
quant_k.rs

1//! K-quant block types for Q2_K, Q3_K, Q4_K, and Q8_K quantization formats.
2//!
3//! These follow the GGML K-quant specification:
4//! - **Q2_K**: 2-bit quantization with 4-bit scales, super-block of 256 weights (84 bytes)
5//! - **Q3_K**: 3-bit quantization with 4-bit scales, super-block of 256 weights (110 bytes)
6//! - **Q4_K**: 4-bit quantization with 6-bit scales, super-block of 256 weights (144 bytes)
7//! - **Q8_K**: 8-bit quantization with FP32 scale, super-block of 256 weights (292 bytes)
8//!
9//! Each super-block stores a global `d` (scale) and `dmin` (minimum) in FP16,
10//! plus per-sub-block scales and quantized weight nibbles/pairs.
11
12use half::f16;
13
14use crate::error::{BonsaiError, BonsaiResult};
15
16// ---------------------------------------------------------------------------
17// Constants
18// ---------------------------------------------------------------------------
19
20/// Number of weights per K-quant super-block.
21pub const QK_K: usize = 256;
22
23/// Number of bytes per Q2_K block.
24pub const BLOCK_Q2_K_BYTES: usize = 84;
25
26/// Number of bytes per Q3_K block.
27pub const BLOCK_Q3K_BYTES: usize = 110;
28
29/// Number of bytes per Q4_K block.
30pub const BLOCK_Q4_K_BYTES: usize = 144;
31
32/// Number of bytes per Q8_K block.
33pub const BLOCK_Q8K_BYTES: usize = 292;
34
35// ---------------------------------------------------------------------------
36// BlockQ2K
37// ---------------------------------------------------------------------------
38
39/// Q2_K super-block: 256 weights quantized to 2 bits each.
40///
41/// Layout (84 bytes):
42/// - `scales`: 16 bytes — packed 4-bit scale/min pairs for 16 sub-blocks of 16 weights.
43///   Each byte holds two 4-bit values: low nibble = scale, high nibble = min.
44/// - `qs`: 64 bytes — 256 x 2-bit quantized weights (4 per byte, LSB first).
45/// - `d`: FP16 super-block scale.
46/// - `dmin`: FP16 super-block minimum.
47///
48/// Dequant: `w[i] = d * sub_scale * q[i] - dmin * sub_min`
49#[derive(Debug, Clone, Copy, PartialEq)]
50#[repr(C)]
51pub struct BlockQ2K {
52    /// Packed 4-bit scale/min pairs for 16 sub-blocks.
53    pub scales: [u8; 16],
54    /// 256 x 2-bit quantized weights, 4 per byte.
55    pub qs: [u8; 64],
56    /// Super-block scale (FP16).
57    pub d: f16,
58    /// Super-block minimum (FP16).
59    pub dmin: f16,
60}
61
62const _: () = assert!(std::mem::size_of::<BlockQ2K>() == BLOCK_Q2_K_BYTES);
63
64impl BlockQ2K {
65    /// Dequantize a slice of Q2_K blocks into f32 output.
66    ///
67    /// `output` must have length `blocks.len() * QK_K`.
68    pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
69        let expected_len = blocks.len() * QK_K;
70        if output.len() < expected_len {
71            return Err(BonsaiError::KQuantError {
72                reason: format!(
73                    "Q2_K dequant: output len {} < expected {}",
74                    output.len(),
75                    expected_len
76                ),
77            });
78        }
79
80        for (block_idx, block) in blocks.iter().enumerate() {
81            let d = block.d.to_f32();
82            let dmin = block.dmin.to_f32();
83            let base = block_idx * QK_K;
84
85            // 16 sub-blocks of 16 weights each
86            for sub in 0..16 {
87                let scale_byte = block.scales[sub];
88                let sc = (scale_byte & 0x0F) as f32; // low nibble = scale
89                let mn = ((scale_byte >> 4) & 0x0F) as f32; // high nibble = min
90
91                let sub_offset = sub * 16;
92                for j in 0..16 {
93                    let global_idx = sub_offset + j;
94                    // Each byte holds 4 x 2-bit values
95                    let byte_idx = global_idx / 4;
96                    let shift = (global_idx % 4) * 2;
97                    let q = ((block.qs[byte_idx] >> shift) & 0x03) as f32;
98                    output[base + global_idx] = d * sc * q - dmin * mn;
99                }
100            }
101        }
102        Ok(())
103    }
104
105    /// Quantize f32 input into Q2_K blocks.
106    ///
107    /// Input length must be a multiple of `QK_K` (256).
108    pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
109        if input.len() % QK_K != 0 {
110            return Err(BonsaiError::KQuantError {
111                reason: format!(
112                    "Q2_K quantize: input len {} not a multiple of {}",
113                    input.len(),
114                    QK_K
115                ),
116            });
117        }
118
119        let num_blocks = input.len() / QK_K;
120        let mut blocks = Vec::with_capacity(num_blocks);
121
122        for block_idx in 0..num_blocks {
123            let base = block_idx * QK_K;
124            let chunk = &input[base..base + QK_K];
125
126            // Pass 1: find global max absolute value and min value across
127            // all sub-blocks to set d and dmin.
128            // For each sub-block of 16 weights, we find the range [min, max].
129            let mut sub_scales = [0.0f32; 16];
130            let mut sub_mins = [0.0f32; 16];
131
132            for sub in 0..16 {
133                let sub_offset = sub * 16;
134                let sub_chunk = &chunk[sub_offset..sub_offset + 16];
135
136                let mut smin = f32::MAX;
137                let mut smax = f32::MIN;
138                for &v in sub_chunk {
139                    if v < smin {
140                        smin = v;
141                    }
142                    if v > smax {
143                        smax = v;
144                    }
145                }
146
147                // The offset (min) removes the minimum, then scale maps remainder to 0..3
148                sub_mins[sub] = if smin < 0.0 { -smin } else { 0.0 };
149                let range = smax + sub_mins[sub];
150                sub_scales[sub] = if range > 0.0 { range / 3.0 } else { 0.0 };
151            }
152
153            // Find the global maximum scale and minimum across sub-blocks
154            let max_scale = sub_scales.iter().copied().fold(0.0f32, f32::max);
155            let max_min = sub_mins.iter().copied().fold(0.0f32, f32::max);
156
157            // Compute d and dmin so that 4-bit sub-block factors (0..15) can represent
158            // the per-sub-block scales and mins.
159            let d = if max_scale > 0.0 {
160                max_scale / 15.0
161            } else {
162                0.0
163            };
164            let dmin = if max_min > 0.0 { max_min / 15.0 } else { 0.0 };
165
166            let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
167            let inv_dmin = if dmin > 0.0 { 1.0 / dmin } else { 0.0 };
168
169            // Quantize per-sub-block scales and mins to 4 bits
170            let mut scales = [0u8; 16];
171            let mut quant_sc = [0u8; 16];
172            let mut quant_mn = [0u8; 16];
173
174            for sub in 0..16 {
175                let sc = (sub_scales[sub] * inv_d + 0.5).min(15.0) as u8;
176                let mn = (sub_mins[sub] * inv_dmin + 0.5).min(15.0) as u8;
177                quant_sc[sub] = sc;
178                quant_mn[sub] = mn;
179                scales[sub] = sc | (mn << 4);
180            }
181
182            // Quantize weights to 2 bits
183            let mut qs = [0u8; 64];
184            for sub in 0..16 {
185                let sub_offset = sub * 16;
186                let sc_f = d * (quant_sc[sub] as f32);
187                let mn_f = dmin * (quant_mn[sub] as f32);
188                let inv_sc = if sc_f > 0.0 { 1.0 / sc_f } else { 0.0 };
189
190                for j in 0..16 {
191                    let global_idx = sub_offset + j;
192                    let val = chunk[global_idx] + mn_f;
193                    let q = (val * inv_sc + 0.5).clamp(0.0, 3.0) as u8;
194                    let byte_idx = global_idx / 4;
195                    let shift = (global_idx % 4) * 2;
196                    qs[byte_idx] |= q << shift;
197                }
198            }
199
200            blocks.push(BlockQ2K {
201                scales,
202                qs,
203                d: f16::from_f32(d),
204                dmin: f16::from_f32(dmin),
205            });
206        }
207
208        Ok(blocks)
209    }
210
211    /// Dequantize a single row's worth of Q2_K blocks into a pre-allocated buffer.
212    ///
213    /// `buf` will be extended by `blocks_for_row.len() * 256` elements.
214    pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
215        let start = buf.len();
216        let n = blocks_for_row.len() * QK_K;
217        buf.resize(start + n, 0.0f32);
218        let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
219    }
220
221    /// Zero-copy cast of a byte slice to a slice of `BlockQ2K`.
222    ///
223    /// Returns error if length is not a multiple of `BLOCK_Q2_K_BYTES` (84)
224    /// or if the pointer is not properly aligned.
225    pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
226        if data.len() % BLOCK_Q2_K_BYTES != 0 {
227            return Err(BonsaiError::KQuantError {
228                reason: format!(
229                    "Q2_K slice_from_bytes: byte len {} not a multiple of {}",
230                    data.len(),
231                    BLOCK_Q2_K_BYTES
232                ),
233            });
234        }
235        if data.is_empty() {
236            return Ok(&[]);
237        }
238        let align = std::mem::align_of::<Self>();
239        if data.as_ptr().align_offset(align) != 0 {
240            return Err(BonsaiError::KQuantError {
241                reason: format!("Q2_K slice_from_bytes: pointer not {}-byte aligned", align),
242            });
243        }
244        let count = data.len() / BLOCK_Q2_K_BYTES;
245        let ptr = data.as_ptr() as *const Self;
246        // SAFETY: repr(C) layout validated by compile-time size assert;
247        // length is a multiple of BLOCK_Q2_K_BYTES; pointer alignment verified above;
248        // lifetime is tied to the input slice.
249        Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
250    }
251}
252
253// ---------------------------------------------------------------------------
254// BlockQ3K
255// ---------------------------------------------------------------------------
256
257/// Q3_K super-block: 256 weights quantized to 3 bits each.
258///
259/// Layout (110 bytes):
260/// - `hmask`:  32 bytes — high bit (bit 2) for each of the 256 weights, packed 8 per byte.
261/// - `qs`:     64 bytes — low 2 bits for each of the 256 weights, packed 4 per byte.
262/// - `scales`: 12 bytes — 4-bit scale values for 16 sub-blocks of 16 weights each.
263///   Each nibble is a signed 4-bit value (stored as u4, subtract 8 for range [-8..7]).
264///   Packing: `scales[j/2] >> (4*(j%2)) & 0xF` gives sub-block j's raw scale.
265/// - `d`: FP16 super-block scale.
266///
267/// Dequant: `w[i] = d * sub_scale * q3_signed[i]`
268/// where `q3_signed = ((low2 | (high1<<2)) as i32) - 4`, range [-4, 3].
269#[derive(Debug, Clone, Copy, PartialEq)]
270#[repr(C)]
271pub struct BlockQ3K {
272    /// High bit (bit 2) for each of 256 weights, packed 8 per byte.
273    pub hmask: [u8; 32],
274    /// Low 2 bits for each of 256 weights, packed 4 per byte (2 bits each, LSB first).
275    pub qs: [u8; 64],
276    /// 16 × 4-bit sub-block scales, 2 per byte (low nibble = sub 2i, high nibble = sub 2i+1).
277    pub scales: [u8; 12],
278    /// Super-block scale (FP16).
279    pub d: f16,
280}
281
282const _: () = assert!(std::mem::size_of::<BlockQ3K>() == BLOCK_Q3K_BYTES);
283
284impl BlockQ3K {
285    /// Dequantize a slice of Q3_K blocks into f32 output.
286    ///
287    /// `output` must have length >= `blocks.len() * QK_K` (256 per block).
288    pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
289        let expected_len = blocks.len() * QK_K;
290        if output.len() < expected_len {
291            return Err(BonsaiError::KQuantError {
292                reason: format!(
293                    "Q3_K dequant: output len {} < expected {}",
294                    output.len(),
295                    expected_len
296                ),
297            });
298        }
299
300        for (block_idx, block) in blocks.iter().enumerate() {
301            let d = block.d.to_f32();
302            let base = block_idx * QK_K;
303
304            // 16 sub-blocks of 16 weights each; scale is 4-bit signed nibble
305            for i in 0..QK_K {
306                // Low 2 bits from qs: each byte holds 4 × 2-bit values (2 bits per weight)
307                let byte_idx = i / 4;
308                let bit_shift = (i % 4) * 2;
309                let lo2 = (block.qs[byte_idx] >> bit_shift) & 0x03;
310
311                // High bit (bit 2) from hmask: each byte holds 8 bits, one per weight
312                let hi1 = (block.hmask[i / 8] >> (i % 8)) & 0x01;
313
314                // 3-bit code in [0..7], centered: range [-4..3]
315                let q3 = lo2 | (hi1 << 2);
316                let q3_signed = (q3 as i32) - 4;
317
318                // Sub-block index: 16 sub-blocks of 16 weights each
319                let sub = i / 16;
320                // 4-bit nibble from scales (2 per byte)
321                let scale_nibble = (block.scales[sub / 2] >> (4 * (sub % 2))) & 0x0F;
322                // Signed 4-bit scale: stored as 0..15 representing -8..7
323                let scale_signed = (scale_nibble as i8) as i32 - 8;
324
325                output[base + i] = d * (scale_signed as f32) * (q3_signed as f32);
326            }
327        }
328        Ok(())
329    }
330
331    /// Dequantize a single row's worth of Q3_K blocks into a pre-allocated buffer.
332    ///
333    /// `buf` will be extended by `blocks_for_row.len() * 256` elements.
334    pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
335        let start = buf.len();
336        let n = blocks_for_row.len() * QK_K;
337        buf.resize(start + n, 0.0f32);
338        let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
339    }
340
341    /// Quantize f32 input into Q3_K blocks.
342    ///
343    /// Input length must be a multiple of `QK_K` (256).
344    ///
345    /// Uses symmetric per-sub-block quantization: 16 sub-blocks of 16 weights each,
346    /// mapping each sub-block to the range [-4..3] with a 4-bit signed scale.
347    pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
348        if input.len() % QK_K != 0 {
349            return Err(BonsaiError::KQuantError {
350                reason: format!(
351                    "Q3_K quantize: input len {} not a multiple of {}",
352                    input.len(),
353                    QK_K
354                ),
355            });
356        }
357
358        let num_blocks = input.len() / QK_K;
359        let mut blocks = Vec::with_capacity(num_blocks);
360
361        for block_idx in 0..num_blocks {
362            let chunk = &input[block_idx * QK_K..block_idx * QK_K + QK_K];
363
364            // Compute per-sub-block max absolute value for 16 sub-blocks of 16 weights
365            let mut sub_max_abs = [0.0f32; 16];
366            for (sub, slot) in sub_max_abs.iter_mut().enumerate() {
367                let sub_chunk = &chunk[sub * 16..(sub + 1) * 16];
368                *slot = sub_chunk.iter().map(|&v| v.abs()).fold(0.0f32, f32::max);
369            }
370
371            // Super-block scale: d * max_scale_nibble * 4 ≈ overall max abs
372            // max scale nibble is 7 (for signed scale value 7 - 8 + 8 = 7 in [0..15])
373            // and max 3-bit centered code is 3 (q3_signed in [-4..3])
374            // Effective range = d * 7 * 3 = d * 21
375            let overall_max = sub_max_abs.iter().copied().fold(0.0f32, f32::max);
376            let d = if overall_max > 0.0 {
377                overall_max / 21.0
378            } else {
379                0.0
380            };
381            let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
382
383            // Compute per-sub-block 4-bit signed scale nibbles
384            let mut scale_nibbles = [0u8; 16];
385            for (sub, &max_abs) in sub_max_abs.iter().enumerate() {
386                // scale_signed = max_abs / (d * 3), clamped to [-8..7]
387                let sc_f = if d > 0.0 { max_abs * inv_d / 3.0 } else { 0.0 };
388                let sc_signed = sc_f.round().clamp(-8.0, 7.0) as i32;
389                // Store as 0..15 (add 8 to shift from signed to unsigned nibble)
390                scale_nibbles[sub] = (sc_signed + 8).clamp(0, 15) as u8;
391            }
392
393            // Pack nibbles into scales[12]: 2 per byte
394            let mut scales = [0u8; 12];
395            for (sub, &nibble_val) in scale_nibbles.iter().enumerate() {
396                let byte_idx = sub / 2;
397                let nibble = nibble_val & 0x0F;
398                if sub % 2 == 0 {
399                    scales[byte_idx] |= nibble;
400                } else {
401                    scales[byte_idx] |= nibble << 4;
402                }
403            }
404
405            // Quantize weights to 3-bit codes
406            let mut hmask = [0u8; 32];
407            let mut qs = [0u8; 64];
408
409            for i in 0..QK_K {
410                let sub = i / 16;
411                let sc_signed = (scale_nibbles[sub] as i32) - 8;
412                // Effective scale for this sub-block
413                let eff_scale = d * (sc_signed as f32);
414                let inv_eff = if eff_scale.abs() > 1e-9 {
415                    1.0 / eff_scale
416                } else {
417                    0.0
418                };
419
420                // Compute 3-bit code: map w → q3_signed in [-4..3], then add 4 → [0..7]
421                let q3_signed = (chunk[i] * inv_eff).round() as i32;
422                let q3 = (q3_signed + 4).clamp(0, 7) as u8;
423
424                // Low 2 bits → qs (4 × 2-bit per byte)
425                let lo2 = q3 & 0x03;
426                let byte_idx = i / 4;
427                let bit_shift = (i % 4) * 2;
428                qs[byte_idx] |= lo2 << bit_shift;
429
430                // High bit (bit 2) → hmask (8 × 1-bit per byte)
431                let hi1 = (q3 >> 2) & 0x01;
432                hmask[i / 8] |= hi1 << (i % 8);
433            }
434
435            blocks.push(BlockQ3K {
436                hmask,
437                qs,
438                scales,
439                d: f16::from_f32(d),
440            });
441        }
442
443        Ok(blocks)
444    }
445
446    /// Zero-copy cast of a byte slice to a slice of `BlockQ3K`.
447    ///
448    /// Returns error if length is not a multiple of `BLOCK_Q3K_BYTES` (110)
449    /// or if the pointer is not properly aligned.
450    pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
451        if data.len() % BLOCK_Q3K_BYTES != 0 {
452            return Err(BonsaiError::KQuantError {
453                reason: format!(
454                    "Q3_K slice_from_bytes: byte len {} not a multiple of {}",
455                    data.len(),
456                    BLOCK_Q3K_BYTES
457                ),
458            });
459        }
460        if data.is_empty() {
461            return Ok(&[]);
462        }
463        let align = std::mem::align_of::<Self>();
464        if data.as_ptr().align_offset(align) != 0 {
465            return Err(BonsaiError::KQuantError {
466                reason: format!("Q3_K slice_from_bytes: pointer not {}-byte aligned", align),
467            });
468        }
469        let count = data.len() / BLOCK_Q3K_BYTES;
470        let ptr = data.as_ptr() as *const Self;
471        // SAFETY: repr(C) layout validated by compile-time size assert;
472        // length is a multiple of BLOCK_Q3K_BYTES; pointer alignment verified above;
473        // lifetime is tied to the input slice.
474        Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
475    }
476}
477
478// ---------------------------------------------------------------------------
479// BlockQ4K
480// ---------------------------------------------------------------------------
481
482/// Q4_K super-block: 256 weights quantized to 4 bits each.
483///
484/// Layout (144 bytes):
485/// - `d`: FP16 super-block scale.
486/// - `dmin`: FP16 super-block minimum.
487/// - `scales`: 12 bytes — packed 6-bit scale/min values for 8 sub-blocks of 32 weights.
488///   Encoding: bytes 0..3 hold low 4 bits of scale[0..7], bytes 4..7 hold low 4 bits
489///   of min[0..7], bytes 8..11 hold the upper 2 bits of scales and mins packed.
490/// - `qs`: 128 bytes — 256 x 4-bit quantized weights (2 per byte).
491///
492/// Dequant: `w[i] = d * sub_scale * q[i] - dmin * sub_min`
493#[derive(Debug, Clone, Copy, PartialEq)]
494#[repr(C)]
495pub struct BlockQ4K {
496    /// Super-block scale (FP16).
497    pub d: f16,
498    /// Super-block minimum (FP16).
499    pub dmin: f16,
500    /// Packed 6-bit scales for 8 sub-blocks.
501    pub scales: [u8; 12],
502    /// 256 x 4-bit quantized weights, 2 per byte.
503    pub qs: [u8; 128],
504}
505
506const _: () = assert!(std::mem::size_of::<BlockQ4K>() == BLOCK_Q4_K_BYTES);
507
508/// Decode the 8 six-bit scale values and 8 six-bit min values from the
509/// 12-byte packed `scales` array in a Q4_K block.
510///
511/// Layout of the 12 bytes:
512/// - bytes 0..3:  low 4 bits of scale[0..7] (two per byte, 4 bits each)
513/// - bytes 4..7:  low 4 bits of min[0..7]   (two per byte, 4 bits each)
514/// - bytes 8..11: upper 2 bits of scale and min, packed
515///
516/// Specifically for bytes 8..11:
517/// - byte  8: bits 0..1 = scale[0] hi, bits 2..3 = scale[1] hi, bits 4..5 = scale[2] hi, bits 6..7 = scale[3] hi
518/// - byte  9: bits 0..1 = scale[4] hi, bits 2..3 = scale[5] hi, bits 4..5 = scale[6] hi, bits 6..7 = scale[7] hi
519/// - byte 10: bits 0..1 = min[0] hi,   bits 2..3 = min[1] hi,   bits 4..5 = min[2] hi,   bits 6..7 = min[3] hi
520/// - byte 11: bits 0..1 = min[4] hi,   bits 2..3 = min[5] hi,   bits 4..5 = min[6] hi,   bits 6..7 = min[7] hi
521fn decode_q4k_scales(scales_raw: &[u8; 12]) -> ([u8; 8], [u8; 8]) {
522    let mut sc = [0u8; 8];
523    let mut mn = [0u8; 8];
524
525    // Low 4 bits of scales (2 per byte in bytes 0..3)
526    for i in 0..4 {
527        sc[2 * i] = scales_raw[i] & 0x0F;
528        sc[2 * i + 1] = (scales_raw[i] >> 4) & 0x0F;
529    }
530
531    // Low 4 bits of mins (2 per byte in bytes 4..7)
532    for i in 0..4 {
533        mn[2 * i] = scales_raw[4 + i] & 0x0F;
534        mn[2 * i + 1] = (scales_raw[4 + i] >> 4) & 0x0F;
535    }
536
537    // Upper 2 bits of scales from bytes 8..9
538    for i in 0..4 {
539        sc[i] |= ((scales_raw[8] >> (2 * i)) & 0x03) << 4;
540        sc[4 + i] |= ((scales_raw[9] >> (2 * i)) & 0x03) << 4;
541    }
542
543    // Upper 2 bits of mins from bytes 10..11
544    for i in 0..4 {
545        mn[i] |= ((scales_raw[10] >> (2 * i)) & 0x03) << 4;
546        mn[4 + i] |= ((scales_raw[11] >> (2 * i)) & 0x03) << 4;
547    }
548
549    (sc, mn)
550}
551
552/// Encode 8 six-bit scale values and 8 six-bit min values into the 12-byte
553/// packed format used by Q4_K blocks.
554fn encode_q4k_scales(sc: &[u8; 8], mn: &[u8; 8]) -> [u8; 12] {
555    let mut out = [0u8; 12];
556
557    // Low 4 bits of scales into bytes 0..3
558    for i in 0..4 {
559        out[i] = (sc[2 * i] & 0x0F) | ((sc[2 * i + 1] & 0x0F) << 4);
560    }
561
562    // Low 4 bits of mins into bytes 4..7
563    for i in 0..4 {
564        out[4 + i] = (mn[2 * i] & 0x0F) | ((mn[2 * i + 1] & 0x0F) << 4);
565    }
566
567    // Upper 2 bits of scales into bytes 8..9
568    for i in 0..4 {
569        out[8] |= ((sc[i] >> 4) & 0x03) << (2 * i);
570        out[9] |= ((sc[4 + i] >> 4) & 0x03) << (2 * i);
571    }
572
573    // Upper 2 bits of mins into bytes 10..11
574    for i in 0..4 {
575        out[10] |= ((mn[i] >> 4) & 0x03) << (2 * i);
576        out[11] |= ((mn[4 + i] >> 4) & 0x03) << (2 * i);
577    }
578
579    out
580}
581
582impl BlockQ4K {
583    /// Dequantize a slice of Q4_K blocks into f32 output.
584    ///
585    /// `output` must have length >= `blocks.len() * QK_K`.
586    pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
587        let expected_len = blocks.len() * QK_K;
588        if output.len() < expected_len {
589            return Err(BonsaiError::KQuantError {
590                reason: format!(
591                    "Q4_K dequant: output len {} < expected {}",
592                    output.len(),
593                    expected_len
594                ),
595            });
596        }
597
598        for (block_idx, block) in blocks.iter().enumerate() {
599            let d = block.d.to_f32();
600            let dmin_val = block.dmin.to_f32();
601            let base = block_idx * QK_K;
602
603            let (sc, mn) = decode_q4k_scales(&block.scales);
604
605            // 8 sub-blocks of 32 weights each
606            for sub in 0..8 {
607                let sub_scale = d * (sc[sub] as f32);
608                let sub_min = dmin_val * (mn[sub] as f32);
609                let sub_offset = sub * 32;
610
611                for j in 0..32 {
612                    let global_idx = sub_offset + j;
613                    let byte_idx = global_idx / 2;
614                    let q = if global_idx % 2 == 0 {
615                        (block.qs[byte_idx] & 0x0F) as f32
616                    } else {
617                        ((block.qs[byte_idx] >> 4) & 0x0F) as f32
618                    };
619                    output[base + global_idx] = sub_scale * q - sub_min;
620                }
621            }
622        }
623        Ok(())
624    }
625
626    /// Quantize f32 input into Q4_K blocks.
627    ///
628    /// Input length must be a multiple of `QK_K` (256).
629    pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
630        if input.len() % QK_K != 0 {
631            return Err(BonsaiError::KQuantError {
632                reason: format!(
633                    "Q4_K quantize: input len {} not a multiple of {}",
634                    input.len(),
635                    QK_K
636                ),
637            });
638        }
639
640        let num_blocks = input.len() / QK_K;
641        let mut blocks = Vec::with_capacity(num_blocks);
642
643        for block_idx in 0..num_blocks {
644            let base = block_idx * QK_K;
645            let chunk = &input[base..base + QK_K];
646
647            // 8 sub-blocks of 32 weights
648            let mut sub_scales = [0.0f32; 8];
649            let mut sub_mins = [0.0f32; 8];
650
651            for sub in 0..8 {
652                let sub_offset = sub * 32;
653                let sub_chunk = &chunk[sub_offset..sub_offset + 32];
654
655                let mut smin = f32::MAX;
656                let mut smax = f32::MIN;
657                for &v in sub_chunk {
658                    if v < smin {
659                        smin = v;
660                    }
661                    if v > smax {
662                        smax = v;
663                    }
664                }
665
666                sub_mins[sub] = if smin < 0.0 { -smin } else { 0.0 };
667                let range = smax + sub_mins[sub];
668                sub_scales[sub] = if range > 0.0 { range / 15.0 } else { 0.0 };
669            }
670
671            let max_scale = sub_scales.iter().copied().fold(0.0f32, f32::max);
672            let max_min = sub_mins.iter().copied().fold(0.0f32, f32::max);
673
674            // 6-bit sub-block factors: 0..63
675            let d = if max_scale > 0.0 {
676                max_scale / 63.0
677            } else {
678                0.0
679            };
680            let dmin = if max_min > 0.0 { max_min / 63.0 } else { 0.0 };
681
682            let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
683            let inv_dmin = if dmin > 0.0 { 1.0 / dmin } else { 0.0 };
684
685            let mut sc = [0u8; 8];
686            let mut mn = [0u8; 8];
687
688            for sub in 0..8 {
689                sc[sub] = (sub_scales[sub] * inv_d + 0.5).min(63.0) as u8;
690                mn[sub] = (sub_mins[sub] * inv_dmin + 0.5).min(63.0) as u8;
691            }
692
693            let scales = encode_q4k_scales(&sc, &mn);
694
695            // Quantize weights to 4 bits
696            let mut qs = [0u8; 128];
697            for sub in 0..8 {
698                let sub_offset = sub * 32;
699                let sc_f = d * (sc[sub] as f32);
700                let mn_f = dmin * (mn[sub] as f32);
701                let inv_sc = if sc_f > 0.0 { 1.0 / sc_f } else { 0.0 };
702
703                for j in 0..32 {
704                    let global_idx = sub_offset + j;
705                    let val = chunk[global_idx] + mn_f;
706                    let q = (val * inv_sc + 0.5).clamp(0.0, 15.0) as u8;
707                    let byte_idx = global_idx / 2;
708                    if global_idx % 2 == 0 {
709                        qs[byte_idx] |= q & 0x0F;
710                    } else {
711                        qs[byte_idx] |= (q & 0x0F) << 4;
712                    }
713                }
714            }
715
716            blocks.push(BlockQ4K {
717                d: f16::from_f32(d),
718                dmin: f16::from_f32(dmin),
719                scales,
720                qs,
721            });
722        }
723
724        Ok(blocks)
725    }
726
727    /// Dequantize a single row's worth of Q4_K blocks into a pre-allocated buffer.
728    ///
729    /// `buf` will be extended by `blocks_for_row.len() * 256` elements.
730    pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
731        let start = buf.len();
732        let n = blocks_for_row.len() * QK_K;
733        buf.resize(start + n, 0.0f32);
734        let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
735    }
736
737    /// Zero-copy cast of a byte slice to a slice of `BlockQ4K`.
738    ///
739    /// Returns error if length is not a multiple of `BLOCK_Q4_K_BYTES` (144)
740    /// or if the pointer is not properly aligned.
741    pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
742        if data.len() % BLOCK_Q4_K_BYTES != 0 {
743            return Err(BonsaiError::KQuantError {
744                reason: format!(
745                    "Q4_K slice_from_bytes: byte len {} not a multiple of {}",
746                    data.len(),
747                    BLOCK_Q4_K_BYTES
748                ),
749            });
750        }
751        if data.is_empty() {
752            return Ok(&[]);
753        }
754        let align = std::mem::align_of::<Self>();
755        if data.as_ptr().align_offset(align) != 0 {
756            return Err(BonsaiError::KQuantError {
757                reason: format!("Q4_K slice_from_bytes: pointer not {}-byte aligned", align),
758            });
759        }
760        let count = data.len() / BLOCK_Q4_K_BYTES;
761        let ptr = data.as_ptr() as *const Self;
762        // SAFETY: repr(C) layout validated by compile-time size assert;
763        // length is a multiple of BLOCK_Q4_K_BYTES; pointer alignment verified above;
764        // lifetime is tied to the input slice.
765        Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
766    }
767}
768
769// ---------------------------------------------------------------------------
770// BlockQ8K
771// ---------------------------------------------------------------------------
772
773/// Q8_K super-block: 256 weights quantized to 8 bits (int8) each.
774///
775/// Layout (292 bytes):
776/// - `d`:      4 bytes — f32 super-block scale (NOT f16, unlike other K-quant formats).
777/// - `qs`:     256 bytes — int8 quantized weight values.
778/// - `bsums`:  32 bytes — precomputed sums of 16 groups of 16 weights (int16, for dot-product optimization).
779///
780/// Dequant: `w[i] = d * qs[i]` (bsums are not needed for scalar dequant).
781#[derive(Debug, Clone, Copy, PartialEq)]
782#[repr(C)]
783pub struct BlockQ8K {
784    /// Super-block scale (f32, NOT f16).
785    pub d: f32,
786    /// 256 int8 quantized weight values.
787    pub qs: [i8; 256],
788    /// Precomputed sums of 16 groups of 16 weights (for SIMD dot-product optimization).
789    pub bsums: [i16; 16],
790}
791
792const _: () = assert!(std::mem::size_of::<BlockQ8K>() == BLOCK_Q8K_BYTES);
793
794impl BlockQ8K {
795    /// Dequantize a slice of Q8_K blocks into f32 output.
796    ///
797    /// `output` must have length >= `blocks.len() * QK_K` (256 per block).
798    pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
799        let expected_len = blocks.len() * QK_K;
800        if output.len() < expected_len {
801            return Err(BonsaiError::KQuantError {
802                reason: format!(
803                    "Q8_K dequant: output len {} < expected {}",
804                    output.len(),
805                    expected_len
806                ),
807            });
808        }
809
810        for (block_idx, block) in blocks.iter().enumerate() {
811            let d = block.d;
812            let base = block_idx * QK_K;
813            for i in 0..QK_K {
814                output[base + i] = d * (block.qs[i] as f32);
815            }
816        }
817        Ok(())
818    }
819
820    /// Dequantize a single row's worth of Q8_K blocks into a pre-allocated buffer.
821    ///
822    /// `buf` will be extended by `blocks_for_row.len() * 256` elements.
823    pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
824        let start = buf.len();
825        let n = blocks_for_row.len() * QK_K;
826        buf.resize(start + n, 0.0f32);
827        let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
828    }
829
830    /// Quantize f32 input into Q8_K blocks.
831    ///
832    /// Input length must be a multiple of `QK_K` (256).
833    ///
834    /// Uses a single super-block scale `d = max_abs / 127`. The `bsums` field is
835    /// populated with the sum of each group of 16 weights (useful for SIMD optimized
836    /// dot-product computation in other implementations).
837    pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
838        if input.len() % QK_K != 0 {
839            return Err(BonsaiError::KQuantError {
840                reason: format!(
841                    "Q8_K quantize: input len {} not a multiple of {}",
842                    input.len(),
843                    QK_K
844                ),
845            });
846        }
847
848        let num_blocks = input.len() / QK_K;
849        let mut blocks = Vec::with_capacity(num_blocks);
850
851        for block_idx in 0..num_blocks {
852            let chunk = &input[block_idx * QK_K..block_idx * QK_K + QK_K];
853
854            // Find max absolute value across all 256 weights
855            let max_abs = chunk.iter().map(|&v| v.abs()).fold(0.0f32, f32::max);
856
857            let d = if max_abs > 0.0 { max_abs / 127.0 } else { 0.0 };
858            let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
859
860            let mut qs = [0i8; 256];
861            for (i, &w) in chunk.iter().enumerate() {
862                qs[i] = (w * inv_d).round().clamp(-127.0, 127.0) as i8;
863            }
864
865            // Precompute bsums: sum of each group of 16 weights (as int16)
866            let mut bsums = [0i16; 16];
867            for (group, slot) in bsums.iter_mut().enumerate() {
868                let group_start = group * 16;
869                let sum: i32 = qs[group_start..group_start + 16]
870                    .iter()
871                    .map(|&q| q as i32)
872                    .sum();
873                *slot = sum.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
874            }
875
876            blocks.push(BlockQ8K { d, qs, bsums });
877        }
878
879        Ok(blocks)
880    }
881
882    /// Zero-copy cast of a byte slice to a slice of `BlockQ8K`.
883    ///
884    /// Returns error if length is not a multiple of `BLOCK_Q8K_BYTES` (292)
885    /// or if the pointer is not properly aligned.
886    pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
887        if data.len() % BLOCK_Q8K_BYTES != 0 {
888            return Err(BonsaiError::KQuantError {
889                reason: format!(
890                    "Q8_K slice_from_bytes: byte len {} not a multiple of {}",
891                    data.len(),
892                    BLOCK_Q8K_BYTES
893                ),
894            });
895        }
896        if data.is_empty() {
897            return Ok(&[]);
898        }
899        let align = std::mem::align_of::<Self>();
900        if data.as_ptr().align_offset(align) != 0 {
901            return Err(BonsaiError::KQuantError {
902                reason: format!("Q8_K slice_from_bytes: pointer not {}-byte aligned", align),
903            });
904        }
905        let count = data.len() / BLOCK_Q8K_BYTES;
906        let ptr = data.as_ptr() as *const Self;
907        // SAFETY: repr(C) layout validated by compile-time size assert;
908        // length is a multiple of BLOCK_Q8K_BYTES; pointer alignment verified above;
909        // lifetime is tied to the input slice.
910        Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
911    }
912}
913
914#[cfg(test)]
915mod tests {
916    use super::*;
917
918    #[test]
919    fn q2k_block_size_correct() {
920        assert_eq!(std::mem::size_of::<BlockQ2K>(), BLOCK_Q2_K_BYTES);
921        assert_eq!(BLOCK_Q2_K_BYTES, 84);
922    }
923
924    #[test]
925    fn q2k_roundtrip_zero_weights() {
926        let blocks = BlockQ2K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
927        let mut out = vec![0.0f32; 256];
928        BlockQ2K::dequant(&blocks, &mut out).expect("dequant ok");
929        for &v in &out {
930            assert!(
931                v.abs() < 1e-4,
932                "all-zero input should dequant to near-zero, got {v}"
933            );
934        }
935    }
936
937    #[test]
938    fn q2k_roundtrip_uniform() {
939        let input = vec![1.0f32; 256];
940        let blocks = BlockQ2K::quantize(&input).expect("quantize ok");
941        let mut out = vec![0.0f32; 256];
942        BlockQ2K::dequant(&blocks, &mut out).expect("dequant ok");
943        for &v in &out {
944            let err = (v - 1.0).abs();
945            assert!(err < 0.2, "uniform round-trip error {err} too high");
946        }
947    }
948
949    #[test]
950    fn q2k_quantize_output_length() {
951        let input = vec![0.5f32; 256];
952        let blocks = BlockQ2K::quantize(&input).expect("quantize ok");
953        assert_eq!(blocks.len(), 1);
954    }
955
956    #[test]
957    fn q2k_slice_from_bytes_empty() {
958        let data: Vec<u8> = vec![];
959        let result = BlockQ2K::slice_from_bytes(&data).expect("empty slice ok");
960        assert_eq!(result.len(), 0);
961    }
962
963    #[test]
964    fn q2k_slice_from_bytes_bad_length() {
965        let data = vec![0u8; 83]; // not a multiple of 84
966        assert!(BlockQ2K::slice_from_bytes(&data).is_err());
967    }
968
969    #[test]
970    fn q4k_block_size_correct() {
971        assert_eq!(std::mem::size_of::<BlockQ4K>(), BLOCK_Q4_K_BYTES);
972        assert_eq!(BLOCK_Q4_K_BYTES, 144);
973    }
974
975    #[test]
976    fn q4k_scale_encode_decode_roundtrip() {
977        let sc = [1, 2, 3, 4, 5, 63, 32, 0];
978        let mn = [10, 20, 30, 40, 50, 60, 15, 7];
979        let encoded = encode_q4k_scales(&sc, &mn);
980        let (sc2, mn2) = decode_q4k_scales(&encoded);
981        assert_eq!(sc, sc2);
982        assert_eq!(mn, mn2);
983    }
984
985    #[test]
986    fn q4k_scale_encode_decode_all_zeros() {
987        let sc = [0u8; 8];
988        let mn = [0u8; 8];
989        let encoded = encode_q4k_scales(&sc, &mn);
990        let (sc2, mn2) = decode_q4k_scales(&encoded);
991        assert_eq!(sc, sc2);
992        assert_eq!(mn, mn2);
993    }
994
995    #[test]
996    fn q4k_scale_encode_decode_max_values() {
997        let sc = [63u8; 8];
998        let mn = [63u8; 8];
999        let encoded = encode_q4k_scales(&sc, &mn);
1000        let (sc2, mn2) = decode_q4k_scales(&encoded);
1001        assert_eq!(sc, sc2);
1002        assert_eq!(mn, mn2);
1003    }
1004
1005    #[test]
1006    fn q4k_slice_from_bytes_empty() {
1007        let data: Vec<u8> = vec![];
1008        let result = BlockQ4K::slice_from_bytes(&data).expect("empty slice ok");
1009        assert_eq!(result.len(), 0);
1010    }
1011
1012    #[test]
1013    fn q4k_slice_from_bytes_bad_length() {
1014        let data = vec![0u8; 100]; // not a multiple of 144
1015        assert!(BlockQ4K::slice_from_bytes(&data).is_err());
1016    }
1017
1018    // -----------------------------------------------------------------------
1019    // BlockQ3K tests
1020    // -----------------------------------------------------------------------
1021
1022    #[test]
1023    fn q3k_block_size_assertion() {
1024        assert_eq!(std::mem::size_of::<BlockQ3K>(), BLOCK_Q3K_BYTES);
1025        assert_eq!(BLOCK_Q3K_BYTES, 110);
1026    }
1027
1028    #[test]
1029    fn q3k_roundtrip_zero_weights() {
1030        let blocks = BlockQ3K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
1031        let mut out = vec![0.0f32; 256];
1032        BlockQ3K::dequant(&blocks, &mut out).expect("dequant ok");
1033        for &v in &out {
1034            assert!(
1035                v.abs() < 1e-4,
1036                "all-zero input should dequant to near-zero, got {v}"
1037            );
1038        }
1039    }
1040
1041    #[test]
1042    fn q3k_roundtrip_uniform() {
1043        // Uniform positive values should round-trip with error < 5% of the value.
1044        let input = vec![1.0f32; 256];
1045        let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
1046        let mut out = vec![0.0f32; 256];
1047        BlockQ3K::dequant(&blocks, &mut out).expect("dequant ok");
1048        for &v in &out {
1049            let err = (v - 1.0).abs() / 1.0;
1050            assert!(
1051                err < 0.5,
1052                "uniform round-trip rel error {err} too high, got {v}"
1053            );
1054        }
1055    }
1056
1057    #[test]
1058    fn q3k_slice_from_bytes() {
1059        // Create a valid aligned byte buffer of exactly 110 bytes and parse it.
1060        let data = vec![0u8; BLOCK_Q3K_BYTES];
1061        let result = BlockQ3K::slice_from_bytes(&data).expect("single block should parse");
1062        assert_eq!(result.len(), 1);
1063    }
1064
1065    #[test]
1066    fn q3k_slice_from_bytes_empty() {
1067        let data: Vec<u8> = vec![];
1068        let result = BlockQ3K::slice_from_bytes(&data).expect("empty slice ok");
1069        assert_eq!(result.len(), 0);
1070    }
1071
1072    #[test]
1073    fn q3k_slice_from_bytes_bad_length() {
1074        let data = vec![0u8; 100]; // not a multiple of 110
1075        assert!(BlockQ3K::slice_from_bytes(&data).is_err());
1076    }
1077
1078    #[test]
1079    fn q3k_quantize_output_length() {
1080        let input = vec![0.5f32; 256];
1081        let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
1082        assert_eq!(blocks.len(), 1, "256 weights → 1 block");
1083    }
1084
1085    #[test]
1086    fn q3k_quantize_non_multiple_errors() {
1087        assert!(BlockQ3K::quantize(&vec![1.0f32; 100]).is_err());
1088    }
1089
1090    #[test]
1091    fn q3k_dequant_output_too_small_errors() {
1092        let blocks = BlockQ3K::quantize(&vec![1.0f32; 256]).expect("quantize ok");
1093        let mut out = vec![0.0f32; 100];
1094        assert!(BlockQ3K::dequant(&blocks, &mut out).is_err());
1095    }
1096
1097    #[test]
1098    fn q3k_dequant_row_to_buf_works() {
1099        let input = vec![0.5f32; 256];
1100        let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
1101        let mut buf = Vec::new();
1102        BlockQ3K::dequant_row_to_buf(&blocks, &mut buf);
1103        assert_eq!(buf.len(), 256);
1104    }
1105
1106    // -----------------------------------------------------------------------
1107    // BlockQ8K tests
1108    // -----------------------------------------------------------------------
1109
1110    #[test]
1111    fn q8k_block_size_assertion() {
1112        assert_eq!(std::mem::size_of::<BlockQ8K>(), BLOCK_Q8K_BYTES);
1113        assert_eq!(BLOCK_Q8K_BYTES, 292);
1114    }
1115
1116    #[test]
1117    fn q8k_roundtrip_zero_weights() {
1118        let blocks = BlockQ8K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
1119        let mut out = vec![0.0f32; 256];
1120        BlockQ8K::dequant(&blocks, &mut out).expect("dequant ok");
1121        for &v in &out {
1122            assert!(
1123                v.abs() < 1e-6,
1124                "all-zero input should dequant to exactly zero, got {v}"
1125            );
1126        }
1127    }
1128
1129    #[test]
1130    fn q8k_roundtrip_uniform() {
1131        let input = vec![1.0f32; 256];
1132        let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
1133        let mut out = vec![0.0f32; 256];
1134        BlockQ8K::dequant(&blocks, &mut out).expect("dequant ok");
1135        for &v in &out {
1136            let err = (v - 1.0).abs();
1137            assert!(err < 0.02, "Q8_K uniform round-trip error {err} too high");
1138        }
1139    }
1140
1141    #[test]
1142    fn q8k_slice_from_bytes() {
1143        let data = vec![0u8; BLOCK_Q8K_BYTES];
1144        let result = BlockQ8K::slice_from_bytes(&data).expect("single block should parse");
1145        assert_eq!(result.len(), 1);
1146    }
1147
1148    #[test]
1149    fn q8k_slice_from_bytes_empty() {
1150        let data: Vec<u8> = vec![];
1151        let result = BlockQ8K::slice_from_bytes(&data).expect("empty slice ok");
1152        assert_eq!(result.len(), 0);
1153    }
1154
1155    #[test]
1156    fn q8k_slice_from_bytes_bad_length() {
1157        let data = vec![0u8; 100]; // not a multiple of 292
1158        assert!(BlockQ8K::slice_from_bytes(&data).is_err());
1159    }
1160
1161    #[test]
1162    fn q8k_quantize_output_length() {
1163        let input = vec![0.5f32; 256];
1164        let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
1165        assert_eq!(blocks.len(), 1, "256 weights → 1 block");
1166    }
1167
1168    #[test]
1169    fn q8k_quantize_non_multiple_errors() {
1170        assert!(BlockQ8K::quantize(&vec![1.0f32; 100]).is_err());
1171    }
1172
1173    #[test]
1174    fn q8k_dequant_output_too_small_errors() {
1175        let blocks = BlockQ8K::quantize(&vec![1.0f32; 256]).expect("quantize ok");
1176        let mut out = vec![0.0f32; 100];
1177        assert!(BlockQ8K::dequant(&blocks, &mut out).is_err());
1178    }
1179
1180    #[test]
1181    fn q8k_dequant_row_to_buf_works() {
1182        let input = vec![0.5f32; 256];
1183        let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
1184        let mut buf = Vec::new();
1185        BlockQ8K::dequant_row_to_buf(&blocks, &mut buf);
1186        assert_eq!(buf.len(), 256);
1187        for &v in &buf {
1188            assert!((v - 0.5).abs() < 0.01, "expected ~0.5, got {v}");
1189        }
1190    }
1191
1192    #[test]
1193    fn q8k_bsums_roundtrip_sign() {
1194        // Verify bsums signs: positive input → positive bsums, negative → negative bsums.
1195        let input_pos = vec![0.5f32; 256];
1196        let blocks_pos = BlockQ8K::quantize(&input_pos).expect("quantize ok");
1197        for &bs in &blocks_pos[0].bsums {
1198            assert!(
1199                bs > 0,
1200                "positive input should yield positive bsums, got {bs}"
1201            );
1202        }
1203
1204        let input_neg = vec![-0.5f32; 256];
1205        let blocks_neg = BlockQ8K::quantize(&input_neg).expect("quantize ok");
1206        for &bs in &blocks_neg[0].bsums {
1207            assert!(
1208                bs < 0,
1209                "negative input should yield negative bsums, got {bs}"
1210            );
1211        }
1212    }
1213}