use std::{
cell::RefCell,
ffi::{CStr, CString},
os::raw::c_char,
};
use ref_cast::RefCast;
use crate::{
expr::{Expr, Symbol},
rtl,
sys::{self, mint, mreal, MArgument},
wstp::Link,
DataStore, Image, NumericArray,
};
pub trait FromArg<'a> {
#[allow(missing_docs)]
unsafe fn from_arg(arg: &'a MArgument) -> Self;
fn parameter_type() -> Expr;
}
pub trait IntoArg {
unsafe fn into_arg(self, arg: MArgument);
fn return_type() -> Expr;
}
pub trait NativeFunction<'a> {
unsafe fn call(&self, args: &'a [MArgument], ret: MArgument);
fn signature(&self) -> Result<(Vec<Expr>, Expr), String>;
}
pub trait WstpFunction {
unsafe fn call(&self, link: &mut Link);
}
impl FromArg<'_> for bool {
unsafe fn from_arg(arg: &MArgument) -> Self {
crate::bool_from_mbool(*arg.boolean)
}
fn parameter_type() -> Expr {
Expr::string("Boolean")
}
}
impl FromArg<'_> for mint {
unsafe fn from_arg(arg: &MArgument) -> Self {
*arg.integer
}
fn parameter_type() -> Expr {
Expr::symbol(Symbol::new("System`Integer"))
}
}
impl FromArg<'_> for mreal {
unsafe fn from_arg(arg: &MArgument) -> Self {
*arg.real
}
fn parameter_type() -> Expr {
Expr::symbol(Symbol::new("System`Real"))
}
}
impl FromArg<'_> for sys::mcomplex {
unsafe fn from_arg(arg: &MArgument) -> Self {
*arg.cmplex
}
fn parameter_type() -> Expr {
Expr::symbol(Symbol::new("System`Complex"))
}
}
unsafe fn c_str_from_arg<'a>(arg: &'a MArgument) -> &'a CStr {
let cstr: *mut c_char = *arg.utf8string;
CStr::from_ptr(cstr)
}
impl<'a> FromArg<'a> for CString {
unsafe fn from_arg(arg: &'a MArgument) -> CString {
let owned = {
let cstr: &'a CStr = c_str_from_arg(arg);
CString::from(cstr)
};
rtl::UTF8String_disown(*arg.utf8string);
owned
}
fn parameter_type() -> Expr {
Expr::symbol(Symbol::new("System`String"))
}
}
impl<'a> FromArg<'a> for String {
unsafe fn from_arg(arg: &'a MArgument) -> String {
let owned = {
let cstr: &'a CStr = c_str_from_arg(arg);
let str: &'a str = cstr
.to_str()
.expect("FromArg for &str: string was not valid UTF-8");
str.to_owned()
};
rtl::UTF8String_disown(*arg.utf8string);
owned
}
fn parameter_type() -> Expr {
Expr::symbol(Symbol::new("System`String"))
}
}
impl<'a> FromArg<'a> for &'a CStr {
unsafe fn from_arg(arg: &'a MArgument) -> &'a CStr {
c_str_from_arg(arg)
}
fn parameter_type() -> Expr {
panic!("&CStr cannot be used as a LibraryLink function parameter type")
}
}
impl<'a> FromArg<'a> for &'a str {
unsafe fn from_arg(arg: &'a MArgument) -> &'a str {
let cstr: &'a CStr = FromArg::<'a>::from_arg(arg);
cstr.to_str()
.expect("FromArg for &str: string was not valid UTF-8")
}
fn parameter_type() -> Expr {
panic!("&str cannot be used as a LibraryLink function parameter type")
}
}
impl<'a, T: crate::NumericArrayType> FromArg<'a> for &'a NumericArray<T> {
unsafe fn from_arg(arg: &'a MArgument) -> &'a NumericArray<T> {
NumericArray::ref_cast(&*arg.numeric)
}
fn parameter_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::normal(Symbol::new("System`LibraryDataType"), vec![
Expr::from(Symbol::new("System`NumericArray")),
Expr::string(T::TYPE.name()),
]),
Expr::string("Constant"),
])
}
}
impl<'a, T: crate::NumericArrayType> FromArg<'a> for NumericArray<T> {
unsafe fn from_arg(arg: &'a MArgument) -> NumericArray<T> {
NumericArray::from_raw(*arg.numeric)
}
fn parameter_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::normal(Symbol::new("System`LibraryDataType"), vec![
Expr::from(Symbol::new("System`NumericArray")),
Expr::string(T::TYPE.name()),
]),
Expr::string("Shared"),
])
}
}
impl<'a> FromArg<'a> for &'a NumericArray<()> {
unsafe fn from_arg(arg: &'a MArgument) -> &'a NumericArray<()> {
NumericArray::ref_cast(&*arg.numeric)
}
fn parameter_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::from(Symbol::new("System`NumericArray")),
Expr::string("Constant"),
])
}
}
impl<'a> FromArg<'a> for NumericArray<()> {
unsafe fn from_arg(arg: &'a MArgument) -> NumericArray<()> {
NumericArray::from_raw(*arg.numeric)
}
fn parameter_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::from(Symbol::new("System`NumericArray")),
Expr::string("Shared"),
])
}
}
impl<'a, T: crate::ImageData> FromArg<'a> for &'a Image<T> {
unsafe fn from_arg(arg: &'a MArgument) -> &'a Image<T> {
Image::ref_cast(&*arg.image)
}
fn parameter_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::normal(Symbol::new("System`LibraryDataType"), vec![
Expr::normal(Symbol::new("System`Alternatives"), vec![
Expr::from(Symbol::new("System`Image")),
Expr::from(Symbol::new("System`Image3D")),
]),
Expr::string(T::TYPE.name()),
]),
Expr::string("Constant"),
])
}
}
impl<'a, T: crate::ImageData> FromArg<'a> for Image<T> {
unsafe fn from_arg(arg: &'a MArgument) -> Image<T> {
Image::from_raw(*arg.image)
}
fn parameter_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::normal(Symbol::new("System`LibraryDataType"), vec![
Expr::normal(Symbol::new("System`Alternatives"), vec![
Expr::from(Symbol::new("System`Image")),
Expr::from(Symbol::new("System`Image3D")),
]),
Expr::string(T::TYPE.name()),
]),
Expr::string("Shared"),
])
}
}
impl<'a> FromArg<'a> for &'a Image<()> {
unsafe fn from_arg(arg: &'a MArgument) -> &'a Image<()> {
Image::ref_cast(&*arg.image)
}
fn parameter_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::normal(Symbol::new("System`Alternatives"), vec![
Expr::from(Symbol::new("System`Image")),
Expr::from(Symbol::new("System`Image3D")),
]),
Expr::string("Constant"),
])
}
}
impl<'a> FromArg<'a> for Image<()> {
unsafe fn from_arg(arg: &'a MArgument) -> Image<()> {
Image::from_raw(*arg.image)
}
fn parameter_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::normal(Symbol::new("System`Alternatives"), vec![
Expr::from(Symbol::new("System`Image")),
Expr::from(Symbol::new("System`Image3D")),
]),
Expr::string("Shared"),
])
}
}
impl FromArg<'_> for DataStore {
unsafe fn from_arg(arg: &MArgument) -> DataStore {
DataStore::from_raw(*arg.tensor as sys::DataStore)
}
fn parameter_type() -> Expr {
Expr::string("DataStore")
}
}
impl<'a> FromArg<'a> for &'a DataStore {
unsafe fn from_arg(arg: &MArgument) -> &'a DataStore {
DataStore::ref_cast(&*(arg.tensor as *mut sys::DataStore))
}
fn parameter_type() -> Expr {
panic!("&DataStore cannot be used as a LibraryLink function parameter type")
}
}
impl IntoArg for () {
unsafe fn into_arg(self, _arg: MArgument) {
}
fn return_type() -> Expr {
Expr::string("Void")
}
}
impl IntoArg for bool {
unsafe fn into_arg(self, arg: MArgument) {
let boole: u32 = if self { sys::True } else { sys::False };
*arg.boolean = boole as sys::mbool;
}
fn return_type() -> Expr {
Expr::string("Boolean")
}
}
impl IntoArg for mint {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = self;
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Integer"))
}
}
impl IntoArg for mreal {
unsafe fn into_arg(self, arg: MArgument) {
*arg.real = self;
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Real"))
}
}
impl IntoArg for sys::mcomplex {
unsafe fn into_arg(self, arg: MArgument) {
*arg.cmplex = self;
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Complex"))
}
}
impl IntoArg for i8 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Integer"))
}
}
impl IntoArg for i16 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Integer"))
}
}
impl IntoArg for i32 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Integer"))
}
}
impl IntoArg for u8 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Integer"))
}
}
impl IntoArg for u16 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Integer"))
}
}
#[cfg(target_pointer_width = "64")]
impl IntoArg for u32 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
Expr::symbol(Symbol::new("System`Integer"))
}
}
thread_local! {
static RETURNED_STRING: RefCell<Option<CString>> = RefCell::new(None);
}
impl IntoArg for CString {
unsafe fn into_arg(self, arg: MArgument) {
let raw: *const c_char = RETURNED_STRING.with(|stored| {
if let Some(prev) = stored.replace(None) {
drop(prev);
}
let raw: *const c_char = self.as_ptr();
*stored.borrow_mut() = Some(self);
raw
});
*arg.utf8string = raw as *mut c_char;
}
fn return_type() -> Expr {
Expr::from(Symbol::new("System`String"))
}
}
impl IntoArg for String {
unsafe fn into_arg(self, arg: MArgument) {
let cstring = CString::new(self)
.expect("IntoArg for String: could not convert String to CString");
<CString as IntoArg>::into_arg(cstring, arg)
}
fn return_type() -> Expr {
Expr::from(Symbol::new("System`String"))
}
}
impl<T: crate::NumericArrayType> IntoArg for NumericArray<T> {
unsafe fn into_arg(self, arg: MArgument) {
*arg.numeric = self.into_raw();
}
fn return_type() -> Expr {
Expr::normal(Symbol::new("System`LibraryDataType"), vec![
Expr::from(Symbol::new("System`NumericArray")),
Expr::string(T::TYPE.name()),
])
}
}
impl IntoArg for NumericArray<()> {
unsafe fn into_arg(self, arg: MArgument) {
*arg.numeric = self.into_raw();
}
fn return_type() -> Expr {
Expr::from(Symbol::new("System`NumericArray"))
}
}
impl<T: crate::ImageData> IntoArg for Image<T> {
unsafe fn into_arg(self, arg: MArgument) {
*arg.image = self.into_raw();
}
fn return_type() -> Expr {
Expr::normal(Symbol::new("System`List"), vec![
Expr::normal(Symbol::new("System`LibraryDataType"), vec![
Expr::normal(Symbol::new("System`Alternatives"), vec![
Expr::from(Symbol::new("System`Image")),
Expr::from(Symbol::new("System`Image3D")),
]),
Expr::string(T::TYPE.name()),
]),
Expr::string("Shared"),
])
}
}
impl IntoArg for DataStore {
unsafe fn into_arg(self, arg: MArgument) {
*arg.tensor = self.into_raw() as *mut _;
}
fn return_type() -> Expr {
Expr::string("DataStore")
}
}
impl<'a: 'b, 'b> NativeFunction<'a> for fn(&'b [MArgument], MArgument) {
unsafe fn call(&self, args: &'a [MArgument], ret: MArgument) {
self(args, ret)
}
fn signature(&self) -> Result<(Vec<Expr>, Expr), String> {
Err(
"fn(&[MArgument], MArgument) function cannot be loaded automatically: \
parameter and return types are unknown."
.to_owned(),
)
}
}
macro_rules! impl_NativeFunction {
($($type:ident),*) => {
impl<'a, $($type,)* R> NativeFunction<'a> for fn($($type),*) -> R
where
R: IntoArg,
$($type: FromArg<'a>),*
{
unsafe fn call(&self, args: &'a [MArgument], ret: MArgument) {
#[allow(non_snake_case)]
let [$($type,)*] = match args {
[$($type,)*] => [$($type,)*],
_ => panic!(
"LibraryLink function number of arguments ({}) does not match \
number of parameters",
args.len()
),
};
$(
#[allow(non_snake_case)]
let $type: $type = $type::from_arg($type);
)*
let result: R = self($($type,)*);
result.into_arg(ret);
}
fn signature(&self) -> Result<(Vec<Expr>, Expr), String> {
let mut param_tys = Vec::new();
$(
param_tys.push($type::parameter_type());
)*
Ok((param_tys, R::return_type()))
}
}
}
}
impl<'a, R> NativeFunction<'a> for fn() -> R
where
R: IntoArg,
{
unsafe fn call(&self, args: &[MArgument], ret: MArgument) {
if args.len() != 0 {
panic!(
"LibraryLink function number of arguments ({}) does not match number of \
parameters",
args.len()
);
}
let result = self();
result.into_arg(ret);
}
fn signature(&self) -> Result<(Vec<Expr>, Expr), String> {
Ok((Vec::new(), R::return_type()))
}
}
impl_NativeFunction!(A1);
impl_NativeFunction!(A1, A2);
impl_NativeFunction!(A1, A2, A3);
impl_NativeFunction!(A1, A2, A3, A4);
impl_NativeFunction!(A1, A2, A3, A4, A5);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8, A9);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12);
impl WstpFunction for fn(&mut Link) {
unsafe fn call(&self, link: &mut Link) {
self(link)
}
}
impl WstpFunction for fn(Vec<Expr>) -> Expr {
unsafe fn call(&self, link: &mut Link) {
let args: Vec<Expr> = match get_args_list(link) {
Ok(args) => args,
Err(message) => panic!("WstpFunction: {}", message),
};
let result: Expr = self(args);
match link.put_expr(&result) {
Ok(()) => (),
Err(err) => panic!(
"WstpFunction: WSTP error writing return expression to link: {}",
err
),
}
}
}
impl WstpFunction for fn(Vec<Expr>) {
unsafe fn call(&self, link: &mut Link) {
let args: Vec<Expr> = match get_args_list(link) {
Ok(args) => args,
Err(message) => panic!("WstpFunction: {}", message),
};
let _null: () = self(args);
match link.put_symbol("System`Null") {
Ok(()) => (),
Err(err) => panic!(
"WstpFunction: WSTP error writing return Null expression to link: {}",
err
),
}
}
}
fn get_args_list(link: &mut Link) -> Result<Vec<Expr>, String> {
get_args_list_impl(link).map_err(|err: wstp::Error| {
format!("WSTP error reading argument List expression: {}", err)
})
}
fn get_args_list_impl(link: &mut Link) -> Result<Vec<Expr>, wstp::Error> {
let arg_count: usize = match link.test_head("List") {
Ok(count) => Ok(count),
Err(err) if err.code() == Some(wstp::sys::WSEGSEQ) => {
link.clear_error();
link.test_head("System`List")
},
Err(err) => Err(err),
}?;
let mut elements: Vec<Expr> = Vec::new();
for _ in 0..arg_count {
let elem = link.get_expr()?;
elements.push(elem);
}
Ok(elements)
}