Skip to main content

cubecl_runtime/tune/
tuner.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::profile::ProfileDuration;
7use hashbrown::HashSet;
8
9use core::time::Duration;
10
11use alloc::string::{String, ToString};
12use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
13
14use crate::config::{Logger, autotune::AutotuneLogLevel};
15use crate::server::LaunchError;
16use crate::tune::{AutotuneResult, TuneBenchmark, TuneCache};
17use crate::{client::ComputeClient, runtime::Runtime};
18
19use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
20
21#[derive(Debug)]
22/// Executes autotune benchmarking and caching
23pub struct Tuner<K: AutotuneKey> {
24    tune_cache: TuneCache<K>,
25    logger: Logger,
26    channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
27    pub(crate) autotuning: HashSet<K>,
28}
29
30/// The measured outcome for a given autotune invocation.
31#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
32#[derive(new, Debug, Clone, PartialEq, Eq)]
33pub struct AutotuneOutcome {
34    name: String,
35    index: usize,
36    computation: BenchmarkComputations,
37}
38
39impl core::fmt::Display for AutotuneOutcome {
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        write!(
42            f,
43            "Autotune[{}] name {} => {:?}",
44            self.index, self.name, self.computation
45        )
46    }
47}
48
49enum AutotuneMessage<K> {
50    Done {
51        key: K,
52        fastest_index: usize,
53        results: Vec<AutotuneResult>,
54        #[cfg(std_io)]
55        checksum: String,
56        context_logs: Option<String>,
57    },
58    #[allow(dead_code)]
59    Pending(K),
60}
61
62/// Error from running autotune.
63#[derive(Debug, Clone)]
64#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
65pub enum AutotuneError {
66    /// An unknown error happened.
67    Unknown {
68        /// The name of the tunable.
69        name: String,
70        /// The unknown error,
71        err: String,
72    },
73    /// All samples are invalid.
74    InvalidSamples {
75        /// The name of the tunable.
76        name: String,
77    },
78    /// No autotune was flagged as valid for the problem.
79    ///
80    /// # Warning
81    ///
82    /// This is an unrecoverable error and will cause a panic.
83    NoValidKernelFound {
84        /// The formatted context on why no valid kernel was found.
85        context: String,
86    },
87    /// The autotune is skipped manually.
88    Skip {
89        /// The name of the skipped kernel.
90        name: String,
91    },
92
93    /// An error happened when launching a kernel.
94    Launch(LaunchError),
95}
96
97impl From<LaunchError> for AutotuneError {
98    fn from(value: LaunchError) -> Self {
99        Self::Launch(value)
100    }
101}
102
103#[allow(clippy::new_without_default)]
104impl<K: AutotuneKey> Tuner<K> {
105    /// Returns a tuner with cache initialized from persistent cache
106    pub fn new(name: &str, device_id: &str) -> Self {
107        let channel = async_channel::unbounded();
108
109        Self {
110            tune_cache: TuneCache::new(name, device_id),
111            logger: Logger::new(),
112            channel,
113            autotuning: HashSet::new(),
114        }
115    }
116
117    /// Fetch the fastest autotune operation index for an autotune key.
118    pub fn fastest(&self, key: &K) -> TuneCacheResult {
119        self.tune_cache.fastest(key)
120    }
121
122    /// Fetch the fastest autotune operation index for an autotune key and validate the checksum.
123    #[cfg(std_io)]
124    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
125        if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
126            self.logger
127                .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
128        }
129        self.tune_cache.validate_checksum(key, checksum)
130    }
131
132    /// Handle an autotune result message, see [`execute_autotune`]
133    fn handle_result(&mut self, msg: AutotuneMessage<K>) {
134        match msg {
135            AutotuneMessage::Pending(key) => {
136                self.tune_cache.mark_pending(key);
137            }
138            AutotuneMessage::Done {
139                key,
140                fastest_index,
141                results,
142                #[cfg(std_io)]
143                checksum,
144                context_logs,
145            } => {
146                match self.logger.log_level_autotune() {
147                    AutotuneLogLevel::Minimal => {
148                        let top_times = results
149                            .iter()
150                            .map(|r| {
151                                let time = r
152                                    .outcome
153                                    .as_ref()
154                                    .map(|r| r.computation.median)
155                                    .unwrap_or(Duration::MAX);
156
157                                let index = r.outcome.as_ref().map(|r| r.index).unwrap_or_default();
158                                (index, time)
159                            })
160                            .take(3)
161                            .collect::<Vec<_>>();
162
163                        let result = results
164                            .first()
165                            .expect("At least one kernel needed.")
166                            .outcome
167                            .as_ref()
168                            .expect("At least one kernel has to succeed.");
169
170                        let context = match &context_logs {
171                            Some(context) => context,
172                            None => "",
173                        };
174                        self.logger.log_autotune(&format!(
175                            "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}",
176                            result.name,
177                        ));
178                    }
179                    AutotuneLogLevel::Full => {
180                        let result = results
181                            .first()
182                            .expect("At least one kernel needed.")
183                            .outcome
184                            .as_ref()
185                            .expect("At least one kernel has to succeed.");
186
187                        let context = match &context_logs {
188                            Some(context) => context,
189                            None => "",
190                        };
191                        self.logger.log_autotune(&format!(
192                            "Fastest result {}-{key}. Context: {context}",
193                            result.name,
194                        ));
195
196                        for result in results.iter() {
197                            match &result.outcome {
198                                Ok(val) => {
199                                    self.logger.log_autotune(&format!("{val}"));
200                                }
201                                Err(err) => self.logger.log_autotune(&format!("{err:?}")),
202                            }
203                        }
204                    }
205                    AutotuneLogLevel::Disabled => {}
206                };
207
208                self.tune_cache.cache_insert(key.clone(), fastest_index);
209
210                #[cfg(std_io)]
211                {
212                    self.tune_cache
213                        .persistent_cache_insert(key, checksum, fastest_index, results);
214                }
215            }
216        }
217    }
218
219    /// Check if any autotuning results have come in asynchronously.
220    pub fn handle_results(&mut self) {
221        // Handle any results that have come in. Note that execute_autotune pushes results to the channel immediately if possible.
222        // Since this function takes an &mut we know we have exclusive access, and no other threads are currently still adding results.
223        while let Ok(msg) = self.channel.1.try_recv() {
224            self.handle_result(msg);
225        }
226    }
227
228    /// Execute benchmarks to find out what the fastest operation is.
229    pub fn prepare_autotune<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput>(
230        &self,
231        key: K,
232        inputs: &In,
233        tunables: &TunableSet<K, In, Out>,
234        client: &ComputeClient<R>,
235    ) -> Box<dyn FnOnce()> {
236        log::info!("Tuning {key}");
237
238        // Note that this message will be processed straight away by handle_results.
239        let sender = self.channel.0.clone();
240
241        let autotunables = tunables.autotunables();
242        let mut results: Vec<AutotuneResult> = Vec::with_capacity(autotunables.len());
243
244        for a in autotunables.iter() {
245            results.push(AutotuneResult::error(AutotuneError::Skip {
246                name: a.name().to_string(),
247            }));
248        }
249
250        if autotunables.len() == 1 {
251            let message = AutotuneMessage::Done {
252                key,
253                fastest_index: 0,
254                results,
255                #[cfg(std_io)]
256                checksum: tunables.compute_checksum(),
257                context_logs: None,
258            };
259
260            return Box::new(move || {
261                sender
262                    .try_send(message)
263                    .expect("Loss message channel somehow")
264            });
265        }
266
267        let client = client.clone();
268        let key_cloned = key.clone();
269        let plan = tunables.plan(&key);
270        let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
271
272        #[cfg(std_io)]
273        let checksum = tunables.compute_checksum();
274        let context_logs = match self.logger.log_level_autotune() {
275            AutotuneLogLevel::Disabled => false,
276            AutotuneLogLevel::Minimal => false,
277            AutotuneLogLevel::Full => true,
278        };
279
280        let fut_result = async move {
281            let test_inputs = inputs_generator();
282
283            Self::generate_tune_message(
284                key_cloned,
285                &client,
286                plan,
287                autotunables,
288                test_inputs,
289                results,
290                #[cfg(std_io)]
291                checksum,
292                context_logs,
293            )
294            .await
295        };
296
297        Box::new(move || {
298            let message = {
299                cfg_if::cfg_if! {
300                    if #[cfg(target_family = "wasm")] {
301                        let sender = sender.clone();
302
303                        let send_fut = async move {
304                            // If the channel has been closed, ignore. Maybe the main app is exiting
305                            // before the tune results come in.
306                            let _ = sender.send(fut_result.await).await;
307                        };
308                        // On wasm, spawn the tuning as a detached task.
309                        wasm_bindgen_futures::spawn_local(send_fut);
310                        // Mark the current tuning as pending.
311                        AutotuneMessage::Pending(key)
312                    } else {
313                        cubecl_common::future::block_on(fut_result)
314                    }
315                }
316            };
317
318            // Note that this message will be processed straight away by handle_results.
319            sender
320                .try_send(message)
321                .expect("Loss message channel somehow");
322        })
323    }
324
325    #[allow(clippy::too_many_arguments)]
326    async fn generate_tune_message<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
327        key: K,
328        client: &ComputeClient<R>,
329        mut plan: TunePlan,
330        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
331        test_inputs: In,
332        mut results: Vec<AutotuneResult>,
333        #[cfg(std_io)] checksum: String,
334        context_logs: bool,
335    ) -> AutotuneMessage<K> {
336        let context_logs = match Self::execute_tune_plan(
337            client,
338            &mut plan,
339            autotunables,
340            &test_inputs,
341            &mut results,
342            context_logs,
343        )
344        .await
345        {
346            Ok(context_logs) => context_logs,
347            Err(err) => {
348                panic!("Can't execute the autotune plan for key: {key:?}\n - Error: {err:?}");
349            }
350        };
351
352        // Finds the fastest operation.
353        results.sort_by(|a, b| {
354            let a = a
355                .outcome
356                .as_ref()
357                .map(|r| r.computation.score())
358                .unwrap_or(u64::MAX);
359            let b = b
360                .outcome
361                .as_ref()
362                .map(|r| r.computation.score())
363                .unwrap_or(u64::MAX);
364
365            a.cmp(&b)
366        });
367
368        // Log & send results.
369        let result = results
370            .first()
371            .expect("At least one kernel needed.")
372            .outcome
373            .as_ref()
374            .expect("At least one kernel has to succeed.");
375
376        AutotuneMessage::Done {
377            key,
378            fastest_index: result.index,
379            results,
380            #[cfg(std_io)]
381            checksum,
382            context_logs,
383        }
384    }
385
386    async fn execute_tune_plan<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
387        client: &ComputeClient<R>,
388        plan: &mut TunePlan,
389        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
390        test_inputs: &In,
391        results: &mut [AutotuneResult],
392        context_logs: bool,
393    ) -> Result<Option<String>, AutotuneError> {
394        #[derive(Debug)]
395        #[allow(unused_variables, dead_code)] // Only use for debug
396        struct Context<'a> {
397            plan: &'a TunePlan,
398            results: &'a [AutotuneResult],
399        }
400
401        let mut context_logs = match context_logs {
402            true => Some("".to_string()),
403            false => None,
404        };
405
406        loop {
407            let mut num_success = 0;
408            let tunable_indices = plan.next(context_logs.as_mut());
409
410            if tunable_indices.is_empty() {
411                return Err(AutotuneError::NoValidKernelFound {
412                    context: format!("{:?}", &Context { plan, results }),
413                });
414            }
415
416            for index in tunable_indices {
417                let op = &autotunables[index];
418                let name = op.name().to_string();
419                let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
420                let profiles = tuner.profile().map(|bench| (name, index, bench));
421
422                match profiles {
423                    Ok(result) => {
424                        // Wait for the results to come in, and determine the outcome.
425                        let (name, index, profiles) = result;
426                        let result = Self::process_autotune(name, index, profiles).await;
427                        match result {
428                            Ok(val) => {
429                                results[index] = AutotuneResult::success(val);
430                                num_success += 1;
431                            }
432                            Err(err) => {
433                                results[index] = AutotuneResult::error(err);
434                            }
435                        }
436                    }
437                    Err(err) => {
438                        results[index] = AutotuneResult::error(err);
439                    }
440                }
441            }
442
443            if num_success > 0 {
444                break;
445            }
446        }
447
448        Ok(context_logs)
449    }
450
451    async fn process_autotune(
452        name: String,
453        index: usize,
454        profiles: Vec<ProfileDuration>,
455    ) -> Result<AutotuneOutcome, AutotuneError> {
456        let mut durations = Vec::new();
457        if !profiles.is_empty() {
458            let timing_method = profiles.first().unwrap().timing_method();
459            for profile in profiles {
460                durations.push(profile.resolve().await.duration());
461            }
462            let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
463
464            Ok(AutotuneOutcome::new(
465                name,
466                index,
467                BenchmarkComputations::new(&bench_durations),
468            ))
469        } else {
470            Err(AutotuneError::Unknown {
471                name,
472                err: "No profiling available".to_string(),
473            })
474        }
475    }
476}
477
478#[cfg(feature = "autotune-checks")]
479pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
480    mut checks_outputs: Vec<Result<O, AutotuneError>>,
481) {
482    let reference = checks_outputs.remove(checks_outputs.len() - 1);
483
484    if let Ok(reference) = reference {
485        for other in checks_outputs.into_iter().flatten() {
486            reference.check_equivalence(other);
487        }
488    }
489}