cubecl_runtime/tune/
tune_benchmark.rs1use alloc::format;
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use cubecl_common::profile::{ProfileDuration, TimingMethod};
5
6use crate::client::ComputeClient;
7use crate::server::ComputeServer;
8
9use super::{AutotuneError, TuneFn};
10
11#[derive(new)]
13pub struct TuneBenchmark<S: ComputeServer, In: Clone + Send + 'static, Out: Send + 'static> {
14 operation: Arc<dyn TuneFn<Inputs = In, Output = Out>>,
15 inputs: In,
16 client: ComputeClient<S>,
17}
18
19pub trait AutotuneOutput: Send + 'static {
21 #[cfg(feature = "autotune-checks")]
22 fn check_equivalence(&self, other: Self);
25}
26
27impl AutotuneOutput for () {
28 #[cfg(feature = "autotune-checks")]
29 fn check_equivalence(&self, _other: Self) {
30 }
32}
33
34impl<S: ComputeServer + 'static, In: Clone + Send + 'static, Out: AutotuneOutput>
35 TuneBenchmark<S, In, Out>
36{
37 pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
41 match self.client.properties().timing_method {
48 TimingMethod::System => self.warmup_full_error_handling(),
49 TimingMethod::Device => self.warmup_minimal_error_handling(),
50 }?;
51
52 let operation = self.operation;
53 let num_samples = 10;
54 let durations: Vec<_> = (0..num_samples)
55 .filter_map(|_| {
56 let result: Result<ProfileDuration, crate::server::ProfileError> =
57 self.client.profile(
58 || {
59 operation
62 .execute(self.inputs.clone())
63 .expect("Should not fail when previously tried during the warmup.")
64 },
65 operation.name(),
66 );
67
68 match result {
69 Ok(val) => Some(val),
70 Err(err) => {
71 log::warn!("Error while autotuning {err:?}");
72 None
73 }
74 }
75 })
76 .collect();
77
78 if durations.is_empty() {
79 Err(AutotuneError::InvalidSamples)
80 } else {
81 Ok(durations)
82 }
83 }
84
85 fn warmup_full_error_handling(&self) -> Result<(), AutotuneError> {
86 let mut error = None;
87
88 let result = self.client.profile(
89 || {
90 if let Err(err) = self.operation.execute(self.inputs.clone()) {
91 error = Some(err);
92 }
93 },
94 self.operation.name(),
95 );
96
97 if let Err(err) = result {
98 return Err(AutotuneError::Unknown(format!("{err:?}")));
99 };
100
101 if let Some(err) = error {
102 return Err(err);
103 };
104
105 Ok(())
106 }
107 fn warmup_minimal_error_handling(&self) -> Result<(), AutotuneError> {
108 self.operation.execute(self.inputs.clone())?;
109 Ok(())
110 }
111}