cubecl_runtime/tune/
tuner.rs1use async_channel::{Receiver, Sender};
2use cubecl_common::future;
3
4use core::any::Any;
5use core::future::Future;
6use core::mem::ManuallyDrop;
7use cubecl_common::stub::Duration;
8
9#[cfg(all(not(target_family = "wasm"), feature = "std"))]
10use std::panic::resume_unwind;
11
12use alloc::string::ToString;
13use alloc::vec::Vec;
14use cubecl_common::benchmark::BenchmarkComputations;
15
16use crate::channel::ComputeChannel;
17use crate::client::ComputeClient;
18use crate::server::ComputeServer;
19use crate::tune::{AutotuneOperationSet, TuneBenchmark, TuneCache};
20
21use super::{AutotuneKey, TuneCacheResult};
22
23#[derive(Debug)]
24pub struct Tuner<K: AutotuneKey> {
26 tune_cache: TuneCache<K>,
27 channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
28}
29
30enum AutotuneMessage<K> {
32 Done {
33 key: K,
34 fastest_index: usize,
35 #[cfg(autotune_persistent_cache)]
36 checksum: String,
37 },
38 Starting {
39 key: K,
40 },
41}
42
43#[derive(Debug)]
45pub enum AutotuneError {
46 Unknown(String),
48 PanicUnwind(ManuallyDrop<Box<dyn Any + Send>>),
50}
51
52impl From<String> for AutotuneError {
53 fn from(value: String) -> Self {
54 Self::Unknown(value)
55 }
56}
57
58#[allow(clippy::new_without_default)]
59impl<K: AutotuneKey> Tuner<K> {
60 pub fn new(name: &str, device_id: &str) -> Self {
62 let channel = async_channel::unbounded();
63
64 Self {
65 tune_cache: TuneCache::new(name, device_id),
66 channel,
67 }
68 }
69
70 pub fn fastest(&self, key: &K) -> TuneCacheResult {
72 self.tune_cache.fastest(key)
73 }
74
75 #[cfg(autotune_persistent_cache)]
77 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
78 self.tune_cache.validate_checksum(key, checksum)
79 }
80
81 pub fn resolve(&mut self) {
83 while let Ok(msg) = self.channel.1.try_recv() {
84 match msg {
85 AutotuneMessage::Done {
86 key,
87 fastest_index,
88 #[cfg(autotune_persistent_cache)]
89 checksum,
90 } => {
91 self.tune_cache.cache_insert(key.clone(), fastest_index);
92
93 #[cfg(autotune_persistent_cache)]
94 {
95 self.tune_cache
96 .persistent_cache_insert(key, checksum, fastest_index);
97 self.tune_cache.save();
98 }
99 }
100 AutotuneMessage::Starting { key } => {
101 self.tune_cache.mark_pending(key);
102 }
103 }
104 }
105 }
106
107 pub fn execute_autotune<
109 S: ComputeServer + 'static,
110 C: ComputeChannel<S> + 'static,
111 Out: Send + 'static,
112 >(
113 &self,
114 set: &dyn AutotuneOperationSet<K, Out>,
115 client: &ComputeClient<S, C>,
116 ) {
117 let key = set.key();
118 log::info!("Tuning {key}");
119
120 let autotunables: Vec<_> = set
121 .autotunables()
122 .into_iter()
123 .enumerate()
124 .filter(|(index, _)| set.should_run(&key, *index))
125 .collect();
126
127 let client = client.clone();
128 let sender = self.channel.0.clone();
129
130 if autotunables.len() == 1 {
131 sender
132 .try_send(AutotuneMessage::Done {
133 key,
134 fastest_index: autotunables[0].0,
135 #[cfg(autotune_persistent_cache)]
136 checksum: set.compute_checksum(),
137 })
138 .expect("Autotune results channel closed");
139 return;
140 }
141
142 sender
143 .try_send(AutotuneMessage::Starting { key: key.clone() })
144 .expect("Autotune results channel closed");
145
146 #[cfg(autotune_persistent_cache)]
147 let checksum = set.compute_checksum();
148
149 spawn_benchmark_task(async move {
150 #[derive(new, Debug)]
151 struct BenchResult {
152 name: String,
153 index: usize,
154 computation: BenchmarkComputations,
155 }
156
157 let mut bench_results = Vec::with_capacity(autotunables.len());
158
159 for (index, op) in autotunables.into_iter() {
160 let name = op.name().to_string();
161 let tuner = TuneBenchmark::new(op, client.clone());
162
163 let sample_fut = tuner.sample_durations();
164 let sample_fut = future::catch_unwind(sample_fut);
165 let result = sample_fut.await;
166
167 let result = match result {
168 Ok(result) => result,
169 Err(err) => {
170 log::warn!(
171 "Caught unknown error while benchmarking, falling back to next operation."
172 );
173 Err(AutotuneError::PanicUnwind(ManuallyDrop::new(err)))
174 }
175 };
176
177 let result = result.map(|durations| {
178 log::info!("Name: {name} => {}", durations);
179 BenchResult::new(name, index, BenchmarkComputations::new(&durations))
180 });
181
182 bench_results.push(result);
183 }
184
185 #[cfg(all(feature = "std", not(target_family = "wasm")))]
187 if bench_results.iter().all(|result| result.is_err()) {
188 let first_error = bench_results.into_iter().next().unwrap().err().unwrap();
189
190 match first_error {
191 AutotuneError::Unknown(reason) => panic!("{reason}"),
192 AutotuneError::PanicUnwind(err) => {
193 resume_unwind(ManuallyDrop::into_inner(err));
194 }
195 }
196 }
197
198 bench_results.sort_by(|a, b| {
200 let a = a
201 .as_ref()
202 .map(|r| r.computation.median)
203 .unwrap_or(Duration::MAX);
204 let b = b
205 .as_ref()
206 .map(|r| r.computation.median)
207 .unwrap_or(Duration::MAX);
208
209 a.cmp(&b)
210 });
211
212 let result = bench_results.first().expect("At least one kernel needed. ");
214
215 let fastest_index = if let Ok(result) = result {
216 let top_times = bench_results
217 .iter()
218 .map(|r| {
219 r.as_ref()
220 .map(|r| r.computation.median)
221 .unwrap_or(Duration::MAX)
222 })
223 .take(3)
224 .collect::<Vec<_>>();
225 log::info!(
226 "Fastest result {}-{key}. \n Top 3 times: {top_times:?}",
227 result.name,
228 );
229
230 result.index
231 } else {
232 0
233 };
234
235 sender
236 .send(AutotuneMessage::Done {
237 key,
238 fastest_index,
239 #[cfg(autotune_persistent_cache)]
240 checksum,
241 })
242 .await
243 .expect("Autotune results channel closed");
244 });
245 }
246}
247
248fn spawn_benchmark_task(future: impl Future<Output = ()> + Send + 'static) {
249 #[cfg(target_family = "wasm")]
251 wasm_bindgen_futures::spawn_local(future);
252
253 #[cfg(not(target_family = "wasm"))]
261 future::block_on(future);
262}