use alloc::boxed::Box;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::fmt::{Debug, Display};
use core::hash::Hash;
#[cfg(std_io)]
use alloc::format;
use super::{
AutotuneError,
input_generator::{InputGenerator, IntoInputGenerator},
key_generator::{IntoKeyGenerator, KeyGenerator},
};
use super::{Tunable, TunePlan};
#[cfg(std_io)]
pub fn compute_checksum<In: Clone + Send + 'static, Out: 'static>(
autotunables: impl Iterator<Item = Arc<dyn TuneFn<Inputs = In, Output = Out>>>,
) -> String {
let mut checksum = String::new();
autotunables.for_each(|op| {
checksum += op.name();
});
format!("{:x}", md5::compute(checksum))
}
pub struct TunableSet<K: AutotuneKey, Inputs: Send + 'static, Output: 'static> {
tunables: Vec<Tunable<K, Inputs, Output>>,
key_gen: Arc<dyn KeyGenerator<K, Inputs>>,
input_gen: Arc<dyn InputGenerator<K, Inputs>>,
#[allow(clippy::type_complexity)]
checksum_override: Option<Arc<dyn Fn(&Self) -> String + Send + Sync>>,
}
unsafe impl<K: AutotuneKey, In: Send, Out> Send for TunableSet<K, In, Out> {}
unsafe impl<K: AutotuneKey, In: Send, Out> Sync for TunableSet<K, In, Out> {}
impl<K: AutotuneKey, Inputs: Clone + Send + 'static, Output: 'static>
TunableSet<K, Inputs, Output>
{
pub fn len(&self) -> usize {
self.tunables.len()
}
pub fn is_empty(&self) -> bool {
self.tunables.len() == 0
}
pub fn new<KMarker, IMarker>(
key_gen: impl IntoKeyGenerator<K, Inputs, KMarker>,
input_gen: impl IntoInputGenerator<K, Inputs, IMarker>,
) -> Self {
Self {
tunables: Default::default(),
input_gen: Arc::new(input_gen.into_input_gen()),
key_gen: Arc::new(key_gen.into_key_gen()),
checksum_override: None,
}
}
pub fn with(mut self, tunable: Tunable<K, Inputs, Output>) -> Self {
self.tunables.push(tunable);
self
}
pub fn with_custom_checksum(
mut self,
checksum: impl Fn(&Self) -> String + Send + Sync + 'static,
) -> Self {
self.checksum_override = Some(Arc::new(checksum));
self
}
pub fn autotunables(&self) -> Vec<Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>> {
self.tunables
.iter()
.map(|tunable| tunable.function.clone())
.collect()
}
pub(crate) fn plan(&self, key: &K) -> TunePlan {
TunePlan::new(key, &self.tunables)
}
pub fn fastest(
&self,
fastest_index: usize,
) -> Arc<dyn TuneFn<Inputs = Inputs, Output = Output>> {
self.tunables[fastest_index].function.clone()
}
#[cfg(std_io)]
pub fn compute_checksum(&self) -> String {
if let Some(checksum_override) = &self.checksum_override {
checksum_override(self)
} else {
compute_checksum(self.tunables.iter().map(|tune| tune.function.clone()))
}
}
pub fn generate_key(&self, inputs: &Inputs) -> K {
self.key_gen.generate(inputs)
}
pub fn inputs_generator(&self, key: &K, inputs: &Inputs) -> Box<dyn FnOnce() -> Inputs> {
let generate = self.input_gen.clone();
let key = key.clone();
let inputs = inputs.clone();
Box::new(move || generate.generate(&key, &inputs))
}
}
pub trait TuneFn: Send + Sync + 'static {
type Inputs: Clone;
type Output;
fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
fn name(&self) -> &str;
}
pub trait IntoTuneFn<In, Out, Marker> {
type Tunable: TuneFn<Inputs = In, Output = Out>;
fn into_tunable(self, name: String) -> Self::Tunable;
}
#[doc(hidden)]
pub struct IsIdentity;
impl<T: TuneFn> IntoTuneFn<T::Inputs, T::Output, IsIdentity> for T {
type Tunable = T;
fn into_tunable(self, _name: String) -> Self::Tunable {
self
}
}
#[cfg(std_io)]
pub trait AutotuneKey:
Clone
+ Debug
+ PartialEq
+ Eq
+ Hash
+ Display
+ serde::Serialize
+ serde::de::DeserializeOwned
+ Send
+ Sync
+ 'static
{
}
#[cfg(not(std_io))]
pub trait AutotuneKey:
Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
{
}
impl AutotuneKey for String {}