cubecl_runtime/tune/
tune_benchmark.rs

1use alloc::format;
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use cubecl_common::profile::{ProfileDuration, TimingMethod};
5
6use crate::client::ComputeClient;
7use crate::server::ComputeServer;
8
9use super::{AutotuneError, TuneFn};
10
11/// A benchmark that runs on server handles
12#[derive(new)]
13pub struct TuneBenchmark<S: ComputeServer, In: Clone + Send + 'static, Out: Send + 'static> {
14    operation: Arc<dyn TuneFn<Inputs = In, Output = Out>>,
15    inputs: In,
16    client: ComputeClient<S>,
17}
18
19/// The trait to be implemented by an autotune output.
20pub trait AutotuneOutput: Send + 'static {
21    #[cfg(feature = "autotune-checks")]
22    /// Checks if the output of an autotune operation is the same as another one on the same
23    /// problem.
24    fn check_equivalence(&self, other: Self);
25}
26
27impl AutotuneOutput for () {
28    #[cfg(feature = "autotune-checks")]
29    fn check_equivalence(&self, _other: Self) {
30        //
31    }
32}
33
34impl<S: ComputeServer + 'static, In: Clone + Send + 'static, Out: AutotuneOutput>
35    TuneBenchmark<S, In, Out>
36{
37    /// Benchmark how long this operation takes for a number of samples.
38    ///
39    /// Returns at least one duration, otherwise an error is returned.
40    pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
41        // If the inner operation need autotuning as well, we need to call it before. This will
42        // recurse and keep calling operations until a leaf operation tunes, and so on. This effectively
43        // does a depth-first traversal of the operation tree.
44
45        // For now we wrap the warmup operation inside a profiling task, since we have basic error
46        // handling for system timing methods.
47        match self.client.properties().timing_method {
48            TimingMethod::System => self.warmup_full_error_handling(),
49            TimingMethod::Device => self.warmup_minimal_error_handling(),
50        }?;
51
52        let operation = self.operation;
53        let num_samples = 10;
54        let durations: Vec<_> = (0..num_samples)
55            .filter_map(|_| {
56                let result: Result<ProfileDuration, crate::server::ProfileError> =
57                    self.client.profile(
58                        || {
59                            // It is important to return the output since otherwise deadcode elimination
60                            // might optimize away code that needs to be profiled.
61                            operation
62                                .execute(self.inputs.clone())
63                                .expect("Should not fail when previously tried during the warmup.")
64                        },
65                        operation.name(),
66                    );
67
68                match result {
69                    Ok(val) => Some(val),
70                    Err(err) => {
71                        log::warn!("Error while autotuning {err:?}");
72                        None
73                    }
74                }
75            })
76            .collect();
77
78        if durations.is_empty() {
79            Err(AutotuneError::InvalidSamples)
80        } else {
81            Ok(durations)
82        }
83    }
84
85    fn warmup_full_error_handling(&self) -> Result<(), AutotuneError> {
86        let mut error = None;
87
88        let result = self.client.profile(
89            || {
90                if let Err(err) = self.operation.execute(self.inputs.clone()) {
91                    error = Some(err);
92                }
93            },
94            self.operation.name(),
95        );
96
97        if let Err(err) = result {
98            return Err(AutotuneError::Unknown(format!("{err:?}")));
99        };
100
101        if let Some(err) = error {
102            return Err(err);
103        };
104
105        Ok(())
106    }
107    fn warmup_minimal_error_handling(&self) -> Result<(), AutotuneError> {
108        self.operation.execute(self.inputs.clone())?;
109        Ok(())
110    }
111}