use crate::error::{DbxError, DbxResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkResult {
pub name: String,
pub avg_time_ms: f64,
pub min_time_ms: f64,
pub max_time_ms: f64,
pub std_dev_ms: f64,
pub sample_count: usize,
pub timestamp: i64,
}
impl BenchmarkResult {
pub fn new(name: String, samples: &[Duration]) -> Self {
let sample_count = samples.len();
let times_ms: Vec<f64> = samples.iter().map(|d| d.as_secs_f64() * 1000.0).collect();
let avg_time_ms = times_ms.iter().sum::<f64>() / sample_count as f64;
let min_time_ms = times_ms.iter().cloned().fold(f64::INFINITY, f64::min);
let max_time_ms = times_ms.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let variance = times_ms
.iter()
.map(|t| {
let diff = t - avg_time_ms;
diff * diff
})
.sum::<f64>()
/ sample_count as f64;
let std_dev_ms = variance.sqrt();
Self {
name,
avg_time_ms,
min_time_ms,
max_time_ms,
std_dev_ms,
sample_count,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
}
}
}
pub struct BenchmarkRunner {
baseline: Arc<RwLock<HashMap<String, BenchmarkResult>>>,
threshold: f64,
baseline_path: PathBuf,
sample_count: usize,
}
impl BenchmarkRunner {
pub fn new() -> Self {
Self {
baseline: Arc::new(RwLock::new(HashMap::new())),
threshold: 1.1, baseline_path: PathBuf::from("target/benchmark_baseline.json"),
sample_count: 100,
}
}
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold;
self
}
pub fn with_baseline_path(mut self, path: PathBuf) -> Self {
self.baseline_path = path;
self
}
pub fn with_sample_count(mut self, count: usize) -> Self {
self.sample_count = count;
self
}
pub fn run<F>(&self, name: &str, mut f: F) -> DbxResult<BenchmarkResult>
where
F: FnMut(),
{
let mut samples = Vec::with_capacity(self.sample_count);
for _ in 0..5 {
f();
}
for _ in 0..self.sample_count {
let start = Instant::now();
f();
let duration = start.elapsed();
samples.push(duration);
}
Ok(BenchmarkResult::new(name.to_string(), &samples))
}
pub fn save_baseline(&self) -> DbxResult<()> {
let baseline = self.baseline.read().unwrap();
let json = serde_json::to_string_pretty(&*baseline)?;
if let Some(parent) = self.baseline_path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(&self.baseline_path, json)?;
Ok(())
}
pub fn load_baseline(&self) -> DbxResult<()> {
if !self.baseline_path.exists() {
return Ok(()); }
let json = fs::read_to_string(&self.baseline_path)?;
let loaded: HashMap<String, BenchmarkResult> = serde_json::from_str(&json)?;
let mut baseline = self.baseline.write().unwrap();
*baseline = loaded;
Ok(())
}
pub fn update_baseline(&self, name: &str, result: &BenchmarkResult) {
self.baseline
.write()
.unwrap()
.insert(name.to_string(), result.clone());
}
pub fn check_regression(&self, name: &str, result: &BenchmarkResult) -> DbxResult<()> {
let baseline = self.baseline.read().unwrap();
if let Some(baseline_result) = baseline.get(name) {
let ratio = result.avg_time_ms / baseline_result.avg_time_ms;
if ratio > self.threshold {
return Err(DbxError::PerformanceRegression {
name: name.to_string(),
baseline: baseline_result.avg_time_ms,
current: result.avg_time_ms,
ratio,
});
}
}
Ok(())
}
pub fn run_and_check<F>(&self, name: &str, f: F) -> DbxResult<BenchmarkResult>
where
F: FnMut(),
{
let result = self.run(name, f)?;
self.check_regression(name, &result)?;
Ok(result)
}
}
impl Default for BenchmarkRunner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_benchmark_runner_basic() {
let runner = BenchmarkRunner::new();
let result = runner
.run("test_sleep", || {
thread::sleep(Duration::from_millis(1));
})
.unwrap();
assert_eq!(result.name, "test_sleep");
assert!(result.avg_time_ms >= 1.0);
assert!(result.sample_count > 0);
}
#[test]
fn test_baseline_save_load() {
let temp_path = PathBuf::from("target/test_baseline.json");
let runner = BenchmarkRunner::new().with_baseline_path(temp_path.clone());
let result = runner
.run("test_op", || {
let _ = 1 + 1;
})
.unwrap();
runner.update_baseline("test_op", &result);
runner.save_baseline().unwrap();
let runner2 = BenchmarkRunner::new().with_baseline_path(temp_path.clone());
runner2.load_baseline().unwrap();
let baseline = runner2.baseline.read().unwrap();
assert!(baseline.contains_key("test_op"));
let _ = fs::remove_file(temp_path);
}
#[test]
fn test_regression_detection() {
let temp_path = PathBuf::from("target/test_regression.json");
let runner = BenchmarkRunner::new()
.with_baseline_path(temp_path.clone())
.with_threshold(1.5);
let baseline_result = runner
.run("fast_op", || {
let _ = 1 + 1;
})
.unwrap();
runner.update_baseline("fast_op", &baseline_result);
let slow_result = BenchmarkResult {
name: "fast_op".to_string(),
avg_time_ms: baseline_result.avg_time_ms * 2.0, min_time_ms: 0.0,
max_time_ms: 0.0,
std_dev_ms: 0.0,
sample_count: 100,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
};
let result = runner.check_regression("fast_op", &slow_result);
assert!(result.is_err());
let _ = fs::remove_file(temp_path);
}
#[test]
fn test_threshold_configuration() {
let runner = BenchmarkRunner::new().with_threshold(2.0);
assert_eq!(runner.threshold, 2.0);
}
#[test]
fn test_benchmark_comparison() {
let runner = BenchmarkRunner::new();
let fast = runner
.run("fast", || {
let _ = 1 + 1;
})
.unwrap();
let slow = runner
.run("slow", || {
thread::sleep(Duration::from_micros(10));
})
.unwrap();
assert!(slow.avg_time_ms > fast.avg_time_ms);
}
}