Skip to main content

qlora_rs/
quantization.rs

1//! 4-bit `NormalFloat` (NF4) quantization.
2//!
3//! NF4 quantization uses 16 levels optimized for normally-distributed weights,
4//! providing better accuracy than uniform 4-bit quantization.
5//!
6//! Features:
7//! - Per-tensor and per-channel quantization strategies
8//! - Zero-point (asymmetric) quantization for non-centered distributions
9//! - Double quantization for additional memory savings
10//! - Mixed precision dequantization (F32, F16, BF16)
11//! - Quantization-aware padding for block alignment
12//!
13//! Reference: <https://arxiv.org/abs/2305.14314> (`QLoRA` paper)
14
15use candle_core::{DType, Device, Tensor};
16use serde::{Deserialize, Serialize};
17
18use crate::error::{QLoraError, Result};
19
20/// The 16 quantization levels for NF4, optimized for N(0,1) distribution.
21/// These values minimize expected quantization error for normally-distributed data.
22#[allow(clippy::excessive_precision)]
23pub const NF4_LEVELS: [f32; 16] = [
24    -1.0,
25    -0.696_192_800_998_688,
26    -0.525_073_051_452_637,
27    -0.394_917_488_098_145,
28    -0.284_441_381_692_887,
29    -0.184_773_430_228_233,
30    -0.091_050_036_251_545,
31    0.0,
32    0.079_580_299_556_255,
33    0.160_930_201_411_247,
34    0.246_112_301_945_686,
35    0.337_915_241_718_292,
36    0.440_709_829_330_444,
37    0.562_617_003_917_694,
38    0.722_956_836_223_602,
39    1.0,
40];
41
42/// Quantization strategy for organizing scale factors.
43#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
44pub enum QuantizationStrategy {
45    /// Standard per-tensor block quantization.
46    PerTensor,
47    /// Per-channel quantization (one scale per output channel).
48    PerChannel,
49}
50
51/// Configuration for quantization.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct QuantizationConfig {
54    /// Block size for quantization (number of values sharing a scale).
55    pub block_size: usize,
56
57    /// Whether to use double quantization (quantize the scales).
58    pub double_quant: bool,
59
60    /// Data type for computation (usually bf16 or f16).
61    pub compute_dtype: ComputeDType,
62
63    /// Quantization strategy (per-tensor or per-channel).
64    pub strategy: QuantizationStrategy,
65
66    /// Whether to use zero-point quantization (asymmetric).
67    pub use_zero_point: bool,
68}
69
70/// Compute data type for dequantized values.
71#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
72pub enum ComputeDType {
73    /// 32-bit float
74    #[default]
75    F32,
76    /// 16-bit float
77    F16,
78    /// 16-bit brain float
79    BF16,
80}
81
82impl Default for QuantizationConfig {
83    fn default() -> Self {
84        Self {
85            block_size: 64,
86            double_quant: true,
87            compute_dtype: ComputeDType::F32,
88            strategy: QuantizationStrategy::PerTensor,
89            use_zero_point: false,
90        }
91    }
92}
93
94/// A quantized tensor with scale factors.
95#[derive(Debug)]
96pub struct QuantizedTensor {
97    /// Packed 4-bit values (2 values per byte).
98    pub data: Vec<u8>,
99    /// Scale factors per block.
100    pub scales: Vec<f32>,
101    /// Zero points per block (for asymmetric quantization).
102    pub zero_points: Option<Vec<f32>>,
103    /// Quantized scales (when double quantization is used).
104    pub scales_quantized: Option<Vec<u8>>,
105    /// Scale factors for the scales (when double quantization is used).
106    pub scales_scales: Option<Vec<f32>>,
107    /// Original shape.
108    pub shape: Vec<usize>,
109    /// Block size used for quantization.
110    pub block_size: usize,
111    /// Whether double quantization was applied.
112    pub double_quant_enabled: bool,
113    /// Quantization strategy used.
114    pub strategy: QuantizationStrategy,
115}
116
117impl QuantizedTensor {
118    /// Get the number of elements in the original tensor.
119    #[must_use]
120    pub fn numel(&self) -> usize {
121        self.shape.iter().product()
122    }
123
124    /// Get the memory size in bytes.
125    #[must_use]
126    pub fn size_bytes(&self) -> usize {
127        let mut size = self.data.len() + self.scales.len() * 4;
128        if let Some(ref zp) = self.zero_points {
129            size += zp.len() * 4;
130        }
131        if let Some(ref sq) = self.scales_quantized {
132            size += sq.len();
133        }
134        if let Some(ref ss) = self.scales_scales {
135            size += ss.len() * 4;
136        }
137        size
138    }
139
140    /// Get the compression ratio relative to FP32 format.
141    #[must_use]
142    #[allow(clippy::cast_precision_loss)]
143    pub fn compression_ratio(&self) -> f64 {
144        let fp32_size = self.numel() * 4;
145        let quantized_size = self.size_bytes();
146        fp32_size as f64 / quantized_size as f64
147    }
148}
149
150/// Quantize a tensor to NF4 format with optional double quantization.
151///
152/// # Arguments
153/// * `tensor` - Input tensor to quantize
154/// * `block_size` - Number of elements per quantization block
155///
156/// # Returns
157/// Quantized tensor with packed 4-bit values and scales
158///
159/// # Errors
160/// Returns error if tensor cannot be flattened or has invalid shape
161pub fn quantize_nf4(tensor: &Tensor, block_size: usize) -> Result<QuantizedTensor> {
162    quantize_nf4_with_config(
163        tensor,
164        &QuantizationConfig {
165            block_size,
166            double_quant: false, // Use non-double quantization by default
167            compute_dtype: ComputeDType::F32,
168            strategy: QuantizationStrategy::PerTensor,
169            use_zero_point: false,
170        },
171    )
172}
173
174/// Quantize a tensor to NF4 format with full configuration options.
175///
176/// # Arguments
177/// * `tensor` - Input tensor to quantize
178/// * `config` - Quantization configuration including double quant option
179///
180/// # Returns
181/// Quantized tensor with packed 4-bit values and optional double-quantized scales
182///
183/// # Errors
184/// Returns error if tensor cannot be flattened or has invalid shape
185pub fn quantize_nf4_with_config(
186    tensor: &Tensor,
187    config: &QuantizationConfig,
188) -> Result<QuantizedTensor> {
189    match config.strategy {
190        QuantizationStrategy::PerTensor => quantize_per_tensor(tensor, config),
191        QuantizationStrategy::PerChannel => quantize_per_channel(tensor, config),
192    }
193}
194
195/// Quantize a tensor using per-tensor strategy (standard block quantization).
196fn quantize_per_tensor(tensor: &Tensor, config: &QuantizationConfig) -> Result<QuantizedTensor> {
197    let shape = tensor.shape().dims().to_vec();
198    let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
199    let numel = flat.len();
200
201    if numel % config.block_size != 0 {
202        return Err(QLoraError::InvalidConfig(format!(
203            "tensor size {} not divisible by block size {}",
204            numel, config.block_size
205        )));
206    }
207
208    let num_blocks = numel / config.block_size;
209    let mut scales = Vec::with_capacity(num_blocks);
210    let mut quantized = Vec::with_capacity(numel.div_ceil(2));
211
212    for block_idx in 0..num_blocks {
213        let start = block_idx * config.block_size;
214        let end = start + config.block_size;
215        let block = &flat[start..end];
216
217        // Compute absmax scale
218        let absmax = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
219        let scale = if absmax > 0.0 { absmax } else { 1.0 };
220        scales.push(scale);
221
222        // Quantize each value in the block
223        for chunk in block.chunks(2) {
224            let q0 = quantize_value_nf4(chunk[0] / scale);
225            let q1 = if chunk.len() > 1 {
226                quantize_value_nf4(chunk[1] / scale)
227            } else {
228                0
229            };
230            // Pack two 4-bit values into one byte
231            quantized.push((q1 << 4) | q0);
232        }
233    }
234
235    // Apply double quantization if enabled
236    let (scales_quantized, scales_scales) = if config.double_quant {
237        let (sq, ss) = double_quantize_scales(&scales, 255);
238        (Some(sq), Some(ss))
239    } else {
240        (None, None)
241    };
242
243    // Compute zero points if enabled
244    let zero_points = if config.use_zero_point {
245        Some(compute_zero_points(&flat, config.block_size, &scales))
246    } else {
247        None
248    };
249
250    Ok(QuantizedTensor {
251        data: quantized,
252        scales,
253        zero_points,
254        scales_quantized,
255        scales_scales,
256        shape,
257        block_size: config.block_size,
258        double_quant_enabled: config.double_quant,
259        strategy: config.strategy,
260    })
261}
262
263/// Quantize a tensor using per-channel strategy.
264///
265/// For per-channel quantization, each output channel (first dimension) gets its own
266/// set of scales. This provides better accuracy for weights with different ranges
267/// across channels.
268fn quantize_per_channel(tensor: &Tensor, config: &QuantizationConfig) -> Result<QuantizedTensor> {
269    let shape = tensor.shape().dims().to_vec();
270
271    // For per-channel, we need at least 2D tensor (channels x features)
272    if shape.len() < 2 {
273        return Err(QLoraError::InvalidConfig(
274            "Per-channel quantization requires at least 2D tensor".to_string(),
275        ));
276    }
277
278    let num_channels = shape[0];
279    let channel_size = shape[1..].iter().product::<usize>();
280    let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
281
282    let mut scales = Vec::with_capacity(num_channels);
283    let mut quantized = Vec::with_capacity(flat.len().div_ceil(2));
284
285    // Quantize each channel separately
286    for ch_idx in 0..num_channels {
287        let ch_start = ch_idx * channel_size;
288        let ch_end = ch_start + channel_size;
289        let channel_data = &flat[ch_start..ch_end];
290
291        // Compute scale for entire channel
292        let absmax = channel_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
293        let scale = if absmax > 0.0 { absmax } else { 1.0 };
294        scales.push(scale);
295
296        // Quantize channel values
297        for chunk in channel_data.chunks(2) {
298            let q0 = quantize_value_nf4(chunk[0] / scale);
299            let q1 = if chunk.len() > 1 {
300                quantize_value_nf4(chunk[1] / scale)
301            } else {
302                0
303            };
304            quantized.push((q1 << 4) | q0);
305        }
306    }
307
308    // Apply double quantization if enabled
309    let (scales_quantized, scales_scales) = if config.double_quant {
310        let (sq, ss) = double_quantize_scales(&scales, 255);
311        (Some(sq), Some(ss))
312    } else {
313        (None, None)
314    };
315
316    // For per-channel, zero points apply per channel if enabled
317    let zero_points = if config.use_zero_point {
318        let mut zps = Vec::with_capacity(num_channels);
319        #[allow(clippy::needless_range_loop)]
320        for ch_idx in 0..num_channels {
321            let ch_start = ch_idx * channel_size;
322            let ch_end = ch_start + channel_size;
323            let channel_data = &flat[ch_start..ch_end];
324            let min_val = channel_data.iter().copied().fold(f32::INFINITY, f32::min);
325            let zp = if scales[ch_idx] > 0.0 {
326                -min_val / scales[ch_idx]
327            } else {
328                0.0
329            };
330            zps.push(zp);
331        }
332        Some(zps)
333    } else {
334        None
335    };
336
337    Ok(QuantizedTensor {
338        data: quantized,
339        scales,
340        zero_points,
341        scales_quantized,
342        scales_scales,
343        shape,
344        block_size: channel_size, // For per-channel, block_size = channel size
345        double_quant_enabled: config.double_quant,
346        strategy: config.strategy,
347    })
348}
349
350/// Apply double quantization to scale factors.
351///
352/// Double quantization quantizes the scale factors themselves to reduce memory usage.
353/// Typically uses 8-bit unsigned integers for the quantized scales.
354///
355/// # Arguments
356/// * `scales` - Original float32 scale factors
357/// * `max_val` - Maximum quantization value (typically 255 for u8)
358///
359/// # Returns
360/// Tuple of (`quantized_scales`, `scale_factors_for_scales`)
361fn double_quantize_scales(scales: &[f32], _max_val: usize) -> (Vec<u8>, Vec<f32>) {
362    if scales.is_empty() {
363        return (Vec::new(), Vec::new());
364    }
365
366    // Find the absolute maximum value in scales
367    let absmax = scales.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
368
369    if absmax == 0.0 {
370        return (vec![0; scales.len()], vec![1.0]);
371    }
372
373    // Quantize all scales preserving sign
374    #[allow(clippy::cast_precision_loss)]
375    let scale_factor = absmax / 127.0; // Use 127 for signed range -127 to 127
376    let quantized_scales: Vec<u8> = scales
377        .iter()
378        .map(|&s| {
379            let quantized = (s / scale_factor) + 128.0; // Offset by 128
380            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
381            let result = quantized.clamp(0.0, 255.0) as u8;
382            result
383        })
384        .collect();
385
386    (quantized_scales, vec![scale_factor])
387}
388
389/// Dequantize double-quantized scales back to float.
390fn dequantize_double_scales(scales_quantized: &[u8], scales_scales: &[f32]) -> Vec<f32> {
391    if scales_quantized.is_empty() || scales_scales.is_empty() {
392        return vec![];
393    }
394
395    let scale_factor = scales_scales[0];
396    scales_quantized
397        .iter()
398        .map(|&sq| (f32::from(sq) - 128.0) * scale_factor)
399        .collect()
400}
401
402/// Compute zero points for asymmetric quantization.
403///
404/// For asymmetric quantization, we compute the zero point (offset) for each block
405/// to handle non-symmetric distributions better.
406///
407/// # Arguments
408/// * `data` - Original float data
409/// * `block_size` - Block size for quantization
410/// * `scales` - Pre-computed scale factors
411///
412/// # Returns
413/// Vector of zero points (one per block)
414fn compute_zero_points(data: &[f32], block_size: usize, scales: &[f32]) -> Vec<f32> {
415    let num_blocks = data.len() / block_size;
416    let mut zero_points = Vec::with_capacity(num_blocks);
417
418    #[allow(clippy::needless_range_loop)]
419    for block_idx in 0..num_blocks {
420        let start = block_idx * block_size;
421        let end = start + block_size;
422        let block = &data[start..end];
423        let scale = scales[block_idx];
424
425        // Find min value in the block
426        let min_val = block.iter().copied().fold(f32::INFINITY, f32::min);
427
428        // Zero point is the minimum value normalized by scale
429        // This shifts the quantization range to better handle asymmetric distributions
430        let zero_point = if scale > 0.0 { -min_val / scale } else { 0.0 };
431
432        zero_points.push(zero_point);
433    }
434
435    zero_points
436}
437
438/// Dequantize an NF4 tensor back to float.
439///
440/// Automatically handles double-quantized scales if enabled.
441///
442/// # Arguments
443/// * `quantized` - Quantized tensor to dequantize
444/// * `device` - Device to create the output tensor on
445///
446/// # Returns
447/// Dequantized float tensor with original shape
448///
449/// # Errors
450/// Returns an error if the tensor cannot be created on the specified device
451pub fn dequantize_nf4(quantized: &QuantizedTensor, device: &Device) -> Result<Tensor> {
452    let numel = quantized.numel();
453    let mut output = Vec::with_capacity(numel);
454
455    // Get scales, applying double quantization reversal if needed
456    let scales = if quantized.double_quant_enabled {
457        if let (Some(ref sq), Some(ref ss)) =
458            (&quantized.scales_quantized, &quantized.scales_scales)
459        {
460            dequantize_double_scales(sq, ss)
461        } else {
462            quantized.scales.clone()
463        }
464    } else {
465        quantized.scales.clone()
466    };
467
468    match quantized.strategy {
469        QuantizationStrategy::PerTensor => {
470            let num_blocks = scales.len();
471
472            for block_idx in 0..num_blocks {
473                let scale = scales[block_idx];
474                let zero_point = quantized
475                    .zero_points
476                    .as_ref()
477                    .map_or(0.0, |zp| zp[block_idx]);
478                let start_byte = (block_idx * quantized.block_size) / 2;
479
480                for i in 0..quantized.block_size {
481                    let byte_idx = start_byte + i / 2;
482                    let byte = quantized.data[byte_idx];
483                    let code = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 };
484                    // Apply zero point for asymmetric quantization
485                    let nf4_value = NF4_LEVELS[code as usize] + zero_point;
486                    let value = nf4_value * scale;
487                    output.push(value);
488                }
489            }
490        }
491        QuantizationStrategy::PerChannel => {
492            let num_channels = scales.len();
493            let channel_size = quantized.block_size;
494
495            for ch_idx in 0..num_channels {
496                let scale = scales[ch_idx];
497                let zero_point = quantized.zero_points.as_ref().map_or(0.0, |zp| zp[ch_idx]);
498                let ch_start_byte = (ch_idx * channel_size) / 2;
499
500                for i in 0..channel_size {
501                    let byte_idx = ch_start_byte + i / 2;
502                    let byte = quantized.data[byte_idx];
503                    let code = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 };
504                    let nf4_value = NF4_LEVELS[code as usize] + zero_point;
505                    let value = nf4_value * scale;
506                    output.push(value);
507                }
508            }
509        }
510    }
511
512    let tensor = Tensor::from_vec(output, quantized.shape.clone(), device)?;
513    Ok(tensor)
514}
515
516/// Dequantize an NF4 tensor to a specific data type for mixed precision computation.
517///
518/// This function dequantizes the tensor and converts it to the specified compute
519/// data type (F16, BF16, or F32) for efficient mixed precision training.
520///
521/// # Arguments
522/// * `quantized` - Quantized tensor to dequantize
523/// * `device` - Device to create the output tensor on
524/// * `compute_dtype` - Target data type for the output tensor
525///
526/// # Returns
527/// Dequantized tensor in the specified data type
528///
529/// # Errors
530/// Returns an error if the tensor cannot be created or dtype conversion fails
531pub fn dequantize_nf4_with_dtype(
532    quantized: &QuantizedTensor,
533    device: &Device,
534    compute_dtype: ComputeDType,
535) -> Result<Tensor> {
536    // First dequantize to f32
537    let f32_tensor = dequantize_nf4(quantized, device)?;
538
539    // Convert to target dtype
540    let dtype = match compute_dtype {
541        ComputeDType::F32 => return Ok(f32_tensor),
542        ComputeDType::F16 => DType::F16,
543        ComputeDType::BF16 => DType::BF16,
544    };
545
546    let converted = f32_tensor.to_dtype(dtype)?;
547    Ok(converted)
548}
549
550/// Pad a tensor to ensure its size is divisible by the block size.
551///
552/// This is useful for quantization-aware model preparation, ensuring that
553/// all weight tensors can be cleanly divided into quantization blocks.
554/// The tensor is flattened, padded, and the padded flat tensor is returned.
555/// The original shape is preserved in the returned tensor's metadata via
556/// `pad_for_quantization_with_info`.
557///
558/// # Arguments
559/// * `tensor` - Input tensor to pad (must be F32 dtype)
560/// * `block_size` - Target block size for quantization
561/// * `pad_value` - Value to use for padding (typically 0.0)
562///
563/// # Returns
564/// 1D padded tensor with size divisible by `block_size`
565///
566/// # Errors
567/// Returns an error if the tensor cannot be processed
568///
569/// # Note
570/// This function only works with F32 tensors. Ensure your tensor is in F32 format
571/// before calling this function.
572///
573/// # Errors
574/// Returns an error if the tensor is not F32 dtype.
575pub fn pad_for_quantization(tensor: &Tensor, block_size: usize, pad_value: f32) -> Result<Tensor> {
576    // Validate dtype - only F32 is supported
577    if tensor.dtype() != candle_core::DType::F32 {
578        return Err(QLoraError::InvalidConfig(format!(
579            "pad_for_quantization only supports F32 tensors, got {:?}. \
580             Convert to F32 first with tensor.to_dtype(DType::F32)",
581            tensor.dtype()
582        )));
583    }
584
585    let numel = tensor.elem_count();
586    let device = tensor.device();
587
588    // Check if padding is needed
589    let remainder = numel % block_size;
590    if remainder == 0 {
591        // Flatten and return
592        let flat = tensor.flatten_all()?;
593        return Ok(flat);
594    }
595
596    // Calculate padding needed
597    let pad_count = block_size - remainder;
598    let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
599
600    // Create padded vector
601    let mut padded = flat;
602    padded.extend(std::iter::repeat_n(pad_value, pad_count));
603
604    // Return as 1D tensor
605    let padded_tensor = Tensor::from_vec(padded, (numel + pad_count,), device)?;
606    Ok(padded_tensor)
607}
608
609/// Information about padding applied for quantization.
610#[derive(Debug, Clone)]
611pub struct PaddingInfo {
612    /// Original shape before padding
613    pub original_shape: Vec<usize>,
614    /// Padded shape
615    pub padded_shape: Vec<usize>,
616    /// Number of elements added as padding
617    pub pad_count: usize,
618    /// Block size used for padding calculation
619    pub block_size: usize,
620}
621
622/// Pad a tensor and return information for later unpadding.
623///
624/// This function pads the tensor for quantization and returns padding metadata
625/// that can be used to remove the padding after dequantization.
626/// The tensor is flattened and padded to be divisible by `block_size`.
627///
628/// # Arguments
629/// * `tensor` - Input tensor to pad (must be F32 dtype)
630/// * `block_size` - Target block size for quantization
631/// * `pad_value` - Value to use for padding (typically 0.0)
632///
633/// # Returns
634/// Tuple of (1D padded tensor, `PaddingInfo` with original shape)
635///
636/// # Errors
637/// Returns an error if the tensor cannot be processed
638///
639/// # Note
640/// This function only works with F32 tensors. Ensure your tensor is in F32 format
641/// before calling this function. For mixed precision workflows, pad before quantizing,
642/// then dequantize to your target dtype (F16/BF16), and convert back to F32 before
643/// unpadding.
644pub fn pad_for_quantization_with_info(
645    tensor: &Tensor,
646    block_size: usize,
647    pad_value: f32,
648) -> Result<(Tensor, PaddingInfo)> {
649    // Validate dtype - only F32 is supported
650    if tensor.dtype() != candle_core::DType::F32 {
651        return Err(QLoraError::InvalidConfig(format!(
652            "pad_for_quantization_with_info only supports F32 tensors, got {:?}. \
653             Convert to F32 first with tensor.to_dtype(DType::F32)",
654            tensor.dtype()
655        )));
656    }
657
658    let original_shape = tensor.shape().dims().to_vec();
659    let numel = tensor.elem_count();
660    let device = tensor.device();
661
662    let remainder = numel % block_size;
663    let pad_count = if remainder == 0 {
664        0
665    } else {
666        block_size - remainder
667    };
668
669    if pad_count == 0 {
670        let flat = tensor.flatten_all()?;
671        let info = PaddingInfo {
672            original_shape: original_shape.clone(),
673            padded_shape: vec![numel],
674            pad_count: 0,
675            block_size,
676        };
677        return Ok((flat, info));
678    }
679
680    let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
681    let mut padded = flat;
682    padded.extend(std::iter::repeat_n(pad_value, pad_count));
683
684    let padded_len = numel + pad_count;
685    let padded_tensor = Tensor::from_vec(padded, (padded_len,), device)?;
686
687    let info = PaddingInfo {
688        original_shape,
689        padded_shape: vec![padded_len],
690        pad_count,
691        block_size,
692    };
693
694    Ok((padded_tensor, info))
695}
696
697/// Remove padding from a dequantized tensor using stored padding information.
698///
699/// # Arguments
700/// * `tensor` - Padded tensor to unpad (must be F32 dtype)
701/// * `padding_info` - Padding information from `pad_for_quantization_with_info`
702///
703/// # Returns
704/// Tensor with original shape (padding removed)
705///
706/// # Errors
707/// Returns an error if unpadding fails
708///
709/// # Note
710/// This function only works with F32 tensors. For mixed precision workflows,
711/// convert the tensor to F32 before unpadding.
712pub fn unpad_tensor(tensor: &Tensor, padding_info: &PaddingInfo) -> Result<Tensor> {
713    // Validate dtype - only F32 is supported
714    if tensor.dtype() != candle_core::DType::F32 {
715        return Err(QLoraError::InvalidConfig(format!(
716            "unpad_tensor only supports F32 tensors, got {:?}. \
717             Convert to F32 first with tensor.to_dtype(DType::F32)",
718            tensor.dtype()
719        )));
720    }
721
722    if padding_info.pad_count == 0 {
723        // Flatten first to ensure consistent behavior regardless of input shape
724        let flat = tensor.flatten_all()?;
725        let reshaped = flat.reshape(padding_info.original_shape.clone())?;
726        return Ok(reshaped);
727    }
728
729    let flat = tensor.flatten_all()?.to_vec1::<f32>()?;
730    let original_numel: usize = padding_info.original_shape.iter().product();
731
732    // Remove padding
733    let unpadded: Vec<f32> = flat.into_iter().take(original_numel).collect();
734
735    let unpadded_tensor = Tensor::from_vec(
736        unpadded,
737        padding_info.original_shape.clone(),
738        tensor.device(),
739    )?;
740    Ok(unpadded_tensor)
741}
742
743/// Quantize a single value to NF4 (returns 4-bit code).
744fn quantize_value_nf4(value: f32) -> u8 {
745    // Find closest NF4 level
746    let mut best_idx = 0;
747    let mut best_dist = f32::MAX;
748
749    for (idx, &level) in NF4_LEVELS.iter().enumerate() {
750        let dist = (value - level).abs();
751        if dist < best_dist {
752            best_dist = dist;
753            best_idx = idx;
754        }
755    }
756
757    #[allow(clippy::cast_possible_truncation)]
758    let result = best_idx as u8;
759    result
760}
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765    use candle_core::DType;
766
767    #[test]
768    fn test_nf4_levels_sorted() {
769        for i in 1..NF4_LEVELS.len() {
770            assert!(NF4_LEVELS[i] > NF4_LEVELS[i - 1]);
771        }
772    }
773
774    #[test]
775    fn test_quantize_dequantize_roundtrip() {
776        let device = Device::Cpu;
777        let original = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
778
779        let quantized = quantize_nf4(&original, 64).unwrap();
780        let restored = dequantize_nf4(&quantized, &device).unwrap();
781
782        let original_vec: Vec<f32> = original.to_vec1().unwrap();
783        let restored_vec: Vec<f32> = restored.to_vec1().unwrap();
784
785        // Check that error is bounded (NF4 should have <0.5 max error for normalized data)
786        for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
787            let error = (o - r).abs();
788            assert!(error < 0.5, "Error {error} too large for value {o}");
789        }
790    }
791
792    #[test]
793    fn test_quantize_preserves_shape() {
794        let device = Device::Cpu;
795        let original = Tensor::zeros(&[32, 64], DType::F32, &device).unwrap();
796
797        let quantized = quantize_nf4(&original, 64).unwrap();
798        let restored = dequantize_nf4(&quantized, &device).unwrap();
799
800        assert_eq!(restored.shape().dims(), &[32, 64]);
801    }
802
803    #[test]
804    fn test_memory_reduction() {
805        let device = Device::Cpu;
806        let original = Tensor::zeros(&[1024, 1024], DType::F32, &device).unwrap();
807        let original_bytes: i32 = 1024 * 1024 * 4; // f32 = 4 bytes
808
809        let quantized = quantize_nf4(&original, 64).unwrap();
810        let quantized_bytes = quantized.size_bytes();
811
812        // Should be roughly 4x reduction (4-bit vs 32-bit) plus some overhead for scales
813        #[allow(clippy::cast_precision_loss)]
814        let ratio = f64::from(original_bytes) / quantized_bytes as f64;
815        assert!(ratio > 3.0, "Expected >3x reduction, got {ratio:.2}x");
816    }
817
818    #[test]
819    fn test_double_quantize_compression() {
820        let device = Device::Cpu;
821        let original = Tensor::randn(0.0f32, 1.0, (512,), &device).unwrap();
822
823        let config = QuantizationConfig {
824            block_size: 64,
825            double_quant: true,
826            compute_dtype: ComputeDType::F32,
827            strategy: QuantizationStrategy::PerTensor,
828            use_zero_point: false,
829        };
830
831        let quantized = quantize_nf4_with_config(&original, &config).unwrap();
832
833        // Verify double quantization was applied
834        assert!(quantized.double_quant_enabled);
835        assert!(quantized.scales_quantized.is_some());
836        assert!(quantized.scales_scales.is_some());
837
838        // Double quantized should use less memory than non-double quantized
839        let non_dq_size = quantized.scales.len() * 4; // Original scales
840        let dq_scales_size = quantized.scales_quantized.as_ref().map_or(0, Vec::len)
841            + quantized
842                .scales_scales
843                .as_ref()
844                .map_or(0, |ss| ss.len() * 4);
845
846        assert!(dq_scales_size < non_dq_size);
847    }
848
849    #[test]
850    fn test_double_quantize_roundtrip() {
851        let device = Device::Cpu;
852        let original = Tensor::randn(0.0f32, 1.0, (256,), &device).unwrap();
853
854        let config = QuantizationConfig {
855            block_size: 64,
856            double_quant: true,
857            compute_dtype: ComputeDType::F32,
858            strategy: QuantizationStrategy::PerTensor,
859            use_zero_point: false,
860        };
861
862        let quantized = quantize_nf4_with_config(&original, &config).unwrap();
863        let restored = dequantize_nf4(&quantized, &device).unwrap();
864
865        let original_vec: Vec<f32> = original.to_vec1().unwrap();
866        let restored_vec: Vec<f32> = restored.to_vec1().unwrap();
867
868        // With double quantization, error increases (scale quantization adds error)
869        // but should still be reasonable (typically 10-20% of value magnitude)
870        let mut max_error = 0.0f32;
871        for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
872            let error = (o - r).abs();
873            max_error = max_error.max(error);
874        }
875        // Allow higher error for double quantization - scales are also quantized
876        assert!(max_error < 5.0, "Max error {max_error} too large");
877    }
878
879    #[test]
880    fn test_double_quant_disabled_still_works() {
881        let device = Device::Cpu;
882        let original = Tensor::randn(0.0f32, 1.0, (128,), &device).unwrap();
883
884        let config = QuantizationConfig {
885            block_size: 64,
886            double_quant: false,
887            compute_dtype: ComputeDType::F32,
888            strategy: QuantizationStrategy::PerTensor,
889            use_zero_point: false,
890        };
891
892        let quantized = quantize_nf4_with_config(&original, &config).unwrap();
893
894        // Verify double quantization was NOT applied
895        assert!(!quantized.double_quant_enabled);
896        assert!(quantized.scales_quantized.is_none());
897        assert!(quantized.scales_scales.is_none());
898
899        let restored = dequantize_nf4(&quantized, &device).unwrap();
900        let original_vec: Vec<f32> = original.to_vec1().unwrap();
901        let restored_vec: Vec<f32> = restored.to_vec1().unwrap();
902
903        // Regular quantization error bounds
904        for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
905            let error = (o - r).abs();
906            assert!(error < 0.5, "Error {error} too large for value {o}");
907        }
908    }
909
910    #[test]
911    fn test_per_channel_quantization() {
912        let device = Device::Cpu;
913        // Create 2D tensor (4 channels x 128 features)
914        let original = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
915
916        let config = QuantizationConfig {
917            block_size: 64, // Not used for per-channel
918            double_quant: false,
919            compute_dtype: ComputeDType::F32,
920            strategy: QuantizationStrategy::PerChannel,
921            use_zero_point: false,
922        };
923
924        let quantized = quantize_nf4_with_config(&original, &config).unwrap();
925
926        // Verify we have one scale per channel
927        assert_eq!(
928            quantized.scales.len(),
929            4,
930            "Should have 4 scales (one per channel)"
931        );
932        assert_eq!(quantized.strategy, QuantizationStrategy::PerChannel);
933
934        // Verify roundtrip
935        let restored = dequantize_nf4(&quantized, &device).unwrap();
936        assert_eq!(restored.shape().dims(), &[4, 128]);
937
938        let original_vec: Vec<f32> = original.flatten_all().unwrap().to_vec1().unwrap();
939        let restored_vec: Vec<f32> = restored.flatten_all().unwrap().to_vec1().unwrap();
940
941        let mut max_error = 0.0f32;
942        for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
943            let error = (o - r).abs();
944            max_error = max_error.max(error);
945        }
946        assert!(max_error < 1.0, "Max error {max_error} too large");
947    }
948
949    #[test]
950    fn test_zero_point_quantization() {
951        let device = Device::Cpu;
952        // Create asymmetric data (all positive values)
953        let original = Tensor::rand(0.0f32, 5.0, (256,), &device).unwrap();
954
955        let config = QuantizationConfig {
956            block_size: 64,
957            double_quant: false,
958            compute_dtype: ComputeDType::F32,
959            strategy: QuantizationStrategy::PerTensor,
960            use_zero_point: true,
961        };
962
963        let quantized = quantize_nf4_with_config(&original, &config).unwrap();
964
965        // Verify zero points were computed
966        assert!(quantized.zero_points.is_some());
967        let zero_points = quantized.zero_points.as_ref().unwrap();
968        assert_eq!(
969            zero_points.len(),
970            256 / 64,
971            "Should have one zero point per block"
972        );
973
974        // Verify roundtrip with zero points
975        let restored = dequantize_nf4(&quantized, &device).unwrap();
976        let original_vec: Vec<f32> = original.to_vec1().unwrap();
977        let restored_vec: Vec<f32> = restored.to_vec1().unwrap();
978
979        let mut max_error = 0.0f32;
980        for (o, r) in original_vec.iter().zip(restored_vec.iter()) {
981            let error = (o - r).abs();
982            max_error = max_error.max(error);
983        }
984        // Zero point quantization should handle asymmetric data better
985        assert!(max_error < 10.0, "Max error {max_error} too large");
986    }
987
988    #[test]
989    fn test_per_channel_with_zero_point() {
990        let device = Device::Cpu;
991        // Create asymmetric 2D data
992        let original = Tensor::rand(0.0f32, 3.0, (2, 128), &device).unwrap();
993
994        let config = QuantizationConfig {
995            block_size: 64,
996            double_quant: false,
997            compute_dtype: ComputeDType::F32,
998            strategy: QuantizationStrategy::PerChannel,
999            use_zero_point: true,
1000        };
1001
1002        let quantized = quantize_nf4_with_config(&original, &config).unwrap();
1003
1004        // Verify per-channel with zero points
1005        assert_eq!(quantized.scales.len(), 2);
1006        assert!(quantized.zero_points.is_some());
1007        let zps = quantized.zero_points.as_ref().unwrap();
1008        assert_eq!(zps.len(), 2, "Should have one zero point per channel");
1009
1010        // Verify roundtrip
1011        let restored = dequantize_nf4(&quantized, &device).unwrap();
1012        assert_eq!(restored.shape().dims(), &[2, 128]);
1013    }
1014
1015    #[test]
1016    fn test_per_channel_with_double_quant() {
1017        let device = Device::Cpu;
1018        let original = Tensor::randn(0.0f32, 1.0, (8, 64), &device).unwrap();
1019
1020        let config = QuantizationConfig {
1021            block_size: 64,
1022            double_quant: true,
1023            compute_dtype: ComputeDType::F32,
1024            strategy: QuantizationStrategy::PerChannel,
1025            use_zero_point: false,
1026        };
1027
1028        let quantized = quantize_nf4_with_config(&original, &config).unwrap();
1029
1030        // Verify both per-channel and double quantization applied
1031        assert_eq!(quantized.scales.len(), 8);
1032        assert!(quantized.double_quant_enabled);
1033        assert!(quantized.scales_quantized.is_some());
1034        assert_eq!(quantized.strategy, QuantizationStrategy::PerChannel);
1035
1036        // Verify roundtrip
1037        let restored = dequantize_nf4(&quantized, &device).unwrap();
1038        assert_eq!(restored.shape().dims(), &[8, 64]);
1039    }
1040
1041    #[test]
1042    fn test_mixed_precision_f16() {
1043        let device = Device::Cpu;
1044        let original = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
1045
1046        let quantized = quantize_nf4(&original, 64).unwrap();
1047        let restored_f16 =
1048            dequantize_nf4_with_dtype(&quantized, &device, ComputeDType::F16).unwrap();
1049
1050        assert_eq!(restored_f16.dtype(), DType::F16);
1051        assert_eq!(restored_f16.shape().dims(), &[64]);
1052    }
1053
1054    #[test]
1055    fn test_mixed_precision_bf16() {
1056        let device = Device::Cpu;
1057        let original = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
1058
1059        let quantized = quantize_nf4(&original, 64).unwrap();
1060        let restored_bf16 =
1061            dequantize_nf4_with_dtype(&quantized, &device, ComputeDType::BF16).unwrap();
1062
1063        assert_eq!(restored_bf16.dtype(), DType::BF16);
1064        assert_eq!(restored_bf16.shape().dims(), &[64]);
1065    }
1066
1067    #[test]
1068    fn test_mixed_precision_f32_passthrough() {
1069        let device = Device::Cpu;
1070        let original = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
1071
1072        let quantized = quantize_nf4(&original, 64).unwrap();
1073        let restored_f32 =
1074            dequantize_nf4_with_dtype(&quantized, &device, ComputeDType::F32).unwrap();
1075
1076        assert_eq!(restored_f32.dtype(), DType::F32);
1077        assert_eq!(restored_f32.shape().dims(), &[64]);
1078    }
1079
1080    #[test]
1081    fn test_padding_for_quantization_needed() {
1082        let device = Device::Cpu;
1083        // 100 elements, block size 64 -> needs padding to 128
1084        let original = Tensor::randn(0.0f32, 1.0, (100,), &device).unwrap();
1085
1086        let padded = pad_for_quantization(&original, 64, 0.0).unwrap();
1087        let padded_numel = padded.elem_count();
1088
1089        // Should be padded to 128 (next multiple of 64)
1090        assert_eq!(padded_numel % 64, 0);
1091        assert_eq!(padded_numel, 128);
1092        // Result is always 1D
1093        assert_eq!(padded.shape().dims().len(), 1);
1094    }
1095
1096    #[test]
1097    fn test_padding_for_quantization_not_needed() {
1098        let device = Device::Cpu;
1099        // 128 elements already divisible by 64
1100        let original = Tensor::randn(0.0f32, 1.0, (128,), &device).unwrap();
1101
1102        let padded = pad_for_quantization(&original, 64, 0.0).unwrap();
1103
1104        // Should remain the same count
1105        assert_eq!(padded.elem_count(), 128);
1106        // Result is always 1D (flattened)
1107        assert_eq!(padded.shape().dims().len(), 1);
1108    }
1109
1110    #[test]
1111    fn test_padding_with_info_roundtrip() {
1112        let device = Device::Cpu;
1113        let original = Tensor::randn(0.0f32, 1.0, (100,), &device).unwrap();
1114
1115        // Pad
1116        let (padded, info) = pad_for_quantization_with_info(&original, 64, 0.0).unwrap();
1117        assert_eq!(info.pad_count, 28); // 128 - 100
1118        assert_eq!(info.original_shape, vec![100]);
1119        assert_eq!(info.padded_shape, vec![128]); // 1D padded
1120
1121        // Quantize the padded tensor
1122        let quantized = quantize_nf4(&padded, 64).unwrap();
1123
1124        // Dequantize
1125        let restored_padded = dequantize_nf4(&quantized, &device).unwrap();
1126
1127        // Unpad
1128        let restored = unpad_tensor(&restored_padded, &info).unwrap();
1129        assert_eq!(restored.shape().dims(), &[100]);
1130    }
1131
1132    #[test]
1133    fn test_padding_2d_tensor() {
1134        let device = Device::Cpu;
1135        // Create a 2D tensor that needs padding
1136        // 4x10 = 40 elements, needs padding to 64
1137        let original = Tensor::randn(0.0f32, 1.0, (4, 10), &device).unwrap();
1138
1139        let (padded, info) = pad_for_quantization_with_info(&original, 64, 0.0).unwrap();
1140
1141        assert_eq!(info.pad_count, 24); // 64 - 40 = 24
1142        assert_eq!(info.original_shape, vec![4, 10]);
1143        // Padded shape is 1D
1144        assert_eq!(info.padded_shape, vec![64]);
1145
1146        // Total elements should be 64
1147        assert_eq!(padded.elem_count(), 64);
1148
1149        // Verify padding actually happened
1150        let padded_flat: Vec<f32> = padded.to_vec1().unwrap();
1151        assert_eq!(padded_flat.len(), 64);
1152    }
1153
1154    #[test]
1155    fn test_padding_preserves_values() {
1156        let device = Device::Cpu;
1157        let original_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1158        let original = Tensor::from_vec(original_data.clone(), (5,), &device).unwrap();
1159
1160        let padded = pad_for_quantization(&original, 8, 0.0).unwrap();
1161        let padded_vec: Vec<f32> = padded.to_vec1().unwrap();
1162
1163        // First 5 values should be preserved
1164        assert_eq!(&padded_vec[..5], &original_data[..]);
1165        // Remaining should be padding
1166        assert_eq!(&padded_vec[5..], &[0.0, 0.0, 0.0]);
1167    }
1168
1169    #[test]
1170    fn test_padding_2d_tensor_no_padding_needed() {
1171        let device = Device::Cpu;
1172        // Create a 2D tensor that doesn't need padding
1173        // 8x8 = 64 elements, exactly divisible by 64
1174        let original = Tensor::randn(0.0f32, 1.0, (8, 8), &device).unwrap();
1175
1176        let (padded, info) = pad_for_quantization_with_info(&original, 64, 0.0).unwrap();
1177
1178        // No padding needed
1179        assert_eq!(info.pad_count, 0);
1180        assert_eq!(info.original_shape, vec![8, 8]);
1181        assert_eq!(info.padded_shape, vec![64]);
1182        assert_eq!(padded.elem_count(), 64);
1183
1184        // Test roundtrip: unpad should restore original shape
1185        let restored = unpad_tensor(&padded, &info).unwrap();
1186        assert_eq!(restored.shape().dims(), &[8, 8]);
1187    }
1188
1189    #[test]
1190    fn test_padding_2d_no_padding_roundtrip() {
1191        let device = Device::Cpu;
1192        // 8x8 = 64 elements, exactly divisible by 64 (no padding needed)
1193        #[allow(clippy::cast_precision_loss)]
1194        let original_data: Vec<f32> = (0..64).map(|i| i as f32).collect();
1195        let original = Tensor::from_vec(original_data.clone(), (8, 8), &device).unwrap();
1196
1197        let (padded, info) = pad_for_quantization_with_info(&original, 64, 0.0).unwrap();
1198        assert_eq!(info.pad_count, 0);
1199
1200        let restored = unpad_tensor(&padded, &info).unwrap();
1201
1202        // Verify shape is restored
1203        assert_eq!(restored.shape().dims(), &[8, 8]);
1204
1205        // Verify VALUES are preserved through the entire pad/unpad cycle
1206        let restored_data: Vec<f32> = restored.flatten_all().unwrap().to_vec1().unwrap();
1207        assert_eq!(
1208            restored_data, original_data,
1209            "Values should be preserved through pad/unpad cycle"
1210        );
1211    }
1212
1213    #[test]
1214    fn test_padding_dtype_validation() {
1215        let device = Device::Cpu;
1216        // Create an F16 tensor - should fail with clear error
1217        let f16_tensor = Tensor::ones((64,), DType::F16, &device).unwrap();
1218
1219        let result = pad_for_quantization(&f16_tensor, 64, 0.0);
1220        assert!(result.is_err());
1221        let err_msg = result.unwrap_err().to_string();
1222        assert!(
1223            err_msg.contains("F32"),
1224            "Error should mention F32 requirement: {err_msg}"
1225        );
1226
1227        let result = pad_for_quantization_with_info(&f16_tensor, 64, 0.0);
1228        assert!(result.is_err());
1229
1230        // unpad_tensor dtype validation
1231        let f16_tensor = Tensor::ones((64,), DType::F16, &device).unwrap();
1232        let dummy_info = PaddingInfo {
1233            original_shape: vec![8, 8],
1234            padded_shape: vec![64],
1235            pad_count: 0,
1236            block_size: 64,
1237        };
1238        let result = unpad_tensor(&f16_tensor, &dummy_info);
1239        assert!(result.is_err());
1240    }
1241}