cubecl_runtime/tune/
tune_benchmark.rs

1use super::{AutotuneError, TuneFn};
2use crate::{client::ComputeClient, runtime::Runtime};
3use alloc::format;
4use alloc::string::ToString;
5use alloc::sync::Arc;
6use alloc::vec::Vec;
7use cubecl_common::profile::{ProfileDuration, TimingMethod};
8
9/// A benchmark that runs on server handles
10#[derive(new)]
11pub struct TuneBenchmark<R: Runtime, In: Clone + Send + 'static, Out: Send + 'static> {
12    operation: Arc<dyn TuneFn<Inputs = In, Output = Out>>,
13    inputs: In,
14    client: ComputeClient<R>,
15}
16
17/// The trait to be implemented by an autotune output.
18pub trait AutotuneOutput: Send + 'static {
19    #[cfg(feature = "autotune-checks")]
20    /// Checks if the output of an autotune operation is the same as another one on the same
21    /// problem.
22    fn check_equivalence(&self, other: Self);
23}
24
25impl AutotuneOutput for () {
26    #[cfg(feature = "autotune-checks")]
27    fn check_equivalence(&self, _other: Self) {
28        //
29    }
30}
31
32impl<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput> TuneBenchmark<R, In, Out> {
33    /// Benchmark how long this operation takes for a number of samples.
34    ///
35    /// Returns at least one duration, otherwise an error is returned.
36    pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
37        // If the inner operation need autotuning as well, we need to call it before. This will
38        // recurse and keep calling operations until a leaf operation tunes, and so on. This effectively
39        // does a depth-first traversal of the operation tree.
40
41        // For now we wrap the warmup operation inside a profiling task, since we have basic error
42        // handling for system timing methods.
43        match self.client.properties().timing_method {
44            TimingMethod::System => self.warmup_full_error_handling(),
45            TimingMethod::Device => self.warmup_minimal_error_handling(),
46        }?;
47
48        let operation = self.operation;
49        let num_samples = 10;
50        let durations: Vec<_> = (0..num_samples)
51            .filter_map(|_| {
52                let result: Result<
53                    (Result<Out, AutotuneError>, ProfileDuration),
54                    crate::server::ProfileError,
55                > = self.client.profile(
56                    || {
57                        // It is important to return the output since otherwise deadcode elimination
58                        // might optimize away code that needs to be profiled.
59                        operation.execute(self.inputs.clone())
60                    },
61                    operation.name(),
62                );
63
64                match result {
65                    Ok((out, duration)) => match out {
66                        Ok(_) => Some(duration),
67                        Err(err) => {
68                            log::warn!("Error while autotuning {err:?}");
69                            None
70                        }
71                    },
72                    Err(err) => {
73                        log::warn!("Error while autotuning {err:?}");
74                        None
75                    }
76                }
77            })
78            .collect();
79
80        if durations.is_empty() {
81            Err(AutotuneError::InvalidSamples {
82                name: operation.name().to_string(),
83            })
84        } else {
85            Ok(durations)
86        }
87    }
88
89    fn warmup_full_error_handling(&self) -> Result<(), AutotuneError> {
90        let mut error = None;
91
92        let result = self.client.profile(
93            || {
94                if let Err(err) = self.operation.execute(self.inputs.clone()) {
95                    error = Some(err);
96                }
97            },
98            self.operation.name(),
99        );
100
101        if let Err(err) = result {
102            return Err(AutotuneError::Unknown {
103                name: self.operation.name().to_string(),
104                err: format!("{err:?}"),
105            });
106        };
107
108        if let Some(err) = error {
109            return Err(err);
110        };
111
112        Ok(())
113    }
114    fn warmup_minimal_error_handling(&self) -> Result<(), AutotuneError> {
115        self.operation.execute(self.inputs.clone())?;
116        Ok(())
117    }
118}