cubecl_runtime/tune/
tune_benchmark.rs1use alloc::sync::Arc;
2use alloc::vec::Vec;
3
4use crate::channel::ComputeChannel;
5use crate::client::ComputeClient;
6use crate::server::ComputeServer;
7use cubecl_common::benchmark::ProfileDuration;
8
9use super::{AutotuneError, Tunable};
10
11#[derive(new)]
13pub struct TuneBenchmark<S: ComputeServer, C, In: Clone + Send + 'static, Out: Send + 'static> {
14 operation: Arc<dyn Tunable<Inputs = In, Output = Out>>,
15 inputs: In,
16 client: ComputeClient<S, C>,
17}
18
19pub trait AutotuneOutput: Send + 'static {
21 #[cfg(feature = "autotune-checks")]
22 fn check_equivalence(&self, other: Self);
25}
26
27impl AutotuneOutput for () {
28 #[cfg(feature = "autotune-checks")]
29 fn check_equivalence(&self, _other: Self) {
30 }
32}
33
34impl<
35 S: ComputeServer + 'static,
36 C: ComputeChannel<S> + 'static,
37 In: Clone + Send + 'static,
38 Out: AutotuneOutput,
39> TuneBenchmark<S, C, In, Out>
40{
41 #[cfg(feature = "autotune-checks")]
42 pub(crate) fn output_for_checks(&self) -> Result<Out, AutotuneError> {
43 self.operation.clone().execute(self.inputs.clone())
44 }
45
46 pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
48 let operation = self.operation;
49 operation.execute(self.inputs.clone())?;
54
55 let num_samples = 10;
56 let durations = (0..num_samples)
57 .map(|_| {
58 self.client.profile(|| {
59 operation
60 .execute(self.inputs.clone())
61 .expect("Should not fail when previously tried during the warmup.");
62 })
63 })
64 .collect();
65
66 Ok(durations)
67 }
68}