cubecl_runtime/tune/
tune_benchmark.rs

1use alloc::sync::Arc;
2use alloc::vec::Vec;
3
4use crate::channel::ComputeChannel;
5use crate::client::ComputeClient;
6use crate::server::ComputeServer;
7use cubecl_common::benchmark::ProfileDuration;
8
9use super::{AutotuneError, Tunable};
10
11/// A benchmark that runs on server handles
12#[derive(new)]
13pub struct TuneBenchmark<S: ComputeServer, C, In: Clone + Send + 'static, Out: Send + 'static> {
14    operation: Arc<dyn Tunable<Inputs = In, Output = Out>>,
15    inputs: In,
16    client: ComputeClient<S, C>,
17}
18
19/// The trait to be implemented by an autotune output.
20pub trait AutotuneOutput: Send + 'static {
21    #[cfg(feature = "autotune-checks")]
22    /// Checks if the output of an autotune operation is the same as another one on the same
23    /// problem.
24    fn check_equivalence(&self, other: Self);
25}
26
27impl AutotuneOutput for () {
28    #[cfg(feature = "autotune-checks")]
29    fn check_equivalence(&self, _other: Self) {
30        //
31    }
32}
33
34impl<
35    S: ComputeServer + 'static,
36    C: ComputeChannel<S> + 'static,
37    In: Clone + Send + 'static,
38    Out: AutotuneOutput,
39> TuneBenchmark<S, C, In, Out>
40{
41    #[cfg(feature = "autotune-checks")]
42    pub(crate) fn output_for_checks(&self) -> Result<Out, AutotuneError> {
43        self.operation.clone().execute(self.inputs.clone())
44    }
45
46    /// Benchmark how long this operation takes for a number of samples.
47    pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
48        let operation = self.operation;
49        // If the inner operation need autotuning as well, we need to call it before. This will
50        // recurse and keep calling operations until a leaf operation tunes, and so on. This effectively
51        // does a depth-first traversal of the operation tree. Without this, client.profile() would have to
52        // support profiling recursively.
53        operation.execute(self.inputs.clone())?;
54
55        let num_samples = 10;
56        let durations = (0..num_samples)
57            .map(|_| {
58                self.client.profile(|| {
59                    operation
60                        .execute(self.inputs.clone())
61                        .expect("Should not fail when previously tried during the warmup.");
62                })
63            })
64            .collect();
65
66        Ok(durations)
67    }
68}