cubecl_runtime/tune/
key_generator.rs1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5pub trait KeyGenerator<K, Inputs>: 'static {
7 fn generate(&self, inputs: &Inputs) -> K;
9}
10
11pub trait IntoKeyGenerator<K, Inputs, Marker> {
13 type Generator: KeyGenerator<K, Inputs>;
15
16 fn into_key_gen(self) -> Self::Generator;
18}
19
20pub struct FunctionKeyGenerator<F, Marker> {
22 func: F,
23 _marker: PhantomData<Marker>,
24}
25
26impl<K, Inputs, Marker: 'static, F> KeyGenerator<K, Inputs> for FunctionKeyGenerator<F, Marker>
27where
28 F: FunctionKeygen<K, Inputs, Marker>,
29{
30 fn generate(&self, inputs: &Inputs) -> K {
31 self.func.execute(inputs)
32 }
33}
34
35#[diagnostic::on_unimplemented(
37 message = "`{Self}` is not a valid key generator",
38 label = "invalid key generator"
39)]
40pub trait FunctionKeygen<K, Inputs, Marker>: 'static {
41 fn execute(&self, inputs: &Inputs) -> K;
43}
44
45impl<K, Inputs, Marker: 'static, F> IntoKeyGenerator<K, Inputs, Marker> for F
46where
47 F: FunctionKeygen<K, Inputs, Marker>,
48{
49 type Generator = FunctionKeyGenerator<F, Marker>;
50
51 fn into_key_gen(self) -> Self::Generator {
52 FunctionKeyGenerator {
53 func: self,
54 _marker: PhantomData,
55 }
56 }
57}
58
59macro_rules! impl_keygen {
60 ($($param:ident),*) => {
61 #[allow(unused_parens)]
62 impl<K: 'static, Func, $($param: Clone + Send + 'static,)*> FunctionKeygen<K, ($($param),*), fn($(&$param),*) -> K> for Func
63 where Func: Send + Sync + 'static,
64 for<'a> &'a Func: Fn($(&$param),*) -> K
65 {
66 #[allow(non_snake_case, clippy::too_many_arguments)]
67 #[inline]
68 fn execute(&self, ($($param),*): &($($param),*)) -> K {
69 fn call_inner<Out, $($param,)*>(
70 f: impl Fn($(&$param,)*) -> Out,
71 $($param: &$param,)*
72 ) -> Out {
73 f($($param,)*)
74 }
75 call_inner(self, $($param),*)
76 }
77 }
78 };
79}
80
81all_tuples!(impl_keygen, 0, 12, I);