1use alloc::format;
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use cubecl_common::profile::ProfileDuration;
5
6use core::time::Duration;
7
8use alloc::string::{String, ToString};
9use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
10
11use crate::config::{Logger, autotune::AutotuneLogLevel};
12use crate::server::LaunchError;
13use crate::tune::{AutotuneResult, TuneCache, tune_benchmark};
14use crate::{client::ComputeClient, runtime::Runtime};
15
16use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneInputs};
17
18#[derive(Debug)]
19pub struct Tuner<K: AutotuneKey> {
25 cache: Arc<spin::Mutex<TuneCache<K>>>,
26 logger: Arc<spin::Mutex<Logger>>,
27}
28
29#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
31#[derive(new, Debug, Clone, PartialEq, Eq)]
32pub struct AutotuneOutcome {
33 name: String,
34 index: usize,
35 computation: BenchmarkComputations,
36}
37
38impl core::fmt::Display for AutotuneOutcome {
39 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
40 write!(
41 f,
42 "Autotune[{}] name {} => {:?}",
43 self.index, self.name, self.computation
44 )
45 }
46}
47
48#[derive(Debug, Clone)]
50#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
51pub enum AutotuneError {
52 Unknown {
54 name: String,
56 err: String,
58 },
59 InvalidSamples {
61 name: String,
63 },
64 NoValidKernelFound {
70 context: String,
72 },
73 Skip {
75 name: String,
77 },
78
79 Launch(LaunchError),
81}
82
83impl From<LaunchError> for AutotuneError {
84 fn from(value: LaunchError) -> Self {
85 Self::Launch(value)
86 }
87}
88
89struct PendingBench {
91 index: usize,
92 name: String,
93 profiles: Vec<ProfileDuration>,
94}
95
96struct TuneRequest<K: AutotuneKey> {
99 key: K,
100 results: Vec<AutotuneResult>,
101 #[cfg(std_io)]
102 checksum: String,
103 context_logs: Option<String>,
104 pending: Vec<PendingBench>,
105}
106
107#[allow(clippy::new_without_default)]
108impl<K: AutotuneKey> Tuner<K> {
109 pub fn new(name: &str, device_id: &str) -> Self {
112 Self {
113 cache: Arc::new(spin::Mutex::new(TuneCache::new(name, device_id))),
114 logger: Arc::new(spin::Mutex::new(Logger::new())),
115 }
116 }
117
118 pub fn fastest(&self, key: &K) -> TuneCacheResult {
120 self.cache.lock().fastest(key)
121 }
122
123 pub fn check_tune<'a, R: Runtime, F: TuneInputs, Out: AutotuneOutput>(
126 &self,
127 key: &K,
128 inputs: &F::At<'a>,
129 tunables: &TunableSet<K, F, Out>,
130 #[cfg_attr(not(std_io), allow(unused))] checksum: impl FnOnce() -> String + Send + Sync,
131 client: &ComputeClient<R>,
132 ) -> TuneCacheResult
133 where
134 <F as TuneInputs>::At<'a>: Clone + Send,
135 {
136 {
137 let mut cache = self.cache.lock();
138 let cur = cache.fastest(key);
139
140 #[cfg(std_io)]
141 let cur = if matches!(cur, TuneCacheResult::Unchecked) {
142 let mut log = self.logger.lock();
143 let checksum = checksum();
144 if let AutotuneLogLevel::Full = log.log_level_autotune() {
145 log.log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
146 }
147 cache.validate_checksum(key, &checksum)
148 } else {
149 cur
150 };
151
152 match cur {
153 TuneCacheResult::Hit { .. } | TuneCacheResult::Pending => return cur,
154 TuneCacheResult::Miss | TuneCacheResult::Unchecked => {
155 cache.mark_pending(key.clone())
156 }
157 }
158 }
161
162 log::info!("Tuning {key}");
163
164 let autotunables = tunables.autotunables().collect::<Vec<_>>();
165 let mut results: Vec<AutotuneResult> = autotunables
166 .iter()
167 .map(|a| {
168 AutotuneResult::error(AutotuneError::Skip {
169 name: a.name.to_string(),
170 })
171 })
172 .collect();
173
174 #[cfg(std_io)]
175 let checksum = tunables.compute_checksum();
176
177 if results.len() == 1 {
179 self.cache.lock().cache_insert(key.clone(), 0);
180 return TuneCacheResult::Hit { fastest_index: 0 };
181 }
182
183 let test_inputs = tunables.generate_inputs(key, inputs);
184 let mut plan = tunables.plan(key);
185 let mut context_logs = match self.logger.lock().log_level_autotune() {
186 AutotuneLogLevel::Full => Some(String::new()),
187 _ => None,
188 };
189
190 let mut pending = Vec::<PendingBench>::new();
195 loop {
196 let tunable_indices = plan.next(context_logs.as_mut());
197
198 if tunable_indices.is_empty() {
199 panic!(
200 "Can't execute the autotune plan for key: {key:?}\n - plan: {plan:?}\n - results: {results:?}"
201 );
202 }
203
204 for index in tunable_indices {
205 let op = autotunables[index];
206
207 match tune_benchmark(op, test_inputs.clone(), client.clone()) {
208 Ok(profiles) => pending.push(PendingBench {
209 index,
210 name: op.name.clone(),
211 profiles,
212 }),
213 Err(err) => {
214 results[index] = AutotuneResult::error(err);
215 }
216 }
217 }
218
219 if !pending.is_empty() {
220 break;
221 }
222 }
223
224 let request = TuneRequest {
225 key: key.clone(),
226 results,
227 #[cfg(std_io)]
228 checksum,
229 context_logs,
230 pending,
231 };
232
233 #[cfg(target_family = "wasm")]
236 {
237 let cache = self.cache.clone();
238 let logger = self.logger.clone();
239 wasm_bindgen_futures::spawn_local(async move {
240 process_request(request, &cache, &logger).await;
241 });
242
243 return TuneCacheResult::Pending;
244 }
245
246 #[cfg(not(target_family = "wasm"))]
247 cubecl_common::future::block_on(process_request(request, &self.cache, &self.logger))
248 }
249}
250
251async fn process_request<K: AutotuneKey>(
253 request: TuneRequest<K>,
254 cache: &spin::Mutex<TuneCache<K>>,
255 logger: &spin::Mutex<Logger>,
256) -> TuneCacheResult {
257 let TuneRequest {
258 key,
259 mut results,
260 #[cfg(std_io)]
261 checksum,
262 context_logs,
263 pending,
264 } = request;
265
266 for bench in pending {
267 let PendingBench {
268 index,
269 name,
270 profiles,
271 } = bench;
272
273 if profiles.is_empty() {
274 results[index] = AutotuneResult::error(AutotuneError::Unknown {
275 name: name.to_string(),
276 err: "No profiling available".to_string(),
277 });
278 continue;
279 }
280
281 let timing_method = profiles.first().unwrap().timing_method();
282 let mut durations = Vec::with_capacity(profiles.len());
283 for profile in profiles {
284 durations.push(profile.resolve().await.duration());
285 }
286
287 results[index] = AutotuneResult::success(AutotuneOutcome::new(
288 name.to_string(),
289 index,
290 BenchmarkComputations::new(&BenchmarkDurations::from_durations(
291 timing_method,
292 durations,
293 )),
294 ));
295 }
296
297 results.sort_by(|a, b| {
298 let a = a
299 .outcome
300 .as_ref()
301 .map(|r| r.computation.score())
302 .unwrap_or(u64::MAX);
303 let b = b
304 .outcome
305 .as_ref()
306 .map(|r| r.computation.score())
307 .unwrap_or(u64::MAX);
308 a.cmp(&b)
309 });
310
311 let fastest_index = results
312 .first()
313 .expect("At least one kernel needed.")
314 .outcome
315 .as_ref()
316 .expect("At least one kernel has to succeed.")
317 .index;
318
319 {
320 log_result(&mut logger.lock(), &key, &results, context_logs.as_deref());
321 cache.lock().cache_insert(key.clone(), fastest_index);
322 #[cfg(std_io)]
323 cache
324 .lock()
325 .persistent_cache_insert(key, checksum, fastest_index, results);
326 }
327
328 TuneCacheResult::Hit { fastest_index }
329}
330
331fn log_result<K: AutotuneKey>(
333 logger: &mut Logger,
334 key: &K,
335 results: &[AutotuneResult],
336 context_logs: Option<&str>,
337) {
338 match logger.log_level_autotune() {
339 AutotuneLogLevel::Minimal => {
340 let top_times = results
341 .iter()
342 .map(|r| {
343 let time = r
344 .outcome
345 .as_ref()
346 .map(|r| r.computation.median)
347 .unwrap_or(Duration::MAX);
348
349 let index = r.outcome.as_ref().map(|r| r.index).unwrap_or_default();
350 (index, time)
351 })
352 .take(3)
353 .collect::<Vec<_>>();
354
355 let result = results
356 .first()
357 .expect("At least one kernel needed.")
358 .outcome
359 .as_ref()
360 .expect("At least one kernel has to succeed.");
361
362 let context = context_logs.unwrap_or("");
363 logger.log_autotune(&format!(
364 "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}",
365 result.name,
366 ));
367 }
368 AutotuneLogLevel::Full => {
369 let result = results
370 .first()
371 .expect("At least one kernel needed.")
372 .outcome
373 .as_ref()
374 .expect("At least one kernel has to succeed.");
375
376 let context = context_logs.unwrap_or("");
377 logger.log_autotune(&format!(
378 "Fastest result {}-{key}. Context: {context}",
379 result.name,
380 ));
381
382 for result in results.iter() {
383 match &result.outcome {
384 Ok(val) => {
385 logger.log_autotune(&format!("{val}"));
386 }
387 Err(err) => logger.log_autotune(&format!("{err:?}")),
388 }
389 }
390 }
391 AutotuneLogLevel::Disabled => {}
392 }
393}
394
395#[cfg(feature = "autotune-checks")]
396pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
397 mut checks_outputs: Vec<Result<O, AutotuneError>>,
398) {
399 let reference = checks_outputs.remove(checks_outputs.len() - 1);
400
401 if let Ok(reference) = reference {
402 for other in checks_outputs.into_iter().flatten() {
403 reference.check_equivalence(other);
404 }
405 }
406}