cubecl_runtime/tune/
base.rs

1use super::{AutotuneKey, IntoTuneFn, TuneFn};
2use alloc::format;
3use alloc::string::String;
4use alloc::sync::Arc;
5use alloc::vec;
6use alloc::vec::Vec;
7use core::sync::atomic::{AtomicU32, Ordering};
8use hashbrown::HashMap;
9
10/// A tunable wraps a [function](TuneFn) that can be included in multiple [groups](TuneGroup).
11///
12/// When a tunable is part of multiple groups, it will be autotuned when one of those groups is
13/// prioritized.
14pub struct Tunable<K, Inputs, Output> {
15    pub(crate) function: Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>,
16    groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
17}
18
19impl<K, Inputs, Output> Tunable<K, Inputs, Output> {
20    /// Create a new tunable based on a function.
21    pub fn new<Marker, Name: Into<String>>(
22        name: Name,
23        function: impl IntoTuneFn<Inputs, Output, Marker>,
24    ) -> Self {
25        Self {
26            function: Arc::new(function.into_tunable(name.into())),
27            groups: Vec::new(),
28        }
29    }
30
31    /// Tag the current tunable as part of the given [group](TuneGroup).
32    /// `group` is a tuning group with a corresponding priority function.
33    /// `priority` is the intra-group priority, applied after the group priority to further sort entries
34    ///
35    /// Groups are tuned in order of priority, and then each entry in the group is tuned based on the
36    /// intra-group priority. Negative priorities ensure the entry is never tuned for this key.
37    pub fn group<F: Fn(&K) -> i8 + 'static>(mut self, group: &TuneGroup<K>, priority: F) -> Self {
38        self.groups.push((group.clone(), Arc::new(priority)));
39        self
40    }
41}
42
43/// A tune group encapsulates a priority that can be calculated based on an
44/// [autotune key](AutotuneKey).
45///
46/// During autotuning, the higher prioritized groups will be autotuned first, and if a tunable
47/// returns a valid result, no more groups will be autotuned afterward.
48///
49/// Note that tunables themselves have a priority dictating the order in which they are autotuned in
50/// each group.
51pub struct TuneGroup<K> {
52    id: u32,
53    name: Arc<String>,
54    pub(crate) priority: PriorityFunc<K>,
55}
56
57impl<K> core::fmt::Debug for TuneGroup<K> {
58    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59        f.debug_struct("TuneGroup").field("id", &self.id).finish()
60    }
61}
62
63impl<K> Clone for TuneGroup<K> {
64    fn clone(&self) -> Self {
65        Self {
66            id: self.id,
67            name: self.name.clone(),
68            priority: self.priority.clone(),
69        }
70    }
71}
72
73impl<K> TuneGroup<K> {
74    /// Create a new group based on a priority function.
75    pub fn new<F: Fn(&K) -> i8 + 'static, S: Into<String>>(name: S, f: F) -> Self {
76        let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
77
78        Self {
79            id,
80            name: Arc::new(name.into()),
81            priority: Arc::new(f),
82        }
83    }
84}
85
86#[derive(Debug)]
87/// A group plan dictates which [tunables](Tunable) should be executed, and in what order.
88pub(crate) struct TunePlan {
89    priorities: Vec<i8>,
90    no_groups: Vec<usize>,
91    groups: HashMap<i8, GroupPlan>,
92    returned: Vec<usize>,
93}
94
95#[derive(Default, Debug)]
96struct GroupPlan {
97    priorities: Vec<i8>,
98    indices: HashMap<i8, Vec<(usize, Arc<String>)>>,
99}
100
101#[derive(Debug)]
102struct Cleanup {
103    groups: Vec<i8>,
104    tunables: Vec<(i8, i8)>,
105    /// Within group priority is too low to even try.
106    skipped: bool,
107}
108
109impl TunePlan {
110    pub fn new<K: AutotuneKey, In, Out>(key: &K, tunables: &[Tunable<K, In, Out>]) -> Self {
111        let mut priorities = Vec::<i8>::new();
112        let mut no_groups = Vec::new();
113        let mut groups = HashMap::<i8, GroupPlan>::new();
114
115        for (index, tunable) in tunables.iter().enumerate() {
116            if tunable.groups.is_empty() {
117                no_groups.push(index);
118            } else {
119                for (group, within_group_priority_fn) in tunable.groups.iter() {
120                    let priority_fn = &group.priority;
121                    let priority = priority_fn(key);
122                    if !priorities.contains(&priority) {
123                        priorities.push(priority);
124                    }
125
126                    let group_priorities = match groups.get_mut(&priority) {
127                        Some(val) => val,
128                        None => {
129                            groups.insert(priority, GroupPlan::default());
130                            groups.get_mut(&priority).unwrap()
131                        }
132                    };
133                    let priority = within_group_priority_fn(key);
134
135                    if group_priorities.priorities.contains(&priority) {
136                        group_priorities
137                            .indices
138                            .get_mut(&priority)
139                            .unwrap()
140                            .push((index, group.name.clone()));
141                    } else {
142                        group_priorities.priorities.push(priority);
143                        group_priorities
144                            .indices
145                            .insert(priority, vec![(index, group.name.clone())]);
146                    }
147                }
148            }
149        }
150
151        priorities.sort();
152
153        for group in groups.iter_mut() {
154            group.1.priorities.sort();
155        }
156
157        Self {
158            priorities,
159            no_groups,
160            groups,
161            returned: Vec::new(),
162        }
163    }
164
165    /// Get the next batch of [tunable](Tunable) index to be autotuned.
166    ///
167    /// Note that if the list is empty, it means no more autotuned entry can be executed.
168    pub(crate) fn next(&mut self, mut context_logs: Option<&mut String>) -> Vec<usize> {
169        let mut indices = core::mem::take(&mut self.no_groups);
170        let priority = self.priorities.last();
171
172        let priority = match priority {
173            Some(val) => *val,
174            None => return indices,
175        };
176
177        let (group_indices, cleanup) = self.group_plan_next(priority);
178        // Some entries are skipped for this round of prioritizing.
179        let skipped = cleanup.skipped || priority < 0;
180
181        self.cleanup(cleanup);
182
183        if priority >= 0 {
184            if let Some(ctx) = context_logs.take() {
185                *ctx += format!("\n - Tuning: {group_indices:?}").as_str();
186                context_logs = Some(ctx);
187            }
188            for (index, _name) in group_indices {
189                if !self.returned.contains(&index) {
190                    indices.push(index);
191                }
192            }
193        }
194
195        // The indices list is empty, but it doesn't mean we should stop
196        // autotuning, since some entries were skipped.
197        if indices.is_empty() && skipped {
198            self.next(context_logs)
199        } else {
200            for i in indices.iter() {
201                self.returned.push(*i);
202            }
203            indices
204        }
205    }
206
207    fn cleanup(&mut self, cleanup: Cleanup) {
208        for group_p in cleanup.groups {
209            let index = self
210                .priorities
211                .iter()
212                .enumerate()
213                .find(|p| *p.1 == group_p)
214                .unwrap();
215
216            self.priorities.remove(index.0);
217            self.groups.remove(&group_p);
218        }
219
220        for (group_p, tunable_p) in cleanup.tunables {
221            if let Some(group) = self.groups.get_mut(&group_p) {
222                let index = group
223                    .priorities
224                    .iter()
225                    .enumerate()
226                    .find(|p| *p.1 == tunable_p)
227                    .unwrap();
228                group.priorities.remove(index.0);
229                group.indices.remove(&tunable_p);
230            }
231        }
232    }
233
234    fn group_plan_next(&mut self, priority: i8) -> (Vec<(usize, Arc<String>)>, Cleanup) {
235        let group_plan = self.groups.get_mut(&priority).expect("To be filled");
236        let within_group_prio = group_plan.priorities.pop().unwrap();
237        let mut next_indices = group_plan.indices.remove(&within_group_prio).unwrap();
238
239        let mut cleanup_groups = Vec::new();
240        let mut cleanup_tunables = Vec::new();
241
242        for (pg, group) in self.groups.iter_mut() {
243            let mut num_empty_tunables = 0;
244            let num_tunables = group.priorities.len();
245
246            for (pt, indices) in group.indices.iter_mut() {
247                for n in &next_indices {
248                    let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
249                    if let Some(entry) = entry {
250                        indices.remove(entry.0);
251                    }
252                }
253
254                if indices.is_empty() {
255                    num_empty_tunables += 1;
256                    cleanup_tunables.push((*pg, *pt));
257                }
258            }
259
260            if num_empty_tunables == num_tunables {
261                cleanup_groups.push(*pg);
262            }
263        }
264
265        if within_group_prio < 0 {
266            // Discard algorithms with negative priority
267            next_indices.clear();
268        }
269
270        (
271            next_indices,
272            Cleanup {
273                groups: cleanup_groups,
274                tunables: cleanup_tunables,
275                skipped: within_group_prio < 0,
276            },
277        )
278    }
279}
280
281type PriorityFunc<K> = Arc<dyn Fn(&K) -> i8>;
282
283static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
284
285#[cfg(test)]
286mod tests {
287    use core::fmt::Display;
288
289    use serde::{Deserialize, Serialize};
290
291    use super::*;
292
293    #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
294    struct FakeAutotuneKey;
295
296    impl Display for FakeAutotuneKey {
297        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
298            f.write_str("FakeAutotuneKey")
299        }
300    }
301
302    impl AutotuneKey for FakeAutotuneKey {}
303
304    #[test]
305    fn test_plan_order() {
306        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
307        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
308
309        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
310        let tunable1 =
311            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
312        let tunable2 =
313            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
314        let tunable3 =
315            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
316
317        let key = FakeAutotuneKey;
318        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
319
320        assert_eq!(plan.next(None), vec![0, 2]);
321        assert_eq!(plan.next(None), vec![1]);
322        assert_eq!(plan.next(None), vec![3]);
323        assert!(plan.next(None).is_empty());
324    }
325
326    #[test]
327    fn test_plan_order_multi_groups_same_priority() {
328        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
329        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
330        let group2 = TuneGroup::<FakeAutotuneKey>::new("group2", |_| 1);
331
332        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
333        let tunable1 =
334            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
335        let tunable2 =
336            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
337        let tunable3 =
338            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
339        let tunable4 =
340            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group2, |_| 2);
341
342        let key = FakeAutotuneKey;
343        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
344
345        assert_eq!(plan.next(None), vec![0, 2]);
346        assert_eq!(plan.next(None), vec![1]);
347        assert_eq!(plan.next(None), vec![3, 4]);
348        assert!(plan.next(None).is_empty());
349    }
350
351    #[test]
352    fn test_plan_order_tunable_multiple_groups() {
353        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 1);
354        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 2);
355
356        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
357        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel)
358            .group(&group0, |_| 1)
359            .group(&group1, |_| 2);
360        let tunable2 =
361            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
362        let tunable3 =
363            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 3);
364
365        let key = FakeAutotuneKey;
366        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
367
368        assert_eq!(plan.next(None), vec![0, 3]);
369        assert_eq!(plan.next(None), vec![1]);
370        assert_eq!(plan.next(None), vec![2]);
371        assert!(plan.next(None).is_empty());
372    }
373
374    #[test]
375    fn test_plan_negative_priority() {
376        let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
377        let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
378
379        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
380        let tunable1 =
381            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| -1);
382        let tunable2 =
383            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
384        let tunable3 =
385            Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
386
387        let key = FakeAutotuneKey;
388        let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
389
390        assert_eq!(plan.next(None), vec![0, 2]);
391        assert_eq!(plan.next(None), vec![3]);
392        assert!(plan.next(None).is_empty());
393    }
394
395    #[test]
396    fn test_plan_no_group() {
397        let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
398        let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
399
400        let key = FakeAutotuneKey;
401        let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
402
403        assert_eq!(plan.next(None), vec![0, 1]);
404        assert!(plan.next(None).is_empty());
405    }
406
407    fn fake_kernel() -> Result<(), String> {
408        Ok(())
409    }
410}