entrenar/quant/granularity/
calibrate.rs1use super::{QuantGranularity, QuantMode, QuantParams};
4
5pub fn calibrate_per_tensor(values: &[f32], bits: u8, mode: QuantMode) -> QuantParams {
12 let (scale, zero_point) = match mode {
13 QuantMode::Symmetric => {
14 let max_abs = values
15 .iter()
16 .map(|v| v.abs())
17 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
18 .unwrap_or(1e-8)
19 .max(1e-8);
20
21 let qmax = ((1i32 << (bits - 1)) - 1) as f32;
22 let scale = max_abs / qmax;
23 (scale, 0)
24 }
25 QuantMode::Asymmetric => {
26 let (min_val, max_val) =
27 values.iter().fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v)));
28
29 let range = (max_val - min_val).max(1e-8);
30 let qmax = ((1i32 << bits) - 1) as f32;
31 let scale = range / qmax;
32 let zero_point = ((-min_val / scale).round() as i32).clamp(0, qmax as i32);
33 (scale, zero_point)
34 }
35 };
36
37 QuantParams {
38 scales: vec![scale],
39 zero_points: if mode == QuantMode::Asymmetric { vec![zero_point] } else { vec![] },
40 granularity: QuantGranularity::PerTensor,
41 mode,
42 bits,
43 }
44}
45
46pub fn calibrate_per_channel(
54 values: &[f32],
55 num_channels: usize,
56 bits: u8,
57 mode: QuantMode,
58) -> QuantParams {
59 if num_channels == 0 || values.is_empty() {
60 return QuantParams {
61 scales: vec![1.0],
62 zero_points: if mode == QuantMode::Asymmetric { vec![0] } else { vec![] },
63 granularity: QuantGranularity::PerChannel,
64 mode,
65 bits,
66 };
67 }
68
69 let features_per_channel = values.len() / num_channels;
70 let qmax_signed = ((1i32 << (bits - 1)) - 1) as f32;
71 let qmax_unsigned = ((1i32 << bits) - 1) as f32;
72
73 let mut scales = Vec::with_capacity(num_channels);
74 let mut zero_points = Vec::with_capacity(num_channels);
75
76 for ch in 0..num_channels {
77 let start = ch * features_per_channel;
78 let end = start + features_per_channel;
79 let channel_values = &values[start..end];
80
81 match mode {
82 QuantMode::Symmetric => {
83 let max_abs = channel_values
84 .iter()
85 .map(|v| v.abs())
86 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
87 .unwrap_or(1e-8)
88 .max(1e-8);
89
90 scales.push(max_abs / qmax_signed);
91 }
92 QuantMode::Asymmetric => {
93 let (min_val, max_val) = channel_values
94 .iter()
95 .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v)));
96
97 let range = (max_val - min_val).max(1e-8);
98 let scale = range / qmax_unsigned;
99 let zp = ((-min_val / scale).round() as i32).clamp(0, qmax_unsigned as i32);
100
101 scales.push(scale);
102 zero_points.push(zp);
103 }
104 }
105 }
106
107 QuantParams { scales, zero_points, granularity: QuantGranularity::PerChannel, mode, bits }
108}
109
110pub fn calibrate_per_group(
118 values: &[f32],
119 group_size: usize,
120 bits: u8,
121 mode: QuantMode,
122) -> QuantParams {
123 let group_size = group_size.max(1);
124 let num_groups = values.len().div_ceil(group_size);
125 let qmax_signed = ((1i32 << (bits - 1)) - 1) as f32;
126 let qmax_unsigned = ((1i32 << bits) - 1) as f32;
127
128 let mut scales = Vec::with_capacity(num_groups);
129 let mut zero_points = Vec::with_capacity(num_groups);
130
131 for group_idx in 0..num_groups {
132 let start = group_idx * group_size;
133 let end = (start + group_size).min(values.len());
134 let group_values = &values[start..end];
135
136 match mode {
137 QuantMode::Symmetric => {
138 let max_abs = group_values
139 .iter()
140 .map(|v| v.abs())
141 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
142 .unwrap_or(1e-8)
143 .max(1e-8);
144
145 scales.push(max_abs / qmax_signed);
146 }
147 QuantMode::Asymmetric => {
148 let (min_val, max_val) = group_values
149 .iter()
150 .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v)));
151
152 let range = (max_val - min_val).max(1e-8);
153 let scale = range / qmax_unsigned;
154 let zp = ((-min_val / scale).round() as i32).clamp(0, qmax_unsigned as i32);
155
156 scales.push(scale);
157 zero_points.push(zp);
158 }
159 }
160 }
161
162 QuantParams {
163 scales,
164 zero_points,
165 granularity: QuantGranularity::PerGroup(group_size),
166 mode,
167 bits,
168 }
169}