Skip to main content

cubecl_runtime/tune/
tune_benchmark.rs

1use super::{AutotuneError, TuneFn, TuneInputs};
2use crate::{client::ComputeClient, runtime::Runtime};
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use cubecl_common::profile::ProfileDuration;
6
7/// The trait to be implemented by an autotune output.
8pub trait AutotuneOutput: Send + 'static {
9    #[cfg(feature = "autotune-checks")]
10    /// Checks if the output of an autotune operation is the same as another one on the same
11    /// problem.
12    fn check_equivalence(&self, other: Self);
13}
14
15impl AutotuneOutput for () {
16    #[cfg(feature = "autotune-checks")]
17    fn check_equivalence(&self, _other: Self) {
18        //
19    }
20}
21
22/// Benchmark how long this operation takes for a number of samples.
23///
24/// Returns at least one duration, otherwise an error is returned.
25pub fn tune_benchmark<'a, R: Runtime, F: TuneInputs, Out: AutotuneOutput>(
26    operation: &TuneFn<F, Out>,
27    inputs: <F as TuneInputs>::At<'a>,
28    client: ComputeClient<R>,
29) -> Result<Vec<ProfileDuration>, AutotuneError> {
30    // `scoped` holds exclusive device access for the whole benchmark loop and
31    // accepts non-`'static` closures.
32    client
33        .clone()
34        .exclusive(move || profile_exclusive(operation, inputs, client))
35        .map_err(|err| AutotuneError::Unknown {
36            name: operation.name.to_string(),
37            err: err.to_string(),
38        })?
39}
40
41fn profile_exclusive<'a, R: Runtime, F: TuneInputs, Out: AutotuneOutput>(
42    operation: &TuneFn<F, Out>,
43    inputs: <F as TuneInputs>::At<'a>,
44    client: ComputeClient<R>,
45) -> Result<Vec<ProfileDuration>, AutotuneError> {
46    warmup(operation, inputs.clone(), client.clone())?;
47
48    let num_samples = 10;
49    let mut durations = Vec::new();
50
51    for _ in 0..num_samples {
52        let result: Result<
53            (Result<Out, AutotuneError>, ProfileDuration),
54            crate::server::ProfileError,
55        > = {
56            let inputs = inputs.clone();
57
58            client.profile(
59                move || {
60                    // It is important to return the output since otherwise deadcode elimination
61                    // might optimize away code that needs to be profiled.
62                    operation.execute(inputs)
63                },
64                &operation.name,
65            )
66        };
67
68        let result = match result {
69            Ok((out, duration)) => match out {
70                Ok(_) => Some(duration),
71                Err(err) => {
72                    log::trace!("Error while autotuning {err:?}");
73                    None
74                }
75            },
76            Err(err) => {
77                log::trace!("Error while autotuning {err:?}");
78                None
79            }
80        };
81
82        if let Some(item) = result {
83            durations.push(item);
84        }
85    }
86
87    if durations.is_empty() {
88        Err(AutotuneError::InvalidSamples {
89            name: operation.name.to_string(),
90        })
91    } else {
92        Ok(durations)
93    }
94}
95
96fn warmup<'a, R: Runtime, F: TuneInputs, Out: AutotuneOutput>(
97    operation: &TuneFn<F, Out>,
98    inputs: <F as TuneInputs>::At<'a>,
99    client: ComputeClient<R>,
100) -> Result<(), AutotuneError> {
101    let num_warmup = 3;
102
103    let mut errors = Vec::with_capacity(num_warmup);
104    // We make sure the server is in a correct state.
105    let _errs = client.flush();
106
107    for _ in 0..num_warmup {
108        let inputs = inputs.clone();
109        let profiled = client.profile(move || operation.execute(inputs), &operation.name);
110
111        match profiled {
112            Ok(_) => {}
113            Err(err) => errors.push(err),
114        }
115    }
116
117    if errors.len() < num_warmup {
118        Ok(())
119    } else {
120        let msg = alloc::format!("{:?}", errors.remove(num_warmup - 1));
121        Err(AutotuneError::Unknown {
122            name: operation.name.to_string(),
123            err: msg,
124        })
125    }
126}