cubecl_runtime/tune/
tune_benchmark.rs1use super::{AutotuneError, TuneFn};
2use crate::{client::ComputeClient, runtime::Runtime};
3use alloc::format;
4use alloc::string::ToString;
5use alloc::sync::Arc;
6use alloc::vec::Vec;
7use cubecl_common::profile::{ProfileDuration, TimingMethod};
8
9#[derive(new)]
11pub struct TuneBenchmark<R: Runtime, In: Clone + Send + 'static, Out: Send + 'static> {
12 operation: Arc<dyn TuneFn<Inputs = In, Output = Out>>,
13 inputs: In,
14 client: ComputeClient<R>,
15}
16
17pub trait AutotuneOutput: Send + 'static {
19 #[cfg(feature = "autotune-checks")]
20 fn check_equivalence(&self, other: Self);
23}
24
25impl AutotuneOutput for () {
26 #[cfg(feature = "autotune-checks")]
27 fn check_equivalence(&self, _other: Self) {
28 }
30}
31
32impl<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput> TuneBenchmark<R, In, Out> {
33 pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
37 match self.client.properties().timing_method {
44 TimingMethod::System => self.warmup_full_error_handling(),
45 TimingMethod::Device => self.warmup_minimal_error_handling(),
46 }?;
47
48 let operation = self.operation;
49 let num_samples = 10;
50 let durations: Vec<_> = (0..num_samples)
51 .filter_map(|_| {
52 let result: Result<
53 (Result<Out, AutotuneError>, ProfileDuration),
54 crate::server::ProfileError,
55 > = self.client.profile(
56 || {
57 operation.execute(self.inputs.clone())
60 },
61 operation.name(),
62 );
63
64 match result {
65 Ok((out, duration)) => match out {
66 Ok(_) => Some(duration),
67 Err(err) => {
68 log::warn!("Error while autotuning {err:?}");
69 None
70 }
71 },
72 Err(err) => {
73 log::warn!("Error while autotuning {err:?}");
74 None
75 }
76 }
77 })
78 .collect();
79
80 if durations.is_empty() {
81 Err(AutotuneError::InvalidSamples {
82 name: operation.name().to_string(),
83 })
84 } else {
85 Ok(durations)
86 }
87 }
88
89 fn warmup_full_error_handling(&self) -> Result<(), AutotuneError> {
90 let mut error = None;
91
92 let result = self.client.profile(
93 || {
94 if let Err(err) = self.operation.execute(self.inputs.clone()) {
95 error = Some(err);
96 }
97 },
98 self.operation.name(),
99 );
100
101 if let Err(err) = result {
102 return Err(AutotuneError::Unknown {
103 name: self.operation.name().to_string(),
104 err: format!("{err:?}"),
105 });
106 };
107
108 if let Some(err) = error {
109 return Err(err);
110 };
111
112 Ok(())
113 }
114 fn warmup_minimal_error_handling(&self) -> Result<(), AutotuneError> {
115 self.operation.execute(self.inputs.clone())?;
116 Ok(())
117 }
118}