1use half::f16;
15
16use crate::error::{BonsaiError, BonsaiResult};
17
18pub const QK_Q4_0: usize = 32;
24
25pub const BLOCK_Q4_0_BYTES: usize = 18;
27
28pub const QK_Q8_0: usize = 32;
30
31pub const BLOCK_Q8_0_BYTES: usize = 34;
33
34#[derive(Debug, Clone, Copy, PartialEq)]
47#[repr(C)]
48pub struct BlockQ4_0 {
49 pub d: f16,
51 pub qs: [u8; 16],
53}
54
55const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == BLOCK_Q4_0_BYTES);
56
57impl BlockQ4_0 {
58 pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
62 let expected_len = blocks.len() * QK_Q4_0;
63 if output.len() < expected_len {
64 return Err(BonsaiError::KQuantError {
65 reason: format!(
66 "Q4_0 dequant: output len {} < expected {}",
67 output.len(),
68 expected_len
69 ),
70 });
71 }
72 for (block_idx, block) in blocks.iter().enumerate() {
73 let d = block.d.to_f32();
74 let base = block_idx * QK_Q4_0;
75 for j in 0..QK_Q4_0 {
76 let nibble = if j % 2 == 0 {
77 (block.qs[j / 2] & 0x0F) as f32
78 } else {
79 ((block.qs[j / 2] >> 4) & 0x0F) as f32
80 };
81 output[base + j] = d * (nibble - 8.0);
82 }
83 }
84 Ok(())
85 }
86
87 pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
93 if input.len() % QK_Q4_0 != 0 {
94 return Err(BonsaiError::KQuantError {
95 reason: format!(
96 "Q4_0 quantize: input len {} not a multiple of {}",
97 input.len(),
98 QK_Q4_0
99 ),
100 });
101 }
102 let num_blocks = input.len() / QK_Q4_0;
103 let mut blocks = Vec::with_capacity(num_blocks);
104
105 for block_idx in 0..num_blocks {
106 let base = block_idx * QK_Q4_0;
107 let chunk = &input[base..base + QK_Q4_0];
108
109 let max_abs = chunk
110 .iter()
111 .filter(|v| !v.is_nan())
112 .map(|v| v.abs())
113 .fold(0.0f32, f32::max);
114
115 if max_abs == 0.0 {
116 blocks.push(BlockQ4_0 {
117 d: f16::ZERO,
118 qs: [0x88u8; 16], });
120 continue;
121 }
122
123 let scale = max_abs / 7.0;
125 let d = f16::from_f32(scale);
126 let scale_actual = d.to_f32();
128 let inv_scale = if scale_actual == 0.0 {
129 0.0
130 } else {
131 1.0 / scale_actual
132 };
133
134 let mut qs = [0u8; 16];
135 for j in 0..QK_Q4_0 {
136 let v = chunk[j];
137 let nibble = (v * inv_scale + 8.5).clamp(0.0, 15.0) as u8;
139 if j % 2 == 0 {
140 qs[j / 2] = nibble & 0x0F;
141 } else {
142 qs[j / 2] |= (nibble & 0x0F) << 4;
143 }
144 }
145
146 blocks.push(BlockQ4_0 { d, qs });
147 }
148 Ok(blocks)
149 }
150
151 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
159 if data.len() % BLOCK_Q4_0_BYTES != 0 {
160 return Err(BonsaiError::KQuantError {
161 reason: format!(
162 "Q4_0 slice_from_bytes: byte len {} not a multiple of {}",
163 data.len(),
164 BLOCK_Q4_0_BYTES
165 ),
166 });
167 }
168 let align = std::mem::align_of::<Self>();
169 if data.as_ptr().align_offset(align) != 0 {
170 return Err(BonsaiError::KQuantError {
171 reason: format!("Q4_0 slice_from_bytes: pointer not {}-byte aligned", align),
172 });
173 }
174 let count = data.len() / BLOCK_Q4_0_BYTES;
175 let ptr = data.as_ptr() as *const Self;
176 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
179 }
180
181 #[inline]
185 pub fn dequant_to_buf(&self, buf: &mut [f32; 32]) {
186 let d = self.d.to_f32();
187 for (j, out) in buf.iter_mut().enumerate() {
188 let nibble = if j % 2 == 0 {
189 (self.qs[j / 2] & 0x0F) as f32
190 } else {
191 ((self.qs[j / 2] >> 4) & 0x0F) as f32
192 };
193 *out = d * (nibble - 8.0);
194 }
195 }
196}
197
198#[derive(Debug, Clone, Copy, PartialEq)]
210#[repr(C)]
211pub struct BlockQ8_0 {
212 pub d: f16,
214 pub qs: [i8; 32],
216}
217
218const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == BLOCK_Q8_0_BYTES);
219
220impl BlockQ8_0 {
221 pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
225 let expected_len = blocks.len() * QK_Q8_0;
226 if output.len() < expected_len {
227 return Err(BonsaiError::KQuantError {
228 reason: format!(
229 "Q8_0 dequant: output len {} < expected {}",
230 output.len(),
231 expected_len
232 ),
233 });
234 }
235 for (block_idx, block) in blocks.iter().enumerate() {
236 let d = block.d.to_f32();
237 let base = block_idx * QK_Q8_0;
238 for (j, &q) in block.qs.iter().enumerate() {
239 output[base + j] = d * (q as f32);
240 }
241 }
242 Ok(())
243 }
244
245 pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
251 if input.len() % QK_Q8_0 != 0 {
252 return Err(BonsaiError::KQuantError {
253 reason: format!(
254 "Q8_0 quantize: input len {} not a multiple of {}",
255 input.len(),
256 QK_Q8_0
257 ),
258 });
259 }
260 let num_blocks = input.len() / QK_Q8_0;
261 let mut blocks = Vec::with_capacity(num_blocks);
262
263 for block_idx in 0..num_blocks {
264 let base = block_idx * QK_Q8_0;
265 let chunk = &input[base..base + QK_Q8_0];
266
267 let max_abs = chunk
268 .iter()
269 .filter(|v| !v.is_nan())
270 .map(|v| v.abs())
271 .fold(0.0f32, f32::max);
272
273 if max_abs == 0.0 {
274 blocks.push(BlockQ8_0 {
275 d: f16::ZERO,
276 qs: [0i8; 32],
277 });
278 continue;
279 }
280
281 let scale = max_abs / 127.0;
282 let d = f16::from_f32(scale);
283 let scale_actual = d.to_f32();
284 let inv_scale = if scale_actual == 0.0 {
285 0.0
286 } else {
287 1.0 / scale_actual
288 };
289
290 let mut qs = [0i8; 32];
291 for (j, &v) in chunk.iter().enumerate() {
292 let q = (v * inv_scale).round().clamp(-127.0, 127.0) as i8;
293 qs[j] = q;
294 }
295
296 blocks.push(BlockQ8_0 { d, qs });
297 }
298 Ok(blocks)
299 }
300
301 pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
309 if data.len() % BLOCK_Q8_0_BYTES != 0 {
310 return Err(BonsaiError::KQuantError {
311 reason: format!(
312 "Q8_0 slice_from_bytes: byte len {} not a multiple of {}",
313 data.len(),
314 BLOCK_Q8_0_BYTES
315 ),
316 });
317 }
318 let align = std::mem::align_of::<Self>();
319 if data.as_ptr().align_offset(align) != 0 {
320 return Err(BonsaiError::KQuantError {
321 reason: format!("Q8_0 slice_from_bytes: pointer not {}-byte aligned", align),
322 });
323 }
324 let count = data.len() / BLOCK_Q8_0_BYTES;
325 let ptr = data.as_ptr() as *const Self;
326 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
329 }
330
331 #[inline]
335 pub fn dequant_to_buf(&self, buf: &mut [f32; 32]) {
336 let d = self.d.to_f32();
337 for (j, &q) in self.qs.iter().enumerate() {
338 buf[j] = d * (q as f32);
339 }
340 }
341}
342
343#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
354 fn q4_0_block_size_correct() {
355 assert_eq!(std::mem::size_of::<BlockQ4_0>(), BLOCK_Q4_0_BYTES);
356 assert_eq!(BLOCK_Q4_0_BYTES, 18);
357 }
358
359 #[test]
360 fn q8_0_block_size_correct() {
361 assert_eq!(std::mem::size_of::<BlockQ8_0>(), BLOCK_Q8_0_BYTES);
362 assert_eq!(BLOCK_Q8_0_BYTES, 34);
363 }
364
365 #[test]
366 fn qk_constants_correct() {
367 assert_eq!(QK_Q4_0, 32);
368 assert_eq!(QK_Q8_0, 32);
369 }
370
371 #[test]
374 fn q4_0_dequant_roundtrip() {
375 let values: Vec<f32> = (0..32).map(|i| (i as f32) * 0.5 - 7.5).collect();
376 let blocks = BlockQ4_0::quantize(&values).unwrap();
377 assert_eq!(blocks.len(), 1);
378 let mut output = vec![0.0f32; 32];
379 BlockQ4_0::dequant(&blocks, &mut output).unwrap();
380 let max_err: f32 = values
381 .iter()
382 .zip(output.iter())
383 .map(|(a, b)| (a - b).abs())
384 .fold(0.0f32, f32::max);
385 assert!(
387 max_err < 1.5,
388 "Q4_0 round-trip max error: {max_err} (values range ±7.5)"
389 );
390 }
391
392 #[test]
393 fn q4_0_all_zeros() {
394 let values = vec![0.0f32; 32];
395 let blocks = BlockQ4_0::quantize(&values).unwrap();
396 let mut output = vec![0.0f32; 32];
397 BlockQ4_0::dequant(&blocks, &mut output).unwrap();
398 assert!(
399 output.iter().all(|&x| x == 0.0),
400 "all-zero input should give all-zero output"
401 );
402 }
403
404 #[test]
405 fn q4_0_nibble_extremes() {
406 let mut values = vec![0.0f32; 32];
409 values[0] = 7.0;
410 values[1] = -7.0;
411 let blocks = BlockQ4_0::quantize(&values).unwrap();
412 let mut output = vec![0.0f32; 32];
413 BlockQ4_0::dequant(&blocks, &mut output).unwrap();
414 assert!(
415 (output[0] - 7.0).abs() < 1.1,
416 "max weight round-trip: got {}",
417 output[0]
418 );
419 assert!(
420 (output[1] + 7.0).abs() < 1.1,
421 "min weight round-trip: got {}",
422 output[1]
423 );
424 }
425
426 #[test]
427 fn q4_0_slice_from_bytes_valid() {
428 let block = BlockQ4_0 {
429 d: f16::from_f32(1.0),
430 qs: [0x88u8; 16],
431 };
432 let bytes: &[u8] = unsafe {
433 std::slice::from_raw_parts((&block as *const BlockQ4_0).cast::<u8>(), BLOCK_Q4_0_BYTES)
434 };
435 let result = BlockQ4_0::slice_from_bytes(bytes).expect("aligned slice should succeed");
436 assert_eq!(result.len(), 1);
437 assert_eq!(result[0].d, f16::from_f32(1.0));
438 }
439
440 #[test]
441 fn q4_0_slice_from_bytes_bad_len() {
442 let data = vec![0u8; 17]; assert!(
444 BlockQ4_0::slice_from_bytes(&data).is_err(),
445 "bad length should be rejected"
446 );
447 }
448
449 #[test]
450 fn q4_0_block_count_validation() {
451 let values = vec![1.0f32; 96]; let blocks = BlockQ4_0::quantize(&values).unwrap();
453 assert_eq!(blocks.len(), 3);
454 }
455
456 #[test]
457 fn q4_0_quantize_wrong_len() {
458 assert!(
459 BlockQ4_0::quantize(&[1.0f32; 15]).is_err(),
460 "non-multiple of 32 should be rejected"
461 );
462 }
463
464 #[test]
465 fn q4_0_dequant_too_small_buffer() {
466 let blocks = BlockQ4_0::quantize(&[1.0f32; 32]).unwrap();
467 let mut out = vec![0.0f32; 10];
468 assert!(
469 BlockQ4_0::dequant(&blocks, &mut out).is_err(),
470 "output too small should be rejected"
471 );
472 }
473
474 #[test]
475 fn q4_0_dequant_to_buf_matches_dequant() {
476 let values: Vec<f32> = (0..32).map(|i| (i as f32) - 16.0).collect();
477 let blocks = BlockQ4_0::quantize(&values).unwrap();
478 let mut full_out = vec![0.0f32; 32];
479 BlockQ4_0::dequant(&blocks, &mut full_out).unwrap();
480 let mut buf = [0.0f32; 32];
481 blocks[0].dequant_to_buf(&mut buf);
482 for (a, b) in full_out.iter().zip(buf.iter()) {
483 assert!((a - b).abs() < 1e-6, "dequant_to_buf must match dequant");
484 }
485 }
486
487 #[test]
488 fn q4_0_multi_block_no_nan() {
489 let values: Vec<f32> = (0..64).map(|i| (i as f32) * 0.25 - 8.0).collect();
490 let blocks = BlockQ4_0::quantize(&values).unwrap();
491 assert_eq!(blocks.len(), 2);
492 let mut out = vec![0.0f32; 64];
493 BlockQ4_0::dequant(&blocks, &mut out).unwrap();
494 assert!(out.iter().all(|x| !x.is_nan()), "no NaN in output");
495 }
496
497 #[test]
498 fn q4_0_scale_nonzero_for_nonzero_input() {
499 let values = vec![1.0f32; 32];
500 let blocks = BlockQ4_0::quantize(&values).unwrap();
501 assert_ne!(blocks[0].d, f16::ZERO, "scale must be non-zero");
502 }
503
504 #[test]
507 fn q8_0_dequant_roundtrip() {
508 let values: Vec<f32> = (0..32).map(|i| (i as f32) * 0.1 - 1.6).collect();
509 let blocks = BlockQ8_0::quantize(&values).unwrap();
510 let mut output = vec![0.0f32; 32];
511 BlockQ8_0::dequant(&blocks, &mut output).unwrap();
512 let max_err: f32 = values
513 .iter()
514 .zip(output.iter())
515 .map(|(a, b)| (a - b).abs())
516 .fold(0.0f32, f32::max);
517 assert!(
518 max_err < 0.05,
519 "Q8_0 round-trip max error: {max_err} (8-bit should be very accurate)"
520 );
521 }
522
523 #[test]
524 fn q8_0_all_zeros() {
525 let values = vec![0.0f32; 32];
526 let blocks = BlockQ8_0::quantize(&values).unwrap();
527 let mut output = vec![0.0f32; 32];
528 BlockQ8_0::dequant(&blocks, &mut output).unwrap();
529 assert!(output.iter().all(|&x| x == 0.0));
530 }
531
532 #[test]
533 fn q8_0_int8_extremes() {
534 let mut values = vec![0.0f32; 32];
536 values[0] = 127.0;
537 values[1] = -127.0;
538 let blocks = BlockQ8_0::quantize(&values).unwrap();
539 let scale = blocks[0].d.to_f32();
541 assert!((scale - 1.0).abs() < 0.01, "scale should be ~1.0: {scale}");
542 assert_eq!(blocks[0].qs[0], 127, "max quantized to 127");
543 assert_eq!(blocks[0].qs[1], -127, "min quantized to -127");
544 }
545
546 #[test]
547 fn q8_0_slice_alignment() {
548 let block = BlockQ8_0 {
549 d: f16::from_f32(2.0),
550 qs: [0i8; 32],
551 };
552 let bytes: &[u8] = unsafe {
553 std::slice::from_raw_parts((&block as *const BlockQ8_0).cast::<u8>(), BLOCK_Q8_0_BYTES)
554 };
555 let result = BlockQ8_0::slice_from_bytes(bytes).expect("aligned slice should succeed");
556 assert_eq!(result.len(), 1);
557 assert_eq!(result[0].d, f16::from_f32(2.0));
558 }
559
560 #[test]
561 fn q8_0_quantize_scale() {
562 let mut values = vec![0.0f32; 32];
563 values[5] = 63.5; let blocks = BlockQ8_0::quantize(&values).unwrap();
565 let scale = blocks[0].d.to_f32();
566 assert!(
567 (scale - 0.5).abs() < 0.02,
568 "scale should be ~0.5 for max=63.5, got {scale}"
569 );
570 }
571
572 #[test]
573 fn q8_0_slice_bad_len() {
574 let data = vec![0u8; 35]; assert!(BlockQ8_0::slice_from_bytes(&data).is_err());
576 }
577
578 #[test]
579 fn q8_0_quantize_wrong_len() {
580 assert!(BlockQ8_0::quantize(&[1.0f32; 17]).is_err());
581 }
582
583 #[test]
584 fn q8_0_dequant_too_small_buffer() {
585 let blocks = BlockQ8_0::quantize(&[0.0f32; 32]).unwrap();
586 let mut out = vec![0.0f32; 5];
587 assert!(BlockQ8_0::dequant(&blocks, &mut out).is_err());
588 }
589
590 #[test]
591 fn q8_0_dequant_to_buf_matches_dequant() {
592 let values: Vec<f32> = (0..32).map(|i| (i as f32) * 3.0 - 48.0).collect();
593 let blocks = BlockQ8_0::quantize(&values).unwrap();
594 let mut full_out = vec![0.0f32; 32];
595 BlockQ8_0::dequant(&blocks, &mut full_out).unwrap();
596 let mut buf = [0.0f32; 32];
597 blocks[0].dequant_to_buf(&mut buf);
598 for (a, b) in full_out.iter().zip(buf.iter()) {
599 assert!((a - b).abs() < 1e-6, "dequant_to_buf must match dequant");
600 }
601 }
602
603 #[test]
604 fn q8_0_positive_negative_mix() {
605 let values: Vec<f32> = (0..32)
606 .map(|i| if i % 2 == 0 { i as f32 } else { -(i as f32) })
607 .collect();
608 let blocks = BlockQ8_0::quantize(&values).unwrap();
609 let mut out = vec![0.0f32; 32];
610 BlockQ8_0::dequant(&blocks, &mut out).unwrap();
611 for i in (2..32).step_by(2) {
613 assert!(
614 out[i] >= 0.0,
615 "even index should be non-negative: {}",
616 out[i]
617 );
618 }
619 for i in (1..32).step_by(2) {
620 assert!(
621 out[i] <= 0.0,
622 "odd index should be non-positive: {}",
623 out[i]
624 );
625 }
626 }
627
628 #[test]
629 fn q8_0_block_count_correct() {
630 let values = vec![1.0f32; 96]; let blocks = BlockQ8_0::quantize(&values).unwrap();
632 assert_eq!(blocks.len(), 3);
633 }
634}