cubecl_runtime/tune/
base.rs

1use super::{AutotuneKey, IntoTuneFn, TuneFn};
2use alloc::sync::Arc;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::sync::atomic::{AtomicU32, Ordering};
6use hashbrown::HashMap;
7
8/// A tunable wraps a [function](TuneFn) that can be included in multiple [groups](TuneGroup).
9///
10/// When a tunable is part of multiple groups, it will be autotuned when one of those groups is
11/// prioritized.
12pub struct Tunable<K, Inputs, Output> {
13    pub(crate) function: Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>,
14    groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
15}
16
17impl<K, Inputs, Output> Tunable<K, Inputs, Output> {
18    /// Create a new tunable based on a function.
19    pub fn new<Marker>(function: impl IntoTuneFn<Inputs, Output, Marker>) -> Self {
20        Self {
21            function: Arc::new(function.into_tunable()),
22            groups: Vec::new(),
23        }
24    }
25
26    /// Tag the current tunable as part of the given [group](TuneGroup).
27    /// `group` is a tuning group with a corresponding priority function.
28    /// `priority` is the intra-group priority, applied after the group priority to further sort entries
29    ///
30    /// Groups are tuned in order of priority, and then each entry in the group is tuned based on the
31    /// intra-group priority. Negative priorities ensure the entry is never tuned for this key.
32    pub fn group<F: Fn(&K) -> i8 + 'static>(mut self, group: &TuneGroup<K>, priority: F) -> Self {
33        self.groups.push((group.clone(), Arc::new(priority)));
34        self
35    }
36}
37
38/// A tune group encapsulates a priority that can be calculated based on an
39/// [autotune key](AutotuneKey).
40///
41/// During autotuning, the higher prioritized groups will be autotuned first, and if a tunable
42/// returns a valid result, no more groups will be autotuned afterward.
43///
44/// Note that tunables themselves have a priority dictating the order in which they are autotuned in
45/// each group.
46pub struct TuneGroup<K> {
47    id: u32,
48    pub(crate) priority: PriorityFunc<K>,
49}
50
51impl<K> core::fmt::Debug for TuneGroup<K> {
52    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53        f.debug_struct("TuneGroup").field("id", &self.id).finish()
54    }
55}
56
57impl<K> Clone for TuneGroup<K> {
58    fn clone(&self) -> Self {
59        Self {
60            id: self.id,
61            priority: self.priority.clone(),
62        }
63    }
64}
65
66impl<K> TuneGroup<K> {
67    /// Create a new group based on a priority function.
68    pub fn new<F: Fn(&K) -> i8 + 'static>(f: F) -> Self {
69        let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
70
71        Self {
72            id,
73            priority: Arc::new(f),
74        }
75    }
76}
77
78#[derive(Debug)]
79/// A group plan dictates which [tunables](Tunable) should be executed, and in what order.
80pub(crate) struct TunePlan {
81    priorities: Vec<i8>,
82    no_groups: Vec<usize>,
83    groups: HashMap<i8, GroupPlan>,
84}
85
86#[derive(Default, Debug)]
87struct GroupPlan {
88    priorities: Vec<i8>,
89    indices: HashMap<i8, Vec<usize>>,
90}
91
92#[derive(Debug)]
93struct Cleanup {
94    groups: Vec<i8>,
95    tunables: Vec<(i8, i8)>,
96    /// Within group priority is too low to even try.
97    skipped: bool,
98}
99
100impl TunePlan {
101    pub fn new<K: AutotuneKey, In, Out>(key: &K, tunables: &[Tunable<K, In, Out>]) -> Self {
102        let mut priorities = Vec::<i8>::new();
103        let mut no_groups = Vec::new();
104        let mut groups = HashMap::<i8, GroupPlan>::new();
105
106        for (index, tunable) in tunables.iter().enumerate() {
107            if tunable.groups.is_empty() {
108                no_groups.push(index);
109            } else {
110                for (group, within_group_priority_fn) in tunable.groups.iter() {
111                    let priority_fn = &group.priority;
112                    let priority = priority_fn(key);
113                    if !priorities.contains(&priority) {
114                        priorities.push(priority);
115                    }
116
117                    let group_priorities = match groups.get_mut(&priority) {
118                        Some(val) => val,
119                        None => {
120                            groups.insert(priority, GroupPlan::default());
121                            groups.get_mut(&priority).unwrap()
122                        }
123                    };
124                    let priority = within_group_priority_fn(key);
125
126                    if group_priorities.priorities.contains(&priority) {
127                        group_priorities
128                            .indices
129                            .get_mut(&priority)
130                            .unwrap()
131                            .push(index);
132                    } else {
133                        group_priorities.priorities.push(priority);
134                        group_priorities.indices.insert(priority, vec![index]);
135                    }
136                }
137            }
138        }
139
140        priorities.sort();
141
142        for group in groups.iter_mut() {
143            group.1.priorities.sort();
144        }
145
146        Self {
147            priorities,
148            no_groups,
149            groups,
150        }
151    }
152
153    /// Get the next batch of [tunable](Tunable) index to be autotuned.
154    ///
155    /// Note that if the list is empty, it means no more autotuned entry can be executed.
156    pub(crate) fn next(&mut self) -> Vec<usize> {
157        let mut indices = core::mem::take(&mut self.no_groups);
158        let priority = self.priorities.last();
159
160        let priority = match priority {
161            Some(val) => *val,
162            None => return indices,
163        };
164
165        let (mut group_indices, cleanup) = self.group_plan_next(priority);
166        // Some entries are skipped for this round of prioritizing.
167        let skipped = cleanup.skipped || priority < 0;
168
169        self.cleanup(cleanup);
170
171        if priority >= 0 {
172            indices.append(&mut group_indices);
173        }
174
175        // The indices list is empty, but it doesn't mean we should stop
176        // autotuning, since some entries were skipped.
177        if indices.is_empty() && skipped {
178            self.next()
179        } else {
180            indices
181        }
182    }
183
184    fn cleanup(&mut self, cleanup: Cleanup) {
185        for group_p in cleanup.groups {
186            let index = self
187                .priorities
188                .iter()
189                .enumerate()
190                .find(|p| *p.1 == group_p)
191                .unwrap();
192
193            self.priorities.remove(index.0);
194            self.groups.remove(&group_p);
195        }
196
197        for (group_p, tunable_p) in cleanup.tunables {
198            if let Some(group) = self.groups.get_mut(&group_p) {
199                let index = group
200                    .priorities
201                    .iter()
202                    .enumerate()
203                    .find(|p| *p.1 == tunable_p)
204                    .unwrap();
205                group.priorities.remove(index.0);
206                group.indices.remove(&tunable_p);
207            }
208        }
209    }
210
211    fn group_plan_next(&mut self, priority: i8) -> (Vec<usize>, Cleanup) {
212        let plan = self.groups.get_mut(&priority).expect("To be filled");
213        let within_group_prio = plan.priorities.pop().unwrap();
214        let mut next_indices = plan.indices.remove(&within_group_prio).unwrap();
215
216        let mut cleanup_groups = Vec::new();
217        let mut cleanup_tunables = Vec::new();
218
219        for (pg, group) in self.groups.iter_mut() {
220            let mut num_empty_tunables = 0;
221            let num_tunables = group.priorities.len();
222
223            for (pt, indices) in group.indices.iter_mut() {
224                for n in &next_indices {
225                    let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
226                    if let Some(entry) = entry {
227                        indices.remove(entry.0);
228                    }
229                }
230
231                if indices.is_empty() {
232                    num_empty_tunables += 1;
233                    cleanup_tunables.push((*pg, *pt));
234                }
235            }
236
237            if num_empty_tunables == num_tunables {
238                cleanup_groups.push(*pg);
239            }
240        }
241
242        if within_group_prio < 0 {
243            // Discard algorithms with negative priority
244            next_indices.clear();
245        }
246
247        (
248            next_indices,
249            Cleanup {
250                groups: cleanup_groups,
251                tunables: cleanup_tunables,
252                skipped: within_group_prio < 0,
253            },
254        )
255    }
256}
257
258type PriorityFunc<K> = Arc<dyn Fn(&K) -> i8>;
259
260static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
261
262#[cfg(test)]
263mod tests {
264    use core::fmt::Display;
265
266    use serde::{Deserialize, Serialize};
267
268    use super::*;
269
270    #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
271    struct FakeAutotuneKey;
272
273    impl Display for FakeAutotuneKey {
274        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
275            f.write_str("FakeAutotuneKey")
276        }
277    }
278
279    impl AutotuneKey for FakeAutotuneKey {}
280
281    #[test]
282    fn test_plan_order() {
283        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
284        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
285
286        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
287        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
288        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
289        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
290
291        let key = FakeAutotuneKey;
292        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
293
294        assert_eq!(plan.next(), vec![0, 2]);
295        assert_eq!(plan.next(), vec![1]);
296        assert_eq!(plan.next(), vec![3]);
297        assert!(plan.next().is_empty());
298    }
299
300    #[test]
301    fn test_plan_order_multi_groups_same_priority() {
302        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
303        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
304        let group2 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
305
306        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
307        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
308        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
309        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
310        let tunable4 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group2, |_| 2);
311
312        let key = FakeAutotuneKey;
313        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
314
315        assert_eq!(plan.next(), vec![0, 2]);
316        assert_eq!(plan.next(), vec![1]);
317        assert_eq!(plan.next(), vec![3, 4]);
318        assert!(plan.next().is_empty());
319    }
320
321    #[test]
322    fn test_plan_order_tunable_multiple_groups() {
323        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
324        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
325
326        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
327        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel)
328            .group(&group0, |_| 1)
329            .group(&group1, |_| 2);
330        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
331        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 3);
332
333        let key = FakeAutotuneKey;
334        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
335
336        assert_eq!(plan.next(), vec![0, 3]);
337        assert_eq!(plan.next(), vec![1]);
338        assert_eq!(plan.next(), vec![2]);
339        assert!(plan.next().is_empty());
340    }
341
342    #[test]
343    fn test_plan_negative_priority() {
344        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
345        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
346
347        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
348        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| -1);
349        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
350        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
351
352        let key = FakeAutotuneKey;
353        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
354
355        assert_eq!(plan.next(), vec![0, 2]);
356        assert_eq!(plan.next(), vec![3]);
357        assert!(plan.next().is_empty());
358    }
359
360    #[test]
361    fn test_plan_no_group() {
362        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
363        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
364
365        let key = FakeAutotuneKey;
366        let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
367
368        assert_eq!(plan.next(), vec![0, 1]);
369        assert!(plan.next().is_empty());
370    }
371
372    fn fake_kernel() -> Result<(), String> {
373        Ok(())
374    }
375}