1use alloc::vec;
36use alloc::vec::Vec;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44#[non_exhaustive]
45pub enum QuantMode {
46 Bits8,
49 Bits3_5,
52 Bits2_5,
55}
56
57impl QuantMode {
58 #[inline]
60 fn n_levels(self) -> u32 {
61 match self {
62 QuantMode::Bits8 => N_LEVELS_8,
63 QuantMode::Bits3_5 => N_LEVELS_3_5,
64 QuantMode::Bits2_5 => N_LEVELS_2_5,
65 }
66 }
67
68 #[inline]
70 fn values_per_word(self) -> usize {
71 match self {
72 QuantMode::Bits8 => VALUES_PER_WORD_8,
73 QuantMode::Bits3_5 => VALUES_PER_WORD_3_5,
74 QuantMode::Bits2_5 => VALUES_PER_WORD_2_5,
75 }
76 }
77
78 #[inline]
80 fn to_u32(self) -> u32 {
81 match self {
82 QuantMode::Bits8 => 0,
83 QuantMode::Bits3_5 => 1,
84 QuantMode::Bits2_5 => 2,
85 }
86 }
87
88 #[inline]
90 fn from_u32(v: u32) -> Option<Self> {
91 match v {
92 0 => Some(QuantMode::Bits8),
93 1 => Some(QuantMode::Bits3_5),
94 2 => Some(QuantMode::Bits2_5),
95 _ => None,
96 }
97 }
98}
99
100const N_LEVELS_8: u32 = 256;
106const VALUES_PER_WORD_8: usize = 4;
108
109const N_LEVELS_3_5: u32 = 11;
111const VALUES_PER_WORD_3_5: usize = 7;
113
114const N_LEVELS_2_5: u32 = 5;
116const VALUES_PER_WORD_2_5: usize = 13;
118
119#[inline]
125fn pack4_bytes(values: &[u8]) -> u32 {
126 let mut packed: u32 = 0;
127 for (i, &v) in values.iter().enumerate().take(4) {
128 packed |= (v as u32) << (i * 8);
129 }
130 packed
131}
132
133#[inline]
135fn unpack4_bytes(packed: u32, count: usize) -> [u8; 4] {
136 let mut values = [0u8; 4];
137 for (i, v) in values.iter_mut().enumerate().take(count) {
138 *v = ((packed >> (i * 8)) & 0xFF) as u8;
139 }
140 values
141}
142
143#[inline]
154fn pack7(values: &[u8]) -> u32 {
155 debug_assert!(values.len() <= 7);
156 let mut packed: u32 = 0;
157 for &v in values.iter().rev() {
158 debug_assert!(v < N_LEVELS_3_5 as u8);
159 packed = packed * N_LEVELS_3_5 + v as u32;
160 }
161 packed
162}
163
164#[inline]
169fn unpack7(packed: u32, count: usize) -> [u8; 7] {
170 let mut values = [0u8; 7];
171 let mut p = packed;
172 for v in values.iter_mut().take(count) {
173 *v = (p % N_LEVELS_3_5) as u8;
174 p /= N_LEVELS_3_5;
175 }
176 values
177}
178
179#[inline]
188fn pack13(values: &[u8]) -> u32 {
189 debug_assert!(values.len() <= 13);
190 let mut packed: u32 = 0;
191 for &v in values.iter().rev() {
192 debug_assert!(v < N_LEVELS_2_5 as u8);
193 packed = packed * N_LEVELS_2_5 + v as u32;
194 }
195 packed
196}
197
198#[inline]
203fn unpack13(packed: u32, count: usize) -> [u8; 13] {
204 let mut values = [0u8; 13];
205 let mut p = packed;
206 for v in values.iter_mut().take(count) {
207 *v = (p % N_LEVELS_2_5) as u8;
208 p /= N_LEVELS_2_5;
209 }
210 values
211}
212
213#[inline]
219fn pack_word(values: &[u8], mode: QuantMode) -> u32 {
220 match mode {
221 QuantMode::Bits8 => pack4_bytes(values),
222 QuantMode::Bits3_5 => pack7(values),
223 QuantMode::Bits2_5 => pack13(values),
224 }
225}
226
227pub struct TurboQuantized {
240 packed: Vec<u32>,
245 n_weights: usize,
247 scale: f64,
249 offset: f64,
251 seed: u64,
253 padded_len: usize,
255 mode: QuantMode,
257}
258
259impl TurboQuantized {
260 pub fn predict(&self, features: &[f64]) -> f64 {
266 if self.n_weights == 0 {
267 return 0.0;
268 }
269 let mut rotated_features = Vec::with_capacity(self.padded_len);
271 let use_len = self.n_weights.min(features.len());
272 rotated_features.extend_from_slice(&features[..use_len]);
273 rotated_features.resize(self.padded_len, 0.0);
274 apply_rotation(&mut rotated_features, self.seed);
275
276 self.dot_with_rotated(&rotated_features)
277 }
278
279 pub fn predict_with_scratch(&self, features: &[f64], scratch: &mut [f64]) -> f64 {
284 if self.n_weights == 0 {
285 return 0.0;
286 }
287 assert!(
288 scratch.len() >= self.padded_len,
289 "scratch buffer too small: {} < {}",
290 scratch.len(),
291 self.padded_len
292 );
293
294 for v in scratch[..self.padded_len].iter_mut() {
296 *v = 0.0;
297 }
298 let use_len = self.n_weights.min(features.len());
299 scratch[..use_len].copy_from_slice(&features[..use_len]);
300
301 apply_rotation(&mut scratch[..self.padded_len], self.seed);
303
304 self.dot_with_rotated(&scratch[..self.padded_len])
305 }
306
307 fn dot_with_rotated(&self, rotated_features: &[f64]) -> f64 {
309 let mut sum = 0.0;
310 let mut feat_idx = 0;
311 let vpw = self.mode.values_per_word();
312
313 for &word in self.packed.iter() {
314 let remaining = self.padded_len - feat_idx;
315 let count = remaining.min(vpw);
316 match self.mode {
318 QuantMode::Bits8 => {
319 let values = unpack4_bytes(word, count);
320 for &q in values.iter().take(count) {
321 let w = q as f64 * self.scale + self.offset;
322 sum += w * rotated_features[feat_idx];
323 feat_idx += 1;
324 }
325 }
326 QuantMode::Bits3_5 => {
327 let values = unpack7(word, count);
328 for &q in values.iter().take(count) {
329 let w = q as f64 * self.scale + self.offset;
330 sum += w * rotated_features[feat_idx];
331 feat_idx += 1;
332 }
333 }
334 QuantMode::Bits2_5 => {
335 let values = unpack13(word, count);
336 for &q in values.iter().take(count) {
337 let w = q as f64 * self.scale + self.offset;
338 sum += w * rotated_features[feat_idx];
339 feat_idx += 1;
340 }
341 }
342 }
343 if feat_idx >= self.padded_len {
344 break;
345 }
346 }
347 sum
348 }
349
350 pub fn dequantize(&self) -> Vec<f64> {
355 let mut rotated = Vec::with_capacity(self.padded_len);
356 let mut count_total = 0;
357 let vpw = self.mode.values_per_word();
358
359 for &word in self.packed.iter() {
360 let remaining = self.padded_len - count_total;
361 let count = remaining.min(vpw);
362 match self.mode {
363 QuantMode::Bits8 => {
364 let values = unpack4_bytes(word, count);
365 for &q in values.iter().take(count) {
366 rotated.push(q as f64 * self.scale + self.offset);
367 count_total += 1;
368 }
369 }
370 QuantMode::Bits3_5 => {
371 let values = unpack7(word, count);
372 for &q in values.iter().take(count) {
373 rotated.push(q as f64 * self.scale + self.offset);
374 count_total += 1;
375 }
376 }
377 QuantMode::Bits2_5 => {
378 let values = unpack13(word, count);
379 for &q in values.iter().take(count) {
380 rotated.push(q as f64 * self.scale + self.offset);
381 count_total += 1;
382 }
383 }
384 }
385 if count_total >= self.padded_len {
386 break;
387 }
388 }
389 apply_inverse_rotation(&mut rotated, self.seed);
391 rotated.truncate(self.n_weights);
392 rotated
393 }
394
395 pub fn n_weights(&self) -> usize {
397 self.n_weights
398 }
399
400 pub fn padded_len(&self) -> usize {
402 self.padded_len
403 }
404
405 pub fn mode(&self) -> QuantMode {
407 self.mode
408 }
409
410 pub fn compression_ratio(&self) -> f64 {
412 let original_bytes = self.n_weights * 8; let packed_bytes = self.packed.len() * 4 + HEADER_SIZE;
414 original_bytes as f64 / packed_bytes as f64
415 }
416
417 pub fn to_bytes(&self) -> Vec<u8> {
430 let mut buf = Vec::with_capacity(HEADER_SIZE + self.packed.len() * 4);
431 buf.extend_from_slice(&(self.n_weights as u32).to_le_bytes());
432 buf.extend_from_slice(&self.mode.to_u32().to_le_bytes());
433 buf.extend_from_slice(&self.seed.to_le_bytes());
434 buf.extend_from_slice(&(self.padded_len as u32).to_le_bytes());
435 buf.extend_from_slice(&self.scale.to_le_bytes());
436 buf.extend_from_slice(&self.offset.to_le_bytes());
437 for &word in &self.packed {
438 buf.extend_from_slice(&word.to_le_bytes());
439 }
440 buf
441 }
442}
443
444pub struct TurboQuantizedView<'a> {
455 packed: &'a [u8],
457 n_weights: usize,
459 seed: u64,
461 padded_len: usize,
463 scale: f64,
465 offset: f64,
467 mode: QuantMode,
469}
470
471const HEADER_SIZE: usize = 36;
473
474impl<'a> TurboQuantizedView<'a> {
475 pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, crate::error::FormatError> {
481 if bytes.len() < HEADER_SIZE {
482 return Err(crate::error::FormatError::Truncated);
483 }
484 let n_weights = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
485 let mode_raw = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
486 let mode = QuantMode::from_u32(mode_raw).ok_or(crate::error::FormatError::Truncated)?;
487 let seed = u64::from_le_bytes([
488 bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
489 ]);
490 let padded_len = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]) as usize;
491 let scale = f64::from_le_bytes([
492 bytes[20], bytes[21], bytes[22], bytes[23], bytes[24], bytes[25], bytes[26], bytes[27],
493 ]);
494 let offset = f64::from_le_bytes([
495 bytes[28], bytes[29], bytes[30], bytes[31], bytes[32], bytes[33], bytes[34], bytes[35],
496 ]);
497
498 let vpw = mode.values_per_word();
500 let n_words = padded_len.div_ceil(vpw);
501 let expected_len = HEADER_SIZE + n_words * 4;
502 if bytes.len() < expected_len {
503 return Err(crate::error::FormatError::Truncated);
504 }
505
506 Ok(Self {
507 packed: &bytes[HEADER_SIZE..HEADER_SIZE + n_words * 4],
508 n_weights,
509 seed,
510 padded_len,
511 scale,
512 offset,
513 mode,
514 })
515 }
516
517 pub fn predict(&self, features: &[f64]) -> f64 {
523 if self.n_weights == 0 {
524 return 0.0;
525 }
526 let mut rotated_features = Vec::with_capacity(self.padded_len);
528 let use_len = self.n_weights.min(features.len());
529 rotated_features.extend_from_slice(&features[..use_len]);
530 rotated_features.resize(self.padded_len, 0.0);
531 apply_rotation(&mut rotated_features, self.seed);
532
533 self.dot_with_rotated(&rotated_features)
534 }
535
536 pub fn predict_with_scratch(&self, features: &[f64], scratch: &mut [f64]) -> f64 {
541 if self.n_weights == 0 {
542 return 0.0;
543 }
544 assert!(
545 scratch.len() >= self.padded_len,
546 "scratch buffer too small: {} < {}",
547 scratch.len(),
548 self.padded_len
549 );
550
551 for v in scratch[..self.padded_len].iter_mut() {
552 *v = 0.0;
553 }
554 let use_len = self.n_weights.min(features.len());
555 scratch[..use_len].copy_from_slice(&features[..use_len]);
556 apply_rotation(&mut scratch[..self.padded_len], self.seed);
557
558 self.dot_with_rotated(&scratch[..self.padded_len])
559 }
560
561 fn dot_with_rotated(&self, rotated_features: &[f64]) -> f64 {
563 let mut sum = 0.0;
564 let mut feat_idx = 0;
565 let vpw = self.mode.values_per_word();
566 let n_words = self.packed.len() / 4;
567
568 for word_idx in 0..n_words {
569 let off = word_idx * 4;
570 let word = u32::from_le_bytes([
571 self.packed[off],
572 self.packed[off + 1],
573 self.packed[off + 2],
574 self.packed[off + 3],
575 ]);
576 let remaining = self.padded_len - feat_idx;
577 let count = remaining.min(vpw);
578 match self.mode {
579 QuantMode::Bits8 => {
580 let values = unpack4_bytes(word, count);
581 for &q in values.iter().take(count) {
582 let w = q as f64 * self.scale + self.offset;
583 sum += w * rotated_features[feat_idx];
584 feat_idx += 1;
585 }
586 }
587 QuantMode::Bits3_5 => {
588 let values = unpack7(word, count);
589 for &q in values.iter().take(count) {
590 let w = q as f64 * self.scale + self.offset;
591 sum += w * rotated_features[feat_idx];
592 feat_idx += 1;
593 }
594 }
595 QuantMode::Bits2_5 => {
596 let values = unpack13(word, count);
597 for &q in values.iter().take(count) {
598 let w = q as f64 * self.scale + self.offset;
599 sum += w * rotated_features[feat_idx];
600 feat_idx += 1;
601 }
602 }
603 }
604 if feat_idx >= self.padded_len {
605 break;
606 }
607 }
608 sum
609 }
610
611 pub fn n_weights(&self) -> usize {
613 self.n_weights
614 }
615
616 pub fn padded_len(&self) -> usize {
618 self.padded_len
619 }
620
621 pub fn mode(&self) -> QuantMode {
623 self.mode
624 }
625}
626
627const DEFAULT_SEED: u64 = 0xDEAD_BEEF;
633
634#[inline]
636fn next_power_of_two(n: usize) -> usize {
637 if n <= 1 {
638 return 1;
639 }
640 let mut v = n - 1;
642 v |= v >> 1;
643 v |= v >> 2;
644 v |= v >> 4;
645 v |= v >> 8;
646 v |= v >> 16;
647 #[cfg(target_pointer_width = "64")]
648 {
649 v |= v >> 32;
650 }
651 v + 1
652}
653
654fn fwht_inplace(x: &mut [f64]) {
659 let n = x.len();
660 debug_assert!(
661 n > 0 && (n & (n - 1)) == 0,
662 "FWHT requires power-of-2 length"
663 );
664 let mut h = 1;
665 while h < n {
666 for i in (0..n).step_by(h * 2) {
667 for j in i..i + h {
668 let a = x[j];
669 let b = x[j + h];
670 x[j] = a + b;
671 x[j + h] = a - b;
672 }
673 }
674 h *= 2;
675 }
676 let scale = 1.0 / crate::math::sqrt(n as f64);
677 for v in x.iter_mut() {
678 *v *= scale;
679 }
680}
681
682fn apply_sign_flip(x: &mut [f64], seed: u64) {
686 let mut state = seed;
687 for v in x.iter_mut() {
688 let r = crate::rng::xorshift64(&mut state);
689 if r & 1 == 0 {
690 *v = -*v;
691 }
692 }
693}
694
695fn apply_rotation(buf: &mut [f64], seed: u64) {
697 apply_sign_flip(buf, seed);
698 fwht_inplace(buf);
699}
700
701fn apply_inverse_rotation(buf: &mut [f64], seed: u64) {
703 fwht_inplace(buf);
704 apply_sign_flip(buf, seed);
705}
706
707pub fn quantize_weights(weights: &[f64]) -> TurboQuantized {
739 quantize(weights, QuantMode::Bits3_5, DEFAULT_SEED)
740}
741
742pub fn quantize_weights_with_seed(weights: &[f64], seed: u64) -> TurboQuantized {
746 quantize(weights, QuantMode::Bits3_5, seed)
747}
748
749pub fn quantize(weights: &[f64], mode: QuantMode, seed: u64) -> TurboQuantized {
755 if weights.is_empty() {
756 return TurboQuantized {
757 packed: vec![],
758 n_weights: 0,
759 scale: 0.0,
760 offset: 0.0,
761 seed,
762 padded_len: 1,
763 mode,
764 };
765 }
766
767 let padded_len = next_power_of_two(weights.len());
769 let mut rotated = Vec::with_capacity(padded_len);
770 rotated.extend_from_slice(weights);
771 rotated.resize(padded_len, 0.0);
772 apply_rotation(&mut rotated, seed);
773
774 let min_val = rotated.iter().copied().fold(f64::INFINITY, f64::min);
777 let max_val = rotated.iter().copied().fold(f64::NEG_INFINITY, f64::max);
778 let range = max_val - min_val;
779 let n_levels = mode.n_levels();
780 let max_level = n_levels - 1;
781 let scale = if range < 1e-15 {
782 0.0
783 } else {
784 range / max_level as f64
785 };
786
787 let quantized: Vec<u8> = rotated
789 .iter()
790 .map(|&w| {
791 if scale < 1e-15 {
792 (max_level / 2) as u8 } else {
794 let q = crate::math::round((w - min_val) / scale);
795 (q as u8).min(max_level as u8)
796 }
797 })
798 .collect();
799
800 let vpw = mode.values_per_word();
802 let n_words = padded_len.div_ceil(vpw);
803 let mut packed = Vec::with_capacity(n_words);
804 for chunk in quantized.chunks(vpw) {
805 packed.push(pack_word(chunk, mode));
806 }
807
808 TurboQuantized {
809 packed,
810 n_weights: weights.len(),
811 scale,
812 offset: min_val,
813 seed,
814 padded_len,
815 mode,
816 }
817}
818
819pub fn quantize_f32(weights: &[f32], mode: QuantMode) -> TurboQuantized {
821 let f64_weights: Vec<f64> = weights.iter().map(|&w| w as f64).collect();
822 quantize(&f64_weights, mode, DEFAULT_SEED)
823}
824
825pub fn quantize_i16(weights: &[i16], scale: f64, mode: QuantMode) -> TurboQuantized {
830 let f64_weights: Vec<f64> = weights.iter().map(|&w| w as f64 * scale).collect();
831 quantize(&f64_weights, mode, DEFAULT_SEED)
832}
833
834#[cfg(test)]
839mod tests {
840 use super::*;
841
842 #[test]
845 fn pack_unpack_roundtrip() {
846 let values = [0u8, 5, 10, 3, 7, 1, 9];
847 let packed = pack7(&values);
848 let unpacked = unpack7(packed, 7);
849 assert_eq!(&unpacked, &values, "pack/unpack roundtrip failed");
850 }
851
852 #[test]
853 fn pack_unpack_partial() {
854 let values = [2u8, 8, 4];
855 let packed = pack7(&values);
856 let unpacked = unpack7(packed, 3);
857 assert_eq!(&unpacked[..3], &values, "partial pack/unpack failed");
858 }
859
860 #[test]
861 fn quantize_empty() {
862 let q = quantize_weights(&[]);
863 assert_eq!(q.n_weights(), 0);
864 assert_eq!(q.predict(&[]), 0.0);
865 }
866
867 #[test]
868 fn quantize_single_weight() {
869 let q = quantize_weights(&[3.125]);
870 assert_eq!(q.n_weights(), 1);
871 let pred = q.predict(&[1.0]);
872 assert!(
873 (pred - 3.125).abs() < 0.5,
874 "single weight should roundtrip reasonably, got {pred}"
875 );
876 }
877
878 #[test]
879 fn quantize_constant_weights() {
880 let q = quantize_weights(&[2.5, 2.5, 2.5, 2.5]);
881 let dq = q.dequantize();
882 for (i, &w) in dq.iter().enumerate() {
883 assert!(
884 (w - 2.5).abs() < 0.05,
885 "constant weights should dequantize closely, got {w} at [{i}]"
886 );
887 }
888 }
889
890 #[test]
891 fn quantize_predict_accuracy() {
892 let weights = vec![0.1, -0.5, 0.3, 0.0, -0.2, 0.4, 0.1, -0.1, 0.2];
893 let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
894 let exact: f64 = weights
895 .iter()
896 .zip(features.iter())
897 .map(|(w, f)| w * f)
898 .sum();
899 let q = quantize_weights(&weights);
900 let pred = q.predict(&features);
901 let rel_err = if exact.abs() > 1e-10 {
902 (pred - exact).abs() / exact.abs()
903 } else {
904 (pred - exact).abs()
905 };
906 assert!(
907 rel_err < 0.25,
908 "relative error should be < 25%, got {rel_err:.4} (exact={exact:.4}, pred={pred:.4})"
909 );
910 }
911
912 #[test]
913 fn quantize_dequantize_bounded_error() {
914 let weights: Vec<f64> = (0..100).map(|i| (i as f64 - 50.0) / 50.0).collect();
915 let q = quantize_weights(&weights);
916 let dq = q.dequantize();
917 let max_err = weights
918 .iter()
919 .zip(dq.iter())
920 .map(|(w, d)| (w - d).abs())
921 .fold(0.0f64, f64::max);
922 assert!(
923 max_err < 0.25,
924 "max dequantize error should be < 0.25, got {max_err}"
925 );
926 }
927
928 #[test]
929 fn to_bytes_from_bytes_roundtrip() {
930 let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.2, 0.7, -0.9, 0.4, 0.6];
931 let q = quantize_weights(&weights);
932 let bytes = q.to_bytes();
933 let view = TurboQuantizedView::from_bytes(&bytes).expect("valid bytes");
934 assert_eq!(view.n_weights(), q.n_weights());
935 let features = vec![1.0; 10];
936 let pred_owned = q.predict(&features);
937 let pred_view = view.predict(&features);
938 assert!(
939 (pred_owned - pred_view).abs() < 1e-15,
940 "owned vs view predict mismatch: {pred_owned} vs {pred_view}"
941 );
942 }
943
944 #[test]
945 fn from_bytes_rejects_short() {
946 assert!(TurboQuantizedView::from_bytes(&[0u8; 10]).is_err());
947 assert!(TurboQuantizedView::from_bytes(&[0u8; 35]).is_err());
948 }
949
950 #[test]
951 fn compression_ratio_reasonable() {
952 let weights: Vec<f64> = (0..100).map(|i| i as f64 * 0.01).collect();
953 let q = quantize_weights(&weights);
954 let ratio = q.compression_ratio();
955 assert!(
956 ratio > 3.0,
957 "compression ratio should be > 3x for 100 weights, got {ratio:.2}"
958 );
959 }
960
961 #[test]
962 fn predict_large_vector() {
963 let n = 1000;
964 let weights: Vec<f64> = (0..n).map(|i| ((i as f64) * 0.1).sin()).collect();
965 let features: Vec<f64> = (0..n).map(|i| ((i as f64) * 0.05).cos()).collect();
966 let exact: f64 = weights
967 .iter()
968 .zip(features.iter())
969 .map(|(w, f)| w * f)
970 .sum();
971 let q = quantize_weights(&weights);
972 let pred = q.predict(&features);
973 assert!(pred.is_finite(), "prediction should be finite");
974 let abs_err = (pred - exact).abs();
975 assert!(
976 abs_err < exact.abs() * 0.5 + 5.0,
977 "absolute error too large: {abs_err} for exact {exact}"
978 );
979 }
980
981 #[test]
982 fn next_power_of_two_correctness() {
983 assert_eq!(next_power_of_two(0), 1);
984 assert_eq!(next_power_of_two(1), 1);
985 assert_eq!(next_power_of_two(2), 2);
986 assert_eq!(next_power_of_two(3), 4);
987 assert_eq!(next_power_of_two(4), 4);
988 assert_eq!(next_power_of_two(5), 8);
989 assert_eq!(next_power_of_two(7), 8);
990 assert_eq!(next_power_of_two(8), 8);
991 assert_eq!(next_power_of_two(9), 16);
992 assert_eq!(next_power_of_two(100), 128);
993 assert_eq!(next_power_of_two(1024), 1024);
994 assert_eq!(next_power_of_two(1025), 2048);
995 }
996
997 #[test]
998 fn fwht_roundtrip() {
999 let mut data = vec![1.0, 2.0, 3.0, 4.0];
1000 let original = data.clone();
1001 fwht_inplace(&mut data);
1002 fwht_inplace(&mut data);
1003 for (i, (&a, &b)) in data.iter().zip(original.iter()).enumerate() {
1004 assert!(
1005 (a - b).abs() < 1e-10,
1006 "FWHT roundtrip failed at [{i}]: {a} vs {b}"
1007 );
1008 }
1009 }
1010
1011 #[test]
1012 fn fwht_roundtrip_large() {
1013 let n = 64;
1014 let mut data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1 - 3.0).collect();
1015 let original = data.clone();
1016 fwht_inplace(&mut data);
1017 fwht_inplace(&mut data);
1018 for (i, (&a, &b)) in data.iter().zip(original.iter()).enumerate() {
1019 assert!(
1020 (a - b).abs() < 1e-10,
1021 "FWHT large roundtrip failed at [{i}]: {a} vs {b}"
1022 );
1023 }
1024 }
1025
1026 #[test]
1027 fn sign_flip_is_self_inverse() {
1028 let seed = 42u64;
1029 let mut data = vec![1.0, -2.5, 3.7, 0.0, -1.1, 5.5, 2.2, -0.8];
1030 let original = data.clone();
1031 apply_sign_flip(&mut data, seed);
1032 apply_sign_flip(&mut data, seed);
1033 for (i, (&a, &b)) in data.iter().zip(original.iter()).enumerate() {
1034 assert!(
1035 (a - b).abs() < 1e-15,
1036 "sign flip self-inverse failed at [{i}]: {a} vs {b}"
1037 );
1038 }
1039 }
1040
1041 #[test]
1042 fn full_rotation_roundtrip() {
1043 let seed = 0xCAFE_u64;
1044 let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1045 let mut buf = original.clone();
1046 apply_rotation(&mut buf, seed);
1047 apply_inverse_rotation(&mut buf, seed);
1048 for (i, (&a, &b)) in buf.iter().zip(original.iter()).enumerate() {
1049 assert!(
1050 (a - b).abs() < 1e-10,
1051 "rotation roundtrip failed at [{i}]: {a} vs {b}"
1052 );
1053 }
1054 }
1055
1056 #[test]
1057 fn rotation_preserves_norm() {
1058 let seed = 0xBEEF_u64;
1059 let data = vec![1.0, 2.0, 3.0, 4.0];
1060 let norm_before: f64 = data.iter().map(|x| x * x).sum();
1061 let mut rotated = data;
1062 apply_rotation(&mut rotated, seed);
1063 let norm_after: f64 = rotated.iter().map(|x| x * x).sum();
1064 assert!(
1065 (norm_before - norm_after).abs() < 1e-10,
1066 "rotation should preserve norm: {norm_before} vs {norm_after}"
1067 );
1068 }
1069
1070 #[test]
1071 fn rotation_improves_correlated_weights() {
1072 let weights = vec![1.0, 1.01, 0.99, 1.02, 0.98, 1.01, 0.99, 1.0];
1073 let q = quantize_weights(&weights);
1074 let dq = q.dequantize();
1075 let max_err: f64 = weights
1076 .iter()
1077 .zip(dq.iter())
1078 .map(|(w, d)| (w - d).abs())
1079 .fold(0.0f64, f64::max);
1080 assert!(
1081 max_err < 0.05,
1082 "rotation should improve correlated weight quantization, max_err={max_err}"
1083 );
1084 }
1085
1086 #[test]
1087 fn quantize_with_seed_deterministic() {
1088 let weights = vec![0.1, -0.5, 0.3, 0.0, -0.2, 0.4, 0.1, -0.1];
1089 let features = vec![1.0; 8];
1090 let q1 = quantize_weights_with_seed(&weights, 123);
1091 let q2 = quantize_weights_with_seed(&weights, 123);
1092 let p1 = q1.predict(&features);
1093 let p2 = q2.predict(&features);
1094 assert!(
1095 (p1 - p2).abs() < 1e-15,
1096 "same seed should give identical results: {p1} vs {p2}"
1097 );
1098 }
1099
1100 #[test]
1101 fn different_seeds_produce_different_quantizations() {
1102 let weights = vec![0.1, -0.5, 0.3, 0.0, -0.2, 0.4, 0.1, -0.1];
1103 let q1 = quantize_weights_with_seed(&weights, 111);
1104 let q2 = quantize_weights_with_seed(&weights, 222);
1105 assert_ne!(
1106 q1.packed, q2.packed,
1107 "different seeds should produce different packed data"
1108 );
1109 }
1110
1111 #[test]
1112 fn to_bytes_from_bytes_preserves_seed_and_padded_len() {
1113 let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0];
1114 let q = quantize_weights_with_seed(&weights, 0xABCD);
1115 let bytes = q.to_bytes();
1116 let view = TurboQuantizedView::from_bytes(&bytes).expect("valid bytes");
1117 assert_eq!(view.seed, 0xABCD);
1118 assert_eq!(view.padded_len, q.padded_len);
1119 assert_eq!(view.n_weights(), q.n_weights());
1120 }
1121
1122 #[test]
1125 fn bits8_pack_unpack_roundtrip() {
1126 let values = [0u8, 127, 255, 42];
1127 let packed = pack4_bytes(&values);
1128 let unpacked = unpack4_bytes(packed, 4);
1129 assert_eq!(&unpacked, &values, "8-bit pack/unpack roundtrip failed");
1130 }
1131
1132 #[test]
1133 fn bits8_near_lossless() {
1134 let weights: Vec<f64> = (0..64).map(|i| (i as f64 - 32.0) / 32.0).collect();
1135 let q = quantize(&weights, QuantMode::Bits8, DEFAULT_SEED);
1136 let dq = q.dequantize();
1137 let max_err = weights
1138 .iter()
1139 .zip(dq.iter())
1140 .map(|(w, d)| (w - d).abs())
1141 .fold(0.0f64, f64::max);
1142 assert!(
1143 max_err < 0.02,
1144 "8-bit should be near-lossless, max_err={max_err}"
1145 );
1146 }
1147
1148 #[test]
1149 fn bits8_predict_accuracy() {
1150 let weights: Vec<f64> = (0..32).map(|i| (i as f64).sin() * 0.5).collect();
1151 let features: Vec<f64> = (0..32).map(|i| (i as f64).cos() * 0.3).collect();
1152 let exact: f64 = weights
1153 .iter()
1154 .zip(features.iter())
1155 .map(|(w, f)| w * f)
1156 .sum();
1157 let q = quantize(&weights, QuantMode::Bits8, DEFAULT_SEED);
1158 let pred = q.predict(&features);
1159 let rel_err = (pred - exact).abs() / exact.abs().max(1e-10);
1160 assert!(
1161 rel_err < 0.10,
1162 "8-bit predict should have <10% relative error, got {rel_err:.4}"
1163 );
1164 }
1165
1166 #[test]
1169 fn bits2_5_packing_roundtrip() {
1170 let values = [0u8, 4, 2, 1, 3, 0, 4, 2, 1, 3, 0, 4, 2];
1171 let packed = pack13(&values);
1172 let unpacked = unpack13(packed, 13);
1173 assert_eq!(&unpacked, &values, "2.5-bit pack/unpack roundtrip failed");
1174 }
1175
1176 #[test]
1177 fn bits2_5_quantize_and_predict() {
1178 let weights: Vec<f64> = (0..16).map(|i| (i as f64 - 8.0) / 8.0).collect();
1179 let features = vec![1.0; 16];
1180 let q = quantize(&weights, QuantMode::Bits2_5, DEFAULT_SEED);
1181 let pred = q.predict(&features);
1182 assert!(pred.is_finite(), "2.5-bit predict should be finite");
1183 }
1184
1185 #[test]
1188 fn all_modes_serialize_roundtrip() {
1189 let weights = vec![0.1, -0.3, 0.5, 0.0, -0.2, 0.4, 0.3, -0.1];
1190 for mode in [QuantMode::Bits8, QuantMode::Bits3_5, QuantMode::Bits2_5] {
1191 let q = quantize(&weights, mode, 42);
1192 let bytes = q.to_bytes();
1193 let view = TurboQuantizedView::from_bytes(&bytes).expect("valid bytes");
1194 assert_eq!(view.n_weights(), q.n_weights());
1195 assert_eq!(view.mode(), mode);
1196 let features = vec![1.0; 8];
1197 let p1 = q.predict(&features);
1198 let p2 = view.predict(&features);
1199 assert!(
1200 (p1 - p2).abs() < 1e-15,
1201 "mode {mode:?}: owned={p1} vs view={p2}"
1202 );
1203 }
1204 }
1205
1206 #[test]
1209 fn predict_with_scratch_matches_predict() {
1210 let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.2];
1211 let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1212 let q = quantize(&weights, QuantMode::Bits3_5, DEFAULT_SEED);
1213 let pred = q.predict(&features);
1214 let mut scratch = vec![0.0; q.padded_len()];
1215 let pred_scratch = q.predict_with_scratch(&features, &mut scratch);
1216 assert!(
1217 (pred - pred_scratch).abs() < 1e-15,
1218 "scratch predict should match: {pred} vs {pred_scratch}"
1219 );
1220 }
1221
1222 #[test]
1223 fn predict_with_scratch_view_matches_predict() {
1224 let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.2];
1225 let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1226 let q = quantize(&weights, QuantMode::Bits8, DEFAULT_SEED);
1227 let bytes = q.to_bytes();
1228 let view = TurboQuantizedView::from_bytes(&bytes).expect("valid bytes");
1229 let pred = view.predict(&features);
1230 let mut scratch = vec![0.0; view.padded_len()];
1231 let pred_scratch = view.predict_with_scratch(&features, &mut scratch);
1232 assert!(
1233 (pred - pred_scratch).abs() < 1e-15,
1234 "view scratch predict should match: {pred} vs {pred_scratch}"
1235 );
1236 }
1237
1238 #[test]
1241 fn quantize_f32_works() {
1242 let weights = vec![0.5f32, -0.3, 0.8, -0.1];
1243 let q = quantize_f32(&weights, QuantMode::Bits8);
1244 assert_eq!(q.n_weights(), 4);
1245 let pred = q.predict(&[1.0, 1.0, 1.0, 1.0]);
1246 assert!(pred.is_finite());
1247 }
1248
1249 #[test]
1250 fn quantize_i16_works() {
1251 let weights = vec![1000i16, -500, 2000, -1000];
1252 let scale = 1.0 / 32767.0;
1253 let q = quantize_i16(&weights, scale, QuantMode::Bits3_5);
1254 assert_eq!(q.n_weights(), 4);
1255 }
1256
1257 #[test]
1260 fn bits8_compression_ratio() {
1261 let weights: Vec<f64> = (0..256).map(|i| i as f64 * 0.01).collect();
1262 let q = quantize(&weights, QuantMode::Bits8, DEFAULT_SEED);
1263 let ratio = q.compression_ratio();
1264 assert!(
1266 ratio > 5.0,
1267 "8-bit compression ratio should be > 5x, got {ratio:.2}"
1268 );
1269 }
1270
1271 #[test]
1272 fn bits2_5_compression_ratio() {
1273 let weights: Vec<f64> = (0..256).map(|i| i as f64 * 0.01).collect();
1274 let q = quantize(&weights, QuantMode::Bits2_5, DEFAULT_SEED);
1275 let ratio = q.compression_ratio();
1276 assert!(
1278 ratio > 10.0,
1279 "2.5-bit compression ratio should be > 10x, got {ratio:.2}"
1280 );
1281 }
1282
1283 #[test]
1286 fn quantize_empty_all_modes() {
1287 for mode in [QuantMode::Bits8, QuantMode::Bits3_5, QuantMode::Bits2_5] {
1288 let q = quantize(&[], mode, DEFAULT_SEED);
1289 assert_eq!(q.n_weights(), 0);
1290 assert_eq!(q.predict(&[]), 0.0);
1291 }
1292 }
1293
1294 #[test]
1295 fn predict_with_scratch_empty() {
1296 let q = quantize(&[], QuantMode::Bits3_5, DEFAULT_SEED);
1297 let mut scratch = vec![0.0; 1];
1298 assert_eq!(q.predict_with_scratch(&[], &mut scratch), 0.0);
1299 }
1300}