cubecl_runtime/tune/
local.rs1use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneInputs, Tuner};
2use crate::{client::ComputeClient, runtime::Runtime, tune::TuneCacheResult};
3use alloc::string::ToString;
4use alloc::sync::Arc;
5use core::{
6 any::{Any, TypeId},
7 fmt::Display,
8 hash::Hash,
9};
10use hashbrown::HashMap;
11use spin::Mutex;
12
13pub struct LocalTuner<AK: AutotuneKey, ID> {
16 state: Mutex<Option<HashMap<ID, Arc<Tuner<AK>>>>>,
17 name: &'static str,
18 sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
19}
20
21#[macro_export]
23macro_rules! local_tuner {
24 ($name:expr) => {
25 LocalTuner::new(concat!(module_path!(), "-", $name));
26 };
27 () => {
28 LocalTuner::new(module_path!());
29 };
30}
31
32pub use local_tuner;
33
34impl<AK, ID> LocalTuner<AK, ID>
35where
36 AK: AutotuneKey + 'static,
37 ID: Hash + PartialEq + Eq + Clone + Display,
38{
39 pub const fn new(name: &'static str) -> Self {
41 Self {
42 state: Mutex::new(None),
43 name,
44 sets: spin::RwLock::new(None),
45 }
46 }
47
48 pub fn init<I, Out, F>(&self, init_set: F) -> Arc<TunableSet<AK, I, Out>>
53 where
54 F: Fn() -> TunableSet<AK, I, Out> + 'static + Send + Sync,
55 I: TuneInputs,
56 Out: AutotuneOutput,
57 {
58 let sets = self.sets.read();
59 let type_id = TypeId::of::<F>();
60
61 static DOWNCAST_ERROR: &str = "Local tuner only support one set of tunable that must work on the same input and output declared with the init function.";
62
63 if let Some(sets) = sets.as_ref()
64 && let Some(set) = sets.get(&type_id)
65 {
66 return set.clone().downcast().expect(DOWNCAST_ERROR);
67 };
68
69 core::mem::drop(sets);
70
71 let mut sets = self.sets.write();
72
73 if let Some(sets) = sets.as_ref()
74 && let Some(set) = sets.get(&type_id)
75 {
76 return set.clone().downcast().expect(DOWNCAST_ERROR);
77 };
78
79 let content = Arc::new(init_set());
80
81 if let Some(sets) = sets.as_mut() {
82 sets.insert(type_id, content.clone());
83 } else {
84 let mut map = HashMap::<TypeId, Arc<dyn Any + Send + Sync>>::new();
85 map.insert(type_id, content.clone());
86 *sets = Some(map);
87 };
88
89 content
90 }
91
92 pub fn clear(&self) {
94 if let Some(s) = self.state.lock().as_mut() {
95 s.clear()
96 }
97 }
98
99 #[cfg(feature = "autotune-checks")]
100 fn checks<'a, I: TuneInputs, Out: AutotuneOutput>(
101 &self,
102 operations: &TunableSet<AK, I, Out>,
103 inputs: &<I as TuneInputs>::At<'a>,
104 ) where
105 <I as TuneInputs>::At<'a>: Clone + Send,
106 {
107 use alloc::vec::Vec;
108
109 let mut checks_outputs = Vec::new();
110 for i in 0..operations.len() {
111 let op = operations.fastest(i);
112 let result = op.execute(inputs.clone());
113 checks_outputs.push(result);
114 }
115 super::check_autotune_outputs(checks_outputs);
116 }
117
118 pub fn execute<'a, R: Runtime, I: TuneInputs, Out>(
121 &self,
122 id: &ID,
123 client: &ComputeClient<R>,
124 operations: Arc<TunableSet<AK, I, Out>>,
125 inputs: <I as TuneInputs>::At<'a>,
126 ) -> Out
127 where
128 <I as TuneInputs>::At<'a>: Clone + Send,
129 Out: AutotuneOutput,
130 {
131 let key = operations.generate_key(&inputs);
132
133 let tuner = {
134 let mut state_lock = self.state.lock();
135 let state_map = state_lock.get_or_insert_with(|| HashMap::new());
136 state_map
137 .entry(id.clone())
138 .or_insert_with(move || {
139 let name = self.name.replace("::", "-");
140 Arc::new(Tuner::new(&name, &id.to_string()))
141 })
142 .clone()
143 };
144
145 if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
147 #[cfg(feature = "autotune-checks")]
148 self.checks::<I, Out>(&operations, &inputs);
149 return operations
150 .fastest(fastest_index)
151 .execute(inputs)
152 .expect("Should run when selected by autotune.");
153 }
154
155 let fastest = tuner.check_tune::<R, I, Out>(
156 &key,
157 &inputs,
158 &operations,
159 || operations.compute_checksum(),
160 client,
161 );
162
163 match fastest {
165 TuneCacheResult::Hit { fastest_index } => {
166 #[cfg(feature = "autotune-checks")]
167 self.checks::<I, Out>(&operations, &inputs);
168
169 operations
170 .fastest(fastest_index)
171 .execute(inputs)
172 .expect("Should run when selected by autotune.")
173 }
174 TuneCacheResult::Unchecked | TuneCacheResult::Miss => {
175 panic!(
176 "Somehow we STILL didn't check a tuning checksum or start tuning, something has gone wrong."
177 )
178 }
179 TuneCacheResult::Pending => {
180 for i in 0..operations.len() {
182 if let Ok(output) = operations.fastest(i).execute(inputs.clone()) {
183 return output;
184 }
185 }
186 panic!("All autotune operations failed, no viable operation found.");
187 }
188 }
189 }
190}