1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::format::format_debug;
7use cubecl_common::profile::ProfileDuration;
8use hashbrown::HashSet;
9
10use core::time::Duration;
11
12use alloc::string::{String, ToString};
13use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
14
15use crate::client::ComputeClient;
16use crate::config::{Logger, autotune::AutotuneLogLevel};
17use crate::server::ComputeServer;
18use crate::tune::{TuneBenchmark, TuneCache};
19
20use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
21
22#[derive(Debug)]
23pub struct Tuner<K: AutotuneKey> {
25 tune_cache: TuneCache<K>,
26 logger: Logger,
27 channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
28 pub(crate) autotuning: HashSet<K>,
29}
30
31#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize, PartialEq, Eq))]
33#[derive(new, Debug, Clone)]
34pub struct AutotuneOutcome {
35 name: String,
36 index: usize,
37 computation: BenchmarkComputations,
38}
39
40impl core::fmt::Display for AutotuneOutcome {
41 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42 write!(
43 f,
44 "Autotune[{}] name {} => {:?}",
45 self.index, self.name, self.computation
46 )
47 }
48}
49
50enum AutotuneMessage<K> {
51 Done {
52 key: K,
53 fastest_index: usize,
54 results: Vec<Result<AutotuneOutcome, AutotuneError>>,
55 #[cfg(std_io)]
56 checksum: String,
57 },
58 #[allow(dead_code)]
59 Pending(K),
60}
61
62#[derive(Debug, PartialEq, Eq, Clone)]
64#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
65pub enum AutotuneError {
66 Unknown(String),
68 InvalidSamples,
70 NoValidKernelFound {
76 context: String,
78 },
79 Skip,
81}
82
83impl From<String> for AutotuneError {
84 fn from(value: String) -> Self {
85 Self::Unknown(value)
86 }
87}
88
89#[allow(clippy::new_without_default)]
90impl<K: AutotuneKey> Tuner<K> {
91 pub fn new(name: &str, device_id: &str) -> Self {
93 let channel = async_channel::unbounded();
94
95 Self {
96 tune_cache: TuneCache::new(name, device_id),
97 logger: Logger::new(),
98 channel,
99 autotuning: HashSet::new(),
100 }
101 }
102
103 pub fn fastest(&self, key: &K) -> TuneCacheResult {
105 self.tune_cache.fastest(key)
106 }
107
108 #[cfg(std_io)]
110 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
111 if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
112 self.logger
113 .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
114 }
115 self.tune_cache.validate_checksum(key, checksum)
116 }
117
118 fn handle_result(&mut self, msg: AutotuneMessage<K>) {
120 match msg {
121 AutotuneMessage::Pending(key) => {
122 self.tune_cache.mark_pending(key);
123 }
124 AutotuneMessage::Done {
125 key,
126 fastest_index,
127 results,
128 #[cfg(std_io)]
129 checksum,
130 } => {
131 match self.logger.log_level_autotune() {
132 AutotuneLogLevel::Minimal => {
133 let top_times = results
134 .iter()
135 .map(|r| {
136 let time = r
137 .as_ref()
138 .map(|r| r.computation.median)
139 .unwrap_or(Duration::MAX);
140
141 let index = r.as_ref().map(|r| r.index).unwrap_or_default();
142 (index, time)
143 })
144 .take(3)
145 .collect::<Vec<_>>();
146
147 let result = results
148 .first()
149 .expect("At least one kernel needed.")
150 .as_ref()
151 .expect("At least one kernel has to succeed.");
152
153 self.logger.log_autotune(&format!(
154 "Fastest result {}-{key}. \n Top 3 times: {top_times:?}",
155 result.name,
156 ));
157 }
158 AutotuneLogLevel::Full => {
159 let result = results
160 .first()
161 .expect("At least one kernel needed.")
162 .as_ref()
163 .expect("At least one kernel has to succeed.");
164
165 self.logger
166 .log_autotune(&format!("Fastest result {}-{key}.", result.name,));
167
168 for result in results.iter() {
169 match result {
170 Ok(val) => {
171 self.logger.log_autotune(&format!("{val}"));
172 }
173 Err(err) => self.logger.log_autotune(&format!("{err:?}")),
174 }
175 }
176 }
177 AutotuneLogLevel::Disabled => {}
178 };
179
180 self.tune_cache.cache_insert(key.clone(), fastest_index);
181
182 #[cfg(std_io)]
183 {
184 self.tune_cache
185 .persistent_cache_insert(key, checksum, fastest_index, results);
186 }
187 }
188 }
189 }
190
191 pub fn handle_results(&mut self) {
193 while let Ok(msg) = self.channel.1.try_recv() {
196 self.handle_result(msg);
197 }
198 }
199
200 pub fn prepare_autotune<
202 S: ComputeServer + 'static,
203 In: Clone + Send + 'static,
204 Out: AutotuneOutput,
205 >(
206 &self,
207 key: K,
208 inputs: &In,
209 tunables: &TunableSet<K, In, Out>,
210 client: &ComputeClient<S>,
211 ) -> Box<dyn FnOnce()> {
212 log::info!("Tuning {key}");
213
214 let sender = self.channel.0.clone();
216
217 let autotunables = tunables.autotunables();
218 let mut results = Vec::with_capacity(autotunables.len());
219
220 for _ in 0..autotunables.len() {
221 results.push(Err(AutotuneError::Skip));
222 }
223
224 if autotunables.len() == 1 {
225 let message = AutotuneMessage::Done {
226 key,
227 fastest_index: 0,
228 results,
229 #[cfg(std_io)]
230 checksum: tunables.compute_checksum(),
231 };
232
233 return Box::new(move || {
234 sender
235 .try_send(message)
236 .expect("Loss message channel somehow")
237 });
238 }
239
240 let client = client.clone();
241 let key_cloned = key.clone();
242 let plan = tunables.plan(&key);
243 let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
244
245 #[cfg(std_io)]
246 let checksum = tunables.compute_checksum();
247
248 let fut_result = async move {
249 let test_inputs = inputs_generator();
250
251 Self::generate_tune_message(
252 key_cloned,
253 &client,
254 plan,
255 autotunables,
256 test_inputs,
257 results,
258 #[cfg(std_io)]
259 checksum,
260 )
261 .await
262 };
263
264 Box::new(move || {
265 let message = {
266 cfg_if::cfg_if! {
267 if #[cfg(target_family = "wasm")] {
268 let sender = sender.clone();
269
270 let send_fut = async move {
271 let _ = sender.send(fut_result.await).await;
274 };
275 wasm_bindgen_futures::spawn_local(send_fut);
277 AutotuneMessage::Pending(key)
279 } else {
280 cubecl_common::future::block_on(fut_result)
281 }
282 }
283 };
284
285 sender
287 .try_send(message)
288 .expect("Loss message channel somehow");
289 })
290 }
291
292 async fn generate_tune_message<
293 In: Clone + Send + 'static,
294 Out: AutotuneOutput,
295 S: ComputeServer + 'static,
296 >(
297 key: K,
298 client: &ComputeClient<S>,
299 mut plan: TunePlan,
300 autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
301 test_inputs: In,
302 mut results: Vec<Result<AutotuneOutcome, AutotuneError>>,
303 #[cfg(std_io)] checksum: String,
304 ) -> AutotuneMessage<K> {
305 match Self::execute_tune_plan(client, &mut plan, autotunables, &test_inputs, &mut results)
306 .await
307 {
308 Ok(_) => {}
309 Err(err) => {
310 panic!("Can't execute the autotune plan for key: {key:?}\n - Error: {err:?}");
311 }
312 };
313
314 results.sort_by(|a, b| {
316 let a = a
317 .as_ref()
318 .map(|r| r.computation.median)
319 .unwrap_or(Duration::MAX);
320 let b = b
321 .as_ref()
322 .map(|r| r.computation.median)
323 .unwrap_or(Duration::MAX);
324
325 a.cmp(&b)
326 });
327
328 let result = results
330 .first()
331 .expect("At least one kernel needed.")
332 .as_ref()
333 .expect("At least one kernel has to succeed.");
334
335 AutotuneMessage::Done {
336 key,
337 fastest_index: result.index,
338 results,
339 #[cfg(std_io)]
340 checksum,
341 }
342 }
343
344 async fn execute_tune_plan<
345 In: Clone + Send + 'static,
346 Out: AutotuneOutput,
347 S: ComputeServer + 'static,
348 >(
349 client: &ComputeClient<S>,
350 plan: &mut TunePlan,
351 autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
352 test_inputs: &In,
353 results: &mut [Result<AutotuneOutcome, AutotuneError>],
354 ) -> Result<(), AutotuneError> {
355 #[derive(Debug)]
356 #[allow(unused_variables, dead_code)] struct Context<'a> {
358 plan: &'a TunePlan,
359 results: &'a [Result<AutotuneOutcome, AutotuneError>],
360 }
361
362 loop {
363 let mut num_success = 0;
364 let tunable_indices = plan.next();
365
366 if tunable_indices.is_empty() {
367 return Err(AutotuneError::NoValidKernelFound {
368 context: format_debug(&Context { plan, results }),
369 });
370 }
371
372 for index in tunable_indices {
373 let op = &autotunables[index];
374 let name = op.name().to_string();
375 let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
376 let profiles = tuner.profile().map(|bench| (name, index, bench));
377
378 match profiles {
379 Ok(result) => {
380 let (name, index, profiles) = result;
382 let result = Self::process_autotune(name, index, profiles).await;
383 match result {
384 Ok(val) => {
385 results[index] = Ok(val);
386 num_success += 1;
387 }
388 Err(err) => {
389 results[index] = Err(err);
390 }
391 }
392 }
393 Err(err) => {
394 results[index] = Err(err);
395 }
396 }
397 }
398
399 if num_success > 0 {
400 break;
401 }
402 }
403
404 Ok(())
405 }
406
407 async fn process_autotune(
408 name: String,
409 index: usize,
410 profiles: Vec<ProfileDuration>,
411 ) -> Result<AutotuneOutcome, AutotuneError> {
412 let mut durations = Vec::new();
413 if !profiles.is_empty() {
414 let timing_method = profiles.first().unwrap().timing_method();
415 for profile in profiles {
416 durations.push(profile.resolve().await.duration());
417 }
418 let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
419
420 Ok(AutotuneOutcome::new(
421 name,
422 index,
423 BenchmarkComputations::new(&bench_durations),
424 ))
425 } else {
426 Err(AutotuneError::Unknown(format!(
427 "Runtime error while profiling {name}."
428 )))
429 }
430 }
431}
432
433#[cfg(feature = "autotune-checks")]
434pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
435 mut checks_outputs: Vec<Result<O, AutotuneError>>,
436) {
437 let reference = checks_outputs.remove(checks_outputs.len() - 1);
438
439 if let Ok(reference) = reference {
440 for other in checks_outputs.into_iter().flatten() {
441 reference.check_equivalence(other);
442 }
443 }
444}