cubecl_runtime/tune/
function_tunable.rs

1use core::marker::PhantomData;
2
3use variadics_please::all_tuples;
4
5use super::{AutotuneError, IntoTuneFn, TuneFn};
6
7/// Tunable implemented as a function or closure
8///
9/// # Marker
10/// The marker generic is used to work around limitations in the trait resolver that causes
11/// conflicting implementation errors.
12pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
13    func: F,
14    _marker: PhantomData<Marker>,
15}
16
17unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
18unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
19
20impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> TuneFn for FunctionTunable<F, Marker> {
21    type Inputs = F::Inputs;
22    type Output = F::Output;
23
24    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
25        self.func.execute(inputs)
26    }
27
28    fn name(&self) -> &str {
29        self.func.name()
30    }
31}
32
33/// Dummy marker for function tunables
34#[doc(hidden)]
35pub struct IsFunction;
36
37impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
38    IntoTuneFn<F::Inputs, F::Output, (Marker, IsFunction)> for F
39{
40    type Tunable = FunctionTunable<F, Marker>;
41
42    fn into_tunable(self) -> Self::Tunable {
43        FunctionTunable {
44            func: self,
45            _marker: PhantomData,
46        }
47    }
48}
49
50/// Tunable implemented as a function or closure that returns a plain value, wrapped in `Ok`.
51///
52/// # Marker
53/// The marker generic is used to work around limitations in the trait resolver that causes
54/// conflicting implementation errors.
55pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
56    func: F,
57    _marker: PhantomData<Marker>,
58}
59
60unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
61    for FunctionTunableResultMap<F, Marker>
62{
63}
64unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
65    for FunctionTunableResultMap<F, Marker>
66{
67}
68
69impl<F: AsFunctionTunable<Marker>, Marker: 'static> TuneFn for FunctionTunableResultMap<F, Marker> {
70    type Inputs = F::Inputs;
71    type Output = F::Output;
72
73    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
74        Ok(self.func.execute(inputs))
75    }
76
77    fn name(&self) -> &str {
78        self.func.name()
79    }
80}
81
82/// A function that can be turned into a tunable.
83///
84/// # Marker
85/// The marker generic is used to work around limitations in the trait resolver that causes
86/// conflicting implementation errors.
87#[diagnostic::on_unimplemented(
88    message = "`{Self}` is not a valid tunable.",
89    label = "invalid tunable"
90)]
91pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
92    /// Function inputs
93    type Inputs: Clone;
94    /// Function output
95    type Output;
96
97    /// Run a tuneable function
98    fn execute(&self, inputs: Self::Inputs) -> Self::Output;
99
100    /// Wrap infallible tunable with `Ok`
101    fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
102        FunctionTunableResultMap {
103            func: self,
104            _marker: PhantomData,
105        }
106    }
107
108    /// The name of the tuneable function
109    fn name(&self) -> &str {
110        core::any::type_name::<Self>()
111    }
112}
113
114/// An infallible function that can be turned into a tunable.
115///
116/// # Marker
117/// The marker generic is used to work around limitations in the trait resolver that causes
118/// conflicting implementation errors.
119#[diagnostic::on_unimplemented(
120    message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
121    label = "invalid tunable"
122)]
123pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
124    /// Function inputs
125    type Inputs: Clone;
126    /// Function output
127    type Output;
128
129    /// Run a tuneable function
130    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
131
132    /// The name of the tuneable function
133    fn name(&self) -> &str {
134        core::any::type_name::<Self>()
135    }
136}
137
138macro_rules! impl_tunable {
139    ($(#[$meta:meta])* $($params:ident),*) => {
140        #[allow(unused_parens)]
141        $(#[$meta])*
142        impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
143            where Func: Send + Sync + 'static,
144            for<'a> &'a Func: Fn($($params),*) -> Out
145        {
146            type Inputs = ($($params),*);
147            type Output = Out;
148
149            #[allow(non_snake_case, clippy::too_many_arguments)]
150            #[inline]
151            fn execute(&self, ($($params),*): ($($params),*)) -> Out {
152                fn call_inner<Out, $($params,)*>(
153                    f: impl Fn($($params,)*) -> Out,
154                    $($params: $params,)*
155                ) -> Out {
156                    f($($params,)*)
157                }
158                call_inner(self, $($params),*)
159            }
160        }
161    };
162}
163
164macro_rules! impl_tunable_result {
165    ($(#[$meta:meta])* $($params:ident),*) => {
166        #[allow(unused_parens)]
167        $(#[$meta])*
168        impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
169            where Func: Send + Sync + 'static,
170            for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
171            Err: Into<AutotuneError>
172        {
173            type Inputs = ($($params),*);
174            type Output = Out;
175
176            #[allow(non_snake_case, clippy::too_many_arguments)]
177            #[inline]
178            fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, AutotuneError> {
179                fn call_inner<Out, Err, $($params,)*>(
180                    f: impl Fn($($params,)*) -> Result<Out, Err>,
181                    $($params: $params,)*
182                ) -> Result<Out, Err> {
183                    f($($params,)*)
184                }
185                call_inner(self, $($params),*).map_err(Into::into)
186            }
187        }
188    };
189}
190
191all_tuples!(impl_tunable, 0, 12, I);
192all_tuples!(impl_tunable_result, 0, 12, I);