use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn, ItemForeignMod, ForeignItem, Expr, FnArg, Pat, ReturnType, Ident};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::Error;
struct ModuleFnMacroArgs {
name: Option<Expr>,
}
impl Parse for ModuleFnMacroArgs {
fn parse(input: ParseStream) -> Result<Self, Error> {
let mut name = None;
let args = Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input)?;
for arg in args {
if let syn::Meta::NameValue(nv) = arg {
if nv.path.is_ident("name") {
name = Some(nv.value);
}
}
}
Ok(ModuleFnMacroArgs { name })
}
}
struct HostFnMacroArgs {
namespace: Option<Expr>,
}
impl Parse for HostFnMacroArgs {
fn parse(input: ParseStream) -> Result<Self, Error> {
let mut namespace = None;
let args = Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input)?;
for arg in args {
if let syn::Meta::NameValue(nv) = arg {
if nv.path.is_ident("namespace") {
namespace = Some(nv.value);
}
}
}
Ok(HostFnMacroArgs { namespace })
}
}
#[proc_macro_attribute]
pub fn mod_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let original_fn_name = &input_fn.sig.ident;
let macro_args = match syn::parse::<ModuleFnMacroArgs>(attr) {
Ok(args) => args,
Err(e) => return e.to_compile_error().into(),
};
let generated_fn_name = match macro_args.name {
Some(Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. })) => {
Ident::new(&lit_str.value(), lit_str.span())
},
_ => original_fn_name.clone(),
};
let mut arg_idents = Vec::new();
let mut fn_args = Vec::new();
for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
if let FnArg::Typed(pat_type) = arg {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
let arg_name = &pat_ident.ident;
let arg_type = &pat_type.ty;
arg_idents.push(arg_name.clone());
fn_args.push(quote! {
let #arg_name = match input.get_arg::<#arg_type>(#i, stringify!(#arg_name)) {
Ok(val) => val,
Err(e) => return Err(::binmod_mdk::ModuleFnErr {
message: e.to_string(),
error_type: "ArgumentError".into(),
}),
};
});
}
}
}
let fn_return_type = match &input_fn.sig.output {
ReturnType::Default => quote! { ::binmod_mdk::ModuleFnResult::Data(
::binmod_mdk::ModuleFnReturn::empty()
) },
_ => quote! { ::binmod_mdk::ModuleFnResult::Data(
::binmod_mdk::ModuleFnReturn::new_serialized(result).unwrap()
) },
};
let arg_idents_tokens = arg_idents
.iter()
.map(|ident| quote! { #ident });
let expanded = quote! {
#input_fn
#[unsafe(no_mangle)]
pub unsafe extern "C" fn #generated_fn_name(input_ptr: u32, input_len: u32) -> u64 {
let input: ::binmod_mdk::ModuleFnInput = match ::binmod_mdk::deserialize_from_ptr(input_ptr, input_len) {
Ok(input) => input,
Err(e) => {
return match ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
::binmod_mdk::ModuleFnErr {
message: e.to_string(),
error_type: "DeserializationError".into(),
}
)) {
Ok(ptr) => ptr,
Err(e) => ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
::binmod_mdk::ModuleFnErr {
message: e.to_string(),
error_type: "SerializationError".into(),
}
)).unwrap_or(0),
}
}
};
let result = std::panic::catch_unwind(|| -> ::binmod_mdk::FnResult<_> {
#(#fn_args)*
#original_fn_name(#(#arg_idents_tokens),*)
});
let response = match result {
Ok(Ok(result)) => #fn_return_type,
Ok(Err(e)) => ::binmod_mdk::ModuleFnResult::Error(e),
Err(_) => ::binmod_mdk::ModuleFnResult::Error(
::binmod_mdk::ModuleFnErr {
message: "Panic occurred".into(),
error_type: "PanicError".into(),
}
),
};
match ::binmod_mdk::serialize_to_ptr(response) {
Ok(ptr) => ptr,
Err(e) => {
::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
::binmod_mdk::ModuleFnErr {
message: e.to_string(),
error_type: "SerializationError".into(),
}
)).unwrap_or(0)
}
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn host_fns(attr: TokenStream, item: TokenStream) -> TokenStream {
let macro_args = match syn::parse::<HostFnMacroArgs>(attr) {
Ok(args) => args,
Err(e) => return e.to_compile_error().into(),
};
let namespace = match macro_args.namespace {
Some(Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. })) => lit_str,
_ => syn::LitStr::new("env", proc_macro2::Span::call_site()),
};
let item = parse_macro_input!(item as ItemForeignMod);
let functions = item.items;
if item.abi.name.is_none() || item.abi.name.unwrap().value() != "host" {
panic!("Host functions must be in a foreign module with the `host` ABI");
}
let mut generated = quote! {};
for function in functions {
if let ForeignItem::Fn(func) = function {
let func_name = &func.sig.ident;
let raw_func_name = Ident::new(&format!("{}_raw", func_name), func_name.span());
let link_name_lit = syn::LitStr::new(&func_name.to_string(), func_name.span());
let params = func
.sig
.inputs
.iter()
.cloned()
.collect::<Vec<_>>();
let param_names = params
.iter()
.map(|param| {
if let FnArg::Typed(pat_type) = param {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
&pat_ident.ident
} else {
panic!("Expected identifier in function argument");
}
} else {
panic!("Expected typed argument in function signature");
}
})
.collect::<Vec<_>>();
let inner_return_type = match &func.sig.output {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
};
let wrapper = quote! {
#[allow(unused_unsafe)]
pub unsafe fn #func_name(#(#params),*) -> ::binmod_mdk::FnResult<#inner_return_type> {
let mut input = ::binmod_mdk::ModuleFnInput::new();
#(
match input.add_arg(#param_names) {
Ok(_) => {},
Err(e) => {
return Err(::binmod_mdk::ModuleFnErr {
message: e.to_string(),
error_type: "ArgumentError".into(),
});
}
}
)*
let input_ptr = match ::binmod_mdk::serialize_to_ptr(input) {
Ok(ptr) => ptr,
Err(e) => {
return Err(::binmod_mdk::ModuleFnErr {
message: e.to_string(),
error_type: "SerializationError".into(),
});
}
};
let result = unsafe { #raw_func_name(input_ptr) };
let (result_ptr, result_len) = ::binmod_mdk::unpack_ptr(result);
let result: ::binmod_mdk::ModuleFnResult<#inner_return_type> = match ::binmod_mdk::deserialize_from_ptr(
result_ptr as u32,
result_len as u32,
) {
Ok(res) => res,
Err(e) => {
unsafe {
host_dealloc(result_ptr as *mut u8, result_len as usize);
}
return Err(::binmod_mdk::ModuleFnErr {
message: e.to_string(),
error_type: "DeserializationError".into(),
});
}
};
unsafe {
host_dealloc(result_ptr as *mut u8, result_len as usize);
}
match result {
::binmod_mdk::ModuleFnResult::Data(data) => {
match data.value {
Some(value) => Ok(value),
None => Ok(Default::default()),
}
},
::binmod_mdk::ModuleFnResult::Error(err) => Err(err),
}
}
};
generated.extend(wrapper);
generated.extend(quote! {
#[link(wasm_import_module = #namespace)]
unsafe extern "C" {
#[link_name = #link_name_lit]
pub fn #raw_func_name(input_ptr: u64) -> u64;
}
});
}
}
generated.extend(quote! {
#[link(wasm_import_module = #namespace)]
unsafe extern "C" {
pub fn host_alloc(len: usize) -> *mut u8;
pub fn host_dealloc(ptr: *mut u8, len: usize);
}
});
generated.into()
}