1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5use super::{AutotuneError, IntoTunable, Tunable};
6
7pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
13 func: F,
14 _marker: PhantomData<Marker>,
15}
16
17unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
18unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
19
20impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> Tunable for FunctionTunable<F, Marker> {
21 type Inputs = F::Inputs;
22 type Output = F::Output;
23
24 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
25 self.func.execute(inputs)
26 }
27}
28
29#[doc(hidden)]
31pub struct IsFunction;
32
33impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
34 IntoTunable<F::Inputs, F::Output, (Marker, IsFunction)> for F
35{
36 type Tunable = FunctionTunable<F, Marker>;
37
38 fn into_tunable(self) -> Self::Tunable {
39 FunctionTunable {
40 func: self,
41 _marker: PhantomData,
42 }
43 }
44}
45
46pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
52 func: F,
53 _marker: PhantomData<Marker>,
54}
55
56unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
57 for FunctionTunableResultMap<F, Marker>
58{
59}
60unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
61 for FunctionTunableResultMap<F, Marker>
62{
63}
64
65impl<F: AsFunctionTunable<Marker>, Marker: 'static> Tunable
66 for FunctionTunableResultMap<F, Marker>
67{
68 type Inputs = F::Inputs;
69 type Output = F::Output;
70
71 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
72 Ok(self.func.execute(inputs))
73 }
74}
75
76#[diagnostic::on_unimplemented(
82 message = "`{Self}` is not a valid tunable.",
83 label = "invalid tunable"
84)]
85pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
86 type Inputs: Clone;
88 type Output;
90
91 fn execute(&self, inputs: Self::Inputs) -> Self::Output;
93
94 fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
96 FunctionTunableResultMap {
97 func: self,
98 _marker: PhantomData,
99 }
100 }
101
102 fn name(&self) -> &str {
104 core::any::type_name::<Self>()
105 }
106}
107
108#[diagnostic::on_unimplemented(
114 message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
115 label = "invalid tunable"
116)]
117pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
118 type Inputs: Clone;
120 type Output;
122
123 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
125
126 fn name(&self) -> &str {
128 core::any::type_name::<Self>()
129 }
130}
131
132macro_rules! impl_tunable {
133 ($(#[$meta:meta])* $($params:ident),*) => {
134 #[allow(unused_parens)]
135 $(#[$meta])*
136 impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
137 where Func: Send + Sync + 'static,
138 for<'a> &'a Func: Fn($($params),*) -> Out
139 {
140 type Inputs = ($($params),*);
141 type Output = Out;
142
143 #[allow(non_snake_case, clippy::too_many_arguments)]
144 #[inline]
145 fn execute(&self, ($($params),*): ($($params),*)) -> Out {
146 fn call_inner<Out, $($params,)*>(
147 f: impl Fn($($params,)*) -> Out,
148 $($params: $params,)*
149 ) -> Out {
150 f($($params,)*)
151 }
152 call_inner(self, $($params),*)
153 }
154 }
155 };
156}
157
158macro_rules! impl_tunable_result {
159 ($(#[$meta:meta])* $($params:ident),*) => {
160 #[allow(unused_parens)]
161 $(#[$meta])*
162 impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
163 where Func: Send + Sync + 'static,
164 for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
165 Err: Into<AutotuneError>
166 {
167 type Inputs = ($($params),*);
168 type Output = Out;
169
170 #[allow(non_snake_case, clippy::too_many_arguments)]
171 #[inline]
172 fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, AutotuneError> {
173 fn call_inner<Out, Err, $($params,)*>(
174 f: impl Fn($($params,)*) -> Result<Out, Err>,
175 $($params: $params,)*
176 ) -> Result<Out, Err> {
177 f($($params,)*)
178 }
179 call_inner(self, $($params),*).map_err(Into::into)
180 }
181 }
182 };
183}
184
185all_tuples!(impl_tunable, 0, 12, I);
186all_tuples!(impl_tunable_result, 0, 12, I);