cubecl_runtime/tune/
input_generator.rs1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5pub trait InputGenerator<K, Inputs>: 'static {
7 fn generate(&self, key: &K, inputs: &Inputs) -> Inputs;
9}
10
11pub trait IntoInputGenerator<K, Inputs, Marker> {
13 type Generator: InputGenerator<K, Inputs>;
15
16 fn into_input_gen(self) -> Self::Generator;
18}
19
20pub struct FunctionInputGenerator<F, Marker> {
22 func: F,
23 _marker: PhantomData<Marker>,
24}
25
26impl<K, Inputs, Marker: 'static, F> InputGenerator<K, Inputs> for FunctionInputGenerator<F, Marker>
27where
28 F: FunctionInputGen<K, Inputs, Marker>,
29{
30 fn generate(&self, key: &K, inputs: &Inputs) -> Inputs {
31 self.func.execute(key, inputs)
32 }
33}
34
35#[diagnostic::on_unimplemented(
37 message = "`{Self}` is not a valid input generator",
38 label = "invalid input generator"
39)]
40pub trait FunctionInputGen<K, Inputs, Marker>: 'static {
41 fn execute(&self, key: &K, inputs: &Inputs) -> Inputs;
43}
44
45impl<K, Inputs, Marker: 'static, F> IntoInputGenerator<K, Inputs, Marker> for F
46where
47 F: FunctionInputGen<K, Inputs, Marker>,
48{
49 type Generator = FunctionInputGenerator<F, Marker>;
50
51 fn into_input_gen(self) -> Self::Generator {
52 FunctionInputGenerator {
53 func: self,
54 _marker: PhantomData,
55 }
56 }
57}
58
59macro_rules! impl_input_gen {
60 ($($param:ident),*) => {
61 #[allow(unused_parens, clippy::unused_unit)]
62 impl<K: 'static, Func, $($param: Clone + Send + 'static,)*> FunctionInputGen<K, ($($param),*), fn(&K, $(&$param),*) -> ($($param),*)> for Func
63 where Func: Send + Sync + 'static,
64 for<'a> &'a Func: Fn(&K, $(&$param),*) -> ($($param),*)
65 {
66 #[allow(non_snake_case, clippy::too_many_arguments)]
67 #[inline]
68 fn execute(&self, key: &K, ($($param),*): &($($param),*)) -> ($($param),*) {
69 fn call_inner<K, $($param,)*>(
70 f: impl Fn(&K, $(&$param,)*) -> ($($param),*),
71 key: &K,
72 $($param: &$param,)*
73 ) -> ($($param),*) {
74 f(key, $($param,)*)
75 }
76 call_inner(self, key, $($param),*)
77 }
78 }
79 };
80}
81
82all_tuples!(impl_input_gen, 0, 13, I);