cubecl_runtime/tune/
function_tunable.rs

1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5use super::{AutotuneError, IntoTunable, Tunable};
6
7/// Tunable implemented as a function or closure
8///
9/// # Marker
10/// The marker generic is used to work around limitations in the trait resolver that causes
11/// conflicting implementation errors.
12pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
13    func: F,
14    _marker: PhantomData<Marker>,
15}
16
17unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
18unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
19
20impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> Tunable for FunctionTunable<F, Marker> {
21    type Inputs = F::Inputs;
22    type Output = F::Output;
23
24    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
25        self.func.execute(inputs)
26    }
27}
28
29/// Dummy marker for function tunables
30#[doc(hidden)]
31pub struct IsFunction;
32
33impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
34    IntoTunable<F::Inputs, F::Output, (Marker, IsFunction)> for F
35{
36    type Tunable = FunctionTunable<F, Marker>;
37
38    fn into_tunable(self) -> Self::Tunable {
39        FunctionTunable {
40            func: self,
41            _marker: PhantomData,
42        }
43    }
44}
45
46/// Tunable implemented as a function or closure that returns a plain value, wrapped in `Ok`.
47///
48/// # Marker
49/// The marker generic is used to work around limitations in the trait resolver that causes
50/// conflicting implementation errors.
51pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
52    func: F,
53    _marker: PhantomData<Marker>,
54}
55
56unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
57    for FunctionTunableResultMap<F, Marker>
58{
59}
60unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
61    for FunctionTunableResultMap<F, Marker>
62{
63}
64
65impl<F: AsFunctionTunable<Marker>, Marker: 'static> Tunable
66    for FunctionTunableResultMap<F, Marker>
67{
68    type Inputs = F::Inputs;
69    type Output = F::Output;
70
71    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
72        Ok(self.func.execute(inputs))
73    }
74}
75
76/// A function that can be turned into a tunable.
77///
78/// # Marker
79/// The marker generic is used to work around limitations in the trait resolver that causes
80/// conflicting implementation errors.
81#[diagnostic::on_unimplemented(
82    message = "`{Self}` is not a valid tunable.",
83    label = "invalid tunable"
84)]
85pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
86    /// Function inputs
87    type Inputs: Clone;
88    /// Function output
89    type Output;
90
91    /// Run a tuneable function
92    fn execute(&self, inputs: Self::Inputs) -> Self::Output;
93
94    /// Wrap infallible tunable with `Ok`
95    fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
96        FunctionTunableResultMap {
97            func: self,
98            _marker: PhantomData,
99        }
100    }
101
102    /// The name of the tuneable function
103    fn name(&self) -> &str {
104        core::any::type_name::<Self>()
105    }
106}
107
108/// An infallible function that can be turned into a tunable.
109///
110/// # Marker
111/// The marker generic is used to work around limitations in the trait resolver that causes
112/// conflicting implementation errors.
113#[diagnostic::on_unimplemented(
114    message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
115    label = "invalid tunable"
116)]
117pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
118    /// Function inputs
119    type Inputs: Clone;
120    /// Function output
121    type Output;
122
123    /// Run a tuneable function
124    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
125
126    /// The name of the tuneable function
127    fn name(&self) -> &str {
128        core::any::type_name::<Self>()
129    }
130}
131
132macro_rules! impl_tunable {
133    ($(#[$meta:meta])* $($params:ident),*) => {
134        #[allow(unused_parens)]
135        $(#[$meta])*
136        impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
137            where Func: Send + Sync + 'static,
138            for<'a> &'a Func: Fn($($params),*) -> Out
139        {
140            type Inputs = ($($params),*);
141            type Output = Out;
142
143            #[allow(non_snake_case, clippy::too_many_arguments)]
144            #[inline]
145            fn execute(&self, ($($params),*): ($($params),*)) -> Out {
146                fn call_inner<Out, $($params,)*>(
147                    f: impl Fn($($params,)*) -> Out,
148                    $($params: $params,)*
149                ) -> Out {
150                    f($($params,)*)
151                }
152                call_inner(self, $($params),*)
153            }
154        }
155    };
156}
157
158macro_rules! impl_tunable_result {
159    ($(#[$meta:meta])* $($params:ident),*) => {
160        #[allow(unused_parens)]
161        $(#[$meta])*
162        impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
163            where Func: Send + Sync + 'static,
164            for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
165            Err: Into<AutotuneError>
166        {
167            type Inputs = ($($params),*);
168            type Output = Out;
169
170            #[allow(non_snake_case, clippy::too_many_arguments)]
171            #[inline]
172            fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, AutotuneError> {
173                fn call_inner<Out, Err, $($params,)*>(
174                    f: impl Fn($($params,)*) -> Result<Out, Err>,
175                    $($params: $params,)*
176                ) -> Result<Out, Err> {
177                    f($($params,)*)
178                }
179                call_inner(self, $($params),*).map_err(Into::into)
180            }
181        }
182    };
183}
184
185all_tuples!(impl_tunable, 0, 12, I);
186all_tuples!(impl_tunable_result, 0, 12, I);