cubecl_runtime/tune/
function_tunable.rs

1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5use super::{AutotuneError, IntoTuneFn, TuneFn};
6
7/// Tunable implemented as a function or closure
8///
9/// # Marker
10/// The marker generic is used to work around limitations in the trait resolver that causes
11/// conflicting implementation errors.
12pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
13    func: F,
14    _marker: PhantomData<Marker>,
15}
16
17unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
18unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
19
20impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> TuneFn for FunctionTunable<F, Marker> {
21    type Inputs = F::Inputs;
22    type Output = F::Output;
23
24    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
25        self.func.execute(inputs)
26    }
27}
28
29/// Dummy marker for function tunables
30#[doc(hidden)]
31pub struct IsFunction;
32
33impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
34    IntoTuneFn<F::Inputs, F::Output, (Marker, IsFunction)> for F
35{
36    type Tunable = FunctionTunable<F, Marker>;
37
38    fn into_tunable(self) -> Self::Tunable {
39        FunctionTunable {
40            func: self,
41            _marker: PhantomData,
42        }
43    }
44}
45
46/// Tunable implemented as a function or closure that returns a plain value, wrapped in `Ok`.
47///
48/// # Marker
49/// The marker generic is used to work around limitations in the trait resolver that causes
50/// conflicting implementation errors.
51pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
52    func: F,
53    _marker: PhantomData<Marker>,
54}
55
56unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
57    for FunctionTunableResultMap<F, Marker>
58{
59}
60unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
61    for FunctionTunableResultMap<F, Marker>
62{
63}
64
65impl<F: AsFunctionTunable<Marker>, Marker: 'static> TuneFn for FunctionTunableResultMap<F, Marker> {
66    type Inputs = F::Inputs;
67    type Output = F::Output;
68
69    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
70        Ok(self.func.execute(inputs))
71    }
72}
73
74/// A function that can be turned into a tunable.
75///
76/// # Marker
77/// The marker generic is used to work around limitations in the trait resolver that causes
78/// conflicting implementation errors.
79#[diagnostic::on_unimplemented(
80    message = "`{Self}` is not a valid tunable.",
81    label = "invalid tunable"
82)]
83pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
84    /// Function inputs
85    type Inputs: Clone;
86    /// Function output
87    type Output;
88
89    /// Run a tuneable function
90    fn execute(&self, inputs: Self::Inputs) -> Self::Output;
91
92    /// Wrap infallible tunable with `Ok`
93    fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
94        FunctionTunableResultMap {
95            func: self,
96            _marker: PhantomData,
97        }
98    }
99
100    /// The name of the tuneable function
101    fn name(&self) -> &str {
102        core::any::type_name::<Self>()
103    }
104}
105
106/// An infallible function that can be turned into a tunable.
107///
108/// # Marker
109/// The marker generic is used to work around limitations in the trait resolver that causes
110/// conflicting implementation errors.
111#[diagnostic::on_unimplemented(
112    message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
113    label = "invalid tunable"
114)]
115pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
116    /// Function inputs
117    type Inputs: Clone;
118    /// Function output
119    type Output;
120
121    /// Run a tuneable function
122    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
123
124    /// The name of the tuneable function
125    fn name(&self) -> &str {
126        core::any::type_name::<Self>()
127    }
128}
129
130macro_rules! impl_tunable {
131    ($(#[$meta:meta])* $($params:ident),*) => {
132        #[allow(unused_parens)]
133        $(#[$meta])*
134        impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
135            where Func: Send + Sync + 'static,
136            for<'a> &'a Func: Fn($($params),*) -> Out
137        {
138            type Inputs = ($($params),*);
139            type Output = Out;
140
141            #[allow(non_snake_case, clippy::too_many_arguments)]
142            #[inline]
143            fn execute(&self, ($($params),*): ($($params),*)) -> Out {
144                fn call_inner<Out, $($params,)*>(
145                    f: impl Fn($($params,)*) -> Out,
146                    $($params: $params,)*
147                ) -> Out {
148                    f($($params,)*)
149                }
150                call_inner(self, $($params),*)
151            }
152        }
153    };
154}
155
156macro_rules! impl_tunable_result {
157    ($(#[$meta:meta])* $($params:ident),*) => {
158        #[allow(unused_parens)]
159        $(#[$meta])*
160        impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
161            where Func: Send + Sync + 'static,
162            for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
163            Err: Into<AutotuneError>
164        {
165            type Inputs = ($($params),*);
166            type Output = Out;
167
168            #[allow(non_snake_case, clippy::too_many_arguments)]
169            #[inline]
170            fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, AutotuneError> {
171                fn call_inner<Out, Err, $($params,)*>(
172                    f: impl Fn($($params,)*) -> Result<Out, Err>,
173                    $($params: $params,)*
174                ) -> Result<Out, Err> {
175                    f($($params,)*)
176                }
177                call_inner(self, $($params),*).map_err(Into::into)
178            }
179        }
180    };
181}
182
183all_tuples!(impl_tunable, 0, 12, I);
184all_tuples!(impl_tunable_result, 0, 12, I);