cubecl_runtime/tune/
local.rs

1use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner};
2use crate::{
3    channel::ComputeChannel, client::ComputeClient, server::ComputeServer, tune::TuneCacheResult,
4};
5use core::{fmt::Display, hash::Hash};
6use hashbrown::HashMap;
7
8#[cfg(not(feature = "std"))]
9use alloc::string::ToString;
10
11/// A local tuner allows to create a tuner for a specific key that can be different from the server
12/// key.
13pub struct LocalTuner<AK: AutotuneKey, ID> {
14    state: spin::RwLock<Option<HashMap<ID, Tuner<AK>>>>,
15    name: &'static str,
16}
17
18/// Create a local tuner with the provided name.
19#[macro_export]
20macro_rules! local_tuner {
21    ($name:expr) => {
22        LocalTuner::new(concat!(module_path!(), "-", $name));
23    };
24    () => {
25        LocalTuner::new(module_path!());
26    };
27}
28
29pub use local_tuner;
30
31impl<AK: AutotuneKey + 'static, ID: Hash + PartialEq + Eq + Clone + Display> LocalTuner<AK, ID> {
32    /// Create a new local tuner.
33    pub const fn new(name: &'static str) -> Self {
34        Self {
35            state: spin::RwLock::new(None),
36            name,
37        }
38    }
39
40    /// Clear the autotune state.
41    pub fn clear(&self) {
42        let mut state = self.state.write();
43        *state = None;
44    }
45
46    #[cfg(feature = "autotune-checks")]
47    fn checks<In: Send + Clone + 'static, Out: AutotuneOutput>(
48        &self,
49        operations: &TunableSet<AK, In, Out>,
50        inputs: &In,
51    ) {
52        let mut checks_outputs = Vec::new();
53        for i in 0..operations.len() {
54            let op = operations.fastest(i);
55            let result = op.execute(inputs.clone());
56            checks_outputs.push(result);
57        }
58        super::check_autotune_outputs(checks_outputs);
59    }
60
61    /// Execute the best operation in the provided [tunable set](TunableSet)
62    pub fn execute<S, C, In: Send + Clone + 'static, Out: AutotuneOutput>(
63        &self,
64        id: &ID,
65        client: &ComputeClient<S, C>,
66        operations: &TunableSet<AK, In, Out>,
67        inputs: In,
68    ) -> Out
69    where
70        S: ComputeServer + 'static,
71        C: ComputeChannel<S> + 'static,
72    {
73        let key = operations.generate_key(&inputs);
74
75        // If this is cached and ready, use the operation.
76        if let Some(map) = self.state.read().as_ref() {
77            if let Some(tuner) = map.get(id) {
78                if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
79                    #[cfg(feature = "autotune-checks")]
80                    self.checks(operations, &inputs);
81
82                    let op = operations.fastest(fastest_index);
83                    let result = op
84                        .execute(inputs)
85                        .expect("Should run when selected by autotune.");
86
87                    return result;
88                }
89            }
90        }
91
92        // Create the tuner if needed, and update some state like
93        // checksums if need be.
94        let (fastest, run_autotune) = {
95            let mut state = self.state.write();
96            let map = state.get_or_insert_with(Default::default);
97            let tuner = map.entry(id.clone()).or_insert_with(move || {
98                let name = self.name.replace("::", "-");
99                Tuner::new(&name, &id.to_string())
100            });
101
102            #[allow(unused_mut)]
103            let mut fastest = tuner.fastest(&key);
104
105            // If the cache checksum hasn't been checked, do so now, and retry.
106            #[cfg(autotune_persistent_cache)]
107            if matches!(fastest, TuneCacheResult::Unchecked) {
108                let checksum = operations.compute_checksum();
109                tuner.validate_checksum(&key, &checksum);
110                fastest = tuner.fastest(&key);
111            }
112            let mut run_autotune = false;
113
114            if matches!(fastest, TuneCacheResult::Miss) && !tuner.autotuning.contains(&key) {
115                tuner.autotuning.insert(key.clone());
116                run_autotune = true;
117            }
118            (fastest, run_autotune)
119        };
120
121        match fastest {
122            TuneCacheResult::Hit { fastest_index } => {
123                #[cfg(feature = "autotune-checks")]
124                self.checks(operations, &inputs);
125
126                return operations
127                    .fastest(fastest_index)
128                    .execute(inputs)
129                    .expect("Should run when selected by autotune.");
130            }
131            TuneCacheResult::Miss => {
132                if run_autotune {
133                    // We don't know the results yet, start autotuning.
134                    //
135                    // Running benchmarks should't lock the tuner, since an autotune operation can recursively use the
136                    // same tuner.
137                    //
138                    // # Example
139                    //
140                    // ```
141                    // - tune_1 start
142                    //   - tune_2 start
143                    //   - tune_2 save
144                    // - tune_1 save
145                    // ```
146                    let state = self.state.read();
147                    let tuner = state
148                        .as_ref()
149                        .and_then(|s| s.get(id))
150                        .expect("Should be initialized");
151                    tuner.execute_autotune(key.clone(), &inputs, operations, client);
152                } else {
153                    // We're waiting for results to come in.
154                }
155            }
156            TuneCacheResult::Pending => {
157                // We're waiting for results to come in.
158            }
159            TuneCacheResult::Unchecked => {
160                panic!("Should have checked the cache already.")
161            }
162        };
163
164        let fastest = {
165            let mut state = self.state.write();
166            let tuner = state
167                .as_mut()
168                .and_then(|s| s.get_mut(id))
169                .expect("Should be initialized");
170
171            // Read all results that have come in since.
172            tuner.handle_results();
173
174            // Check again what the fastest option is, new results might have come in.
175            match tuner.fastest(&key) {
176                TuneCacheResult::Hit { fastest_index } => {
177                    // Theres a known good value - just run it.
178                    fastest_index
179                }
180                TuneCacheResult::Pending => {
181                    // If we still don't know, just execute a default index.
182                    0
183                }
184                TuneCacheResult::Miss => {
185                    if run_autotune {
186                        panic!("Should have at least started autotuning");
187                    } else {
188                        // We're still waiting for the results of the autotune task.
189                        // Let's execute the default index while we wait.
190                        //
191                        // This should only happen on wasm since we can't block waiting on the results there.
192                        0
193                    }
194                }
195                TuneCacheResult::Unchecked => {
196                    panic!("Should have checked the cache.")
197                }
198            }
199        };
200
201        #[cfg(feature = "autotune-checks")]
202        self.checks(operations, &inputs);
203
204        operations
205            .fastest(fastest)
206            .execute(inputs)
207            .expect("Should run when selected by autotune.")
208    }
209}