1use super::{AutotuneKey, IntoTuneFn, TuneFn};
2use alloc::format;
3use alloc::string::String;
4use alloc::sync::Arc;
5use alloc::vec;
6use alloc::vec::Vec;
7use core::sync::atomic::{AtomicU32, Ordering};
8use hashbrown::HashMap;
9
10pub struct Tunable<K, Inputs, Output> {
15 pub(crate) function: Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>,
16 groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
17}
18
19impl<K, Inputs, Output> Tunable<K, Inputs, Output> {
20 pub fn new<Marker, Name: Into<String>>(
22 name: Name,
23 function: impl IntoTuneFn<Inputs, Output, Marker>,
24 ) -> Self {
25 Self {
26 function: Arc::new(function.into_tunable(name.into())),
27 groups: Vec::new(),
28 }
29 }
30
31 pub fn group<F: Fn(&K) -> i8 + 'static>(mut self, group: &TuneGroup<K>, priority: F) -> Self {
38 self.groups.push((group.clone(), Arc::new(priority)));
39 self
40 }
41}
42
43pub struct TuneGroup<K> {
52 id: u32,
53 name: Arc<String>,
54 pub(crate) priority: PriorityFunc<K>,
55}
56
57impl<K> core::fmt::Debug for TuneGroup<K> {
58 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59 f.debug_struct("TuneGroup").field("id", &self.id).finish()
60 }
61}
62
63impl<K> Clone for TuneGroup<K> {
64 fn clone(&self) -> Self {
65 Self {
66 id: self.id,
67 name: self.name.clone(),
68 priority: self.priority.clone(),
69 }
70 }
71}
72
73impl<K> TuneGroup<K> {
74 pub fn new<F: Fn(&K) -> i8 + 'static, S: Into<String>>(name: S, f: F) -> Self {
76 let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
77
78 Self {
79 id,
80 name: Arc::new(name.into()),
81 priority: Arc::new(f),
82 }
83 }
84}
85
86#[derive(Debug)]
87pub(crate) struct TunePlan {
89 priorities: Vec<i8>,
90 no_groups: Vec<usize>,
91 groups: HashMap<i8, GroupPlan>,
92 returned: Vec<usize>,
93}
94
95#[derive(Default, Debug)]
96struct GroupPlan {
97 priorities: Vec<i8>,
98 indices: HashMap<i8, Vec<(usize, Arc<String>)>>,
99}
100
101#[derive(Debug)]
102struct Cleanup {
103 groups: Vec<i8>,
104 tunables: Vec<(i8, i8)>,
105 skipped: bool,
107}
108
109impl TunePlan {
110 pub fn new<K: AutotuneKey, In, Out>(key: &K, tunables: &[Tunable<K, In, Out>]) -> Self {
111 let mut priorities = Vec::<i8>::new();
112 let mut no_groups = Vec::new();
113 let mut groups = HashMap::<i8, GroupPlan>::new();
114
115 for (index, tunable) in tunables.iter().enumerate() {
116 if tunable.groups.is_empty() {
117 no_groups.push(index);
118 } else {
119 for (group, within_group_priority_fn) in tunable.groups.iter() {
120 let priority_fn = &group.priority;
121 let priority = priority_fn(key);
122 if !priorities.contains(&priority) {
123 priorities.push(priority);
124 }
125
126 let group_priorities = match groups.get_mut(&priority) {
127 Some(val) => val,
128 None => {
129 groups.insert(priority, GroupPlan::default());
130 groups.get_mut(&priority).unwrap()
131 }
132 };
133 let priority = within_group_priority_fn(key);
134
135 if group_priorities.priorities.contains(&priority) {
136 group_priorities
137 .indices
138 .get_mut(&priority)
139 .unwrap()
140 .push((index, group.name.clone()));
141 } else {
142 group_priorities.priorities.push(priority);
143 group_priorities
144 .indices
145 .insert(priority, vec![(index, group.name.clone())]);
146 }
147 }
148 }
149 }
150
151 priorities.sort();
152
153 for group in groups.iter_mut() {
154 group.1.priorities.sort();
155 }
156
157 Self {
158 priorities,
159 no_groups,
160 groups,
161 returned: Vec::new(),
162 }
163 }
164
165 pub(crate) fn next(&mut self, mut context_logs: Option<&mut String>) -> Vec<usize> {
169 let mut indices = core::mem::take(&mut self.no_groups);
170 let priority = self.priorities.last();
171
172 let priority = match priority {
173 Some(val) => *val,
174 None => return indices,
175 };
176
177 let (group_indices, cleanup) = self.group_plan_next(priority);
178 let skipped = cleanup.skipped || priority < 0;
180
181 self.cleanup(cleanup);
182
183 if priority >= 0 {
184 if let Some(ctx) = context_logs.take() {
185 *ctx += format!("\n - Tuning: {group_indices:?}").as_str();
186 context_logs = Some(ctx);
187 }
188 for (index, _name) in group_indices {
189 if !self.returned.contains(&index) {
190 indices.push(index);
191 }
192 }
193 }
194
195 if indices.is_empty() && skipped {
198 self.next(context_logs)
199 } else {
200 for i in indices.iter() {
201 self.returned.push(*i);
202 }
203 indices
204 }
205 }
206
207 fn cleanup(&mut self, cleanup: Cleanup) {
208 for group_p in cleanup.groups {
209 let index = self
210 .priorities
211 .iter()
212 .enumerate()
213 .find(|p| *p.1 == group_p)
214 .unwrap();
215
216 self.priorities.remove(index.0);
217 self.groups.remove(&group_p);
218 }
219
220 for (group_p, tunable_p) in cleanup.tunables {
221 if let Some(group) = self.groups.get_mut(&group_p) {
222 let index = group
223 .priorities
224 .iter()
225 .enumerate()
226 .find(|p| *p.1 == tunable_p)
227 .unwrap();
228 group.priorities.remove(index.0);
229 group.indices.remove(&tunable_p);
230 }
231 }
232 }
233
234 fn group_plan_next(&mut self, priority: i8) -> (Vec<(usize, Arc<String>)>, Cleanup) {
235 let group_plan = self.groups.get_mut(&priority).expect("To be filled");
236 let within_group_prio = group_plan.priorities.pop().unwrap();
237 let mut next_indices = group_plan.indices.remove(&within_group_prio).unwrap();
238
239 let mut cleanup_groups = Vec::new();
240 let mut cleanup_tunables = Vec::new();
241
242 for (pg, group) in self.groups.iter_mut() {
243 let mut num_empty_tunables = 0;
244 let num_tunables = group.priorities.len();
245
246 for (pt, indices) in group.indices.iter_mut() {
247 for n in &next_indices {
248 let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
249 if let Some(entry) = entry {
250 indices.remove(entry.0);
251 }
252 }
253
254 if indices.is_empty() {
255 num_empty_tunables += 1;
256 cleanup_tunables.push((*pg, *pt));
257 }
258 }
259
260 if num_empty_tunables == num_tunables {
261 cleanup_groups.push(*pg);
262 }
263 }
264
265 if within_group_prio < 0 {
266 next_indices.clear();
268 }
269
270 (
271 next_indices,
272 Cleanup {
273 groups: cleanup_groups,
274 tunables: cleanup_tunables,
275 skipped: within_group_prio < 0,
276 },
277 )
278 }
279}
280
281type PriorityFunc<K> = Arc<dyn Fn(&K) -> i8>;
282
283static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
284
285#[cfg(test)]
286mod tests {
287 use core::fmt::Display;
288
289 use serde::{Deserialize, Serialize};
290
291 use super::*;
292
293 #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
294 struct FakeAutotuneKey;
295
296 impl Display for FakeAutotuneKey {
297 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
298 f.write_str("FakeAutotuneKey")
299 }
300 }
301
302 impl AutotuneKey for FakeAutotuneKey {}
303
304 #[test]
305 fn test_plan_order() {
306 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
307 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
308
309 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
310 let tunable1 =
311 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
312 let tunable2 =
313 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
314 let tunable3 =
315 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
316
317 let key = FakeAutotuneKey;
318 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
319
320 assert_eq!(plan.next(None), vec![0, 2]);
321 assert_eq!(plan.next(None), vec![1]);
322 assert_eq!(plan.next(None), vec![3]);
323 assert!(plan.next(None).is_empty());
324 }
325
326 #[test]
327 fn test_plan_order_multi_groups_same_priority() {
328 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
329 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
330 let group2 = TuneGroup::<FakeAutotuneKey>::new("group2", |_| 1);
331
332 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
333 let tunable1 =
334 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 1);
335 let tunable2 =
336 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
337 let tunable3 =
338 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
339 let tunable4 =
340 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group2, |_| 2);
341
342 let key = FakeAutotuneKey;
343 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
344
345 assert_eq!(plan.next(None), vec![0, 2]);
346 assert_eq!(plan.next(None), vec![1]);
347 assert_eq!(plan.next(None), vec![3, 4]);
348 assert!(plan.next(None).is_empty());
349 }
350
351 #[test]
352 fn test_plan_order_tunable_multiple_groups() {
353 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 1);
354 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 2);
355
356 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
357 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel)
358 .group(&group0, |_| 1)
359 .group(&group1, |_| 2);
360 let tunable2 =
361 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
362 let tunable3 =
363 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 3);
364
365 let key = FakeAutotuneKey;
366 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
367
368 assert_eq!(plan.next(None), vec![0, 3]);
369 assert_eq!(plan.next(None), vec![1]);
370 assert_eq!(plan.next(None), vec![2]);
371 assert!(plan.next(None).is_empty());
372 }
373
374 #[test]
375 fn test_plan_negative_priority() {
376 let group0 = TuneGroup::<FakeAutotuneKey>::new("group0", |_| 2);
377 let group1 = TuneGroup::<FakeAutotuneKey>::new("group1", |_| 1);
378
379 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
380 let tunable1 =
381 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| -1);
382 let tunable2 =
383 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group0, |_| 2);
384 let tunable3 =
385 Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel).group(&group1, |_| 2);
386
387 let key = FakeAutotuneKey;
388 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
389
390 assert_eq!(plan.next(None), vec![0, 2]);
391 assert_eq!(plan.next(None), vec![3]);
392 assert!(plan.next(None).is_empty());
393 }
394
395 #[test]
396 fn test_plan_no_group() {
397 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
398 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new("fake", fake_kernel);
399
400 let key = FakeAutotuneKey;
401 let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
402
403 assert_eq!(plan.next(None), vec![0, 1]);
404 assert!(plan.next(None).is_empty());
405 }
406
407 fn fake_kernel() -> Result<(), String> {
408 Ok(())
409 }
410}