1use super::{AutotuneKey, IntoTuneFn, TuneFn};
2use alloc::sync::Arc;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::sync::atomic::{AtomicU32, Ordering};
6use hashbrown::HashMap;
7
8pub struct Tunable<K, Inputs, Output> {
13 pub(crate) function: Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>,
14 groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
15}
16
17impl<K, Inputs, Output> Tunable<K, Inputs, Output> {
18 pub fn new<Marker>(function: impl IntoTuneFn<Inputs, Output, Marker>) -> Self {
20 Self {
21 function: Arc::new(function.into_tunable()),
22 groups: Vec::new(),
23 }
24 }
25
26 pub fn group<F: Fn(&K) -> i8 + 'static>(mut self, group: &TuneGroup<K>, priority: F) -> Self {
33 self.groups.push((group.clone(), Arc::new(priority)));
34 self
35 }
36}
37
38pub struct TuneGroup<K> {
47 id: u32,
48 pub(crate) priority: PriorityFunc<K>,
49}
50
51impl<K> Clone for TuneGroup<K> {
52 fn clone(&self) -> Self {
53 Self {
54 id: self.id,
55 priority: self.priority.clone(),
56 }
57 }
58}
59
60impl<K> TuneGroup<K> {
61 pub fn new<F: Fn(&K) -> i8 + 'static>(f: F) -> Self {
63 let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
64
65 Self {
66 id,
67 priority: Arc::new(f),
68 }
69 }
70}
71
72#[derive(Debug)]
73pub(crate) struct TunePlan {
75 priorities: Vec<i8>,
76 no_groups: Vec<usize>,
77 groups: HashMap<i8, GroupPlan>,
78}
79
80#[derive(Default, Debug)]
81struct GroupPlan {
82 priorities: Vec<i8>,
83 indices: HashMap<i8, Vec<usize>>,
84}
85
86struct Cleanup {
87 groups: Vec<i8>,
88 tunables: Vec<(i8, i8)>,
89}
90
91impl TunePlan {
92 pub fn new<K: AutotuneKey, In, Out>(key: &K, tunables: &[Tunable<K, In, Out>]) -> Self {
93 let mut priorities = Vec::<i8>::new();
94 let mut no_groups = Vec::new();
95 let mut groups = HashMap::<i8, GroupPlan>::new();
96
97 for (index, tunable) in tunables.iter().enumerate() {
98 if tunable.groups.is_empty() {
99 no_groups.push(index);
100 } else {
101 for (group, within_group_priority_fn) in tunable.groups.iter() {
102 let priority_fn = &group.priority;
103 let priority = priority_fn(key);
104 if !priorities.contains(&priority) {
105 priorities.push(priority);
106 }
107
108 let group_priorities = match groups.get_mut(&priority) {
109 Some(val) => val,
110 None => {
111 groups.insert(priority, GroupPlan::default());
112 groups.get_mut(&priority).unwrap()
113 }
114 };
115 let priority = within_group_priority_fn(key);
116
117 if group_priorities.priorities.contains(&priority) {
118 group_priorities
119 .indices
120 .get_mut(&priority)
121 .unwrap()
122 .push(index);
123 } else {
124 group_priorities.priorities.push(priority);
125 group_priorities.indices.insert(priority, vec![index]);
126 }
127 }
128 }
129 }
130
131 priorities.sort();
132
133 for group in groups.iter_mut() {
134 group.1.priorities.sort();
135 }
136
137 Self {
138 priorities,
139 no_groups,
140 groups,
141 }
142 }
143
144 pub(crate) fn next(&mut self) -> Vec<usize> {
148 let mut indices = core::mem::take(&mut self.no_groups);
149 let priority = self.priorities.last();
150
151 let priority = match priority {
152 Some(val) => *val,
153 None => return indices,
154 };
155
156 let (mut group_indices, cleanup) = self.group_plan_next(priority);
157 self.cleanup(cleanup);
158
159 if priority >= 0 {
160 indices.append(&mut group_indices);
161 }
162
163 indices
164 }
165
166 fn cleanup(&mut self, cleanup: Cleanup) {
167 for group_p in cleanup.groups {
168 let index = self
169 .priorities
170 .iter()
171 .enumerate()
172 .find(|p| *p.1 == group_p)
173 .unwrap();
174
175 self.priorities.remove(index.0);
176 self.groups.remove(&group_p);
177 }
178
179 for (group_p, tunable_p) in cleanup.tunables {
180 if let Some(group) = self.groups.get_mut(&group_p) {
181 let index = group
182 .priorities
183 .iter()
184 .enumerate()
185 .find(|p| *p.1 == tunable_p)
186 .unwrap();
187 group.priorities.remove(index.0);
188 group.indices.remove(&tunable_p);
189 }
190 }
191 }
192
193 fn group_plan_next(&mut self, priority: i8) -> (Vec<usize>, Cleanup) {
194 let plan = self.groups.get_mut(&priority).expect("To be filled");
195 let within_group_prio = plan.priorities.pop().unwrap();
196 let mut next_indices = plan.indices.remove(&within_group_prio).unwrap();
197
198 let mut cleanup_groups = Vec::new();
199 let mut cleanup_tunables = Vec::new();
200
201 for (pg, group) in self.groups.iter_mut() {
202 let mut num_empty_tunables = 0;
203 let num_tunables = group.priorities.len();
204
205 for (pt, indices) in group.indices.iter_mut() {
206 for n in &next_indices {
207 let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
208 if let Some(entry) = entry {
209 indices.remove(entry.0);
210 }
211 }
212
213 if indices.is_empty() {
214 num_empty_tunables += 1;
215 cleanup_tunables.push((*pg, *pt));
216 }
217 }
218
219 if num_empty_tunables == num_tunables {
220 cleanup_groups.push(*pg);
221 }
222 }
223
224 if within_group_prio < 0 {
225 next_indices.clear();
227 }
228
229 (
230 next_indices,
231 Cleanup {
232 groups: cleanup_groups,
233 tunables: cleanup_tunables,
234 },
235 )
236 }
237}
238
239type PriorityFunc<K> = Arc<dyn Fn(&K) -> i8>;
240
241static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
242
243#[cfg(test)]
244mod tests {
245 use core::fmt::Display;
246
247 use serde::{Deserialize, Serialize};
248
249 use super::*;
250
251 #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
252 struct FakeAutotuneKey;
253
254 impl Display for FakeAutotuneKey {
255 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
256 f.write_str("FakeAutotuneKey")
257 }
258 }
259
260 impl AutotuneKey for FakeAutotuneKey {}
261
262 #[test]
263 fn test_plan_order() {
264 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
265 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
266
267 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
268 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
269 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
270 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
271
272 let key = FakeAutotuneKey;
273 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
274
275 assert_eq!(plan.next(), vec![0, 2]);
276 assert_eq!(plan.next(), vec![1]);
277 assert_eq!(plan.next(), vec![3]);
278 assert!(plan.next().is_empty());
279 }
280
281 #[test]
282 fn test_plan_order_multi_groups_same_priority() {
283 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
284 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
285 let group2 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
286
287 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
288 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
289 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
290 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
291 let tunable4 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group2, |_| 2);
292
293 let key = FakeAutotuneKey;
294 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
295
296 assert_eq!(plan.next(), vec![0, 2]);
297 assert_eq!(plan.next(), vec![1]);
298 assert_eq!(plan.next(), vec![3, 4]);
299 assert!(plan.next().is_empty());
300 }
301
302 #[test]
303 fn test_plan_order_tunable_multiple_groups() {
304 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
305 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
306
307 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
308 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel)
309 .group(&group0, |_| 1)
310 .group(&group1, |_| 2);
311 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
312 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 3);
313
314 let key = FakeAutotuneKey;
315 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
316
317 assert_eq!(plan.next(), vec![0, 3]);
318 assert_eq!(plan.next(), vec![1]);
319 assert_eq!(plan.next(), vec![2]);
320 assert!(plan.next().is_empty());
321 }
322
323 #[test]
324 fn test_plan_no_group() {
325 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
326 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
327
328 let key = FakeAutotuneKey;
329 let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
330
331 assert_eq!(plan.next(), vec![0, 1]);
332 assert!(plan.next().is_empty());
333 }
334
335 fn fake_kernel() -> Result<(), String> {
336 Ok(())
337 }
338}