use std::convert::{TryFrom, TryInto};
use std::{
os::raw::{c_int, c_void},
ptr, slice,
};
use super::{function::Result, Function};
use crate::errors::Error;
pub use tvm_sys::{ffi, ArgValue, RetValue};
pub trait Typed<I, O> {
fn args(i: Vec<ArgValue<'static>>) -> Result<I>;
fn ret(o: O) -> Result<RetValue>;
}
pub trait ToFunction<I, O>: Sized {
type Handle;
fn into_raw(self) -> *mut Self::Handle;
fn call(handle: *mut Self::Handle, args: Vec<ArgValue<'static>>) -> Result<RetValue>
where
Self: Typed<I, O>;
fn drop(handle: *mut Self::Handle);
fn to_function(self) -> Function
where
Self: Typed<I, O>,
{
let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
let resource_handle = self.into_raw();
check_call!(ffi::TVMFuncCreateFromCFunc(
Some(Self::tvm_callback),
resource_handle as *mut _,
None, &mut fhandle as *mut ffi::TVMFunctionHandle,
));
Function::new(fhandle)
}
unsafe extern "C" fn tvm_callback(
args: *mut ffi::TVMValue,
type_codes: *mut c_int,
num_args: c_int,
ret: ffi::TVMRetValueHandle,
resource_handle: *mut c_void,
) -> c_int
where
Self: Typed<I, O>,
{
#![allow(unused_assignments, unused_unsafe)]
let result = std::panic::catch_unwind(|| {
let len = num_args as usize;
let args_list = slice::from_raw_parts_mut(args, len);
let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
let mut local_args: Vec<ArgValue> = Vec::new();
let mut value = ffi::TVMValue { v_int64: 0 };
let mut tcode = 0;
let resource_handle = resource_handle as *mut Self::Handle;
for i in 0..len {
value = args_list[i];
tcode = type_codes_list[i];
if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int
|| tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int
|| tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int
|| tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int
{
check_call!(ffi::TVMCbArgToReturn(
&mut value as *mut _,
&mut tcode as *mut _
));
}
let arg_value = ArgValue::from_tvm_value(value, tcode as u32);
local_args.push(arg_value);
}
let rv = match Self::call(resource_handle, local_args) {
Ok(v) => v,
Err(msg) => {
return Err(msg);
}
};
let (mut ret_val, ret_tcode) = rv.to_tvm_value();
let mut ret_type_code = ret_tcode as c_int;
check_call!(ffi::TVMCFuncSetReturn(
ret,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _,
1 as c_int
));
Ok(())
});
match result {
Err(_) => {
crate::set_last_error(&Error::Panic);
return -1;
}
Ok(inner_res) => match inner_res {
Err(err) => {
crate::set_last_error(&err);
return -1;
}
Ok(()) => return 0,
},
}
}
unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) {
let handle = std::mem::transmute(fhandle);
Self::drop(handle)
}
}
impl Typed<Vec<ArgValue<'static>>, RetValue> for fn(Vec<ArgValue<'static>>) -> Result<RetValue> {
fn args(args: Vec<ArgValue<'static>>) -> Result<Vec<ArgValue<'static>>> {
Ok(args)
}
fn ret(o: RetValue) -> Result<RetValue> {
Ok(o)
}
}
impl ToFunction<Vec<ArgValue<'static>>, RetValue>
for fn(Vec<ArgValue<'static>>) -> Result<RetValue>
{
type Handle = fn(Vec<ArgValue<'static>>) -> Result<RetValue>;
fn into_raw(self) -> *mut Self::Handle {
let ptr: Box<Self::Handle> = Box::new(self);
Box::into_raw(ptr)
}
fn call(handle: *mut Self::Handle, args: Vec<ArgValue<'static>>) -> Result<RetValue> {
unsafe { (*handle)(args) }
}
fn drop(_: *mut Self::Handle) {}
}
macro_rules! impl_typed_and_to_function {
($len:literal; $($t:ident),*) => {
impl<F, Out, $($t),*> Typed<($($t,)*), Out> for F
where
F: Fn($($t),*) -> Out,
Out: TryInto<RetValue>,
Error: From<Out::Error>,
$( $t: TryFrom<ArgValue<'static>>,
Error: From<$t::Error>, )*
{
#[allow(non_snake_case, unused_variables, unused_mut)]
fn args(args: Vec<ArgValue<'static>>) -> Result<($($t,)*)> {
if args.len() != $len {
return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n",
std::any::type_name::<Self>(),
$len, args.len())))
}
let mut args = args.into_iter();
$(let $t = args.next().unwrap().try_into()?;)*
Ok(($($t,)*))
}
fn ret(out: Out) -> Result<RetValue> {
out.try_into().map_err(|e| e.into())
}
}
impl<F, $($t,)* Out> ToFunction<($($t,)*), Out> for F
where
F: Fn($($t,)*) -> Out + 'static
{
type Handle = Box<dyn Fn($($t,)*) -> Out + 'static>;
fn into_raw(self) -> *mut Self::Handle {
let ptr: Box<Self::Handle> = Box::new(Box::new(self));
Box::into_raw(ptr)
}
#[allow(non_snake_case)]
fn call(handle: *mut Self::Handle, args: Vec<ArgValue<'static>>) -> Result<RetValue>
where
F: Typed<($($t,)*), Out>
{
let ($($t,)*) = F::args(args)?;
let out = unsafe { (*handle)($($t),*) };
F::ret(out)
}
fn drop(ptr: *mut Self::Handle) {
let bx = unsafe { Box::from_raw(ptr) };
std::mem::drop(bx)
}
}
}
}
impl_typed_and_to_function!(0;);
impl_typed_and_to_function!(1; A);
impl_typed_and_to_function!(2; A, B);
impl_typed_and_to_function!(3; A, B, C);
impl_typed_and_to_function!(4; A, B, C, D);
impl_typed_and_to_function!(5; A, B, C, D, E);
impl_typed_and_to_function!(6; A, B, C, D, E, G);
#[cfg(test)]
mod tests {
use super::*;
fn call<F, I, O>(f: F, args: Vec<ArgValue<'static>>) -> Result<RetValue>
where
F: ToFunction<I, O>,
F: Typed<I, O>,
{
F::call(f.into_raw(), args)
}
#[test]
fn test_to_function0() {
fn zero() -> i32 {
10
}
let _ = zero.to_function();
let good = call(zero, vec![]).unwrap();
assert_eq!(i32::try_from(good).unwrap(), 10);
let bad = call(zero, vec![1.into()]).unwrap_err();
assert!(matches!(bad, Error::CallFailed(..)));
}
#[test]
fn test_to_function2() {
fn two_arg(i: i32, j: i32) -> i32 {
i + j
}
let good = call(two_arg, vec![3.into(), 4.into()]).unwrap();
assert_eq!(i32::try_from(good).unwrap(), 7);
}
}