cubecl_runtime/tune/
tuner.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::format::format_debug;
7use cubecl_common::profile::ProfileDuration;
8use hashbrown::HashSet;
9
10use core::time::Duration;
11
12use alloc::string::{String, ToString};
13use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
14
15use crate::config::{Logger, autotune::AutotuneLogLevel};
16use crate::server::LaunchError;
17use crate::tune::{AutotuneResult, TuneBenchmark, TuneCache};
18use crate::{client::ComputeClient, runtime::Runtime};
19
20use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
21
22#[derive(Debug)]
23/// Executes autotune benchmarking and caching
24pub struct Tuner<K: AutotuneKey> {
25    tune_cache: TuneCache<K>,
26    logger: Logger,
27    channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
28    pub(crate) autotuning: HashSet<K>,
29}
30
31/// The measured outcome for a given autotune invocation.
32#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
33#[derive(new, Debug, Clone, PartialEq, Eq)]
34pub struct AutotuneOutcome {
35    name: String,
36    index: usize,
37    computation: BenchmarkComputations,
38}
39
40impl core::fmt::Display for AutotuneOutcome {
41    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42        write!(
43            f,
44            "Autotune[{}] name {} => {:?}",
45            self.index, self.name, self.computation
46        )
47    }
48}
49
50enum AutotuneMessage<K> {
51    Done {
52        key: K,
53        fastest_index: usize,
54        results: Vec<AutotuneResult>,
55        #[cfg(std_io)]
56        checksum: String,
57        context_logs: Option<String>,
58    },
59    #[allow(dead_code)]
60    Pending(K),
61}
62
63/// Error from running autotune.
64#[derive(Debug, Clone)]
65#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
66pub enum AutotuneError {
67    /// An unknown error happened.
68    Unknown {
69        /// The name of the tunable.
70        name: String,
71        /// The unknown error,
72        err: String,
73    },
74    /// All samples are invalid.
75    InvalidSamples {
76        /// The name of the tunable.
77        name: String,
78    },
79    /// No autotune was flagged as valid for the problem.
80    ///
81    /// # Warning
82    ///
83    /// This is an unrecoverable error and will cause a panic.
84    NoValidKernelFound {
85        /// The formatted context on why no valid kernel was found.
86        context: String,
87    },
88    /// The autotune is skipped manually.
89    Skip {
90        /// The name of the skipped kernel.
91        name: String,
92    },
93
94    /// An error happened when launching a kernel.
95    Launch(LaunchError),
96}
97
98impl From<LaunchError> for AutotuneError {
99    fn from(value: LaunchError) -> Self {
100        Self::Launch(value)
101    }
102}
103
104#[allow(clippy::new_without_default)]
105impl<K: AutotuneKey> Tuner<K> {
106    /// Returns a tuner with cache initialized from persistent cache
107    pub fn new(name: &str, device_id: &str) -> Self {
108        let channel = async_channel::unbounded();
109
110        Self {
111            tune_cache: TuneCache::new(name, device_id),
112            logger: Logger::new(),
113            channel,
114            autotuning: HashSet::new(),
115        }
116    }
117
118    /// Fetch the fastest autotune operation index for an autotune key.
119    pub fn fastest(&self, key: &K) -> TuneCacheResult {
120        self.tune_cache.fastest(key)
121    }
122
123    /// Fetch the fastest autotune operation index for an autotune key and validate the checksum.
124    #[cfg(std_io)]
125    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
126        if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
127            self.logger
128                .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
129        }
130        self.tune_cache.validate_checksum(key, checksum)
131    }
132
133    /// Handle an autotune result message, see [`execute_autotune`]
134    fn handle_result(&mut self, msg: AutotuneMessage<K>) {
135        match msg {
136            AutotuneMessage::Pending(key) => {
137                self.tune_cache.mark_pending(key);
138            }
139            AutotuneMessage::Done {
140                key,
141                fastest_index,
142                results,
143                #[cfg(std_io)]
144                checksum,
145                context_logs,
146            } => {
147                match self.logger.log_level_autotune() {
148                    AutotuneLogLevel::Minimal => {
149                        let top_times = results
150                            .iter()
151                            .map(|r| {
152                                let time = r
153                                    .outcome
154                                    .as_ref()
155                                    .map(|r| r.computation.median)
156                                    .unwrap_or(Duration::MAX);
157
158                                let index = r.outcome.as_ref().map(|r| r.index).unwrap_or_default();
159                                (index, time)
160                            })
161                            .take(3)
162                            .collect::<Vec<_>>();
163
164                        let result = results
165                            .first()
166                            .expect("At least one kernel needed.")
167                            .outcome
168                            .as_ref()
169                            .expect("At least one kernel has to succeed.");
170
171                        let context = match &context_logs {
172                            Some(context) => context,
173                            None => "",
174                        };
175                        self.logger.log_autotune(&format!(
176                            "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}",
177                            result.name,
178                        ));
179                    }
180                    AutotuneLogLevel::Full => {
181                        let result = results
182                            .first()
183                            .expect("At least one kernel needed.")
184                            .outcome
185                            .as_ref()
186                            .expect("At least one kernel has to succeed.");
187
188                        let context = match &context_logs {
189                            Some(context) => context,
190                            None => "",
191                        };
192                        self.logger.log_autotune(&format!(
193                            "Fastest result {}-{key}. Context: {context}",
194                            result.name,
195                        ));
196
197                        for result in results.iter() {
198                            match &result.outcome {
199                                Ok(val) => {
200                                    self.logger.log_autotune(&format!("{val}"));
201                                }
202                                Err(err) => self.logger.log_autotune(&format!("{err:?}")),
203                            }
204                        }
205                    }
206                    AutotuneLogLevel::Disabled => {}
207                };
208
209                self.tune_cache.cache_insert(key.clone(), fastest_index);
210
211                #[cfg(std_io)]
212                {
213                    self.tune_cache
214                        .persistent_cache_insert(key, checksum, fastest_index, results);
215                }
216            }
217        }
218    }
219
220    /// Check if any autotuning results have come in asynchronously.
221    pub fn handle_results(&mut self) {
222        // Handle any results that have come in. Note that execute_autotune pushes results to the channel immediately if possible.
223        // Since this function takes an &mut we know we have exclusive access, and no other threads are currently still adding results.
224        while let Ok(msg) = self.channel.1.try_recv() {
225            self.handle_result(msg);
226        }
227    }
228
229    /// Execute benchmarks to find out what the fastest operation is.
230    pub fn prepare_autotune<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput>(
231        &self,
232        key: K,
233        inputs: &In,
234        tunables: &TunableSet<K, In, Out>,
235        client: &ComputeClient<R>,
236    ) -> Box<dyn FnOnce()> {
237        log::info!("Tuning {key}");
238
239        // Note that this message will be processed straight away by handle_results.
240        let sender = self.channel.0.clone();
241
242        let autotunables = tunables.autotunables();
243        let mut results: Vec<AutotuneResult> = Vec::with_capacity(autotunables.len());
244
245        for a in autotunables.iter() {
246            results.push(AutotuneResult::error(AutotuneError::Skip {
247                name: a.name().to_string(),
248            }));
249        }
250
251        if autotunables.len() == 1 {
252            let message = AutotuneMessage::Done {
253                key,
254                fastest_index: 0,
255                results,
256                #[cfg(std_io)]
257                checksum: tunables.compute_checksum(),
258                context_logs: None,
259            };
260
261            return Box::new(move || {
262                sender
263                    .try_send(message)
264                    .expect("Loss message channel somehow")
265            });
266        }
267
268        let client = client.clone();
269        let key_cloned = key.clone();
270        let plan = tunables.plan(&key);
271        let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
272
273        #[cfg(std_io)]
274        let checksum = tunables.compute_checksum();
275        let context_logs = match self.logger.log_level_autotune() {
276            AutotuneLogLevel::Disabled => false,
277            AutotuneLogLevel::Minimal => false,
278            AutotuneLogLevel::Full => true,
279        };
280
281        let fut_result = async move {
282            let test_inputs = inputs_generator();
283
284            Self::generate_tune_message(
285                key_cloned,
286                &client,
287                plan,
288                autotunables,
289                test_inputs,
290                results,
291                #[cfg(std_io)]
292                checksum,
293                context_logs,
294            )
295            .await
296        };
297
298        Box::new(move || {
299            let message = {
300                cfg_if::cfg_if! {
301                    if #[cfg(target_family = "wasm")] {
302                        let sender = sender.clone();
303
304                        let send_fut = async move {
305                            // If the channel has been closed, ignore. Maybe the main app is exiting
306                            // before the tune results come in.
307                            let _ = sender.send(fut_result.await).await;
308                        };
309                        // On wasm, spawn the tuning as a detached task.
310                        wasm_bindgen_futures::spawn_local(send_fut);
311                        // Mark the current tuning as pending.
312                        AutotuneMessage::Pending(key)
313                    } else {
314                        cubecl_common::future::block_on(fut_result)
315                    }
316                }
317            };
318
319            // Note that this message will be processed straight away by handle_results.
320            sender
321                .try_send(message)
322                .expect("Loss message channel somehow");
323        })
324    }
325
326    #[allow(clippy::too_many_arguments)]
327    async fn generate_tune_message<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
328        key: K,
329        client: &ComputeClient<R>,
330        mut plan: TunePlan,
331        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
332        test_inputs: In,
333        mut results: Vec<AutotuneResult>,
334        #[cfg(std_io)] checksum: String,
335        context_logs: bool,
336    ) -> AutotuneMessage<K> {
337        let context_logs = match Self::execute_tune_plan(
338            client,
339            &mut plan,
340            autotunables,
341            &test_inputs,
342            &mut results,
343            context_logs,
344        )
345        .await
346        {
347            Ok(context_logs) => context_logs,
348            Err(err) => {
349                panic!("Can't execute the autotune plan for key: {key:?}\n - Error: {err:?}");
350            }
351        };
352
353        // Finds the fastest operation (by the median time).
354        results.sort_by(|a, b| {
355            let a = a
356                .outcome
357                .as_ref()
358                .map(|r| r.computation.median)
359                .unwrap_or(Duration::MAX);
360            let b = b
361                .outcome
362                .as_ref()
363                .map(|r| r.computation.median)
364                .unwrap_or(Duration::MAX);
365
366            a.cmp(&b)
367        });
368
369        // Log & send results.
370        let result = results
371            .first()
372            .expect("At least one kernel needed.")
373            .outcome
374            .as_ref()
375            .expect("At least one kernel has to succeed.");
376
377        AutotuneMessage::Done {
378            key,
379            fastest_index: result.index,
380            results,
381            #[cfg(std_io)]
382            checksum,
383            context_logs,
384        }
385    }
386
387    async fn execute_tune_plan<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
388        client: &ComputeClient<R>,
389        plan: &mut TunePlan,
390        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
391        test_inputs: &In,
392        results: &mut [AutotuneResult],
393        context_logs: bool,
394    ) -> Result<Option<String>, AutotuneError> {
395        #[derive(Debug)]
396        #[allow(unused_variables, dead_code)] // Only use for debug
397        struct Context<'a> {
398            plan: &'a TunePlan,
399            results: &'a [AutotuneResult],
400        }
401
402        let mut context_logs = match context_logs {
403            true => Some("".to_string()),
404            false => None,
405        };
406
407        loop {
408            let mut num_success = 0;
409            let tunable_indices = plan.next(context_logs.as_mut());
410
411            if tunable_indices.is_empty() {
412                return Err(AutotuneError::NoValidKernelFound {
413                    context: format_debug(&Context { plan, results }),
414                });
415            }
416
417            for index in tunable_indices {
418                let op = &autotunables[index];
419                let name = op.name().to_string();
420                let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
421                let profiles = tuner.profile().map(|bench| (name, index, bench));
422
423                match profiles {
424                    Ok(result) => {
425                        // Wait for the results to come in, and determine the outcome.
426                        let (name, index, profiles) = result;
427                        let result = Self::process_autotune(name, index, profiles).await;
428                        match result {
429                            Ok(val) => {
430                                results[index] = AutotuneResult::success(val);
431                                num_success += 1;
432                            }
433                            Err(err) => {
434                                results[index] = AutotuneResult::error(err);
435                            }
436                        }
437                    }
438                    Err(err) => {
439                        results[index] = AutotuneResult::error(err);
440                    }
441                }
442            }
443
444            if num_success > 0 {
445                break;
446            }
447        }
448
449        Ok(context_logs)
450    }
451
452    async fn process_autotune(
453        name: String,
454        index: usize,
455        profiles: Vec<ProfileDuration>,
456    ) -> Result<AutotuneOutcome, AutotuneError> {
457        let mut durations = Vec::new();
458        if !profiles.is_empty() {
459            let timing_method = profiles.first().unwrap().timing_method();
460            for profile in profiles {
461                durations.push(profile.resolve().await.duration());
462            }
463            let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
464
465            Ok(AutotuneOutcome::new(
466                name,
467                index,
468                BenchmarkComputations::new(&bench_durations),
469            ))
470        } else {
471            Err(AutotuneError::Unknown {
472                name,
473                err: "No profiling available".to_string(),
474            })
475        }
476    }
477}
478
479#[cfg(feature = "autotune-checks")]
480pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
481    mut checks_outputs: Vec<Result<O, AutotuneError>>,
482) {
483    let reference = checks_outputs.remove(checks_outputs.len() - 1);
484
485    if let Ok(reference) = reference {
486        for other in checks_outputs.into_iter().flatten() {
487            reference.check_equivalence(other);
488        }
489    }
490}