cubecl_runtime/tune/
operation.rs1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::vec::Vec;
4use core::fmt::{Debug, Display};
5use core::hash::Hash;
6
7use super::AutotuneError;
8
9#[cfg(autotune_persistent_cache)]
11pub fn compute_checksum<Out: Send + 'static>(
12 autotunables: &[Box<dyn AutotuneOperation<Out>>],
13) -> String {
14 let mut checksum = String::new();
15 autotunables.iter().for_each(|op| {
16 checksum += op.name();
17 });
18 format!("{:x}", md5::compute(checksum))
19}
20
21pub trait AutotuneOperationSet<K: Send + 'static, Output: Send + 'static = ()>: Send {
23 fn key(&self) -> K;
25
26 fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation<Output>>>;
29
30 fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation<Output>>;
34
35 #[cfg(autotune_persistent_cache)]
37 fn compute_checksum(&self) -> String {
38 compute_checksum(&self.autotunables())
39 }
40
41 #[allow(unused)]
43 fn should_run(&self, key: &K, index: usize) -> bool {
44 true
45 }
46}
47
48pub trait AutotuneOperation<Output: Send + 'static = ()>: Send + core::fmt::Debug {
50 fn execute(self: Box<Self>) -> Result<Output, AutotuneError>;
52
53 fn name(&self) -> &str {
55 core::any::type_name::<Self>()
56 }
57
58 fn clone(&self) -> Box<dyn AutotuneOperation<Output>>;
60}
61
62#[cfg(autotune_persistent_cache)]
63pub trait AutotuneKey:
65 Clone
66 + Debug
67 + PartialEq
68 + Eq
69 + Hash
70 + Display
71 + serde::Serialize
72 + serde::de::DeserializeOwned
73 + Send
74 + Sync
75 + 'static
76{
77}
78#[cfg(not(autotune_persistent_cache))]
79pub trait AutotuneKey:
81 Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
82{
83}
84
85impl AutotuneKey for String {}