Skip to main content

cubecl_runtime/tune/
local.rs

1use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner};
2use crate::{client::ComputeClient, runtime::Runtime, tune::TuneCacheResult};
3use alloc::boxed::Box;
4use alloc::sync::Arc;
5use core::{
6    any::{Any, TypeId},
7    fmt::Display,
8    hash::Hash,
9};
10use cubecl_common::map::SharedStateMap;
11use hashbrown::HashMap;
12
13use alloc::string::ToString;
14
15/// A local tuner allows to create a tuner for a specific key that can be different from the server
16/// key.
17pub struct LocalTuner<AK: AutotuneKey, ID> {
18    state: SharedStateMap<ID, Tuner<AK>>,
19    name: &'static str,
20    sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
21}
22
23unsafe impl<AK: AutotuneKey, ID> Sync for LocalTuner<AK, ID> {}
24
25/// Create a local tuner with the provided name.
26#[macro_export]
27macro_rules! local_tuner {
28    ($name:expr) => {
29        LocalTuner::new(concat!(module_path!(), "-", $name));
30    };
31    () => {
32        LocalTuner::new(module_path!());
33    };
34}
35
36pub use local_tuner;
37
38impl<AK, ID> LocalTuner<AK, ID>
39where
40    AK: AutotuneKey + 'static,
41    ID: Hash + PartialEq + Eq + Clone + Display,
42{
43    /// Create a new local tuner.
44    pub const fn new(name: &'static str) -> Self {
45        Self {
46            state: SharedStateMap::new(),
47            name,
48            sets: spin::RwLock::new(None),
49        }
50    }
51
52    /// Init the [tunable set](TunableSet)
53    pub fn init<In, Out, F>(&self, init_set: F) -> Arc<TunableSet<AK, In, Out>>
54    where
55        F: Fn() -> TunableSet<AK, In, Out> + 'static + Send + Sync,
56        In: Clone + Send + 'static,
57        Out: AutotuneOutput,
58    {
59        let sets = self.sets.read();
60        let type_id = TypeId::of::<F>();
61
62        static DOWNCAST_ERROR: &str = "Local tuner only support one set of tunable that must work on the same input and output declared with the init function.";
63
64        if let Some(sets) = sets.as_ref()
65            && let Some(set) = sets.get(&type_id)
66        {
67            return set.clone().downcast().expect(DOWNCAST_ERROR);
68        };
69
70        core::mem::drop(sets);
71
72        let mut sets = self.sets.write();
73
74        if let Some(sets) = sets.as_ref()
75            && let Some(set) = sets.get(&type_id)
76        {
77            return set.clone().downcast().expect(DOWNCAST_ERROR);
78        };
79
80        let content = Arc::new(init_set());
81
82        if let Some(sets) = sets.as_mut() {
83            sets.insert(type_id, content.clone());
84        } else {
85            let mut map = HashMap::<TypeId, Arc<dyn Any + Send + Sync>>::new();
86            map.insert(type_id, content.clone());
87            *sets = Some(map);
88        };
89
90        content
91    }
92
93    /// Clear the autotune state.
94    pub fn clear(&self) {
95        self.state.clear()
96    }
97
98    #[cfg(feature = "autotune-checks")]
99    fn checks<In: Send + Clone + 'static, Out: AutotuneOutput>(
100        &self,
101        operations: &TunableSet<AK, In, Out>,
102        inputs: &In,
103    ) {
104        use alloc::vec::Vec;
105
106        let mut checks_outputs = Vec::new();
107        for i in 0..operations.len() {
108            let op = operations.fastest(i);
109            let result = op.execute(inputs.clone());
110            checks_outputs.push(result);
111        }
112        super::check_autotune_outputs(checks_outputs);
113    }
114
115    /// Execute the best operation in the provided [tunable set](TunableSet)
116    pub fn execute<R: Runtime, In, Out>(
117        &self,
118        id: &ID,
119        client: &ComputeClient<R>,
120        operations: Arc<TunableSet<AK, In, Out>>,
121        inputs: In,
122    ) -> Out
123    where
124        In: Clone + Send + 'static,
125        Out: AutotuneOutput,
126    {
127        let key = operations.generate_key(&inputs);
128
129        // If this is cached and ready, use the operation.
130        let autotune_job = {
131            let tuner_state = self.state.get_or_init(id, move |id| {
132                let name = self.name.replace("::", "-");
133                Tuner::new(&name, &id.to_string())
134            });
135            let tuner = tuner_state.read();
136
137            let mut tuner = match tuner.fastest(&key) {
138                TuneCacheResult::Hit { fastest_index } => {
139                    core::mem::drop(tuner);
140                    core::mem::drop(tuner_state);
141
142                    #[cfg(feature = "autotune-checks")]
143                    self.checks(&operations, &inputs);
144
145                    let op = operations.fastest(fastest_index);
146                    let result = op
147                        .execute(inputs)
148                        .expect("Should run when selected by autotune.");
149                    return result;
150                }
151                TuneCacheResult::Pending => {
152                    core::mem::drop(tuner);
153                    core::mem::drop(tuner_state);
154
155                    let op = operations.fastest(0);
156                    let result = op
157                        .execute(inputs)
158                        .expect("Should run when selected by autotune.");
159                    return result;
160                }
161                #[cfg(std_io)]
162                TuneCacheResult::Unchecked => {
163                    core::mem::drop(tuner);
164                    let mut tuner = tuner_state.write();
165
166                    // If the cache checksum hasn't been checked, do so now, and retry.
167                    let checksum = operations.compute_checksum();
168                    tuner.validate_checksum(&key, &checksum);
169
170                    // Check if with validation we can use its result
171                    if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
172                        core::mem::drop(tuner);
173                        core::mem::drop(tuner_state);
174
175                        let op = operations.fastest(fastest_index);
176                        let result = op
177                            .execute(inputs)
178                            .expect("Should run when selected by autotune.");
179                        return result;
180                    }
181
182                    tuner
183                }
184
185                #[cfg(not(std_io))]
186                TuneCacheResult::Unchecked => {
187                    core::mem::drop(tuner);
188                    tuner_state.write()
189                }
190                TuneCacheResult::Miss => {
191                    core::mem::drop(tuner);
192                    tuner_state.write()
193                }
194            };
195
196            if tuner.autotuning.contains(&key) {
197                Box::new(move || {})
198            } else {
199                tuner.autotuning.insert(key.clone());
200                tuner.prepare_autotune(key.clone(), &inputs, &operations, client)
201            }
202        };
203
204        autotune_job();
205
206        let index_to_run = {
207            let tuner_state = self.state.get(id).unwrap();
208            let mut tuner = tuner_state.write();
209
210            tuner.handle_results();
211
212            match tuner.fastest(&key) {
213                TuneCacheResult::Hit { fastest_index } => {
214                    // Theres a known good value - just run it.
215                    fastest_index
216                }
217                TuneCacheResult::Pending => {
218                    // If we still don't know, just execute a default index.
219                    0
220                }
221                TuneCacheResult::Miss => {
222                    // We're still waiting for the results of the autotune task.
223                    // Let's execute the default index while we wait.
224                    //
225                    // This should only happen on wasm since we can't block waiting on the results there.
226                    0
227                }
228                TuneCacheResult::Unchecked => {
229                    panic!("Should have checked the cache.")
230                }
231            }
232        };
233
234        operations
235            .fastest(index_to_run)
236            .execute(inputs)
237            .expect("Should run when selected by autotune.")
238    }
239}