Skip to main content

entrenar/quant/granularity/
metrics.rs

1//! Quantization error metrics and comparison functions
2
3use super::{
4    calibrate_per_channel, calibrate_per_tensor, dequantize_with_params, quantize_with_params,
5    QuantMode,
6};
7
8/// Compute quantization error (MSE)
9pub fn quantization_mse(original: &[f32], dequantized: &[f32]) -> f32 {
10    if original.len() != dequantized.len() || original.is_empty() {
11        return f32::MAX;
12    }
13
14    let sum_sq: f32 = original.iter().zip(dequantized.iter()).map(|(a, b)| (a - b).powi(2)).sum();
15
16    sum_sq / original.len().max(1) as f32
17}
18
19/// Compare per-channel vs per-tensor quantization error
20///
21/// # Arguments
22/// * `values` - Input tensor values (row-major)
23/// * `num_channels` - Number of channels
24/// * `bits` - Bit width
25///
26/// # Returns
27/// (per_tensor_mse, per_channel_mse)
28pub fn compare_granularities(values: &[f32], num_channels: usize, bits: u8) -> (f32, f32) {
29    // Per-tensor
30    let pt_params = calibrate_per_tensor(values, bits, QuantMode::Symmetric);
31    let pt_quantized = quantize_with_params(values, &pt_params);
32    let pt_dequantized = dequantize_with_params(&pt_quantized, &pt_params);
33    let pt_mse = quantization_mse(values, &pt_dequantized);
34
35    // Per-channel
36    let pc_params = calibrate_per_channel(values, num_channels, bits, QuantMode::Symmetric);
37    let pc_quantized = quantize_with_params(values, &pc_params);
38    let pc_dequantized = dequantize_with_params(&pc_quantized, &pc_params);
39    let pc_mse = quantization_mse(values, &pc_dequantized);
40
41    (pt_mse, pc_mse)
42}