1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::profile::ProfileDuration;
7use hashbrown::HashSet;
8
9use core::time::Duration;
10
11use alloc::string::{String, ToString};
12use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
13
14use crate::config::{Logger, autotune::AutotuneLogLevel};
15use crate::server::LaunchError;
16use crate::tune::{AutotuneResult, TuneBenchmark, TuneCache};
17use crate::{client::ComputeClient, runtime::Runtime};
18
19use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
20
21#[derive(Debug)]
22pub struct Tuner<K: AutotuneKey> {
24 tune_cache: TuneCache<K>,
25 logger: Logger,
26 channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
27 pub(crate) autotuning: HashSet<K>,
28}
29
30#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
32#[derive(new, Debug, Clone, PartialEq, Eq)]
33pub struct AutotuneOutcome {
34 name: String,
35 index: usize,
36 computation: BenchmarkComputations,
37}
38
39impl core::fmt::Display for AutotuneOutcome {
40 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41 write!(
42 f,
43 "Autotune[{}] name {} => {:?}",
44 self.index, self.name, self.computation
45 )
46 }
47}
48
49enum AutotuneMessage<K> {
50 Done {
51 key: K,
52 fastest_index: usize,
53 results: Vec<AutotuneResult>,
54 #[cfg(std_io)]
55 checksum: String,
56 context_logs: Option<String>,
57 },
58 #[allow(dead_code)]
59 Pending(K),
60}
61
62#[derive(Debug, Clone)]
64#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
65pub enum AutotuneError {
66 Unknown {
68 name: String,
70 err: String,
72 },
73 InvalidSamples {
75 name: String,
77 },
78 NoValidKernelFound {
84 context: String,
86 },
87 Skip {
89 name: String,
91 },
92
93 Launch(LaunchError),
95}
96
97impl From<LaunchError> for AutotuneError {
98 fn from(value: LaunchError) -> Self {
99 Self::Launch(value)
100 }
101}
102
103#[allow(clippy::new_without_default)]
104impl<K: AutotuneKey> Tuner<K> {
105 pub fn new(name: &str, device_id: &str) -> Self {
107 let channel = async_channel::unbounded();
108
109 Self {
110 tune_cache: TuneCache::new(name, device_id),
111 logger: Logger::new(),
112 channel,
113 autotuning: HashSet::new(),
114 }
115 }
116
117 pub fn fastest(&self, key: &K) -> TuneCacheResult {
119 self.tune_cache.fastest(key)
120 }
121
122 #[cfg(std_io)]
124 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
125 if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
126 self.logger
127 .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
128 }
129 self.tune_cache.validate_checksum(key, checksum)
130 }
131
132 fn handle_result(&mut self, msg: AutotuneMessage<K>) {
134 match msg {
135 AutotuneMessage::Pending(key) => {
136 self.tune_cache.mark_pending(key);
137 }
138 AutotuneMessage::Done {
139 key,
140 fastest_index,
141 results,
142 #[cfg(std_io)]
143 checksum,
144 context_logs,
145 } => {
146 match self.logger.log_level_autotune() {
147 AutotuneLogLevel::Minimal => {
148 let top_times = results
149 .iter()
150 .map(|r| {
151 let time = r
152 .outcome
153 .as_ref()
154 .map(|r| r.computation.median)
155 .unwrap_or(Duration::MAX);
156
157 let index = r.outcome.as_ref().map(|r| r.index).unwrap_or_default();
158 (index, time)
159 })
160 .take(3)
161 .collect::<Vec<_>>();
162
163 let result = results
164 .first()
165 .expect("At least one kernel needed.")
166 .outcome
167 .as_ref()
168 .expect("At least one kernel has to succeed.");
169
170 let context = match &context_logs {
171 Some(context) => context,
172 None => "",
173 };
174 self.logger.log_autotune(&format!(
175 "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}",
176 result.name,
177 ));
178 }
179 AutotuneLogLevel::Full => {
180 let result = results
181 .first()
182 .expect("At least one kernel needed.")
183 .outcome
184 .as_ref()
185 .expect("At least one kernel has to succeed.");
186
187 let context = match &context_logs {
188 Some(context) => context,
189 None => "",
190 };
191 self.logger.log_autotune(&format!(
192 "Fastest result {}-{key}. Context: {context}",
193 result.name,
194 ));
195
196 for result in results.iter() {
197 match &result.outcome {
198 Ok(val) => {
199 self.logger.log_autotune(&format!("{val}"));
200 }
201 Err(err) => self.logger.log_autotune(&format!("{err:?}")),
202 }
203 }
204 }
205 AutotuneLogLevel::Disabled => {}
206 };
207
208 self.tune_cache.cache_insert(key.clone(), fastest_index);
209
210 #[cfg(std_io)]
211 {
212 self.tune_cache
213 .persistent_cache_insert(key, checksum, fastest_index, results);
214 }
215 }
216 }
217 }
218
219 pub fn handle_results(&mut self) {
221 while let Ok(msg) = self.channel.1.try_recv() {
224 self.handle_result(msg);
225 }
226 }
227
228 pub fn prepare_autotune<R: Runtime, In: Clone + Send + 'static, Out: AutotuneOutput>(
230 &self,
231 key: K,
232 inputs: &In,
233 tunables: &TunableSet<K, In, Out>,
234 client: &ComputeClient<R>,
235 ) -> Box<dyn FnOnce()> {
236 log::info!("Tuning {key}");
237
238 let sender = self.channel.0.clone();
240
241 let autotunables = tunables.autotunables();
242 let mut results: Vec<AutotuneResult> = Vec::with_capacity(autotunables.len());
243
244 for a in autotunables.iter() {
245 results.push(AutotuneResult::error(AutotuneError::Skip {
246 name: a.name().to_string(),
247 }));
248 }
249
250 if autotunables.len() == 1 {
251 let message = AutotuneMessage::Done {
252 key,
253 fastest_index: 0,
254 results,
255 #[cfg(std_io)]
256 checksum: tunables.compute_checksum(),
257 context_logs: None,
258 };
259
260 return Box::new(move || {
261 sender
262 .try_send(message)
263 .expect("Loss message channel somehow")
264 });
265 }
266
267 let client = client.clone();
268 let key_cloned = key.clone();
269 let plan = tunables.plan(&key);
270 let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
271
272 #[cfg(std_io)]
273 let checksum = tunables.compute_checksum();
274 let context_logs = match self.logger.log_level_autotune() {
275 AutotuneLogLevel::Disabled => false,
276 AutotuneLogLevel::Minimal => false,
277 AutotuneLogLevel::Full => true,
278 };
279
280 let fut_result = async move {
281 let test_inputs = inputs_generator();
282
283 Self::generate_tune_message(
284 key_cloned,
285 &client,
286 plan,
287 autotunables,
288 test_inputs,
289 results,
290 #[cfg(std_io)]
291 checksum,
292 context_logs,
293 )
294 .await
295 };
296
297 Box::new(move || {
298 let message = {
299 cfg_if::cfg_if! {
300 if #[cfg(target_family = "wasm")] {
301 let sender = sender.clone();
302
303 let send_fut = async move {
304 let _ = sender.send(fut_result.await).await;
307 };
308 wasm_bindgen_futures::spawn_local(send_fut);
310 AutotuneMessage::Pending(key)
312 } else {
313 cubecl_common::future::block_on(fut_result)
314 }
315 }
316 };
317
318 sender
320 .try_send(message)
321 .expect("Loss message channel somehow");
322 })
323 }
324
325 #[allow(clippy::too_many_arguments)]
326 async fn generate_tune_message<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
327 key: K,
328 client: &ComputeClient<R>,
329 mut plan: TunePlan,
330 autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
331 test_inputs: In,
332 mut results: Vec<AutotuneResult>,
333 #[cfg(std_io)] checksum: String,
334 context_logs: bool,
335 ) -> AutotuneMessage<K> {
336 let context_logs = match Self::execute_tune_plan(
337 client,
338 &mut plan,
339 autotunables,
340 &test_inputs,
341 &mut results,
342 context_logs,
343 )
344 .await
345 {
346 Ok(context_logs) => context_logs,
347 Err(err) => {
348 panic!("Can't execute the autotune plan for key: {key:?}\n - Error: {err:?}");
349 }
350 };
351
352 results.sort_by(|a, b| {
354 let a = a
355 .outcome
356 .as_ref()
357 .map(|r| r.computation.score())
358 .unwrap_or(u64::MAX);
359 let b = b
360 .outcome
361 .as_ref()
362 .map(|r| r.computation.score())
363 .unwrap_or(u64::MAX);
364
365 a.cmp(&b)
366 });
367
368 let result = results
370 .first()
371 .expect("At least one kernel needed.")
372 .outcome
373 .as_ref()
374 .expect("At least one kernel has to succeed.");
375
376 AutotuneMessage::Done {
377 key,
378 fastest_index: result.index,
379 results,
380 #[cfg(std_io)]
381 checksum,
382 context_logs,
383 }
384 }
385
386 async fn execute_tune_plan<In: Clone + Send + 'static, Out: AutotuneOutput, R: Runtime>(
387 client: &ComputeClient<R>,
388 plan: &mut TunePlan,
389 autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
390 test_inputs: &In,
391 results: &mut [AutotuneResult],
392 context_logs: bool,
393 ) -> Result<Option<String>, AutotuneError> {
394 #[derive(Debug)]
395 #[allow(unused_variables, dead_code)] struct Context<'a> {
397 plan: &'a TunePlan,
398 results: &'a [AutotuneResult],
399 }
400
401 let mut context_logs = match context_logs {
402 true => Some("".to_string()),
403 false => None,
404 };
405
406 loop {
407 let mut num_success = 0;
408 let tunable_indices = plan.next(context_logs.as_mut());
409
410 if tunable_indices.is_empty() {
411 return Err(AutotuneError::NoValidKernelFound {
412 context: format!("{:?}", &Context { plan, results }),
413 });
414 }
415
416 for index in tunable_indices {
417 let op = &autotunables[index];
418 let name = op.name().to_string();
419 let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
420 let profiles = tuner.profile().map(|bench| (name, index, bench));
421
422 match profiles {
423 Ok(result) => {
424 let (name, index, profiles) = result;
426 let result = Self::process_autotune(name, index, profiles).await;
427 match result {
428 Ok(val) => {
429 results[index] = AutotuneResult::success(val);
430 num_success += 1;
431 }
432 Err(err) => {
433 results[index] = AutotuneResult::error(err);
434 }
435 }
436 }
437 Err(err) => {
438 results[index] = AutotuneResult::error(err);
439 }
440 }
441 }
442
443 if num_success > 0 {
444 break;
445 }
446 }
447
448 Ok(context_logs)
449 }
450
451 async fn process_autotune(
452 name: String,
453 index: usize,
454 profiles: Vec<ProfileDuration>,
455 ) -> Result<AutotuneOutcome, AutotuneError> {
456 let mut durations = Vec::new();
457 if !profiles.is_empty() {
458 let timing_method = profiles.first().unwrap().timing_method();
459 for profile in profiles {
460 durations.push(profile.resolve().await.duration());
461 }
462 let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
463
464 Ok(AutotuneOutcome::new(
465 name,
466 index,
467 BenchmarkComputations::new(&bench_durations),
468 ))
469 } else {
470 Err(AutotuneError::Unknown {
471 name,
472 err: "No profiling available".to_string(),
473 })
474 }
475 }
476}
477
478#[cfg(feature = "autotune-checks")]
479pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
480 mut checks_outputs: Vec<Result<O, AutotuneError>>,
481) {
482 let reference = checks_outputs.remove(checks_outputs.len() - 1);
483
484 if let Ok(reference) = reference {
485 for other in checks_outputs.into_iter().flatten() {
486 reference.check_equivalence(other);
487 }
488 }
489}