cubecl_runtime/tune/
local.rs1use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner};
2use crate::{client::ComputeClient, runtime::Runtime, tune::TuneCacheResult};
3use alloc::string::ToString;
4use alloc::sync::Arc;
5use core::{
6 any::{Any, TypeId},
7 fmt::Display,
8 hash::Hash,
9};
10use cubecl_common::map::SharedStateMap;
11use hashbrown::HashMap;
12
13pub struct LocalTuner<AK: AutotuneKey, ID> {
16 state: SharedStateMap<ID, Tuner<AK>>,
17 name: &'static str,
18 sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
19}
20
21unsafe impl<AK: AutotuneKey, ID> Sync for LocalTuner<AK, ID> {}
22
23#[macro_export]
25macro_rules! local_tuner {
26 ($name:expr) => {
27 LocalTuner::new(concat!(module_path!(), "-", $name));
28 };
29 () => {
30 LocalTuner::new(module_path!());
31 };
32}
33
34pub use local_tuner;
35
36impl<AK, ID> LocalTuner<AK, ID>
37where
38 AK: AutotuneKey + 'static,
39 ID: Hash + PartialEq + Eq + Clone + Display,
40{
41 pub const fn new(name: &'static str) -> Self {
43 Self {
44 state: SharedStateMap::new(),
45 name,
46 sets: spin::RwLock::new(None),
47 }
48 }
49
50 pub fn init<In, Out, F>(&self, init_set: F) -> Arc<TunableSet<AK, In, Out>>
52 where
53 F: Fn() -> TunableSet<AK, In, Out> + 'static + Send + Sync,
54 In: Clone + Send + 'static,
55 Out: AutotuneOutput,
56 {
57 let sets = self.sets.read();
58 let type_id = TypeId::of::<F>();
59
60 static DOWNCAST_ERROR: &str = "Local tuner only support one set of tunable that must work on the same input and output declared with the init function.";
61
62 if let Some(sets) = sets.as_ref()
63 && let Some(set) = sets.get(&type_id)
64 {
65 return set.clone().downcast().expect(DOWNCAST_ERROR);
66 };
67
68 core::mem::drop(sets);
69
70 let mut sets = self.sets.write();
71
72 if let Some(sets) = sets.as_ref()
73 && let Some(set) = sets.get(&type_id)
74 {
75 return set.clone().downcast().expect(DOWNCAST_ERROR);
76 };
77
78 let content = Arc::new(init_set());
79
80 if let Some(sets) = sets.as_mut() {
81 sets.insert(type_id, content.clone());
82 } else {
83 let mut map = HashMap::<TypeId, Arc<dyn Any + Send + Sync>>::new();
84 map.insert(type_id, content.clone());
85 *sets = Some(map);
86 };
87
88 content
89 }
90
91 pub fn clear(&self) {
93 self.state.clear()
94 }
95
96 #[cfg(feature = "autotune-checks")]
97 fn checks<In: Send + Clone + 'static, Out: AutotuneOutput>(
98 &self,
99 operations: &TunableSet<AK, In, Out>,
100 inputs: &In,
101 ) {
102 use alloc::vec::Vec;
103
104 let mut checks_outputs = Vec::new();
105 for i in 0..operations.len() {
106 let op = operations.fastest(i);
107 let result = op.execute(inputs.clone());
108 checks_outputs.push(result);
109 }
110 super::check_autotune_outputs(checks_outputs);
111 }
112
113 fn try_all_operations<In, Out>(operations: &TunableSet<AK, In, Out>, inputs: In) -> Out
118 where
119 In: Clone + Send + 'static,
120 Out: AutotuneOutput,
121 {
122 for i in 0..operations.len() {
123 if let Ok(output) = operations.fastest(i).execute(inputs.clone()) {
124 return output;
125 }
126 }
127 panic!("All autotune operations failed, no viable operation found.");
128 }
129
130 pub fn execute<R: Runtime, In, Out>(
132 &self,
133 id: &ID,
134 client: &ComputeClient<R>,
135 operations: Arc<TunableSet<AK, In, Out>>,
136 inputs: In,
137 ) -> Out
138 where
139 In: Clone + Send + 'static,
140 Out: AutotuneOutput,
141 {
142 let key = operations.generate_key(&inputs);
143
144 let tuner_state = self.state.get_or_init(id, move |id| {
146 let name = self.name.replace("::", "-");
147 Tuner::new(&name, &id.to_string())
148 });
149 let tuner = tuner_state.read();
150
151 let mut tuner = match tuner.fastest(&key) {
152 TuneCacheResult::Hit { fastest_index } => {
153 core::mem::drop(tuner);
154 core::mem::drop(tuner_state);
155
156 #[cfg(feature = "autotune-checks")]
157 self.checks(&operations, &inputs);
158
159 let op = operations.fastest(fastest_index);
160 let result = op
161 .execute(inputs)
162 .expect("Should run when selected by autotune.");
163 return result;
164 }
165 TuneCacheResult::Pending => {
166 core::mem::drop(tuner);
167 core::mem::drop(tuner_state);
168
169 #[cfg(feature = "autotune-checks")]
170 self.checks(&operations, &inputs);
171
172 return Self::try_all_operations(&operations, inputs);
173 }
174 #[cfg(std_io)]
175 TuneCacheResult::Unchecked => {
176 core::mem::drop(tuner);
177 let mut tuner = tuner_state.write();
178
179 let checksum = operations.compute_checksum();
181 tuner.validate_checksum(&key, &checksum);
182
183 if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
185 core::mem::drop(tuner);
186 core::mem::drop(tuner_state);
187
188 let op = operations.fastest(fastest_index);
189 let result = op
190 .execute(inputs)
191 .expect("Should run when selected by autotune.");
192 return result;
193 }
194
195 tuner
196 }
197
198 #[cfg(not(std_io))]
199 TuneCacheResult::Unchecked => {
200 core::mem::drop(tuner);
201 tuner_state.write()
202 }
203 TuneCacheResult::Miss => {
204 core::mem::drop(tuner);
205 tuner_state.write()
206 }
207 };
208
209 let job = if !tuner.autotuning.contains(&key) {
210 tuner.autotuning.insert(key.clone());
211 Some(tuner.prepare_autotune(key.clone(), &inputs, &operations, client))
212 } else {
213 None
214 };
215
216 core::mem::drop(tuner);
219 core::mem::drop(tuner_state);
220
221 if let Some(job) = job {
222 job();
223 }
224
225 let index_to_run = {
226 let tuner_state = self.state.get(id).unwrap();
227 let mut tuner = tuner_state.write();
228
229 tuner.handle_results();
230
231 match tuner.fastest(&key) {
232 TuneCacheResult::Hit { fastest_index } => {
233 fastest_index
235 }
236 TuneCacheResult::Pending | TuneCacheResult::Miss => {
237 return Self::try_all_operations(&operations, inputs);
241 }
242 TuneCacheResult::Unchecked => {
243 panic!("Should have checked the cache.")
244 }
245 }
246 };
247
248 operations
249 .fastest(index_to_run)
250 .execute(inputs)
251 .expect("Should run when selected by autotune.")
252 }
253}