entrenar/quant/benchmarks/
runners.rs1use super::super::error_analysis::analyze_error;
6use super::super::granularity::{
7 calibrate_per_channel, calibrate_per_group, calibrate_per_tensor, dequantize_with_params,
8 quantization_mse, quantize_with_params, QuantGranularity, QuantMode,
9};
10use super::generators::{
11 generate_gaussian_weights, generate_multi_channel_weights, generate_uniform_weights,
12 generate_weights_with_outliers,
13};
14use super::types::{BenchmarkSuite, QuantBenchmarkResult};
15
16pub fn run_benchmark(
18 name: &str,
19 values: &[f32],
20 bits: u8,
21 granularity: QuantGranularity,
22 mode: QuantMode,
23) -> QuantBenchmarkResult {
24 let params = match granularity {
25 QuantGranularity::PerTensor => calibrate_per_tensor(values, bits, mode),
26 QuantGranularity::PerChannel => {
27 let num_channels = (values.len() as f32).sqrt() as usize;
29 calibrate_per_channel(values, num_channels.max(1), bits, mode)
30 }
31 QuantGranularity::PerGroup(size) => calibrate_per_group(values, size, bits, mode),
32 };
33
34 let stats = analyze_error(values, ¶ms, 0.1);
35
36 let original_bytes = values.len() * 4; let scale_bytes = params.scales.len() * 4;
39 let zp_bytes = params.zero_points.len() * 4;
40 let data_bytes = if bits == 4 { values.len().div_ceil(2) } else { values.len() };
41 let compressed_bytes = scale_bytes + zp_bytes + data_bytes;
42 let compression_ratio = original_bytes as f32 / compressed_bytes.max(1) as f32;
43
44 QuantBenchmarkResult {
45 name: name.to_string(),
46 num_elements: values.len(),
47 bits,
48 granularity,
49 mode,
50 mse: stats.mse,
51 max_error: stats.max_error,
52 sqnr_db: stats.sqnr_db,
53 compression_ratio,
54 }
55}
56
57pub fn run_full_benchmark_suite(size: usize) -> BenchmarkSuite {
59 let mut suite = BenchmarkSuite::default();
60
61 let gaussian = generate_gaussian_weights(size, 0.0, 1.0, 42);
63
64 for bits in [4u8, 8] {
66 for granularity in [
67 QuantGranularity::PerTensor,
68 QuantGranularity::PerChannel,
69 QuantGranularity::PerGroup(32),
70 ] {
71 let name = format!(
72 "gaussian_{}bit_{:?}",
73 bits,
74 match granularity {
75 QuantGranularity::PerTensor => "tensor",
76 QuantGranularity::PerChannel => "channel",
77 QuantGranularity::PerGroup(_) => "group",
78 }
79 );
80 suite.add(run_benchmark(&name, &gaussian, bits, granularity, QuantMode::Symmetric));
81 }
82 }
83
84 let uniform = generate_uniform_weights(size, -1.0, 1.0, 43);
86 suite.add(run_benchmark(
87 "uniform_8bit_tensor",
88 &uniform,
89 8,
90 QuantGranularity::PerTensor,
91 QuantMode::Symmetric,
92 ));
93
94 let outliers = generate_weights_with_outliers(size, 0.01, 10.0, 44);
96 suite.add(run_benchmark(
97 "outliers_8bit_tensor",
98 &outliers,
99 8,
100 QuantGranularity::PerTensor,
101 QuantMode::Symmetric,
102 ));
103 suite.add(run_benchmark(
104 "outliers_8bit_group32",
105 &outliers,
106 8,
107 QuantGranularity::PerGroup(32),
108 QuantMode::Symmetric,
109 ));
110
111 let multi_ch = generate_multi_channel_weights(16, size / 16, 5.0, 45);
113 suite.add(run_benchmark(
114 "multi_channel_8bit_tensor",
115 &multi_ch,
116 8,
117 QuantGranularity::PerTensor,
118 QuantMode::Symmetric,
119 ));
120 suite.add(run_benchmark(
121 "multi_channel_8bit_channel",
122 &multi_ch,
123 8,
124 QuantGranularity::PerChannel,
125 QuantMode::Symmetric,
126 ));
127
128 suite
129}
130
131pub fn compare_bit_width_degradation(values: &[f32]) -> Vec<(u8, f32, f32)> {
133 let mut results = Vec::new();
134
135 for bits in [4u8, 8] {
136 let params = calibrate_per_tensor(values, bits, QuantMode::Symmetric);
137 let quantized = quantize_with_params(values, ¶ms);
138 let dequantized = dequantize_with_params(&quantized, ¶ms);
139 let mse = quantization_mse(values, &dequantized);
140
141 let compression = if bits == 4 { 8.0 } else { 4.0 }; results.push((bits, mse, compression));
143 }
144
145 results
146}
147
148pub fn accuracy_retention(original_mse: f32, quantized_mse: f32) -> f32 {
150 if quantized_mse > 1e-10 {
151 (1.0 - (quantized_mse - original_mse).abs() / quantized_mse.max(original_mse)) * 100.0
152 } else if original_mse > 1e-10 {
153 0.0
154 } else {
155 100.0
156 }
157}