Skip to main content

entrenar/quant/granularity/
quantize.rs

1//! Quantization and dequantization functions
2
3use super::{
4    calibrate_per_channel, calibrate_per_group, calibrate_per_tensor, QuantGranularity, QuantMode,
5    QuantParams, QuantizedTensor,
6};
7
8/// Quantize values using given parameters
9///
10/// # Arguments
11/// * `values` - Input f32 values
12/// * `params` - Quantization parameters
13pub fn quantize_with_params(values: &[f32], params: &QuantParams) -> Vec<i8> {
14    let qmax_signed = ((1i32 << (params.bits - 1)) - 1) as f32;
15    let qmin_signed = -qmax_signed - 1.0;
16    let qmax_unsigned = ((1i32 << params.bits) - 1) as f32;
17
18    let group_size = match params.granularity {
19        QuantGranularity::PerTensor => values.len(),
20        QuantGranularity::PerChannel => values.len() / params.scales.len().max(1),
21        QuantGranularity::PerGroup(size) => size,
22    };
23
24    let mut result = Vec::with_capacity(values.len());
25
26    for (i, &val) in values.iter().enumerate() {
27        let group_idx = i / group_size.max(1);
28        let scale = params.scales.get(group_idx).copied().unwrap_or(1.0);
29
30        let q_val = match params.mode {
31            QuantMode::Symmetric => (val / scale).round().clamp(qmin_signed, qmax_signed) as i8,
32            QuantMode::Asymmetric => {
33                let zp = params.zero_points.get(group_idx).copied().unwrap_or(0);
34                let q = (val / scale + zp as f32).round().clamp(0.0, qmax_unsigned);
35                // Store as signed for uniform representation
36                (q as i32 - 128) as i8
37            }
38        };
39
40        result.push(q_val);
41    }
42
43    result
44}
45
46/// Dequantize values using given parameters
47///
48/// # Arguments
49/// * `quantized` - Quantized i8 values
50/// * `params` - Quantization parameters
51pub fn dequantize_with_params(quantized: &[i8], params: &QuantParams) -> Vec<f32> {
52    let group_size = match params.granularity {
53        QuantGranularity::PerTensor => quantized.len(),
54        QuantGranularity::PerChannel => quantized.len() / params.scales.len().max(1),
55        QuantGranularity::PerGroup(size) => size,
56    };
57
58    let mut result = Vec::with_capacity(quantized.len());
59
60    for (i, &q_val) in quantized.iter().enumerate() {
61        let group_idx = i / group_size.max(1);
62        let scale = params.scales.get(group_idx).copied().unwrap_or(1.0);
63
64        let val = match params.mode {
65            QuantMode::Symmetric => f32::from(q_val) * scale,
66            QuantMode::Asymmetric => {
67                let zp = params.zero_points.get(group_idx).copied().unwrap_or(0);
68                // Convert back from signed storage
69                let q_unsigned = (i32::from(q_val) + 128) as f32;
70                (q_unsigned - zp as f32) * scale
71            }
72        };
73
74        result.push(val);
75    }
76
77    result
78}
79
80/// Quantize tensor with specified granularity
81///
82/// # Arguments
83/// * `values` - Input tensor values
84/// * `shape` - Tensor shape
85/// * `granularity` - Quantization granularity
86/// * `mode` - Quantization mode
87/// * `bits` - Bit width (4 or 8)
88pub fn quantize_tensor(
89    values: &[f32],
90    shape: &[usize],
91    granularity: QuantGranularity,
92    mode: QuantMode,
93    bits: u8,
94) -> QuantizedTensor {
95    let params = match granularity {
96        QuantGranularity::PerTensor => calibrate_per_tensor(values, bits, mode),
97        QuantGranularity::PerChannel => {
98            let num_channels = shape.first().copied().unwrap_or(1);
99            calibrate_per_channel(values, num_channels, bits, mode)
100        }
101        QuantGranularity::PerGroup(group_size) => {
102            calibrate_per_group(values, group_size, bits, mode)
103        }
104    };
105
106    let data = quantize_with_params(values, &params);
107
108    QuantizedTensor { data, params, shape: shape.to_vec() }
109}
110
111/// Dequantize tensor
112pub fn dequantize_tensor(quantized: &QuantizedTensor) -> Vec<f32> {
113    dequantize_with_params(&quantized.data, &quantized.params)
114}