cubecl_runtime/tune/
function_tunable.rs

1use super::{AutotuneError, IntoTuneFn, TuneFn};
2use alloc::string::String;
3use core::marker::PhantomData;
4use variadics_please::all_tuples;
5
6/// Tunable implemented as a function or closure
7///
8/// # Marker
9/// The marker generic is used to work around limitations in the trait resolver that causes
10/// conflicting implementation errors.
11pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
12    name: String,
13    func: F,
14    _marker: PhantomData<Marker>,
15}
16
17unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
18unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
19
20impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> TuneFn for FunctionTunable<F, Marker> {
21    type Inputs = F::Inputs;
22    type Output = F::Output;
23
24    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
25        match self.func.execute(inputs) {
26            Ok(val) => Ok(val),
27            Err(err) => Err(AutotuneError::Unknown {
28                name: self.name.clone(),
29                err,
30            }),
31        }
32    }
33
34    fn name(&self) -> &str {
35        &self.name
36    }
37}
38
39/// Dummy marker for function tunables
40#[doc(hidden)]
41pub struct IsFunction;
42
43impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
44    IntoTuneFn<F::Inputs, F::Output, (Marker, IsFunction)> for F
45{
46    type Tunable = FunctionTunable<F, Marker>;
47
48    fn into_tunable(self, name: String) -> Self::Tunable {
49        FunctionTunable {
50            func: self,
51            name,
52            _marker: PhantomData,
53        }
54    }
55}
56
57/// Tunable implemented as a function or closure that returns a plain value, wrapped in `Ok`.
58///
59/// # Marker
60/// The marker generic is used to work around limitations in the trait resolver that causes
61/// conflicting implementation errors.
62pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
63    func: F,
64    _marker: PhantomData<Marker>,
65}
66
67unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
68    for FunctionTunableResultMap<F, Marker>
69{
70}
71unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
72    for FunctionTunableResultMap<F, Marker>
73{
74}
75
76impl<F: AsFunctionTunable<Marker>, Marker: 'static> TuneFn for FunctionTunableResultMap<F, Marker> {
77    type Inputs = F::Inputs;
78    type Output = F::Output;
79
80    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
81        Ok(self.func.execute(inputs))
82    }
83
84    fn name(&self) -> &str {
85        self.func.name()
86    }
87}
88
89/// A function that can be turned into a tunable.
90///
91/// # Marker
92/// The marker generic is used to work around limitations in the trait resolver that causes
93/// conflicting implementation errors.
94#[diagnostic::on_unimplemented(
95    message = "`{Self}` is not a valid tunable.",
96    label = "invalid tunable"
97)]
98pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
99    /// Function inputs
100    type Inputs: Clone;
101    /// Function output
102    type Output;
103
104    /// Run a tuneable function
105    fn execute(&self, inputs: Self::Inputs) -> Self::Output;
106
107    /// Wrap infallible tunable with `Ok`
108    fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
109        FunctionTunableResultMap {
110            func: self,
111            _marker: PhantomData,
112        }
113    }
114
115    /// The name of the tuneable function
116    fn name(&self) -> &str {
117        core::any::type_name::<Self>()
118    }
119}
120
121/// An infallible function that can be turned into a tunable.
122///
123/// # Marker
124/// The marker generic is used to work around limitations in the trait resolver that causes
125/// conflicting implementation errors.
126#[diagnostic::on_unimplemented(
127    message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
128    label = "invalid tunable"
129)]
130pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
131    /// Function inputs
132    type Inputs: Clone;
133    /// Function output
134    type Output;
135
136    /// Run a tuneable function
137    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, String>;
138}
139
140macro_rules! impl_tunable {
141    ($(#[$meta:meta])* $($params:ident),*) => {
142        #[allow(unused_parens)]
143        $(#[$meta])*
144        impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
145            where Func: Send + Sync + 'static,
146            for<'a> &'a Func: Fn($($params),*) -> Out
147        {
148            type Inputs = ($($params),*);
149            type Output = Out;
150
151            #[allow(non_snake_case, clippy::too_many_arguments)]
152            #[inline]
153            fn execute(&self, ($($params),*): ($($params),*)) -> Out {
154                fn call_inner<Out, $($params,)*>(
155                    f: impl Fn($($params,)*) -> Out,
156                    $($params: $params,)*
157                ) -> Out {
158                    f($($params,)*)
159                }
160                call_inner(self, $($params),*)
161            }
162        }
163    };
164}
165
166macro_rules! impl_tunable_result {
167    ($(#[$meta:meta])* $($params:ident),*) => {
168        #[allow(unused_parens)]
169        $(#[$meta])*
170        impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
171            where Func: Send + Sync + 'static,
172            for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
173            Err: Into<String>
174        {
175            type Inputs = ($($params),*);
176            type Output = Out;
177
178            #[allow(non_snake_case, clippy::too_many_arguments)]
179            #[inline]
180            fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, String> {
181                fn call_inner<Out, Err, $($params,)*>(
182                    f: impl Fn($($params,)*) -> Result<Out, Err>,
183                    $($params: $params,)*
184                ) -> Result<Out, Err> {
185                    f($($params,)*)
186                }
187                call_inner(self, $($params),*).map_err(Into::into)
188            }
189        }
190    };
191}
192
193all_tuples!(impl_tunable, 0, 12, I);
194all_tuples!(impl_tunable_result, 0, 12, I);