Skip to main content

cubecl_runtime/tune/
tuner.rs

1use alloc::format;
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use cubecl_common::profile::ProfileDuration;
5
6use core::time::Duration;
7
8use alloc::string::{String, ToString};
9use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations};
10
11use crate::config::{Logger, autotune::AutotuneLogLevel};
12use crate::server::LaunchError;
13use crate::tune::{AutotuneResult, TuneCache, tune_benchmark};
14use crate::{client::ComputeClient, runtime::Runtime};
15
16use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneInputs};
17
18#[derive(Debug)]
19/// Runs autotune benchmarks for a single device and caches the results.
20///
21/// On wasm, [`tune`](Self::tune) spawns its work on the browser event loop; elsewhere
22/// it blocks inline. Either way the benchmarking itself is synchronous; only the
23/// per-sample profile resolution is awaited.
24pub struct Tuner<K: AutotuneKey> {
25    cache: Arc<spin::Mutex<TuneCache<K>>>,
26    logger: Arc<spin::Mutex<Logger>>,
27}
28
29/// The measured outcome for a given autotune invocation.
30#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
31#[derive(new, Debug, Clone, PartialEq, Eq)]
32pub struct AutotuneOutcome {
33    name: String,
34    index: usize,
35    computation: BenchmarkComputations,
36}
37
38impl core::fmt::Display for AutotuneOutcome {
39    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
40        write!(
41            f,
42            "Autotune[{}] name {} => {:?}",
43            self.index, self.name, self.computation
44        )
45    }
46}
47
48/// Error from running autotune.
49#[derive(Debug, Clone)]
50#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
51pub enum AutotuneError {
52    /// An unknown error happened.
53    Unknown {
54        /// The name of the tunable.
55        name: String,
56        /// The unknown error,
57        err: String,
58    },
59    /// All samples are invalid.
60    InvalidSamples {
61        /// The name of the tunable.
62        name: String,
63    },
64    /// No autotune was flagged as valid for the problem.
65    ///
66    /// # Warning
67    ///
68    /// This is an unrecoverable error and will cause a panic.
69    NoValidKernelFound {
70        /// The formatted context on why no valid kernel was found.
71        context: String,
72    },
73    /// The autotune is skipped manually.
74    Skip {
75        /// The name of the skipped kernel.
76        name: String,
77    },
78
79    /// An error happened when launching a kernel.
80    Launch(LaunchError),
81}
82
83impl From<LaunchError> for AutotuneError {
84    fn from(value: LaunchError) -> Self {
85        Self::Launch(value)
86    }
87}
88
89/// A successfully-queued benchmark: the profile futures for each sample, plus its metadata.
90struct PendingBench {
91    index: usize,
92    name: String,
93    profiles: Vec<ProfileDuration>,
94}
95
96/// A queued tuning job: all data needed to resolve samples and commit the result.
97/// Holds no references so it's trivially `Send + 'static` for the wasm spawn path.
98struct TuneRequest<K: AutotuneKey> {
99    key: K,
100    results: Vec<AutotuneResult>,
101    #[cfg(std_io)]
102    checksum: String,
103    context_logs: Option<String>,
104    pending: Vec<PendingBench>,
105}
106
107#[allow(clippy::new_without_default)]
108impl<K: AutotuneKey> Tuner<K> {
109    /// Create a tuner. Its cache is seeded from the persistent on-disk cache when
110    /// `std_io` is enabled.
111    pub fn new(name: &str, device_id: &str) -> Self {
112        Self {
113            cache: Arc::new(spin::Mutex::new(TuneCache::new(name, device_id))),
114            logger: Arc::new(spin::Mutex::new(Logger::new())),
115        }
116    }
117
118    /// Fetch the fastest autotune operation index for an autotune key.
119    pub fn fastest(&self, key: &K) -> TuneCacheResult {
120        self.cache.lock().fastest(key)
121    }
122
123    /// Check the cache, validate checksums if needed, and kick off a tuning job if the
124    /// key is a miss. Returns the resolved cache state.
125    pub fn check_tune<'a, R: Runtime, F: TuneInputs, Out: AutotuneOutput>(
126        &self,
127        key: &K,
128        inputs: &F::At<'a>,
129        tunables: &TunableSet<K, F, Out>,
130        #[cfg_attr(not(std_io), allow(unused))] checksum: impl FnOnce() -> String + Send + Sync,
131        client: &ComputeClient<R>,
132    ) -> TuneCacheResult
133    where
134        <F as TuneInputs>::At<'a>: Clone + Send,
135    {
136        {
137            let mut cache = self.cache.lock();
138            let cur = cache.fastest(key);
139
140            #[cfg(std_io)]
141            let cur = if matches!(cur, TuneCacheResult::Unchecked) {
142                let mut log = self.logger.lock();
143                let checksum = checksum();
144                if let AutotuneLogLevel::Full = log.log_level_autotune() {
145                    log.log_autotune(&format!("validate checksum key={key}, checksum={checksum}"));
146                }
147                cache.validate_checksum(key, &checksum)
148            } else {
149                cur
150            };
151
152            match cur {
153                TuneCacheResult::Hit { .. } | TuneCacheResult::Pending => return cur,
154                TuneCacheResult::Miss | TuneCacheResult::Unchecked => {
155                    cache.mark_pending(key.clone())
156                }
157            }
158            // Scope the guard: the rest of this function re-locks `self.cache` (fast
159            // path insert, `process_request`), and `spin::Mutex` is non-reentrant.
160        }
161
162        log::info!("Tuning {key}");
163
164        let autotunables = tunables.autotunables().collect::<Vec<_>>();
165        let mut results: Vec<AutotuneResult> = autotunables
166            .iter()
167            .map(|a| {
168                AutotuneResult::error(AutotuneError::Skip {
169                    name: a.name.to_string(),
170                })
171            })
172            .collect();
173
174        #[cfg(std_io)]
175        let checksum = tunables.compute_checksum();
176
177        // Fast path: single tunable, no benchmarking needed.
178        if results.len() == 1 {
179            self.cache.lock().cache_insert(key.clone(), 0);
180            return TuneCacheResult::Hit { fastest_index: 0 };
181        }
182
183        let test_inputs = tunables.generate_inputs(key, inputs);
184        let mut plan = tunables.plan(key);
185        let mut context_logs = match self.logger.lock().log_level_autotune() {
186            AutotuneLogLevel::Full => Some(String::new()),
187            _ => None,
188        };
189
190        // Walk the plan batch by batch, launching each benchmark synchronously. A
191        // successful launch queues a `PendingBench` for the async resolver below;
192        // launch errors go straight into `results`. Retry the next batch if a whole
193        // batch failed to queue anything.
194        let mut pending = Vec::<PendingBench>::new();
195        loop {
196            let tunable_indices = plan.next(context_logs.as_mut());
197
198            if tunable_indices.is_empty() {
199                panic!(
200                    "Can't execute the autotune plan for key: {key:?}\n - plan: {plan:?}\n - results: {results:?}"
201                );
202            }
203
204            for index in tunable_indices {
205                let op = autotunables[index];
206
207                match tune_benchmark(op, test_inputs.clone(), client.clone()) {
208                    Ok(profiles) => pending.push(PendingBench {
209                        index,
210                        name: op.name.clone(),
211                        profiles,
212                    }),
213                    Err(err) => {
214                        results[index] = AutotuneResult::error(err);
215                    }
216                }
217            }
218
219            if !pending.is_empty() {
220                break;
221            }
222        }
223
224        let request = TuneRequest {
225            key: key.clone(),
226            results,
227            #[cfg(std_io)]
228            checksum,
229            context_logs,
230            pending,
231        };
232
233        // Resolve samples and commit the result. On wasm this runs on the browser
234        // event loop; elsewhere it blocks inline.
235        #[cfg(target_family = "wasm")]
236        {
237            let cache = self.cache.clone();
238            let logger = self.logger.clone();
239            wasm_bindgen_futures::spawn_local(async move {
240                process_request(request, &cache, &logger).await;
241            });
242
243            return TuneCacheResult::Pending;
244        }
245
246        #[cfg(not(target_family = "wasm"))]
247        cubecl_common::future::block_on(process_request(request, &self.cache, &self.logger))
248    }
249}
250
251/// Await every profile sample, pick the fastest tunable, commit to the cache.
252async fn process_request<K: AutotuneKey>(
253    request: TuneRequest<K>,
254    cache: &spin::Mutex<TuneCache<K>>,
255    logger: &spin::Mutex<Logger>,
256) -> TuneCacheResult {
257    let TuneRequest {
258        key,
259        mut results,
260        #[cfg(std_io)]
261        checksum,
262        context_logs,
263        pending,
264    } = request;
265
266    for bench in pending {
267        let PendingBench {
268            index,
269            name,
270            profiles,
271        } = bench;
272
273        if profiles.is_empty() {
274            results[index] = AutotuneResult::error(AutotuneError::Unknown {
275                name: name.to_string(),
276                err: "No profiling available".to_string(),
277            });
278            continue;
279        }
280
281        let timing_method = profiles.first().unwrap().timing_method();
282        let mut durations = Vec::with_capacity(profiles.len());
283        for profile in profiles {
284            durations.push(profile.resolve().await.duration());
285        }
286
287        results[index] = AutotuneResult::success(AutotuneOutcome::new(
288            name.to_string(),
289            index,
290            BenchmarkComputations::new(&BenchmarkDurations::from_durations(
291                timing_method,
292                durations,
293            )),
294        ));
295    }
296
297    results.sort_by(|a, b| {
298        let a = a
299            .outcome
300            .as_ref()
301            .map(|r| r.computation.score())
302            .unwrap_or(u64::MAX);
303        let b = b
304            .outcome
305            .as_ref()
306            .map(|r| r.computation.score())
307            .unwrap_or(u64::MAX);
308        a.cmp(&b)
309    });
310
311    let fastest_index = results
312        .first()
313        .expect("At least one kernel needed.")
314        .outcome
315        .as_ref()
316        .expect("At least one kernel has to succeed.")
317        .index;
318
319    {
320        log_result(&mut logger.lock(), &key, &results, context_logs.as_deref());
321        cache.lock().cache_insert(key.clone(), fastest_index);
322        #[cfg(std_io)]
323        cache
324            .lock()
325            .persistent_cache_insert(key, checksum, fastest_index, results);
326    }
327
328    TuneCacheResult::Hit { fastest_index }
329}
330
331/// Emit the autotune result through the logger at the currently configured level.
332fn log_result<K: AutotuneKey>(
333    logger: &mut Logger,
334    key: &K,
335    results: &[AutotuneResult],
336    context_logs: Option<&str>,
337) {
338    match logger.log_level_autotune() {
339        AutotuneLogLevel::Minimal => {
340            let top_times = results
341                .iter()
342                .map(|r| {
343                    let time = r
344                        .outcome
345                        .as_ref()
346                        .map(|r| r.computation.median)
347                        .unwrap_or(Duration::MAX);
348
349                    let index = r.outcome.as_ref().map(|r| r.index).unwrap_or_default();
350                    (index, time)
351                })
352                .take(3)
353                .collect::<Vec<_>>();
354
355            let result = results
356                .first()
357                .expect("At least one kernel needed.")
358                .outcome
359                .as_ref()
360                .expect("At least one kernel has to succeed.");
361
362            let context = context_logs.unwrap_or("");
363            logger.log_autotune(&format!(
364                "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}",
365                result.name,
366            ));
367        }
368        AutotuneLogLevel::Full => {
369            let result = results
370                .first()
371                .expect("At least one kernel needed.")
372                .outcome
373                .as_ref()
374                .expect("At least one kernel has to succeed.");
375
376            let context = context_logs.unwrap_or("");
377            logger.log_autotune(&format!(
378                "Fastest result {}-{key}. Context: {context}",
379                result.name,
380            ));
381
382            for result in results.iter() {
383                match &result.outcome {
384                    Ok(val) => {
385                        logger.log_autotune(&format!("{val}"));
386                    }
387                    Err(err) => logger.log_autotune(&format!("{err:?}")),
388                }
389            }
390        }
391        AutotuneLogLevel::Disabled => {}
392    }
393}
394
395#[cfg(feature = "autotune-checks")]
396pub(crate) fn check_autotune_outputs<O: AutotuneOutput>(
397    mut checks_outputs: Vec<Result<O, AutotuneError>>,
398) {
399    let reference = checks_outputs.remove(checks_outputs.len() - 1);
400
401    if let Ok(reference) = reference {
402        for other in checks_outputs.into_iter().flatten() {
403            reference.check_equivalence(other);
404        }
405    }
406}