use std::convert::{TryFrom, TryInto};
use std::{
ffi::CString,
os::raw::{c_char, c_int},
ptr, str,
};
use crate::errors::Error;
pub use super::to_function::{ToFunction, Typed};
pub use tvm_sys::{ffi, ArgValue, RetValue};
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Hash)]
pub struct Function {
pub(crate) handle: ffi::TVMFunctionHandle,
is_global: bool,
from_rust: bool,
}
unsafe impl Send for Function {}
unsafe impl Sync for Function {}
impl Function {
pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
Function {
handle,
is_global: false,
from_rust: false,
}
}
pub unsafe fn null() -> Self {
Function {
handle: std::ptr::null_mut(),
is_global: false,
from_rust: false,
}
}
pub fn get<S: AsRef<str>>(name: S) -> Option<Function> {
let name = CString::new(name.as_ref()).unwrap();
let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
check_call!(ffi::TVMFuncGetGlobal(
name.as_ptr() as *const c_char,
&mut handle as *mut _
));
if handle.is_null() {
None
} else {
Some(Function {
handle,
is_global: true,
from_rust: false,
})
}
}
pub fn get_boxed<F, S>(name: S) -> Option<Box<F>>
where
S: AsRef<str>,
F: ?Sized,
Self: Into<Box<F>>,
{
Self::get(name).map(|f| f.into())
}
pub fn handle(&self) -> ffi::TVMFunctionHandle {
self.handle
}
pub fn is_global(&self) -> bool {
self.is_global
}
pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue> {
let num_args = arg_buf.len();
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMArgTypeCode>) =
arg_buf.into_iter().map(|arg| arg.to_tvm_value()).unzip();
let mut ret_val = ffi::TVMValue { v_int64: 0 };
let mut ret_type_code = 0i32;
let ret_code = unsafe {
ffi::TVMFuncCall(
self.handle,
values.as_mut_ptr() as *mut ffi::TVMValue,
type_codes.as_mut_ptr() as *mut c_int,
num_args as c_int,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _,
)
};
if ret_code != 0 {
let raw_error = crate::get_last_error();
let error = match Error::from_raw_tvm(raw_error) {
Error::Raw(string) => Error::CallFailed(string),
e => e,
};
return Err(error);
}
let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);
Ok(rv)
}
}
macro_rules! impl_to_fn {
() => { impl_to_fn!(@impl); };
($t:ident, $($ts:ident,)*) => { impl_to_fn!(@impl $t, $($ts,)*); impl_to_fn!($($ts,)*); };
(@impl $($t:ident,)*) => {
impl<Err, Out, $($t,)*> From<Function> for Box<dyn Fn($($t,)*) -> Result<Out>>
where
Error: From<Err>,
Out: TryFrom<RetValue, Error = Err>,
$($t: Into<ArgValue<'static>>),*
{
fn from(func: Function) -> Self {
#[allow(non_snake_case)]
Box::new(move |$($t : $t),*| {
let args = vec![ $($t.into()),* ];
Ok(func.invoke(args)?.try_into()?)
})
}
}
};
}
impl_to_fn!(T1, T2, T3, T4, T5, T6,);
impl Clone for Function {
fn clone(&self) -> Function {
Self {
handle: self.handle,
is_global: self.is_global,
from_rust: true,
}
}
}
impl From<Function> for RetValue {
fn from(func: Function) -> RetValue {
RetValue::FuncHandle(func.handle)
}
}
impl TryFrom<RetValue> for Function {
type Error = Error;
fn try_from(ret_value: RetValue) -> Result<Function> {
match ret_value {
RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
_ => Err(Error::downcast(
format!("{:?}", ret_value),
"FunctionHandle",
)),
}
}
}
impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
if func.handle.is_null() {
ArgValue::Null
} else {
ArgValue::FuncHandle(func.handle)
}
}
}
impl<'a> TryFrom<ArgValue<'a>> for Function {
type Error = Error;
fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
_ => Err(Error::downcast(
format!("{:?}", arg_value),
"FunctionHandle",
)),
}
}
}
impl<'a> TryFrom<&ArgValue<'a>> for Function {
type Error = Error;
fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
_ => Err(Error::downcast(
format!("{:?}", arg_value),
"FunctionHandle",
)),
}
}
}
pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()>
where
F: ToFunction<I, O>,
F: Typed<I, O>,
{
register_override(f, name, false)
}
pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: bool) -> Result<()>
where
F: ToFunction<I, O>,
F: Typed<I, O>,
{
let func = f.to_function();
let name = name.into();
let handle = func.handle();
let name = CString::new(name)?;
check_call!(ffi::TVMFuncRegisterGlobal(
name.into_raw(),
handle,
override_ as c_int
));
Ok(())
}
pub fn register_untyped<S: Into<String>>(
f: fn(Vec<ArgValue<'static>>) -> Result<RetValue>,
name: S,
override_: bool,
) -> Result<()> {
let func = f.to_function();
let name = name.into();
let handle = func.handle();
let name = CString::new(name)?;
check_call!(ffi::TVMFuncRegisterGlobal(
name.into_raw(),
handle,
override_ as c_int
));
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::function::Function;
static CANARY: &str = "runtime.ModuleLoadFromFile";
#[test]
fn get_fn() {
assert!(Function::get(CANARY).is_some());
assert!(Function::get("does not exists!").is_none());
}
#[test]
fn register_and_call_closure0() {
use crate::function;
use function::Result;
fn constfn() -> i64 {
return 10;
}
function::register_override(constfn, "constfn".to_owned(), true).unwrap();
let func = Function::get_boxed::<dyn Fn() -> Result<i32>, _>("constfn").unwrap();
let ret = func().unwrap();
assert_eq!(ret, 10);
}
#[test]
fn register_and_call_closure1() {
use crate::function::{self};
fn ident(x: i64) -> i64 {
return x;
}
function::register_override(ident, "ident".to_owned(), true).unwrap();
let func = Function::get_boxed::<dyn Fn(i32) -> Result<i32>, _>("ident").unwrap();
assert_eq!(func(60).unwrap(), 60);
}
}