1use candle_core::{DType, Device, Tensor};
16use serde::{Deserialize, Serialize};
17
18use crate::error::{QLoraError, Result};
19
20#[allow(clippy::excessive_precision)]
23pub const NF4_LEVELS: [f32; 16] = [
24 -1.0,
25 -0.696_192_800_998_688,
26 -0.525_073_051_452_637,
27 -0.394_917_488_098_145,
28 -0.284_441_381_692_887,
29 -0.184_773_430_228_233,
30 -0.091_050_036_251_545,
31 0.0,
32 0.079_580_299_556_255,
33 0.160_930_201_411_247,
34 0.246_112_301_945_686,
35 0.337_915_241_718_292,
36 0.440_709_829_330_444,
37 0.562_617_003_917_694,
38 0.722_956_836_223_602,
39 1.0,
40];
41
42#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
44pub enum QuantizationStrategy {
45 PerTensor,
47 PerChannel,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct QuantizationConfig {
54 pub block_size: usize,
56
57 pub double_quant: bool,
59
60 pub compute_dtype: ComputeDType,
62
63 pub strategy: QuantizationStrategy,
65
66 pub use_zero_point: bool,
68}
69
70#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
72pub enum ComputeDType {
73 #[default]
75 F32,
76 F16,
78 BF16,
80}
81
82impl Default for QuantizationConfig {
83 fn default() -> Self {
84 Self {
85 block_size: 64,
86 double_quant: true,
87 compute_dtype: ComputeDType::F32,
88 strategy: QuantizationStrategy::PerTensor,
89 use_zero_point: false,
90 }
91 }
92}
93
94#[derive(Debug)]
96pub struct QuantizedTensor {
97 pub data: Vec<u8>,
99 pub scales: Vec<f32>,
101 pub zero_points: Option<Vec<f32>>,
103 pub scales_quantized: Option<Vec<u8>>,
105 pub scales_scales: Option<Vec<f32>>,
107 pub shape: Vec<usize>,
109 pub block_size: usize,
111 pub double_quant_enabled: bool,
113 pub strategy: QuantizationStrategy,
115}
116
117impl QuantizedTensor {
118 #[must_use]
120 pub fn numel(&self) -> usize {
121 self.shape.iter().product()
122 }
123
124 #[must_use]
126 pub fn size_bytes(&self) -> usize {
127 let mut size = self.data.len() + self.scales.len() * 4;
128 if let Some(ref zp) = self.zero_points {
129 size += zp.len() * 4;
130 }
131 if let Some(ref sq) = self.scales_quantized {
132 size += sq.len();
133 }
134 if let Some(ref ss) = self.scales_scales {
135 size += ss.len() * 4;
136 }
137 size
138 }
139
140 #[must_use]
142 #[allow(clippy::cast_precision_loss)]
143 pub fn compression_ratio(&self) -> f64 {
144 let fp32_size = self.numel() * 4;
145 let quantized_size = self.size_bytes();
146 fp32_size as f64 / quantized_size as f64
147 }
148}
149
150pub fn quantize_nf4(tensor: &Tensor, block_size: usize) -> Result<QuantizedTensor> {
162 quantize_nf4_with_config(
163 tensor,
164 &QuantizationConfig {
165 block_size,
166 double_quant: false, compute_dtype: ComputeDType::F32,
168 strategy: QuantizationStrategy::PerTensor,
169 use_zero_point: false,
170 },
171 )
172}
173
174pub fn quantize_nf4_with_config(
186 tensor: &Tensor,
187 config: &QuantizationConfig,
188) -> Result<QuantizedTensor> {
189 match config.strategy {
190 QuantizationStrategy::PerTensor => quantize_per_tensor(tensor, config),
191 QuantizationStrategy::PerChannel => quantize_per_channel(tensor, config),
192 }
193}
194
195fn quantize_per_tensor(tensor: &Tensor, config: &QuantizationConfig) -> Result<QuantizedTensor> {
197 let shape = tensor.shape().dims().to_vec();
198 let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
199 let numel = flat.len();
200
201 if numel % config.block_size != 0 {
202 return Err(QLoraError::InvalidConfig(format!(
203 "tensor size {} not divisible by block size {}",
204 numel, config.block_size
205 )));
206 }
207
208 let num_blocks = numel / config.block_size;
209 let mut scales = Vec::with_capacity(num_blocks);
210 let mut quantized = Vec::with_capacity(numel.div_ceil(2));
211
212 for block_idx in 0..num_blocks {
213 let start = block_idx * config.block_size;
214 let end = start + config.block_size;
215 let block = &flat[start..end];
216
217 let absmax = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
219 let scale = if absmax > 0.0 { absmax } else { 1.0 };
220 scales.push(scale);
221
222 for chunk in block.chunks(2) {
224 let q0 = quantize_value_nf4(chunk[0] / scale);
225 let q1 = if chunk.len() > 1 {
226 quantize_value_nf4(chunk[1] / scale)
227 } else {
228 0
229 };
230 quantized.push((q1 << 4) | q0);
232 }
233 }
234
235 let (scales_quantized, scales_scales) = if config.double_quant {
237 let (sq, ss) = double_quantize_scales(&scales, 255);
238 (Some(sq), Some(ss))
239 } else {
240 (None, None)
241 };
242
243 let zero_points = if config.use_zero_point {
245 Some(compute_zero_points(&flat, config.block_size, &scales))
246 } else {
247 None
248 };
249
250 Ok(QuantizedTensor {
251 data: quantized,
252 scales,
253 zero_points,
254 scales_quantized,
255 scales_scales,
256 shape,
257 block_size: config.block_size,
258 double_quant_enabled: config.double_quant,
259 strategy: config.strategy,
260 })
261}
262
263fn quantize_per_channel(tensor: &Tensor, config: &QuantizationConfig) -> Result<QuantizedTensor> {
269 let shape = tensor.shape().dims().to_vec();
270
271 if shape.len() < 2 {
273 return Err(QLoraError::InvalidConfig(
274 "Per-channel quantization requires at least 2D tensor".to_string(),
275 ));
276 }
277
278 let num_channels = shape[0];
279 let channel_size = shape[1..].iter().product::<usize>();
280 let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
281
282 let mut scales = Vec::with_capacity(num_channels);
283 let mut quantized = Vec::with_capacity(flat.len().div_ceil(2));
284
285 for ch_idx in 0..num_channels {
287 let ch_start = ch_idx * channel_size;
288 let ch_end = ch_start + channel_size;
289 let channel_data = &flat[ch_start..ch_end];
290
291 let absmax = channel_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
293 let scale = if absmax > 0.0 { absmax } else { 1.0 };
294 scales.push(scale);
295
296 for chunk in channel_data.chunks(2) {
298 let q0 = quantize_value_nf4(chunk[0] / scale);
299 let q1 = if chunk.len() > 1 {
300 quantize_value_nf4(chunk[1] / scale)
301 } else {
302 0
303 };
304 quantized.push((q1 << 4) | q0);
305 }
306 }
307
308 let (scales_quantized, scales_scales) = if config.double_quant {
310 let (sq, ss) = double_quantize_scales(&scales, 255);
311 (Some(sq), Some(ss))
312 } else {
313 (None, None)
314 };
315
316 let zero_points = if config.use_zero_point {
318 let mut zps = Vec::with_capacity(num_channels);
319 #[allow(clippy::needless_range_loop)]
320 for ch_idx in 0..num_channels {
321 let ch_start = ch_idx * channel_size;
322 let ch_end = ch_start + channel_size;
323 let channel_data = &flat[ch_start..ch_end];
324 let min_val = channel_data.iter().copied().fold(f32::INFINITY, f32::min);
325 let zp = if scales[ch_idx] > 0.0 {
326 -min_val / scales[ch_idx]
327 } else {
328 0.0
329 };
330 zps.push(zp);
331 }
332 Some(zps)
333 } else {
334 None
335 };
336
337 Ok(QuantizedTensor {
338 data: quantized,
339 scales,
340 zero_points,
341 scales_quantized,
342 scales_scales,
343 shape,
344 block_size: channel_size, double_quant_enabled: config.double_quant,
346 strategy: config.strategy,
347 })
348}
349
350fn double_quantize_scales(scales: &[f32], _max_val: usize) -> (Vec<u8>, Vec<f32>) {
362 if scales.is_empty() {
363 return (Vec::new(), Vec::new());
364 }
365
366 let absmax = scales.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
368
369 if absmax == 0.0 {
370 return (vec![0; scales.len()], vec![1.0]);
371 }
372
373 #[allow(clippy::cast_precision_loss)]
375 let scale_factor = absmax / 127.0; let quantized_scales: Vec<u8> = scales
377 .iter()
378 .map(|&s| {
379 let quantized = (s / scale_factor) + 128.0; #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
381 let result = quantized.clamp(0.0, 255.0) as u8;
382 result
383 })
384 .collect();
385
386 (quantized_scales, vec![scale_factor])
387}
388
389fn dequantize_double_scales(scales_quantized: &[u8], scales_scales: &[f32]) -> Vec<f32> {
391 if scales_quantized.is_empty() || scales_scales.is_empty() {
392 return vec![];
393 }
394
395 let scale_factor = scales_scales[0];
396 scales_quantized
397 .iter()
398 .map(|&sq| (f32::from(sq) - 128.0) * scale_factor)
399 .collect()
400}
401
402fn compute_zero_points(data: &[f32], block_size: usize, scales: &[f32]) -> Vec<f32> {
415 let num_blocks = data.len() / block_size;
416 let mut zero_points = Vec::with_capacity(num_blocks);
417
418 #[allow(clippy::needless_range_loop)]
419 for block_idx in 0..num_blocks {
420 let start = block_idx * block_size;
421 let end = start + block_size;
422 let block = &data[start..end];
423 let scale = scales[block_idx];
424
425 let min_val = block.iter().copied().fold(f32::INFINITY, f32::min);
427
428 let zero_point = if scale > 0.0 { -min_val / scale } else { 0.0 };
431
432 zero_points.push(zero_point);
433 }
434
435 zero_points
436}
437
438pub fn dequantize_nf4(quantized: &QuantizedTensor, device: &Device) -> Result<Tensor> {
452 let numel = quantized.numel();
453 let mut output = Vec::with_capacity(numel);
454
455 let scales = if quantized.double_quant_enabled {
457 if let (Some(ref sq), Some(ref ss)) =
458 (&quantized.scales_quantized, &quantized.scales_scales)
459 {
460 dequantize_double_scales(sq, ss)
461 } else {
462 quantized.scales.clone()
463 }
464 } else {
465 quantized.scales.clone()
466 };
467
468 match quantized.strategy {
469 QuantizationStrategy::PerTensor => {
470 let num_blocks = scales.len();
471
472 for block_idx in 0..num_blocks {
473 let scale = scales[block_idx];
474 let zero_point = quantized
475 .zero_points
476 .as_ref()
477 .map_or(0.0, |zp| zp[block_idx]);
478 let start_byte = (block_idx * quantized.block_size) / 2;
479
480 for i in 0..quantized.block_size {
481 let byte_idx = start_byte + i / 2;
482 let byte = quantized.data[byte_idx];
483 let code = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 };
484 let nf4_value = NF4_LEVELS[code as usize] + zero_point;
486 let value = nf4_value * scale;
487 output.push(value);
488 }
489 }
490 }
491 QuantizationStrategy::PerChannel => {
492 let num_channels = scales.len();
493 let channel_size = quantized.block_size;
494
495 for ch_idx in 0..num_channels {
496 let scale = scales[ch_idx];
497 let zero_point = quantized.zero_points.as_ref().map_or(0.0, |zp| zp[ch_idx]);
498 let ch_start_byte = (ch_idx * channel_size) / 2;
499
500 for i in 0..channel_size {
501 let byte_idx = ch_start_byte + i / 2;
502 let byte = quantized.data[byte_idx];
503 let code = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 };
504 let nf4_value = NF4_LEVELS[code as usize] + zero_point;
505 let value = nf4_value * scale;
506 output.push(value);
507 }
508 }
509 }
510 }
511
512 let tensor = Tensor::from_vec(output, quantized.shape.clone(), device)?;
513 Ok(tensor)
514}
515
516pub fn dequantize_nf4_with_dtype(
532 quantized: &QuantizedTensor,
533 device: &Device,
534 compute_dtype: ComputeDType,
535) -> Result<Tensor> {
536 let f32_tensor = dequantize_nf4(quantized, device)?;
538
539 let dtype = match compute_dtype {
541 ComputeDType::F32 => return Ok(f32_tensor),
542 ComputeDType::F16 => DType::F16,
543 ComputeDType::BF16 => DType::BF16,
544 };
545
546 let converted = f32_tensor.to_dtype(dtype)?;
547 Ok(converted)
548}
549
550pub fn pad_for_quantization(tensor: &Tensor, block_size: usize, pad_value: f32) -> Result<Tensor> {
576 if tensor.dtype() != candle_core::DType::F32 {
578 return Err(QLoraError::InvalidConfig(format!(
579 "pad_for_quantization only supports F32 tensors, got {:?}. \
580 Convert to F32 first with tensor.to_dtype(DType::F32)",
581 tensor.dtype()
582 )));
583 }
584
585 let numel = tensor.elem_count();
586 let device = tensor.device();
587
588 let remainder = numel % block_size;
590 if remainder == 0 {
591 let flat = tensor.flatten_all()?;
593 return Ok(flat);
594 }
595
596 let pad_count = block_size - remainder;
598 let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
599
600 let mut padded = flat;
602 padded.extend(std::iter::repeat_n(pad_value, pad_count));
603
604 let padded_tensor = Tensor::from_vec(padded, (numel + pad_count,), device)?;
606 Ok(padded_tensor)
607}
608
609#[derive(Debug, Clone)]
611pub struct PaddingInfo {
612 pub original_shape: Vec<usize>,
614 pub padded_shape: Vec<usize>,
616 pub pad_count: usize,
618 pub block_size: usize,
620}
621
622pub fn pad_for_quantization_with_info(
645 tensor: &Tensor,
646 block_size: usize,
647 pad_value: f32,
648) -> Result<(Tensor, PaddingInfo)> {
649 if tensor.dtype() != candle_core::DType::F32 {
651 return Err(QLoraError::InvalidConfig(format!(
652 "pad_for_quantization_with_info only supports F32 tensors, got {:?}. \
653 Convert to F32 first with tensor.to_dtype(DType::F32)",
654 tensor.dtype()
655 )));
656 }
657
658 let original_shape = tensor.shape().dims().to_vec();
659 let numel = tensor.elem_count();
660 let device = tensor.device();
661
662 let remainder = numel % block_size;
663 let pad_count = if remainder == 0 {
664 0
665 } else {
666 block_size - remainder
667 };
668
669 if pad_count == 0 {
670 let flat = tensor.flatten_all()?;
671 let info = PaddingInfo {
672 original_shape: original_shape.clone(),
673 padded_shape: vec![numel],
674 pad_count: 0,
675 block_size,
676 };
677 return Ok((flat, info));
678 }
679
680 let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
681 let mut padded = flat;
682 padded.extend(std::iter::repeat_n(pad_value, pad_count));
683
684 let padded_len = numel + pad_count;
685 let padded_tensor = Tensor::from_vec(padded, (padded_len,), device)?;
686
687 let info = PaddingInfo {
688 original_shape,
689 padded_shape: vec![padded_len],
690 pad_count,
691 block_size,
692 };
693
694 Ok((padded_tensor, info))
695}
696
697pub fn unpad_tensor(tensor: &Tensor, padding_info: &PaddingInfo) -> Result<Tensor> {
713 if tensor.dtype() != candle_core::DType::F32 {
715 return Err(QLoraError::InvalidConfig(format!(
716 "unpad_tensor only supports F32 tensors, got {:?}. \
717 Convert to F32 first with tensor.to_dtype(DType::F32)",
718 tensor.dtype()
719 )));
720 }
721
722 if padding_info.pad_count == 0 {
723 let flat = tensor.flatten_all()?;
725 let reshaped = flat.reshape(padding_info.original_shape.clone())?;
726 return Ok(reshaped);
727 }
728
729 let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
730 let original_numel: usize = padding_info.original_shape.iter().product();
731
732 let unpadded: Vec<f32> = flat.into_iter().take(original_numel).collect();
734
735 let unpadded_tensor = Tensor::from_vec(
736 unpadded,
737 padding_info.original_shape.clone(),
738 tensor.device(),
739 )?;
740 Ok(unpadded_tensor)
741}
742
743fn quantize_value_nf4(value: f32) -> u8 {
745 let mut best_idx = 0;
747 let mut best_dist = f32::MAX;
748
749 for (idx, &level) in NF4_LEVELS.iter().enumerate() {
750 let dist = (value - level).abs();
751 if dist < best_dist {
752 best_dist = dist;
753 best_idx = idx;
754 }
755 }
756
757 #[allow(clippy::cast_possible_truncation)]
758 let result = best_idx as u8;
759 result
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765 use candle_core::DType;
766
767 #[test]
768 fn test_nf4_levels_sorted() {
769 for i in 1..NF4_LEVELS.len() {
770 assert!(NF4_LEVELS[i] > NF4_LEVELS[i - 1]);
771 }
772 }
773
774 #[test]
775 fn test_quantize_dequantize_roundtrip() {
776 let device = Device::Cpu;
777 let original = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
778
779 let quantized = quantize_nf4(&original, 64).unwrap();
780 let restored = dequantize_nf4(&quantized, &device).unwrap();
781
782 let original_vec: Vec<f32> = original.to_vec1().unwrap();
783 let restored_vec: Vec<f32> = restored.to_vec1().unwrap();
784
785 for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
787 let error = (o - r).abs();
788 assert!(error < 0.5, "Error {error} too large for value {o}");
789 }
790 }
791
792 #[test]
793 fn test_quantize_preserves_shape() {
794 let device = Device::Cpu;
795 let original = Tensor::zeros(&[32, 64], DType::F32, &device).unwrap();
796
797 let quantized = quantize_nf4(&original, 64).unwrap();
798 let restored = dequantize_nf4(&quantized, &device).unwrap();
799
800 assert_eq!(restored.shape().dims(), &[32, 64]);
801 }
802
803 #[test]
804 fn test_memory_reduction() {
805 let device = Device::Cpu;
806 let original = Tensor::zeros(&[1024, 1024], DType::F32, &device).unwrap();
807 let original_bytes: i32 = 1024 * 1024 * 4; let quantized = quantize_nf4(&original, 64).unwrap();
810 let quantized_bytes = quantized.size_bytes();
811
812 #[allow(clippy::cast_precision_loss)]
814 let ratio = f64::from(original_bytes) / quantized_bytes as f64;
815 assert!(ratio > 3.0, "Expected >3x reduction, got {ratio:.2}x");
816 }
817
818 #[test]
819 fn test_double_quantize_compression() {
820 let device = Device::Cpu;
821 let original = Tensor::randn(0.0f32, 1.0, (512,), &device).unwrap();
822
823 let config = QuantizationConfig {
824 block_size: 64,
825 double_quant: true,
826 compute_dtype: ComputeDType::F32,
827 strategy: QuantizationStrategy::PerTensor,
828 use_zero_point: false,
829 };
830
831 let quantized = quantize_nf4_with_config(&original, &config).unwrap();
832
833 assert!(quantized.double_quant_enabled);
835 assert!(quantized.scales_quantized.is_some());
836 assert!(quantized.scales_scales.is_some());
837
838 let non_dq_size = quantized.scales.len() * 4; let dq_scales_size = quantized.scales_quantized.as_ref().map_or(0, Vec::len)
841 + quantized
842 .scales_scales
843 .as_ref()
844 .map_or(0, |ss| ss.len() * 4);
845
846 assert!(dq_scales_size < non_dq_size);
847 }
848
849 #[test]
850 fn test_double_quantize_roundtrip() {
851 let device = Device::Cpu;
852 let original = Tensor::randn(0.0f32, 1.0, (256,), &device).unwrap();
853
854 let config = QuantizationConfig {
855 block_size: 64,
856 double_quant: true,
857 compute_dtype: ComputeDType::F32,
858 strategy: QuantizationStrategy::PerTensor,
859 use_zero_point: false,
860 };
861
862 let quantized = quantize_nf4_with_config(&original, &config).unwrap();
863 let restored = dequantize_nf4(&quantized, &device).unwrap();
864
865 let original_vec: Vec<f32> = original.to_vec1().unwrap();
866 let restored_vec: Vec<f32> = restored.to_vec1().unwrap();
867
868 let mut max_error = 0.0f32;
871 for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
872 let error = (o - r).abs();
873 max_error = max_error.max(error);
874 }
875 assert!(max_error < 5.0, "Max error {max_error} too large");
877 }
878
879 #[test]
880 fn test_double_quant_disabled_still_works() {
881 let device = Device::Cpu;
882 let original = Tensor::randn(0.0f32, 1.0, (128,), &device).unwrap();
883
884 let config = QuantizationConfig {
885 block_size: 64,
886 double_quant: false,
887 compute_dtype: ComputeDType::F32,
888 strategy: QuantizationStrategy::PerTensor,
889 use_zero_point: false,
890 };
891
892 let quantized = quantize_nf4_with_config(&original, &config).unwrap();
893
894 assert!(!quantized.double_quant_enabled);
896 assert!(quantized.scales_quantized.is_none());
897 assert!(quantized.scales_scales.is_none());
898
899 let restored = dequantize_nf4(&quantized, &device).unwrap();
900 let original_vec: Vec<f32> = original.to_vec1().unwrap();
901 let restored_vec: Vec<f32> = restored.to_vec1().unwrap();
902
903 for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
905 let error = (o - r).abs();
906 assert!(error < 0.5, "Error {error} too large for value {o}");
907 }
908 }
909
910 #[test]
911 fn test_per_channel_quantization() {
912 let device = Device::Cpu;
913 let original = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
915
916 let config = QuantizationConfig {
917 block_size: 64, double_quant: false,
919 compute_dtype: ComputeDType::F32,
920 strategy: QuantizationStrategy::PerChannel,
921 use_zero_point: false,
922 };
923
924 let quantized = quantize_nf4_with_config(&original, &config).unwrap();
925
926 assert_eq!(
928 quantized.scales.len(),
929 4,
930 "Should have 4 scales (one per channel)"
931 );
932 assert_eq!(quantized.strategy, QuantizationStrategy::PerChannel);
933
934 let restored = dequantize_nf4(&quantized, &device).unwrap();
936 assert_eq!(restored.shape().dims(), &[4, 128]);
937
938 let original_vec: Vec<f32> = original.flatten_all().unwrap().to_vec1().unwrap();
939 let restored_vec: Vec<f32> = restored.flatten_all().unwrap().to_vec1().unwrap();
940
941 let mut max_error = 0.0f32;
942 for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
943 let error = (o - r).abs();
944 max_error = max_error.max(error);
945 }
946 assert!(max_error < 1.0, "Max error {max_error} too large");
947 }
948
949 #[test]
950 fn test_zero_point_quantization() {
951 let device = Device::Cpu;
952 let original = Tensor::rand(0.0f32, 5.0, (256,), &device).unwrap();
954
955 let config = QuantizationConfig {
956 block_size: 64,
957 double_quant: false,
958 compute_dtype: ComputeDType::F32,
959 strategy: QuantizationStrategy::PerTensor,
960 use_zero_point: true,
961 };
962
963 let quantized = quantize_nf4_with_config(&original, &config).unwrap();
964
965 assert!(quantized.zero_points.is_some());
967 let zero_points = quantized.zero_points.as_ref().unwrap();
968 assert_eq!(
969 zero_points.len(),
970 256 / 64,
971 "Should have one zero point per block"
972 );
973
974 let restored = dequantize_nf4(&quantized, &device).unwrap();
976 let original_vec: Vec<f32> = original.to_vec1().unwrap();
977 let restored_vec: Vec<f32> = restored.to_vec1().unwrap();
978
979 let mut max_error = 0.0f32;
980 for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
981 let error = (o - r).abs();
982 max_error = max_error.max(error);
983 }
984 assert!(max_error < 10.0, "Max error {max_error} too large");
986 }
987
988 #[test]
989 fn test_per_channel_with_zero_point() {
990 let device = Device::Cpu;
991 let original = Tensor::rand(0.0f32, 3.0, (2, 128), &device).unwrap();
993
994 let config = QuantizationConfig {
995 block_size: 64,
996 double_quant: false,
997 compute_dtype: ComputeDType::F32,
998 strategy: QuantizationStrategy::PerChannel,
999 use_zero_point: true,
1000 };
1001
1002 let quantized = quantize_nf4_with_config(&original, &config).unwrap();
1003
1004 assert_eq!(quantized.scales.len(), 2);
1006 assert!(quantized.zero_points.is_some());
1007 let zps = quantized.zero_points.as_ref().unwrap();
1008 assert_eq!(zps.len(), 2, "Should have one zero point per channel");
1009
1010 let restored = dequantize_nf4(&quantized, &device).unwrap();
1012 assert_eq!(restored.shape().dims(), &[2, 128]);
1013 }
1014
1015 #[test]
1016 fn test_per_channel_with_double_quant() {
1017 let device = Device::Cpu;
1018 let original = Tensor::randn(0.0f32, 1.0, (8, 64), &device).unwrap();
1019
1020 let config = QuantizationConfig {
1021 block_size: 64,
1022 double_quant: true,
1023 compute_dtype: ComputeDType::F32,
1024 strategy: QuantizationStrategy::PerChannel,
1025 use_zero_point: false,
1026 };
1027
1028 let quantized = quantize_nf4_with_config(&original, &config).unwrap();
1029
1030 assert_eq!(quantized.scales.len(), 8);
1032 assert!(quantized.double_quant_enabled);
1033 assert!(quantized.scales_quantized.is_some());
1034 assert_eq!(quantized.strategy, QuantizationStrategy::PerChannel);
1035
1036 let restored = dequantize_nf4(&quantized, &device).unwrap();
1038 assert_eq!(restored.shape().dims(), &[8, 64]);
1039 }
1040
1041 #[test]
1042 fn test_mixed_precision_f16() {
1043 let device = Device::Cpu;
1044 let original = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
1045
1046 let quantized = quantize_nf4(&original, 64).unwrap();
1047 let restored_f16 =
1048 dequantize_nf4_with_dtype(&quantized, &device, ComputeDType::F16).unwrap();
1049
1050 assert_eq!(restored_f16.dtype(), DType::F16);
1051 assert_eq!(restored_f16.shape().dims(), &[64]);
1052 }
1053
1054 #[test]
1055 fn test_mixed_precision_bf16() {
1056 let device = Device::Cpu;
1057 let original = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
1058
1059 let quantized = quantize_nf4(&original, 64).unwrap();
1060 let restored_bf16 =
1061 dequantize_nf4_with_dtype(&quantized, &device, ComputeDType::BF16).unwrap();
1062
1063 assert_eq!(restored_bf16.dtype(), DType::BF16);
1064 assert_eq!(restored_bf16.shape().dims(), &[64]);
1065 }
1066
1067 #[test]
1068 fn test_mixed_precision_f32_passthrough() {
1069 let device = Device::Cpu;
1070 let original = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
1071
1072 let quantized = quantize_nf4(&original, 64).unwrap();
1073 let restored_f32 =
1074 dequantize_nf4_with_dtype(&quantized, &device, ComputeDType::F32).unwrap();
1075
1076 assert_eq!(restored_f32.dtype(), DType::F32);
1077 assert_eq!(restored_f32.shape().dims(), &[64]);
1078 }
1079
1080 #[test]
1081 fn test_padding_for_quantization_needed() {
1082 let device = Device::Cpu;
1083 let original = Tensor::randn(0.0f32, 1.0, (100,), &device).unwrap();
1085
1086 let padded = pad_for_quantization(&original, 64, 0.0).unwrap();
1087 let padded_numel = padded.elem_count();
1088
1089 assert_eq!(padded_numel % 64, 0);
1091 assert_eq!(padded_numel, 128);
1092 assert_eq!(padded.shape().dims().len(), 1);
1094 }
1095
1096 #[test]
1097 fn test_padding_for_quantization_not_needed() {
1098 let device = Device::Cpu;
1099 let original = Tensor::randn(0.0f32, 1.0, (128,), &device).unwrap();
1101
1102 let padded = pad_for_quantization(&original, 64, 0.0).unwrap();
1103
1104 assert_eq!(padded.elem_count(), 128);
1106 assert_eq!(padded.shape().dims().len(), 1);
1108 }
1109
1110 #[test]
1111 fn test_padding_with_info_roundtrip() {
1112 let device = Device::Cpu;
1113 let original = Tensor::randn(0.0f32, 1.0, (100,), &device).unwrap();
1114
1115 let (padded, info) = pad_for_quantization_with_info(&original, 64, 0.0).unwrap();
1117 assert_eq!(info.pad_count, 28); assert_eq!(info.original_shape, vec![100]);
1119 assert_eq!(info.padded_shape, vec![128]); let quantized = quantize_nf4(&padded, 64).unwrap();
1123
1124 let restored_padded = dequantize_nf4(&quantized, &device).unwrap();
1126
1127 let restored = unpad_tensor(&restored_padded, &info).unwrap();
1129 assert_eq!(restored.shape().dims(), &[100]);
1130 }
1131
1132 #[test]
1133 fn test_padding_2d_tensor() {
1134 let device = Device::Cpu;
1135 let original = Tensor::randn(0.0f32, 1.0, (4, 10), &device).unwrap();
1138
1139 let (padded, info) = pad_for_quantization_with_info(&original, 64, 0.0).unwrap();
1140
1141 assert_eq!(info.pad_count, 24); assert_eq!(info.original_shape, vec![4, 10]);
1143 assert_eq!(info.padded_shape, vec![64]);
1145
1146 assert_eq!(padded.elem_count(), 64);
1148
1149 let padded_flat: Vec<f32> = padded.to_vec1().unwrap();
1151 assert_eq!(padded_flat.len(), 64);
1152 }
1153
1154 #[test]
1155 fn test_padding_preserves_values() {
1156 let device = Device::Cpu;
1157 let original_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1158 let original = Tensor::from_vec(original_data.clone(), (5,), &device).unwrap();
1159
1160 let padded = pad_for_quantization(&original, 8, 0.0).unwrap();
1161 let padded_vec: Vec<f32> = padded.to_vec1().unwrap();
1162
1163 assert_eq!(&padded_vec[..5], &original_data[..]);
1165 assert_eq!(&padded_vec[5..], &[0.0, 0.0, 0.0]);
1167 }
1168
1169 #[test]
1170 fn test_padding_2d_tensor_no_padding_needed() {
1171 let device = Device::Cpu;
1172 let original = Tensor::randn(0.0f32, 1.0, (8, 8), &device).unwrap();
1175
1176 let (padded, info) = pad_for_quantization_with_info(&original, 64, 0.0).unwrap();
1177
1178 assert_eq!(info.pad_count, 0);
1180 assert_eq!(info.original_shape, vec![8, 8]);
1181 assert_eq!(info.padded_shape, vec![64]);
1182 assert_eq!(padded.elem_count(), 64);
1183
1184 let restored = unpad_tensor(&padded, &info).unwrap();
1186 assert_eq!(restored.shape().dims(), &[8, 8]);
1187 }
1188
1189 #[test]
1190 fn test_padding_2d_no_padding_roundtrip() {
1191 let device = Device::Cpu;
1192 #[allow(clippy::cast_precision_loss)]
1194 let original_data: Vec<f32> = (0..64).map(|i| i as f32).collect();
1195 let original = Tensor::from_vec(original_data.clone(), (8, 8), &device).unwrap();
1196
1197 let (padded, info) = pad_for_quantization_with_info(&original, 64, 0.0).unwrap();
1198 assert_eq!(info.pad_count, 0);
1199
1200 let restored = unpad_tensor(&padded, &info).unwrap();
1201
1202 assert_eq!(restored.shape().dims(), &[8, 8]);
1204
1205 let restored_data: Vec<f32> = restored.flatten_all().unwrap().to_vec1().unwrap();
1207 assert_eq!(
1208 restored_data, original_data,
1209 "Values should be preserved through pad/unpad cycle"
1210 );
1211 }
1212
1213 #[test]
1214 fn test_padding_dtype_validation() {
1215 let device = Device::Cpu;
1216 let f16_tensor = Tensor::ones((64,), DType::F16, &device).unwrap();
1218
1219 let result = pad_for_quantization(&f16_tensor, 64, 0.0);
1220 assert!(result.is_err());
1221 let err_msg = result.unwrap_err().to_string();
1222 assert!(
1223 err_msg.contains("F32"),
1224 "Error should mention F32 requirement: {err_msg}"
1225 );
1226
1227 let result = pad_for_quantization_with_info(&f16_tensor, 64, 0.0);
1228 assert!(result.is_err());
1229
1230 let f16_tensor = Tensor::ones((64,), DType::F16, &device).unwrap();
1232 let dummy_info = PaddingInfo {
1233 original_shape: vec![8, 8],
1234 padded_shape: vec![64],
1235 pad_count: 0,
1236 block_size: 64,
1237 };
1238 let result = unpad_tensor(&f16_tensor, &dummy_info);
1239 assert!(result.is_err());
1240 }
1241}