Skip to main content

cubecl_runtime/tune/
local.rs

1use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner};
2use crate::{client::ComputeClient, runtime::Runtime, tune::TuneCacheResult};
3use alloc::string::ToString;
4use alloc::sync::Arc;
5use core::{
6    any::{Any, TypeId},
7    fmt::Display,
8    hash::Hash,
9};
10use cubecl_common::map::SharedStateMap;
11use hashbrown::HashMap;
12
13/// A local tuner allows to create a tuner for a specific key that can be different from the server
14/// key.
15pub struct LocalTuner<AK: AutotuneKey, ID> {
16    state: SharedStateMap<ID, Tuner<AK>>,
17    name: &'static str,
18    sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
19}
20
21unsafe impl<AK: AutotuneKey, ID> Sync for LocalTuner<AK, ID> {}
22
23/// Create a local tuner with the provided name.
24#[macro_export]
25macro_rules! local_tuner {
26    ($name:expr) => {
27        LocalTuner::new(concat!(module_path!(), "-", $name));
28    };
29    () => {
30        LocalTuner::new(module_path!());
31    };
32}
33
34pub use local_tuner;
35
36impl<AK, ID> LocalTuner<AK, ID>
37where
38    AK: AutotuneKey + 'static,
39    ID: Hash + PartialEq + Eq + Clone + Display,
40{
41    /// Create a new local tuner.
42    pub const fn new(name: &'static str) -> Self {
43        Self {
44            state: SharedStateMap::new(),
45            name,
46            sets: spin::RwLock::new(None),
47        }
48    }
49
50    /// Init the [tunable set](TunableSet)
51    pub fn init<In, Out, F>(&self, init_set: F) -> Arc<TunableSet<AK, In, Out>>
52    where
53        F: Fn() -> TunableSet<AK, In, Out> + 'static + Send + Sync,
54        In: Clone + Send + 'static,
55        Out: AutotuneOutput,
56    {
57        let sets = self.sets.read();
58        let type_id = TypeId::of::<F>();
59
60        static DOWNCAST_ERROR: &str = "Local tuner only support one set of tunable that must work on the same input and output declared with the init function.";
61
62        if let Some(sets) = sets.as_ref()
63            && let Some(set) = sets.get(&type_id)
64        {
65            return set.clone().downcast().expect(DOWNCAST_ERROR);
66        };
67
68        core::mem::drop(sets);
69
70        let mut sets = self.sets.write();
71
72        if let Some(sets) = sets.as_ref()
73            && let Some(set) = sets.get(&type_id)
74        {
75            return set.clone().downcast().expect(DOWNCAST_ERROR);
76        };
77
78        let content = Arc::new(init_set());
79
80        if let Some(sets) = sets.as_mut() {
81            sets.insert(type_id, content.clone());
82        } else {
83            let mut map = HashMap::<TypeId, Arc<dyn Any + Send + Sync>>::new();
84            map.insert(type_id, content.clone());
85            *sets = Some(map);
86        };
87
88        content
89    }
90
91    /// Clear the autotune state.
92    pub fn clear(&self) {
93        self.state.clear()
94    }
95
96    #[cfg(feature = "autotune-checks")]
97    fn checks<In: Send + Clone + 'static, Out: AutotuneOutput>(
98        &self,
99        operations: &TunableSet<AK, In, Out>,
100        inputs: &In,
101    ) {
102        use alloc::vec::Vec;
103
104        let mut checks_outputs = Vec::new();
105        for i in 0..operations.len() {
106            let op = operations.fastest(i);
107            let result = op.execute(inputs.clone());
108            checks_outputs.push(result);
109        }
110        super::check_autotune_outputs(checks_outputs);
111    }
112
113    /// Try every operation in order and return the first successful result.
114    ///
115    /// Used as a fallback when autotuning results aren't available yet
116    /// (e.g. on wasm where tuning is async).
117    fn try_all_operations<In, Out>(operations: &TunableSet<AK, In, Out>, inputs: In) -> Out
118    where
119        In: Clone + Send + 'static,
120        Out: AutotuneOutput,
121    {
122        for i in 0..operations.len() {
123            if let Ok(output) = operations.fastest(i).execute(inputs.clone()) {
124                return output;
125            }
126        }
127        panic!("All autotune operations failed, no viable operation found.");
128    }
129
130    /// Execute the best operation in the provided [tunable set](TunableSet)
131    pub fn execute<R: Runtime, In, Out>(
132        &self,
133        id: &ID,
134        client: &ComputeClient<R>,
135        operations: Arc<TunableSet<AK, In, Out>>,
136        inputs: In,
137    ) -> Out
138    where
139        In: Clone + Send + 'static,
140        Out: AutotuneOutput,
141    {
142        let key = operations.generate_key(&inputs);
143
144        // If this is cached and ready, use the operation.
145        let tuner_state = self.state.get_or_init(id, move |id| {
146            let name = self.name.replace("::", "-");
147            Tuner::new(&name, &id.to_string())
148        });
149        let tuner = tuner_state.read();
150
151        let mut tuner = match tuner.fastest(&key) {
152            TuneCacheResult::Hit { fastest_index } => {
153                core::mem::drop(tuner);
154                core::mem::drop(tuner_state);
155
156                #[cfg(feature = "autotune-checks")]
157                self.checks(&operations, &inputs);
158
159                let op = operations.fastest(fastest_index);
160                let result = op
161                    .execute(inputs)
162                    .expect("Should run when selected by autotune.");
163                return result;
164            }
165            TuneCacheResult::Pending => {
166                core::mem::drop(tuner);
167                core::mem::drop(tuner_state);
168
169                #[cfg(feature = "autotune-checks")]
170                self.checks(&operations, &inputs);
171
172                return Self::try_all_operations(&operations, inputs);
173            }
174            #[cfg(std_io)]
175            TuneCacheResult::Unchecked => {
176                core::mem::drop(tuner);
177                let mut tuner = tuner_state.write();
178
179                // If the cache checksum hasn't been checked, do so now, and retry.
180                let checksum = operations.compute_checksum();
181                tuner.validate_checksum(&key, &checksum);
182
183                // Check if with validation we can use its result
184                if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
185                    core::mem::drop(tuner);
186                    core::mem::drop(tuner_state);
187
188                    let op = operations.fastest(fastest_index);
189                    let result = op
190                        .execute(inputs)
191                        .expect("Should run when selected by autotune.");
192                    return result;
193                }
194
195                tuner
196            }
197
198            #[cfg(not(std_io))]
199            TuneCacheResult::Unchecked => {
200                core::mem::drop(tuner);
201                tuner_state.write()
202            }
203            TuneCacheResult::Miss => {
204                core::mem::drop(tuner);
205                tuner_state.write()
206            }
207        };
208
209        let job = if !tuner.autotuning.contains(&key) {
210            tuner.autotuning.insert(key.clone());
211            Some(tuner.prepare_autotune(key.clone(), &inputs, &operations, client))
212        } else {
213            None
214        };
215
216        // Drop the write lock before running the (potentially blocking) job
217        // and before re-acquiring the lock below.
218        core::mem::drop(tuner);
219        core::mem::drop(tuner_state);
220
221        if let Some(job) = job {
222            job();
223        }
224
225        let index_to_run = {
226            let tuner_state = self.state.get(id).unwrap();
227            let mut tuner = tuner_state.write();
228
229            tuner.handle_results();
230
231            match tuner.fastest(&key) {
232                TuneCacheResult::Hit { fastest_index } => {
233                    // There's a known good value - just run it.
234                    fastest_index
235                }
236                TuneCacheResult::Pending | TuneCacheResult::Miss => {
237                    // We're still waiting for the results of the autotune task.
238                    // This should only happen on wasm since we can't block waiting
239                    // on the results there. Try all options.
240                    return Self::try_all_operations(&operations, inputs);
241                }
242                TuneCacheResult::Unchecked => {
243                    panic!("Should have checked the cache.")
244                }
245            }
246        };
247
248        operations
249            .fastest(index_to_run)
250            .execute(inputs)
251            .expect("Should run when selected by autotune.")
252    }
253}