cubecl_runtime/tune/
operation.rsuse alloc::boxed::Box;
use alloc::string::String;
use alloc::vec::Vec;
use core::fmt::{Debug, Display};
use core::hash::Hash;
#[cfg(autotune_persistent_cache)]
pub fn compute_checksum<Out>(autotunables: &[Box<dyn AutotuneOperation<Out>>]) -> String {
let mut checksum = String::new();
autotunables.iter().for_each(|op| {
checksum += op.name();
});
format!("{:x}", md5::compute(checksum))
}
pub trait AutotuneOperationSet<K: Send + Sync + 'static, Output = ()>: Send + Sync {
fn key(&self) -> K;
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation<Output>>>;
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation<Output>>;
#[cfg(autotune_persistent_cache)]
fn compute_checksum(&self) -> String {
compute_checksum(&self.autotunables())
}
#[allow(unused)]
fn should_run(&self, key: &K, index: usize) -> bool {
true
}
}
pub trait AutotuneOperation<Output = ()>: core::fmt::Debug {
fn execute(self: Box<Self>) -> Output;
fn name(&self) -> &str {
core::any::type_name::<Self>()
}
fn clone(&self) -> Box<dyn AutotuneOperation<Output>>;
}
#[cfg(autotune_persistent_cache)]
pub trait AutotuneKey:
Clone
+ Debug
+ PartialEq
+ Eq
+ Hash
+ Display
+ serde::Serialize
+ serde::de::DeserializeOwned
+ Send
+ Sync
+ 'static
{
}
#[cfg(not(autotune_persistent_cache))]
pub trait AutotuneKey:
Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
{
}
impl AutotuneKey for String {}