use std::{
cell::RefCell,
ffi::{CStr, CString},
os::raw::c_char,
};
use ref_cast::RefCast;
#[cfg(feature = "wstp")]
use crate::wstp::Link;
use crate::{
expr::{expr, Expr},
rtl,
sys::{self, mint, mreal, MArgument},
DataStore, Image, NumericArray,
};
#[cfg(feature = "wstp")]
use crate::expr::Symbol;
pub trait FromArg<'a> {
#[allow(missing_docs)]
unsafe fn from_arg(arg: &'a MArgument) -> Self;
fn parameter_type() -> Expr;
}
pub trait IntoArg {
unsafe fn into_arg(self, arg: MArgument);
fn return_type() -> Expr;
}
pub trait NativeFunction<'a> {
unsafe fn call(&self, args: &'a [MArgument], ret: MArgument);
fn signature(&self) -> Result<(Vec<Expr>, Expr), String>;
}
#[cfg(feature = "wstp")]
pub trait WstpFunction {
unsafe fn call(&self, link: &mut Link);
}
impl FromArg<'_> for bool {
unsafe fn from_arg(arg: &MArgument) -> Self {
crate::bool_from_mbool(*arg.boolean)
}
fn parameter_type() -> Expr {
Expr::string("Boolean")
}
}
impl FromArg<'_> for mint {
unsafe fn from_arg(arg: &MArgument) -> Self {
*arg.integer
}
fn parameter_type() -> Expr {
expr!(System::Integer)
}
}
impl FromArg<'_> for mreal {
unsafe fn from_arg(arg: &MArgument) -> Self {
*arg.real
}
fn parameter_type() -> Expr {
expr!(System::Real)
}
}
impl FromArg<'_> for sys::mcomplex {
unsafe fn from_arg(arg: &MArgument) -> Self {
*arg.cmplex
}
fn parameter_type() -> Expr {
expr!(System::Complex)
}
}
unsafe fn c_str_from_arg<'a>(arg: &'a MArgument) -> &'a CStr {
let cstr: *mut c_char = *arg.utf8string;
CStr::from_ptr(cstr)
}
impl<'a> FromArg<'a> for CString {
unsafe fn from_arg(arg: &'a MArgument) -> CString {
let owned = {
let cstr: &'a CStr = c_str_from_arg(arg);
CString::from(cstr)
};
rtl::UTF8String_disown(*arg.utf8string);
owned
}
fn parameter_type() -> Expr {
expr!(System::String)
}
}
impl<'a> FromArg<'a> for String {
unsafe fn from_arg(arg: &'a MArgument) -> String {
let owned = {
let cstr: &'a CStr = c_str_from_arg(arg);
let str: &'a str = cstr
.to_str()
.expect("FromArg for &str: string was not valid UTF-8");
str.to_owned()
};
rtl::UTF8String_disown(*arg.utf8string);
owned
}
fn parameter_type() -> Expr {
expr!(System::String)
}
}
impl<'a> FromArg<'a> for &'a CStr {
unsafe fn from_arg(arg: &'a MArgument) -> &'a CStr {
c_str_from_arg(arg)
}
fn parameter_type() -> Expr {
panic!("&CStr cannot be used as a LibraryLink function parameter type")
}
}
impl<'a> FromArg<'a> for &'a str {
unsafe fn from_arg(arg: &'a MArgument) -> &'a str {
let cstr: &'a CStr = FromArg::<'a>::from_arg(arg);
cstr.to_str()
.expect("FromArg for &str: string was not valid UTF-8")
}
fn parameter_type() -> Expr {
panic!("&str cannot be used as a LibraryLink function parameter type")
}
}
impl<'a, T: crate::NumericArrayType> FromArg<'a> for &'a NumericArray<T> {
unsafe fn from_arg(arg: &'a MArgument) -> &'a NumericArray<T> {
NumericArray::ref_cast(&*arg.numeric)
}
fn parameter_type() -> Expr {
let type_name = T::TYPE.name();
let ldt = crate::expr::expr!(System::LibraryDataType["NumericArray", type_name]);
crate::expr::expr!(System::List[ldt, "Constant"])
}
}
impl<'a, T: crate::NumericArrayType> FromArg<'a> for NumericArray<T> {
unsafe fn from_arg(arg: &'a MArgument) -> NumericArray<T> {
NumericArray::from_raw(*arg.numeric)
}
fn parameter_type() -> Expr {
let type_name = T::TYPE.name();
let ldt = crate::expr::expr!(System::LibraryDataType["NumericArray", type_name]);
crate::expr::expr!(System::List[ldt, "Shared"])
}
}
impl<'a> FromArg<'a> for &'a NumericArray<()> {
unsafe fn from_arg(arg: &'a MArgument) -> &'a NumericArray<()> {
NumericArray::ref_cast(&*arg.numeric)
}
fn parameter_type() -> Expr {
crate::expr::expr!(System::List["NumericArray", "Constant"])
}
}
impl<'a> FromArg<'a> for NumericArray<()> {
unsafe fn from_arg(arg: &'a MArgument) -> NumericArray<()> {
NumericArray::from_raw(*arg.numeric)
}
fn parameter_type() -> Expr {
crate::expr::expr!(System::List["NumericArray", "Shared"])
}
}
impl<'a, T: crate::ImageData> FromArg<'a> for &'a Image<T> {
unsafe fn from_arg(arg: &'a MArgument) -> &'a Image<T> {
Image::ref_cast(&*arg.image)
}
fn parameter_type() -> Expr {
let type_name = T::TYPE.name();
let alts = crate::expr::expr!(System::Alternatives["Image", "Image3D"]);
let ldt = crate::expr::expr!(System::LibraryDataType[alts, type_name]);
crate::expr::expr!(System::List[ldt, "Constant"])
}
}
impl<'a, T: crate::ImageData> FromArg<'a> for Image<T> {
unsafe fn from_arg(arg: &'a MArgument) -> Image<T> {
Image::from_raw(*arg.image)
}
fn parameter_type() -> Expr {
let type_name = T::TYPE.name();
let alts = crate::expr::expr!(System::Alternatives["Image", "Image3D"]);
let ldt = crate::expr::expr!(System::LibraryDataType[alts, type_name]);
crate::expr::expr!(System::List[ldt, "Shared"])
}
}
impl<'a> FromArg<'a> for &'a Image<()> {
unsafe fn from_arg(arg: &'a MArgument) -> &'a Image<()> {
Image::ref_cast(&*arg.image)
}
fn parameter_type() -> Expr {
let alts = crate::expr::expr!(System::Alternatives["Image", "Image3D"]);
crate::expr::expr!(System::List[alts, "Constant"])
}
}
impl<'a> FromArg<'a> for Image<()> {
unsafe fn from_arg(arg: &'a MArgument) -> Image<()> {
Image::from_raw(*arg.image)
}
fn parameter_type() -> Expr {
let alts = crate::expr::expr!(System::Alternatives["Image", "Image3D"]);
crate::expr::expr!(System::List[alts, "Shared"])
}
}
impl FromArg<'_> for DataStore {
unsafe fn from_arg(arg: &MArgument) -> DataStore {
DataStore::from_raw(*arg.tensor as sys::DataStore)
}
fn parameter_type() -> Expr {
Expr::string("DataStore")
}
}
impl<'a> FromArg<'a> for &'a DataStore {
unsafe fn from_arg(arg: &MArgument) -> &'a DataStore {
DataStore::ref_cast(&*(arg.tensor as *mut sys::DataStore))
}
fn parameter_type() -> Expr {
panic!("&DataStore cannot be used as a LibraryLink function parameter type")
}
}
impl IntoArg for () {
unsafe fn into_arg(self, _arg: MArgument) {
}
fn return_type() -> Expr {
Expr::string("Void")
}
}
impl IntoArg for bool {
unsafe fn into_arg(self, arg: MArgument) {
let boole: u32 = if self { sys::True } else { sys::False };
*arg.boolean = boole as sys::mbool;
}
fn return_type() -> Expr {
Expr::string("Boolean")
}
}
impl IntoArg for mint {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = self;
}
fn return_type() -> Expr {
expr!(System::Integer)
}
}
impl IntoArg for mreal {
unsafe fn into_arg(self, arg: MArgument) {
*arg.real = self;
}
fn return_type() -> Expr {
expr!(System::Real)
}
}
impl IntoArg for sys::mcomplex {
unsafe fn into_arg(self, arg: MArgument) {
*arg.cmplex = self;
}
fn return_type() -> Expr {
expr!(System::Complex)
}
}
impl IntoArg for i8 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
expr!(System::Integer)
}
}
impl IntoArg for i16 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
expr!(System::Integer)
}
}
impl IntoArg for i32 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
expr!(System::Integer)
}
}
impl IntoArg for u8 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
expr!(System::Integer)
}
}
impl IntoArg for u16 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
expr!(System::Integer)
}
}
#[cfg(target_pointer_width = "64")]
impl IntoArg for u32 {
unsafe fn into_arg(self, arg: MArgument) {
*arg.integer = mint::from(self);
}
fn return_type() -> Expr {
expr!(System::Integer)
}
}
thread_local! {
static RETURNED_STRING: RefCell<Option<CString>> = RefCell::new(None);
}
impl IntoArg for CString {
unsafe fn into_arg(self, arg: MArgument) {
let raw: *const c_char = RETURNED_STRING.with(|stored| {
if let Some(prev) = stored.replace(None) {
drop(prev);
}
let raw: *const c_char = self.as_ptr();
*stored.borrow_mut() = Some(self);
raw
});
*arg.utf8string = raw as *mut c_char;
}
fn return_type() -> Expr {
expr!(System::String)
}
}
impl IntoArg for String {
unsafe fn into_arg(self, arg: MArgument) {
let cstring = CString::new(self)
.expect("IntoArg for String: could not convert String to CString");
<CString as IntoArg>::into_arg(cstring, arg)
}
fn return_type() -> Expr {
expr!(System::String)
}
}
impl<T: crate::NumericArrayType> IntoArg for NumericArray<T> {
unsafe fn into_arg(self, arg: MArgument) {
*arg.numeric = self.into_raw();
}
fn return_type() -> Expr {
let type_name = T::TYPE.name();
crate::expr::expr!(System::LibraryDataType["NumericArray", type_name])
}
}
impl IntoArg for NumericArray<()> {
unsafe fn into_arg(self, arg: MArgument) {
*arg.numeric = self.into_raw();
}
fn return_type() -> Expr {
crate::expr::expr!("NumericArray")
}
}
impl<T: crate::ImageData> IntoArg for Image<T> {
unsafe fn into_arg(self, arg: MArgument) {
*arg.image = self.into_raw();
}
fn return_type() -> Expr {
let type_name = T::TYPE.name();
let alts = crate::expr::expr!(System::Alternatives["Image", "Image3D"]);
let ldt = crate::expr::expr!(System::LibraryDataType[alts, type_name]);
crate::expr::expr!(System::List[ldt, "Shared"])
}
}
impl IntoArg for DataStore {
unsafe fn into_arg(self, arg: MArgument) {
*arg.tensor = self.into_raw() as *mut _;
}
fn return_type() -> Expr {
Expr::string("DataStore")
}
}
impl<'a: 'b, 'b> NativeFunction<'a> for fn(&'b [MArgument], MArgument) {
unsafe fn call(&self, args: &'a [MArgument], ret: MArgument) {
self(args, ret)
}
fn signature(&self) -> Result<(Vec<Expr>, Expr), String> {
Err(
"fn(&[MArgument], MArgument) function cannot be loaded automatically: \
parameter and return types are unknown."
.to_owned(),
)
}
}
macro_rules! impl_NativeFunction {
($($type:ident),*) => {
impl<'a, $($type,)* R> NativeFunction<'a> for fn($($type),*) -> R
where
R: IntoArg,
$($type: FromArg<'a>),*
{
unsafe fn call(&self, args: &'a [MArgument], ret: MArgument) {
#[allow(non_snake_case)]
let [$($type,)*] = match args {
[$($type,)*] => [$($type,)*],
_ => panic!(
"LibraryLink function number of arguments ({}) does not match \
number of parameters",
args.len()
),
};
$(
#[allow(non_snake_case)]
let $type: $type = $type::from_arg($type);
)*
let result: R = self($($type,)*);
result.into_arg(ret);
}
fn signature(&self) -> Result<(Vec<Expr>, Expr), String> {
let mut param_tys = Vec::new();
$(
param_tys.push($type::parameter_type());
)*
Ok((param_tys, R::return_type()))
}
}
}
}
impl<'a, R> NativeFunction<'a> for fn() -> R
where
R: IntoArg,
{
unsafe fn call(&self, args: &[MArgument], ret: MArgument) {
if args.len() != 0 {
panic!(
"LibraryLink function number of arguments ({}) does not match number of \
parameters",
args.len()
);
}
let result = self();
result.into_arg(ret);
}
fn signature(&self) -> Result<(Vec<Expr>, Expr), String> {
Ok((Vec::new(), R::return_type()))
}
}
impl_NativeFunction!(A1);
impl_NativeFunction!(A1, A2);
impl_NativeFunction!(A1, A2, A3);
impl_NativeFunction!(A1, A2, A3, A4);
impl_NativeFunction!(A1, A2, A3, A4, A5);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8, A9);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11);
impl_NativeFunction!(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12);
#[cfg(feature = "wstp")]
mod wstp_impls {
use super::*;
impl WstpFunction for fn(&mut Link) {
unsafe fn call(&self, link: &mut Link) {
self(link)
}
}
impl WstpFunction for fn(Vec<Expr>) -> Expr {
unsafe fn call(&self, link: &mut Link) {
let args: Vec<Expr> = match get_args_list(link) {
Ok(args) => args,
Err(err) => return write_arg_failure(link, err),
};
let result: Expr = self(args);
if let Err(err) = link.put_expr(&result) {
write_arg_failure(link, err.into());
}
}
}
impl WstpFunction for fn(Vec<Expr>) {
unsafe fn call(&self, link: &mut Link) {
let args: Vec<Expr> = match get_args_list(link) {
Ok(args) => args,
Err(err) => return write_arg_failure(link, err),
};
let _null: () = self(args);
if let Err(err) = link.put_symbol("System`Null") {
write_arg_failure(link, err.into());
}
}
}
fn write_arg_failure(link: &mut Link, err: crate::LibraryError) {
let _ = crate::macro_utils::write_failure_to_link(link, &err);
}
fn get_args_list(link: &mut Link) -> Result<Vec<Expr>, crate::LibraryError> {
Ok(get_args_list_impl(link)?)
}
fn get_args_list_impl(link: &mut Link) -> Result<Vec<Expr>, wstp::Error> {
let arg_count: usize = match link.test_head("List") {
Ok(count) => Ok(count),
Err(err) if err.code() == Some(wstp::sys::WSEGSEQ) => {
link.clear_error();
link.test_head("System`List")
},
Err(err) => Err(err),
}?;
let mut elements: Vec<Expr> = Vec::new();
for _ in 0..arg_count {
let elem = link.get_expr_with_resolver(&mut |name| {
Symbol::try_new(&format!("System`{name}"))
})?;
elements.push(elem);
}
Ok(elements)
}
}