cubecl_runtime/tune/
tuner.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::format::format_debug;
7use cubecl_common::profile::ProfileDuration;
8use hashbrown::HashSet;
9
10use core::time::Duration;
11
12use alloc::string::{String, ToString};
13use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
14
15use crate::client::ComputeClient;
16use crate::config::{Logger, autotune::AutotuneLogLevel};
17use crate::server::ComputeServer;
18use crate::tune::{TuneBenchmark, TuneCache};
19
20use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
21
22#[derive(Debug)]
23/// Executes autotune benchmarking and caching
24pub struct Tuner<K: AutotuneKey> {
25    tune_cache: TuneCache<K>,
26    logger: Logger,
27    channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
28    pub(crate) autotuning: HashSet<K>,
29}
30
31/// The measured outcome for a given autotune invocation.
32#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize, PartialEq, Eq))]
33#[derive(new, Debug, Clone)]
34pub struct AutotuneOutcome {
35    name: String,
36    index: usize,
37    computation: BenchmarkComputations,
38}
39
40impl core::fmt::Display for AutotuneOutcome {
41    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42        write!(
43            f,
44            "Autotune[{}] name {} => {:?}",
45            self.index, self.name, self.computation
46        )
47    }
48}
49
50enum AutotuneMessage<K> {
51    Done {
52        key: K,
53        fastest_index: usize,
54        results: Vec<Result<AutotuneOutcome, AutotuneError>>,
55        #[cfg(std_io)]
56        checksum: String,
57    },
58    #[allow(dead_code)]
59    Pending(K),
60}
61
62/// Error from running autotune.
63#[derive(Debug, PartialEq, Eq, Clone)]
64#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
65pub enum AutotuneError {
66    /// An unknown error happened.
67    Unknown(String),
68    /// All samples are invalid.
69    InvalidSamples,
70    /// No autotune was flagged as valid for the problem.
71    ///
72    /// # Warning
73    ///
74    /// This is an unrecoverable error and will cause a panic.
75    NoValidKernelFound {
76        /// The formatted context on why no valid kernel was found.
77        context: String,
78    },
79    /// The autotune is skipped manually.
80    Skip,
81}
82
83impl From<String> for AutotuneError {
84    fn from(value: String) -> Self {
85        Self::Unknown(value)
86    }
87}
88
89#[allow(clippy::new_without_default)]
90impl<K: AutotuneKey> Tuner<K> {
91    /// Returns a tuner with cache initialized from persistent cache
92    pub fn new(name: &str, device_id: &str) -> Self {
93        let channel = async_channel::unbounded();
94
95        Self {
96            tune_cache: TuneCache::new(name, device_id),
97            logger: Logger::new(),
98            channel,
99            autotuning: HashSet::new(),
100        }
101    }
102
103    /// Fetch the fastest autotune operation index for an autotune key.
104    pub fn fastest(&self, key: &K) -> TuneCacheResult {
105        self.tune_cache.fastest(key)
106    }
107
108    /// Fetch the fastest autotune operation index for an autotune key and validate the checksum.
109    #[cfg(std_io)]
110    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
111        if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
112            self.logger
113                .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
114        }
115        self.tune_cache.validate_checksum(key, checksum)
116    }
117
118    /// Handle an autotune result message, see [`execute_autotune`]
119    fn handle_result(&mut self, msg: AutotuneMessage<K>) {
120        match msg {
121            AutotuneMessage::Pending(key) => {
122                self.tune_cache.mark_pending(key);
123            }
124            AutotuneMessage::Done {
125                key,
126                fastest_index,
127                results,
128                #[cfg(std_io)]
129                checksum,
130            } => {
131                match self.logger.log_level_autotune() {
132                    AutotuneLogLevel::Minimal => {
133                        let top_times = results
134                            .iter()
135                            .map(|r| {
136                                let time = r
137                                    .as_ref()
138                                    .map(|r| r.computation.median)
139                                    .unwrap_or(Duration::MAX);
140
141                                let index = r.as_ref().map(|r| r.index).unwrap_or_default();
142                                (index, time)
143                            })
144                            .take(3)
145                            .collect::<Vec<_>>();
146
147                        let result = results
148                            .first()
149                            .expect("At least one kernel needed.")
150                            .as_ref()
151                            .expect("At least one kernel has to succeed.");
152
153                        self.logger.log_autotune(&format!(
154                            "Fastest result {}-{key}. \n Top 3 times: {top_times:?}",
155                            result.name,
156                        ));
157                    }
158                    AutotuneLogLevel::Full => {
159                        let result = results
160                            .first()
161                            .expect("At least one kernel needed.")
162                            .as_ref()
163                            .expect("At least one kernel has to succeed.");
164
165                        self.logger
166                            .log_autotune(&format!("Fastest result {}-{key}.", result.name,));
167
168                        for result in results.iter() {
169                            match result {
170                                Ok(val) => {
171                                    self.logger.log_autotune(&format!("{val}"));
172                                }
173                                Err(err) => self.logger.log_autotune(&format!("{err:?}")),
174                            }
175                        }
176                    }
177                    AutotuneLogLevel::Disabled => {}
178                };
179
180                self.tune_cache.cache_insert(key.clone(), fastest_index);
181
182                #[cfg(std_io)]
183                {
184                    self.tune_cache
185                        .persistent_cache_insert(key, checksum, fastest_index, results);
186                }
187            }
188        }
189    }
190
191    /// Check if any autotuning results have come in asynchronously.
192    pub fn handle_results(&mut self) {
193        // Handle any results that have come in. Note that execute_autotune pushes results to the channel immediately if possible.
194        // Since this function takes an &mut we know we have exclusive access, and no other threads are currently still adding results.
195        while let Ok(msg) = self.channel.1.try_recv() {
196            self.handle_result(msg);
197        }
198    }
199
200    /// Execute benchmarks to find out what the fastest operation is.
201    pub fn prepare_autotune<
202        S: ComputeServer + 'static,
203        In: Clone + Send + 'static,
204        Out: AutotuneOutput,
205    >(
206        &self,
207        key: K,
208        inputs: &In,
209        tunables: &TunableSet<K, In, Out>,
210        client: &ComputeClient<S>,
211    ) -> Box<dyn FnOnce()> {
212        log::info!("Tuning {key}");
213
214        // Note that this message will be processed straight away by handle_results.
215        let sender = self.channel.0.clone();
216
217        let autotunables = tunables.autotunables();
218        let mut results = Vec::with_capacity(autotunables.len());
219
220        for _ in 0..autotunables.len() {
221            results.push(Err(AutotuneError::Skip));
222        }
223
224        if autotunables.len() == 1 {
225            let message = AutotuneMessage::Done {
226                key,
227                fastest_index: 0,
228                results,
229                #[cfg(std_io)]
230                checksum: tunables.compute_checksum(),
231            };
232
233            return Box::new(move || {
234                sender
235                    .try_send(message)
236                    .expect("Loss message channel somehow")
237            });
238        }
239
240        let client = client.clone();
241        let key_cloned = key.clone();
242        let plan = tunables.plan(&key);
243        let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
244
245        #[cfg(std_io)]
246        let checksum = tunables.compute_checksum();
247
248        let fut_result = async move {
249            let test_inputs = inputs_generator();
250
251            Self::generate_tune_message(
252                key_cloned,
253                &client,
254                plan,
255                autotunables,
256                test_inputs,
257                results,
258                #[cfg(std_io)]
259                checksum,
260            )
261            .await
262        };
263
264        Box::new(move || {
265            let message = {
266                cfg_if::cfg_if! {
267                    if #[cfg(target_family = "wasm")] {
268                        let sender = sender.clone();
269
270                        let send_fut = async move {
271                            // If the channel has been closed, ignore. Maybe the main app is exiting
272                            // before the tune results come in.
273                            let _ = sender.send(fut_result.await).await;
274                        };
275                        // On wasm, spawn the tuning as a detached task.
276                        wasm_bindgen_futures::spawn_local(send_fut);
277                        // Mark the current tuning as pending.
278                        AutotuneMessage::Pending(key)
279                    } else {
280                        cubecl_common::future::block_on(fut_result)
281                    }
282                }
283            };
284
285            // Note that this message will be processed straight away by handle_results.
286            sender
287                .try_send(message)
288                .expect("Loss message channel somehow");
289        })
290    }
291
292    async fn generate_tune_message<
293        In: Clone + Send + 'static,
294        Out: AutotuneOutput,
295        S: ComputeServer + 'static,
296    >(
297        key: K,
298        client: &ComputeClient<S>,
299        mut plan: TunePlan,
300        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
301        test_inputs: In,
302        mut results: Vec<Result<AutotuneOutcome, AutotuneError>>,
303        #[cfg(std_io)] checksum: String,
304    ) -> AutotuneMessage<K> {
305        match Self::execute_tune_plan(client, &mut plan, autotunables, &test_inputs, &mut results)
306            .await
307        {
308            Ok(_) => {}
309            Err(err) => {
310                panic!("Can't execute the autotune plan for key: {key:?}\n - Error: {err:?}");
311            }
312        };
313
314        // Finds the fastest operation (by the median time).
315        results.sort_by(|a, b| {
316            let a = a
317                .as_ref()
318                .map(|r| r.computation.median)
319                .unwrap_or(Duration::MAX);
320            let b = b
321                .as_ref()
322                .map(|r| r.computation.median)
323                .unwrap_or(Duration::MAX);
324
325            a.cmp(&b)
326        });
327
328        // Log & send results.
329        let result = results
330            .first()
331            .expect("At least one kernel needed.")
332            .as_ref()
333            .expect("At least one kernel has to succeed.");
334
335        AutotuneMessage::Done {
336            key,
337            fastest_index: result.index,
338            results,
339            #[cfg(std_io)]
340            checksum,
341        }
342    }
343
344    async fn execute_tune_plan<
345        In: Clone + Send + 'static,
346        Out: AutotuneOutput,
347        S: ComputeServer + 'static,
348    >(
349        client: &ComputeClient<S>,
350        plan: &mut TunePlan,
351        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
352        test_inputs: &In,
353        results: &mut [Result<AutotuneOutcome, AutotuneError>],
354    ) -> Result<(), AutotuneError> {
355        #[derive(Debug)]
356        #[allow(unused_variables, dead_code)] // Only use for debug
357        struct Context<'a> {
358            plan: &'a TunePlan,
359            results: &'a [Result<AutotuneOutcome, AutotuneError>],
360        }
361
362        loop {
363            let mut num_success = 0;
364            let tunable_indices = plan.next();
365
366            if tunable_indices.is_empty() {
367                return Err(AutotuneError::NoValidKernelFound {
368                    context: format_debug(&Context { plan, results }),
369                });
370            }
371
372            for index in tunable_indices {
373                let op = &autotunables[index];
374                let name = op.name().to_string();
375                let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
376                let profiles = tuner.profile().map(|bench| (name, index, bench));
377
378                match profiles {
379                    Ok(result) => {
380                        // Wait for the results to come in, and determine the outcome.
381                        let (name, index, profiles) = result;
382                        let result = Self::process_autotune(name, index, profiles).await;
383                        match result {
384                            Ok(val) => {
385                                results[index] = Ok(val);
386                                num_success += 1;
387                            }
388                            Err(err) => {
389                                results[index] = Err(err);
390                            }
391                        }
392                    }
393                    Err(err) => {
394                        results[index] = Err(err);
395                    }
396                }
397            }
398
399            if num_success > 0 {
400                break;
401            }
402        }
403
404        Ok(())
405    }
406
407    async fn process_autotune(
408        name: String,
409        index: usize,
410        profiles: Vec<ProfileDuration>,
411    ) -> Result<AutotuneOutcome, AutotuneError> {
412        let mut durations = Vec::new();
413        if !profiles.is_empty() {
414            let timing_method = profiles.first().unwrap().timing_method();
415            for profile in profiles {
416                durations.push(profile.resolve().await.duration());
417            }
418            let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
419
420            Ok(AutotuneOutcome::new(
421                name,
422                index,
423                BenchmarkComputations::new(&bench_durations),
424            ))
425        } else {
426            Err(AutotuneError::Unknown(format!(
427                "Runtime error while profiling {name}."
428            )))
429        }
430    }
431}
432
433#[cfg(feature = "autotune-checks")]
434pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
435    mut checks_outputs: Vec<Result<O, AutotuneError>>,
436) {
437    let reference = checks_outputs.remove(checks_outputs.len() - 1);
438
439    if let Ok(reference) = reference {
440        for other in checks_outputs.into_iter().flatten() {
441            reference.check_equivalence(other);
442        }
443    }
444}