1use super::{AutotuneKey, IntoTuneFn, TuneFn};
2use alloc::sync::Arc;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::sync::atomic::{AtomicU32, Ordering};
6use hashbrown::HashMap;
7
8pub struct Tunable<K, Inputs, Output> {
13 pub(crate) function: Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>,
14 groups: Vec<(TuneGroup<K>, PriorityFunc<K>)>,
15}
16
17impl<K, Inputs, Output> Tunable<K, Inputs, Output> {
18 pub fn new<Marker>(function: impl IntoTuneFn<Inputs, Output, Marker>) -> Self {
20 Self {
21 function: Arc::new(function.into_tunable()),
22 groups: Vec::new(),
23 }
24 }
25
26 pub fn group<F: Fn(&K) -> u8 + 'static>(mut self, group: &TuneGroup<K>, priority: F) -> Self {
28 self.groups.push((group.clone(), Arc::new(priority)));
29 self
30 }
31}
32
33pub struct TuneGroup<K> {
42 id: u32,
43 pub(crate) priority: PriorityFunc<K>,
44}
45
46impl<K> Clone for TuneGroup<K> {
47 fn clone(&self) -> Self {
48 Self {
49 id: self.id,
50 priority: self.priority.clone(),
51 }
52 }
53}
54
55impl<K> TuneGroup<K> {
56 pub fn new<F: Fn(&K) -> u8 + 'static>(f: F) -> Self {
58 let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed);
59
60 Self {
61 id,
62 priority: Arc::new(f),
63 }
64 }
65}
66
67#[derive(Debug)]
68pub(crate) struct TunePlan {
70 priorities: Vec<u8>,
71 no_groups: Vec<usize>,
72 groups: HashMap<u8, GroupPlan>,
73}
74
75#[derive(Default, Debug)]
76struct GroupPlan {
77 priorities: Vec<u8>,
78 indices: HashMap<u8, Vec<usize>>,
79}
80
81struct Cleanup {
82 groups: Vec<u8>,
83 tunables: Vec<(u8, u8)>,
84}
85
86impl TunePlan {
87 pub fn new<K: AutotuneKey, In, Out>(key: &K, tunables: &[Tunable<K, In, Out>]) -> Self {
88 let mut priorities = Vec::<u8>::new();
89 let mut no_groups = Vec::new();
90 let mut groups = HashMap::<u8, GroupPlan>::new();
91
92 for (index, tunable) in tunables.iter().enumerate() {
93 if tunable.groups.is_empty() {
94 no_groups.push(index);
95 } else {
96 for (group, within_group_priority_fn) in tunable.groups.iter() {
97 let priority_fn = &group.priority;
98 let priority = priority_fn(key);
99 if !priorities.contains(&priority) {
100 priorities.push(priority);
101 }
102
103 let group_priorities = match groups.get_mut(&priority) {
104 Some(val) => val,
105 None => {
106 groups.insert(priority, GroupPlan::default());
107 groups.get_mut(&priority).unwrap()
108 }
109 };
110 let priority = within_group_priority_fn(key);
111
112 if group_priorities.priorities.contains(&priority) {
113 group_priorities
114 .indices
115 .get_mut(&priority)
116 .unwrap()
117 .push(index);
118 } else {
119 group_priorities.priorities.push(priority);
120 group_priorities.indices.insert(priority, vec![index]);
121 }
122 }
123 }
124 }
125
126 priorities.sort();
127
128 for group in groups.iter_mut() {
129 group.1.priorities.sort();
130 }
131
132 Self {
133 priorities,
134 no_groups,
135 groups,
136 }
137 }
138
139 pub(crate) fn next(&mut self) -> Vec<usize> {
143 let mut indices = core::mem::take(&mut self.no_groups);
144 let priority = self.priorities.last();
145
146 let priority = match priority {
147 Some(val) => val,
148 None => return indices,
149 };
150
151 let (mut group_indices, cleanup) = self.group_plan_next(*priority);
152 self.cleanup(cleanup);
153 indices.append(&mut group_indices);
154
155 indices
156 }
157
158 fn cleanup(&mut self, cleanup: Cleanup) {
159 for group_p in cleanup.groups {
160 let index = self
161 .priorities
162 .iter()
163 .enumerate()
164 .find(|p| *p.1 == group_p)
165 .unwrap();
166
167 self.priorities.remove(index.0);
168 self.groups.remove(&group_p);
169 }
170
171 for (group_p, tunable_p) in cleanup.tunables {
172 if let Some(group) = self.groups.get_mut(&group_p) {
173 let index = group
174 .priorities
175 .iter()
176 .enumerate()
177 .find(|p| *p.1 == tunable_p)
178 .unwrap();
179 group.priorities.remove(index.0);
180 group.indices.remove(&tunable_p);
181 }
182 }
183 }
184
185 fn group_plan_next(&mut self, priority: u8) -> (Vec<usize>, Cleanup) {
186 let plan = self.groups.get_mut(&priority).expect("To be filled");
187 let within_group_prio = plan.priorities.pop().unwrap();
188 let next_indices = plan.indices.remove(&within_group_prio).unwrap();
189
190 let mut cleanup_groups = Vec::new();
191 let mut cleanup_tunables = Vec::new();
192
193 for (pg, group) in self.groups.iter_mut() {
194 let mut num_empty_tunables = 0;
195 let num_tunables = group.priorities.len();
196
197 for (pt, indices) in group.indices.iter_mut() {
198 for n in &next_indices {
199 let entry = indices.iter().enumerate().find(|p| *p.1 == *n);
200 if let Some(entry) = entry {
201 indices.remove(entry.0);
202 }
203 }
204
205 if indices.is_empty() {
206 num_empty_tunables += 1;
207 cleanup_tunables.push((*pg, *pt));
208 }
209 }
210
211 if num_empty_tunables == num_tunables {
212 cleanup_groups.push(*pg);
213 }
214 }
215
216 (
217 next_indices,
218 Cleanup {
219 groups: cleanup_groups,
220 tunables: cleanup_tunables,
221 },
222 )
223 }
224}
225
226type PriorityFunc<K> = Arc<dyn Fn(&K) -> u8>;
227
228static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0);
229
230#[cfg(test)]
231mod tests {
232 use core::fmt::Display;
233
234 use serde::{Deserialize, Serialize};
235
236 use super::*;
237
238 #[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize, Debug)]
239 struct FakeAutotuneKey;
240
241 impl Display for FakeAutotuneKey {
242 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
243 f.write_str("FakeAutotuneKey")
244 }
245 }
246
247 impl AutotuneKey for FakeAutotuneKey {}
248
249 #[test]
250 fn test_plan_order() {
251 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
252 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
253
254 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
255 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
256 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
257 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
258
259 let key = FakeAutotuneKey;
260 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
261
262 assert_eq!(plan.next(), vec![0, 2]);
263 assert_eq!(plan.next(), vec![1]);
264 assert_eq!(plan.next(), vec![3]);
265 assert!(plan.next().is_empty());
266 }
267
268 #[test]
269 fn test_plan_order_multi_groups_same_priority() {
270 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
271 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
272 let group2 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
273
274 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
275 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 1);
276 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
277 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
278 let tunable4 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group2, |_| 2);
279
280 let key = FakeAutotuneKey;
281 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3, tunable4]);
282
283 assert_eq!(plan.next(), vec![0, 2]);
284 assert_eq!(plan.next(), vec![1]);
285 assert_eq!(plan.next(), vec![3, 4]);
286 assert!(plan.next().is_empty());
287 }
288
289 #[test]
290 fn test_plan_order_tunable_multiple_groups() {
291 let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
292 let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
293
294 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
295 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel)
296 .group(&group0, |_| 1)
297 .group(&group1, |_| 2);
298 let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
299 let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 3);
300
301 let key = FakeAutotuneKey;
302 let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
303
304 assert_eq!(plan.next(), vec![0, 3]);
305 assert_eq!(plan.next(), vec![1]);
306 assert_eq!(plan.next(), vec![2]);
307 assert!(plan.next().is_empty());
308 }
309
310 #[test]
311 fn test_plan_no_group() {
312 let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
313 let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
314
315 let key = FakeAutotuneKey;
316 let mut plan = TunePlan::new(&key, &[tunable0, tunable1]);
317
318 assert_eq!(plan.next(), vec![0, 1]);
319 assert!(plan.next().is_empty());
320 }
321
322 fn fake_kernel() -> Result<(), String> {
323 Ok(())
324 }
325}