cubecl_runtime/tune/
operation.rs1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::fmt::{Debug, Display};
6use core::hash::Hash;
7
8#[cfg(std_io)]
9use alloc::format;
10
11use super::{
12 AutotuneError,
13 input_generator::{InputGenerator, IntoInputGenerator},
14 key_generator::{IntoKeyGenerator, KeyGenerator},
15};
16use super::{Tunable, TunePlan};
17
18#[cfg(std_io)]
20pub fn compute_checksum<In: Clone + Send + 'static, Out: 'static>(
21 autotunables: impl Iterator<Item = Arc<dyn TuneFn<Inputs = In, Output = Out>>>,
22) -> String {
23 let mut checksum = String::new();
24 autotunables.for_each(|op| {
25 checksum += op.name();
26 });
27 format!("{:x}", md5::compute(checksum))
28}
29
30pub struct TunableSet<K: AutotuneKey, Inputs: Send + 'static, Output: 'static> {
32 tunables: Vec<Tunable<K, Inputs, Output>>,
33 key_gen: Arc<dyn KeyGenerator<K, Inputs>>,
34 input_gen: Arc<dyn InputGenerator<K, Inputs>>,
35 #[allow(clippy::type_complexity)]
36 checksum_override: Option<Arc<dyn Fn(&Self) -> String + Send + Sync>>,
37}
38
39unsafe impl<K: AutotuneKey, In: Send, Out> Send for TunableSet<K, In, Out> {}
40unsafe impl<K: AutotuneKey, In: Send, Out> Sync for TunableSet<K, In, Out> {}
41
42impl<K: AutotuneKey, Inputs: Clone + Send + 'static, Output: 'static>
43 TunableSet<K, Inputs, Output>
44{
45 pub fn len(&self) -> usize {
47 self.tunables.len()
48 }
49 pub fn is_empty(&self) -> bool {
51 self.tunables.len() == 0
52 }
53 pub fn new<KMarker, IMarker>(
55 key_gen: impl IntoKeyGenerator<K, Inputs, KMarker>,
56 input_gen: impl IntoInputGenerator<K, Inputs, IMarker>,
57 ) -> Self {
58 Self {
59 tunables: Default::default(),
60 input_gen: Arc::new(input_gen.into_input_gen()),
61 key_gen: Arc::new(key_gen.into_key_gen()),
62 checksum_override: None,
63 }
64 }
65
66 pub fn with(mut self, tunable: Tunable<K, Inputs, Output>) -> Self {
68 self.tunables.push(tunable);
69 self
70 }
71
72 pub fn with_custom_checksum(
74 mut self,
75 checksum: impl Fn(&Self) -> String + Send + Sync + 'static,
76 ) -> Self {
77 self.checksum_override = Some(Arc::new(checksum));
78 self
79 }
80
81 pub fn autotunables(&self) -> Vec<Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>> {
86 self.tunables
87 .iter()
88 .map(|tunable| tunable.function.clone())
89 .collect()
90 }
91
92 pub(crate) fn plan(&self, key: &K) -> TunePlan {
94 TunePlan::new(key, &self.tunables)
95 }
96
97 pub fn fastest(
101 &self,
102 fastest_index: usize,
103 ) -> Arc<dyn TuneFn<Inputs = Inputs, Output = Output>> {
104 self.tunables[fastest_index].function.clone()
105 }
106
107 #[cfg(std_io)]
109 pub fn compute_checksum(&self) -> String {
110 if let Some(checksum_override) = &self.checksum_override {
111 checksum_override(self)
112 } else {
113 compute_checksum(self.tunables.iter().map(|tune| tune.function.clone()))
114 }
115 }
116
117 pub fn generate_key(&self, inputs: &Inputs) -> K {
119 self.key_gen.generate(inputs)
120 }
121
122 pub fn inputs_generator(&self, key: &K, inputs: &Inputs) -> Box<dyn FnOnce() -> Inputs> {
124 let generate = self.input_gen.clone();
125 let key = key.clone();
126 let inputs = inputs.clone();
127
128 Box::new(move || generate.generate(&key, &inputs))
129 }
130}
131
132pub trait TuneFn: Send + Sync + 'static {
134 type Inputs: Clone;
136 type Output;
138
139 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
141
142 fn name(&self) -> &str;
144}
145
146pub trait IntoTuneFn<In, Out, Marker> {
152 type Tunable: TuneFn<Inputs = In, Output = Out>;
154
155 fn into_tunable(self, name: String) -> Self::Tunable;
157}
158
159#[doc(hidden)]
161pub struct IsIdentity;
162
163impl<T: TuneFn> IntoTuneFn<T::Inputs, T::Output, IsIdentity> for T {
164 type Tunable = T;
165
166 fn into_tunable(self, _name: String) -> Self::Tunable {
167 self
168 }
169}
170
171#[cfg(std_io)]
172pub trait AutotuneKey:
174 Clone
175 + Debug
176 + PartialEq
177 + Eq
178 + Hash
179 + Display
180 + serde::Serialize
181 + serde::de::DeserializeOwned
182 + Send
183 + Sync
184 + 'static
185{
186}
187#[cfg(not(std_io))]
188pub trait AutotuneKey:
190 Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
191{
192}
193
194impl AutotuneKey for String {}