1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::format::format_debug;
7use cubecl_common::profile::ProfileDuration;
8use hashbrown::HashSet;
9
10use core::time::Duration;
11
12use alloc::string::{String, ToString};
13use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
14
15use crate::config::{Logger, autotune::AutotuneLogLevel};
16use crate::server::LaunchError;
17use crate::tune::{AutotuneResult, TuneBenchmark, TuneCache};
18use crate::{client::ComputeClient, runtime::Runtime};
19
20use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
21
22#[derive(Debug)]
23pub struct Tuner<K: AutotuneKey> {
25 tune_cache: TuneCache<K>,
26 logger: Logger,
27 channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
28 pub(crate) autotuning: HashSet<K>,
29}
30
31#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
33#[derive(new, Debug, Clone, PartialEq, Eq)]
34pub struct AutotuneOutcome {
35 name: String,
36 index: usize,
37 computation: BenchmarkComputations,
38}
39
40impl core::fmt::Display for AutotuneOutcome {
41 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42 write!(
43 f,
44 "Autotune[{}] name {} => {:?}",
45 self.index, self.name, self.computation
46 )
47 }
48}
49
50enum AutotuneMessage<K> {
51 Done {
52 key: K,
53 fastest_index: usize,
54 results: Vec<AutotuneResult>,
55 #[cfg(std_io)]
56 checksum: String,
57 context_logs: Option<String>,
58 },
59 #[allow(dead_code)]
60 Pending(K),
61}
62
63#[derive(Debug, Clone)]
65#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
66pub enum AutotuneError {
67 Unknown {
69 name: String,
71 err: String,
73 },
74 InvalidSamples {
76 name: String,
78 },
79 NoValidKernelFound {
85 context: String,
87 },
88 Skip {
90 name: String,
92 },
93
94 Launch(LaunchError),
96}
97
98impl From<LaunchError> for AutotuneError {
99 fn from(value: LaunchError) -> Self {
100 Self::Launch(value)
101 }
102}
103
104#[allow(clippy::new_without_default)]
105impl<K: AutotuneKey> Tuner<K> {
106 pub fn new(name: &str, device_id: &str) -> Self {
108 let channel = async_channel::unbounded();
109
110 Self {
111 tune_cache: TuneCache::new(name, device_id),
112 logger: Logger::new(),
113 channel,
114 autotuning: HashSet::new(),
115 }
116 }
117
118 pub fn fastest(&self, key: &K) -> TuneCacheResult {
120 self.tune_cache.fastest(key)
121 }
122
123 #[cfg(std_io)]
125 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
126 if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
127 self.logger
128 .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
129 }
130 self.tune_cache.validate_checksum(key, checksum)
131 }
132
133 fn handle_result(&mut self, msg: AutotuneMessage<K>) {
135 match msg {
136 AutotuneMessage::Pending(key) => {
137 self.tune_cache.mark_pending(key);
138 }
139 AutotuneMessage::Done {
140 key,
141 fastest_index,
142 results,
143 #[cfg(std_io)]
144 checksum,
145 context_logs,
146 } => {
147 match self.logger.log_level_autotune() {
148 AutotuneLogLevel::Minimal => {
149 let top_times = results
150 .iter()
151 .map(|r| {
152 let time = r
153 .outcome
154 .as_ref()
155 .map(|r| r.computation.median)
156 .unwrap_or(Duration::MAX);
157
158 let index = r.outcome.as_ref().map(|r| r.index).unwrap_or_default();
159 (index, time)
160 })
161 .take(3)
162 .collect::<Vec<_>>();
163
164 let result = results
165 .first()
166 .expect("At least one kernel needed.")
167 .outcome
168 .as_ref()
169 .expect("At least one kernel has to succeed.");
170
171 let context = match &context_logs {
172 Some(context) => context,
173 None => "",
174 };
175 self.logger.log_autotune(&format!(
176 "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}",
177 result.name,
178 ));
179 }
180 AutotuneLogLevel::Full => {
181 let result = results
182 .first()
183 .expect("At least one kernel needed.")
184 .outcome
185 .as_ref()
186 .expect("At least one kernel has to succeed.");
187
188 let context = match &context_logs {
189 Some(context) => context,
190 None => "",
191 };
192 self.logger.log_autotune(&format!(
193 "Fastest result {}-{key}. Context: {context}",
194 result.name,
195 ));
196
197 for result in results.iter() {
198 match &result.outcome {
199 Ok(val) => {
200 self.logger.log_autotune(&format!("{val}"));
201 }
202 Err(err) => self.logger.log_autotune(&format!("{err:?}")),
203 }
204 }
205 }
206 AutotuneLogLevel::Disabled => {}
207 };
208
209 self.tune_cache.cache_insert(key.clone(), fastest_index);
210
211 #[cfg(std_io)]
212 {
213 self.tune_cache
214 .persistent_cache_insert(key, checksum, fastest_index, results);
215 }
216 }
217 }
218 }
219
220 pub fn handle_results(&mut self) {
222 while let Ok(msg) = self.channel.1.try_recv() {
225 self.handle_result(msg);
226 }
227 }
228
229 pub fn prepare_autotune<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput>(
231 &self,
232 key: K,
233 inputs: &In,
234 tunables: &TunableSet<K, In, Out>,
235 client: &ComputeClient<R>,
236 ) -> Box<dyn FnOnce()> {
237 log::info!("Tuning {key}");
238
239 let sender = self.channel.0.clone();
241
242 let autotunables = tunables.autotunables();
243 let mut results: Vec<AutotuneResult> = Vec::with_capacity(autotunables.len());
244
245 for a in autotunables.iter() {
246 results.push(AutotuneResult::error(AutotuneError::Skip {
247 name: a.name().to_string(),
248 }));
249 }
250
251 if autotunables.len() == 1 {
252 let message = AutotuneMessage::Done {
253 key,
254 fastest_index: 0,
255 results,
256 #[cfg(std_io)]
257 checksum: tunables.compute_checksum(),
258 context_logs: None,
259 };
260
261 return Box::new(move || {
262 sender
263 .try_send(message)
264 .expect("Loss message channel somehow")
265 });
266 }
267
268 let client = client.clone();
269 let key_cloned = key.clone();
270 let plan = tunables.plan(&key);
271 let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
272
273 #[cfg(std_io)]
274 let checksum = tunables.compute_checksum();
275 let context_logs = match self.logger.log_level_autotune() {
276 AutotuneLogLevel::Disabled => false,
277 AutotuneLogLevel::Minimal => false,
278 AutotuneLogLevel::Full => true,
279 };
280
281 let fut_result = async move {
282 let test_inputs = inputs_generator();
283
284 Self::generate_tune_message(
285 key_cloned,
286 &client,
287 plan,
288 autotunables,
289 test_inputs,
290 results,
291 #[cfg(std_io)]
292 checksum,
293 context_logs,
294 )
295 .await
296 };
297
298 Box::new(move || {
299 let message = {
300 cfg_if::cfg_if! {
301 if #[cfg(target_family = "wasm")] {
302 let sender = sender.clone();
303
304 let send_fut = async move {
305 let _ = sender.send(fut_result.await).await;
308 };
309 wasm_bindgen_futures::spawn_local(send_fut);
311 AutotuneMessage::Pending(key)
313 } else {
314 cubecl_common::future::block_on(fut_result)
315 }
316 }
317 };
318
319 sender
321 .try_send(message)
322 .expect("Loss message channel somehow");
323 })
324 }
325
326 #[allow(clippy::too_many_arguments)]
327 async fn generate_tune_message<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
328 key: K,
329 client: &ComputeClient<R>,
330 mut plan: TunePlan,
331 autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
332 test_inputs: In,
333 mut results: Vec<AutotuneResult>,
334 #[cfg(std_io)] checksum: String,
335 context_logs: bool,
336 ) -> AutotuneMessage<K> {
337 let context_logs = match Self::execute_tune_plan(
338 client,
339 &mut plan,
340 autotunables,
341 &test_inputs,
342 &mut results,
343 context_logs,
344 )
345 .await
346 {
347 Ok(context_logs) => context_logs,
348 Err(err) => {
349 panic!("Can't execute the autotune plan for key: {key:?}\n - Error: {err:?}");
350 }
351 };
352
353 results.sort_by(|a, b| {
355 let a = a
356 .outcome
357 .as_ref()
358 .map(|r| r.computation.median)
359 .unwrap_or(Duration::MAX);
360 let b = b
361 .outcome
362 .as_ref()
363 .map(|r| r.computation.median)
364 .unwrap_or(Duration::MAX);
365
366 a.cmp(&b)
367 });
368
369 let result = results
371 .first()
372 .expect("At least one kernel needed.")
373 .outcome
374 .as_ref()
375 .expect("At least one kernel has to succeed.");
376
377 AutotuneMessage::Done {
378 key,
379 fastest_index: result.index,
380 results,
381 #[cfg(std_io)]
382 checksum,
383 context_logs,
384 }
385 }
386
387 async fn execute_tune_plan<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
388 client: &ComputeClient<R>,
389 plan: &mut TunePlan,
390 autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
391 test_inputs: &In,
392 results: &mut [AutotuneResult],
393 context_logs: bool,
394 ) -> Result<Option<String>, AutotuneError> {
395 #[derive(Debug)]
396 #[allow(unused_variables, dead_code)] struct Context<'a> {
398 plan: &'a TunePlan,
399 results: &'a [AutotuneResult],
400 }
401
402 let mut context_logs = match context_logs {
403 true => Some("".to_string()),
404 false => None,
405 };
406
407 loop {
408 let mut num_success = 0;
409 let tunable_indices = plan.next(context_logs.as_mut());
410
411 if tunable_indices.is_empty() {
412 return Err(AutotuneError::NoValidKernelFound {
413 context: format_debug(&Context { plan, results }),
414 });
415 }
416
417 for index in tunable_indices {
418 let op = &autotunables[index];
419 let name = op.name().to_string();
420 let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
421 let profiles = tuner.profile().map(|bench| (name, index, bench));
422
423 match profiles {
424 Ok(result) => {
425 let (name, index, profiles) = result;
427 let result = Self::process_autotune(name, index, profiles).await;
428 match result {
429 Ok(val) => {
430 results[index] = AutotuneResult::success(val);
431 num_success += 1;
432 }
433 Err(err) => {
434 results[index] = AutotuneResult::error(err);
435 }
436 }
437 }
438 Err(err) => {
439 results[index] = AutotuneResult::error(err);
440 }
441 }
442 }
443
444 if num_success > 0 {
445 break;
446 }
447 }
448
449 Ok(context_logs)
450 }
451
452 async fn process_autotune(
453 name: String,
454 index: usize,
455 profiles: Vec<ProfileDuration>,
456 ) -> Result<AutotuneOutcome, AutotuneError> {
457 let mut durations = Vec::new();
458 if !profiles.is_empty() {
459 let timing_method = profiles.first().unwrap().timing_method();
460 for profile in profiles {
461 durations.push(profile.resolve().await.duration());
462 }
463 let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
464
465 Ok(AutotuneOutcome::new(
466 name,
467 index,
468 BenchmarkComputations::new(&bench_durations),
469 ))
470 } else {
471 Err(AutotuneError::Unknown {
472 name,
473 err: "No profiling available".to_string(),
474 })
475 }
476 }
477}
478
479#[cfg(feature = "autotune-checks")]
480pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
481 mut checks_outputs: Vec<Result<O, AutotuneError>>,
482) {
483 let reference = checks_outputs.remove(checks_outputs.len() - 1);
484
485 if let Ok(reference) = reference {
486 for other in checks_outputs.into_iter().flatten() {
487 reference.check_equivalence(other);
488 }
489 }
490}