cubecl_runtime/tune/
tuner.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::format::format_debug;
7use cubecl_common::profile::ProfileDuration;
8use hashbrown::HashSet;
9
10use core::time::Duration;
11
12use alloc::string::{String, ToString};
13use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
14
15use crate::config::{Logger, autotune::AutotuneLogLevel};
16use crate::server::LaunchError;
17use crate::tune::{TuneBenchmark, TuneCache};
18use crate::{client::ComputeClient, runtime::Runtime};
19
20use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
21
22#[derive(Debug)]
23/// Executes autotune benchmarking and caching
24pub struct Tuner<K: AutotuneKey> {
25    tune_cache: TuneCache<K>,
26    logger: Logger,
27    channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
28    pub(crate) autotuning: HashSet<K>,
29}
30
31/// The measured outcome for a given autotune invocation.
32#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize, PartialEq, Eq))]
33#[derive(new, Debug, Clone)]
34pub struct AutotuneOutcome {
35    name: String,
36    index: usize,
37    computation: BenchmarkComputations,
38}
39
40impl core::fmt::Display for AutotuneOutcome {
41    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42        write!(
43            f,
44            "Autotune[{}] name {} => {:?}",
45            self.index, self.name, self.computation
46        )
47    }
48}
49
50enum AutotuneMessage<K> {
51    Done {
52        key: K,
53        fastest_index: usize,
54        results: Vec<Result<AutotuneOutcome, AutotuneError>>,
55        #[cfg(std_io)]
56        checksum: String,
57        context_logs: Option<String>,
58    },
59    #[allow(dead_code)]
60    Pending(K),
61}
62
63/// Error from running autotune.
64#[derive(Debug, PartialEq, Eq, Clone)]
65#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
66pub enum AutotuneError {
67    /// An unknown error happened.
68    Unknown {
69        /// The name of the tunable.
70        name: String,
71        /// The unknown error,
72        err: String,
73    },
74    /// All samples are invalid.
75    InvalidSamples {
76        /// The name of the tunable.
77        name: String,
78    },
79    /// No autotune was flagged as valid for the problem.
80    ///
81    /// # Warning
82    ///
83    /// This is an unrecoverable error and will cause a panic.
84    NoValidKernelFound {
85        /// The formatted context on why no valid kernel was found.
86        context: String,
87    },
88    /// The autotune is skipped manually.
89    Skip {
90        /// The name of the skipped kernel.
91        name: String,
92    },
93
94    /// An error happened when launching a kernel.
95    Launch(LaunchError),
96}
97
98impl From<LaunchError> for AutotuneError {
99    fn from(value: LaunchError) -> Self {
100        Self::Launch(value)
101    }
102}
103
104#[allow(clippy::new_without_default)]
105impl<K: AutotuneKey> Tuner<K> {
106    /// Returns a tuner with cache initialized from persistent cache
107    pub fn new(name: &str, device_id: &str) -> Self {
108        let channel = async_channel::unbounded();
109
110        Self {
111            tune_cache: TuneCache::new(name, device_id),
112            logger: Logger::new(),
113            channel,
114            autotuning: HashSet::new(),
115        }
116    }
117
118    /// Fetch the fastest autotune operation index for an autotune key.
119    pub fn fastest(&self, key: &K) -> TuneCacheResult {
120        self.tune_cache.fastest(key)
121    }
122
123    /// Fetch the fastest autotune operation index for an autotune key and validate the checksum.
124    #[cfg(std_io)]
125    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
126        if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
127            self.logger
128                .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
129        }
130        self.tune_cache.validate_checksum(key, checksum)
131    }
132
133    /// Handle an autotune result message, see [`execute_autotune`]
134    fn handle_result(&mut self, msg: AutotuneMessage<K>) {
135        match msg {
136            AutotuneMessage::Pending(key) => {
137                self.tune_cache.mark_pending(key);
138            }
139            AutotuneMessage::Done {
140                key,
141                fastest_index,
142                results,
143                #[cfg(std_io)]
144                checksum,
145                context_logs,
146            } => {
147                match self.logger.log_level_autotune() {
148                    AutotuneLogLevel::Minimal => {
149                        let top_times = results
150                            .iter()
151                            .map(|r| {
152                                let time = r
153                                    .as_ref()
154                                    .map(|r| r.computation.median)
155                                    .unwrap_or(Duration::MAX);
156
157                                let index = r.as_ref().map(|r| r.index).unwrap_or_default();
158                                (index, time)
159                            })
160                            .take(3)
161                            .collect::<Vec<_>>();
162
163                        let result = results
164                            .first()
165                            .expect("At least one kernel needed.")
166                            .as_ref()
167                            .expect("At least one kernel has to succeed.");
168
169                        let context = match &context_logs {
170                            Some(context) => context,
171                            None => "",
172                        };
173                        self.logger.log_autotune(&format!(
174                            "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}",
175                            result.name,
176                        ));
177                    }
178                    AutotuneLogLevel::Full => {
179                        let result = results
180                            .first()
181                            .expect("At least one kernel needed.")
182                            .as_ref()
183                            .expect("At least one kernel has to succeed.");
184
185                        let context = match &context_logs {
186                            Some(context) => context,
187                            None => "",
188                        };
189                        self.logger.log_autotune(&format!(
190                            "Fastest result {}-{key}. Context: {context}",
191                            result.name,
192                        ));
193
194                        for result in results.iter() {
195                            match result {
196                                Ok(val) => {
197                                    self.logger.log_autotune(&format!("{val}"));
198                                }
199                                Err(err) => self.logger.log_autotune(&format!("{err:?}")),
200                            }
201                        }
202                    }
203                    AutotuneLogLevel::Disabled => {}
204                };
205
206                self.tune_cache.cache_insert(key.clone(), fastest_index);
207
208                #[cfg(std_io)]
209                {
210                    self.tune_cache
211                        .persistent_cache_insert(key, checksum, fastest_index, results);
212                }
213            }
214        }
215    }
216
217    /// Check if any autotuning results have come in asynchronously.
218    pub fn handle_results(&mut self) {
219        // Handle any results that have come in. Note that execute_autotune pushes results to the channel immediately if possible.
220        // Since this function takes an &mut we know we have exclusive access, and no other threads are currently still adding results.
221        while let Ok(msg) = self.channel.1.try_recv() {
222            self.handle_result(msg);
223        }
224    }
225
226    /// Execute benchmarks to find out what the fastest operation is.
227    pub fn prepare_autotune<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput>(
228        &self,
229        key: K,
230        inputs: &In,
231        tunables: &TunableSet<K, In, Out>,
232        client: &ComputeClient<R>,
233    ) -> Box<dyn FnOnce()> {
234        log::info!("Tuning {key}");
235
236        // Note that this message will be processed straight away by handle_results.
237        let sender = self.channel.0.clone();
238
239        let autotunables = tunables.autotunables();
240        let mut results = Vec::with_capacity(autotunables.len());
241
242        for a in autotunables.iter() {
243            results.push(Err(AutotuneError::Skip {
244                name: a.name().to_string(),
245            }));
246        }
247
248        if autotunables.len() == 1 {
249            let message = AutotuneMessage::Done {
250                key,
251                fastest_index: 0,
252                results,
253                #[cfg(std_io)]
254                checksum: tunables.compute_checksum(),
255                context_logs: None,
256            };
257
258            return Box::new(move || {
259                sender
260                    .try_send(message)
261                    .expect("Loss message channel somehow")
262            });
263        }
264
265        let client = client.clone();
266        let key_cloned = key.clone();
267        let plan = tunables.plan(&key);
268        let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
269
270        #[cfg(std_io)]
271        let checksum = tunables.compute_checksum();
272        let context_logs = match self.logger.log_level_autotune() {
273            AutotuneLogLevel::Disabled => false,
274            AutotuneLogLevel::Minimal => false,
275            AutotuneLogLevel::Full => true,
276        };
277
278        let fut_result = async move {
279            let test_inputs = inputs_generator();
280
281            Self::generate_tune_message(
282                key_cloned,
283                &client,
284                plan,
285                autotunables,
286                test_inputs,
287                results,
288                #[cfg(std_io)]
289                checksum,
290                context_logs,
291            )
292            .await
293        };
294
295        Box::new(move || {
296            let message = {
297                cfg_if::cfg_if! {
298                    if #[cfg(target_family = "wasm")] {
299                        let sender = sender.clone();
300
301                        let send_fut = async move {
302                            // If the channel has been closed, ignore. Maybe the main app is exiting
303                            // before the tune results come in.
304                            let _ = sender.send(fut_result.await).await;
305                        };
306                        // On wasm, spawn the tuning as a detached task.
307                        wasm_bindgen_futures::spawn_local(send_fut);
308                        // Mark the current tuning as pending.
309                        AutotuneMessage::Pending(key)
310                    } else {
311                        cubecl_common::future::block_on(fut_result)
312                    }
313                }
314            };
315
316            // Note that this message will be processed straight away by handle_results.
317            sender
318                .try_send(message)
319                .expect("Loss message channel somehow");
320        })
321    }
322
323    #[allow(clippy::too_many_arguments)]
324    async fn generate_tune_message<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
325        key: K,
326        client: &ComputeClient<R>,
327        mut plan: TunePlan,
328        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
329        test_inputs: In,
330        mut results: Vec<Result<AutotuneOutcome, AutotuneError>>,
331        #[cfg(std_io)] checksum: String,
332        context_logs: bool,
333    ) -> AutotuneMessage<K> {
334        let context_logs = match Self::execute_tune_plan(
335            client,
336            &mut plan,
337            autotunables,
338            &test_inputs,
339            &mut results,
340            context_logs,
341        )
342        .await
343        {
344            Ok(context_logs) => context_logs,
345            Err(err) => {
346                panic!("Can't execute the autotune plan for key: {key:?}\n - Error: {err:?}");
347            }
348        };
349
350        // Finds the fastest operation (by the median time).
351        results.sort_by(|a, b| {
352            let a = a
353                .as_ref()
354                .map(|r| r.computation.median)
355                .unwrap_or(Duration::MAX);
356            let b = b
357                .as_ref()
358                .map(|r| r.computation.median)
359                .unwrap_or(Duration::MAX);
360
361            a.cmp(&b)
362        });
363
364        // Log & send results.
365        let result = results
366            .first()
367            .expect("At least one kernel needed.")
368            .as_ref()
369            .expect("At least one kernel has to succeed.");
370
371        AutotuneMessage::Done {
372            key,
373            fastest_index: result.index,
374            results,
375            #[cfg(std_io)]
376            checksum,
377            context_logs,
378        }
379    }
380
381    async fn execute_tune_plan<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
382        client: &ComputeClient<R>,
383        plan: &mut TunePlan,
384        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
385        test_inputs: &In,
386        results: &mut [Result<AutotuneOutcome, AutotuneError>],
387        context_logs: bool,
388    ) -> Result<Option<String>, AutotuneError> {
389        #[derive(Debug)]
390        #[allow(unused_variables, dead_code)] // Only use for debug
391        struct Context<'a> {
392            plan: &'a TunePlan,
393            results: &'a [Result<AutotuneOutcome, AutotuneError>],
394        }
395
396        let mut context_logs = match context_logs {
397            true => Some("".to_string()),
398            false => None,
399        };
400
401        loop {
402            let mut num_success = 0;
403            let tunable_indices = plan.next(context_logs.as_mut());
404
405            if tunable_indices.is_empty() {
406                return Err(AutotuneError::NoValidKernelFound {
407                    context: format_debug(&Context { plan, results }),
408                });
409            }
410
411            for index in tunable_indices {
412                let op = &autotunables[index];
413                let name = op.name().to_string();
414                let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
415                let profiles = tuner.profile().map(|bench| (name, index, bench));
416
417                match profiles {
418                    Ok(result) => {
419                        // Wait for the results to come in, and determine the outcome.
420                        let (name, index, profiles) = result;
421                        let result = Self::process_autotune(name, index, profiles).await;
422                        match result {
423                            Ok(val) => {
424                                results[index] = Ok(val);
425                                num_success += 1;
426                            }
427                            Err(err) => {
428                                results[index] = Err(err);
429                            }
430                        }
431                    }
432                    Err(err) => {
433                        results[index] = Err(err);
434                    }
435                }
436            }
437
438            if num_success > 0 {
439                break;
440            }
441        }
442
443        Ok(context_logs)
444    }
445
446    async fn process_autotune(
447        name: String,
448        index: usize,
449        profiles: Vec<ProfileDuration>,
450    ) -> Result<AutotuneOutcome, AutotuneError> {
451        let mut durations = Vec::new();
452        if !profiles.is_empty() {
453            let timing_method = profiles.first().unwrap().timing_method();
454            for profile in profiles {
455                durations.push(profile.resolve().await.duration());
456            }
457            let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
458
459            Ok(AutotuneOutcome::new(
460                name,
461                index,
462                BenchmarkComputations::new(&bench_durations),
463            ))
464        } else {
465            Err(AutotuneError::Unknown {
466                name,
467                err: "No profiling available".to_string(),
468            })
469        }
470    }
471}
472
473#[cfg(feature = "autotune-checks")]
474pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
475    mut checks_outputs: Vec<Result<O, AutotuneError>>,
476) {
477    let reference = checks_outputs.remove(checks_outputs.len() - 1);
478
479    if let Ok(reference) = reference {
480        for other in checks_outputs.into_iter().flatten() {
481            reference.check_equivalence(other);
482        }
483    }
484}