cubecl_runtime/tune/
operation.rs1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::fmt::{Debug, Display};
6use core::hash::Hash;
7
8use super::{
9 AutotuneError,
10 input_generator::{InputGenerator, IntoInputGenerator},
11 key_generator::{IntoKeyGenerator, KeyGenerator},
12};
13use super::{Tunable, TunePlan};
14
15#[cfg(std_io)]
17pub fn compute_checksum<In: Clone + Send + 'static, Out: 'static>(
18 autotunables: impl Iterator<Item = Arc<dyn TuneFn<Inputs = In, Output = Out>>>,
19) -> String {
20 let mut checksum = String::new();
21 autotunables.for_each(|op| {
22 checksum += op.name();
23 });
24 format!("{:x}", md5::compute(checksum))
25}
26
27pub struct TunableSet<K: AutotuneKey, Inputs: Send + 'static, Output: 'static> {
29 tunables: Vec<Tunable<K, Inputs, Output>>,
30 key_gen: Arc<dyn KeyGenerator<K, Inputs>>,
31 input_gen: Arc<dyn InputGenerator<K, Inputs>>,
32 #[allow(clippy::type_complexity)]
33 checksum_override: Option<Arc<dyn Fn(&Self) -> String + Send + Sync>>,
34}
35
36unsafe impl<K: AutotuneKey, In: Send, Out> Send for TunableSet<K, In, Out> {}
37unsafe impl<K: AutotuneKey, In: Send, Out> Sync for TunableSet<K, In, Out> {}
38
39impl<K: AutotuneKey, Inputs: Clone + Send + 'static, Output: 'static>
40 TunableSet<K, Inputs, Output>
41{
42 pub fn len(&self) -> usize {
44 self.tunables.len()
45 }
46 pub fn is_empty(&self) -> bool {
48 self.tunables.len() == 0
49 }
50 pub fn new<KMarker, IMarker>(
52 key_gen: impl IntoKeyGenerator<K, Inputs, KMarker>,
53 input_gen: impl IntoInputGenerator<K, Inputs, IMarker>,
54 ) -> Self {
55 Self {
56 tunables: Default::default(),
57 input_gen: Arc::new(input_gen.into_input_gen()),
58 key_gen: Arc::new(key_gen.into_key_gen()),
59 checksum_override: None,
60 }
61 }
62
63 pub fn with(mut self, tunable: Tunable<K, Inputs, Output>) -> Self {
65 self.tunables.push(tunable);
66 self
67 }
68
69 pub fn with_custom_checksum(
71 mut self,
72 checksum: impl Fn(&Self) -> String + Send + Sync + 'static,
73 ) -> Self {
74 self.checksum_override = Some(Arc::new(checksum));
75 self
76 }
77
78 pub fn autotunables(&self) -> Vec<Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>> {
81 self.tunables
82 .iter()
83 .map(|tunable| tunable.function.clone())
84 .collect()
85 }
86
87 pub(crate) fn plan(&self, key: &K) -> TunePlan {
89 TunePlan::new(key, &self.tunables)
90 }
91
92 pub fn fastest(
96 &self,
97 fastest_index: usize,
98 ) -> Arc<dyn TuneFn<Inputs = Inputs, Output = Output>> {
99 self.tunables[fastest_index].function.clone()
100 }
101
102 #[cfg(std_io)]
104 pub fn compute_checksum(&self) -> String {
105 if let Some(checksum_override) = &self.checksum_override {
106 checksum_override(self)
107 } else {
108 compute_checksum(self.tunables.iter().map(|tune| tune.function.clone()))
109 }
110 }
111
112 pub fn generate_key(&self, inputs: &Inputs) -> K {
114 self.key_gen.generate(inputs)
115 }
116
117 pub fn inputs_generator(&self, key: &K, inputs: &Inputs) -> Box<dyn FnOnce() -> Inputs> {
119 let generate = self.input_gen.clone();
120 let key = key.clone();
121 let inputs = inputs.clone();
122
123 Box::new(move || generate.generate(&key, &inputs))
124 }
125}
126
127pub trait TuneFn: Send + Sync + 'static {
129 type Inputs: Clone;
131 type Output;
133
134 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
136
137 fn name(&self) -> &str {
139 core::any::type_name::<Self>()
140 }
141}
142
143pub trait IntoTuneFn<In, Out, Marker> {
149 type Tunable: TuneFn<Inputs = In, Output = Out>;
151
152 fn into_tunable(self) -> Self::Tunable;
154}
155
156#[doc(hidden)]
158pub struct IsIdentity;
159
160impl<T: TuneFn> IntoTuneFn<T::Inputs, T::Output, IsIdentity> for T {
161 type Tunable = T;
162
163 fn into_tunable(self) -> Self::Tunable {
164 self
165 }
166}
167
168#[cfg(std_io)]
169pub trait AutotuneKey:
171 Clone
172 + Debug
173 + PartialEq
174 + Eq
175 + Hash
176 + Display
177 + serde::Serialize
178 + serde::de::DeserializeOwned
179 + Send
180 + Sync
181 + 'static
182{
183}
184#[cfg(not(std_io))]
185pub trait AutotuneKey:
187 Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
188{
189}
190
191impl AutotuneKey for String {}