1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5use super::{AutotuneError, IntoTuneFn, TuneFn};
6
7pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
13 func: F,
14 _marker: PhantomData<Marker>,
15}
16
17unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
18unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
19
20impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> TuneFn for FunctionTunable<F, Marker> {
21 type Inputs = F::Inputs;
22 type Output = F::Output;
23
24 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
25 self.func.execute(inputs)
26 }
27
28 fn name(&self) -> &str {
29 self.func.name()
30 }
31}
32
33#[doc(hidden)]
35pub struct IsFunction;
36
37impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
38 IntoTuneFn<F::Inputs, F::Output, (Marker, IsFunction)> for F
39{
40 type Tunable = FunctionTunable<F, Marker>;
41
42 fn into_tunable(self) -> Self::Tunable {
43 FunctionTunable {
44 func: self,
45 _marker: PhantomData,
46 }
47 }
48}
49
50pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
56 func: F,
57 _marker: PhantomData<Marker>,
58}
59
60unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
61 for FunctionTunableResultMap<F, Marker>
62{
63}
64unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
65 for FunctionTunableResultMap<F, Marker>
66{
67}
68
69impl<F: AsFunctionTunable<Marker>, Marker: 'static> TuneFn for FunctionTunableResultMap<F, Marker> {
70 type Inputs = F::Inputs;
71 type Output = F::Output;
72
73 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
74 Ok(self.func.execute(inputs))
75 }
76
77 fn name(&self) -> &str {
78 self.func.name()
79 }
80}
81
82#[diagnostic::on_unimplemented(
88 message = "`{Self}` is not a valid tunable.",
89 label = "invalid tunable"
90)]
91pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
92 type Inputs: Clone;
94 type Output;
96
97 fn execute(&self, inputs: Self::Inputs) -> Self::Output;
99
100 fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
102 FunctionTunableResultMap {
103 func: self,
104 _marker: PhantomData,
105 }
106 }
107
108 fn name(&self) -> &str {
110 core::any::type_name::<Self>()
111 }
112}
113
114#[diagnostic::on_unimplemented(
120 message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
121 label = "invalid tunable"
122)]
123pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
124 type Inputs: Clone;
126 type Output;
128
129 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
131
132 fn name(&self) -> &str {
134 core::any::type_name::<Self>()
135 }
136}
137
138macro_rules! impl_tunable {
139 ($(#[$meta:meta])* $($params:ident),*) => {
140 #[allow(unused_parens)]
141 $(#[$meta])*
142 impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
143 where Func: Send + Sync + 'static,
144 for<'a> &'a Func: Fn($($params),*) -> Out
145 {
146 type Inputs = ($($params),*);
147 type Output = Out;
148
149 #[allow(non_snake_case, clippy::too_many_arguments)]
150 #[inline]
151 fn execute(&self, ($($params),*): ($($params),*)) -> Out {
152 fn call_inner<Out, $($params,)*>(
153 f: impl Fn($($params,)*) -> Out,
154 $($params: $params,)*
155 ) -> Out {
156 f($($params,)*)
157 }
158 call_inner(self, $($params),*)
159 }
160 }
161 };
162}
163
164macro_rules! impl_tunable_result {
165 ($(#[$meta:meta])* $($params:ident),*) => {
166 #[allow(unused_parens)]
167 $(#[$meta])*
168 impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
169 where Func: Send + Sync + 'static,
170 for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
171 Err: Into<AutotuneError>
172 {
173 type Inputs = ($($params),*);
174 type Output = Out;
175
176 #[allow(non_snake_case, clippy::too_many_arguments)]
177 #[inline]
178 fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, AutotuneError> {
179 fn call_inner<Out, Err, $($params,)*>(
180 f: impl Fn($($params,)*) -> Result<Out, Err>,
181 $($params: $params,)*
182 ) -> Result<Out, Err> {
183 f($($params,)*)
184 }
185 call_inner(self, $($params),*).map_err(Into::into)
186 }
187 }
188 };
189}
190
191all_tuples!(impl_tunable, 0, 12, I);
192all_tuples!(impl_tunable_result, 0, 12, I);