cubecl_runtime/tune/
tune_benchmark.rs1use super::{AutotuneError, TuneFn, TuneInputs};
2use crate::{client::ComputeClient, runtime::Runtime};
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use cubecl_common::profile::ProfileDuration;
6
7pub trait AutotuneOutput: Send + 'static {
9 #[cfg(feature = "autotune-checks")]
10 fn check_equivalence(&self, other: Self);
13}
14
15impl AutotuneOutput for () {
16 #[cfg(feature = "autotune-checks")]
17 fn check_equivalence(&self, _other: Self) {
18 }
20}
21
22pub fn tune_benchmark<'a, R: Runtime, F: TuneInputs, Out: AutotuneOutput>(
26 operation: &TuneFn<F, Out>,
27 inputs: <F as TuneInputs>::At<'a>,
28 client: ComputeClient<R>,
29) -> Result<Vec<ProfileDuration>, AutotuneError> {
30 client
33 .clone()
34 .exclusive(move || profile_exclusive(operation, inputs, client))
35 .map_err(|err| AutotuneError::Unknown {
36 name: operation.name.to_string(),
37 err: err.to_string(),
38 })?
39}
40
41fn profile_exclusive<'a, R: Runtime, F: TuneInputs, Out: AutotuneOutput>(
42 operation: &TuneFn<F, Out>,
43 inputs: <F as TuneInputs>::At<'a>,
44 client: ComputeClient<R>,
45) -> Result<Vec<ProfileDuration>, AutotuneError> {
46 warmup(operation, inputs.clone(), client.clone())?;
47
48 let num_samples = 10;
49 let mut durations = Vec::new();
50
51 for _ in 0..num_samples {
52 let result: Result<
53 (Result<Out, AutotuneError>, ProfileDuration),
54 crate::server::ProfileError,
55 > = {
56 let inputs = inputs.clone();
57
58 client.profile(
59 move || {
60 operation.execute(inputs)
63 },
64 &operation.name,
65 )
66 };
67
68 let result = match result {
69 Ok((out, duration)) => match out {
70 Ok(_) => Some(duration),
71 Err(err) => {
72 log::trace!("Error while autotuning {err:?}");
73 None
74 }
75 },
76 Err(err) => {
77 log::trace!("Error while autotuning {err:?}");
78 None
79 }
80 };
81
82 if let Some(item) = result {
83 durations.push(item);
84 }
85 }
86
87 if durations.is_empty() {
88 Err(AutotuneError::InvalidSamples {
89 name: operation.name.to_string(),
90 })
91 } else {
92 Ok(durations)
93 }
94}
95
96fn warmup<'a, R: Runtime, F: TuneInputs, Out: AutotuneOutput>(
97 operation: &TuneFn<F, Out>,
98 inputs: <F as TuneInputs>::At<'a>,
99 client: ComputeClient<R>,
100) -> Result<(), AutotuneError> {
101 let num_warmup = 3;
102
103 let mut errors = Vec::with_capacity(num_warmup);
104 let _errs = client.flush();
106
107 for _ in 0..num_warmup {
108 let inputs = inputs.clone();
109 let profiled = client.profile(move || operation.execute(inputs), &operation.name);
110
111 match profiled {
112 Ok(_) => {}
113 Err(err) => errors.push(err),
114 }
115 }
116
117 if errors.len() < num_warmup {
118 Ok(())
119 } else {
120 let msg = alloc::format!("{:?}", errors.remove(num_warmup - 1));
121 Err(AutotuneError::Unknown {
122 name: operation.name.to_string(),
123 err: msg,
124 })
125 }
126}