Skip to main content

entrenar/quant/granularity/
calibrate.rs

1//! Calibration functions for different quantization granularities
2
3use super::{QuantGranularity, QuantMode, QuantParams};
4
5/// Calibrate quantization parameters for per-tensor quantization
6///
7/// # Arguments
8/// * `values` - Input tensor values
9/// * `bits` - Bit width (4 or 8)
10/// * `mode` - Symmetric or asymmetric quantization
11pub fn calibrate_per_tensor(values: &[f32], bits: u8, mode: QuantMode) -> QuantParams {
12    let (scale, zero_point) = match mode {
13        QuantMode::Symmetric => {
14            let max_abs = values
15                .iter()
16                .map(|v| v.abs())
17                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
18                .unwrap_or(1e-8)
19                .max(1e-8);
20
21            let qmax = ((1i32 << (bits - 1)) - 1) as f32;
22            let scale = max_abs / qmax;
23            (scale, 0)
24        }
25        QuantMode::Asymmetric => {
26            let (min_val, max_val) =
27                values.iter().fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v)));
28
29            let range = (max_val - min_val).max(1e-8);
30            let qmax = ((1i32 << bits) - 1) as f32;
31            let scale = range / qmax;
32            let zero_point = ((-min_val / scale).round() as i32).clamp(0, qmax as i32);
33            (scale, zero_point)
34        }
35    };
36
37    QuantParams {
38        scales: vec![scale],
39        zero_points: if mode == QuantMode::Asymmetric { vec![zero_point] } else { vec![] },
40        granularity: QuantGranularity::PerTensor,
41        mode,
42        bits,
43    }
44}
45
46/// Calibrate quantization parameters for per-channel quantization
47///
48/// # Arguments
49/// * `values` - Input tensor values (row-major: [channels, features])
50/// * `num_channels` - Number of channels (first dimension)
51/// * `bits` - Bit width (4 or 8)
52/// * `mode` - Symmetric or asymmetric quantization
53pub fn calibrate_per_channel(
54    values: &[f32],
55    num_channels: usize,
56    bits: u8,
57    mode: QuantMode,
58) -> QuantParams {
59    if num_channels == 0 || values.is_empty() {
60        return QuantParams {
61            scales: vec![1.0],
62            zero_points: if mode == QuantMode::Asymmetric { vec![0] } else { vec![] },
63            granularity: QuantGranularity::PerChannel,
64            mode,
65            bits,
66        };
67    }
68
69    let features_per_channel = values.len() / num_channels;
70    let qmax_signed = ((1i32 << (bits - 1)) - 1) as f32;
71    let qmax_unsigned = ((1i32 << bits) - 1) as f32;
72
73    let mut scales = Vec::with_capacity(num_channels);
74    let mut zero_points = Vec::with_capacity(num_channels);
75
76    for ch in 0..num_channels {
77        let start = ch * features_per_channel;
78        let end = start + features_per_channel;
79        let channel_values = &values[start..end];
80
81        match mode {
82            QuantMode::Symmetric => {
83                let max_abs = channel_values
84                    .iter()
85                    .map(|v| v.abs())
86                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
87                    .unwrap_or(1e-8)
88                    .max(1e-8);
89
90                scales.push(max_abs / qmax_signed);
91            }
92            QuantMode::Asymmetric => {
93                let (min_val, max_val) = channel_values
94                    .iter()
95                    .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v)));
96
97                let range = (max_val - min_val).max(1e-8);
98                let scale = range / qmax_unsigned;
99                let zp = ((-min_val / scale).round() as i32).clamp(0, qmax_unsigned as i32);
100
101                scales.push(scale);
102                zero_points.push(zp);
103            }
104        }
105    }
106
107    QuantParams { scales, zero_points, granularity: QuantGranularity::PerChannel, mode, bits }
108}
109
110/// Calibrate quantization parameters for per-group quantization
111///
112/// # Arguments
113/// * `values` - Input tensor values
114/// * `group_size` - Number of elements per group
115/// * `bits` - Bit width (4 or 8)
116/// * `mode` - Symmetric or asymmetric quantization
117pub fn calibrate_per_group(
118    values: &[f32],
119    group_size: usize,
120    bits: u8,
121    mode: QuantMode,
122) -> QuantParams {
123    let group_size = group_size.max(1);
124    let num_groups = values.len().div_ceil(group_size);
125    let qmax_signed = ((1i32 << (bits - 1)) - 1) as f32;
126    let qmax_unsigned = ((1i32 << bits) - 1) as f32;
127
128    let mut scales = Vec::with_capacity(num_groups);
129    let mut zero_points = Vec::with_capacity(num_groups);
130
131    for group_idx in 0..num_groups {
132        let start = group_idx * group_size;
133        let end = (start + group_size).min(values.len());
134        let group_values = &values[start..end];
135
136        match mode {
137            QuantMode::Symmetric => {
138                let max_abs = group_values
139                    .iter()
140                    .map(|v| v.abs())
141                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
142                    .unwrap_or(1e-8)
143                    .max(1e-8);
144
145                scales.push(max_abs / qmax_signed);
146            }
147            QuantMode::Asymmetric => {
148                let (min_val, max_val) = group_values
149                    .iter()
150                    .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v)));
151
152                let range = (max_val - min_val).max(1e-8);
153                let scale = range / qmax_unsigned;
154                let zp = ((-min_val / scale).round() as i32).clamp(0, qmax_unsigned as i32);
155
156                scales.push(scale);
157                zero_points.push(zp);
158            }
159        }
160    }
161
162    QuantParams {
163        scales,
164        zero_points,
165        granularity: QuantGranularity::PerGroup(group_size),
166        mode,
167        bits,
168    }
169}