use super::super::granularity::{QuantGranularity, QuantMode};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QuantBenchmarkResult {
pub name: String,
pub num_elements: usize,
pub bits: u8,
pub granularity: QuantGranularity,
pub mode: QuantMode,
pub mse: f32,
pub max_error: f32,
pub sqnr_db: f32,
pub compression_ratio: f32,
}
impl QuantBenchmarkResult {
pub fn quality_score(&self) -> f32 {
if self.compression_ratio > 0.0 {
self.sqnr_db / self.compression_ratio.max(1.0)
} else {
0.0
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct BenchmarkSuite {
pub results: Vec<QuantBenchmarkResult>,
}
impl BenchmarkSuite {
pub fn add(&mut self, result: QuantBenchmarkResult) {
self.results.push(result);
}
pub fn best_by_sqnr(&self) -> Option<&QuantBenchmarkResult> {
self.results.iter().max_by(|a, b| a.sqnr_db.total_cmp(&b.sqnr_db))
}
pub fn best_by_mse(&self) -> Option<&QuantBenchmarkResult> {
self.results.iter().min_by(|a, b| a.mse.total_cmp(&b.mse))
}
pub fn sorted_by_quality(&self) -> Vec<&QuantBenchmarkResult> {
let mut sorted: Vec<_> = self.results.iter().collect();
sorted.sort_by(|a, b| b.quality_score().total_cmp(&a.quality_score()));
sorted
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_result(name: &str, sqnr: f32, mse: f32, compression: f32) -> QuantBenchmarkResult {
QuantBenchmarkResult {
name: name.to_string(),
num_elements: 1000,
bits: 8,
granularity: QuantGranularity::PerTensor,
mode: QuantMode::Symmetric,
mse,
max_error: mse * 2.0,
sqnr_db: sqnr,
compression_ratio: compression,
}
}
#[test]
fn test_quality_score_normal() {
let result = make_result("test", 40.0, 0.01, 4.0);
assert!((result.quality_score() - 10.0).abs() < 1e-6); }
#[test]
fn test_quality_score_zero_compression() {
let result = make_result("test", 40.0, 0.01, 0.0);
assert_eq!(result.quality_score(), 0.0);
}
#[test]
fn test_quality_score_low_compression() {
let result = make_result("test", 40.0, 0.01, 0.5);
assert!((result.quality_score() - 40.0).abs() < 1e-6);
}
#[test]
fn test_benchmark_suite_default() {
let suite = BenchmarkSuite::default();
assert!(suite.results.is_empty());
}
#[test]
fn test_benchmark_suite_add() {
let mut suite = BenchmarkSuite::default();
suite.add(make_result("test1", 40.0, 0.01, 4.0));
assert_eq!(suite.results.len(), 1);
suite.add(make_result("test2", 50.0, 0.005, 4.0));
assert_eq!(suite.results.len(), 2);
}
#[test]
fn test_best_by_sqnr() {
let mut suite = BenchmarkSuite::default();
suite.add(make_result("low", 30.0, 0.02, 4.0));
suite.add(make_result("high", 50.0, 0.01, 4.0));
suite.add(make_result("mid", 40.0, 0.015, 4.0));
let best = suite.best_by_sqnr().expect("operation should succeed");
assert_eq!(best.name, "high");
assert!((best.sqnr_db - 50.0).abs() < 1e-6);
}
#[test]
fn test_best_by_sqnr_empty() {
let suite = BenchmarkSuite::default();
assert!(suite.best_by_sqnr().is_none());
}
#[test]
fn test_best_by_mse() {
let mut suite = BenchmarkSuite::default();
suite.add(make_result("high_error", 30.0, 0.02, 4.0));
suite.add(make_result("low_error", 50.0, 0.005, 4.0));
suite.add(make_result("mid_error", 40.0, 0.01, 4.0));
let best = suite.best_by_mse().expect("operation should succeed");
assert_eq!(best.name, "low_error");
assert!((best.mse - 0.005).abs() < 1e-6);
}
#[test]
fn test_best_by_mse_empty() {
let suite = BenchmarkSuite::default();
assert!(suite.best_by_mse().is_none());
}
#[test]
fn test_sorted_by_quality() {
let mut suite = BenchmarkSuite::default();
suite.add(make_result("low_quality", 20.0, 0.02, 4.0)); suite.add(make_result("high_quality", 60.0, 0.01, 4.0)); suite.add(make_result("mid_quality", 40.0, 0.015, 4.0));
let sorted = suite.sorted_by_quality();
assert_eq!(sorted.len(), 3);
assert_eq!(sorted[0].name, "high_quality");
assert_eq!(sorted[1].name, "mid_quality");
assert_eq!(sorted[2].name, "low_quality");
}
#[test]
fn test_sorted_by_quality_empty() {
let suite = BenchmarkSuite::default();
let sorted = suite.sorted_by_quality();
assert!(sorted.is_empty());
}
#[test]
fn test_quant_benchmark_result_serde() {
let result = make_result("test", 40.0, 0.01, 4.0);
let json = serde_json::to_string(&result).expect("JSON serialization should succeed");
let deserialized: QuantBenchmarkResult =
serde_json::from_str(&json).expect("JSON deserialization should succeed");
assert_eq!(result.name, deserialized.name);
assert!((result.sqnr_db - deserialized.sqnr_db).abs() < 1e-6);
}
#[test]
fn test_benchmark_suite_serde() {
let mut suite = BenchmarkSuite::default();
suite.add(make_result("test1", 40.0, 0.01, 4.0));
suite.add(make_result("test2", 50.0, 0.005, 4.0));
let json = serde_json::to_string(&suite).expect("JSON serialization should succeed");
let deserialized: BenchmarkSuite =
serde_json::from_str(&json).expect("JSON deserialization should succeed");
assert_eq!(suite.results.len(), deserialized.results.len());
}
}