1use super::{AutotuneError, AutotuneKey, TuneFn, TuneInputs};
2use alloc::boxed::Box;
3use alloc::string::ToString;
4use alloc::{format, string::String, sync::Arc, vec, vec::Vec};
5use core::sync::atomic::{AtomicU32, Ordering};
6use hashbrown::HashMap;
7
8pub struct Tunable<K, F: TuneInputs, Output> {
11 pub(crate) function: TuneFn<F, Output>,
12 groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
13}
14
15impl<K, F: TuneInputs, Output: 'static> Tunable<K, F, Output> {
16 pub fn new<Func, Err>(name: &str, func: Func) -> Self
27 where
28 Err: Into<String> + 'static,
29 Func: for<'a> Fn(<F as TuneInputs>::At<'a>) -> Result<Output, Err> + Send + Sync + 'static,
30 {
31 let name: String = name.into();
32 let name_for_err = name.clone();
33 Self {
34 function: TuneFn::new(
35 name,
36 Box::new(move |inputs| {
37 func(inputs).map_err(|err| AutotuneError::Unknown {
38 name: name_for_err.to_string(),
39 err: err.into(),
40 })
41 }),
42 ),
43 groups: Vec::new(),
44 }
45 }
46
47 pub fn group(
53 mut self,
54 group: &TuneGroup<K>,
55 priority: impl Fn(&K) -> i8 + Send + Sync + 'static,
56 ) -> Self {
57 self.groups.push((group.clone(), Arc::new(priority)));
58 self
59 }
60}
61
62pub struct TuneGroup<K> {
67 id: u32,
68 name: Arc<String>,
69 pub(crate) priority: PriorityFunc<K>,
70}
71
72impl<K> core::fmt::Debug for TuneGroup<K> {
73 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74 f.debug_struct("TuneGroup").field("id", &self.id).finish()
75 }
76}
77
78impl<K> Clone for TuneGroup<K> {
79 fn clone(&self) -> Self {
80 Self {
81 id: self.id,
82 name: self.name.clone(),
83 priority: self.priority.clone(),
84 }
85 }
86}
87
88impl<K> TuneGroup<K> {
89 pub fn new(name: &str, f: impl Fn(&K) -> i8 + Send + Sync + 'static) -> Self {
91 let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
92
93 Self {
94 id,
95 name: Arc::new(name.into()),
96 priority: Arc::new(f),
97 }
98 }
99}
100
101#[derive(Debug)]
102pub(crate) struct TunePlan {
104 priorities: Vec<i8>,
105 no_groups: Vec<usize>,
106 groups: HashMap<i8, GroupPlan>,
107 returned: Vec<usize>,
108}
109
110#[derive(Default, Debug)]
111struct GroupPlan {
112 priorities: Vec<i8>,
113 indices: HashMap<i8, Vec<(usize, Arc<String>)>>,
114}
115
116#[derive(Debug)]
117struct Cleanup {
118 groups: Vec<i8>,
119 tunables: Vec<(i8, i8)>,
120 skipped: bool,
122}
123
124impl TunePlan {
125 pub fn new<K: AutotuneKey, F: TuneInputs, Out>(
126 key: &K,
127 tunables: &[Tunable<K, F, Out>],
128 ) -> Self {
129 let mut priorities = Vec::<i8>::new();
130 let mut no_groups = Vec::new();
131 let mut groups = HashMap::<i8, GroupPlan>::new();
132
133 for (index, tunable) in tunables.iter().enumerate() {
134 if tunable.groups.is_empty() {
135 no_groups.push(index);
136 } else {
137 for (group, within_group_priority_fn) in tunable.groups.iter() {
138 let priority_fn = &group.priority;
139 let priority = priority_fn(key);
140 if !priorities.contains(&priority) {
141 priorities.push(priority);
142 }
143
144 let group_priorities = match groups.get_mut(&priority) {
145 Some(val) => val,
146 None => {
147 groups.insert(priority, GroupPlan::default());
148 groups.get_mut(&priority).unwrap()
149 }
150 };
151 let priority = within_group_priority_fn(key);
152
153 if group_priorities.priorities.contains(&priority) {
154 group_priorities
155 .indices
156 .get_mut(&priority)
157 .unwrap()
158 .push((index, group.name.clone()));
159 } else {
160 group_priorities.priorities.push(priority);
161 group_priorities
162 .indices
163 .insert(priority, vec![(index, group.name.clone())]);
164 }
165 }
166 }
167 }
168
169 priorities.sort();
170
171 for group in groups.iter_mut() {
172 group.1.priorities.sort();
173 }
174
175 Self {
176 priorities,
177 no_groups,
178 groups,
179 returned: Vec::new(),
180 }
181 }
182
183 pub(crate) fn next(&mut self, mut context_logs: Option<&mut String>) -> Vec<usize> {
187 let mut indices = core::mem::take(&mut self.no_groups);
188 let priority = self.priorities.last();
189
190 let priority = match priority {
191 Some(val) => *val,
192 None => return indices,
193 };
194
195 let (group_indices, cleanup) = self.group_plan_next(priority);
196 let skipped = cleanup.skipped || priority < 0;
198 let mut all_skip = true;
199
200 self.cleanup(cleanup);
201
202 if priority >= 0 {
203 if let Some(ctx) = context_logs.take() {
204 *ctx += format!("\n - Tuning: {group_indices:?}").as_str();
205 context_logs = Some(ctx);
206 }
207 for (index, _name) in group_indices {
208 if !self.returned.contains(&index) {
209 all_skip = false;
210 indices.push(index);
211 }
212 }
213 }
214
215 if indices.is_empty() && (skipped || all_skip) {
219 self.next(context_logs)
220 } else {
221 for i in indices.iter() {
222 self.returned.push(*i);
223 }
224 indices
225 }
226 }
227
228 fn cleanup(&mut self, cleanup: Cleanup) {
229 for group_p in cleanup.groups {
230 let index = self
231 .priorities
232 .iter()
233 .enumerate()
234 .find(|p| *p.1 == group_p)
235 .unwrap();
236
237 self.priorities.remove(index.0);
238 self.groups.remove(&group_p);
239 }
240
241 for (group_p, tunable_p) in cleanup.tunables {
242 if let Some(group) = self.groups.get_mut(&group_p) {
243 let index = group
244 .priorities
245 .iter()
246 .enumerate()
247 .find(|p| *p.1 == tunable_p)
248 .unwrap();
249 group.priorities.remove(index.0);
250 group.indices.remove(&tunable_p);
251 }
252 }
253 }
254
255 fn group_plan_next(&mut self, priority: i8) -> (Vec<(usize, Arc<String>)>, Cleanup) {
256 let group_plan = self.groups.get_mut(&priority).expect("To be filled");
257 let within_group_prio = group_plan.priorities.pop().unwrap();
258 let mut next_indices = group_plan.indices.remove(&within_group_prio).unwrap();
259
260 let mut cleanup_groups = Vec::new();
261 let mut cleanup_tunables = Vec::new();
262
263 for (pg, group) in self.groups.iter_mut() {
264 let mut num_empty_tunables = 0;
265 let num_tunables = group.priorities.len();
266
267 for (pt, indices) in group.indices.iter_mut() {
268 for n in &next_indices {
269 let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
270 if let Some(entry) = entry {
271 indices.remove(entry.0);
272 }
273 }
274
275 if indices.is_empty() {
276 num_empty_tunables += 1;
277 cleanup_tunables.push((*pg, *pt));
278 }
279 }
280
281 if num_empty_tunables == num_tunables {
282 cleanup_groups.push(*pg);
283 }
284 }
285
286 if within_group_prio < 0 {
287 next_indices.clear();
289 }
290
291 (
292 next_indices,
293 Cleanup {
294 groups: cleanup_groups,
295 tunables: cleanup_tunables,
296 skipped: within_group_prio < 0,
297 },
298 )
299 }
300}
301
302type PriorityFunc<K> = Arc<dyn Fn(&K) -> i8 + Send + Sync>;
303
304static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
305
306#[cfg(test)]
307mod tests {
308 use core::fmt::Display;
309
310 use serde::{Deserialize, Serialize};
311
312 use super::*;
313
314 #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
315 struct FakeAutotuneKey;
316
317 impl Display for FakeAutotuneKey {
318 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
319 f.write_str("FakeAutotuneKey")
320 }
321 }
322
323 impl AutotuneKey for FakeAutotuneKey {}
324
325 #[test_log::test]
326 fn test_plan_order() {
327 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
328 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
329
330 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
331 let tunable1 =
332 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
333 let tunable2 =
334 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
335 let tunable3 =
336 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
337
338 let key = FakeAutotuneKey;
339 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
340
341 assert_eq!(plan.next(None), vec![0, 2]);
342 assert_eq!(plan.next(None), vec![1]);
343 assert_eq!(plan.next(None), vec![3]);
344 assert!(plan.next(None).is_empty());
345 }
346
347 #[test_log::test]
348 fn test_plan_order_multi_groups_same_priority() {
349 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
350 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
351 let group2 = TuneGroup::<FakeAutotuneKey>::new("group2", |_| 1);
352
353 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
354 let tunable1 =
355 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
356 let tunable2 =
357 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
358 let tunable3 =
359 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
360 let tunable4 =
361 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group2, |_| 2);
362
363 let key = FakeAutotuneKey;
364 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
365
366 assert_eq!(plan.next(None), vec![0, 2]);
367 assert_eq!(plan.next(None), vec![1]);
368 assert_eq!(plan.next(None), vec![3, 4]);
369 assert!(plan.next(None).is_empty());
370 }
371
372 #[test_log::test]
373 fn test_plan_order_tunable_multiple_groups() {
374 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 1);
375 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 2);
376
377 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
378 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel)
379 .group(&group0, |_| 1)
380 .group(&group1, |_| 2);
381 let tunable2 =
382 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
383 let tunable3 =
384 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 3);
385
386 let key = FakeAutotuneKey;
387 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
388
389 assert_eq!(plan.next(None), vec![0, 3]);
390 assert_eq!(plan.next(None), vec![1]);
391 assert_eq!(plan.next(None), vec![2]);
392 assert!(plan.next(None).is_empty());
393 }
394
395 #[test_log::test]
396 fn test_plan_negative_priority() {
397 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
398 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
399
400 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
401 let tunable1 =
402 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| -1);
403 let tunable2 =
404 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
405 let tunable3 =
406 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
407
408 let key = FakeAutotuneKey;
409 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
410
411 assert_eq!(plan.next(None), vec![0, 2]);
412 assert_eq!(plan.next(None), vec![3]);
413 assert!(plan.next(None).is_empty());
414 }
415
416 #[test_log::test]
417 fn test_plan_no_group() {
418 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
419 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
420
421 let key = FakeAutotuneKey;
422 let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
423
424 assert_eq!(plan.next(None), vec![0, 1]);
425 assert!(plan.next(None).is_empty());
426 }
427
428 #[test_log::test]
429 fn test_plan_falls_through_when_all_group_tunables_fail() {
430 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
434 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
435
436 let tunable0 =
437 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
438 let tunable1 =
439 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
440 let tunable2 =
441 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 1);
442 let tunable3 =
443 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
444
445 let key = FakeAutotuneKey;
446 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
447
448 let mut all_returned: Vec<usize> = Vec::new();
449 loop {
450 let batch = plan.next(None);
451 if batch.is_empty() {
452 break;
453 }
454 all_returned.extend(batch);
455 }
456
457 assert_eq!(all_returned, vec![1, 0, 3, 2]);
459 }
460
461 #[test_log::test]
462 fn test_plan_single_group_exhausts_all_intra_priorities() {
463 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 0);
466
467 let tunable0 =
468 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
469 let tunable1 =
470 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
471 let tunable2 =
472 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 3);
473
474 let key = FakeAutotuneKey;
475 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]);
476
477 assert_eq!(plan.next(None), vec![2]);
478 assert_eq!(plan.next(None), vec![1]);
479 assert_eq!(plan.next(None), vec![0]);
480 assert!(plan.next(None).is_empty());
481 }
482
483 #[test_log::test]
484 fn test_plan_all_negative_group_advances_to_next_group() {
485 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
488 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
489
490 let tunable0 =
491 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| -1);
492 let tunable1 =
493 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| -2);
494 let tunable2 =
495 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 1);
496
497 let key = FakeAutotuneKey;
498 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]);
499
500 assert_eq!(plan.next(None), vec![2]);
501 assert!(plan.next(None).is_empty());
502 }
503
504 #[test_log::test]
505 fn test_plan_no_group_tunables_only_emitted_once_even_on_failures() {
506 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
510 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
511
512 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
513 let tunable1 =
514 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
515 let tunable2 =
516 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 1);
517
518 let key = FakeAutotuneKey;
519 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]);
520
521 assert_eq!(plan.next(None), vec![0, 1]);
522 assert_eq!(plan.next(None), vec![2]);
523 assert!(plan.next(None).is_empty());
524 }
525
526 #[test_log::test]
527 fn test_plan_multi_group_tunable_not_duplicated_across_failed_groups() {
528 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 1);
531 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 2);
532
533 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel)
534 .group(&group0, |_| 1)
535 .group(&group1, |_| 1);
536 let tunable1 =
537 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
538
539 let key = FakeAutotuneKey;
540 let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
541
542 let mut all_returned: Vec<usize> = Vec::new();
543 loop {
544 let batch = plan.next(None);
545 if batch.is_empty() {
546 break;
547 }
548 all_returned.extend(batch);
549 }
550
551 assert_eq!(all_returned, vec![0, 1]);
554 }
555
556 #[test_log::test]
557 fn test_plan_recurses_when_batch_is_fully_already_returned() {
558 let group_hi = TuneGroup::<FakeAutotuneKey>::new("hi", |_| 2);
568 let group_lo = TuneGroup::<FakeAutotuneKey>::new("lo", |_| 1);
569
570 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel)
572 .group(&group_hi, |_| 1)
573 .group(&group_lo, |_| 2);
574 let tunable1 =
575 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group_lo, |_| 1);
576
577 let key = FakeAutotuneKey;
578 let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
579
580 assert_eq!(plan.next(None), vec![0]);
582 assert_eq!(plan.next(None), vec![1]);
586 assert!(plan.next(None).is_empty());
587 }
588
589 fn fake_kernel(_: ()) -> Result<(), String> {
590 Ok(())
591 }
592}