cubecl_runtime/tune/
operation.rs1use alloc::string::String;
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use core::fmt::{Debug, Display};
5use core::hash::Hash;
6
7use super::{
8 AutotuneError,
9 input_generator::{InputGenerator, IntoInputGenerator},
10 key_generator::{IntoKeyGenerator, KeyGenerator},
11};
12
13#[cfg(autotune_persistent_cache)]
15pub fn compute_checksum<In: Clone + Send + 'static, Out: 'static>(
16 autotunables: &[Arc<dyn Tunable<Inputs = In, Output = Out>>],
17) -> String {
18 let mut checksum = String::new();
19 autotunables.iter().for_each(|op| {
20 checksum += op.name();
21 });
22 format!("{:x}", md5::compute(checksum))
23}
24
25#[derive(Clone)]
27pub struct TunableSet<K: AutotuneKey, Inputs: Send + 'static, Output: 'static> {
28 tunables: Vec<Arc<dyn Tunable<Inputs = Inputs, Output = Output>>>,
29 key_gen: Arc<dyn KeyGenerator<K, Inputs>>,
30 input_gen: Arc<dyn InputGenerator<K, Inputs>>,
31 #[allow(clippy::type_complexity)]
32 checksum_override: Option<Arc<dyn Fn(&Self) -> String + Send + Sync>>,
33}
34
35impl<K: AutotuneKey, Inputs: Clone + Send + 'static, Output: 'static>
36 TunableSet<K, Inputs, Output>
37{
38 pub fn len(&self) -> usize {
40 self.tunables.len()
41 }
42 pub fn is_empty(&self) -> bool {
44 self.tunables.len() == 0
45 }
46 pub fn new<KMarker, IMarker>(
48 key_gen: impl IntoKeyGenerator<K, Inputs, KMarker>,
49 input_gen: impl IntoInputGenerator<K, Inputs, IMarker>,
50 ) -> Self {
51 Self {
52 tunables: Default::default(),
53 input_gen: Arc::new(input_gen.into_input_gen()),
54 key_gen: Arc::new(key_gen.into_key_gen()),
55 checksum_override: None,
56 }
57 }
58
59 pub fn with_tunable<Marker>(
61 mut self,
62 tunable: impl IntoTunable<Inputs, Output, Marker>,
63 ) -> Self {
64 self.tunables.push(Arc::new(tunable.into_tunable()));
65 self
66 }
67
68 pub fn with_custom_checksum(
70 mut self,
71 checksum: impl Fn(&Self) -> String + Send + Sync + 'static,
72 ) -> Self {
73 self.checksum_override = Some(Arc::new(checksum));
74 self
75 }
76
77 pub fn autotunables(&self) -> Vec<Arc<dyn Tunable<Inputs = Inputs, Output = Output>>> {
80 self.tunables.clone()
81 }
82
83 pub fn fastest(
87 &self,
88 fastest_index: usize,
89 ) -> Arc<dyn Tunable<Inputs = Inputs, Output = Output>> {
90 self.tunables[fastest_index].clone()
91 }
92
93 #[cfg(autotune_persistent_cache)]
95 pub fn compute_checksum(&self) -> String {
96 if let Some(checksum_override) = &self.checksum_override {
97 checksum_override(self)
98 } else {
99 compute_checksum(&self.tunables)
100 }
101 }
102
103 pub fn generate_key(&self, inputs: &Inputs) -> K {
105 self.key_gen.generate(inputs)
106 }
107
108 pub fn generate_inputs(&self, key: &K, inputs: &Inputs) -> Inputs {
110 self.input_gen.generate(key, inputs)
111 }
112}
113
114pub trait Tunable: Send + Sync + 'static {
116 type Inputs: Clone;
118 type Output;
120
121 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
123
124 fn name(&self) -> &str {
126 core::any::type_name::<Self>()
127 }
128}
129
130pub trait IntoTunable<In, Out, Marker> {
136 type Tunable: Tunable<Inputs = In, Output = Out>;
138
139 fn into_tunable(self) -> Self::Tunable;
141}
142
143#[doc(hidden)]
145pub struct IsIdentity;
146
147impl<T: Tunable> IntoTunable<T::Inputs, T::Output, IsIdentity> for T {
148 type Tunable = T;
149
150 fn into_tunable(self) -> Self::Tunable {
151 self
152 }
153}
154
155#[cfg(autotune_persistent_cache)]
156pub trait AutotuneKey:
158 Clone
159 + Debug
160 + PartialEq
161 + Eq
162 + Hash
163 + Display
164 + serde::Serialize
165 + serde::de::DeserializeOwned
166 + Send
167 + Sync
168 + 'static
169{
170}
171#[cfg(not(autotune_persistent_cache))]
172pub trait AutotuneKey:
174 Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
175{
176}
177
178impl AutotuneKey for String {}