cubecl_runtime/tune/
local.rs1use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner};
2use crate::{client::ComputeClient, server::ComputeServer, tune::TuneCacheResult};
3use alloc::boxed::Box;
4use alloc::sync::Arc;
5use core::{
6 any::{Any, TypeId},
7 fmt::Display,
8 hash::Hash,
9};
10use cubecl_common::map::SharedStateMap;
11use hashbrown::HashMap;
12
13#[cfg(not(feature = "std"))]
14use alloc::string::ToString;
15
16pub struct LocalTuner<AK: AutotuneKey, ID> {
19 state: SharedStateMap<ID, Tuner<AK>>,
20 name: &'static str,
21 sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
22}
23
24unsafe impl<AK: AutotuneKey, ID> Sync for LocalTuner<AK, ID> {}
25
26#[macro_export]
28macro_rules! local_tuner {
29 ($name:expr) => {
30 LocalTuner::new(concat!(module_path!(), "-", $name));
31 };
32 () => {
33 LocalTuner::new(module_path!());
34 };
35}
36
37pub use local_tuner;
38
39impl<AK, ID> LocalTuner<AK, ID>
40where
41 AK: AutotuneKey + 'static,
42 ID: Hash + PartialEq + Eq + Clone + Display,
43{
44 pub const fn new(name: &'static str) -> Self {
46 Self {
47 state: SharedStateMap::new(),
48 name,
49 sets: spin::RwLock::new(None),
50 }
51 }
52
53 pub fn init<In, Out, F>(&self, init_set: F) -> Arc<TunableSet<AK, In, Out>>
55 where
56 F: Fn() -> TunableSet<AK, In, Out> + 'static + Send + Sync,
57 In: Clone + Send + 'static,
58 Out: AutotuneOutput,
59 {
60 let sets = self.sets.read();
61 let type_id = TypeId::of::<F>();
62
63 static DOWNCAST_ERROR: &str = "Local tuner only support one set of tunable that must work on the same input and output declared with the init function.";
64
65 if let Some(sets) = sets.as_ref()
66 && let Some(set) = sets.get(&type_id)
67 {
68 return set.clone().downcast().expect(DOWNCAST_ERROR);
69 };
70
71 core::mem::drop(sets);
72
73 let mut sets = self.sets.write();
74
75 if let Some(sets) = sets.as_ref()
76 && let Some(set) = sets.get(&type_id)
77 {
78 return set.clone().downcast().expect(DOWNCAST_ERROR);
79 };
80
81 let content = Arc::new(init_set());
82
83 if let Some(sets) = sets.as_mut() {
84 sets.insert(type_id, content.clone());
85 } else {
86 let mut map = HashMap::<TypeId, Arc<dyn Any + Send + Sync>>::new();
87 map.insert(type_id, content.clone());
88 *sets = Some(map);
89 };
90
91 content
92 }
93
94 pub fn clear(&self) {
96 self.state.clear()
97 }
98
99 #[cfg(feature = "autotune-checks")]
100 fn checks<In: Send + Clone + 'static, Out: AutotuneOutput>(
101 &self,
102 operations: &TunableSet<AK, In, Out>,
103 inputs: &In,
104 ) {
105 let mut checks_outputs = Vec::new();
106 for i in 0..operations.len() {
107 let op = operations.fastest(i);
108 let result = op.execute(inputs.clone());
109 checks_outputs.push(result);
110 }
111 super::check_autotune_outputs(checks_outputs);
112 }
113
114 pub fn execute<S, In, Out>(
116 &self,
117 id: &ID,
118 client: &ComputeClient<S>,
119 operations: Arc<TunableSet<AK, In, Out>>,
120 inputs: In,
121 ) -> Out
122 where
123 S: ComputeServer + 'static,
124 In: Clone + Send + 'static,
125 Out: AutotuneOutput,
126 {
127 let key = operations.generate_key(&inputs);
128
129 let autotune_job = {
131 let tuner_state = self.state.get_or_init(id, move |id| {
132 let name = self.name.replace("::", "-");
133 Tuner::new(&name, &id.to_string())
134 });
135 let tuner = tuner_state.read();
136
137 let mut tuner = match tuner.fastest(&key) {
138 TuneCacheResult::Hit { fastest_index } => {
139 core::mem::drop(tuner);
140 core::mem::drop(tuner_state);
141
142 #[cfg(feature = "autotune-checks")]
143 self.checks(&operations, &inputs);
144
145 let op = operations.fastest(fastest_index);
146 let result = op
147 .execute(inputs)
148 .expect("Should run when selected by autotune.");
149 return result;
150 }
151 TuneCacheResult::Pending => {
152 core::mem::drop(tuner);
153 core::mem::drop(tuner_state);
154
155 let op = operations.fastest(0);
156 let result = op
157 .execute(inputs)
158 .expect("Should run when selected by autotune.");
159 return result;
160 }
161 #[cfg(std_io)]
162 TuneCacheResult::Unchecked => {
163 core::mem::drop(tuner);
164 let mut tuner = tuner_state.write();
165
166 let checksum = operations.compute_checksum();
168 tuner.validate_checksum(&key, &checksum);
169
170 if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
172 core::mem::drop(tuner);
173 core::mem::drop(tuner_state);
174
175 let op = operations.fastest(fastest_index);
176 let result = op
177 .execute(inputs)
178 .expect("Should run when selected by autotune.");
179 return result;
180 }
181
182 tuner
183 }
184
185 #[cfg(not(std_io))]
186 TuneCacheResult::Unchecked => {
187 core::mem::drop(tuner);
188 tuner_state.write()
189 }
190 TuneCacheResult::Miss => {
191 core::mem::drop(tuner);
192 tuner_state.write()
193 }
194 };
195
196 if tuner.autotuning.contains(&key) {
197 Box::new(move || {})
198 } else {
199 tuner.autotuning.insert(key.clone());
200 tuner.prepare_autotune(key.clone(), &inputs, &operations, client)
201 }
202 };
203
204 autotune_job();
205
206 let index_to_run = {
207 let tuner_state = self.state.get(id).unwrap();
208 let mut tuner = tuner_state.write();
209
210 tuner.handle_results();
211
212 match tuner.fastest(&key) {
213 TuneCacheResult::Hit { fastest_index } => {
214 fastest_index
216 }
217 TuneCacheResult::Pending => {
218 0
220 }
221 TuneCacheResult::Miss => {
222 0
227 }
228 TuneCacheResult::Unchecked => {
229 panic!("Should have checked the cache.")
230 }
231 }
232 };
233
234 operations
235 .fastest(index_to_run)
236 .execute(inputs)
237 .expect("Should run when selected by autotune.")
238 }
239}