use crate::{QScheme, QuantConfig, TorshResult};
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct QuantizationBenchmarker {
pub config: BenchmarkConfig,
pub metrics: Vec<BenchmarkResult>,
}
#[derive(Debug, Clone)]
pub struct BenchmarkConfig {
pub warmup_iterations: usize,
pub measurement_iterations: usize,
pub batch_size: usize,
pub measure_memory: bool,
pub measure_accuracy: bool,
}
impl Default for BenchmarkConfig {
fn default() -> Self {
Self {
warmup_iterations: 10,
measurement_iterations: 100,
batch_size: 32,
measure_memory: true,
measure_accuracy: true,
}
}
}
impl Default for QuantizationBenchmarker {
fn default() -> Self {
Self::new(BenchmarkConfig::default())
}
}
impl QuantizationBenchmarker {
pub fn new(config: BenchmarkConfig) -> Self {
Self {
config,
metrics: Vec::new(),
}
}
pub fn benchmark_scheme(
&mut self,
scheme: QScheme,
operation: impl Fn() -> TorshResult<()>,
) -> TorshResult<BenchmarkResult> {
for _ in 0..self.config.warmup_iterations {
operation()?;
}
let start = Instant::now();
for _ in 0..self.config.measurement_iterations {
operation()?;
}
let duration = start.elapsed();
let avg_duration = duration / self.config.measurement_iterations as u32;
let throughput = self.calculate_throughput(avg_duration);
let result = BenchmarkResult {
scheme,
avg_latency_ms: avg_duration.as_millis() as f32,
throughput_ops_per_sec: throughput,
memory_usage_mb: self.estimate_memory_usage(scheme),
accuracy_preservation: self.estimate_accuracy_preservation(scheme),
compression_ratio: self.estimate_compression_ratio(scheme),
};
self.metrics.push(result.clone());
Ok(result)
}
pub fn benchmark_comparison(
&mut self,
schemes: &[QScheme],
operation_factory: impl Fn(QScheme) -> Box<dyn Fn() -> TorshResult<()>>,
) -> TorshResult<Vec<BenchmarkResult>> {
let mut results = Vec::new();
for &scheme in schemes {
let operation = operation_factory(scheme);
let result = self.benchmark_scheme(scheme, || operation())?;
results.push(result);
}
Ok(results)
}
pub fn generate_report(&self) -> String {
let mut report = String::new();
report.push_str("Quantization Benchmarking Report\n");
report.push_str(&"=".repeat(80));
report.push('\n');
report.push_str(&format!(
"{:<20} | {:>12} | {:>12} | {:>10} | {:>10}\n",
"Scheme", "Latency (ms)", "Throughput", "Memory", "Accuracy"
));
report.push_str(&"-".repeat(80));
report.push('\n');
for metric in &self.metrics {
report.push_str(&format!(
"{:<20} | {:>10.2} | {:>10.0} | {:>8.1}MB | {:>8.3}\n",
format!("{:?}", metric.scheme),
metric.avg_latency_ms,
metric.throughput_ops_per_sec,
metric.memory_usage_mb,
metric.accuracy_preservation
));
}
report.push('\n');
report.push_str(&format!(
"Benchmark Configuration:\n\
- Warmup iterations: {}\n\
- Measurement iterations: {}\n\
- Batch size: {}",
self.config.warmup_iterations,
self.config.measurement_iterations,
self.config.batch_size
));
report
}
pub fn find_best_scheme(&self, criteria: OptimizationCriteria) -> Option<QScheme> {
if self.metrics.is_empty() {
return None;
}
let mut best_score = f32::NEG_INFINITY;
let mut best_scheme = None;
for metric in &self.metrics {
let score = criteria.calculate_score(metric);
if score > best_score {
best_score = score;
best_scheme = Some(metric.scheme);
}
}
best_scheme
}
fn calculate_throughput(&self, avg_duration: Duration) -> f32 {
self.config.batch_size as f32 / avg_duration.as_secs_f32()
}
fn estimate_memory_usage(&self, scheme: QScheme) -> f32 {
match scheme {
QScheme::Binary => 0.5,
QScheme::Ternary => 1.0,
QScheme::Int4PerTensor | QScheme::Int4PerChannel => 2.0,
QScheme::PerTensorAffine | QScheme::PerChannelAffine => 4.0,
QScheme::PerTensorSymmetric | QScheme::PerChannelSymmetric => 4.0,
QScheme::MixedPrecision => 8.0,
QScheme::GroupWise => 3.0,
}
}
fn estimate_accuracy_preservation(&self, scheme: QScheme) -> f32 {
match scheme {
QScheme::PerTensorAffine => 0.98,
QScheme::PerChannelAffine => 0.99,
QScheme::PerTensorSymmetric => 0.97,
QScheme::PerChannelSymmetric => 0.98,
QScheme::Int4PerTensor => 0.93,
QScheme::Int4PerChannel => 0.95,
QScheme::MixedPrecision => 0.99,
QScheme::Binary => 0.75,
QScheme::Ternary => 0.85,
QScheme::GroupWise => 0.96,
}
}
fn estimate_compression_ratio(&self, scheme: QScheme) -> f32 {
match scheme {
QScheme::PerTensorAffine => 4.0,
QScheme::PerChannelAffine => 3.8,
QScheme::PerTensorSymmetric => 4.0,
QScheme::PerChannelSymmetric => 3.8,
QScheme::Int4PerTensor => 8.0,
QScheme::Int4PerChannel => 7.5,
QScheme::MixedPrecision => 5.0,
QScheme::Binary => 32.0,
QScheme::Ternary => 16.0,
QScheme::GroupWise => 6.0,
}
}
pub fn clear_metrics(&mut self) {
self.metrics.clear();
}
pub fn get_metrics(&self) -> &[BenchmarkResult] {
&self.metrics
}
}
#[derive(Debug, Clone)]
pub struct BenchmarkResult {
pub scheme: QScheme,
pub avg_latency_ms: f32,
pub throughput_ops_per_sec: f32,
pub memory_usage_mb: f32,
pub accuracy_preservation: f32,
pub compression_ratio: f32,
}
#[derive(Debug, Clone)]
pub struct OptimizationCriteria {
pub latency_weight: f32,
pub throughput_weight: f32,
pub memory_weight: f32,
pub accuracy_weight: f32,
pub compression_weight: f32,
}
impl OptimizationCriteria {
pub fn optimize_for_speed() -> Self {
Self {
latency_weight: 0.4,
throughput_weight: 0.4,
memory_weight: 0.1,
accuracy_weight: 0.1,
compression_weight: 0.0,
}
}
pub fn optimize_for_accuracy() -> Self {
Self {
latency_weight: 0.1,
throughput_weight: 0.1,
memory_weight: 0.1,
accuracy_weight: 0.7,
compression_weight: 0.0,
}
}
pub fn optimize_for_size() -> Self {
Self {
latency_weight: 0.1,
throughput_weight: 0.1,
memory_weight: 0.3,
accuracy_weight: 0.2,
compression_weight: 0.3,
}
}
pub fn calculate_score(&self, result: &BenchmarkResult) -> f32 {
let latency_score = (1.0 / result.avg_latency_ms.max(0.001)) * self.latency_weight;
let throughput_score =
(result.throughput_ops_per_sec / 10000.0).min(1.0) * self.throughput_weight;
let memory_score = (1.0 / result.memory_usage_mb.max(0.1)) * self.memory_weight;
let accuracy_score = result.accuracy_preservation * self.accuracy_weight;
let compression_score =
(result.compression_ratio / 32.0).min(1.0) * self.compression_weight;
latency_score + throughput_score + memory_score + accuracy_score + compression_score
}
}