use std::{
any::{Any, TypeId},
borrow::Cow,
mem, slice,
sync::OnceLock,
};
use crate::Bencher;
pub struct BenchArgs {
args: OnceLock<ErasedArgsSlice>,
}
#[derive(Clone, Copy)]
pub struct BenchArgsRunner {
args: &'static ErasedArgsSlice,
bench: fn(Bencher, &ErasedArgsSlice, arg_index: usize),
}
struct ErasedArgsSlice {
args: *const (),
names: *const &'static str,
len: usize,
arg_type: TypeId,
}
unsafe impl Send for ErasedArgsSlice {}
unsafe impl Sync for ErasedArgsSlice {}
impl BenchArgs {
pub const fn new() -> Self {
Self { args: OnceLock::new() }
}
pub fn runner<I, B>(
&'static self,
make_args: impl FnOnce() -> I,
_bench_impl: B,
) -> BenchArgsRunner
where
I: IntoIterator,
I::Item: Any + ToString + Send + Sync,
B: FnOnce(Bencher, &I::Item) + Copy,
{
let args = self.args.get_or_init(|| {
let args: &'static [I::Item] = Box::leak(make_args().into_iter().collect());
let names: &'static [&str] = 'names: {
if let Some(args) = (&args as &dyn Any).downcast_ref::<&[&str]>() {
break 'names args;
}
Box::leak(
args.iter()
.map(|arg| -> &str {
if let Some(arg) = (arg as &dyn Any).downcast_ref::<String>() {
return arg;
}
if let Some(arg) = (arg as &dyn Any).downcast_ref::<Box<str>>() {
return arg;
}
if let Some(arg) = (arg as &dyn Any).downcast_ref::<Cow<str>>() {
return arg;
}
Box::leak(arg.to_string().into_boxed_str())
})
.collect(),
)
};
ErasedArgsSlice {
args: crate::black_box(args.as_ptr().cast()),
names: names.as_ptr(),
len: args.len(),
arg_type: TypeId::of::<I::Item>(),
}
});
BenchArgsRunner { args, bench: bench::<I::Item, B> }
}
}
impl BenchArgsRunner {
#[inline]
pub(crate) fn bench(&self, bencher: Bencher, index: usize) {
(self.bench)(bencher, self.args, index)
}
#[inline]
pub(crate) fn arg_names(&self) -> &'static [&'static str] {
self.args.names()
}
}
impl ErasedArgsSlice {
#[inline]
fn typed_args<T: Any>(&self) -> Option<&[T]> {
if self.arg_type == TypeId::of::<T>() {
Some(unsafe { slice::from_raw_parts(self.args.cast(), self.len) })
} else {
None
}
}
#[inline]
fn names(&self) -> &'static [&str] {
unsafe { slice::from_raw_parts(self.names, self.len) }
}
}
fn bench<T, B>(bencher: Bencher, erased_args: &ErasedArgsSlice, arg_index: usize)
where
T: Any,
B: FnOnce(Bencher, &T) + Copy,
{
let Some(typed_args) = erased_args.typed_args::<T>() else {
type_mismatch::<T>();
#[cold]
#[inline(never)]
fn type_mismatch<T>() -> ! {
unreachable!("incorrect type '{}'", std::any::type_name::<T>())
}
};
let bench_impl: B = unsafe {
assert_eq!(mem::size_of::<B>(), 0, "benchmark closure expected to be zero-sized");
mem::zeroed()
};
bench_impl(bencher, &typed_args[arg_index]);
}
#[cfg(test)]
mod tests {
use super::*;
mod optimizations {
use std::borrow::Borrow;
use super::*;
fn test_eq_ptr<A: Borrow<str>, B: Borrow<str>>(a: &[A], b: &[B]) {
assert_eq!(a.len(), b.len());
for (a, b) in a.iter().zip(b) {
let a = a.borrow();
let b = b.borrow();
assert_eq!(a, b);
assert_eq!(a.as_ptr(), b.as_ptr());
}
}
#[test]
fn str() {
static ARGS: BenchArgs = BenchArgs::new();
let runner = ARGS.runner(|| ["a", "b"], |_, _| {});
let typed_args = runner.args.typed_args::<&str>().unwrap();
let names = runner.arg_names();
assert_eq!(names, ["a", "b"]);
assert_eq!(names, typed_args);
assert_eq!(names.as_ptr(), typed_args.as_ptr());
}
#[test]
fn string() {
static ARGS: BenchArgs = BenchArgs::new();
let runner = ARGS.runner(|| ["a".to_owned(), "b".to_owned()], |_, _| {});
let typed_args = runner.args.typed_args::<String>().unwrap();
let names = runner.arg_names();
assert_eq!(names, ["a", "b"]);
test_eq_ptr(names, typed_args);
}
#[test]
fn box_str() {
static ARGS: BenchArgs = BenchArgs::new();
let runner = ARGS.runner(
|| ["a".to_owned().into_boxed_str(), "b".to_owned().into_boxed_str()],
|_, _| {},
);
let typed_args = runner.args.typed_args::<Box<str>>().unwrap();
let names = runner.arg_names();
assert_eq!(names, ["a", "b"]);
test_eq_ptr(names, typed_args);
}
#[test]
fn cow_str() {
static ARGS: BenchArgs = BenchArgs::new();
let runner =
ARGS.runner(|| [Cow::Owned("a".to_owned()), Cow::Borrowed("b")], |_, _| {});
let typed_args = runner.args.typed_args::<Cow<str>>().unwrap();
let names = runner.arg_names();
assert_eq!(names, ["a", "b"]);
test_eq_ptr(names, typed_args);
}
}
}