use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkConfig {
pub warmup_iterations: usize,
pub benchmark_iterations: usize,
pub batch_sizes: Vec<usize>,
pub sequence_lengths: Vec<usize>,
}
impl Default for BenchmarkConfig {
fn default() -> Self {
Self {
warmup_iterations: 5,
benchmark_iterations: 20,
batch_sizes: vec![1, 4, 8, 16],
sequence_lengths: vec![128, 256, 512],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkResult {
pub test_name: String,
pub avg_latency_ms: f64,
pub std_latency_ms: f64,
pub throughput: f64,
pub memory_mb: Option<f64>,
}
pub type ModelBenchmarkFn = Box<dyn Fn(&Tensor) -> Result<Tensor>>;
pub struct BenchmarkSuite {
config: BenchmarkConfig,
models: HashMap<String, ModelBenchmarkFn>,
}
impl BenchmarkSuite {
pub fn new(config: BenchmarkConfig) -> Self {
Self {
config,
models: HashMap::new(),
}
}
pub fn add_model<F>(&mut self, name: &str, model_fn: F)
where
F: Fn(&Tensor) -> Result<Tensor> + 'static,
{
self.models.insert(name.to_string(), Box::new(model_fn));
}
pub fn run_benchmarks(&self) -> Result<Vec<BenchmarkResult>> {
let mut results = Vec::new();
for (model_name, model_fn) in &self.models {
for &batch_size in &self.config.batch_sizes {
for &seq_len in &self.config.sequence_lengths {
let test_name =
format!("{}_{}_{}x{}", model_name, "forward", batch_size, seq_len);
let input = Tensor::randn(&[batch_size, seq_len])?;
let result = self.benchmark_model_with_input(&test_name, model_fn, &input)?;
results.push(result);
}
}
}
Ok(results)
}
fn benchmark_model_with_input(
&self,
test_name: &str,
model_fn: &ModelBenchmarkFn,
input: &Tensor,
) -> Result<BenchmarkResult> {
for _ in 0..self.config.warmup_iterations {
model_fn(input)?;
}
let mut durations = Vec::new();
for _ in 0..self.config.benchmark_iterations {
let start = Instant::now();
model_fn(input)?;
durations.push(start.elapsed().as_millis() as f64);
}
let avg_latency_ms = durations.iter().sum::<f64>() / durations.len() as f64;
let variance = durations.iter().map(|&x| (x - avg_latency_ms).powi(2)).sum::<f64>()
/ durations.len() as f64;
let std_latency_ms = variance.sqrt();
let throughput = 1000.0 / avg_latency_ms;
Ok(BenchmarkResult {
test_name: test_name.to_string(),
avg_latency_ms,
std_latency_ms,
throughput,
memory_mb: None, })
}
pub fn generate_report(&self, results: &[BenchmarkResult]) -> String {
let mut report = String::new();
report.push_str("# Benchmark Report\n\n");
report
.push_str("| Test Name | Avg Latency (ms) | Std Dev (ms) | Throughput (samples/s) |\n");
report.push_str("|-----------|------------------|--------------|----------------------|\n");
for result in results {
report.push_str(&format!(
"| {} | {:.2} | {:.2} | {:.2} |\n",
result.test_name, result.avg_latency_ms, result.std_latency_ms, result.throughput
));
}
report
}
}
pub struct ModelComparator;
impl ModelComparator {
pub fn compare_results(baseline: &[BenchmarkResult], comparison: &[BenchmarkResult]) -> String {
let mut report = String::new();
report.push_str("# Performance Comparison\n\n");
let baseline_map: HashMap<_, _> =
baseline.iter().map(|r| (r.test_name.clone(), r)).collect();
report
.push_str("| Test | Baseline (ms) | Comparison (ms) | Speedup | Throughput Change |\n");
report
.push_str("|------|---------------|-----------------|---------|------------------|\n");
for comp_result in comparison {
if let Some(base_result) = baseline_map.get(&comp_result.test_name) {
let speedup = base_result.avg_latency_ms / comp_result.avg_latency_ms;
let throughput_change = (comp_result.throughput - base_result.throughput)
/ base_result.throughput
* 100.0;
report.push_str(&format!(
"| {} | {:.2} | {:.2} | {:.2}x | {:.1}% |\n",
comp_result.test_name,
base_result.avg_latency_ms,
comp_result.avg_latency_ms,
speedup,
throughput_change
));
}
}
report
}
}
pub struct BenchmarkUtils;
impl BenchmarkUtils {
pub fn create_test_inputs(batch_sizes: &[usize], seq_lengths: &[usize]) -> Result<Vec<Tensor>> {
let mut inputs = Vec::new();
for &batch_size in batch_sizes {
for &seq_len in seq_lengths {
inputs.push(Tensor::randn(&[batch_size, seq_len])?);
}
}
Ok(inputs)
}
pub fn measure_execution_time<F, T>(mut f: F) -> (T, Duration)
where
F: FnMut() -> T,
{
let start = Instant::now();
let result = f();
let duration = start.elapsed();
(result, duration)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_benchmark_suite_creation() {
let config = BenchmarkConfig::default();
let suite = BenchmarkSuite::new(config);
assert_eq!(suite.models.len(), 0);
}
#[test]
fn test_benchmark_config() {
let config = BenchmarkConfig::default();
assert_eq!(config.warmup_iterations, 5);
assert_eq!(config.benchmark_iterations, 20);
assert!(!config.batch_sizes.is_empty());
assert!(!config.sequence_lengths.is_empty());
}
#[test]
fn test_benchmark_utils() -> Result<()> {
let inputs = BenchmarkUtils::create_test_inputs(&[2, 4], &[10, 20])?;
assert_eq!(inputs.len(), 4);
Ok(())
}
#[test]
fn test_execution_time_measurement() {
let (result, duration) = BenchmarkUtils::measure_execution_time(|| {
std::thread::sleep(std::time::Duration::from_millis(10));
42
});
assert_eq!(result, 42);
assert!(duration.as_millis() >= 10);
}
}