Skip to main content

entrenar/quant/benchmarks/
runners.rs

1//! Benchmark runner functions
2//!
3//! Functions for running quantization benchmarks.
4
5use super::super::error_analysis::analyze_error;
6use super::super::granularity::{
7    calibrate_per_channel, calibrate_per_group, calibrate_per_tensor, dequantize_with_params,
8    quantization_mse, quantize_with_params, QuantGranularity, QuantMode,
9};
10use super::generators::{
11    generate_gaussian_weights, generate_multi_channel_weights, generate_uniform_weights,
12    generate_weights_with_outliers,
13};
14use super::types::{BenchmarkSuite, QuantBenchmarkResult};
15
16/// Run benchmark on given values with specified configuration
17pub fn run_benchmark(
18    name: &str,
19    values: &[f32],
20    bits: u8,
21    granularity: QuantGranularity,
22    mode: QuantMode,
23) -> QuantBenchmarkResult {
24    let params = match granularity {
25        QuantGranularity::PerTensor => calibrate_per_tensor(values, bits, mode),
26        QuantGranularity::PerChannel => {
27            // Assume square-ish shape for simplicity
28            let num_channels = (values.len() as f32).sqrt() as usize;
29            calibrate_per_channel(values, num_channels.max(1), bits, mode)
30        }
31        QuantGranularity::PerGroup(size) => calibrate_per_group(values, size, bits, mode),
32    };
33
34    let stats = analyze_error(values, &params, 0.1);
35
36    // Calculate compression ratio
37    let original_bytes = values.len() * 4; // f32 = 4 bytes
38    let scale_bytes = params.scales.len() * 4;
39    let zp_bytes = params.zero_points.len() * 4;
40    let data_bytes = if bits == 4 { values.len().div_ceil(2) } else { values.len() };
41    let compressed_bytes = scale_bytes + zp_bytes + data_bytes;
42    let compression_ratio = original_bytes as f32 / compressed_bytes.max(1) as f32;
43
44    QuantBenchmarkResult {
45        name: name.to_string(),
46        num_elements: values.len(),
47        bits,
48        granularity,
49        mode,
50        mse: stats.mse,
51        max_error: stats.max_error,
52        sqnr_db: stats.sqnr_db,
53        compression_ratio,
54    }
55}
56
57/// Run full benchmark suite on various weight patterns
58pub fn run_full_benchmark_suite(size: usize) -> BenchmarkSuite {
59    let mut suite = BenchmarkSuite::default();
60
61    // Gaussian weights
62    let gaussian = generate_gaussian_weights(size, 0.0, 1.0, 42);
63
64    // Test different configurations
65    for bits in [4u8, 8] {
66        for granularity in [
67            QuantGranularity::PerTensor,
68            QuantGranularity::PerChannel,
69            QuantGranularity::PerGroup(32),
70        ] {
71            let name = format!(
72                "gaussian_{}bit_{:?}",
73                bits,
74                match granularity {
75                    QuantGranularity::PerTensor => "tensor",
76                    QuantGranularity::PerChannel => "channel",
77                    QuantGranularity::PerGroup(_) => "group",
78                }
79            );
80            suite.add(run_benchmark(&name, &gaussian, bits, granularity, QuantMode::Symmetric));
81        }
82    }
83
84    // Uniform weights
85    let uniform = generate_uniform_weights(size, -1.0, 1.0, 43);
86    suite.add(run_benchmark(
87        "uniform_8bit_tensor",
88        &uniform,
89        8,
90        QuantGranularity::PerTensor,
91        QuantMode::Symmetric,
92    ));
93
94    // Weights with outliers
95    let outliers = generate_weights_with_outliers(size, 0.01, 10.0, 44);
96    suite.add(run_benchmark(
97        "outliers_8bit_tensor",
98        &outliers,
99        8,
100        QuantGranularity::PerTensor,
101        QuantMode::Symmetric,
102    ));
103    suite.add(run_benchmark(
104        "outliers_8bit_group32",
105        &outliers,
106        8,
107        QuantGranularity::PerGroup(32),
108        QuantMode::Symmetric,
109    ));
110
111    // Multi-channel weights
112    let multi_ch = generate_multi_channel_weights(16, size / 16, 5.0, 45);
113    suite.add(run_benchmark(
114        "multi_channel_8bit_tensor",
115        &multi_ch,
116        8,
117        QuantGranularity::PerTensor,
118        QuantMode::Symmetric,
119    ));
120    suite.add(run_benchmark(
121        "multi_channel_8bit_channel",
122        &multi_ch,
123        8,
124        QuantGranularity::PerChannel,
125        QuantMode::Symmetric,
126    ));
127
128    suite
129}
130
131/// Compare accuracy degradation across bit widths
132pub fn compare_bit_width_degradation(values: &[f32]) -> Vec<(u8, f32, f32)> {
133    let mut results = Vec::new();
134
135    for bits in [4u8, 8] {
136        let params = calibrate_per_tensor(values, bits, QuantMode::Symmetric);
137        let quantized = quantize_with_params(values, &params);
138        let dequantized = dequantize_with_params(&quantized, &params);
139        let mse = quantization_mse(values, &dequantized);
140
141        let compression = if bits == 4 { 8.0 } else { 4.0 }; // vs f32
142        results.push((bits, mse, compression));
143    }
144
145    results
146}
147
148/// Calculate accuracy retention percentage
149pub fn accuracy_retention(original_mse: f32, quantized_mse: f32) -> f32 {
150    if quantized_mse > 1e-10 {
151        (1.0 - (quantized_mse - original_mse).abs() / quantized_mse.max(original_mse)) * 100.0
152    } else if original_mse > 1e-10 {
153        0.0
154    } else {
155        100.0
156    }
157}