entrenar/quant/granularity/
quantize.rs1use super::{
4 calibrate_per_channel, calibrate_per_group, calibrate_per_tensor, QuantGranularity, QuantMode,
5 QuantParams, QuantizedTensor,
6};
7
8pub fn quantize_with_params(values: &[f32], params: &QuantParams) -> Vec<i8> {
14 let qmax_signed = ((1i32 << (params.bits - 1)) - 1) as f32;
15 let qmin_signed = -qmax_signed - 1.0;
16 let qmax_unsigned = ((1i32 << params.bits) - 1) as f32;
17
18 let group_size = match params.granularity {
19 QuantGranularity::PerTensor => values.len(),
20 QuantGranularity::PerChannel => values.len() / params.scales.len().max(1),
21 QuantGranularity::PerGroup(size) => size,
22 };
23
24 let mut result = Vec::with_capacity(values.len());
25
26 for (i, &val) in values.iter().enumerate() {
27 let group_idx = i / group_size.max(1);
28 let scale = params.scales.get(group_idx).copied().unwrap_or(1.0);
29
30 let q_val = match params.mode {
31 QuantMode::Symmetric => (val / scale).round().clamp(qmin_signed, qmax_signed) as i8,
32 QuantMode::Asymmetric => {
33 let zp = params.zero_points.get(group_idx).copied().unwrap_or(0);
34 let q = (val / scale + zp as f32).round().clamp(0.0, qmax_unsigned);
35 (q as i32 - 128) as i8
37 }
38 };
39
40 result.push(q_val);
41 }
42
43 result
44}
45
46pub fn dequantize_with_params(quantized: &[i8], params: &QuantParams) -> Vec<f32> {
52 let group_size = match params.granularity {
53 QuantGranularity::PerTensor => quantized.len(),
54 QuantGranularity::PerChannel => quantized.len() / params.scales.len().max(1),
55 QuantGranularity::PerGroup(size) => size,
56 };
57
58 let mut result = Vec::with_capacity(quantized.len());
59
60 for (i, &q_val) in quantized.iter().enumerate() {
61 let group_idx = i / group_size.max(1);
62 let scale = params.scales.get(group_idx).copied().unwrap_or(1.0);
63
64 let val = match params.mode {
65 QuantMode::Symmetric => f32::from(q_val) * scale,
66 QuantMode::Asymmetric => {
67 let zp = params.zero_points.get(group_idx).copied().unwrap_or(0);
68 let q_unsigned = (i32::from(q_val) + 128) as f32;
70 (q_unsigned - zp as f32) * scale
71 }
72 };
73
74 result.push(val);
75 }
76
77 result
78}
79
80pub fn quantize_tensor(
89 values: &[f32],
90 shape: &[usize],
91 granularity: QuantGranularity,
92 mode: QuantMode,
93 bits: u8,
94) -> QuantizedTensor {
95 let params = match granularity {
96 QuantGranularity::PerTensor => calibrate_per_tensor(values, bits, mode),
97 QuantGranularity::PerChannel => {
98 let num_channels = shape.first().copied().unwrap_or(1);
99 calibrate_per_channel(values, num_channels, bits, mode)
100 }
101 QuantGranularity::PerGroup(group_size) => {
102 calibrate_per_group(values, group_size, bits, mode)
103 }
104 };
105
106 let data = quantize_with_params(values, ¶ms);
107
108 QuantizedTensor { data, params, shape: shape.to_vec() }
109}
110
111pub fn dequantize_tensor(quantized: &QuantizedTensor) -> Vec<f32> {
113 dequantize_with_params(&quantized.data, &quantized.params)
114}