use crate::utils;
use crate::utils::is_anyhow_result;
use proc_macro2::Span;
use proc_macro2::TokenStream;
use quote::quote;
use quote::ToTokens;
use syn::parse::Error;
use syn::parse::Parse;
use syn::parse::ParseStream;
use syn::parse::Result;
use syn::parse_quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::Attribute;
use syn::FnArg;
use syn::ImplItem;
use syn::ItemImpl;
use syn::Pat;
use syn::Receiver;
use syn::Signature;
use syn::Token;
use syn::Type;
impl ToTokens for Item {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.0.to_tokens(tokens);
}
}
struct ExternType {
ident: syn::Ident,
ffi_ty: Type,
}
pub fn expand(input: &mut Item) -> Result<()> {
expand_types(&mut input.0.items);
let mut extern_enums = Vec::new();
for inner in &input.0.items {
if let ImplItem::Type(item_type) = inner {
if item_type.ident.to_string().ends_with("_Repr") {
extern_enums.push(ExternType {
ident: item_type.ident.clone(),
ffi_ty: item_type.ty.clone(),
});
}
}
}
let deprecated_infallible_attr: Attribute = parse_quote!(#[deprecated_infallible]);
let mut export_methods = Vec::new();
for inner in &mut input.0.items {
if let ImplItem::Method(method) = inner {
let mut deprecated_infallible = false;
method.attrs.retain(|attr| {
if utils::has_custom_attribute(attr, &deprecated_infallible_attr) {
deprecated_infallible = true;
false
} else {
true
}
});
let sig = &mut method.sig;
if let Some(exp) = get_export_function(sig, deprecated_infallible, &extern_enums)? {
export_methods.push(exp);
}
}
}
let import_method = expand_imports(&export_methods);
input.0.items.push(import_method);
input.0.items.extend(
export_methods
.into_iter()
.map(|func| syn::ImplItem::Method(func.item)),
);
Ok(())
}
pub struct Args;
impl Parse for Args {
fn parse(input: ParseStream<'_>) -> Result<Self> {
if input.is_empty() {
Ok(Self)
} else {
Err(Error::new(Span::call_site(), "expected #[host_exports]"))
}
}
}
pub struct Nothing;
impl Parse for Nothing {
fn parse(_input: ParseStream<'_>) -> Result<Self> {
Ok(Self)
}
}
pub struct Item(ItemImpl);
impl Parse for Item {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let lookahead = input.lookahead1();
if lookahead.peek(Token![impl]) {
let mut item: ItemImpl = input.parse()?;
item.attrs = attrs;
Ok(Self(item))
} else {
Err(lookahead.error())
}
}
}
enum FallibleMode {
Fallible,
DeprecatedInfallible,
Infallible,
}
struct ExportFunction {
item: syn::ImplItemMethod,
base_name: String,
fallible_mode: FallibleMode,
}
fn get_export_function(
sig: &Signature,
deprecated_infallible: bool,
extern_enums: &[ExternType],
) -> Result<Option<ExportFunction>> {
let shim_ident = &sig.ident;
let shim_ident_str = shim_ident.to_string();
if !shim_ident_str.ends_with("_shim") {
return Ok(None);
}
let base_name = shim_ident_str.trim_end_matches("_shim");
let (res_type, is_unit, infallible) = match &sig.output {
syn::ReturnType::Default => {
return Err(Error::new(
sig.output.span(),
"unexpected return type: only Result is allowed!",
));
}
syn::ReturnType::Type(_, retty) => match retty.as_ref() {
syn::Type::Path(tp) => {
if let Some(parsed) = is_anyhow_result(tp) {
let parsed = parsed?;
match parsed {
Some(tp) => (Some(tp), false, true),
None => (Some(tp.clone()), true, true),
}
} else if utils::type_path_ends_with(tp, "Result") {
match utils::extract_first_generic_type(tp)? {
Some(tp) => (Some(tp), false, false),
None => (Some(tp.clone()), true, false),
}
} else {
return Err(Error::new(
retty.span(),
"unexpected return type: only Result is allowed!",
));
}
}
_ => return Err(Error::new(retty.span(), "unexpected return type")),
},
};
let fallible_mode = if deprecated_infallible {
if !infallible {
return Err(Error::new(
sig.output.span(),
"functions marked as #[deprecated_infallible] must return anyhow::Result<T>",
));
}
FallibleMode::DeprecatedInfallible
} else if infallible {
FallibleMode::Infallible
} else {
FallibleMode::Fallible
};
let mut is_self = None;
let mut args = Vec::new();
let mut optional_args = Vec::new();
for arg in &sig.inputs {
match arg {
FnArg::Receiver(Receiver {
reference: Some(_),
mutability,
..
}) => {
is_self = Some(mutability.is_some());
}
FnArg::Receiver(arg) => {
return Err(syn::Error::new(arg.span(), "must take self by reference"));
}
FnArg::Typed(arg) => {
if let Pat::Ident(pat) = &*arg.pat {
assert!(
pat.ident != "host_context",
"module context shouldn't be passed to shim function"
);
if pat.ident == "memory" {
optional_args.push(&pat.ident);
continue;
}
args.push((
&pat.ident,
convert_arg(&pat.ident, arg.ty.as_ref(), extern_enums)?,
));
} else {
return Err(syn::Error::new(
arg.span(),
"argument does not have an identifier",
));
}
}
}
}
let export_name = syn::Ident::new(&format!("{base_name}_export"), shim_ident.span());
let mut export_sig = {
let res_type = if matches!(fallible_mode, FallibleMode::Fallible) {
quote!(Result<(), Self::ApiError>)
} else {
quote!(anyhow::Result<()>)
};
let export_sig: ImplItem = parse_quote!(
fn #export_name(memory: &mut Self::Memory, host_context: &mut Self::Context) -> #res_type {}
);
if let ImplItem::Method(mut method) = export_sig {
for arg in args.iter().flat_map(|(_, inp)| inp.args.iter()) {
method.sig.inputs.push(arg.clone());
}
method.sig
} else {
unreachable!()
}
};
let params = {
let mut params: syn::punctuated::Punctuated<syn::Expr, syn::token::Comma> =
syn::punctuated::Punctuated::new();
for arg in optional_args {
params.push(parse_quote!(#arg));
}
for (ident, arg) in &args {
params.push(match &arg.from_wasm {
Some(block) => parse_quote!(#block),
None => parse_quote!(#ident),
});
}
params
};
let call = if is_self.is_some() {
quote!(Self::get(host_context)?.#shim_ident(#params))
} else {
quote!(Self::#shim_ident(#params))
};
let call = if let Some(tp) = res_type {
if is_unit {
quote!(#call)
} else {
export_sig
.inputs
.push(parse_quote!(__ark_ffi_output_ptr: u32));
let mut wrapped_call = None;
if let Some(last) = tp.path.segments.last() {
if last.ident == "Vec" || last.ident == "String" {
let ensure_bytes: Option<syn::Stmt> = if last.ident == "Vec" {
let elem_tp = utils::extract_single_generic_type(&tp)?;
if elem_tp.map_or(false, |tp| !utils::type_path_ends_with(&tp, "u8")) {
return Err(Error::new(
tp.span(),
"only Vec of u8 is allowed in return type position",
));
}
None
} else {
Some(parse_quote!(let res: Vec<u8> = res.into();))
};
wrapped_call = Some(quote!(#call.and_then(|res| {
#ensure_bytes
let output = memory.get_mut(__ark_ffi_output_ptr)?;
*output = res.len() as u32;
host_context.core.set_host_return_vec(res);
Ok(())
})));
}
}
match wrapped_call {
Some(w) => w,
None => {
quote!(#call.and_then(|res| {
let output = memory.get_mut(__ark_ffi_output_ptr)?;
*output = res;
Ok(())
}))
}
}
}
} else {
quote!(Ok(#call))
};
#[cfg(feature = "ffi_profiling")]
let profile = quote!(ark_profiler::function!(););
#[cfg(not(feature = "ffi_profiling"))]
let profile = quote!();
Ok(Some(ExportFunction {
item: parse_quote!(
#export_sig {
#profile
#call
}
),
fallible_mode,
base_name: base_name.to_string(),
}))
}
struct ExportArg {
args: Vec<syn::FnArg>,
from_wasm: Option<syn::Block>,
}
fn is_str(ty: &syn::Type) -> bool {
if let Type::Path(tp) = ty {
match tp.path.get_ident() {
None => false,
Some(id) => {
let idents = id.to_string();
idents == "str"
}
}
} else {
false
}
}
fn convert_arg(
ident: &syn::Ident,
ty: &syn::Type,
extern_enums: &[ExternType],
) -> Result<ExportArg> {
let export_arg = match ty {
syn::Type::Path(tp) => {
let mut param: syn::FnArg = parse_quote!(#ident: #ty);
let mut from_wasm = None;
let extern_enum = extern_enums.iter().find(|ee| {
if let Some(enum_str) = ee.ident.to_string().strip_suffix("_Repr") {
tp.path.segments.last().unwrap().ident == enum_str
} else {
false
}
});
if let Some(ee) = extern_enum {
let ffi_type = &ee.ffi_ty;
param = parse_quote!(#ident: #ffi_type);
from_wasm = Some(parse_quote!({
TryFrom::try_from(#ident).map_err(|_e| ApiError::invalid_arguments(""))? }));
}
ExportArg {
args: vec![param],
from_wasm,
}
}
syn::Type::Reference(tr) => {
let is_mut = tr.mutability.is_some();
if is_str(tr.elem.as_ref()) {
if is_mut {
return Err(syn::Error::new(ty.span(), "&mut str is not allowed!"));
}
let ident_ptr = syn::Ident::new(&format!("{ident}_ptr"), ident.span());
let ident_len = syn::Ident::new(&format!("{ident}_len"), ident.span());
ExportArg {
args: vec![
parse_quote!(#ident_ptr: u32),
parse_quote!(#ident_len: u32),
],
from_wasm: Some(parse_quote!({
memory.str(#ident_ptr, #ident_len)?
})),
}
} else if let syn::Type::Slice(inner) = tr.elem.as_ref() {
if let syn::Type::Path(_tp) = inner.elem.as_ref() {
let ident_ptr = syn::Ident::new(&format!("{ident}_ptr"), ident.span());
let ident_len = syn::Ident::new(&format!("{ident}_len"), ident.span());
ExportArg {
args: vec![
parse_quote!(#ident_ptr: u32),
parse_quote!(#ident_len: u32),
],
from_wasm: Some(if is_mut {
parse_quote!({
memory.slice_mut(#ident_ptr, #ident_len)?
})
} else {
parse_quote!({
memory.slice(#ident_ptr, #ident_len)?
})
}),
}
} else {
return Err(Error::new(tr.elem.span(), "not a simple type path"));
}
} else if let syn::Type::Path(_tp) = tr.elem.as_ref() {
let ident_ptr = syn::Ident::new(&format!("{ident}_ptr"), ident.span());
ExportArg {
args: vec![
parse_quote!(#ident_ptr: u32),
],
from_wasm: Some(if is_mut {
parse_quote!({
memory.get_mut(#ident_ptr)?
})
} else {
parse_quote!({
memory.get(#ident_ptr)?
})
}),
}
} else {
return Err(Error::new(tr.span(), "this type is not supported"));
}
}
_ => return Err(Error::new(ty.span(), "this type is not supported")),
};
Ok(export_arg)
}
fn expand_imports(export_functions: &[ExportFunction]) -> syn::ImplItem {
let mut import_method: syn::ImplItemMethod = parse_quote!(
fn imports(wasmtime_linker: &mut Self::WasmLinker) -> Result<(), Self::ImportError> {}
);
let mut block: syn::Block = parse_quote!({
let (namespace, prefix) = Self::namespace();
});
for export_func in export_functions {
let method = &export_func.item;
let export_ident = &method.sig.ident;
let first_input: Punctuated<FnArg, Token![,]> =
parse_quote!(mut caller: wasmtime::Caller<'_, ModuleContext>,);
let closure_inputs = first_input
.pairs()
.chain(method.sig.inputs.pairs().skip(2)) .collect::<Punctuated<&FnArg, &Token![,]>>();
let mut actual_params: Punctuated<&syn::Ident, syn::token::Comma> = Punctuated::new();
method.sig.inputs.iter().for_each(|arg| {
if let FnArg::Typed(arg) = arg {
if let Pat::Ident(pat) = &*arg.pat {
actual_params.push(&pat.ident);
}
}
});
let name = &export_func.base_name;
let log_call = match export_func.fallible_mode {
FallibleMode::Fallible => quote!(Self::log_call(#name, result)),
FallibleMode::DeprecatedInfallible => {
quote!(Self::log_deprecated_infallible(#name, result))
}
FallibleMode::Infallible => quote!(Self::log_infallible_call(#name, result)),
};
let ffi_ok_type = if matches!(export_func.fallible_mode, FallibleMode::Infallible) {
quote!(())
} else {
quote!(u32)
};
block.stmts.push(parse_quote!(
wasmtime_linker.func_wrap(
namespace,
format!("{}__{}", prefix, #name).as_str(),
move |#closure_inputs| -> anyhow::Result<#ffi_ok_type> {
let (mut memory, host_context) =
crate::wasm_util::get_host_context_from_caller(&mut caller);
let memory = &mut memory;
let result = Self::#export_ident(#actual_params);
#log_call
},
).map_err(|err| InstantiationError::Import(err))?;
));
}
block.stmts.push(parse_quote!(return Ok(());));
import_method.block = block;
syn::ImplItem::Method(import_method)
}
fn expand_types(input: &mut Vec<ImplItem>) {
let err_type: syn::ImplItemType = parse_quote!(
type ApiError = ApiError;
);
let memory_type: syn::ImplItemType = parse_quote!(
type Memory = crate::wasm_util::WasmMemoryHandle<'t>;
);
let context_type: syn::ImplItemType = parse_quote!(
type Context = ModuleContext;
);
let linker_type: syn::ImplItemType = parse_quote!(
type WasmLinker = crate::host_api::WasmLinker;
);
let import_error_type: syn::ImplItemType = parse_quote!(
type ImportError = InstantiationError;
);
input.push(ImplItem::Type(err_type));
input.push(ImplItem::Type(memory_type));
input.push(ImplItem::Type(context_type));
input.push(ImplItem::Type(linker_type));
input.push(ImplItem::Type(import_error_type));
}