Skip to main content

cubecl_runtime/tune/
local.rs

1use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneInputs, Tuner};
2use crate::{client::ComputeClient, runtime::Runtime, tune::TuneCacheResult};
3use alloc::string::ToString;
4use alloc::sync::Arc;
5use core::{
6    any::{Any, TypeId},
7    fmt::Display,
8    hash::Hash,
9};
10use hashbrown::HashMap;
11use spin::Mutex;
12
13/// A local tuner allows to create a tuner for a specific key that can be different from the server
14/// key.
15pub struct LocalTuner<AK: AutotuneKey, ID> {
16    state: Mutex<Option<HashMap<ID, Arc<Tuner<AK>>>>>,
17    name: &'static str,
18    sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
19}
20
21/// Create a local tuner with the provided name.
22#[macro_export]
23macro_rules! local_tuner {
24    ($name:expr) => {
25        LocalTuner::new(concat!(module_path!(), "-", $name));
26    };
27    () => {
28        LocalTuner::new(module_path!());
29    };
30}
31
32pub use local_tuner;
33
34impl<AK, ID> LocalTuner<AK, ID>
35where
36    AK: AutotuneKey + 'static,
37    ID: Hash + PartialEq + Eq + Clone + Display,
38{
39    /// Create a new local tuner.
40    pub const fn new(name: &'static str) -> Self {
41        Self {
42            state: Mutex::new(None),
43            name,
44            sets: spin::RwLock::new(None),
45        }
46    }
47
48    /// Get or initialize the [`TunableSet`] for this tuner.
49    ///
50    /// Returns a cached `Arc<TunableSet>` keyed by the `TypeId` of `init_set`. The
51    /// initializer runs at most once per process.
52    pub fn init<I, Out, F>(&self, init_set: F) -> Arc<TunableSet<AK, I, Out>>
53    where
54        F: Fn() -> TunableSet<AK, I, Out> + 'static + Send + Sync,
55        I: TuneInputs,
56        Out: AutotuneOutput,
57    {
58        let sets = self.sets.read();
59        let type_id = TypeId::of::<F>();
60
61        static DOWNCAST_ERROR: &str = "Local tuner only support one set of tunable that must work on the same input and output declared with the init function.";
62
63        if let Some(sets) = sets.as_ref()
64            && let Some(set) = sets.get(&type_id)
65        {
66            return set.clone().downcast().expect(DOWNCAST_ERROR);
67        };
68
69        core::mem::drop(sets);
70
71        let mut sets = self.sets.write();
72
73        if let Some(sets) = sets.as_ref()
74            && let Some(set) = sets.get(&type_id)
75        {
76            return set.clone().downcast().expect(DOWNCAST_ERROR);
77        };
78
79        let content = Arc::new(init_set());
80
81        if let Some(sets) = sets.as_mut() {
82            sets.insert(type_id, content.clone());
83        } else {
84            let mut map = HashMap::<TypeId, Arc<dyn Any + Send + Sync>>::new();
85            map.insert(type_id, content.clone());
86            *sets = Some(map);
87        };
88
89        content
90    }
91
92    /// Clear the autotune state.
93    pub fn clear(&self) {
94        if let Some(s) = self.state.lock().as_mut() {
95            s.clear()
96        }
97    }
98
99    #[cfg(feature = "autotune-checks")]
100    fn checks<'a, I: TuneInputs, Out: AutotuneOutput>(
101        &self,
102        operations: &TunableSet<AK, I, Out>,
103        inputs: &<I as TuneInputs>::At<'a>,
104    ) where
105        <I as TuneInputs>::At<'a>: Clone + Send,
106    {
107        use alloc::vec::Vec;
108
109        let mut checks_outputs = Vec::new();
110        for i in 0..operations.len() {
111            let op = operations.fastest(i);
112            let result = op.execute(inputs.clone());
113            checks_outputs.push(result);
114        }
115        super::check_autotune_outputs(checks_outputs);
116    }
117
118    /// Execute the fastest operation in a [`TunableSet`], triggering a tuning pass on
119    /// the first call for a given key.
120    pub fn execute<'a, R: Runtime, I: TuneInputs, Out>(
121        &self,
122        id: &ID,
123        client: &ComputeClient<R>,
124        operations: Arc<TunableSet<AK, I, Out>>,
125        inputs: <I as TuneInputs>::At<'a>,
126    ) -> Out
127    where
128        <I as TuneInputs>::At<'a>: Clone + Send,
129        Out: AutotuneOutput,
130    {
131        let key = operations.generate_key(&inputs);
132
133        let tuner = {
134            let mut state_lock = self.state.lock();
135            let state_map = state_lock.get_or_insert_with(|| HashMap::new());
136            state_map
137                .entry(id.clone())
138                .or_insert_with(move || {
139                    let name = self.name.replace("::", "-");
140                    Arc::new(Tuner::new(&name, &id.to_string()))
141                })
142                .clone()
143        };
144
145        // First, check for a cache hit under a read lock.
146        if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
147            #[cfg(feature = "autotune-checks")]
148            self.checks::<I, Out>(&operations, &inputs);
149            return operations
150                .fastest(fastest_index)
151                .execute(inputs)
152                .expect("Should run when selected by autotune.");
153        }
154
155        let fastest = tuner.check_tune::<R, I, Out>(
156            &key,
157            &inputs,
158            &operations,
159            || operations.compute_checksum(),
160            client,
161        );
162
163        // Run the execution depending on the cache state.
164        match fastest {
165            TuneCacheResult::Hit { fastest_index } => {
166                #[cfg(feature = "autotune-checks")]
167                self.checks::<I, Out>(&operations, &inputs);
168
169                operations
170                    .fastest(fastest_index)
171                    .execute(inputs)
172                    .expect("Should run when selected by autotune.")
173            }
174            TuneCacheResult::Unchecked | TuneCacheResult::Miss => {
175                panic!(
176                    "Somehow we STILL didn't check a tuning checksum or start tuning, something has gone wrong."
177                )
178            }
179            TuneCacheResult::Pending => {
180                // Still waiting (e.g. on wasm). Try all operations as a fallback.
181                for i in 0..operations.len() {
182                    if let Ok(output) = operations.fastest(i).execute(inputs.clone()) {
183                        return output;
184                    }
185                }
186                panic!("All autotune operations failed, no viable operation found.");
187            }
188        }
189    }
190}