cubecl_runtime/tune/
tuner.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::profile::ProfileDuration;
7use hashbrown::HashSet;
8
9use core::time::Duration;
10
11use alloc::string::{String, ToString};
12use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
13
14use crate::client::ComputeClient;
15use crate::config::{Logger, autotune::AutotuneLogLevel};
16use crate::server::ComputeServer;
17use crate::tune::{TuneBenchmark, TuneCache};
18
19use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
20
21#[derive(Debug)]
22/// Executes autotune benchmarking and caching
23pub struct Tuner<K: AutotuneKey> {
24    tune_cache: TuneCache<K>,
25    logger: Logger,
26    channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
27    pub(crate) autotuning: HashSet<K>,
28}
29
30/// The measured outcome for a given autotune invocation.
31#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize, PartialEq, Eq))]
32#[derive(new, Debug, Clone)]
33pub struct AutotuneOutcome {
34    name: String,
35    index: usize,
36    computation: BenchmarkComputations,
37}
38
39impl core::fmt::Display for AutotuneOutcome {
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        write!(
42            f,
43            "Autotune[{}] name {} => {:?}",
44            self.index, self.name, self.computation
45        )
46    }
47}
48
49enum AutotuneMessage<K> {
50    Done {
51        key: K,
52        fastest_index: usize,
53        results: Vec<Result<AutotuneOutcome, AutotuneError>>,
54        #[cfg(std_io)]
55        checksum: String,
56    },
57    #[allow(dead_code)]
58    Pending(K),
59}
60
61/// Error from running autotune.
62#[derive(Debug, PartialEq, Eq, Clone)]
63#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
64pub enum AutotuneError {
65    /// An unknown error happened.
66    Unknown(String),
67    /// All samples are invalid.
68    InvalidSamples,
69    /// The autotune is skipped manually.
70    Skip,
71}
72
73impl From<String> for AutotuneError {
74    fn from(value: String) -> Self {
75        Self::Unknown(value)
76    }
77}
78
79#[allow(clippy::new_without_default)]
80impl<K: AutotuneKey> Tuner<K> {
81    /// Returns a tuner with cache initialized from persistent cache
82    pub fn new(name: &str, device_id: &str) -> Self {
83        let channel = async_channel::unbounded();
84
85        Self {
86            tune_cache: TuneCache::new(name, device_id),
87            logger: Logger::new(),
88            channel,
89            autotuning: HashSet::new(),
90        }
91    }
92
93    /// Fetch the fastest autotune operation index for an autotune key.
94    pub fn fastest(&self, key: &K) -> TuneCacheResult {
95        self.tune_cache.fastest(key)
96    }
97
98    /// Fetch the fastest autotune operation index for an autotune key and validate the checksum.
99    #[cfg(std_io)]
100    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
101        if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
102            self.logger
103                .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
104        }
105        self.tune_cache.validate_checksum(key, checksum)
106    }
107
108    /// Handle an autotune result message, see [`execute_autotune`]
109    fn handle_result(&mut self, msg: AutotuneMessage<K>) {
110        match msg {
111            AutotuneMessage::Pending(key) => {
112                self.tune_cache.mark_pending(key);
113            }
114            AutotuneMessage::Done {
115                key,
116                fastest_index,
117                results,
118                #[cfg(std_io)]
119                checksum,
120            } => {
121                match self.logger.log_level_autotune() {
122                    AutotuneLogLevel::Minimal => {
123                        let top_times = results
124                            .iter()
125                            .map(|r| {
126                                let time = r
127                                    .as_ref()
128                                    .map(|r| r.computation.median)
129                                    .unwrap_or(Duration::MAX);
130
131                                let index = r.as_ref().map(|r| r.index).unwrap_or_default();
132                                (index, time)
133                            })
134                            .take(3)
135                            .collect::<Vec<_>>();
136
137                        let result = results
138                            .first()
139                            .expect("At least one kernel needed.")
140                            .as_ref()
141                            .expect("At least one kernel has to succeed.");
142
143                        self.logger.log_autotune(&format!(
144                            "Fastest result {}-{key}. \n Top 3 times: {top_times:?}",
145                            result.name,
146                        ));
147                    }
148                    AutotuneLogLevel::Full => {
149                        let result = results
150                            .first()
151                            .expect("At least one kernel needed.")
152                            .as_ref()
153                            .expect("At least one kernel has to succeed.");
154
155                        self.logger
156                            .log_autotune(&format!("Fastest result {}-{key}.", result.name,));
157
158                        for result in results.iter() {
159                            match result {
160                                Ok(val) => {
161                                    self.logger.log_autotune(&format!("{val}"));
162                                }
163                                Err(err) => self.logger.log_autotune(&format!("{err:?}")),
164                            }
165                        }
166                    }
167                    AutotuneLogLevel::Disabled => {}
168                };
169
170                self.tune_cache.cache_insert(key.clone(), fastest_index);
171
172                #[cfg(std_io)]
173                {
174                    self.tune_cache
175                        .persistent_cache_insert(key, checksum, fastest_index, results);
176                }
177            }
178        }
179    }
180
181    /// Check if any autotuning results have come in asynchronously.
182    pub fn handle_results(&mut self) {
183        // Handle any results that have come in. Note that execute_autotune pushes results to the channel immediately if possible.
184        // Since this function takes an &mut we know we have exclusive access, and no other threads are currently still adding results.
185        while let Ok(msg) = self.channel.1.try_recv() {
186            self.handle_result(msg);
187        }
188    }
189
190    /// Execute benchmarks to find out what the fastest operation is.
191    pub fn prepare_autotune<
192        S: ComputeServer + 'static,
193        In: Clone + Send + 'static,
194        Out: AutotuneOutput,
195    >(
196        &self,
197        key: K,
198        inputs: &In,
199        tunables: &TunableSet<K, In, Out>,
200        client: &ComputeClient<S>,
201    ) -> Box<dyn FnOnce()> {
202        log::info!("Tuning {key}");
203
204        // Note that this message will be processed straight away by handle_results.
205        let sender = self.channel.0.clone();
206
207        let autotunables = tunables.autotunables();
208        let mut results = Vec::with_capacity(autotunables.len());
209
210        for _ in 0..autotunables.len() {
211            results.push(Err(AutotuneError::Skip));
212        }
213
214        if autotunables.len() == 1 {
215            let message = AutotuneMessage::Done {
216                key,
217                fastest_index: 0,
218                results,
219                #[cfg(std_io)]
220                checksum: tunables.compute_checksum(),
221            };
222
223            return Box::new(move || {
224                sender
225                    .try_send(message)
226                    .expect("Loss message channel somehow")
227            });
228        }
229
230        let client = client.clone();
231        let key_cloned = key.clone();
232        let plan = tunables.plan(&key);
233        let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
234
235        #[cfg(std_io)]
236        let checksum = tunables.compute_checksum();
237
238        let fut_result = async move {
239            let test_inputs = inputs_generator();
240
241            Self::generate_tune_message(
242                key_cloned,
243                &client,
244                plan,
245                autotunables,
246                test_inputs,
247                results,
248                #[cfg(std_io)]
249                checksum,
250            )
251            .await
252        };
253
254        Box::new(move || {
255            let message = {
256                cfg_if::cfg_if! {
257                    if #[cfg(target_family = "wasm")] {
258                        let sender = sender.clone();
259
260                        let send_fut = async move {
261                            // If the channel has been closed, ignore. Maybe the main app is exiting
262                            // before the tune results come in.
263                            let _ = sender.send(fut_result.await).await;
264                        };
265                        // On wasm, spawn the tuning as a detached task.
266                        wasm_bindgen_futures::spawn_local(send_fut);
267                        // Mark the current tuning as pending.
268                        AutotuneMessage::Pending(key)
269                    } else {
270                        cubecl_common::future::block_on(fut_result)
271                    }
272                }
273            };
274
275            // Note that this message will be processed straight away by handle_results.
276            sender
277                .try_send(message)
278                .expect("Loss message channel somehow");
279        })
280    }
281
282    async fn generate_tune_message<
283        In: Clone + Send + 'static,
284        Out: AutotuneOutput,
285        S: ComputeServer + 'static,
286    >(
287        key: K,
288        client: &ComputeClient<S>,
289        mut plan: TunePlan,
290        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
291        test_inputs: In,
292        mut results: Vec<Result<AutotuneOutcome, AutotuneError>>,
293        #[cfg(std_io)] checksum: String,
294    ) -> AutotuneMessage<K> {
295        Self::execute_tune_plan(client, &mut plan, autotunables, &test_inputs, &mut results).await;
296
297        // Finds the fastest operation (by the median time).
298        results.sort_by(|a, b| {
299            let a = a
300                .as_ref()
301                .map(|r| r.computation.median)
302                .unwrap_or(Duration::MAX);
303            let b = b
304                .as_ref()
305                .map(|r| r.computation.median)
306                .unwrap_or(Duration::MAX);
307
308            a.cmp(&b)
309        });
310
311        // Log & send results.
312        let result = results
313            .first()
314            .expect("At least one kernel needed.")
315            .as_ref()
316            .expect("At least one kernel has to succeed.");
317
318        AutotuneMessage::Done {
319            key,
320            fastest_index: result.index,
321            results,
322            #[cfg(std_io)]
323            checksum,
324        }
325    }
326
327    async fn execute_tune_plan<
328        In: Clone + Send + 'static,
329        Out: AutotuneOutput,
330        S: ComputeServer + 'static,
331    >(
332        client: &ComputeClient<S>,
333        plan: &mut TunePlan,
334        autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
335        test_inputs: &In,
336        results: &mut [Result<AutotuneOutcome, AutotuneError>],
337    ) {
338        loop {
339            let mut num_autotuned = 0;
340
341            let tunable_indices = plan.next();
342
343            if tunable_indices.is_empty() {
344                panic!("No autotune was flagged as valid for the problem.")
345            }
346
347            for index in tunable_indices {
348                let op = &autotunables[index];
349                let name = op.name().to_string();
350                let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
351                let profiles = tuner.profile().map(|bench| (name, index, bench));
352
353                match profiles {
354                    Ok(result) => {
355                        // Wait for the results to come in, and determine the outcome.
356                        let (name, index, profiles) = result;
357                        let result = Self::process_autotune(name, index, profiles).await;
358                        match result {
359                            Ok(val) => {
360                                results[index] = Ok(val);
361                                num_autotuned += 1;
362                            }
363                            Err(err) => {
364                                results[index] = Err(err);
365                            }
366                        }
367                    }
368                    Err(err) => {
369                        results[index] = Err(err);
370                    }
371                }
372            }
373
374            if num_autotuned > 0 {
375                break;
376            }
377        }
378    }
379
380    async fn process_autotune(
381        name: String,
382        index: usize,
383        profiles: Vec<ProfileDuration>,
384    ) -> Result<AutotuneOutcome, AutotuneError> {
385        let mut durations = Vec::new();
386        if !profiles.is_empty() {
387            let timing_method = profiles.first().unwrap().timing_method();
388            for profile in profiles {
389                durations.push(profile.resolve().await.duration());
390            }
391            let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
392
393            Ok(AutotuneOutcome::new(
394                name,
395                index,
396                BenchmarkComputations::new(&bench_durations),
397            ))
398        } else {
399            Err(AutotuneError::Unknown(format!(
400                "Runtime error while profiling {name}."
401            )))
402        }
403    }
404}
405
406#[cfg(feature = "autotune-checks")]
407pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
408    mut checks_outputs: Vec<Result<O, AutotuneError>>,
409) {
410    let reference = checks_outputs.remove(checks_outputs.len() - 1);
411
412    if let Ok(reference) = reference {
413        for other in checks_outputs.into_iter().flatten() {
414            reference.check_equivalence(other);
415        }
416    }
417}