use super::{AutotuneError, IntoTuneFn, TuneFn};
use alloc::string::String;
use core::marker::PhantomData;
use variadics_please::all_tuples;
pub struct FunctionTunable<F: AsFunctionTunableResult<Marker>, Marker> {
name: String,
func: F,
_marker: PhantomData<Marker>,
}
unsafe impl<F: AsFunctionTunableResult<Marker> + Send, Marker> Send for FunctionTunable<F, Marker> {}
unsafe impl<F: AsFunctionTunableResult<Marker> + Sync, Marker> Sync for FunctionTunable<F, Marker> {}
impl<F: AsFunctionTunableResult<Marker>, Marker: 'static> TuneFn for FunctionTunable<F, Marker> {
type Inputs = F::Inputs;
type Output = F::Output;
fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
match self.func.execute(inputs) {
Ok(val) => Ok(val),
Err(err) => Err(AutotuneError::Unknown {
name: self.name.clone(),
err,
}),
}
}
fn name(&self) -> &str {
&self.name
}
}
#[doc(hidden)]
pub struct IsFunction;
impl<F: AsFunctionTunableResult<Marker>, Marker: 'static>
IntoTuneFn<F::Inputs, F::Output, (Marker, IsFunction)> for F
{
type Tunable = FunctionTunable<F, Marker>;
fn into_tunable(self, name: String) -> Self::Tunable {
FunctionTunable {
func: self,
name,
_marker: PhantomData,
}
}
}
pub struct FunctionTunableResultMap<F: AsFunctionTunable<Marker>, Marker> {
func: F,
_marker: PhantomData<Marker>,
}
unsafe impl<F: AsFunctionTunable<Marker> + Send, Marker> Send
for FunctionTunableResultMap<F, Marker>
{
}
unsafe impl<F: AsFunctionTunable<Marker> + Sync, Marker> Sync
for FunctionTunableResultMap<F, Marker>
{
}
impl<F: AsFunctionTunable<Marker>, Marker: 'static> TuneFn for FunctionTunableResultMap<F, Marker> {
type Inputs = F::Inputs;
type Output = F::Output;
fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError> {
Ok(self.func.execute(inputs))
}
fn name(&self) -> &str {
self.func.name()
}
}
#[diagnostic::on_unimplemented(
message = "`{Self}` is not a valid tunable.",
label = "invalid tunable"
)]
pub trait AsFunctionTunable<Marker>: Sized + Send + Sync + 'static {
type Inputs: Clone;
type Output;
fn execute(&self, inputs: Self::Inputs) -> Self::Output;
fn ok(self) -> FunctionTunableResultMap<Self, Marker> {
FunctionTunableResultMap {
func: self,
_marker: PhantomData,
}
}
fn name(&self) -> &str {
core::any::type_name::<Self>()
}
}
#[diagnostic::on_unimplemented(
message = "`{Self}` is not a valid tunable. For infallible kernels, use `AsFunctionTunable::ok`",
label = "invalid tunable"
)]
pub trait AsFunctionTunableResult<Marker>: Send + Sync + 'static {
type Inputs: Clone;
type Output;
fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, String>;
}
macro_rules! impl_tunable {
($(#[$meta:meta])* $($params:ident),*) => {
#[allow(unused_parens)]
$(#[$meta])*
impl<Out: 'static, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunable<fn($($params),*) -> Out> for Func
where Func: Send + Sync + 'static,
for<'a> &'a Func: Fn($($params),*) -> Out
{
type Inputs = ($($params),*);
type Output = Out;
#[allow(non_snake_case, clippy::too_many_arguments)]
#[inline]
fn execute(&self, ($($params),*): ($($params),*)) -> Out {
fn call_inner<Out, $($params,)*>(
f: impl Fn($($params,)*) -> Out,
$($params: $params,)*
) -> Out {
f($($params,)*)
}
call_inner(self, $($params),*)
}
}
};
}
macro_rules! impl_tunable_result {
($(#[$meta:meta])* $($params:ident),*) => {
#[allow(unused_parens)]
$(#[$meta])*
impl<Out: 'static, Err, Func, $($params: Clone + Send + 'static,)*> AsFunctionTunableResult<fn($($params),*) -> Result<Out, Err>> for Func
where Func: Send + Sync + 'static,
for<'a> &'a Func: Fn($($params),*) -> Result<Out, Err>,
Err: Into<String>
{
type Inputs = ($($params),*);
type Output = Out;
#[allow(non_snake_case, clippy::too_many_arguments)]
#[inline]
fn execute(&self, ($($params),*): ($($params),*)) -> Result<Out, String> {
fn call_inner<Out, Err, $($params,)*>(
f: impl Fn($($params,)*) -> Result<Out, Err>,
$($params: $params,)*
) -> Result<Out, Err> {
f($($params,)*)
}
call_inner(self, $($params),*).map_err(Into::into)
}
}
};
}
all_tuples!(impl_tunable, 0, 12, I);
all_tuples!(impl_tunable_result, 0, 12, I);