1use super::{AutotuneKey, IntoTuneFn, TuneFn};
2use alloc::sync::Arc;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::sync::atomic::{AtomicU32, Ordering};
6use hashbrown::HashMap;
7
8pub struct Tunable<K, Inputs, Output> {
13 pub(crate) function: Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>,
14 groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
15}
16
17impl<K, Inputs, Output> Tunable<K, Inputs, Output> {
18 pub fn new<Marker>(function: impl IntoTuneFn<Inputs, Output, Marker>) -> Self {
20 Self {
21 function: Arc::new(function.into_tunable()),
22 groups: Vec::new(),
23 }
24 }
25
26 pub fn group<F: Fn(&K) -> i8 + 'static>(mut self, group: &TuneGroup<K>, priority: F) -> Self {
33 self.groups.push((group.clone(), Arc::new(priority)));
34 self
35 }
36}
37
38pub struct TuneGroup<K> {
47 id: u32,
48 pub(crate) priority: PriorityFunc<K>,
49}
50
51impl<K> core::fmt::Debug for TuneGroup<K> {
52 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53 f.debug_struct("TuneGroup").field("id", &self.id).finish()
54 }
55}
56
57impl<K> Clone for TuneGroup<K> {
58 fn clone(&self) -> Self {
59 Self {
60 id: self.id,
61 priority: self.priority.clone(),
62 }
63 }
64}
65
66impl<K> TuneGroup<K> {
67 pub fn new<F: Fn(&K) -> i8 + 'static>(f: F) -> Self {
69 let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
70
71 Self {
72 id,
73 priority: Arc::new(f),
74 }
75 }
76}
77
78#[derive(Debug)]
79pub(crate) struct TunePlan {
81 priorities: Vec<i8>,
82 no_groups: Vec<usize>,
83 groups: HashMap<i8, GroupPlan>,
84}
85
86#[derive(Default, Debug)]
87struct GroupPlan {
88 priorities: Vec<i8>,
89 indices: HashMap<i8, Vec<usize>>,
90}
91
92#[derive(Debug)]
93struct Cleanup {
94 groups: Vec<i8>,
95 tunables: Vec<(i8, i8)>,
96 skipped: bool,
98}
99
100impl TunePlan {
101 pub fn new<K: AutotuneKey, In, Out>(key: &K, tunables: &[Tunable<K, In, Out>]) -> Self {
102 let mut priorities = Vec::<i8>::new();
103 let mut no_groups = Vec::new();
104 let mut groups = HashMap::<i8, GroupPlan>::new();
105
106 for (index, tunable) in tunables.iter().enumerate() {
107 if tunable.groups.is_empty() {
108 no_groups.push(index);
109 } else {
110 for (group, within_group_priority_fn) in tunable.groups.iter() {
111 let priority_fn = &group.priority;
112 let priority = priority_fn(key);
113 if !priorities.contains(&priority) {
114 priorities.push(priority);
115 }
116
117 let group_priorities = match groups.get_mut(&priority) {
118 Some(val) => val,
119 None => {
120 groups.insert(priority, GroupPlan::default());
121 groups.get_mut(&priority).unwrap()
122 }
123 };
124 let priority = within_group_priority_fn(key);
125
126 if group_priorities.priorities.contains(&priority) {
127 group_priorities
128 .indices
129 .get_mut(&priority)
130 .unwrap()
131 .push(index);
132 } else {
133 group_priorities.priorities.push(priority);
134 group_priorities.indices.insert(priority, vec![index]);
135 }
136 }
137 }
138 }
139
140 priorities.sort();
141
142 for group in groups.iter_mut() {
143 group.1.priorities.sort();
144 }
145
146 Self {
147 priorities,
148 no_groups,
149 groups,
150 }
151 }
152
153 pub(crate) fn next(&mut self) -> Vec<usize> {
157 let mut indices = core::mem::take(&mut self.no_groups);
158 let priority = self.priorities.last();
159
160 let priority = match priority {
161 Some(val) => *val,
162 None => return indices,
163 };
164
165 let (mut group_indices, cleanup) = self.group_plan_next(priority);
166 let skipped = cleanup.skipped || priority < 0;
168
169 self.cleanup(cleanup);
170
171 if priority >= 0 {
172 indices.append(&mut group_indices);
173 }
174
175 if indices.is_empty() && skipped {
178 self.next()
179 } else {
180 indices
181 }
182 }
183
184 fn cleanup(&mut self, cleanup: Cleanup) {
185 for group_p in cleanup.groups {
186 let index = self
187 .priorities
188 .iter()
189 .enumerate()
190 .find(|p| *p.1 == group_p)
191 .unwrap();
192
193 self.priorities.remove(index.0);
194 self.groups.remove(&group_p);
195 }
196
197 for (group_p, tunable_p) in cleanup.tunables {
198 if let Some(group) = self.groups.get_mut(&group_p) {
199 let index = group
200 .priorities
201 .iter()
202 .enumerate()
203 .find(|p| *p.1 == tunable_p)
204 .unwrap();
205 group.priorities.remove(index.0);
206 group.indices.remove(&tunable_p);
207 }
208 }
209 }
210
211 fn group_plan_next(&mut self, priority: i8) -> (Vec<usize>, Cleanup) {
212 let plan = self.groups.get_mut(&priority).expect("To be filled");
213 let within_group_prio = plan.priorities.pop().unwrap();
214 let mut next_indices = plan.indices.remove(&within_group_prio).unwrap();
215
216 let mut cleanup_groups = Vec::new();
217 let mut cleanup_tunables = Vec::new();
218
219 for (pg, group) in self.groups.iter_mut() {
220 let mut num_empty_tunables = 0;
221 let num_tunables = group.priorities.len();
222
223 for (pt, indices) in group.indices.iter_mut() {
224 for n in &next_indices {
225 let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
226 if let Some(entry) = entry {
227 indices.remove(entry.0);
228 }
229 }
230
231 if indices.is_empty() {
232 num_empty_tunables += 1;
233 cleanup_tunables.push((*pg, *pt));
234 }
235 }
236
237 if num_empty_tunables == num_tunables {
238 cleanup_groups.push(*pg);
239 }
240 }
241
242 if within_group_prio < 0 {
243 next_indices.clear();
245 }
246
247 (
248 next_indices,
249 Cleanup {
250 groups: cleanup_groups,
251 tunables: cleanup_tunables,
252 skipped: within_group_prio < 0,
253 },
254 )
255 }
256}
257
258type PriorityFunc<K> = Arc<dyn Fn(&K) -> i8>;
259
260static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
261
262#[cfg(test)]
263mod tests {
264 use core::fmt::Display;
265
266 use serde::{Deserialize, Serialize};
267
268 use super::*;
269
270 #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
271 struct FakeAutotuneKey;
272
273 impl Display for FakeAutotuneKey {
274 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
275 f.write_str("FakeAutotuneKey")
276 }
277 }
278
279 impl AutotuneKey for FakeAutotuneKey {}
280
281 #[test]
282 fn test_plan_order() {
283 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
284 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
285
286 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
287 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
288 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
289 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
290
291 let key = FakeAutotuneKey;
292 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
293
294 assert_eq!(plan.next(), vec![0, 2]);
295 assert_eq!(plan.next(), vec![1]);
296 assert_eq!(plan.next(), vec![3]);
297 assert!(plan.next().is_empty());
298 }
299
300 #[test]
301 fn test_plan_order_multi_groups_same_priority() {
302 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
303 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
304 let group2 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
305
306 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
307 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
308 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
309 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
310 let tunable4 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group2, |_| 2);
311
312 let key = FakeAutotuneKey;
313 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
314
315 assert_eq!(plan.next(), vec![0, 2]);
316 assert_eq!(plan.next(), vec![1]);
317 assert_eq!(plan.next(), vec![3, 4]);
318 assert!(plan.next().is_empty());
319 }
320
321 #[test]
322 fn test_plan_order_tunable_multiple_groups() {
323 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
324 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
325
326 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
327 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel)
328 .group(&group0, |_| 1)
329 .group(&group1, |_| 2);
330 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
331 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 3);
332
333 let key = FakeAutotuneKey;
334 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
335
336 assert_eq!(plan.next(), vec![0, 3]);
337 assert_eq!(plan.next(), vec![1]);
338 assert_eq!(plan.next(), vec![2]);
339 assert!(plan.next().is_empty());
340 }
341
342 #[test]
343 fn test_plan_negative_priority() {
344 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
345 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
346
347 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
348 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| -1);
349 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
350 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
351
352 let key = FakeAutotuneKey;
353 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
354
355 assert_eq!(plan.next(), vec![0, 2]);
356 assert_eq!(plan.next(), vec![3]);
357 assert!(plan.next().is_empty());
358 }
359
360 #[test]
361 fn test_plan_no_group() {
362 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
363 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
364
365 let key = FakeAutotuneKey;
366 let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
367
368 assert_eq!(plan.next(), vec![0, 1]);
369 assert!(plan.next().is_empty());
370 }
371
372 fn fake_kernel() -> Result<(), String> {
373 Ok(())
374 }
375}