Skip to main content

cubecl_runtime/tune/
tune_benchmark.rs

1use super::{AutotuneError, TuneFn};
2use crate::{client::ComputeClient, runtime::Runtime};
3use alloc::string::ToString;
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6use cubecl_common::profile::ProfileDuration;
7
8/// A benchmark that runs on server handles
9#[derive(new)]
10pub struct TuneBenchmark<R: Runtime, In: Clone + Send + 'static, Out: Send + 'static> {
11    operation: Arc<dyn TuneFn<Inputs = In, Output = Out>>,
12    inputs: In,
13    client: ComputeClient<R>,
14}
15
16/// The trait to be implemented by an autotune output.
17pub trait AutotuneOutput: Send + 'static {
18    #[cfg(feature = "autotune-checks")]
19    /// Checks if the output of an autotune operation is the same as another one on the same
20    /// problem.
21    fn check_equivalence(&self, other: Self);
22}
23
24impl AutotuneOutput for () {
25    #[cfg(feature = "autotune-checks")]
26    fn check_equivalence(&self, _other: Self) {
27        //
28    }
29}
30
31impl<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput> TuneBenchmark<R, In, Out> {
32    /// Benchmark how long this operation takes for a number of samples.
33    ///
34    /// Returns at least one duration, otherwise an error is returned.
35    pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
36        let client = self.client.clone();
37        let name = self.operation.name().to_string();
38
39        client
40            .exclusive(move || self.profile_exclusive())
41            .map_err(|err| AutotuneError::Unknown {
42                name,
43                err: err.to_string(),
44            })?
45    }
46
47    fn profile_exclusive(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
48        self.warmup()?;
49
50        let operation = self.operation.clone();
51        let name = operation.name().to_string();
52        let num_samples = 10;
53        let mut durations = Vec::new();
54        for _ in 0..num_samples {
55            let result: Result<
56                (Result<Out, AutotuneError>, ProfileDuration),
57                crate::server::ProfileError,
58            > = {
59                let inputs = self.inputs.clone();
60                let operation = operation.clone();
61
62                self.client.profile(
63                    move || {
64                        // It is important to return the output since otherwise deadcode elimination
65                        // might optimize away code that needs to be profiled.
66                        operation.execute(inputs)
67                    },
68                    &name,
69                )
70            };
71
72            let result = match result {
73                Ok((out, duration)) => match out {
74                    Ok(_) => Some(duration),
75                    Err(err) => {
76                        log::trace!("Error while autotuning {err:?}");
77                        None
78                    }
79                },
80                Err(err) => {
81                    log::trace!("Error while autotuning {err:?}");
82                    None
83                }
84            };
85
86            if let Some(item) = result {
87                durations.push(item);
88            }
89        }
90
91        if durations.is_empty() {
92            Err(AutotuneError::InvalidSamples { name })
93        } else {
94            Ok(durations)
95        }
96    }
97
98    fn warmup(&self) -> Result<(), AutotuneError> {
99        let num_warmup = 3;
100
101        let mut errors = Vec::with_capacity(num_warmup);
102        // We make sure the server is in a correct state.
103        let _errs = self.client.flush();
104
105        for _ in 0..num_warmup {
106            let op = self.operation.clone();
107            let inputs = self.inputs.clone();
108            let profiled = self
109                .client
110                .profile(move || op.execute(inputs), self.operation.name());
111
112            match profiled {
113                Ok(_) => {}
114                Err(err) => errors.push(err),
115            }
116        }
117
118        if errors.len() < num_warmup {
119            Ok(())
120        } else {
121            let msg = alloc::format!("{:?}", errors.remove(num_warmup - 1));
122            Err(AutotuneError::Unknown {
123                name: self.operation.name().to_string(),
124                err: msg,
125            })
126        }
127    }
128}