cubecl_runtime/tune/
local.rs

1use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner};
2use crate::{client::ComputeClient, server::ComputeServer, tune::TuneCacheResult};
3use alloc::boxed::Box;
4use alloc::sync::Arc;
5use core::{
6    any::{Any, TypeId},
7    fmt::Display,
8    hash::Hash,
9};
10use cubecl_common::map::SharedStateMap;
11use hashbrown::HashMap;
12
13#[cfg(not(feature = "std"))]
14use alloc::string::ToString;
15
16/// A local tuner allows to create a tuner for a specific key that can be different from the server
17/// key.
18pub struct LocalTuner<AK: AutotuneKey, ID> {
19    state: SharedStateMap<ID, Tuner<AK>>,
20    name: &'static str,
21    sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
22}
23
24unsafe impl<AK: AutotuneKey, ID> Sync for LocalTuner<AK, ID> {}
25
26/// Create a local tuner with the provided name.
27#[macro_export]
28macro_rules! local_tuner {
29    ($name:expr) => {
30        LocalTuner::new(concat!(module_path!(), "-", $name));
31    };
32    () => {
33        LocalTuner::new(module_path!());
34    };
35}
36
37pub use local_tuner;
38
39impl<AK, ID> LocalTuner<AK, ID>
40where
41    AK: AutotuneKey + 'static,
42    ID: Hash + PartialEq + Eq + Clone + Display,
43{
44    /// Create a new local tuner.
45    pub const fn new(name: &'static str) -> Self {
46        Self {
47            state: SharedStateMap::new(),
48            name,
49            sets: spin::RwLock::new(None),
50        }
51    }
52
53    /// Init the [tunable set](TunableSet)
54    pub fn init<In, Out, F>(&self, init_set: F) -> Arc<TunableSet<AK, In, Out>>
55    where
56        F: Fn() -> TunableSet<AK, In, Out> + 'static + Send + Sync,
57        In: Clone + Send + 'static,
58        Out: AutotuneOutput,
59    {
60        let sets = self.sets.read();
61        let type_id = TypeId::of::<F>();
62
63        static DOWNCAST_ERROR: &str = "Local tuner only support one set of tunable that must work on the same input and output declared with the init function.";
64
65        if let Some(sets) = sets.as_ref()
66            && let Some(set) = sets.get(&type_id)
67        {
68            return set.clone().downcast().expect(DOWNCAST_ERROR);
69        };
70
71        core::mem::drop(sets);
72
73        let mut sets = self.sets.write();
74
75        if let Some(sets) = sets.as_ref()
76            && let Some(set) = sets.get(&type_id)
77        {
78            return set.clone().downcast().expect(DOWNCAST_ERROR);
79        };
80
81        let content = Arc::new(init_set());
82
83        if let Some(sets) = sets.as_mut() {
84            sets.insert(type_id, content.clone());
85        } else {
86            let mut map = HashMap::<TypeId, Arc<dyn Any + Send + Sync>>::new();
87            map.insert(type_id, content.clone());
88            *sets = Some(map);
89        };
90
91        content
92    }
93
94    /// Clear the autotune state.
95    pub fn clear(&self) {
96        self.state.clear()
97    }
98
99    #[cfg(feature = "autotune-checks")]
100    fn checks<In: Send + Clone + 'static, Out: AutotuneOutput>(
101        &self,
102        operations: &TunableSet<AK, In, Out>,
103        inputs: &In,
104    ) {
105        let mut checks_outputs = Vec::new();
106        for i in 0..operations.len() {
107            let op = operations.fastest(i);
108            let result = op.execute(inputs.clone());
109            checks_outputs.push(result);
110        }
111        super::check_autotune_outputs(checks_outputs);
112    }
113
114    /// Execute the best operation in the provided [tunable set](TunableSet)
115    pub fn execute<S, In, Out>(
116        &self,
117        id: &ID,
118        client: &ComputeClient<S>,
119        operations: Arc<TunableSet<AK, In, Out>>,
120        inputs: In,
121    ) -> Out
122    where
123        S: ComputeServer + 'static,
124        In: Clone + Send + 'static,
125        Out: AutotuneOutput,
126    {
127        let key = operations.generate_key(&inputs);
128
129        // If this is cached and ready, use the operation.
130        let autotune_job = {
131            let tuner_state = self.state.get_or_init(id, move |id| {
132                let name = self.name.replace("::", "-");
133                Tuner::new(&name, &id.to_string())
134            });
135            let tuner = tuner_state.read();
136
137            let mut tuner = match tuner.fastest(&key) {
138                TuneCacheResult::Hit { fastest_index } => {
139                    core::mem::drop(tuner);
140                    core::mem::drop(tuner_state);
141
142                    #[cfg(feature = "autotune-checks")]
143                    self.checks(&operations, &inputs);
144
145                    let op = operations.fastest(fastest_index);
146                    let result = op
147                        .execute(inputs)
148                        .expect("Should run when selected by autotune.");
149                    return result;
150                }
151                TuneCacheResult::Pending => {
152                    core::mem::drop(tuner);
153                    core::mem::drop(tuner_state);
154
155                    let op = operations.fastest(0);
156                    let result = op
157                        .execute(inputs)
158                        .expect("Should run when selected by autotune.");
159                    return result;
160                }
161                #[cfg(std_io)]
162                TuneCacheResult::Unchecked => {
163                    core::mem::drop(tuner);
164                    let mut tuner = tuner_state.write();
165
166                    // If the cache checksum hasn't been checked, do so now, and retry.
167                    let checksum = operations.compute_checksum();
168                    tuner.validate_checksum(&key, &checksum);
169
170                    // Check if with validation we can use its result
171                    if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
172                        core::mem::drop(tuner);
173                        core::mem::drop(tuner_state);
174
175                        let op = operations.fastest(fastest_index);
176                        let result = op
177                            .execute(inputs)
178                            .expect("Should run when selected by autotune.");
179                        return result;
180                    }
181
182                    tuner
183                }
184
185                #[cfg(not(std_io))]
186                TuneCacheResult::Unchecked => {
187                    core::mem::drop(tuner);
188                    tuner_state.write()
189                }
190                TuneCacheResult::Miss => {
191                    core::mem::drop(tuner);
192                    tuner_state.write()
193                }
194            };
195
196            if tuner.autotuning.contains(&key) {
197                Box::new(move || {})
198            } else {
199                tuner.autotuning.insert(key.clone());
200                tuner.prepare_autotune(key.clone(), &inputs, &operations, client)
201            }
202        };
203
204        autotune_job();
205
206        let index_to_run = {
207            let tuner_state = self.state.get(id).unwrap();
208            let mut tuner = tuner_state.write();
209
210            tuner.handle_results();
211
212            match tuner.fastest(&key) {
213                TuneCacheResult::Hit { fastest_index } => {
214                    // Theres a known good value - just run it.
215                    fastest_index
216                }
217                TuneCacheResult::Pending => {
218                    // If we still don't know, just execute a default index.
219                    0
220                }
221                TuneCacheResult::Miss => {
222                    // We're still waiting for the results of the autotune task.
223                    // Let's execute the default index while we wait.
224                    //
225                    // This should only happen on wasm since we can't block waiting on the results there.
226                    0
227                }
228                TuneCacheResult::Unchecked => {
229                    panic!("Should have checked the cache.")
230                }
231            }
232        };
233
234        operations
235            .fastest(index_to_run)
236            .execute(inputs)
237            .expect("Should run when selected by autotune.")
238    }
239}