cubecl_runtime/tune/
base.rs

1use super::{AutotuneKey, IntoTuneFn, TuneFn};
2use alloc::sync::Arc;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::sync::atomic::{AtomicU32, Ordering};
6use hashbrown::HashMap;
7
8/// A tunable wraps a [function](TuneFn) that can be included in multiple [groups](TuneGroup).
9///
10/// When a tunable is part of multiple groups, it will be autotuned when one of those groups is
11/// prioritized.
12pub struct Tunable<K, Inputs, Output> {
13    pub(crate) function: Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>,
14    groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
15}
16
17impl<K, Inputs, Output> Tunable<K, Inputs, Output> {
18    /// Create a new tunable based on a function.
19    pub fn new<Marker>(function: impl IntoTuneFn<Inputs, Output, Marker>) -> Self {
20        Self {
21            function: Arc::new(function.into_tunable()),
22            groups: Vec::new(),
23        }
24    }
25
26    /// Tag the current tunable as part of the given [group](TuneGroup).
27    pub fn group<F: Fn(&K) -> u8 + 'static>(mut self, group: &TuneGroup<K>, priority: F) -> Self {
28        self.groups.push((group.clone(), Arc::new(priority)));
29        self
30    }
31}
32
33/// A tune group encapsulates a priority that can be calculated based on an
34/// [autotune key](AutotuneKey).
35///
36/// During autotuning, the higher prioritized groups will be autotuned first, and if a tunable
37/// returns a valid result, no more groups will be autotuned afterward.
38///
39/// Note that tunables themselves have a priority dictating the order in which they are autotuned in
40/// each group.
41pub struct TuneGroup<K> {
42    id: u32,
43    pub(crate) priority: PriorityFunc<K>,
44}
45
46impl<K> Clone for TuneGroup<K> {
47    fn clone(&self) -> Self {
48        Self {
49            id: self.id,
50            priority: self.priority.clone(),
51        }
52    }
53}
54
55impl<K> TuneGroup<K> {
56    /// Create a new group based on a priority function.
57    pub fn new<F: Fn(&K) -> u8 + 'static>(f: F) -> Self {
58        let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
59
60        Self {
61            id,
62            priority: Arc::new(f),
63        }
64    }
65}
66
67#[derive(Debug)]
68/// A group plan dictates which [tunables](Tunable) should be executed, and in what order.
69pub(crate) struct TunePlan {
70    priorities: Vec<u8>,
71    no_groups: Vec<usize>,
72    groups: HashMap<u8, GroupPlan>,
73}
74
75#[derive(Default, Debug)]
76struct GroupPlan {
77    priorities: Vec<u8>,
78    indices: HashMap<u8, Vec<usize>>,
79}
80
81struct Cleanup {
82    groups: Vec<u8>,
83    tunables: Vec<(u8, u8)>,
84}
85
86impl TunePlan {
87    pub fn new<K: AutotuneKey, In, Out>(key: &K, tunables: &[Tunable<K, In, Out>]) -> Self {
88        let mut priorities = Vec::<u8>::new();
89        let mut no_groups = Vec::new();
90        let mut groups = HashMap::<u8, GroupPlan>::new();
91
92        for (index, tunable) in tunables.iter().enumerate() {
93            if tunable.groups.is_empty() {
94                no_groups.push(index);
95            } else {
96                for (group, within_group_priority_fn) in tunable.groups.iter() {
97                    let priority_fn = &group.priority;
98                    let priority = priority_fn(key);
99                    if !priorities.contains(&priority) {
100                        priorities.push(priority);
101                    }
102
103                    let group_priorities = match groups.get_mut(&priority) {
104                        Some(val) => val,
105                        None => {
106                            groups.insert(priority, GroupPlan::default());
107                            groups.get_mut(&priority).unwrap()
108                        }
109                    };
110                    let priority = within_group_priority_fn(key);
111
112                    if group_priorities.priorities.contains(&priority) {
113                        group_priorities
114                            .indices
115                            .get_mut(&priority)
116                            .unwrap()
117                            .push(index);
118                    } else {
119                        group_priorities.priorities.push(priority);
120                        group_priorities.indices.insert(priority, vec![index]);
121                    }
122                }
123            }
124        }
125
126        priorities.sort();
127
128        for group in groups.iter_mut() {
129            group.1.priorities.sort();
130        }
131
132        Self {
133            priorities,
134            no_groups,
135            groups,
136        }
137    }
138
139    /// Get the next batch of [tunable](Tunable) index to be autotuned.
140    ///
141    /// Note that if the list is empty, it means no more autotuned entry can be executed.
142    pub(crate) fn next(&mut self) -> Vec<usize> {
143        let mut indices = core::mem::take(&mut self.no_groups);
144        let priority = self.priorities.last();
145
146        let priority = match priority {
147            Some(val) => val,
148            None => return indices,
149        };
150
151        let (mut group_indices, cleanup) = self.group_plan_next(*priority);
152        self.cleanup(cleanup);
153        indices.append(&mut group_indices);
154
155        indices
156    }
157
158    fn cleanup(&mut self, cleanup: Cleanup) {
159        for group_p in cleanup.groups {
160            let index = self
161                .priorities
162                .iter()
163                .enumerate()
164                .find(|p| *p.1 == group_p)
165                .unwrap();
166
167            self.priorities.remove(index.0);
168            self.groups.remove(&group_p);
169        }
170
171        for (group_p, tunable_p) in cleanup.tunables {
172            if let Some(group) = self.groups.get_mut(&group_p) {
173                let index = group
174                    .priorities
175                    .iter()
176                    .enumerate()
177                    .find(|p| *p.1 == tunable_p)
178                    .unwrap();
179                group.priorities.remove(index.0);
180                group.indices.remove(&tunable_p);
181            }
182        }
183    }
184
185    fn group_plan_next(&mut self, priority: u8) -> (Vec<usize>, Cleanup) {
186        let plan = self.groups.get_mut(&priority).expect("To be filled");
187        let within_group_prio = plan.priorities.pop().unwrap();
188        let next_indices = plan.indices.remove(&within_group_prio).unwrap();
189
190        let mut cleanup_groups = Vec::new();
191        let mut cleanup_tunables = Vec::new();
192
193        for (pg, group) in self.groups.iter_mut() {
194            let mut num_empty_tunables = 0;
195            let num_tunables = group.priorities.len();
196
197            for (pt, indices) in group.indices.iter_mut() {
198                for n in &next_indices {
199                    let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
200                    if let Some(entry) = entry {
201                        indices.remove(entry.0);
202                    }
203                }
204
205                if indices.is_empty() {
206                    num_empty_tunables += 1;
207                    cleanup_tunables.push((*pg, *pt));
208                }
209            }
210
211            if num_empty_tunables == num_tunables {
212                cleanup_groups.push(*pg);
213            }
214        }
215
216        (
217            next_indices,
218            Cleanup {
219                groups: cleanup_groups,
220                tunables: cleanup_tunables,
221            },
222        )
223    }
224}
225
226type PriorityFunc<K> = Arc<dyn Fn(&K) -> u8>;
227
228static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
229
230#[cfg(test)]
231mod tests {
232    use core::fmt::Display;
233
234    use serde::{Deserialize, Serialize};
235
236    use super::*;
237
238    #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
239    struct FakeAutotuneKey;
240
241    impl Display for FakeAutotuneKey {
242        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
243            f.write_str("FakeAutotuneKey")
244        }
245    }
246
247    impl AutotuneKey for FakeAutotuneKey {}
248
249    #[test]
250    fn test_plan_order() {
251        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
252        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
253
254        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
255        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
256        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
257        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
258
259        let key = FakeAutotuneKey;
260        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
261
262        assert_eq!(plan.next(), vec![0, 2]);
263        assert_eq!(plan.next(), vec![1]);
264        assert_eq!(plan.next(), vec![3]);
265        assert!(plan.next().is_empty());
266    }
267
268    #[test]
269    fn test_plan_order_multi_groups_same_priority() {
270        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
271        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
272        let group2 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
273
274        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
275        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
276        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
277        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
278        let tunable4 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group2, |_| 2);
279
280        let key = FakeAutotuneKey;
281        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
282
283        assert_eq!(plan.next(), vec![0, 2]);
284        assert_eq!(plan.next(), vec![1]);
285        assert_eq!(plan.next(), vec![3, 4]);
286        assert!(plan.next().is_empty());
287    }
288
289    #[test]
290    fn test_plan_order_tunable_multiple_groups() {
291        let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
292        let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
293
294        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
295        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel)
296            .group(&group0, |_| 1)
297            .group(&group1, |_| 2);
298        let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
299        let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 3);
300
301        let key = FakeAutotuneKey;
302        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
303
304        assert_eq!(plan.next(), vec![0, 3]);
305        assert_eq!(plan.next(), vec![1]);
306        assert_eq!(plan.next(), vec![2]);
307        assert!(plan.next().is_empty());
308    }
309
310    #[test]
311    fn test_plan_no_group() {
312        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
313        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
314
315        let key = FakeAutotuneKey;
316        let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
317
318        assert_eq!(plan.next(), vec![0, 1]);
319        assert!(plan.next().is_empty());
320    }
321
322    fn fake_kernel() -> Result<(), String> {
323        Ok(())
324    }
325}