cubecl_runtime/tune/
local.rs1use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner};
2use crate::{client::ComputeClient, runtime::Runtime, tune::TuneCacheResult};
3use alloc::boxed::Box;
4use alloc::sync::Arc;
5use core::{
6 any::{Any, TypeId},
7 fmt::Display,
8 hash::Hash,
9};
10use cubecl_common::map::SharedStateMap;
11use hashbrown::HashMap;
12
13#[cfg(not(feature = "std"))]
14use alloc::string::ToString;
15
16pub struct LocalTuner<AK: AutotuneKey, ID> {
19 state: SharedStateMap<ID, Tuner<AK>>,
20 name: &'static str,
21 sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
22}
23
24unsafe impl<AK: AutotuneKey, ID> Sync for LocalTuner<AK, ID> {}
25
26#[macro_export]
28macro_rules! local_tuner {
29 ($name:expr) => {
30 LocalTuner::new(concat!(module_path!(), "-", $name));
31 };
32 () => {
33 LocalTuner::new(module_path!());
34 };
35}
36
37pub use local_tuner;
38
39impl<AK, ID> LocalTuner<AK, ID>
40where
41 AK: AutotuneKey + 'static,
42 ID: Hash + PartialEq + Eq + Clone + Display,
43{
44 pub const fn new(name: &'static str) -> Self {
46 Self {
47 state: SharedStateMap::new(),
48 name,
49 sets: spin::RwLock::new(None),
50 }
51 }
52
53 pub fn init<In, Out, F>(&self, init_set: F) -> Arc<TunableSet<AK, In, Out>>
55 where
56 F: Fn() -> TunableSet<AK, In, Out> + 'static + Send + Sync,
57 In: Clone + Send + 'static,
58 Out: AutotuneOutput,
59 {
60 let sets = self.sets.read();
61 let type_id = TypeId::of::<F>();
62
63 static DOWNCAST_ERROR: &str = "Local tuner only support one set of tunable that must work on the same input and output declared with the init function.";
64
65 if let Some(sets) = sets.as_ref()
66 && let Some(set) = sets.get(&type_id)
67 {
68 return set.clone().downcast().expect(DOWNCAST_ERROR);
69 };
70
71 core::mem::drop(sets);
72
73 let mut sets = self.sets.write();
74
75 if let Some(sets) = sets.as_ref()
76 && let Some(set) = sets.get(&type_id)
77 {
78 return set.clone().downcast().expect(DOWNCAST_ERROR);
79 };
80
81 let content = Arc::new(init_set());
82
83 if let Some(sets) = sets.as_mut() {
84 sets.insert(type_id, content.clone());
85 } else {
86 let mut map = HashMap::<TypeId, Arc<dyn Any + Send + Sync>>::new();
87 map.insert(type_id, content.clone());
88 *sets = Some(map);
89 };
90
91 content
92 }
93
94 pub fn clear(&self) {
96 self.state.clear()
97 }
98
99 #[cfg(feature = "autotune-checks")]
100 fn checks<In: Send + Clone + 'static, Out: AutotuneOutput>(
101 &self,
102 operations: &TunableSet<AK, In, Out>,
103 inputs: &In,
104 ) {
105 let mut checks_outputs = Vec::new();
106 for i in 0..operations.len() {
107 let op = operations.fastest(i);
108 let result = op.execute(inputs.clone());
109 checks_outputs.push(result);
110 }
111 super::check_autotune_outputs(checks_outputs);
112 }
113
114 pub fn execute<R: Runtime, In, Out>(
116 &self,
117 id: &ID,
118 client: &ComputeClient<R>,
119 operations: Arc<TunableSet<AK, In, Out>>,
120 inputs: In,
121 ) -> Out
122 where
123 In: Clone + Send + 'static,
124 Out: AutotuneOutput,
125 {
126 let key = operations.generate_key(&inputs);
127
128 let autotune_job = {
130 let tuner_state = self.state.get_or_init(id, move |id| {
131 let name = self.name.replace("::", "-");
132 Tuner::new(&name, &id.to_string())
133 });
134 let tuner = tuner_state.read();
135
136 let mut tuner = match tuner.fastest(&key) {
137 TuneCacheResult::Hit { fastest_index } => {
138 core::mem::drop(tuner);
139 core::mem::drop(tuner_state);
140
141 #[cfg(feature = "autotune-checks")]
142 self.checks(&operations, &inputs);
143
144 let op = operations.fastest(fastest_index);
145 let result = op
146 .execute(inputs)
147 .expect("Should run when selected by autotune.");
148 return result;
149 }
150 TuneCacheResult::Pending => {
151 core::mem::drop(tuner);
152 core::mem::drop(tuner_state);
153
154 let op = operations.fastest(0);
155 let result = op
156 .execute(inputs)
157 .expect("Should run when selected by autotune.");
158 return result;
159 }
160 #[cfg(std_io)]
161 TuneCacheResult::Unchecked => {
162 core::mem::drop(tuner);
163 let mut tuner = tuner_state.write();
164
165 let checksum = operations.compute_checksum();
167 tuner.validate_checksum(&key, &checksum);
168
169 if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
171 core::mem::drop(tuner);
172 core::mem::drop(tuner_state);
173
174 let op = operations.fastest(fastest_index);
175 let result = op
176 .execute(inputs)
177 .expect("Should run when selected by autotune.");
178 return result;
179 }
180
181 tuner
182 }
183
184 #[cfg(not(std_io))]
185 TuneCacheResult::Unchecked => {
186 core::mem::drop(tuner);
187 tuner_state.write()
188 }
189 TuneCacheResult::Miss => {
190 core::mem::drop(tuner);
191 tuner_state.write()
192 }
193 };
194
195 if tuner.autotuning.contains(&key) {
196 Box::new(move || {})
197 } else {
198 tuner.autotuning.insert(key.clone());
199 tuner.prepare_autotune(key.clone(), &inputs, &operations, client)
200 }
201 };
202
203 autotune_job();
204
205 let index_to_run = {
206 let tuner_state = self.state.get(id).unwrap();
207 let mut tuner = tuner_state.write();
208
209 tuner.handle_results();
210
211 match tuner.fastest(&key) {
212 TuneCacheResult::Hit { fastest_index } => {
213 fastest_index
215 }
216 TuneCacheResult::Pending => {
217 0
219 }
220 TuneCacheResult::Miss => {
221 0
226 }
227 TuneCacheResult::Unchecked => {
228 panic!("Should have checked the cache.")
229 }
230 }
231 };
232
233 operations
234 .fastest(index_to_run)
235 .execute(inputs)
236 .expect("Should run when selected by autotune.")
237 }
238}