cubecl_runtime/tune/
base.rs

1use super::{AutotuneKey, IntoTuneFn, TuneFn};
2use alloc::sync::Arc;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::sync::atomic::{AtomicU32, Ordering};
6use hashbrown::HashMap;
7
8/// A tunable wraps a [function](TuneFn) that can be included in multiple [groups](TuneGroup).
9///
10/// When a tunable is part of multiple groups, it will be autotuned when one of those groups is
11/// prioritized.
12pub struct Tunable<K, Inputs, Output> {
13    pub(crate) function: Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>,
14    groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
15}
16
17impl<K, Inputs, Output> Tunable<K, Inputs, Output> {
18    /// Create a new tunable based on a function.
19    pub fn new<Marker>(function: impl IntoTuneFn<Inputs, Output, Marker>) -> Self {
20        Self {
21            function: Arc::new(function.into_tunable()),
22            groups: Vec::new(),
23        }
24    }
25
26    /// Tag the current tunable as part of the given [group](TuneGroup).
27    /// `group` is a tuning group with a corresponding priority function.
28    /// `priority` is the intra-group priority, applied after the group priority to further sort entries
29    ///
30    /// Groups are tuned in order of priority, and then each entry in the group is tuned based on the
31    /// intra-group priority. Negative priorities ensure the entry is never tuned for this key.
32    pub fn group<F: Fn(&K) -> i8 + 'static>(mut self, group: &TuneGroup<K>, priority: F) -> Self {
33        self.groups.push((group.clone(), Arc::new(priority)));
34        self
35    }
36}
37
38/// A tune group encapsulates a priority that can be calculated based on an
39/// [autotune key](AutotuneKey).
40///
41/// During autotuning, the higher prioritized groups will be autotuned first, and if a tunable
42/// returns a valid result, no more groups will be autotuned afterward.
43///
44/// Note that tunables themselves have a priority dictating the order in which they are autotuned in
45/// each group.
46pub struct TuneGroup<K> {
47    id: u32,
48    pub(crate) priority: PriorityFunc<K>,
49}
50
51impl<K> Clone for TuneGroup<K> {
52    fn clone(&self) -> Self {
53        Self {
54            id: self.id,
55            priority: self.priority.clone(),
56        }
57    }
58}
59
60impl<K> TuneGroup<K> {
61    /// Create a new group based on a priority function.
62    pub fn new<F: Fn(&K) -> i8 + 'static>(f: F) -> Self {
63        let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
64
65        Self {
66            id,
67            priority: Arc::new(f),
68        }
69    }
70}
71
72#[derive(Debug)]
73/// A group plan dictates which [tunables](Tunable) should be executed, and in what order.
74pub(crate) struct TunePlan {
75    priorities: Vec<i8>,
76    no_groups: Vec<usize>,
77    groups: HashMap<i8, GroupPlan>,
78}
79
80#[derive(Default, Debug)]
81struct GroupPlan {
82    priorities: Vec<i8>,
83    indices: HashMap<i8, Vec<usize>>,
84}
85
86struct Cleanup {
87    groups: Vec<i8>,
88    tunables: Vec<(i8, i8)>,
89}
90
91impl TunePlan {
92    pub fn new<K: AutotuneKey, In, Out>(key: &K, tunables: &[Tunable<K, In, Out>]) -> Self {
93        let mut priorities = Vec::<i8>::new();
94        let mut no_groups = Vec::new();
95        let mut groups = HashMap::<i8, GroupPlan>::new();
96
97        for (index, tunable) in tunables.iter().enumerate() {
98            if tunable.groups.is_empty() {
99                no_groups.push(index);
100            } else {
101                for (group, within_group_priority_fn) in tunable.groups.iter() {
102                    let priority_fn = &group.priority;
103                    let priority = priority_fn(key);
104                    if !priorities.contains(&priority) {
105                        priorities.push(priority);
106                    }
107
108                    let group_priorities = match groups.get_mut(&priority) {
109                        Some(val) => val,
110                        None => {
111                            groups.insert(priority, GroupPlan::default());
112                            groups.get_mut(&priority).unwrap()
113                        }
114                    };
115                    let priority = within_group_priority_fn(key);
116
117                    if group_priorities.priorities.contains(&priority) {
118                        group_priorities
119                            .indices
120                            .get_mut(&priority)
121                            .unwrap()
122                            .push(index);
123                    } else {
124                        group_priorities.priorities.push(priority);
125                        group_priorities.indices.insert(priority, vec![index]);
126                    }
127                }
128            }
129        }
130
131        priorities.sort();
132
133        for group in groups.iter_mut() {
134            group.1.priorities.sort();
135        }
136
137        Self {
138            priorities,
139            no_groups,
140            groups,
141        }
142    }
143
144    /// Get the next batch of [tunable](Tunable) index to be autotuned.
145    ///
146    /// Note that if the list is empty, it means no more autotuned entry can be executed.
147    pub(crate) fn next(&mut self) -> Vec<usize> {
148        let mut indices = core::mem::take(&mut self.no_groups);
149        let priority = self.priorities.last();
150
151        let priority = match priority {
152            Some(val) => *val,
153            None => return indices,
154        };
155
156        let (mut group_indices, cleanup) = self.group_plan_next(priority);
157        self.cleanup(cleanup);
158
159        if priority >= 0 {
160            indices.append(&mut group_indices);
161        }
162
163        indices
164    }
165
166    fn cleanup(&mut self, cleanup: Cleanup) {
167        for group_p in cleanup.groups {
168            let index = self
169                .priorities
170                .iter()
171                .enumerate()
172                .find(|p| *p.1 == group_p)
173                .unwrap();
174
175            self.priorities.remove(index.0);
176            self.groups.remove(&group_p);
177        }
178
179        for (group_p, tunable_p) in cleanup.tunables {
180            if let Some(group) = self.groups.get_mut(&group_p) {
181                let index = group
182                    .priorities
183                    .iter()
184                    .enumerate()
185                    .find(|p| *p.1 == tunable_p)
186                    .unwrap();
187                group.priorities.remove(index.0);
188                group.indices.remove(&tunable_p);
189            }
190        }
191    }
192
193    fn group_plan_next(&mut self, priority: i8) -> (Vec<usize>, Cleanup) {
194        let plan = self.groups.get_mut(&priority).expect("To be filled");
195        let within_group_prio = plan.priorities.pop().unwrap();
196        let mut next_indices = plan.indices.remove(&within_group_prio).unwrap();
197
198        let mut cleanup_groups = Vec::new();
199        let mut cleanup_tunables = Vec::new();
200
201        for (pg, group) in self.groups.iter_mut() {
202            let mut num_empty_tunables = 0;
203            let num_tunables = group.priorities.len();
204
205            for (pt, indices) in group.indices.iter_mut() {
206                for n in &next_indices {
207                    let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
208                    if let Some(entry) = entry {
209                        indices.remove(entry.0);
210                    }
211                }
212
213                if indices.is_empty() {
214                    num_empty_tunables += 1;
215                    cleanup_tunables.push((*pg, *pt));
216                }
217            }
218
219            if num_empty_tunables == num_tunables {
220                cleanup_groups.push(*pg);
221            }
222        }
223
224        if within_group_prio < 0 {
225            // Discard algorithms with negative priority
226            next_indices.clear();
227        }
228
229        (
230            next_indices,
231            Cleanup {
232                groups: cleanup_groups,
233                tunables: cleanup_tunables,
234            },
235        )
236    }
237}
238
239type PriorityFunc<K> = Arc<dyn Fn(&K) -> i8>;
240
241static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
242
243#[cfg(test)]
244mod tests {
245    use core::fmt::Display;
246
247    use serde::{Deserialize, Serialize};
248
249    use super::*;
250
251    #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
252    struct FakeAutotuneKey;
253
254    impl Display for FakeAutotuneKey {
255        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
256            f.write_str("FakeAutotuneKey")
257        }
258    }
259
260    impl AutotuneKey for FakeAutotuneKey {}
261
262    #[test]
263    fn test_plan_order() {
264        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
265        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
266
267        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
268        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
269        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
270        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
271
272        let key = FakeAutotuneKey;
273        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
274
275        assert_eq!(plan.next(), vec![0, 2]);
276        assert_eq!(plan.next(), vec![1]);
277        assert_eq!(plan.next(), vec![3]);
278        assert!(plan.next().is_empty());
279    }
280
281    #[test]
282    fn test_plan_order_multi_groups_same_priority() {
283        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
284        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
285        let group2 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
286
287        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
288        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
289        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
290        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
291        let tunable4 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group2, |_| 2);
292
293        let key = FakeAutotuneKey;
294        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
295
296        assert_eq!(plan.next(), vec![0, 2]);
297        assert_eq!(plan.next(), vec![1]);
298        assert_eq!(plan.next(), vec![3, 4]);
299        assert!(plan.next().is_empty());
300    }
301
302    #[test]
303    fn test_plan_order_tunable_multiple_groups() {
304        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
305        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
306
307        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
308        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel)
309            .group(&group0, |_| 1)
310            .group(&group1, |_| 2);
311        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
312        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 3);
313
314        let key = FakeAutotuneKey;
315        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
316
317        assert_eq!(plan.next(), vec![0, 3]);
318        assert_eq!(plan.next(), vec![1]);
319        assert_eq!(plan.next(), vec![2]);
320        assert!(plan.next().is_empty());
321    }
322
323    #[test]
324    fn test_plan_no_group() {
325        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
326        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
327
328        let key = FakeAutotuneKey;
329        let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
330
331        assert_eq!(plan.next(), vec![0, 1]);
332        assert!(plan.next().is_empty());
333    }
334
335    fn fake_kernel() -> Result<(), String> {
336        Ok(())
337    }
338}