use super::super::error_analysis::analyze_error;
use super::super::granularity::{
calibrate_per_channel, calibrate_per_group, calibrate_per_tensor, dequantize_with_params,
quantization_mse, quantize_with_params, QuantGranularity, QuantMode,
};
use super::generators::{
generate_gaussian_weights, generate_multi_channel_weights, generate_uniform_weights,
generate_weights_with_outliers,
};
use super::types::{BenchmarkSuite, QuantBenchmarkResult};
pub fn run_benchmark(
name: &str,
values: &[f32],
bits: u8,
granularity: QuantGranularity,
mode: QuantMode,
) -> QuantBenchmarkResult {
let params = match granularity {
QuantGranularity::PerTensor => calibrate_per_tensor(values, bits, mode),
QuantGranularity::PerChannel => {
let num_channels = (values.len() as f32).sqrt() as usize;
calibrate_per_channel(values, num_channels.max(1), bits, mode)
}
QuantGranularity::PerGroup(size) => calibrate_per_group(values, size, bits, mode),
};
let stats = analyze_error(values, ¶ms, 0.1);
let original_bytes = values.len() * 4; let scale_bytes = params.scales.len() * 4;
let zp_bytes = params.zero_points.len() * 4;
let data_bytes = if bits == 4 { values.len().div_ceil(2) } else { values.len() };
let compressed_bytes = scale_bytes + zp_bytes + data_bytes;
let compression_ratio = original_bytes as f32 / compressed_bytes.max(1) as f32;
QuantBenchmarkResult {
name: name.to_string(),
num_elements: values.len(),
bits,
granularity,
mode,
mse: stats.mse,
max_error: stats.max_error,
sqnr_db: stats.sqnr_db,
compression_ratio,
}
}
pub fn run_full_benchmark_suite(size: usize) -> BenchmarkSuite {
let mut suite = BenchmarkSuite::default();
let gaussian = generate_gaussian_weights(size, 0.0, 1.0, 42);
for bits in [4u8, 8] {
for granularity in [
QuantGranularity::PerTensor,
QuantGranularity::PerChannel,
QuantGranularity::PerGroup(32),
] {
let name = format!(
"gaussian_{}bit_{:?}",
bits,
match granularity {
QuantGranularity::PerTensor => "tensor",
QuantGranularity::PerChannel => "channel",
QuantGranularity::PerGroup(_) => "group",
}
);
suite.add(run_benchmark(&name, &gaussian, bits, granularity, QuantMode::Symmetric));
}
}
let uniform = generate_uniform_weights(size, -1.0, 1.0, 43);
suite.add(run_benchmark(
"uniform_8bit_tensor",
&uniform,
8,
QuantGranularity::PerTensor,
QuantMode::Symmetric,
));
let outliers = generate_weights_with_outliers(size, 0.01, 10.0, 44);
suite.add(run_benchmark(
"outliers_8bit_tensor",
&outliers,
8,
QuantGranularity::PerTensor,
QuantMode::Symmetric,
));
suite.add(run_benchmark(
"outliers_8bit_group32",
&outliers,
8,
QuantGranularity::PerGroup(32),
QuantMode::Symmetric,
));
let multi_ch = generate_multi_channel_weights(16, size / 16, 5.0, 45);
suite.add(run_benchmark(
"multi_channel_8bit_tensor",
&multi_ch,
8,
QuantGranularity::PerTensor,
QuantMode::Symmetric,
));
suite.add(run_benchmark(
"multi_channel_8bit_channel",
&multi_ch,
8,
QuantGranularity::PerChannel,
QuantMode::Symmetric,
));
suite
}
pub fn compare_bit_width_degradation(values: &[f32]) -> Vec<(u8, f32, f32)> {
let mut results = Vec::new();
for bits in [4u8, 8] {
let params = calibrate_per_tensor(values, bits, QuantMode::Symmetric);
let quantized = quantize_with_params(values, ¶ms);
let dequantized = dequantize_with_params(&quantized, ¶ms);
let mse = quantization_mse(values, &dequantized);
let compression = if bits == 4 { 8.0 } else { 4.0 }; results.push((bits, mse, compression));
}
results
}
pub fn accuracy_retention(original_mse: f32, quantized_mse: f32) -> f32 {
if quantized_mse > 1e-10 {
(1.0 - (quantized_mse - original_mse).abs() / quantized_mse.max(original_mse)) * 100.0
} else if original_mse > 1e-10 {
0.0
} else {
100.0
}
}