Skip to main content

cubecl_runtime/tune/
base.rs

1use super::{AutotuneError, AutotuneKey, TuneFn, TuneInputs};
2use alloc::boxed::Box;
3use alloc::string::ToString;
4use alloc::{format, string::String, sync::Arc, vec, vec::Vec};
5use core::sync::atomic::{AtomicU32, Ordering};
6use hashbrown::HashMap;
7
8/// A single candidate for autotune: a named [`TuneFn`] plus the [groups](TuneGroup) it
9/// belongs to. A tunable is autotuned whenever any of its groups is prioritized.
10pub struct Tunable<K, F: TuneInputs, Output> {
11    pub(crate) function: TuneFn<F, Output>,
12    groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
13}
14
15impl<K, F: TuneInputs, Output: 'static> Tunable<K, F, Output> {
16    /// Create a tunable from a closure.
17    ///
18    /// The `for<'a> Fn(F::At<'a>) -> _` bound is spelled out directly in the
19    /// `where`-clause (rather than hidden behind a helper trait) so that Rust closure
20    /// inference sees it: otherwise `move |input| …` picks a single concrete lifetime
21    /// and fails with `implementation of FnOnce is not general enough` whenever
22    /// `F::At<'a>` actually depends on `'a`.
23    ///
24    /// For multi-input kernels, destructure a tuple:
25    /// `Tunable::new("name", |(lhs, rhs, out)| body)`.
26    pub fn new<Func, Err>(name: &str, func: Func) -> Self
27    where
28        Err: Into<String> + 'static,
29        Func: for<'a> Fn(<F as TuneInputs>::At<'a>) -> Result<Output, Err> + Send + Sync + 'static,
30    {
31        let name: String = name.into();
32        let name_for_err = name.clone();
33        Self {
34            function: TuneFn::new(
35                name,
36                Box::new(move |inputs| {
37                    func(inputs).map_err(|err| AutotuneError::Unknown {
38                        name: name_for_err.to_string(),
39                        err: err.into(),
40                    })
41                }),
42            ),
43            groups: Vec::new(),
44        }
45    }
46
47    /// Add this tunable to a [`TuneGroup`] with the given intra-group priority.
48    ///
49    /// Groups are autotuned in order of their priority; within each group, tunables are
50    /// tried in order of `priority(key)`. A negative priority skips the tunable for this
51    /// key.
52    pub fn group(
53        mut self,
54        group: &TuneGroup<K>,
55        priority: impl Fn(&K) -> i8 + Send + Sync + 'static,
56    ) -> Self {
57        self.groups.push((group.clone(), Arc::new(priority)));
58        self
59    }
60}
61
62/// A priority bucket for tunables, computed from the [autotune key](AutotuneKey).
63///
64/// Higher-priority groups are autotuned first; once any tunable in a group returns a
65/// valid result, no later groups are tried.
66pub struct TuneGroup<K> {
67    id: u32,
68    name: Arc<String>,
69    pub(crate) priority: PriorityFunc<K>,
70}
71
72impl<K> core::fmt::Debug for TuneGroup<K> {
73    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74        f.debug_struct("TuneGroup").field("id", &self.id).finish()
75    }
76}
77
78impl<K> Clone for TuneGroup<K> {
79    fn clone(&self) -> Self {
80        Self {
81            id: self.id,
82            name: self.name.clone(),
83            priority: self.priority.clone(),
84        }
85    }
86}
87
88impl<K> TuneGroup<K> {
89    /// Create a new group based on a priority function.
90    pub fn new(name: &str, f: impl Fn(&K) -> i8 + Send + Sync + 'static) -> Self {
91        let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
92
93        Self {
94            id,
95            name: Arc::new(name.into()),
96            priority: Arc::new(f),
97        }
98    }
99}
100
101#[derive(Debug)]
102/// A group plan dictates which [tunables](Tunable) should be executed, and in what order.
103pub(crate) struct TunePlan {
104    priorities: Vec<i8>,
105    no_groups: Vec<usize>,
106    groups: HashMap<i8, GroupPlan>,
107    returned: Vec<usize>,
108}
109
110#[derive(Default, Debug)]
111struct GroupPlan {
112    priorities: Vec<i8>,
113    indices: HashMap<i8, Vec<(usize, Arc<String>)>>,
114}
115
116#[derive(Debug)]
117struct Cleanup {
118    groups: Vec<i8>,
119    tunables: Vec<(i8, i8)>,
120    /// Within group priority is too low to even try.
121    skipped: bool,
122}
123
124impl TunePlan {
125    pub fn new<K: AutotuneKey, F: TuneInputs, Out>(
126        key: &K,
127        tunables: &[Tunable<K, F, Out>],
128    ) -> Self {
129        let mut priorities = Vec::<i8>::new();
130        let mut no_groups = Vec::new();
131        let mut groups = HashMap::<i8, GroupPlan>::new();
132
133        for (index, tunable) in tunables.iter().enumerate() {
134            if tunable.groups.is_empty() {
135                no_groups.push(index);
136            } else {
137                for (group, within_group_priority_fn) in tunable.groups.iter() {
138                    let priority_fn = &group.priority;
139                    let priority = priority_fn(key);
140                    if !priorities.contains(&priority) {
141                        priorities.push(priority);
142                    }
143
144                    let group_priorities = match groups.get_mut(&priority) {
145                        Some(val) => val,
146                        None => {
147                            groups.insert(priority, GroupPlan::default());
148                            groups.get_mut(&priority).unwrap()
149                        }
150                    };
151                    let priority = within_group_priority_fn(key);
152
153                    if group_priorities.priorities.contains(&priority) {
154                        group_priorities
155                            .indices
156                            .get_mut(&priority)
157                            .unwrap()
158                            .push((index, group.name.clone()));
159                    } else {
160                        group_priorities.priorities.push(priority);
161                        group_priorities
162                            .indices
163                            .insert(priority, vec![(index, group.name.clone())]);
164                    }
165                }
166            }
167        }
168
169        priorities.sort();
170
171        for group in groups.iter_mut() {
172            group.1.priorities.sort();
173        }
174
175        Self {
176            priorities,
177            no_groups,
178            groups,
179            returned: Vec::new(),
180        }
181    }
182
183    /// Get the next batch of [tunable](Tunable) index to be autotuned.
184    ///
185    /// Note that if the list is empty, it means no more autotuned entry can be executed.
186    pub(crate) fn next(&mut self, mut context_logs: Option<&mut String>) -> Vec<usize> {
187        let mut indices = core::mem::take(&mut self.no_groups);
188        let priority = self.priorities.last();
189
190        let priority = match priority {
191            Some(val) => *val,
192            None => return indices,
193        };
194
195        let (group_indices, cleanup) = self.group_plan_next(priority);
196        // Some entries are skipped for this round of prioritizing.
197        let skipped = cleanup.skipped || priority < 0;
198        let mut all_skip = true;
199
200        self.cleanup(cleanup);
201
202        if priority >= 0 {
203            if let Some(ctx) = context_logs.take() {
204                *ctx += format!("\n - Tuning: {group_indices:?}").as_str();
205                context_logs = Some(ctx);
206            }
207            for (index, _name) in group_indices {
208                if !self.returned.contains(&index) {
209                    all_skip = false;
210                    indices.push(index);
211                }
212            }
213        }
214
215        // The indices list is empty, but it doesn't mean we should stop
216        // autotuning, since some entries were skipped.
217
218        if indices.is_empty() && (skipped || all_skip) {
219            self.next(context_logs)
220        } else {
221            for i in indices.iter() {
222                self.returned.push(*i);
223            }
224            indices
225        }
226    }
227
228    fn cleanup(&mut self, cleanup: Cleanup) {
229        for group_p in cleanup.groups {
230            let index = self
231                .priorities
232                .iter()
233                .enumerate()
234                .find(|p| *p.1 == group_p)
235                .unwrap();
236
237            self.priorities.remove(index.0);
238            self.groups.remove(&group_p);
239        }
240
241        for (group_p, tunable_p) in cleanup.tunables {
242            if let Some(group) = self.groups.get_mut(&group_p) {
243                let index = group
244                    .priorities
245                    .iter()
246                    .enumerate()
247                    .find(|p| *p.1 == tunable_p)
248                    .unwrap();
249                group.priorities.remove(index.0);
250                group.indices.remove(&tunable_p);
251            }
252        }
253    }
254
255    fn group_plan_next(&mut self, priority: i8) -> (Vec<(usize, Arc<String>)>, Cleanup) {
256        let group_plan = self.groups.get_mut(&priority).expect("To be filled");
257        let within_group_prio = group_plan.priorities.pop().unwrap();
258        let mut next_indices = group_plan.indices.remove(&within_group_prio).unwrap();
259
260        let mut cleanup_groups = Vec::new();
261        let mut cleanup_tunables = Vec::new();
262
263        for (pg, group) in self.groups.iter_mut() {
264            let mut num_empty_tunables = 0;
265            let num_tunables = group.priorities.len();
266
267            for (pt, indices) in group.indices.iter_mut() {
268                for n in &next_indices {
269                    let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
270                    if let Some(entry) = entry {
271                        indices.remove(entry.0);
272                    }
273                }
274
275                if indices.is_empty() {
276                    num_empty_tunables += 1;
277                    cleanup_tunables.push((*pg, *pt));
278                }
279            }
280
281            if num_empty_tunables == num_tunables {
282                cleanup_groups.push(*pg);
283            }
284        }
285
286        if within_group_prio < 0 {
287            // Discard algorithms with negative priority
288            next_indices.clear();
289        }
290
291        (
292            next_indices,
293            Cleanup {
294                groups: cleanup_groups,
295                tunables: cleanup_tunables,
296                skipped: within_group_prio < 0,
297            },
298        )
299    }
300}
301
302type PriorityFunc<K> = Arc<dyn Fn(&K) -> i8 + Send + Sync>;
303
304static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
305
306#[cfg(test)]
307mod tests {
308    use core::fmt::Display;
309
310    use serde::{Deserialize, Serialize};
311
312    use super::*;
313
314    #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
315    struct FakeAutotuneKey;
316
317    impl Display for FakeAutotuneKey {
318        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
319            f.write_str("FakeAutotuneKey")
320        }
321    }
322
323    impl AutotuneKey for FakeAutotuneKey {}
324
325    #[test_log::test]
326    fn test_plan_order() {
327        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
328        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
329
330        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
331        let tunable1 =
332            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
333        let tunable2 =
334            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
335        let tunable3 =
336            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
337
338        let key = FakeAutotuneKey;
339        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
340
341        assert_eq!(plan.next(None), vec![0, 2]);
342        assert_eq!(plan.next(None), vec![1]);
343        assert_eq!(plan.next(None), vec![3]);
344        assert!(plan.next(None).is_empty());
345    }
346
347    #[test_log::test]
348    fn test_plan_order_multi_groups_same_priority() {
349        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
350        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
351        let group2 = TuneGroup::<FakeAutotuneKey>::new("group2", |_| 1);
352
353        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
354        let tunable1 =
355            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
356        let tunable2 =
357            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
358        let tunable3 =
359            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
360        let tunable4 =
361            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group2, |_| 2);
362
363        let key = FakeAutotuneKey;
364        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
365
366        assert_eq!(plan.next(None), vec![0, 2]);
367        assert_eq!(plan.next(None), vec![1]);
368        assert_eq!(plan.next(None), vec![3, 4]);
369        assert!(plan.next(None).is_empty());
370    }
371
372    #[test_log::test]
373    fn test_plan_order_tunable_multiple_groups() {
374        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 1);
375        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 2);
376
377        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
378        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel)
379            .group(&group0, |_| 1)
380            .group(&group1, |_| 2);
381        let tunable2 =
382            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
383        let tunable3 =
384            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 3);
385
386        let key = FakeAutotuneKey;
387        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
388
389        assert_eq!(plan.next(None), vec![0, 3]);
390        assert_eq!(plan.next(None), vec![1]);
391        assert_eq!(plan.next(None), vec![2]);
392        assert!(plan.next(None).is_empty());
393    }
394
395    #[test_log::test]
396    fn test_plan_negative_priority() {
397        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
398        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
399
400        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
401        let tunable1 =
402            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| -1);
403        let tunable2 =
404            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
405        let tunable3 =
406            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
407
408        let key = FakeAutotuneKey;
409        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
410
411        assert_eq!(plan.next(None), vec![0, 2]);
412        assert_eq!(plan.next(None), vec![3]);
413        assert!(plan.next(None).is_empty());
414    }
415
416    #[test_log::test]
417    fn test_plan_no_group() {
418        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
419        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
420
421        let key = FakeAutotuneKey;
422        let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
423
424        assert_eq!(plan.next(None), vec![0, 1]);
425        assert!(plan.next(None).is_empty());
426    }
427
428    #[test_log::test]
429    fn test_plan_falls_through_when_all_group_tunables_fail() {
430        // Every tunable lives in exactly one group; the caller treats every batch as a failure
431        // by continuing to call next(). The plan must still surface every tunable, in priority
432        // order, before going empty.
433        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
434        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
435
436        let tunable0 =
437            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
438        let tunable1 =
439            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
440        let tunable2 =
441            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 1);
442        let tunable3 =
443            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
444
445        let key = FakeAutotuneKey;
446        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
447
448        let mut all_returned: Vec<usize> = Vec::new();
449        loop {
450            let batch = plan.next(None);
451            if batch.is_empty() {
452                break;
453            }
454            all_returned.extend(batch);
455        }
456
457        // Highest group (prio 2) drains first from highest intra-priority down, then next group.
458        assert_eq!(all_returned, vec![1, 0, 3, 2]);
459    }
460
461    #[test_log::test]
462    fn test_plan_single_group_exhausts_all_intra_priorities() {
463        // A single group with multiple intra-priorities should yield each batch separately,
464        // allowing the caller to continue on failures until the group is exhausted.
465        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 0);
466
467        let tunable0 =
468            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
469        let tunable1 =
470            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
471        let tunable2 =
472            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 3);
473
474        let key = FakeAutotuneKey;
475        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]);
476
477        assert_eq!(plan.next(None), vec![2]);
478        assert_eq!(plan.next(None), vec![1]);
479        assert_eq!(plan.next(None), vec![0]);
480        assert!(plan.next(None).is_empty());
481    }
482
483    #[test_log::test]
484    fn test_plan_all_negative_group_advances_to_next_group() {
485        // A group whose every tunable has a negative intra-priority should be skipped entirely
486        // without stopping autotuning — the next group must still be reached.
487        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
488        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
489
490        let tunable0 =
491            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| -1);
492        let tunable1 =
493            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| -2);
494        let tunable2 =
495            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 1);
496
497        let key = FakeAutotuneKey;
498        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]);
499
500        assert_eq!(plan.next(None), vec![2]);
501        assert!(plan.next(None).is_empty());
502    }
503
504    #[test_log::test]
505    fn test_plan_no_group_tunables_only_emitted_once_even_on_failures() {
506        // The ungrouped tunables are emitted together with the first group batch. If the caller
507        // keeps calling next() (treating the first batch as failing), they must not be
508        // re-emitted, and the plan must still advance to later groups.
509        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
510        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
511
512        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
513        let tunable1 =
514            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
515        let tunable2 =
516            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 1);
517
518        let key = FakeAutotuneKey;
519        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]);
520
521        assert_eq!(plan.next(None), vec![0, 1]);
522        assert_eq!(plan.next(None), vec![2]);
523        assert!(plan.next(None).is_empty());
524    }
525
526    #[test_log::test]
527    fn test_plan_multi_group_tunable_not_duplicated_across_failed_groups() {
528        // tunable1 belongs to both group0 and group1. It must be returned exactly once (via its
529        // higher-priority group), even if the caller continues iterating after failures.
530        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 1);
531        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 2);
532
533        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel)
534            .group(&group0, |_| 1)
535            .group(&group1, |_| 1);
536        let tunable1 =
537            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
538
539        let key = FakeAutotuneKey;
540        let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
541
542        let mut all_returned: Vec<usize> = Vec::new();
543        loop {
544            let batch = plan.next(None);
545            if batch.is_empty() {
546                break;
547            }
548            all_returned.extend(batch);
549        }
550
551        // tunable0 comes from group1 (higher priority). tunable1 is the sole member of group0
552        // after cross-group dedup. No duplicates.
553        assert_eq!(all_returned, vec![0, 1]);
554    }
555
556    #[test_log::test]
557    fn test_plan_recurses_when_batch_is_fully_already_returned() {
558        // Regression test: a tunable that lives in multiple groups was already emitted via its
559        // higher-priority group, so when its lower-priority group's batch fires the only index
560        // is one already present in `returned`. The plan must NOT return an empty batch here
561        // (that signals "no more work" to the caller and aborts with NoValidKernelFound); it
562        // must recurse to the next intra-priority and surface the remaining tunable.
563        //
564        // Cross-group dedup in group_plan_next compares (index, Arc<String> group_name), so a
565        // tunable appearing in both group_hi and group_lo isn't auto-removed from group_lo
566        // when popped from group_hi — the `returned` + `all_skip` path is the only guard.
567        let group_hi = TuneGroup::<FakeAutotuneKey>::new("hi", |_| 2);
568        let group_lo = TuneGroup::<FakeAutotuneKey>::new("lo", |_| 1);
569
570        // tunable0 is in both groups. tunable1 is only in group_lo at a lower intra-priority.
571        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel)
572            .group(&group_hi, |_| 1)
573            .group(&group_lo, |_| 2);
574        let tunable1 =
575            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group_lo, |_| 1);
576
577        let key = FakeAutotuneKey;
578        let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
579
580        // First call: group_hi yields tunable0.
581        assert_eq!(plan.next(None), vec![0]);
582        // Second call: group_lo's higher intra-priority batch is just tunable0 (already
583        // returned). Without the fix this returns [] and the autotuner aborts. With the fix
584        // the plan recurses and yields tunable1.
585        assert_eq!(plan.next(None), vec![1]);
586        assert!(plan.next(None).is_empty());
587    }
588
589    fn fake_kernel(_: ()) -> Result<(), String> {
590        Ok(())
591    }
592}