cubecl_runtime/tune/
input_generator.rs

1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5/// A function that generates the input for autotuning passes
6pub trait InputGenerator<K, Inputs>: 'static {
7    /// Generate a set of inputs for a given key and reference inputs
8    fn generate(&self, key: &K, inputs: &Inputs) -> Inputs;
9}
10
11/// Something that can be turned into an input generator
12pub trait IntoInputGenerator<K, Inputs, Marker> {
13    /// The concrete type of the input generator
14    type Generator: InputGenerator<K, Inputs>;
15
16    /// Convert this type into a concrete input generator
17    fn into_input_gen(self) -> Self::Generator;
18}
19
20/// An input generator implemented by an `Fn`
21pub struct FunctionInputGenerator<F, Marker> {
22    func: F,
23    _marker: PhantomData<Marker>,
24}
25
26impl<K, Inputs, Marker: 'static, F> InputGenerator<K, Inputs> for FunctionInputGenerator<F, Marker>
27where
28    F: FunctionInputGen<K, Inputs, Marker>,
29{
30    fn generate(&self, key: &K, inputs: &Inputs) -> Inputs {
31        self.func.execute(key, inputs)
32    }
33}
34
35/// A function that can be turned into an input generator for `Inputs`
36#[diagnostic::on_unimplemented(
37    message = "`{Self}` is not a valid input generator",
38    label = "invalid input generator"
39)]
40pub trait FunctionInputGen<K, Inputs, Marker>: 'static {
41    /// Execute the function and generate a set of inputs
42    fn execute(&self, key: &K, inputs: &Inputs) -> Inputs;
43}
44
45impl<K, Inputs, Marker: 'static, F> IntoInputGenerator<K, Inputs, Marker> for F
46where
47    F: FunctionInputGen<K, Inputs, Marker>,
48{
49    type Generator = FunctionInputGenerator<F, Marker>;
50
51    fn into_input_gen(self) -> Self::Generator {
52        FunctionInputGenerator {
53            func: self,
54            _marker: PhantomData,
55        }
56    }
57}
58
59macro_rules! impl_input_gen {
60    ($($param:ident),*) => {
61        #[allow(unused_parens, clippy::unused_unit)]
62        impl<K: 'static, Func, $($param: Clone + Send + 'static,)*> FunctionInputGen<K, ($($param),*), fn(&K, $(&$param),*) -> ($($param),*)> for Func
63            where Func: Send + Sync + 'static,
64            for<'a> &'a Func: Fn(&K, $(&$param),*) -> ($($param),*)
65        {
66            #[allow(non_snake_case, clippy::too_many_arguments)]
67            #[inline]
68            fn execute(&self, key: &K, ($($param),*): &($($param),*)) -> ($($param),*) {
69                fn call_inner<K, $($param,)*>(
70                    f: impl Fn(&K, $(&$param,)*) -> ($($param),*),
71                    key: &K,
72                    $($param: &$param,)*
73                ) -> ($($param),*) {
74                    f(key, $($param,)*)
75                }
76                call_inner(self, key, $($param),*)
77            }
78        }
79    };
80}
81
82all_tuples!(impl_input_gen, 0, 13, I);