Skip to main content

entrenar/quant/benchmarks/
types.rs

1//! Benchmark result types
2//!
3//! Data structures for quantization benchmark results.
4
5use super::super::granularity::{QuantGranularity, QuantMode};
6use serde::{Deserialize, Serialize};
7
8/// Benchmark results for quantization accuracy
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct QuantBenchmarkResult {
11    /// Benchmark name
12    pub name: String,
13    /// Number of elements tested
14    pub num_elements: usize,
15    /// Bits used for quantization
16    pub bits: u8,
17    /// Granularity used
18    pub granularity: QuantGranularity,
19    /// Mode used (symmetric/asymmetric)
20    pub mode: QuantMode,
21    /// MSE error
22    pub mse: f32,
23    /// Max error
24    pub max_error: f32,
25    /// SQNR in dB
26    pub sqnr_db: f32,
27    /// Compression ratio
28    pub compression_ratio: f32,
29}
30
31impl QuantBenchmarkResult {
32    /// Quality score (higher is better): SQNR / compression overhead
33    pub fn quality_score(&self) -> f32 {
34        if self.compression_ratio > 0.0 {
35            self.sqnr_db / self.compression_ratio.max(1.0)
36        } else {
37            0.0
38        }
39    }
40}
41
42/// Suite of benchmark results
43#[derive(Clone, Debug, Serialize, Deserialize, Default)]
44pub struct BenchmarkSuite {
45    pub results: Vec<QuantBenchmarkResult>,
46}
47
48impl BenchmarkSuite {
49    /// Add a benchmark result
50    pub fn add(&mut self, result: QuantBenchmarkResult) {
51        self.results.push(result);
52    }
53
54    /// Get best result by SQNR
55    pub fn best_by_sqnr(&self) -> Option<&QuantBenchmarkResult> {
56        self.results.iter().max_by(|a, b| a.sqnr_db.total_cmp(&b.sqnr_db))
57    }
58
59    /// Get best result by MSE (lowest)
60    pub fn best_by_mse(&self) -> Option<&QuantBenchmarkResult> {
61        self.results.iter().min_by(|a, b| a.mse.total_cmp(&b.mse))
62    }
63
64    /// Get results sorted by quality score
65    pub fn sorted_by_quality(&self) -> Vec<&QuantBenchmarkResult> {
66        let mut sorted: Vec<_> = self.results.iter().collect();
67        sorted.sort_by(|a, b| b.quality_score().total_cmp(&a.quality_score()));
68        sorted
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    fn make_result(name: &str, sqnr: f32, mse: f32, compression: f32) -> QuantBenchmarkResult {
77        QuantBenchmarkResult {
78            name: name.to_string(),
79            num_elements: 1000,
80            bits: 8,
81            granularity: QuantGranularity::PerTensor,
82            mode: QuantMode::Symmetric,
83            mse,
84            max_error: mse * 2.0,
85            sqnr_db: sqnr,
86            compression_ratio: compression,
87        }
88    }
89
90    #[test]
91    fn test_quality_score_normal() {
92        let result = make_result("test", 40.0, 0.01, 4.0);
93        assert!((result.quality_score() - 10.0).abs() < 1e-6); // 40 / 4 = 10
94    }
95
96    #[test]
97    fn test_quality_score_zero_compression() {
98        let result = make_result("test", 40.0, 0.01, 0.0);
99        assert_eq!(result.quality_score(), 0.0);
100    }
101
102    #[test]
103    fn test_quality_score_low_compression() {
104        let result = make_result("test", 40.0, 0.01, 0.5);
105        // compression_ratio.max(1.0) = 1.0, so 40 / 1 = 40
106        assert!((result.quality_score() - 40.0).abs() < 1e-6);
107    }
108
109    #[test]
110    fn test_benchmark_suite_default() {
111        let suite = BenchmarkSuite::default();
112        assert!(suite.results.is_empty());
113    }
114
115    #[test]
116    fn test_benchmark_suite_add() {
117        let mut suite = BenchmarkSuite::default();
118        suite.add(make_result("test1", 40.0, 0.01, 4.0));
119        assert_eq!(suite.results.len(), 1);
120        suite.add(make_result("test2", 50.0, 0.005, 4.0));
121        assert_eq!(suite.results.len(), 2);
122    }
123
124    #[test]
125    fn test_best_by_sqnr() {
126        let mut suite = BenchmarkSuite::default();
127        suite.add(make_result("low", 30.0, 0.02, 4.0));
128        suite.add(make_result("high", 50.0, 0.01, 4.0));
129        suite.add(make_result("mid", 40.0, 0.015, 4.0));
130
131        let best = suite.best_by_sqnr().expect("operation should succeed");
132        assert_eq!(best.name, "high");
133        assert!((best.sqnr_db - 50.0).abs() < 1e-6);
134    }
135
136    #[test]
137    fn test_best_by_sqnr_empty() {
138        let suite = BenchmarkSuite::default();
139        assert!(suite.best_by_sqnr().is_none());
140    }
141
142    #[test]
143    fn test_best_by_mse() {
144        let mut suite = BenchmarkSuite::default();
145        suite.add(make_result("high_error", 30.0, 0.02, 4.0));
146        suite.add(make_result("low_error", 50.0, 0.005, 4.0));
147        suite.add(make_result("mid_error", 40.0, 0.01, 4.0));
148
149        let best = suite.best_by_mse().expect("operation should succeed");
150        assert_eq!(best.name, "low_error");
151        assert!((best.mse - 0.005).abs() < 1e-6);
152    }
153
154    #[test]
155    fn test_best_by_mse_empty() {
156        let suite = BenchmarkSuite::default();
157        assert!(suite.best_by_mse().is_none());
158    }
159
160    #[test]
161    fn test_sorted_by_quality() {
162        let mut suite = BenchmarkSuite::default();
163        suite.add(make_result("low_quality", 20.0, 0.02, 4.0)); // 20/4 = 5
164        suite.add(make_result("high_quality", 60.0, 0.01, 4.0)); // 60/4 = 15
165        suite.add(make_result("mid_quality", 40.0, 0.015, 4.0)); // 40/4 = 10
166
167        let sorted = suite.sorted_by_quality();
168        assert_eq!(sorted.len(), 3);
169        assert_eq!(sorted[0].name, "high_quality");
170        assert_eq!(sorted[1].name, "mid_quality");
171        assert_eq!(sorted[2].name, "low_quality");
172    }
173
174    #[test]
175    fn test_sorted_by_quality_empty() {
176        let suite = BenchmarkSuite::default();
177        let sorted = suite.sorted_by_quality();
178        assert!(sorted.is_empty());
179    }
180
181    #[test]
182    fn test_quant_benchmark_result_serde() {
183        let result = make_result("test", 40.0, 0.01, 4.0);
184        let json = serde_json::to_string(&result).expect("JSON serialization should succeed");
185        let deserialized: QuantBenchmarkResult =
186            serde_json::from_str(&json).expect("JSON deserialization should succeed");
187        assert_eq!(result.name, deserialized.name);
188        assert!((result.sqnr_db - deserialized.sqnr_db).abs() < 1e-6);
189    }
190
191    #[test]
192    fn test_benchmark_suite_serde() {
193        let mut suite = BenchmarkSuite::default();
194        suite.add(make_result("test1", 40.0, 0.01, 4.0));
195        suite.add(make_result("test2", 50.0, 0.005, 4.0));
196
197        let json = serde_json::to_string(&suite).expect("JSON serialization should succeed");
198        let deserialized: BenchmarkSuite =
199            serde_json::from_str(&json).expect("JSON deserialization should succeed");
200        assert_eq!(suite.results.len(), deserialized.results.len());
201    }
202}