use alloc::{borrow::Cow, boxed::Box};
use bevy_platform::sync::Arc;
use core::fmt::{Debug, Formatter};
use crate::func::{
args::{ArgCount, ArgList},
dynamic_function_internal::DynamicFunctionInternal,
DynamicFunction, FunctionInfo, FunctionOverloadError, FunctionResult, IntoFunctionMut,
};
type BoxFnMut<'env> = Box<dyn for<'a> FnMut(ArgList<'a>) -> FunctionResult<'a> + 'env>;
pub struct DynamicFunctionMut<'env> {
internal: DynamicFunctionInternal<BoxFnMut<'env>>,
}
impl<'env> DynamicFunctionMut<'env> {
pub fn new<F: for<'a> FnMut(ArgList<'a>) -> FunctionResult<'a> + 'env>(
func: F,
info: impl TryInto<FunctionInfo, Error: Debug>,
) -> Self {
Self {
internal: DynamicFunctionInternal::new(Box::new(func), info.try_into().unwrap()),
}
}
pub fn with_name(mut self, name: impl Into<Cow<'static, str>>) -> Self {
self.internal = self.internal.with_name(name);
self
}
pub fn with_overload<'a, F: IntoFunctionMut<'a, Marker>, Marker>(
self,
function: F,
) -> DynamicFunctionMut<'a>
where
'env: 'a,
{
self.try_with_overload(function).unwrap_or_else(|(_, err)| {
panic!("{}", err);
})
}
pub fn try_with_overload<F: IntoFunctionMut<'env, Marker>, Marker>(
mut self,
function: F,
) -> Result<Self, (Box<Self>, FunctionOverloadError)> {
let function = function.into_function_mut();
match self.internal.merge(function.internal) {
Ok(_) => Ok(self),
Err(err) => Err((Box::new(self), err)),
}
}
pub fn call<'a>(&mut self, args: ArgList<'a>) -> FunctionResult<'a> {
self.internal.validate_args(&args)?;
let func = self.internal.get_mut(&args)?;
func(args)
}
pub fn call_once(mut self, args: ArgList) -> FunctionResult {
self.call(args)
}
pub fn info(&self) -> &FunctionInfo {
self.internal.info()
}
pub fn name(&self) -> Option<&Cow<'static, str>> {
self.internal.name()
}
pub fn is_overloaded(&self) -> bool {
self.internal.is_overloaded()
}
pub fn arg_count(&self) -> ArgCount {
self.internal.arg_count()
}
}
impl<'env> Debug for DynamicFunctionMut<'env> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "DynamicFunctionMut({:?})", &self.internal)
}
}
impl<'env> From<DynamicFunction<'env>> for DynamicFunctionMut<'env> {
#[inline]
fn from(function: DynamicFunction<'env>) -> Self {
Self {
internal: function.internal.map_functions(arc_to_box),
}
}
}
fn arc_to_box<'env>(
f: Arc<dyn for<'a> Fn(ArgList<'a>) -> FunctionResult<'a> + Send + Sync + 'env>,
) -> BoxFnMut<'env> {
Box::new(move |args| f(args))
}
impl<'env> IntoFunctionMut<'env, ()> for DynamicFunctionMut<'env> {
#[inline]
fn into_function_mut(self) -> DynamicFunctionMut<'env> {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::func::{FunctionError, IntoReturn, SignatureInfo};
use alloc::vec;
use core::ops::Add;
#[test]
fn should_overwrite_function_name() {
let mut total = 0;
let func = (|a: i32, b: i32| total = a + b).into_function_mut();
assert!(func.name().is_none());
let func = func.with_name("my_function");
assert_eq!(func.name().unwrap(), "my_function");
}
#[test]
fn should_convert_dynamic_function_mut_with_into_function() {
fn make_closure<'env, F: IntoFunctionMut<'env, M>, M>(f: F) -> DynamicFunctionMut<'env> {
f.into_function_mut()
}
let mut total = 0;
let closure: DynamicFunctionMut = make_closure(|a: i32, b: i32| total = a + b);
let _: DynamicFunctionMut = make_closure(closure);
}
#[test]
fn should_return_error_on_arg_count_mismatch() {
let mut total = 0;
let mut func = (|a: i32, b: i32| total = a + b).into_function_mut();
let args = ArgList::default().with_owned(25_i32);
let error = func.call(args).unwrap_err();
assert_eq!(
error,
FunctionError::ArgCountMismatch {
expected: ArgCount::new(2).unwrap(),
received: 1
}
);
let args = ArgList::default().with_owned(25_i32);
let error = func.call_once(args).unwrap_err();
assert_eq!(
error,
FunctionError::ArgCountMismatch {
expected: ArgCount::new(2).unwrap(),
received: 1
}
);
}
#[test]
fn should_allow_creating_manual_generic_dynamic_function_mut() {
let mut total = 0_i32;
let func = DynamicFunctionMut::new(
|mut args| {
let value = args.take_arg()?;
if value.is::<i32>() {
let value = value.take::<i32>()?;
total += value;
} else {
let value = value.take::<i16>()?;
total += value as i32;
}
Ok(().into_return())
},
vec![
SignatureInfo::named("add::<i32>").with_arg::<i32>("value"),
SignatureInfo::named("add::<i16>").with_arg::<i16>("value"),
],
);
assert_eq!(func.name().unwrap(), "add::<i32>");
let mut func = func.with_name("add");
assert_eq!(func.name().unwrap(), "add");
let args = ArgList::default().with_owned(25_i32);
func.call(args).unwrap();
let args = ArgList::default().with_owned(75_i16);
func.call(args).unwrap();
drop(func);
assert_eq!(total, 100);
}
#[test]
fn should_allow_function_overloading() {
fn add<T: Add<Output = T>>(a: T, b: T) -> T {
a + b
}
let mut func = add::<i32>.into_function_mut().with_overload(add::<f32>);
let args = ArgList::default().with_owned(25_i32).with_owned(75_i32);
let result = func.call(args).unwrap().unwrap_owned();
assert_eq!(result.try_take::<i32>().unwrap(), 100);
let args = ArgList::default().with_owned(25.0_f32).with_owned(75.0_f32);
let result = func.call(args).unwrap().unwrap_owned();
assert_eq!(result.try_take::<f32>().unwrap(), 100.0);
}
}