1use alloc::boxed::Box;
2use alloc::format;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use async_channel::{Receiver, Sender};
6use cubecl_common::profile::ProfileDuration;
7use hashbrown::HashSet;
8
9use core::time::Duration;
10
11use alloc::string::{String, ToString};
12use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
13
14use crate::client::ComputeClient;
15use crate::config::{Logger, autotune::AutotuneLogLevel};
16use crate::server::ComputeServer;
17use crate::tune::{TuneBenchmark, TuneCache};
18
19use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan};
20
21#[derive(Debug)]
22pub struct Tuner<K: AutotuneKey> {
24 tune_cache: TuneCache<K>,
25 logger: Logger,
26 channel: (Sender<AutotuneMessage<K>>, Receiver<AutotuneMessage<K>>),
27 pub(crate) autotuning: HashSet<K>,
28}
29
30#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize, PartialEq, Eq))]
32#[derive(new, Debug, Clone)]
33pub struct AutotuneOutcome {
34 name: String,
35 index: usize,
36 computation: BenchmarkComputations,
37}
38
39impl core::fmt::Display for AutotuneOutcome {
40 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41 write!(
42 f,
43 "Autotune[{}] name {} => {:?}",
44 self.index, self.name, self.computation
45 )
46 }
47}
48
49enum AutotuneMessage<K> {
50 Done {
51 key: K,
52 fastest_index: usize,
53 results: Vec<Result<AutotuneOutcome, AutotuneError>>,
54 #[cfg(std_io)]
55 checksum: String,
56 },
57 #[allow(dead_code)]
58 Pending(K),
59}
60
61#[derive(Debug, PartialEq, Eq, Clone)]
63#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
64pub enum AutotuneError {
65 Unknown(String),
67 InvalidSamples,
69 Skip,
71}
72
73impl From<String> for AutotuneError {
74 fn from(value: String) -> Self {
75 Self::Unknown(value)
76 }
77}
78
79#[allow(clippy::new_without_default)]
80impl<K: AutotuneKey> Tuner<K> {
81 pub fn new(name: &str, device_id: &str) -> Self {
83 let channel = async_channel::unbounded();
84
85 Self {
86 tune_cache: TuneCache::new(name, device_id),
87 logger: Logger::new(),
88 channel,
89 autotuning: HashSet::new(),
90 }
91 }
92
93 pub fn fastest(&self, key: &K) -> TuneCacheResult {
95 self.tune_cache.fastest(key)
96 }
97
98 #[cfg(std_io)]
100 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
101 if let AutotuneLogLevel::Full = self.logger.log_level_autotune() {
102 self.logger
103 .log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
104 }
105 self.tune_cache.validate_checksum(key, checksum)
106 }
107
108 fn handle_result(&mut self, msg: AutotuneMessage<K>) {
110 match msg {
111 AutotuneMessage::Pending(key) => {
112 self.tune_cache.mark_pending(key);
113 }
114 AutotuneMessage::Done {
115 key,
116 fastest_index,
117 results,
118 #[cfg(std_io)]
119 checksum,
120 } => {
121 match self.logger.log_level_autotune() {
122 AutotuneLogLevel::Minimal => {
123 let top_times = results
124 .iter()
125 .map(|r| {
126 let time = r
127 .as_ref()
128 .map(|r| r.computation.median)
129 .unwrap_or(Duration::MAX);
130
131 let index = r.as_ref().map(|r| r.index).unwrap_or_default();
132 (index, time)
133 })
134 .take(3)
135 .collect::<Vec<_>>();
136
137 let result = results
138 .first()
139 .expect("At least one kernel needed.")
140 .as_ref()
141 .expect("At least one kernel has to succeed.");
142
143 self.logger.log_autotune(&format!(
144 "Fastest result {}-{key}. \n Top 3 times: {top_times:?}",
145 result.name,
146 ));
147 }
148 AutotuneLogLevel::Full => {
149 let result = results
150 .first()
151 .expect("At least one kernel needed.")
152 .as_ref()
153 .expect("At least one kernel has to succeed.");
154
155 self.logger
156 .log_autotune(&format!("Fastest result {}-{key}.", result.name,));
157
158 for result in results.iter() {
159 match result {
160 Ok(val) => {
161 self.logger.log_autotune(&format!("{val}"));
162 }
163 Err(err) => self.logger.log_autotune(&format!("{err:?}")),
164 }
165 }
166 }
167 AutotuneLogLevel::Disabled => {}
168 };
169
170 self.tune_cache.cache_insert(key.clone(), fastest_index);
171
172 #[cfg(std_io)]
173 {
174 self.tune_cache
175 .persistent_cache_insert(key, checksum, fastest_index, results);
176 }
177 }
178 }
179 }
180
181 pub fn handle_results(&mut self) {
183 while let Ok(msg) = self.channel.1.try_recv() {
186 self.handle_result(msg);
187 }
188 }
189
190 pub fn prepare_autotune<
192 S: ComputeServer + 'static,
193 In: Clone + Send + 'static,
194 Out: AutotuneOutput,
195 >(
196 &self,
197 key: K,
198 inputs: &In,
199 tunables: &TunableSet<K, In, Out>,
200 client: &ComputeClient<S>,
201 ) -> Box<dyn FnOnce()> {
202 log::info!("Tuning {key}");
203
204 let sender = self.channel.0.clone();
206
207 let autotunables = tunables.autotunables();
208 let mut results = Vec::with_capacity(autotunables.len());
209
210 for _ in 0..autotunables.len() {
211 results.push(Err(AutotuneError::Skip));
212 }
213
214 if autotunables.len() == 1 {
215 let message = AutotuneMessage::Done {
216 key,
217 fastest_index: 0,
218 results,
219 #[cfg(std_io)]
220 checksum: tunables.compute_checksum(),
221 };
222
223 return Box::new(move || {
224 sender
225 .try_send(message)
226 .expect("Loss message channel somehow")
227 });
228 }
229
230 let client = client.clone();
231 let key_cloned = key.clone();
232 let plan = tunables.plan(&key);
233 let inputs_generator = tunables.inputs_generator(&key.clone(), inputs);
234
235 #[cfg(std_io)]
236 let checksum = tunables.compute_checksum();
237
238 let fut_result = async move {
239 let test_inputs = inputs_generator();
240
241 Self::generate_tune_message(
242 key_cloned,
243 &client,
244 plan,
245 autotunables,
246 test_inputs,
247 results,
248 #[cfg(std_io)]
249 checksum,
250 )
251 .await
252 };
253
254 Box::new(move || {
255 let message = {
256 cfg_if::cfg_if! {
257 if #[cfg(target_family = "wasm")] {
258 let sender = sender.clone();
259
260 let send_fut = async move {
261 let _ = sender.send(fut_result.await).await;
264 };
265 wasm_bindgen_futures::spawn_local(send_fut);
267 AutotuneMessage::Pending(key)
269 } else {
270 cubecl_common::future::block_on(fut_result)
271 }
272 }
273 };
274
275 sender
277 .try_send(message)
278 .expect("Loss message channel somehow");
279 })
280 }
281
282 async fn generate_tune_message<
283 In: Clone + Send + 'static,
284 Out: AutotuneOutput,
285 S: ComputeServer + 'static,
286 >(
287 key: K,
288 client: &ComputeClient<S>,
289 mut plan: TunePlan,
290 autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
291 test_inputs: In,
292 mut results: Vec<Result<AutotuneOutcome, AutotuneError>>,
293 #[cfg(std_io)] checksum: String,
294 ) -> AutotuneMessage<K> {
295 Self::execute_tune_plan(client, &mut plan, autotunables, &test_inputs, &mut results).await;
296
297 results.sort_by(|a, b| {
299 let a = a
300 .as_ref()
301 .map(|r| r.computation.median)
302 .unwrap_or(Duration::MAX);
303 let b = b
304 .as_ref()
305 .map(|r| r.computation.median)
306 .unwrap_or(Duration::MAX);
307
308 a.cmp(&b)
309 });
310
311 let result = results
313 .first()
314 .expect("At least one kernel needed.")
315 .as_ref()
316 .expect("At least one kernel has to succeed.");
317
318 AutotuneMessage::Done {
319 key,
320 fastest_index: result.index,
321 results,
322 #[cfg(std_io)]
323 checksum,
324 }
325 }
326
327 async fn execute_tune_plan<
328 In: Clone + Send + 'static,
329 Out: AutotuneOutput,
330 S: ComputeServer + 'static,
331 >(
332 client: &ComputeClient<S>,
333 plan: &mut TunePlan,
334 autotunables: Vec<Arc<dyn TuneFn<Inputs = In, Output = Out> + 'static>>,
335 test_inputs: &In,
336 results: &mut [Result<AutotuneOutcome, AutotuneError>],
337 ) {
338 loop {
339 let mut num_autotuned = 0;
340
341 let tunable_indices = plan.next();
342
343 if tunable_indices.is_empty() {
344 panic!("No autotune was flagged as valid for the problem.")
345 }
346
347 for index in tunable_indices {
348 let op = &autotunables[index];
349 let name = op.name().to_string();
350 let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone());
351 let profiles = tuner.profile().map(|bench| (name, index, bench));
352
353 match profiles {
354 Ok(result) => {
355 let (name, index, profiles) = result;
357 let result = Self::process_autotune(name, index, profiles).await;
358 match result {
359 Ok(val) => {
360 results[index] = Ok(val);
361 num_autotuned += 1;
362 }
363 Err(err) => {
364 results[index] = Err(err);
365 }
366 }
367 }
368 Err(err) => {
369 results[index] = Err(err);
370 }
371 }
372 }
373
374 if num_autotuned > 0 {
375 break;
376 }
377 }
378 }
379
380 async fn process_autotune(
381 name: String,
382 index: usize,
383 profiles: Vec<ProfileDuration>,
384 ) -> Result<AutotuneOutcome, AutotuneError> {
385 let mut durations = Vec::new();
386 if !profiles.is_empty() {
387 let timing_method = profiles.first().unwrap().timing_method();
388 for profile in profiles {
389 durations.push(profile.resolve().await.duration());
390 }
391 let bench_durations = BenchmarkDurations::from_durations(timing_method, durations);
392
393 Ok(AutotuneOutcome::new(
394 name,
395 index,
396 BenchmarkComputations::new(&bench_durations),
397 ))
398 } else {
399 Err(AutotuneError::Unknown(format!(
400 "Runtime error while profiling {name}."
401 )))
402 }
403 }
404}
405
406#[cfg(feature = "autotune-checks")]
407pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
408 mut checks_outputs: Vec<Result<O, AutotuneError>>,
409) {
410 let reference = checks_outputs.remove(checks_outputs.len() - 1);
411
412 if let Ok(reference) = reference {
413 for other in checks_outputs.into_iter().flatten() {
414 reference.check_equivalence(other);
415 }
416 }
417}