1use half::f16;
13
14use crate::error::{BonsaiError, BonsaiResult};
15
16pub const QK_K: usize = 256;
22
23pub const BLOCK_Q2_K_BYTES: usize = 84;
25
26pub const BLOCK_Q3K_BYTES: usize = 110;
28
29pub const BLOCK_Q4_K_BYTES: usize = 144;
31
32pub const BLOCK_Q8K_BYTES: usize = 292;
34
35#[derive(Debug, Clone, Copy, PartialEq)]
50#[repr(C)]
51pub struct BlockQ2K {
52 pub scales: [u8; 16],
54 pub qs: [u8; 64],
56 pub d: f16,
58 pub dmin: f16,
60}
61
62const _: () = assert!(std::mem::size_of::<BlockQ2K>() == BLOCK_Q2_K_BYTES);
63
64impl BlockQ2K {
65 pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
69 let expected_len = blocks.len() * QK_K;
70 if output.len() < expected_len {
71 return Err(BonsaiError::KQuantError {
72 reason: format!(
73 "Q2_K dequant: output len {} < expected {}",
74 output.len(),
75 expected_len
76 ),
77 });
78 }
79
80 for (block_idx, block) in blocks.iter().enumerate() {
81 let d = block.d.to_f32();
82 let dmin = block.dmin.to_f32();
83 let base = block_idx * QK_K;
84
85 for sub in 0..16 {
87 let scale_byte = block.scales[sub];
88 let sc = (scale_byte & 0x0F) as f32; let mn = ((scale_byte >> 4) & 0x0F) as f32; let sub_offset = sub * 16;
92 for j in 0..16 {
93 let global_idx = sub_offset + j;
94 let byte_idx = global_idx / 4;
96 let shift = (global_idx % 4) * 2;
97 let q = ((block.qs[byte_idx] >> shift) & 0x03) as f32;
98 output[base + global_idx] = d * sc * q - dmin * mn;
99 }
100 }
101 }
102 Ok(())
103 }
104
105 pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
109 if input.len() % QK_K != 0 {
110 return Err(BonsaiError::KQuantError {
111 reason: format!(
112 "Q2_K quantize: input len {} not a multiple of {}",
113 input.len(),
114 QK_K
115 ),
116 });
117 }
118
119 let num_blocks = input.len() / QK_K;
120 let mut blocks = Vec::with_capacity(num_blocks);
121
122 for block_idx in 0..num_blocks {
123 let base = block_idx * QK_K;
124 let chunk = &input[base..base + QK_K];
125
126 let mut sub_scales = [0.0f32; 16];
130 let mut sub_mins = [0.0f32; 16];
131
132 for sub in 0..16 {
133 let sub_offset = sub * 16;
134 let sub_chunk = &chunk[sub_offset..sub_offset + 16];
135
136 let mut smin = f32::MAX;
137 let mut smax = f32::MIN;
138 for &v in sub_chunk {
139 if v < smin {
140 smin = v;
141 }
142 if v > smax {
143 smax = v;
144 }
145 }
146
147 sub_mins[sub] = if smin < 0.0 { -smin } else { 0.0 };
149 let range = smax + sub_mins[sub];
150 sub_scales[sub] = if range > 0.0 { range / 3.0 } else { 0.0 };
151 }
152
153 let max_scale = sub_scales.iter().copied().fold(0.0f32, f32::max);
155 let max_min = sub_mins.iter().copied().fold(0.0f32, f32::max);
156
157 let d = if max_scale > 0.0 {
160 max_scale / 15.0
161 } else {
162 0.0
163 };
164 let dmin = if max_min > 0.0 { max_min / 15.0 } else { 0.0 };
165
166 let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
167 let inv_dmin = if dmin > 0.0 { 1.0 / dmin } else { 0.0 };
168
169 let mut scales = [0u8; 16];
171 let mut quant_sc = [0u8; 16];
172 let mut quant_mn = [0u8; 16];
173
174 for sub in 0..16 {
175 let sc = (sub_scales[sub] * inv_d + 0.5).min(15.0) as u8;
176 let mn = (sub_mins[sub] * inv_dmin + 0.5).min(15.0) as u8;
177 quant_sc[sub] = sc;
178 quant_mn[sub] = mn;
179 scales[sub] = sc | (mn << 4);
180 }
181
182 let mut qs = [0u8; 64];
184 for sub in 0..16 {
185 let sub_offset = sub * 16;
186 let sc_f = d * (quant_sc[sub] as f32);
187 let mn_f = dmin * (quant_mn[sub] as f32);
188 let inv_sc = if sc_f > 0.0 { 1.0 / sc_f } else { 0.0 };
189
190 for j in 0..16 {
191 let global_idx = sub_offset + j;
192 let val = chunk[global_idx] + mn_f;
193 let q = (val * inv_sc + 0.5).clamp(0.0, 3.0) as u8;
194 let byte_idx = global_idx / 4;
195 let shift = (global_idx % 4) * 2;
196 qs[byte_idx] |= q << shift;
197 }
198 }
199
200 blocks.push(BlockQ2K {
201 scales,
202 qs,
203 d: f16::from_f32(d),
204 dmin: f16::from_f32(dmin),
205 });
206 }
207
208 Ok(blocks)
209 }
210
211 pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
215 let start = buf.len();
216 let n = blocks_for_row.len() * QK_K;
217 buf.resize(start + n, 0.0f32);
218 let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
219 }
220
221 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
226 if data.len() % BLOCK_Q2_K_BYTES != 0 {
227 return Err(BonsaiError::KQuantError {
228 reason: format!(
229 "Q2_K slice_from_bytes: byte len {} not a multiple of {}",
230 data.len(),
231 BLOCK_Q2_K_BYTES
232 ),
233 });
234 }
235 if data.is_empty() {
236 return Ok(&[]);
237 }
238 let align = std::mem::align_of::<Self>();
239 if data.as_ptr().align_offset(align) != 0 {
240 return Err(BonsaiError::KQuantError {
241 reason: format!("Q2_K slice_from_bytes: pointer not {}-byte aligned", align),
242 });
243 }
244 let count = data.len() / BLOCK_Q2_K_BYTES;
245 let ptr = data.as_ptr() as *const Self;
246 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq)]
270#[repr(C)]
271pub struct BlockQ3K {
272 pub hmask: [u8; 32],
274 pub qs: [u8; 64],
276 pub scales: [u8; 12],
278 pub d: f16,
280}
281
282const _: () = assert!(std::mem::size_of::<BlockQ3K>() == BLOCK_Q3K_BYTES);
283
284impl BlockQ3K {
285 pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
289 let expected_len = blocks.len() * QK_K;
290 if output.len() < expected_len {
291 return Err(BonsaiError::KQuantError {
292 reason: format!(
293 "Q3_K dequant: output len {} < expected {}",
294 output.len(),
295 expected_len
296 ),
297 });
298 }
299
300 for (block_idx, block) in blocks.iter().enumerate() {
301 let d = block.d.to_f32();
302 let base = block_idx * QK_K;
303
304 for i in 0..QK_K {
306 let byte_idx = i / 4;
308 let bit_shift = (i % 4) * 2;
309 let lo2 = (block.qs[byte_idx] >> bit_shift) & 0x03;
310
311 let hi1 = (block.hmask[i / 8] >> (i % 8)) & 0x01;
313
314 let q3 = lo2 | (hi1 << 2);
316 let q3_signed = (q3 as i32) - 4;
317
318 let sub = i / 16;
320 let scale_nibble = (block.scales[sub / 2] >> (4 * (sub % 2))) & 0x0F;
322 let scale_signed = (scale_nibble as i8) as i32 - 8;
324
325 output[base + i] = d * (scale_signed as f32) * (q3_signed as f32);
326 }
327 }
328 Ok(())
329 }
330
331 pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
335 let start = buf.len();
336 let n = blocks_for_row.len() * QK_K;
337 buf.resize(start + n, 0.0f32);
338 let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
339 }
340
341 pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
348 if input.len() % QK_K != 0 {
349 return Err(BonsaiError::KQuantError {
350 reason: format!(
351 "Q3_K quantize: input len {} not a multiple of {}",
352 input.len(),
353 QK_K
354 ),
355 });
356 }
357
358 let num_blocks = input.len() / QK_K;
359 let mut blocks = Vec::with_capacity(num_blocks);
360
361 for block_idx in 0..num_blocks {
362 let chunk = &input[block_idx * QK_K..block_idx * QK_K + QK_K];
363
364 let mut sub_max_abs = [0.0f32; 16];
366 for (sub, slot) in sub_max_abs.iter_mut().enumerate() {
367 let sub_chunk = &chunk[sub * 16..(sub + 1) * 16];
368 *slot = sub_chunk.iter().map(|&v| v.abs()).fold(0.0f32, f32::max);
369 }
370
371 let overall_max = sub_max_abs.iter().copied().fold(0.0f32, f32::max);
376 let d = if overall_max > 0.0 {
377 overall_max / 21.0
378 } else {
379 0.0
380 };
381 let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
382
383 let mut scale_nibbles = [0u8; 16];
385 for (sub, &max_abs) in sub_max_abs.iter().enumerate() {
386 let sc_f = if d > 0.0 { max_abs * inv_d / 3.0 } else { 0.0 };
388 let sc_signed = sc_f.round().clamp(-8.0, 7.0) as i32;
389 scale_nibbles[sub] = (sc_signed + 8).clamp(0, 15) as u8;
391 }
392
393 let mut scales = [0u8; 12];
395 for (sub, &nibble_val) in scale_nibbles.iter().enumerate() {
396 let byte_idx = sub / 2;
397 let nibble = nibble_val & 0x0F;
398 if sub % 2 == 0 {
399 scales[byte_idx] |= nibble;
400 } else {
401 scales[byte_idx] |= nibble << 4;
402 }
403 }
404
405 let mut hmask = [0u8; 32];
407 let mut qs = [0u8; 64];
408
409 for i in 0..QK_K {
410 let sub = i / 16;
411 let sc_signed = (scale_nibbles[sub] as i32) - 8;
412 let eff_scale = d * (sc_signed as f32);
414 let inv_eff = if eff_scale.abs() > 1e-9 {
415 1.0 / eff_scale
416 } else {
417 0.0
418 };
419
420 let q3_signed = (chunk[i] * inv_eff).round() as i32;
422 let q3 = (q3_signed + 4).clamp(0, 7) as u8;
423
424 let lo2 = q3 & 0x03;
426 let byte_idx = i / 4;
427 let bit_shift = (i % 4) * 2;
428 qs[byte_idx] |= lo2 << bit_shift;
429
430 let hi1 = (q3 >> 2) & 0x01;
432 hmask[i / 8] |= hi1 << (i % 8);
433 }
434
435 blocks.push(BlockQ3K {
436 hmask,
437 qs,
438 scales,
439 d: f16::from_f32(d),
440 });
441 }
442
443 Ok(blocks)
444 }
445
446 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
451 if data.len() % BLOCK_Q3K_BYTES != 0 {
452 return Err(BonsaiError::KQuantError {
453 reason: format!(
454 "Q3_K slice_from_bytes: byte len {} not a multiple of {}",
455 data.len(),
456 BLOCK_Q3K_BYTES
457 ),
458 });
459 }
460 if data.is_empty() {
461 return Ok(&[]);
462 }
463 let align = std::mem::align_of::<Self>();
464 if data.as_ptr().align_offset(align) != 0 {
465 return Err(BonsaiError::KQuantError {
466 reason: format!("Q3_K slice_from_bytes: pointer not {}-byte aligned", align),
467 });
468 }
469 let count = data.len() / BLOCK_Q3K_BYTES;
470 let ptr = data.as_ptr() as *const Self;
471 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
475 }
476}
477
478#[derive(Debug, Clone, Copy, PartialEq)]
494#[repr(C)]
495pub struct BlockQ4K {
496 pub d: f16,
498 pub dmin: f16,
500 pub scales: [u8; 12],
502 pub qs: [u8; 128],
504}
505
506const _: () = assert!(std::mem::size_of::<BlockQ4K>() == BLOCK_Q4_K_BYTES);
507
508fn decode_q4k_scales(scales_raw: &[u8; 12]) -> ([u8; 8], [u8; 8]) {
522 let mut sc = [0u8; 8];
523 let mut mn = [0u8; 8];
524
525 for i in 0..4 {
527 sc[2 * i] = scales_raw[i] & 0x0F;
528 sc[2 * i + 1] = (scales_raw[i] >> 4) & 0x0F;
529 }
530
531 for i in 0..4 {
533 mn[2 * i] = scales_raw[4 + i] & 0x0F;
534 mn[2 * i + 1] = (scales_raw[4 + i] >> 4) & 0x0F;
535 }
536
537 for i in 0..4 {
539 sc[i] |= ((scales_raw[8] >> (2 * i)) & 0x03) << 4;
540 sc[4 + i] |= ((scales_raw[9] >> (2 * i)) & 0x03) << 4;
541 }
542
543 for i in 0..4 {
545 mn[i] |= ((scales_raw[10] >> (2 * i)) & 0x03) << 4;
546 mn[4 + i] |= ((scales_raw[11] >> (2 * i)) & 0x03) << 4;
547 }
548
549 (sc, mn)
550}
551
552fn encode_q4k_scales(sc: &[u8; 8], mn: &[u8; 8]) -> [u8; 12] {
555 let mut out = [0u8; 12];
556
557 for i in 0..4 {
559 out[i] = (sc[2 * i] & 0x0F) | ((sc[2 * i + 1] & 0x0F) << 4);
560 }
561
562 for i in 0..4 {
564 out[4 + i] = (mn[2 * i] & 0x0F) | ((mn[2 * i + 1] & 0x0F) << 4);
565 }
566
567 for i in 0..4 {
569 out[8] |= ((sc[i] >> 4) & 0x03) << (2 * i);
570 out[9] |= ((sc[4 + i] >> 4) & 0x03) << (2 * i);
571 }
572
573 for i in 0..4 {
575 out[10] |= ((mn[i] >> 4) & 0x03) << (2 * i);
576 out[11] |= ((mn[4 + i] >> 4) & 0x03) << (2 * i);
577 }
578
579 out
580}
581
582impl BlockQ4K {
583 pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
587 let expected_len = blocks.len() * QK_K;
588 if output.len() < expected_len {
589 return Err(BonsaiError::KQuantError {
590 reason: format!(
591 "Q4_K dequant: output len {} < expected {}",
592 output.len(),
593 expected_len
594 ),
595 });
596 }
597
598 for (block_idx, block) in blocks.iter().enumerate() {
599 let d = block.d.to_f32();
600 let dmin_val = block.dmin.to_f32();
601 let base = block_idx * QK_K;
602
603 let (sc, mn) = decode_q4k_scales(&block.scales);
604
605 for sub in 0..8 {
607 let sub_scale = d * (sc[sub] as f32);
608 let sub_min = dmin_val * (mn[sub] as f32);
609 let sub_offset = sub * 32;
610
611 for j in 0..32 {
612 let global_idx = sub_offset + j;
613 let byte_idx = global_idx / 2;
614 let q = if global_idx % 2 == 0 {
615 (block.qs[byte_idx] & 0x0F) as f32
616 } else {
617 ((block.qs[byte_idx] >> 4) & 0x0F) as f32
618 };
619 output[base + global_idx] = sub_scale * q - sub_min;
620 }
621 }
622 }
623 Ok(())
624 }
625
626 pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
630 if input.len() % QK_K != 0 {
631 return Err(BonsaiError::KQuantError {
632 reason: format!(
633 "Q4_K quantize: input len {} not a multiple of {}",
634 input.len(),
635 QK_K
636 ),
637 });
638 }
639
640 let num_blocks = input.len() / QK_K;
641 let mut blocks = Vec::with_capacity(num_blocks);
642
643 for block_idx in 0..num_blocks {
644 let base = block_idx * QK_K;
645 let chunk = &input[base..base + QK_K];
646
647 let mut sub_scales = [0.0f32; 8];
649 let mut sub_mins = [0.0f32; 8];
650
651 for sub in 0..8 {
652 let sub_offset = sub * 32;
653 let sub_chunk = &chunk[sub_offset..sub_offset + 32];
654
655 let mut smin = f32::MAX;
656 let mut smax = f32::MIN;
657 for &v in sub_chunk {
658 if v < smin {
659 smin = v;
660 }
661 if v > smax {
662 smax = v;
663 }
664 }
665
666 sub_mins[sub] = if smin < 0.0 { -smin } else { 0.0 };
667 let range = smax + sub_mins[sub];
668 sub_scales[sub] = if range > 0.0 { range / 15.0 } else { 0.0 };
669 }
670
671 let max_scale = sub_scales.iter().copied().fold(0.0f32, f32::max);
672 let max_min = sub_mins.iter().copied().fold(0.0f32, f32::max);
673
674 let d = if max_scale > 0.0 {
676 max_scale / 63.0
677 } else {
678 0.0
679 };
680 let dmin = if max_min > 0.0 { max_min / 63.0 } else { 0.0 };
681
682 let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
683 let inv_dmin = if dmin > 0.0 { 1.0 / dmin } else { 0.0 };
684
685 let mut sc = [0u8; 8];
686 let mut mn = [0u8; 8];
687
688 for sub in 0..8 {
689 sc[sub] = (sub_scales[sub] * inv_d + 0.5).min(63.0) as u8;
690 mn[sub] = (sub_mins[sub] * inv_dmin + 0.5).min(63.0) as u8;
691 }
692
693 let scales = encode_q4k_scales(&sc, &mn);
694
695 let mut qs = [0u8; 128];
697 for sub in 0..8 {
698 let sub_offset = sub * 32;
699 let sc_f = d * (sc[sub] as f32);
700 let mn_f = dmin * (mn[sub] as f32);
701 let inv_sc = if sc_f > 0.0 { 1.0 / sc_f } else { 0.0 };
702
703 for j in 0..32 {
704 let global_idx = sub_offset + j;
705 let val = chunk[global_idx] + mn_f;
706 let q = (val * inv_sc + 0.5).clamp(0.0, 15.0) as u8;
707 let byte_idx = global_idx / 2;
708 if global_idx % 2 == 0 {
709 qs[byte_idx] |= q & 0x0F;
710 } else {
711 qs[byte_idx] |= (q & 0x0F) << 4;
712 }
713 }
714 }
715
716 blocks.push(BlockQ4K {
717 d: f16::from_f32(d),
718 dmin: f16::from_f32(dmin),
719 scales,
720 qs,
721 });
722 }
723
724 Ok(blocks)
725 }
726
727 pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
731 let start = buf.len();
732 let n = blocks_for_row.len() * QK_K;
733 buf.resize(start + n, 0.0f32);
734 let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
735 }
736
737 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
742 if data.len() % BLOCK_Q4_K_BYTES != 0 {
743 return Err(BonsaiError::KQuantError {
744 reason: format!(
745 "Q4_K slice_from_bytes: byte len {} not a multiple of {}",
746 data.len(),
747 BLOCK_Q4_K_BYTES
748 ),
749 });
750 }
751 if data.is_empty() {
752 return Ok(&[]);
753 }
754 let align = std::mem::align_of::<Self>();
755 if data.as_ptr().align_offset(align) != 0 {
756 return Err(BonsaiError::KQuantError {
757 reason: format!("Q4_K slice_from_bytes: pointer not {}-byte aligned", align),
758 });
759 }
760 let count = data.len() / BLOCK_Q4_K_BYTES;
761 let ptr = data.as_ptr() as *const Self;
762 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
766 }
767}
768
769#[derive(Debug, Clone, Copy, PartialEq)]
782#[repr(C)]
783pub struct BlockQ8K {
784 pub d: f32,
786 pub qs: [i8; 256],
788 pub bsums: [i16; 16],
790}
791
792const _: () = assert!(std::mem::size_of::<BlockQ8K>() == BLOCK_Q8K_BYTES);
793
794impl BlockQ8K {
795 pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
799 let expected_len = blocks.len() * QK_K;
800 if output.len() < expected_len {
801 return Err(BonsaiError::KQuantError {
802 reason: format!(
803 "Q8_K dequant: output len {} < expected {}",
804 output.len(),
805 expected_len
806 ),
807 });
808 }
809
810 for (block_idx, block) in blocks.iter().enumerate() {
811 let d = block.d;
812 let base = block_idx * QK_K;
813 for i in 0..QK_K {
814 output[base + i] = d * (block.qs[i] as f32);
815 }
816 }
817 Ok(())
818 }
819
820 pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
824 let start = buf.len();
825 let n = blocks_for_row.len() * QK_K;
826 buf.resize(start + n, 0.0f32);
827 let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
828 }
829
830 pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
838 if input.len() % QK_K != 0 {
839 return Err(BonsaiError::KQuantError {
840 reason: format!(
841 "Q8_K quantize: input len {} not a multiple of {}",
842 input.len(),
843 QK_K
844 ),
845 });
846 }
847
848 let num_blocks = input.len() / QK_K;
849 let mut blocks = Vec::with_capacity(num_blocks);
850
851 for block_idx in 0..num_blocks {
852 let chunk = &input[block_idx * QK_K..block_idx * QK_K + QK_K];
853
854 let max_abs = chunk.iter().map(|&v| v.abs()).fold(0.0f32, f32::max);
856
857 let d = if max_abs > 0.0 { max_abs / 127.0 } else { 0.0 };
858 let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
859
860 let mut qs = [0i8; 256];
861 for (i, &w) in chunk.iter().enumerate() {
862 qs[i] = (w * inv_d).round().clamp(-127.0, 127.0) as i8;
863 }
864
865 let mut bsums = [0i16; 16];
867 for (group, slot) in bsums.iter_mut().enumerate() {
868 let group_start = group * 16;
869 let sum: i32 = qs[group_start..group_start + 16]
870 .iter()
871 .map(|&q| q as i32)
872 .sum();
873 *slot = sum.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
874 }
875
876 blocks.push(BlockQ8K { d, qs, bsums });
877 }
878
879 Ok(blocks)
880 }
881
882 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
887 if data.len() % BLOCK_Q8K_BYTES != 0 {
888 return Err(BonsaiError::KQuantError {
889 reason: format!(
890 "Q8_K slice_from_bytes: byte len {} not a multiple of {}",
891 data.len(),
892 BLOCK_Q8K_BYTES
893 ),
894 });
895 }
896 if data.is_empty() {
897 return Ok(&[]);
898 }
899 let align = std::mem::align_of::<Self>();
900 if data.as_ptr().align_offset(align) != 0 {
901 return Err(BonsaiError::KQuantError {
902 reason: format!("Q8_K slice_from_bytes: pointer not {}-byte aligned", align),
903 });
904 }
905 let count = data.len() / BLOCK_Q8K_BYTES;
906 let ptr = data.as_ptr() as *const Self;
907 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
911 }
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917
918 #[test]
919 fn q2k_block_size_correct() {
920 assert_eq!(std::mem::size_of::<BlockQ2K>(), BLOCK_Q2_K_BYTES);
921 assert_eq!(BLOCK_Q2_K_BYTES, 84);
922 }
923
924 #[test]
925 fn q2k_roundtrip_zero_weights() {
926 let blocks = BlockQ2K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
927 let mut out = vec![0.0f32; 256];
928 BlockQ2K::dequant(&blocks, &mut out).expect("dequant ok");
929 for &v in &out {
930 assert!(
931 v.abs() < 1e-4,
932 "all-zero input should dequant to near-zero, got {v}"
933 );
934 }
935 }
936
937 #[test]
938 fn q2k_roundtrip_uniform() {
939 let input = vec![1.0f32; 256];
940 let blocks = BlockQ2K::quantize(&input).expect("quantize ok");
941 let mut out = vec![0.0f32; 256];
942 BlockQ2K::dequant(&blocks, &mut out).expect("dequant ok");
943 for &v in &out {
944 let err = (v - 1.0).abs();
945 assert!(err < 0.2, "uniform round-trip error {err} too high");
946 }
947 }
948
949 #[test]
950 fn q2k_quantize_output_length() {
951 let input = vec![0.5f32; 256];
952 let blocks = BlockQ2K::quantize(&input).expect("quantize ok");
953 assert_eq!(blocks.len(), 1);
954 }
955
956 #[test]
957 fn q2k_slice_from_bytes_empty() {
958 let data: Vec<u8> = vec![];
959 let result = BlockQ2K::slice_from_bytes(&data).expect("empty slice ok");
960 assert_eq!(result.len(), 0);
961 }
962
963 #[test]
964 fn q2k_slice_from_bytes_bad_length() {
965 let data = vec![0u8; 83]; assert!(BlockQ2K::slice_from_bytes(&data).is_err());
967 }
968
969 #[test]
970 fn q4k_block_size_correct() {
971 assert_eq!(std::mem::size_of::<BlockQ4K>(), BLOCK_Q4_K_BYTES);
972 assert_eq!(BLOCK_Q4_K_BYTES, 144);
973 }
974
975 #[test]
976 fn q4k_scale_encode_decode_roundtrip() {
977 let sc = [1, 2, 3, 4, 5, 63, 32, 0];
978 let mn = [10, 20, 30, 40, 50, 60, 15, 7];
979 let encoded = encode_q4k_scales(&sc, &mn);
980 let (sc2, mn2) = decode_q4k_scales(&encoded);
981 assert_eq!(sc, sc2);
982 assert_eq!(mn, mn2);
983 }
984
985 #[test]
986 fn q4k_scale_encode_decode_all_zeros() {
987 let sc = [0u8; 8];
988 let mn = [0u8; 8];
989 let encoded = encode_q4k_scales(&sc, &mn);
990 let (sc2, mn2) = decode_q4k_scales(&encoded);
991 assert_eq!(sc, sc2);
992 assert_eq!(mn, mn2);
993 }
994
995 #[test]
996 fn q4k_scale_encode_decode_max_values() {
997 let sc = [63u8; 8];
998 let mn = [63u8; 8];
999 let encoded = encode_q4k_scales(&sc, &mn);
1000 let (sc2, mn2) = decode_q4k_scales(&encoded);
1001 assert_eq!(sc, sc2);
1002 assert_eq!(mn, mn2);
1003 }
1004
1005 #[test]
1006 fn q4k_slice_from_bytes_empty() {
1007 let data: Vec<u8> = vec![];
1008 let result = BlockQ4K::slice_from_bytes(&data).expect("empty slice ok");
1009 assert_eq!(result.len(), 0);
1010 }
1011
1012 #[test]
1013 fn q4k_slice_from_bytes_bad_length() {
1014 let data = vec![0u8; 100]; assert!(BlockQ4K::slice_from_bytes(&data).is_err());
1016 }
1017
1018 #[test]
1023 fn q3k_block_size_assertion() {
1024 assert_eq!(std::mem::size_of::<BlockQ3K>(), BLOCK_Q3K_BYTES);
1025 assert_eq!(BLOCK_Q3K_BYTES, 110);
1026 }
1027
1028 #[test]
1029 fn q3k_roundtrip_zero_weights() {
1030 let blocks = BlockQ3K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
1031 let mut out = vec![0.0f32; 256];
1032 BlockQ3K::dequant(&blocks, &mut out).expect("dequant ok");
1033 for &v in &out {
1034 assert!(
1035 v.abs() < 1e-4,
1036 "all-zero input should dequant to near-zero, got {v}"
1037 );
1038 }
1039 }
1040
1041 #[test]
1042 fn q3k_roundtrip_uniform() {
1043 let input = vec![1.0f32; 256];
1045 let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
1046 let mut out = vec![0.0f32; 256];
1047 BlockQ3K::dequant(&blocks, &mut out).expect("dequant ok");
1048 for &v in &out {
1049 let err = (v - 1.0).abs() / 1.0;
1050 assert!(
1051 err < 0.5,
1052 "uniform round-trip rel error {err} too high, got {v}"
1053 );
1054 }
1055 }
1056
1057 #[test]
1058 fn q3k_slice_from_bytes() {
1059 let data = vec![0u8; BLOCK_Q3K_BYTES];
1061 let result = BlockQ3K::slice_from_bytes(&data).expect("single block should parse");
1062 assert_eq!(result.len(), 1);
1063 }
1064
1065 #[test]
1066 fn q3k_slice_from_bytes_empty() {
1067 let data: Vec<u8> = vec![];
1068 let result = BlockQ3K::slice_from_bytes(&data).expect("empty slice ok");
1069 assert_eq!(result.len(), 0);
1070 }
1071
1072 #[test]
1073 fn q3k_slice_from_bytes_bad_length() {
1074 let data = vec![0u8; 100]; assert!(BlockQ3K::slice_from_bytes(&data).is_err());
1076 }
1077
1078 #[test]
1079 fn q3k_quantize_output_length() {
1080 let input = vec![0.5f32; 256];
1081 let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
1082 assert_eq!(blocks.len(), 1, "256 weights → 1 block");
1083 }
1084
1085 #[test]
1086 fn q3k_quantize_non_multiple_errors() {
1087 assert!(BlockQ3K::quantize(&vec![1.0f32; 100]).is_err());
1088 }
1089
1090 #[test]
1091 fn q3k_dequant_output_too_small_errors() {
1092 let blocks = BlockQ3K::quantize(&vec![1.0f32; 256]).expect("quantize ok");
1093 let mut out = vec![0.0f32; 100];
1094 assert!(BlockQ3K::dequant(&blocks, &mut out).is_err());
1095 }
1096
1097 #[test]
1098 fn q3k_dequant_row_to_buf_works() {
1099 let input = vec![0.5f32; 256];
1100 let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
1101 let mut buf = Vec::new();
1102 BlockQ3K::dequant_row_to_buf(&blocks, &mut buf);
1103 assert_eq!(buf.len(), 256);
1104 }
1105
1106 #[test]
1111 fn q8k_block_size_assertion() {
1112 assert_eq!(std::mem::size_of::<BlockQ8K>(), BLOCK_Q8K_BYTES);
1113 assert_eq!(BLOCK_Q8K_BYTES, 292);
1114 }
1115
1116 #[test]
1117 fn q8k_roundtrip_zero_weights() {
1118 let blocks = BlockQ8K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
1119 let mut out = vec![0.0f32; 256];
1120 BlockQ8K::dequant(&blocks, &mut out).expect("dequant ok");
1121 for &v in &out {
1122 assert!(
1123 v.abs() < 1e-6,
1124 "all-zero input should dequant to exactly zero, got {v}"
1125 );
1126 }
1127 }
1128
1129 #[test]
1130 fn q8k_roundtrip_uniform() {
1131 let input = vec![1.0f32; 256];
1132 let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
1133 let mut out = vec![0.0f32; 256];
1134 BlockQ8K::dequant(&blocks, &mut out).expect("dequant ok");
1135 for &v in &out {
1136 let err = (v - 1.0).abs();
1137 assert!(err < 0.02, "Q8_K uniform round-trip error {err} too high");
1138 }
1139 }
1140
1141 #[test]
1142 fn q8k_slice_from_bytes() {
1143 let data = vec![0u8; BLOCK_Q8K_BYTES];
1144 let result = BlockQ8K::slice_from_bytes(&data).expect("single block should parse");
1145 assert_eq!(result.len(), 1);
1146 }
1147
1148 #[test]
1149 fn q8k_slice_from_bytes_empty() {
1150 let data: Vec<u8> = vec![];
1151 let result = BlockQ8K::slice_from_bytes(&data).expect("empty slice ok");
1152 assert_eq!(result.len(), 0);
1153 }
1154
1155 #[test]
1156 fn q8k_slice_from_bytes_bad_length() {
1157 let data = vec![0u8; 100]; assert!(BlockQ8K::slice_from_bytes(&data).is_err());
1159 }
1160
1161 #[test]
1162 fn q8k_quantize_output_length() {
1163 let input = vec![0.5f32; 256];
1164 let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
1165 assert_eq!(blocks.len(), 1, "256 weights → 1 block");
1166 }
1167
1168 #[test]
1169 fn q8k_quantize_non_multiple_errors() {
1170 assert!(BlockQ8K::quantize(&vec![1.0f32; 100]).is_err());
1171 }
1172
1173 #[test]
1174 fn q8k_dequant_output_too_small_errors() {
1175 let blocks = BlockQ8K::quantize(&vec![1.0f32; 256]).expect("quantize ok");
1176 let mut out = vec![0.0f32; 100];
1177 assert!(BlockQ8K::dequant(&blocks, &mut out).is_err());
1178 }
1179
1180 #[test]
1181 fn q8k_dequant_row_to_buf_works() {
1182 let input = vec![0.5f32; 256];
1183 let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
1184 let mut buf = Vec::new();
1185 BlockQ8K::dequant_row_to_buf(&blocks, &mut buf);
1186 assert_eq!(buf.len(), 256);
1187 for &v in &buf {
1188 assert!((v - 0.5).abs() < 0.01, "expected ~0.5, got {v}");
1189 }
1190 }
1191
1192 #[test]
1193 fn q8k_bsums_roundtrip_sign() {
1194 let input_pos = vec![0.5f32; 256];
1196 let blocks_pos = BlockQ8K::quantize(&input_pos).expect("quantize ok");
1197 for &bs in &blocks_pos[0].bsums {
1198 assert!(
1199 bs > 0,
1200 "positive input should yield positive bsums, got {bs}"
1201 );
1202 }
1203
1204 let input_neg = vec![-0.5f32; 256];
1205 let blocks_neg = BlockQ8K::quantize(&input_neg).expect("quantize ok");
1206 for &bs in &blocks_neg[0].bsums {
1207 assert!(
1208 bs < 0,
1209 "negative input should yield negative bsums, got {bs}"
1210 );
1211 }
1212 }
1213}