cubecl_runtime/tune/
tuner.rs

1use async_channel::{Receiver, Sender};
2use cubecl_common::future;
3
4use core::any::Any;
5use core::future::Future;
6use core::mem::ManuallyDrop;
7use cubecl_common::stub::Duration;
8
9#[cfg(all(not(target_family = "wasm"), feature = "std"))]
10use std::panic::resume_unwind;
11
12use alloc::string::ToString;
13use alloc::vec::Vec;
14use cubecl_common::benchmark::BenchmarkComputations;
15
16use crate::channel::ComputeChannel;
17use crate::client::ComputeClient;
18use crate::server::ComputeServer;
19use crate::tune::{AutotuneOperationSet, TuneBenchmark, TuneCache};
20
21use super::{AutotuneKey, TuneCacheResult};
22
23#[derive(Debug)]
24/// Executes autotune benchmarking and caching
25pub struct Tuner<K: AutotuneKey> {
26    tune_cache: TuneCache<K>,
27    channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
28}
29
30/// Result from running benchmarks.
31enum AutotuneMessage<K> {
32    Done {
33        key: K,
34        fastest_index: usize,
35        #[cfg(autotune_persistent_cache)]
36        checksum: String,
37    },
38    Starting {
39        key: K,
40    },
41}
42
43/// Error from running autotune.
44#[derive(Debug)]
45pub enum AutotuneError {
46    /// An unknown error happened.
47    Unknown(String),
48    /// An error catched with panic unwind.
49    PanicUnwind(ManuallyDrop<Box<dyn Any + Send>>),
50}
51
52impl From<String> for AutotuneError {
53    fn from(value: String) -> Self {
54        Self::Unknown(value)
55    }
56}
57
58#[allow(clippy::new_without_default)]
59impl<K: AutotuneKey> Tuner<K> {
60    /// Returns a tuner with cache initialized from persistent cache
61    pub fn new(name: &str, device_id: &str) -> Self {
62        let channel = async_channel::unbounded();
63
64        Self {
65            tune_cache: TuneCache::new(name, device_id),
66            channel,
67        }
68    }
69
70    /// Fetch the fastest autotune operation index for an autotune key.
71    pub fn fastest(&self, key: &K) -> TuneCacheResult {
72        self.tune_cache.fastest(key)
73    }
74
75    /// Fetch the fastest autotune operation index for an autotune key and validate the checksum.
76    #[cfg(autotune_persistent_cache)]
77    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
78        self.tune_cache.validate_checksum(key, checksum)
79    }
80
81    /// Wait for async results to come in.
82    pub fn resolve(&mut self) {
83        while let Ok(msg) = self.channel.1.try_recv() {
84            match msg {
85                AutotuneMessage::Done {
86                    key,
87                    fastest_index,
88                    #[cfg(autotune_persistent_cache)]
89                    checksum,
90                } => {
91                    self.tune_cache.cache_insert(key.clone(), fastest_index);
92
93                    #[cfg(autotune_persistent_cache)]
94                    {
95                        self.tune_cache
96                            .persistent_cache_insert(key, checksum, fastest_index);
97                        self.tune_cache.save();
98                    }
99                }
100                AutotuneMessage::Starting { key } => {
101                    self.tune_cache.mark_pending(key);
102                }
103            }
104        }
105    }
106
107    /// Execute benchmarks to find out what the fastest operation is.
108    pub fn execute_autotune<
109        S: ComputeServer + 'static,
110        C: ComputeChannel<S> + 'static,
111        Out: Send + 'static,
112    >(
113        &self,
114        set: &dyn AutotuneOperationSet<K, Out>,
115        client: &ComputeClient<S, C>,
116    ) {
117        let key = set.key();
118        log::info!("Tuning {key}");
119
120        let autotunables: Vec<_> = set
121            .autotunables()
122            .into_iter()
123            .enumerate()
124            .filter(|(index, _)| set.should_run(&key, *index))
125            .collect();
126
127        let client = client.clone();
128        let sender = self.channel.0.clone();
129
130        if autotunables.len() == 1 {
131            sender
132                .try_send(AutotuneMessage::Done {
133                    key,
134                    fastest_index: autotunables[0].0,
135                    #[cfg(autotune_persistent_cache)]
136                    checksum: set.compute_checksum(),
137                })
138                .expect("Autotune results channel closed");
139            return;
140        }
141
142        sender
143            .try_send(AutotuneMessage::Starting { key: key.clone() })
144            .expect("Autotune results channel closed");
145
146        #[cfg(autotune_persistent_cache)]
147        let checksum = set.compute_checksum();
148
149        spawn_benchmark_task(async move {
150            #[derive(new, Debug)]
151            struct BenchResult {
152                name: String,
153                index: usize,
154                computation: BenchmarkComputations,
155            }
156
157            let mut bench_results = Vec::with_capacity(autotunables.len());
158
159            for (index, op) in autotunables.into_iter() {
160                let name = op.name().to_string();
161                let tuner = TuneBenchmark::new(op, client.clone());
162
163                let sample_fut = tuner.sample_durations();
164                let sample_fut = future::catch_unwind(sample_fut);
165                let result = sample_fut.await;
166
167                let result = match result {
168                    Ok(result) => result,
169                    Err(err) => {
170                        log::warn!(
171                            "Caught unknown error while benchmarking, falling back to next operation."
172                        );
173                        Err(AutotuneError::PanicUnwind(ManuallyDrop::new(err)))
174                    }
175                };
176
177                let result = result.map(|durations| {
178                    log::info!("Name: {name} => {}", durations);
179                    BenchResult::new(name, index, BenchmarkComputations::new(&durations))
180                });
181
182                bench_results.push(result);
183            }
184
185            // Panic if all tuners panicked.
186            #[cfg(all(feature = "std", not(target_family = "wasm")))]
187            if bench_results.iter().all(|result| result.is_err()) {
188                let first_error = bench_results.into_iter().next().unwrap().err().unwrap();
189
190                match first_error {
191                    AutotuneError::Unknown(reason) => panic!("{reason}"),
192                    AutotuneError::PanicUnwind(err) => {
193                        resume_unwind(ManuallyDrop::into_inner(err));
194                    }
195                }
196            }
197
198            // Finds the fastest operation (by the median time).
199            bench_results.sort_by(|a, b| {
200                let a = a
201                    .as_ref()
202                    .map(|r| r.computation.median)
203                    .unwrap_or(Duration::MAX);
204                let b = b
205                    .as_ref()
206                    .map(|r| r.computation.median)
207                    .unwrap_or(Duration::MAX);
208
209                a.cmp(&b)
210            });
211
212            // Log & send results.
213            let result = bench_results.first().expect("At least one kernel needed. ");
214
215            let fastest_index = if let Ok(result) = result {
216                let top_times = bench_results
217                    .iter()
218                    .map(|r| {
219                        r.as_ref()
220                            .map(|r| r.computation.median)
221                            .unwrap_or(Duration::MAX)
222                    })
223                    .take(3)
224                    .collect::<Vec<_>>();
225                log::info!(
226                    "Fastest result {}-{key}. \n Top 3 times: {top_times:?}",
227                    result.name,
228                );
229
230                result.index
231            } else {
232                0
233            };
234
235            sender
236                .send(AutotuneMessage::Done {
237                    key,
238                    fastest_index,
239                    #[cfg(autotune_persistent_cache)]
240                    checksum,
241                })
242                .await
243                .expect("Autotune results channel closed");
244        });
245    }
246}
247
248fn spawn_benchmark_task(future: impl Future<Output = ()> + Send + 'static) {
249    // On wasm, spawn the tuning as a detached task.
250    #[cfg(target_family = "wasm")]
251    wasm_bindgen_futures::spawn_local(future);
252
253    // On native, it is possible to run the tuning on a thread, which could help startup times,
254    // but might have two downsides:
255    // - Benchmarks would need a "warmup" time until a good kernel is selected.
256    // - Tuning could be less precise, as it's possible that other operations are
257    //   submitted while tuning, which might skew results.
258    //
259    // So, for now, just block on the future.
260    #[cfg(not(target_family = "wasm"))]
261    future::block_on(future);
262}