1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5use super::{AutotuneError, IntoTuneFn, TuneFn};
6
7pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
13 func: F,
14 _marker: PhantomData<Marker>,
15}
16
17unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
18unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
19
20impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> TuneFn for FunctionTunable<F, Marker> {
21 type Inputs = F::Inputs;
22 type Output = F::Output;
23
24 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
25 self.func.execute(inputs)
26 }
27}
28
29#[doc(hidden)]
31pub struct IsFunction;
32
33impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
34 IntoTuneFn<F::Inputs, F::Output, (Marker, IsFunction)> for F
35{
36 type Tunable = FunctionTunable<F, Marker>;
37
38 fn into_tunable(self) -> Self::Tunable {
39 FunctionTunable {
40 func: self,
41 _marker: PhantomData,
42 }
43 }
44}
45
46pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
52 func: F,
53 _marker: PhantomData<Marker>,
54}
55
56unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
57 for FunctionTunableResultMap<F, Marker>
58{
59}
60unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
61 for FunctionTunableResultMap<F, Marker>
62{
63}
64
65impl<F: AsFunctionTunable<Marker>, Marker: 'static> TuneFn for FunctionTunableResultMap<F, Marker> {
66 type Inputs = F::Inputs;
67 type Output = F::Output;
68
69 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
70 Ok(self.func.execute(inputs))
71 }
72}
73
74#[diagnostic::on_unimplemented(
80 message = "`{Self}` is not a valid tunable.",
81 label = "invalid tunable"
82)]
83pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
84 type Inputs: Clone;
86 type Output;
88
89 fn execute(&self, inputs: Self::Inputs) -> Self::Output;
91
92 fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
94 FunctionTunableResultMap {
95 func: self,
96 _marker: PhantomData,
97 }
98 }
99
100 fn name(&self) -> &str {
102 core::any::type_name::<Self>()
103 }
104}
105
106#[diagnostic::on_unimplemented(
112 message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
113 label = "invalid tunable"
114)]
115pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
116 type Inputs: Clone;
118 type Output;
120
121 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
123
124 fn name(&self) -> &str {
126 core::any::type_name::<Self>()
127 }
128}
129
130macro_rules! impl_tunable {
131 ($(#[$meta:meta])* $($params:ident),*) => {
132 #[allow(unused_parens)]
133 $(#[$meta])*
134 impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
135 where Func: Send + Sync + 'static,
136 for<'a> &'a Func: Fn($($params),*) -> Out
137 {
138 type Inputs = ($($params),*);
139 type Output = Out;
140
141 #[allow(non_snake_case, clippy::too_many_arguments)]
142 #[inline]
143 fn execute(&self, ($($params),*): ($($params),*)) -> Out {
144 fn call_inner<Out, $($params,)*>(
145 f: impl Fn($($params,)*) -> Out,
146 $($params: $params,)*
147 ) -> Out {
148 f($($params,)*)
149 }
150 call_inner(self, $($params),*)
151 }
152 }
153 };
154}
155
156macro_rules! impl_tunable_result {
157 ($(#[$meta:meta])* $($params:ident),*) => {
158 #[allow(unused_parens)]
159 $(#[$meta])*
160 impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
161 where Func: Send + Sync + 'static,
162 for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
163 Err: Into<AutotuneError>
164 {
165 type Inputs = ($($params),*);
166 type Output = Out;
167
168 #[allow(non_snake_case, clippy::too_many_arguments)]
169 #[inline]
170 fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, AutotuneError> {
171 fn call_inner<Out, Err, $($params,)*>(
172 f: impl Fn($($params,)*) -> Result<Out, Err>,
173 $($params: $params,)*
174 ) -> Result<Out, Err> {
175 f($($params,)*)
176 }
177 call_inner(self, $($params),*).map_err(Into::into)
178 }
179 }
180 };
181}
182
183all_tuples!(impl_tunable, 0, 12, I);
184all_tuples!(impl_tunable_result, 0, 12, I);