cubecl_runtime/tune/
tune_benchmark.rs1use super::{AutotuneError, TuneFn};
2use crate::{client::ComputeClient, runtime::Runtime};
3use alloc::string::ToString;
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6use cubecl_common::profile::ProfileDuration;
7
8#[derive(new)]
10pub struct TuneBenchmark<R: Runtime, In: Clone + Send + 'static, Out: Send + 'static> {
11 operation: Arc<dyn TuneFn<Inputs = In, Output = Out>>,
12 inputs: In,
13 client: ComputeClient<R>,
14}
15
16pub trait AutotuneOutput: Send + 'static {
18 #[cfg(feature = "autotune-checks")]
19 fn check_equivalence(&self, other: Self);
22}
23
24impl AutotuneOutput for () {
25 #[cfg(feature = "autotune-checks")]
26 fn check_equivalence(&self, _other: Self) {
27 }
29}
30
31impl<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput> TuneBenchmark<R, In, Out> {
32 pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
36 let client = self.client.clone();
37 let name = self.operation.name().to_string();
38
39 client
40 .exclusive(move || self.profile_exclusive())
41 .map_err(|err| AutotuneError::Unknown {
42 name,
43 err: err.to_string(),
44 })?
45 }
46
47 fn profile_exclusive(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
48 self.warmup()?;
49
50 let operation = self.operation.clone();
51 let name = operation.name().to_string();
52 let num_samples = 10;
53 let mut durations = Vec::new();
54 for _ in 0..num_samples {
55 let result: Result<
56 (Result<Out, AutotuneError>, ProfileDuration),
57 crate::server::ProfileError,
58 > = {
59 let inputs = self.inputs.clone();
60 let operation = operation.clone();
61
62 self.client.profile(
63 move || {
64 operation.execute(inputs)
67 },
68 &name,
69 )
70 };
71
72 let result = match result {
73 Ok((out, duration)) => match out {
74 Ok(_) => Some(duration),
75 Err(err) => {
76 log::trace!("Error while autotuning {err:?}");
77 None
78 }
79 },
80 Err(err) => {
81 log::trace!("Error while autotuning {err:?}");
82 None
83 }
84 };
85
86 if let Some(item) = result {
87 durations.push(item);
88 }
89 }
90
91 if durations.is_empty() {
92 Err(AutotuneError::InvalidSamples { name })
93 } else {
94 Ok(durations)
95 }
96 }
97
98 fn warmup(&self) -> Result<(), AutotuneError> {
99 let num_warmup = 3;
100
101 let mut errors = Vec::with_capacity(num_warmup);
102 let _errs = self.client.flush();
104
105 for _ in 0..num_warmup {
106 let op = self.operation.clone();
107 let inputs = self.inputs.clone();
108 let profiled = self
109 .client
110 .profile(move || op.execute(inputs), self.operation.name());
111
112 match profiled {
113 Ok(_) => {}
114 Err(err) => errors.push(err),
115 }
116 }
117
118 if errors.len() < num_warmup {
119 Ok(())
120 } else {
121 let msg = alloc::format!("{:?}", errors.remove(num_warmup - 1));
122 Err(AutotuneError::Unknown {
123 name: self.operation.name().to_string(),
124 err: msg,
125 })
126 }
127 }
128}