use std::{
collections::{BTreeSet, HashMap},
time::Instant,
fs::{OpenOptions, symlink_metadata},
path::{
Path,
Component::Normal,
},
io::{BufReader, BufRead , Result, ErrorKind, Error},
hint::black_box,
};
pub type Ns = u128;
#[derive(Default)]
struct Run {
times: HashMap<&'static str, Ns>,
total_ns: Ns,
}
pub struct Bench {
cur_start: Instant,
cur_times: HashMap<&'static str, Ns>,
runs: Vec<Run>,
}
impl Bench {
pub fn new() -> Self {
Self {
cur_start: Instant::now(),
cur_times: HashMap::new(),
runs: Vec::new(),
}
}
pub fn measure<R, F>(&mut self, label: &'static str, mut f: F) -> R
where
F: FnMut() -> R,
{
let warmup_runs = 200;
let concerned_runs = 1000;
for _ in 0..warmup_runs {
black_box(f());
}
let mut total_ns: u128 = 0;
let mut function_output: Option<R> = None;
for _ in 0..concerned_runs {
let t0 = Instant::now();
let out = black_box(f());
let ns = t0.elapsed().as_nanos();
total_ns = total_ns.saturating_add(ns); function_output = Some(out);
}
let average = total_ns / concerned_runs as u128;
self.cur_times.insert(label, average);
function_output.expect("no measured runs completed (closure panicked before first store?)")
}
pub fn measure_with_custom_runs_and_warmup<R, F>(
&mut self,
label: &'static str,
mut f: F,
runs: usize,
warmup: usize,
) -> R
where
F: FnMut() -> R,
{
assert!(runs > warmup, "runs must be > warmup");
let conerened_runs = runs - warmup;
assert!(conerened_runs > 0, "must have at least one timed run");
for _ in 0..warmup {
black_box(f());
}
let mut total_ns: u128 = 0;
let mut last_out: Option<R> = None;
for _ in 0..conerened_runs {
let t0 = Instant::now();
let out = std::hint::black_box(f());
let ns = t0.elapsed().as_nanos();
total_ns = total_ns.saturating_add(ns); last_out = Some(out);
}
let average = total_ns / conerened_runs as u128;
self.cur_times.insert(label, average);
last_out.expect("no measured runs completed (error running the given function?)")
}
pub fn next_run(&mut self) {
let total_ns = self.cur_start.elapsed().as_nanos();
self.runs.push(Run {
times: std::mem::take(&mut self.cur_times),
total_ns,
});
self.cur_start = Instant::now();
}
pub fn save_to_csv<P: AsRef<Path>>(&mut self, path: P) -> csv::Result<()> {
let path = path.as_ref();
ensure_cwd_csv(path).map_err(csv::Error::from)?;
self.next_run();
let mut labels: BTreeSet<&'static str> = BTreeSet::new();
for run in &self.runs {
labels.extend(run.times.keys());
}
let (mut wtr, start_idx) = if path.exists() {
let f = OpenOptions::new().read(true).open(path)?;
let mut rdr = BufReader::new(&f);
let mut last = 0;
let mut line = String::new();
while rdr.read_line(&mut line)? != 0 {
if let Some(first) = line.split(',').next() {
last = first.trim().parse::<usize>().unwrap_or(last);
}
line.clear();
}
let f = OpenOptions::new().append(true).open(path)?;
let w = csv::WriterBuilder::new().has_headers(false).from_writer(f);
(w, last)
} else {
let mut w = csv::Writer::from_path(path)?;
let mut header: Vec<String> = Vec::with_capacity(labels.len() + 2);
header.push("run".into());
for l in &labels {
header.push(format!("{l}_ns"));
}
header.push("total_ns".into());
w.write_record(&header)?;
(w, 0)
};
for (idx, run) in self.runs.iter().enumerate() {
let mut row: Vec<String> = Vec::with_capacity(labels.len() + 2);
row.push((start_idx + idx + 1).to_string()); for l in &labels {
row.push(
run.times
.get(l)
.map_or(String::new(), |v| v.to_string()),
);
}
row.push(run.total_ns.to_string());
wtr.write_record(&row)?;
}
wtr.flush()?;
Ok(())
}
}
#[macro_export]
macro_rules! bench {
($bench:expr, $label:expr, $body:block) => {{
$bench.measure($label, || $body)
}};
}
fn ensure_cwd_csv(path: &Path) -> Result<()> {
if path.is_absolute() {
return Err(Error::new(ErrorKind::InvalidInput, "absolute paths are not allowed"));
}
let mut comps = path.components();
let ok = match (comps.next(), comps.next(), comps.next()) {
(Some(Normal(_)), None, None) => true,
(Some(_), Some(Normal(_)), None) => true,
_ => false,
};
if !ok {
return Err(Error::new(
ErrorKind::InvalidInput,
"only filenames in the current directory are allowed",
));
}
if path.is_dir() {
return Err(Error::new(ErrorKind::InvalidInput, "path points to a directory"));
}
if let Some(ext) = path.extension() {
if ext != "csv" {
return Err(Error::new(ErrorKind::InvalidInput, "file extension must be .csv"));
}
} else {
return Err(Error::new(ErrorKind::InvalidInput, "file must have .csv extension"));
}
if let Ok(meta) = symlink_metadata(path) {
if meta.file_type().is_symlink() {
return Err(Error::new(ErrorKind::InvalidInput, "symlinks are not allowed"));
}
}
Ok(())
}