1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use super::{AutotuneError, TuneFn};
use crate::{client::ComputeClient, runtime::Runtime};
use alloc::format;
use alloc::string::ToString;
use alloc::sync::Arc;
use alloc::vec::Vec;
use cubecl_common::profile::{ProfileDuration, TimingMethod};
/// A benchmark that runs on server handles
#[derive(new)]
pub struct TuneBenchmark<R: Runtime, In: Clone + Send + 'static, Out: Send + 'static> {
operation: Arc<dyn TuneFn<Inputs = In, Output = Out>>,
inputs: In,
client: ComputeClient<R>,
}
/// The trait to be implemented by an autotune output.
pub trait AutotuneOutput: Send + 'static {
#[cfg(feature = "autotune-checks")]
/// Checks if the output of an autotune operation is the same as another one on the same
/// problem.
fn check_equivalence(&self, other: Self);
}
impl AutotuneOutput for () {
#[cfg(feature = "autotune-checks")]
fn check_equivalence(&self, _other: Self) {
//
}
}
impl<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput> TuneBenchmark<R, In, Out> {
/// Benchmark how long this operation takes for a number of samples.
///
/// Returns at least one duration, otherwise an error is returned.
pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
// If the inner operation need autotuning as well, we need to call it before. This will
// recurse and keep calling operations until a leaf operation tunes, and so on. This effectively
// does a depth-first traversal of the operation tree.
// For now we wrap the warmup operation inside a profiling task, since we have basic error
// handling for system timing methods.
match self.client.properties().timing_method {
TimingMethod::System => self.warmup_full_error_handling(),
TimingMethod::Device => self.warmup_minimal_error_handling(),
}?;
let operation = self.operation;
let num_samples = 10;
let durations: Vec<_> = (0..num_samples)
.filter_map(|_| {
let result: Result<
(Result<Out, AutotuneError>, ProfileDuration),
crate::server::ProfileError,
> = self.client.profile(
|| {
// It is important to return the output since otherwise deadcode elimination
// might optimize away code that needs to be profiled.
operation.execute(self.inputs.clone())
},
operation.name(),
);
match result {
Ok((out, duration)) => match out {
Ok(_) => Some(duration),
Err(err) => {
log::warn!("Error while autotuning {err:?}");
None
}
},
Err(err) => {
log::warn!("Error while autotuning {err:?}");
None
}
}
})
.collect();
if durations.is_empty() {
Err(AutotuneError::InvalidSamples {
name: operation.name().to_string(),
})
} else {
Ok(durations)
}
}
fn warmup_full_error_handling(&self) -> Result<(), AutotuneError> {
let mut error = None;
let result = self.client.profile(
|| {
if let Err(err) = self.operation.execute(self.inputs.clone()) {
error = Some(err);
}
},
self.operation.name(),
);
if let Err(err) = result {
return Err(AutotuneError::Unknown {
name: self.operation.name().to_string(),
err: format!("{err:?}"),
});
};
if let Some(err) = error {
return Err(err);
};
Ok(())
}
fn warmup_minimal_error_handling(&self) -> Result<(), AutotuneError> {
self.operation.execute(self.inputs.clone())?;
Ok(())
}
}