cubecl_runtime/tune/
operation.rs1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::fmt::{Debug, Display};
6use core::hash::Hash;
7
8use alloc::format;
9
10use super::{
11 AutotuneError, input_generator::InputGenerator, key_generator::KeyGenerator,
12 tune_inputs::TuneInputs,
13};
14use super::{Tunable, TunePlan};
15
16type TuneDelegate<I, Out> =
21 dyn for<'inp> Fn(<I as TuneInputs>::At<'inp>) -> Result<Out, AutotuneError> + Send + Sync;
22
23#[derive(new)]
26pub struct TuneFn<I: TuneInputs, Out> {
27 pub(crate) name: String,
28 func: Box<TuneDelegate<I, Out>>,
29}
30
31impl<I: TuneInputs, Out: 'static> TuneFn<I, Out> {
32 pub fn execute<'a>(&self, inputs: <I as TuneInputs>::At<'a>) -> Result<Out, AutotuneError> {
34 (self.func)(inputs)
35 }
36}
37
38pub struct TunableSet<K: AutotuneKey, F: TuneInputs, Output: 'static> {
41 tunables: Vec<Tunable<K, F, Output>>,
42 key_gen: Arc<dyn KeyGenerator<K, F> + Send + Sync>,
43 input_gen: Arc<dyn InputGenerator<K, F> + Send + Sync>,
44}
45
46impl<K: AutotuneKey, F: TuneInputs, Output: 'static> TunableSet<K, F, Output> {
47 pub fn len(&self) -> usize {
49 self.tunables.len()
50 }
51
52 pub fn is_empty(&self) -> bool {
54 self.tunables.is_empty()
55 }
56
57 pub fn new(key_gen: impl KeyGenerator<K, F>, input_gen: impl InputGenerator<K, F>) -> Self {
59 Self {
60 tunables: Default::default(),
61 input_gen: Arc::new(input_gen),
62 key_gen: Arc::new(key_gen),
63 }
64 }
65
66 pub fn new_cloning_inputs(key_gen: impl KeyGenerator<K, F>) -> Self {
69 Self::new(key_gen, super::CloneInputGenerator)
70 }
71
72 pub fn with(mut self, tunable: Tunable<K, F, Output>) -> Self {
74 self.tunables.push(tunable);
75 self
76 }
77
78 pub fn autotunables(&self) -> impl Iterator<Item = &TuneFn<F, Output>> {
80 self.tunables.iter().map(|tunable| &tunable.function)
81 }
82
83 pub(crate) fn plan(&self, key: &K) -> TunePlan {
85 TunePlan::new(key, &self.tunables)
86 }
87
88 pub fn fastest(&self, fastest_index: usize) -> &TuneFn<F, Output> {
91 &self.tunables[fastest_index].function
92 }
93
94 pub fn compute_checksum(&self) -> String {
97 let mut checksum = String::new();
98 for tune in &self.tunables {
99 checksum += &tune.function.name;
100 }
101 format!("{:x}", md5::compute(checksum))
102 }
103
104 pub fn generate_key<'a>(&self, inputs: &F::At<'a>) -> K {
106 self.key_gen.generate(inputs)
107 }
108
109 pub fn generate_inputs<'a>(&self, key: &K, inputs: &F::At<'a>) -> F::At<'a> {
111 self.input_gen.generate(key, inputs)
112 }
113}
114
115#[cfg(std_io)]
116pub trait AutotuneKey:
118 Clone
119 + Debug
120 + PartialEq
121 + Eq
122 + Hash
123 + Display
124 + serde::Serialize
125 + serde::de::DeserializeOwned
126 + Send
127 + Sync
128 + 'static
129{
130}
131#[cfg(not(std_io))]
132pub trait AutotuneKey:
134 Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
135{
136}
137
138impl AutotuneKey for String {}