use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{parse_macro_input, FnArg, ItemFn, PatType, ReturnType, Type};
#[proc_macro_attribute]
pub fn jit_export(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(item as ItemFn);
if input.sig.abi.is_none() {
input.sig.abi = Some(syn::parse_quote!(extern "C"));
}
input.attrs.push(syn::parse_quote!(
#[allow(improper_ctypes_definitions)]
));
let fn_name = input.sig.ident.clone();
let fn_name_str = fn_name.to_string();
let helper_mod = format_ident!("{}_jit", fn_name);
let param_types: Vec<&Type> = input
.sig
.inputs
.iter()
.filter_map(|arg| match arg {
FnArg::Typed(PatType { ty, .. }) => Some(ty.as_ref()),
FnArg::Receiver(_) => None,
})
.collect();
let return_type: Option<&Type> = match &input.sig.output {
ReturnType::Default => None,
ReturnType::Type(_, ty) => match ty.as_ref() {
Type::Tuple(t) if t.elems.is_empty() => None,
other => Some(other),
},
};
let arg_idents: Vec<_> = (0..param_types.len())
.map(|i| format_ident!("p{}", i))
.collect();
let arg_generics: Vec<_> = (0..param_types.len())
.map(|i| format_ident!("A{}", i))
.collect();
let sig_param_pushes: Vec<TokenStream2> = param_types
.iter()
.map(|ty| {
quote! {
<#ty as ::lower_ir_utils::JitParam>::push_params(&mut sig.params, ptr_ty);
}
})
.collect();
let sig_return_pushes = match return_type {
Some(rt) => quote! {
<#rt as ::lower_ir_utils::JitParam>::push_params(&mut sig.returns, ptr_ty);
},
None => quote! {},
};
let arg_lowers: Vec<TokenStream2> = arg_idents
.iter()
.map(|id| {
quote! {
<_ as ::lower_ir_utils::JitArg>::lower(#id, bcx, ptr_ty, &mut args_buf);
}
})
.collect();
let (call_ret_ty, call_ret_expr) = if return_type.is_some() {
(
quote! { ::lower_ir_utils::__reexport::cranelift_codegen::ir::Value },
quote! { bcx.inst_results(__inst)[0] },
)
} else {
(
quote! { ::lower_ir_utils::__reexport::cranelift_codegen::ir::Inst },
quote! { __inst },
)
};
let expanded = quote! {
#input
#[allow(non_snake_case, non_camel_case_types, dead_code)]
pub mod #helper_mod {
use super::*;
pub const NAME: &'static str = #fn_name_str;
pub fn symbol_addr() -> *const u8 {
super::#fn_name as *const u8
}
pub fn register(jb: &mut ::lower_ir_utils::__reexport::cranelift_jit::JITBuilder) {
jb.symbol(NAME, symbol_addr());
}
pub fn signature<M: ::lower_ir_utils::__reexport::cranelift_module::Module>(
module: &M,
) -> ::lower_ir_utils::__reexport::cranelift_codegen::ir::Signature {
let mut sig = module.make_signature();
let ptr_ty = module.target_config().pointer_type();
#(#sig_param_pushes)*
#sig_return_pushes
sig
}
pub fn declare<M: ::lower_ir_utils::__reexport::cranelift_module::Module>(
module: &mut M,
) -> ::lower_ir_utils::__reexport::cranelift_module::FuncId {
let sig = signature(module);
module
.declare_function(NAME, ::lower_ir_utils::__reexport::cranelift_module::Linkage::Import, &sig)
.expect("declare_function failed")
}
#[allow(clippy::too_many_arguments)]
pub fn call<
M: ::lower_ir_utils::__reexport::cranelift_module::Module,
#(#arg_generics: ::lower_ir_utils::JitArg,)*
>(
bcx: &mut ::lower_ir_utils::__reexport::cranelift_frontend::FunctionBuilder<'_>,
module: &mut M,
id: ::lower_ir_utils::__reexport::cranelift_module::FuncId,
#(#arg_idents: #arg_generics,)*
) -> #call_ret_ty {
use ::lower_ir_utils::__reexport::cranelift_codegen::ir::InstBuilder as _;
let ptr_ty = module.target_config().pointer_type();
let local = module.declare_func_in_func(id, bcx.func);
let mut args_buf: ::lower_ir_utils::__reexport::smallvec::SmallVec<
[::lower_ir_utils::__reexport::cranelift_codegen::ir::Value; 8]
> = ::lower_ir_utils::__reexport::smallvec::SmallVec::new();
#(#arg_lowers)*
let __inst = bcx.ins().call(local, &args_buf);
#call_ret_expr
}
}
};
TokenStream::from(expanded)
}