1use super::{AutotuneError, IntoTuneFn, TuneFn};
2use alloc::string::String;
3use core::marker::PhantomData;
4use variadics_please::all_tuples;
5
6pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
12 name: String,
13 func: F,
14 _marker: PhantomData<Marker>,
15}
16
17unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
18unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
19
20impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> TuneFn for FunctionTunable<F, Marker> {
21 type Inputs = F::Inputs;
22 type Output = F::Output;
23
24 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
25 match self.func.execute(inputs) {
26 Ok(val) => Ok(val),
27 Err(err) => Err(AutotuneError::Unknown {
28 name: self.name.clone(),
29 err,
30 }),
31 }
32 }
33
34 fn name(&self) -> &str {
35 &self.name
36 }
37}
38
39#[doc(hidden)]
41pub struct IsFunction;
42
43impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
44 IntoTuneFn<F::Inputs, F::Output, (Marker, IsFunction)> for F
45{
46 type Tunable = FunctionTunable<F, Marker>;
47
48 fn into_tunable(self, name: String) -> Self::Tunable {
49 FunctionTunable {
50 func: self,
51 name,
52 _marker: PhantomData,
53 }
54 }
55}
56
57pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
63 func: F,
64 _marker: PhantomData<Marker>,
65}
66
67unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
68 for FunctionTunableResultMap<F, Marker>
69{
70}
71unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
72 for FunctionTunableResultMap<F, Marker>
73{
74}
75
76impl<F: AsFunctionTunable<Marker>, Marker: 'static> TuneFn for FunctionTunableResultMap<F, Marker> {
77 type Inputs = F::Inputs;
78 type Output = F::Output;
79
80 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
81 Ok(self.func.execute(inputs))
82 }
83
84 fn name(&self) -> &str {
85 self.func.name()
86 }
87}
88
89#[diagnostic::on_unimplemented(
95 message = "`{Self}` is not a valid tunable.",
96 label = "invalid tunable"
97)]
98pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
99 type Inputs: Clone;
101 type Output;
103
104 fn execute(&self, inputs: Self::Inputs) -> Self::Output;
106
107 fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
109 FunctionTunableResultMap {
110 func: self,
111 _marker: PhantomData,
112 }
113 }
114
115 fn name(&self) -> &str {
117 core::any::type_name::<Self>()
118 }
119}
120
121#[diagnostic::on_unimplemented(
127 message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
128 label = "invalid tunable"
129)]
130pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
131 type Inputs: Clone;
133 type Output;
135
136 fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, String>;
138}
139
140macro_rules! impl_tunable {
141 ($(#[$meta:meta])* $($params:ident),*) => {
142 #[allow(unused_parens)]
143 $(#[$meta])*
144 impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
145 where Func: Send + Sync + 'static,
146 for<'a> &'a Func: Fn($($params),*) -> Out
147 {
148 type Inputs = ($($params),*);
149 type Output = Out;
150
151 #[allow(non_snake_case, clippy::too_many_arguments)]
152 #[inline]
153 fn execute(&self, ($($params),*): ($($params),*)) -> Out {
154 fn call_inner<Out, $($params,)*>(
155 f: impl Fn($($params,)*) -> Out,
156 $($params: $params,)*
157 ) -> Out {
158 f($($params,)*)
159 }
160 call_inner(self, $($params),*)
161 }
162 }
163 };
164}
165
166macro_rules! impl_tunable_result {
167 ($(#[$meta:meta])* $($params:ident),*) => {
168 #[allow(unused_parens)]
169 $(#[$meta])*
170 impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
171 where Func: Send + Sync + 'static,
172 for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
173 Err: Into<String>
174 {
175 type Inputs = ($($params),*);
176 type Output = Out;
177
178 #[allow(non_snake_case, clippy::too_many_arguments)]
179 #[inline]
180 fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, String> {
181 fn call_inner<Out, Err, $($params,)*>(
182 f: impl Fn($($params,)*) -> Result<Out, Err>,
183 $($params: $params,)*
184 ) -> Result<Out, Err> {
185 f($($params,)*)
186 }
187 call_inner(self, $($params),*).map_err(Into::into)
188 }
189 }
190 };
191}
192
193all_tuples!(impl_tunable, 0, 12, I);
194all_tuples!(impl_tunable_result, 0, 12, I);