use crate::utils;
use proc_macro2::Ident;
use proc_macro2::Span;
use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::parse::Error;
use syn::parse::Parse;
use syn::parse::ParseStream;
use syn::parse::Result;
use syn::parse_quote;
use syn::spanned::Spanned;
use syn::Attribute;
use syn::FnArg;
use syn::ItemFn;
use syn::ItemMod;
use syn::ItemTrait;
use syn::Pat;
use syn::Signature;
use syn::Token;
use syn::TraitItemMethod;
use syn::Type;
#[derive(Clone)]
pub struct Args {
imports: String,
}
impl Parse for Args {
fn parse(input: ParseStream<'_>) -> Result<Self> {
match try_parse(input) {
Ok(args) if input.is_empty() => Ok(args),
Ok(_) | Err(_) => Err(Error::new(
Span::call_site(),
"expected #[ark_bindgen(imports = \"<name>\"]",
)),
}
}
}
mod kw {
syn::custom_keyword!(imports);
}
fn try_parse(input: ParseStream<'_>) -> Result<Args> {
if input.peek(kw::imports) {
input.parse::<kw::imports>()?;
input.parse::<Token![=]>()?;
let name = input.parse::<proc_macro2::Literal>()?;
let name = name.to_string();
let name = &name[1..name.len() - 1];
Ok(Args {
imports: name.to_owned(),
})
} else {
Err(Error::new(Span::call_site(), "no imports"))
}
}
pub struct Item(ItemMod);
impl Parse for Item {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let mut lookahead = input.lookahead1();
if lookahead.peek(Token![pub]) {
let ahead = input.fork();
ahead.parse::<Token![pub]>()?;
lookahead = ahead.lookahead1();
}
if lookahead.peek(Token![mod]) {
let mut item: ItemMod = input.parse()?;
item.attrs = attrs;
Ok(Self(item))
} else {
Err(lookahead.error())
}
}
}
impl ToTokens for Item {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.0.to_tokens(tokens);
}
}
struct Context {
mod_name: Ident,
}
pub fn expand(input: &mut Item, args: Args) -> Result<()> {
validate_input_and_output_types(&input.0)?;
validate_types(&input.0)?;
let ctx = Context {
mod_name: input.0.ident.clone(),
};
let items = extract_items_from_input(input)?;
let extern_enums = translate_enums(items.1)?;
let extern_funcs = expand_extern_functions(items.0, &extern_enums, &ctx, &args)?;
expand_safe_mod(&extern_funcs, &extern_enums, input);
expand_host_shim(&extern_funcs, &extern_enums, input, &ctx, &args)?;
expand_use(&extern_enums, input);
Ok(())
}
fn validate_input_and_output_types(item_mod: &ItemMod) -> Result<()> {
use syn::visit;
struct FindSigs(FindTypes);
impl<'a> visit::Visit<'a> for FindSigs {
fn visit_signature(&mut self, node: &'a Signature) {
visit::visit_signature(&mut self.0, node);
}
}
struct FindTypes(Option<Error>);
impl<'a> visit::Visit<'a> for FindTypes {
fn visit_fn_arg(&mut self, node: &'a FnArg) {
visit::visit_fn_arg(self, node);
if let syn::FnArg::Typed(pt) = node {
if let syn::Type::Path(path) = &*pt.ty {
for (invalid_type, suggested_type) in [
("bool", "u32"),
("u16", "u32"),
("u8", "u32"),
("i16", "i32"),
("i8", "i32"),
] {
if path.path.is_ident(invalid_type) {
self.0 = Some(Error::new_spanned(
path,
format!("FFI parameters of type `{invalid_type}` are not supported at the moment. Pass it as a `{suggested_type}` instead.")
));
break; }
}
}
}
}
}
let mut sigs = FindSigs(FindTypes(None));
visit::visit_item_mod(&mut sigs, item_mod);
if let Some(error) = (sigs.0).0 {
Err(error)
} else {
Ok(())
}
}
fn validate_types(item_mod: &ItemMod) -> Result<()> {
use syn::visit;
use syn::visit::Visit;
const DISALLOWED_TYPES: &[&str] = &["usize", "isize", "u128", "i128"];
struct TypeVisitor(Option<Error>);
impl<'a> visit::Visit<'a> for TypeVisitor {
fn visit_type(&mut self, t: &'a Type) {
visit::visit_type(self, t);
if let Type::Path(p) = t {
for ident in DISALLOWED_TYPES {
if p.path.is_ident(ident) {
let e = Error::new_spanned(p, format!("`{ident}` is not FFI-safe"));
match &mut self.0 {
Some(e2) => e2.combine(e),
None => self.0 = Some(e),
}
}
}
}
}
}
let mut visitor = TypeVisitor(None);
visitor.visit_item_mod(item_mod);
match visitor.0 {
Some(e) => Err(e),
None => Ok(()),
}
}
enum PreciseType {
ByteVec,
String,
}
impl PreciseType {
fn make_type_asserts(
&self,
fallible_mode: FallibleMode,
original_result_type: &syn::Type,
) -> syn::Stmt {
let expected_type = match self {
Self::ByteVec => quote::quote!(::std::vec::Vec<u8>),
Self::String => quote::quote!(::std::string::String),
};
let expected_type = if fallible_mode.is_fallible() {
quote::quote!(crate::error_code::FFIResult<#expected_type>)
} else {
expected_type
};
parse_quote!(
static_assertions::assert_type_eq_all!(
#original_result_type,
#expected_type
);
)
}
}
enum ReturnType {
PreciseType(PreciseType),
GenericPod(syn::TypePath),
UnitType,
}
impl ReturnType {
fn make_type_asserts(
&self,
fallible_mode: FallibleMode,
original_return_type: &syn::Type,
) -> syn::Stmt {
match self {
Self::PreciseType(byte_vec) => {
byte_vec.make_type_asserts(fallible_mode, original_return_type)
}
Self::GenericPod(plain_ty) => {
let expected_type_assert = if fallible_mode.is_fallible() {
Some(quote::quote!(
static_assertions::assert_type_eq_all!(
crate::error_code::FFIResult<#plain_ty>,
#original_return_type
);
))
} else {
None
};
parse_quote!({
trait ValidReturn {}
impl<T: Copy + Clone> ValidReturn for T {}
static_assertions::assert_impl_all!(#plain_ty: ValidReturn);
#expected_type_assert
})
}
Self::UnitType => {
assert!(fallible_mode.is_fallible());
parse_quote!({
static_assertions::assert_type_eq_all!(
crate::error_code::FFIResult<()>,
#original_return_type
);
})
}
}
}
}
#[derive(Clone, Copy)]
enum FallibleMode {
Fallible,
DeprecatedInfallible,
Infallible,
}
impl FallibleMode {
fn is_fallible(self) -> bool {
matches!(self, Self::Fallible)
}
fn check_ffi_return(self, res_code: syn::Ident) -> Option<syn::Stmt> {
match self {
Self::Fallible => Some(parse_quote!(if #res_code != crate::ErrorCode::Success {
return Err(#res_code);
})),
Self::DeprecatedInfallible => Some(parse_quote!(assert_eq!(
#res_code,
crate::ErrorCode::Success,
"unexpected error in deprecated infallible function"
);)),
Self::Infallible => None,
}
}
fn ensure_ffi_success(self, res_code: syn::Ident) -> syn::Stmt {
match self {
Self::Fallible => {
parse_quote!(if #res_code != crate::ErrorCode::Success {
return Err(#res_code);
})
}
Self::DeprecatedInfallible | Self::Infallible => {
parse_quote!(assert_eq!(
#res_code,
crate::ErrorCode::Success,
"unexpected error in deprecated infallible function"
);)
}
}
}
fn return_result(self, result: syn::Ident) -> syn::Stmt {
match self {
Self::Fallible => parse_quote!(return Ok(#result);),
Self::DeprecatedInfallible | Self::Infallible => {
parse_quote!(return #result;)
}
}
}
fn return_unit(self) -> Option<syn::Stmt> {
match self {
Self::Fallible => Some(parse_quote!(return Ok(());)),
Self::DeprecatedInfallible | Self::Infallible => None,
}
}
}
struct ExternFn {
sig: Signature,
attrs: Vec<Attribute>,
with_memory: bool,
fallible_mode: FallibleMode,
return_type: ReturnType,
params: Vec<FfiParam>,
ffi_ident: Ident,
}
impl ExternFn {
fn make_type_asserts(&self) -> Option<syn::Stmt> {
let original_return_type = if let syn::ReturnType::Type(_, typ) = &self.sig.output {
&**typ
} else {
assert!(matches!(self.return_type, ReturnType::UnitType));
return None;
};
Some(
self.return_type
.make_type_asserts(self.fallible_mode, original_return_type),
)
}
}
pub(crate) struct ExternEnum {
pub ident: Ident,
ffi_ty: Type,
}
fn extract_items_from_input(
input: &mut Item,
) -> Result<(
&mut syn::ItemForeignMod,
Vec<&mut syn::ItemEnum>,
Vec<&syn::ItemStruct>,
)> {
let foreign_item;
let mut enum_items = Vec::new();
let mut struct_items = Vec::new();
let cloned_input = input.0.clone();
match &mut input.0.content {
Some(items) => {
let mut foreign = None;
for item in &mut items.1 {
match item {
syn::Item::ForeignMod(fmod) => match foreign {
Some(_) => {
return Err(Error::new(
fmod.span(),
"an extern module was already declared in this module",
));
}
None => foreign = Some(fmod),
},
syn::Item::Enum(ienum) => enum_items.push(ienum),
syn::Item::Struct(istruct) => struct_items.push(istruct as &syn::ItemStruct),
_ => {}
}
}
if let Some(fmod) = foreign {
foreign_item = fmod;
} else {
return Err(Error::new(
cloned_input.span(),
"the module doesn't contain an extern module",
));
}
}
None => {
return Err(Error::new(
cloned_input.span(),
"can't bindgen an empty module",
))
}
}
Ok((foreign_item, enum_items, struct_items))
}
fn translate_enums(enum_items: Vec<&mut syn::ItemEnum>) -> Result<Vec<ExternEnum>> {
let mut extern_enums = Vec::with_capacity(enum_items.len());
for enum_item in enum_items {
let mut enum_ffi_ident = None;
for attribute in &enum_item.attrs {
let ident = utils::search_enum_attribute(
attribute,
"repr",
&[
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128",
],
);
if ident.is_some() {
if enum_ffi_ident.is_none() {
enum_ffi_ident = ident;
} else {
return Err(Error::new(
attribute.span(),
"#[repr(primitive_type)] was already declared for this enum",
));
}
}
if utils::search_enum_attribute(
attribute,
"derive",
&["TryFromPrimitive", "IntoPrimitive"],
)
.is_some()
{
return Err(Error::new(attribute.span(),
"IntoPrimitive and TryFromPrimitive are generated by the ark_bindgen macro, don't declare them yourself!"));
}
if utils::search_enum_attribute(attribute, "derive", &["NoUninit", "CheckedBitPattern"])
.is_some()
{
return Err(Error::new(attribute.span(),
"For enums, NoUninit and CheckedBitPattern are generated by the ark_bindgen macro, don't declare them yourself!"));
}
}
let Some(ffi_ident) = enum_ffi_ident else {
return Err(Error::new(
enum_item.span(),
format!(
"{} enum is missing the '#[repr(primitive_type)]' attribute",
enum_item.ident
),
));
};
enum_item.attrs.push(
parse_quote!(#[derive(IntoPrimitive, TryFromPrimitive, NoUninit, CheckedBitPattern)]),
);
extern_enums.push(ExternEnum {
ident: enum_item.ident.clone(),
ffi_ty: parse_quote!(#ffi_ident),
});
}
Ok(extern_enums)
}
fn expand_extern_functions(
foreign_mod: &mut syn::ItemForeignMod,
extern_enums: &[ExternEnum],
ctx: &Context,
args: &Args,
) -> Result<Vec<ExternFn>> {
let imports = &args.imports;
foreign_mod
.attrs
.push(parse_quote!(#[link(wasm_import_module = #imports)]));
let extern_items = &mut foreign_mod.items;
let mut externs = Vec::new();
let with_memory_attr: Attribute = parse_quote!(#[with_memory]);
let deprecated_infallible_attr: Attribute = parse_quote!(#[deprecated_infallible]);
for func in extern_items.iter_mut().filter_map(|fi| {
if let syn::ForeignItem::Fn(func) = fi {
Some(func)
} else {
None
}
}) {
let mut with_memory = false;
let mut deprecated_infallible = false;
func.attrs.retain(|attr| {
if utils::has_custom_attribute(attr, &with_memory_attr) {
with_memory = true;
false
} else if utils::has_custom_attribute(attr, &deprecated_infallible_attr) {
deprecated_infallible = true;
false
} else {
true
}
});
let original_sig = func.sig.clone();
let (params, return_type, fallible_mode) = to_ffi_sig(
&ctx.mod_name,
&mut func.sig,
extern_enums,
deprecated_infallible,
)?;
externs.push(ExternFn {
sig: original_sig,
attrs: func.attrs.clone(),
with_memory,
fallible_mode,
return_type,
params,
ffi_ident: func.sig.ident.clone(),
});
}
Ok(externs)
}
fn to_ffi_sig(
namespace: &Ident,
sig: &mut Signature,
extern_enums: &[ExternEnum],
deprecated_infallible: bool,
) -> Result<(Vec<FfiParam>, ReturnType, FallibleMode)> {
let mut params = Vec::new();
sig.ident = syn::Ident::new(&format!("{namespace}__{}", sig.ident), sig.ident.span());
for param in &sig.inputs {
match param {
FnArg::Typed(pt) => {
if let Pat::Ident(pat) = &*pt.pat {
convert_ffi_param(&mut params, &pat.ident, pt.ty.as_ref(), extern_enums)?;
} else {
return Err(Error::new(pt.span(), "parameter is missing an identifier"));
}
}
other @ FnArg::Receiver(_) => {
return Err(Error::new(
other.span(),
"you're trying to pass a self via FFI, this will not work",
));
}
}
}
let (return_type, infallible) = return_to_ffi(sig, &mut params)?;
if deprecated_infallible && !infallible {
return Err(Error::new(
sig.output.span(),
"#[deprecated_infallible] only applies to plain, non-Result return types",
));
}
let fallible_mode = if deprecated_infallible {
FallibleMode::DeprecatedInfallible
} else if infallible {
FallibleMode::Infallible
} else {
FallibleMode::Fallible
};
sig.output = match fallible_mode {
FallibleMode::Fallible | FallibleMode::DeprecatedInfallible => {
parse_quote!(-> crate::ErrorCode)
}
FallibleMode::Infallible => syn::ReturnType::Default,
};
sig.inputs = params.iter().map(|fp| fp.param.clone()).collect();
Ok((params, return_type, fallible_mode))
}
fn return_to_ffi(sig: &Signature, params: &mut Vec<FfiParam>) -> Result<(ReturnType, bool)> {
let ty = match &sig.output {
syn::ReturnType::Default => {
return Ok((ReturnType::UnitType, true));
}
syn::ReturnType::Type(_, ty) => match ty.as_ref() {
Type::Path(tp) => tp.clone(),
_ => {
return Err(Error::new(
ty.span(),
"unhandled return kind: can only return plain types",
));
}
},
};
if utils::type_path_ends_with(&ty, "ErrorCode") {
return Err(Error::new(
ty.span(),
"Returning ErrorCode is deprecated, return FFIResult<()> instead",
));
}
let (ret_type_path, infallible) = if utils::type_path_ends_with(&ty, "FFIResult") {
let ok_type = utils::extract_single_generic_type(&ty)?;
match ok_type {
Some(type_path) => (type_path, false),
None => {
return Ok((ReturnType::UnitType, false));
}
}
} else if utils::type_path_ends_with(&ty, "Result") {
match utils::extract_generic_type(1, Some(2), &ty)? {
Some(err_ty) if utils::type_path_ends_with(&err_ty, "ErrorCode") => {}
_ => return Err(Error::new(ty.span(), "Result error type must be ErrorCode")),
};
let ok_type = utils::extract_first_generic_type(&ty)?;
match ok_type {
Some(type_path) => (type_path, false),
None => {
return Ok((ReturnType::UnitType, false));
}
}
} else {
(ty, true)
};
let (out_param, ret_type, reserved_param_name) = convert_ffi_result(ret_type_path)?;
let should_rename = params.iter().find(|p| {
if let FnArg::Typed(pat) = &p.param {
if let Pat::Ident(pat) = &*pat.pat {
return pat.ident == reserved_param_name;
}
}
false
});
if let Some(should_rename) = should_rename {
return Err(Error::new(
should_rename.param.span(),
"this name is reserved, please rename!",
));
}
params.push(out_param);
Ok((ret_type, infallible))
}
struct FfiParam {
param: syn::FnArg,
export_type: Option<Type>,
to_ffi: syn::Expr,
}
fn is_str(ty: &syn::Type) -> bool {
if let Type::Path(tp) = ty {
match tp.path.get_ident() {
None => false,
Some(id) => {
let idents = id.to_string();
idents == "str"
}
}
} else {
false
}
}
fn convert_ffi_result(tp: syn::TypePath) -> Result<(FfiParam, ReturnType, &'static str)> {
if utils::type_path_ends_with(&tp, "String") {
Ok((
FfiParam {
param: parse_quote!(__ark_byte_size: *mut u32),
to_ffi: parse_quote!(&mut __ark_byte_size),
export_type: Some(parse_quote!(u32)),
},
ReturnType::PreciseType(PreciseType::String),
"__ark_byte_size",
))
} else if utils::type_path_ends_with(&tp, "Vec") {
match utils::extract_single_generic_type(&tp)? {
Some(tp) if utils::type_path_ends_with(&tp, "u8") => {}
_ => {
return Err(Error::new(
tp.span(),
"only supported return Vec type is Vec<u8>",
))
}
};
Ok((
FfiParam {
param: parse_quote!(__ark_byte_size: *mut u32),
export_type: Some(parse_quote!(u32)),
to_ffi: parse_quote!(&mut __ark_byte_size),
},
ReturnType::PreciseType(PreciseType::ByteVec),
"__ark_byte_size",
))
} else {
Ok((
FfiParam {
param: parse_quote!(__ark_ffi_output: *mut #tp),
export_type: Some(parse_quote!(u32)),
to_ffi: parse_quote!(&mut __ark_ffi_output),
},
ReturnType::GenericPod(parse_quote!(#tp)),
"__ark_ffi_output",
))
}
}
fn convert_ffi_param(
params: &mut Vec<FfiParam>,
ident: &syn::Ident,
ty: &syn::Type,
extern_enums: &[ExternEnum],
) -> Result<()> {
match ty {
syn::Type::Path(tp) => {
let (param, to_ffi) = match utils::type_path_is_enum(tp, extern_enums) {
Some(ee) => {
let enum_ty = &ee.ffi_ty;
(parse_quote!(#ident: #enum_ty), parse_quote!(#ident.into()))
}
None => {
if utils::type_path_ends_with(tp, "Vec") {
return Err(syn::Error::new(
ty.span(),
"Vec not supported in function parameter position, use slices instead",
));
} else {
(parse_quote!(#ident: #ty), parse_quote!(#ident))
}
}
};
params.push(FfiParam {
param,
to_ffi,
export_type: None,
});
}
syn::Type::Reference(tr) => {
let is_mut = tr.mutability.is_some();
if is_str(tr.elem.as_ref()) {
if is_mut {
return Err(syn::Error::new(
tr.span(),
"&mut str is not allowed, consider returning a String instead!",
));
}
let ident_ptr = syn::Ident::new(&format!("{ident}_ptr"), ident.span());
let ident_len = syn::Ident::new(&format!("{ident}_len"), ident.span());
params.push(FfiParam {
param: parse_quote!(#ident_ptr: *const u8),
to_ffi: parse_quote!(#ident.as_ptr()),
export_type: Some(parse_quote!(u32)),
});
params.push(FfiParam {
param: parse_quote!(#ident_len: u32),
to_ffi: parse_quote!(#ident.len() as u32),
export_type: None,
});
} else if let syn::Type::Slice(inner) = tr.elem.as_ref() {
if let syn::Type::Path(tp) = inner.elem.as_ref() {
let ident_ptr = syn::Ident::new(&format!("{ident}_ptr"), ident.span());
let ident_len = syn::Ident::new(&format!("{ident}_len"), ident.span());
params.push(FfiParam {
param: if is_mut {
parse_quote!(#ident_ptr: *mut #tp)
} else {
parse_quote!(#ident_ptr: *const #tp)
},
to_ffi: if is_mut {
parse_quote!(#ident.as_mut_ptr())
} else {
parse_quote!(#ident.as_ptr())
},
export_type: Some(parse_quote!(u32)),
});
params.push(FfiParam {
param: parse_quote!(#ident_len: u32),
to_ffi: parse_quote!(#ident.len() as u32),
export_type: None,
});
} else {
return Err(Error::new(tr.elem.span(), "not a simple type path"));
}
} else if let syn::Type::Path(tp) = tr.elem.as_ref() {
let ident_ptr = syn::Ident::new(&format!("{ident}_ptr"), ident.span());
params.push(FfiParam {
param: if is_mut {
parse_quote!(#ident_ptr: *mut #tp)
} else {
parse_quote!(#ident_ptr: *const #tp)
},
to_ffi: parse_quote!(#ident),
export_type: Some(parse_quote!(u32)),
});
} else {
return Err(Error::new(tr.span(), "this type is not supported"));
}
}
_ => return Err(Error::new(ty.span(), "this type is not supported")),
}
Ok(())
}
fn expand_safe_mod(functions: &[ExternFn], enums: &[ExternEnum], input: &mut Item) {
let mut safe_funcs = Vec::with_capacity(functions.len());
for func in functions {
let sig = &func.sig;
let mut safe_func: ItemFn = parse_quote!(
#[inline]
pub #sig {
}
);
safe_func.attrs.extend_from_slice(&func.attrs);
let args = {
let mut args: syn::punctuated::Punctuated<syn::Expr, syn::token::Comma> =
syn::punctuated::Punctuated::new();
for param in &func.params {
args.push(param.to_ffi.clone());
}
args
};
let ffi_ident = &func.ffi_ident;
let type_asserts = func.make_type_asserts();
safe_func.block = match &func.return_type {
ReturnType::PreciseType(inner) => {
let convert_buffer: Option<syn::Stmt> = match inner {
PreciseType::ByteVec => None,
PreciseType::String => Some(if func.fallible_mode.is_fallible() {
parse_quote!(
let buffer = String::from_utf8(buffer)
.map_err(|_decode_err| crate::ErrorCode::InternalError)?;
)
} else {
parse_quote!(
let buffer = String::from_utf8(buffer)
.expect("invalid utf8 bytes returned by infallible host function");
)
}),
};
let check_ffi_return = func.fallible_mode.check_ffi_return(parse_quote!(res_code));
let check_core_result = func
.fallible_mode
.ensure_ffi_success(parse_quote!(res_code));
let return_expr = func.fallible_mode.return_result(parse_quote!(buffer));
parse_quote!({
#type_asserts;
let mut __ark_byte_size = 0;
let res_code = unsafe { #ffi_ident(#args) };
#check_ffi_return
let mut buffer = vec![0; __ark_byte_size as usize];
let res_code = unsafe {
crate::core_v4::core__take_host_return_vec(
buffer.as_mut_slice().as_mut_ptr(),
__ark_byte_size
)
};
#check_core_result
#convert_buffer
#return_expr
})
}
ReturnType::GenericPod(type_path) => {
let mut output: syn::Expr = parse_quote!(#type_path::default());
if let Some(ee) = utils::type_path_is_enum(type_path, enums) {
if let Type::Path(tp) = &ee.ffi_ty {
let ident = &ee.ident;
output = if func.fallible_mode.is_fallible() {
parse_quote!(
#ident::try_from(#tp::default())
.map_err(|_| crate::ErrorCode::InvalidArguments)?
)
} else {
parse_quote!(
#ident::try_from(#tp::default())
.expect("invalid enum value")
)
};
}
}
let check_ffi_return = func.fallible_mode.check_ffi_return(parse_quote!(res_code));
let return_expr = func
.fallible_mode
.return_result(parse_quote!(__ark_ffi_output));
parse_quote!({
#type_asserts;
let mut __ark_ffi_output = #output;
let res_code = unsafe { #ffi_ident(#args) };
#check_ffi_return
#return_expr
})
}
ReturnType::UnitType => {
let check_ffi_return = func.fallible_mode.check_ffi_return(parse_quote!(res_code));
let return_unit = func.fallible_mode.return_unit();
parse_quote!({
#type_asserts;
let res_code = unsafe { #ffi_ident(#args) };
#check_ffi_return
#return_unit
})
}
};
safe_funcs.push(syn::Item::Fn(safe_func));
}
input.0.content.as_mut().unwrap().1.append(&mut safe_funcs);
}
fn expand_host_shim(
functions: &[ExternFn],
extern_enums: &[ExternEnum],
input: &mut Item,
ctx: &Context,
args: &Args,
) -> Result<()> {
let mut shim_trait: ItemTrait = parse_quote!(
#[cfg(not(target_arch = "wasm32"))]
pub trait HostShim<'t> {
type Memory;
type Context;
type WasmLinker;
type ImportError;
type ApiError;
}
);
for func in functions {
let sig = &func.sig;
{
let method_ident = Ident::new(&format!("{}_shim", sig.ident), sig.span());
let ok_type: syn::Type = match &func.return_type {
ReturnType::PreciseType(inner) => match inner {
PreciseType::ByteVec => parse_quote!(Vec<u8>),
PreciseType::String => parse_quote!(String),
},
ReturnType::GenericPod(ty) => parse_quote!(#ty),
ReturnType::UnitType => parse_quote!(()),
};
let shim_return: syn::Type = if func.fallible_mode.is_fallible() {
parse_quote!(Result<#ok_type, Self::ApiError>)
} else {
parse_quote!(anyhow::Result<#ok_type>)
};
let mut method: TraitItemMethod = parse_quote!(
fn #method_ident(&mut self) -> #shim_return;
);
method.attrs.extend_from_slice(&func.attrs);
let params = &mut method.sig.inputs;
if func.with_memory {
let arg: FnArg = parse_quote!(memory: &mut Self::Memory);
params.push(arg);
}
for arg in &sig.inputs {
params.push(arg.clone());
}
if method.sig.inputs.len() > 7 {
let clippy_attr: Attribute = parse_quote!(#[allow(clippy::too_many_arguments)]);
method.attrs.push(clippy_attr);
}
shim_trait.items.push(syn::TraitItem::Method(method));
}
{
let method_ident = Ident::new(&format!("{}_export", sig.ident), sig.span());
let res_type: Type = if func.fallible_mode.is_fallible() {
parse_quote!(Result<(), Self::ApiError>)
} else {
parse_quote!(anyhow::Result<()>)
};
let mut method: TraitItemMethod = parse_quote!(
fn #method_ident<'a>(
memory: &mut Self::Memory,
host_context: &mut Self::Context
) -> #res_type;
);
method.attrs.extend_from_slice(&func.attrs);
let formal_params = &mut method.sig.inputs;
for param in &func.params {
match ¶m.param {
FnArg::Typed(pat_type) => {
if let Pat::Ident(pat) = &*pat_type.pat {
let ty = match ¶m.export_type {
Some(ty) => ty,
None => pat_type.ty.as_ref(),
};
let ident = &pat.ident;
formal_params.push(parse_quote!(#ident: #ty));
} else {
return Err(Error::new(
pat_type.span(),
"parameter is missing an identifier",
));
}
}
other @ FnArg::Receiver(_) => {
return Err(Error::new(
other.span(),
"you're trying to pass a self via FFI, this will not work",
));
}
}
}
if method.sig.inputs.len() > 7 {
let clippy_attr: Attribute = parse_quote!(#[allow(clippy::too_many_arguments)]);
method.attrs.push(clippy_attr);
}
shim_trait.items.push(syn::TraitItem::Method(method));
}
}
{
{
let imports_ident = Ident::new("imports", shim_trait.span());
let method: TraitItemMethod = parse_quote!(
fn #imports_ident(linker: &mut Self::WasmLinker) -> Result<(), Self::ImportError>;
);
shim_trait.items.push(syn::TraitItem::Method(method));
}
{
let ns_ident = Ident::new("namespace", shim_trait.span());
let mut method: TraitItemMethod = parse_quote!(
fn #ns_ident() -> (&'static str, &'static str);
);
let imports = &args.imports;
let prefix = &ctx.mod_name;
method.default = Some(parse_quote!({
(#imports, stringify!(#prefix))
}));
shim_trait.items.push(syn::TraitItem::Method(method));
}
}
{
for extern_enum in extern_enums {
let enum_ident = Ident::new(&format!("{}_Repr", extern_enum.ident), shim_trait.span());
shim_trait
.items
.push(syn::TraitItem::Type(parse_quote!(type #enum_ident;)));
}
}
input
.0
.content
.as_mut()
.unwrap()
.1
.push(syn::Item::Trait(shim_trait));
Ok(())
}
fn expand_use(extern_enums: &[ExternEnum], input: &mut Item) {
if extern_enums.is_empty() {
return;
}
let items = &mut input.0.content.as_mut().unwrap().1;
items.push(syn::Item::Use(parse_quote!(
use num_enum::{IntoPrimitive, TryFromPrimitive};
)));
items.push(syn::Item::Use(parse_quote!(
use bytemuck::{NoUninit, CheckedBitPattern};
)));
items.push(syn::Item::Use(parse_quote!(
use std::convert::TryFrom;
)));
}