use alloc::boxed::Box;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::fmt::{Debug, Display};
use core::hash::Hash;
use alloc::format;
use super::{
AutotuneError, input_generator::InputGenerator, key_generator::KeyGenerator,
tune_inputs::TuneInputs,
};
use super::{Tunable, TunePlan};
type TuneDelegate<I, Out> =
dyn for<'inp> Fn(<I as TuneInputs>::At<'inp>) -> Result<Out, AutotuneError> + Send + Sync;
#[derive(new)]
pub struct TuneFn<I: TuneInputs, Out> {
pub(crate) name: String,
func: Box<TuneDelegate<I, Out>>,
}
impl<I: TuneInputs, Out: 'static> TuneFn<I, Out> {
pub fn execute<'a>(&self, inputs: <I as TuneInputs>::At<'a>) -> Result<Out, AutotuneError> {
(self.func)(inputs)
}
}
pub struct TunableSet<K: AutotuneKey, F: TuneInputs, Output: 'static> {
tunables: Vec<Tunable<K, F, Output>>,
key_gen: Arc<dyn KeyGenerator<K, F> + Send + Sync>,
input_gen: Arc<dyn InputGenerator<K, F> + Send + Sync>,
}
impl<K: AutotuneKey, F: TuneInputs, Output: 'static> TunableSet<K, F, Output> {
pub fn len(&self) -> usize {
self.tunables.len()
}
pub fn is_empty(&self) -> bool {
self.tunables.is_empty()
}
pub fn new(key_gen: impl KeyGenerator<K, F>, input_gen: impl InputGenerator<K, F>) -> Self {
Self {
tunables: Default::default(),
input_gen: Arc::new(input_gen),
key_gen: Arc::new(key_gen),
}
}
pub fn new_cloning_inputs(key_gen: impl KeyGenerator<K, F>) -> Self {
Self::new(key_gen, super::CloneInputGenerator)
}
pub fn with(mut self, tunable: Tunable<K, F, Output>) -> Self {
self.tunables.push(tunable);
self
}
pub fn autotunables(&self) -> impl Iterator<Item = &TuneFn<F, Output>> {
self.tunables.iter().map(|tunable| &tunable.function)
}
pub(crate) fn plan(&self, key: &K) -> TunePlan {
TunePlan::new(key, &self.tunables)
}
pub fn fastest(&self, fastest_index: usize) -> &TuneFn<F, Output> {
&self.tunables[fastest_index].function
}
pub fn compute_checksum(&self) -> String {
let mut checksum = String::new();
for tune in &self.tunables {
checksum += &tune.function.name;
}
format!("{:x}", md5::compute(checksum))
}
pub fn generate_key<'a>(&self, inputs: &F::At<'a>) -> K {
self.key_gen.generate(inputs)
}
pub fn generate_inputs<'a>(&self, key: &K, inputs: &F::At<'a>) -> F::At<'a> {
self.input_gen.generate(key, inputs)
}
}
#[cfg(std_io)]
pub trait AutotuneKey:
Clone
+ Debug
+ PartialEq
+ Eq
+ Hash
+ Display
+ serde::Serialize
+ serde::de::DeserializeOwned
+ Send
+ Sync
+ 'static
{
}
#[cfg(not(std_io))]
pub trait AutotuneKey:
Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
{
}
impl AutotuneKey for String {}