cubecl_runtime/tune/
key_generator.rs

1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5/// A generator that creates a key for a given set of inputs
6pub trait KeyGenerator<K, Inputs>: 'static {
7    /// Generate a key from a set of inputs
8    fn generate(&self, inputs: &Inputs) -> K;
9}
10
11/// Something that can be turned into a key generator
12pub trait IntoKeyGenerator<K, Inputs, Marker> {
13    /// The concrete key generator type
14    type Generator: KeyGenerator<K, Inputs>;
15
16    /// Turn this type into a concrete key generator
17    fn into_key_gen(self) -> Self::Generator;
18}
19
20/// A key generator implemented by an `Fn`
21pub struct FunctionKeyGenerator<F, Marker> {
22    func: F,
23    _marker: PhantomData<Marker>,
24}
25
26impl<K, Inputs, Marker: 'static, F> KeyGenerator<K, Inputs> for FunctionKeyGenerator<F, Marker>
27where
28    F: FunctionKeygen<K, Inputs, Marker>,
29{
30    fn generate(&self, inputs: &Inputs) -> K {
31        self.func.execute(inputs)
32    }
33}
34
35/// An `Fn` that can act as a key generator
36#[diagnostic::on_unimplemented(
37    message = "`{Self}` is not a valid key generator",
38    label = "invalid key generator"
39)]
40pub trait FunctionKeygen<K, Inputs, Marker>: 'static {
41    /// Execute this function and generate a key
42    fn execute(&self, inputs: &Inputs) -> K;
43}
44
45impl<K, Inputs, Marker: 'static, F> IntoKeyGenerator<K, Inputs, Marker> for F
46where
47    F: FunctionKeygen<K, Inputs, Marker>,
48{
49    type Generator = FunctionKeyGenerator<F, Marker>;
50
51    fn into_key_gen(self) -> Self::Generator {
52        FunctionKeyGenerator {
53            func: self,
54            _marker: PhantomData,
55        }
56    }
57}
58
59macro_rules! impl_keygen {
60    ($($param:ident),*) => {
61        #[allow(unused_parens)]
62        impl<K: 'static, Func, $($param: Clone + Send + 'static,)*> FunctionKeygen<K, ($($param),*), fn($(&$param),*) -> K> for Func
63            where Func: Send + Sync + 'static,
64            for<'a> &'a Func: Fn($(&$param),*) -> K
65        {
66            #[allow(non_snake_case, clippy::too_many_arguments)]
67            #[inline]
68            fn execute(&self, ($($param),*): &($($param),*)) -> K {
69                fn call_inner<Out, $($param,)*>(
70                    f: impl Fn($(&$param,)*) -> Out,
71                    $($param: &$param,)*
72                ) -> Out {
73                    f($($param,)*)
74                }
75                call_inner(self, $($param),*)
76            }
77        }
78    };
79}
80
81all_tuples!(impl_keygen, 0, 12, I);