entrenar/quant/benchmarks/
types.rs1use super::super::granularity::{QuantGranularity, QuantMode};
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct QuantBenchmarkResult {
11 pub name: String,
13 pub num_elements: usize,
15 pub bits: u8,
17 pub granularity: QuantGranularity,
19 pub mode: QuantMode,
21 pub mse: f32,
23 pub max_error: f32,
25 pub sqnr_db: f32,
27 pub compression_ratio: f32,
29}
30
31impl QuantBenchmarkResult {
32 pub fn quality_score(&self) -> f32 {
34 if self.compression_ratio > 0.0 {
35 self.sqnr_db / self.compression_ratio.max(1.0)
36 } else {
37 0.0
38 }
39 }
40}
41
42#[derive(Clone, Debug, Serialize, Deserialize, Default)]
44pub struct BenchmarkSuite {
45 pub results: Vec<QuantBenchmarkResult>,
46}
47
48impl BenchmarkSuite {
49 pub fn add(&mut self, result: QuantBenchmarkResult) {
51 self.results.push(result);
52 }
53
54 pub fn best_by_sqnr(&self) -> Option<&QuantBenchmarkResult> {
56 self.results.iter().max_by(|a, b| a.sqnr_db.total_cmp(&b.sqnr_db))
57 }
58
59 pub fn best_by_mse(&self) -> Option<&QuantBenchmarkResult> {
61 self.results.iter().min_by(|a, b| a.mse.total_cmp(&b.mse))
62 }
63
64 pub fn sorted_by_quality(&self) -> Vec<&QuantBenchmarkResult> {
66 let mut sorted: Vec<_> = self.results.iter().collect();
67 sorted.sort_by(|a, b| b.quality_score().total_cmp(&a.quality_score()));
68 sorted
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 fn make_result(name: &str, sqnr: f32, mse: f32, compression: f32) -> QuantBenchmarkResult {
77 QuantBenchmarkResult {
78 name: name.to_string(),
79 num_elements: 1000,
80 bits: 8,
81 granularity: QuantGranularity::PerTensor,
82 mode: QuantMode::Symmetric,
83 mse,
84 max_error: mse * 2.0,
85 sqnr_db: sqnr,
86 compression_ratio: compression,
87 }
88 }
89
90 #[test]
91 fn test_quality_score_normal() {
92 let result = make_result("test", 40.0, 0.01, 4.0);
93 assert!((result.quality_score() - 10.0).abs() < 1e-6); }
95
96 #[test]
97 fn test_quality_score_zero_compression() {
98 let result = make_result("test", 40.0, 0.01, 0.0);
99 assert_eq!(result.quality_score(), 0.0);
100 }
101
102 #[test]
103 fn test_quality_score_low_compression() {
104 let result = make_result("test", 40.0, 0.01, 0.5);
105 assert!((result.quality_score() - 40.0).abs() < 1e-6);
107 }
108
109 #[test]
110 fn test_benchmark_suite_default() {
111 let suite = BenchmarkSuite::default();
112 assert!(suite.results.is_empty());
113 }
114
115 #[test]
116 fn test_benchmark_suite_add() {
117 let mut suite = BenchmarkSuite::default();
118 suite.add(make_result("test1", 40.0, 0.01, 4.0));
119 assert_eq!(suite.results.len(), 1);
120 suite.add(make_result("test2", 50.0, 0.005, 4.0));
121 assert_eq!(suite.results.len(), 2);
122 }
123
124 #[test]
125 fn test_best_by_sqnr() {
126 let mut suite = BenchmarkSuite::default();
127 suite.add(make_result("low", 30.0, 0.02, 4.0));
128 suite.add(make_result("high", 50.0, 0.01, 4.0));
129 suite.add(make_result("mid", 40.0, 0.015, 4.0));
130
131 let best = suite.best_by_sqnr().expect("operation should succeed");
132 assert_eq!(best.name, "high");
133 assert!((best.sqnr_db - 50.0).abs() < 1e-6);
134 }
135
136 #[test]
137 fn test_best_by_sqnr_empty() {
138 let suite = BenchmarkSuite::default();
139 assert!(suite.best_by_sqnr().is_none());
140 }
141
142 #[test]
143 fn test_best_by_mse() {
144 let mut suite = BenchmarkSuite::default();
145 suite.add(make_result("high_error", 30.0, 0.02, 4.0));
146 suite.add(make_result("low_error", 50.0, 0.005, 4.0));
147 suite.add(make_result("mid_error", 40.0, 0.01, 4.0));
148
149 let best = suite.best_by_mse().expect("operation should succeed");
150 assert_eq!(best.name, "low_error");
151 assert!((best.mse - 0.005).abs() < 1e-6);
152 }
153
154 #[test]
155 fn test_best_by_mse_empty() {
156 let suite = BenchmarkSuite::default();
157 assert!(suite.best_by_mse().is_none());
158 }
159
160 #[test]
161 fn test_sorted_by_quality() {
162 let mut suite = BenchmarkSuite::default();
163 suite.add(make_result("low_quality", 20.0, 0.02, 4.0)); suite.add(make_result("high_quality", 60.0, 0.01, 4.0)); suite.add(make_result("mid_quality", 40.0, 0.015, 4.0)); let sorted = suite.sorted_by_quality();
168 assert_eq!(sorted.len(), 3);
169 assert_eq!(sorted[0].name, "high_quality");
170 assert_eq!(sorted[1].name, "mid_quality");
171 assert_eq!(sorted[2].name, "low_quality");
172 }
173
174 #[test]
175 fn test_sorted_by_quality_empty() {
176 let suite = BenchmarkSuite::default();
177 let sorted = suite.sorted_by_quality();
178 assert!(sorted.is_empty());
179 }
180
181 #[test]
182 fn test_quant_benchmark_result_serde() {
183 let result = make_result("test", 40.0, 0.01, 4.0);
184 let json = serde_json::to_string(&result).expect("JSON serialization should succeed");
185 let deserialized: QuantBenchmarkResult =
186 serde_json::from_str(&json).expect("JSON deserialization should succeed");
187 assert_eq!(result.name, deserialized.name);
188 assert!((result.sqnr_db - deserialized.sqnr_db).abs() < 1e-6);
189 }
190
191 #[test]
192 fn test_benchmark_suite_serde() {
193 let mut suite = BenchmarkSuite::default();
194 suite.add(make_result("test1", 40.0, 0.01, 4.0));
195 suite.add(make_result("test2", 50.0, 0.005, 4.0));
196
197 let json = serde_json::to_string(&suite).expect("JSON serialization should succeed");
198 let deserialized: BenchmarkSuite =
199 serde_json::from_str(&json).expect("JSON deserialization should succeed");
200 assert_eq!(suite.results.len(), deserialized.results.len());
201 }
202}