use std::sync::OnceLock;
use std::time::{Duration, Instant};
use svod_device::device::Program;
use crate::Result;
#[derive(Debug, Clone)]
pub struct BenchmarkConfig {
pub warmup_runs: usize,
pub timing_runs: usize,
pub take_minimum: bool,
pub early_stop: Option<Duration>,
pub clear_l2: bool,
}
impl Default for BenchmarkConfig {
fn default() -> Self {
Self { warmup_runs: 0, timing_runs: 3, take_minimum: true, early_stop: None, clear_l2: false }
}
}
#[derive(Debug, Clone)]
pub struct BenchmarkResult {
pub min: Duration,
pub mean: Duration,
pub runs: Vec<Duration>,
}
impl BenchmarkResult {
pub fn timing(&self, take_minimum: bool) -> Duration {
if take_minimum { self.min } else { self.mean }
}
}
pub unsafe fn benchmark_kernel(
kernel: &dyn Program,
buffers: &[*mut u8],
vals: &[i64],
global_size: Option<[usize; 3]>,
local_size: Option<[usize; 3]>,
config: &BenchmarkConfig,
) -> Result<BenchmarkResult> {
for _ in 0..config.warmup_runs {
unsafe { kernel.execute(buffers, vals, global_size, local_size)? };
}
let mut runs = Vec::with_capacity(config.timing_runs);
for i in 0..config.timing_runs {
if config.clear_l2 && i > 0 {
invalidate_l2();
}
let start = Instant::now();
unsafe { kernel.execute(buffers, vals, global_size, local_size)? };
runs.push(start.elapsed());
if let Some(threshold) = config.early_stop
&& runs.iter().copied().min().expect("runs non-empty after push") > threshold
{
break;
}
}
let min = runs.iter().copied().min().unwrap_or(Duration::ZERO);
let total: Duration = runs.iter().sum();
let mean = total / runs.len().max(1) as u32;
Ok(BenchmarkResult { min, mean, runs })
}
pub fn warmup_thread_pool() {
rayon::join(|| (), || ());
}
fn invalidate_l2() {
const SCRATCH_BYTES: usize = 16 * 1024 * 1024;
static SCRATCH: OnceLock<Vec<u8>> = OnceLock::new();
let scratch = SCRATCH.get_or_init(|| vec![0u8; SCRATCH_BYTES]);
let mut acc: u8 = 0;
let stride = 64; let mut i = 0;
while i < scratch.len() {
acc = acc.wrapping_add(scratch[i]);
i += stride;
}
std::hint::black_box(acc);
}
#[cfg(test)]
#[path = "test/unit/benchmark.rs"]
mod tests;