entrenar/quant/granularity/
metrics.rs1use super::{
4 calibrate_per_channel, calibrate_per_tensor, dequantize_with_params, quantize_with_params,
5 QuantMode,
6};
7
8pub fn quantization_mse(original: &[f32], dequantized: &[f32]) -> f32 {
10 if original.len() != dequantized.len() || original.is_empty() {
11 return f32::MAX;
12 }
13
14 let sum_sq: f32 = original.iter().zip(dequantized.iter()).map(|(a, b)| (a - b).powi(2)).sum();
15
16 sum_sq / original.len().max(1) as f32
17}
18
19pub fn compare_granularities(values: &[f32], num_channels: usize, bits: u8) -> (f32, f32) {
29 let pt_params = calibrate_per_tensor(values, bits, QuantMode::Symmetric);
31 let pt_quantized = quantize_with_params(values, &pt_params);
32 let pt_dequantized = dequantize_with_params(&pt_quantized, &pt_params);
33 let pt_mse = quantization_mse(values, &pt_dequantized);
34
35 let pc_params = calibrate_per_channel(values, num_channels, bits, QuantMode::Symmetric);
37 let pc_quantized = quantize_with_params(values, &pc_params);
38 let pc_dequantized = dequantize_with_params(&pc_quantized, &pc_params);
39 let pc_mse = quantization_mse(values, &pc_dequantized);
40
41 (pt_mse, pc_mse)
42}