Skip to main content

lower_ir_utils_macros/
lib.rs

1//! Procedural macros for [lower-ir-utils](https://docs.rs/lower-ir-utils).
2//!
3//! Prefer depending on **`lower-ir-utils`** for the public API (`#[jit_export]` is
4//! re-exported there). Match this crate's version to your `lower-ir-utils` dependency.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, FnArg, ItemFn, PatType, ReturnType, Type};
10
11/// Annotate a Rust function so it can be called from JIT-compiled Cranelift IR.
12///
13/// Generates a sibling module `<fn_name>_jit` exposing helpers that hide the
14/// boilerplate of (1) registering the symbol with `JITBuilder`, (2) building
15/// the Cranelift `Signature`, (3) declaring the import in the module, (4)
16/// declaring a local `FuncRef`, and (5) emitting the call.
17///
18/// If the function has no explicit ABI, `extern "C"` is added automatically.
19///
20/// # Generated API
21///
22/// For `fn foo(p1: T1, p2: T2) -> R` the macro emits, alongside the function:
23///
24/// ```ignore
25/// pub mod foo_jit {
26///     pub const NAME: &'static str;
27///     pub fn symbol_addr() -> *const u8;
28///     pub fn register(jb: &mut JITBuilder);
29///     pub fn signature<M: Module>(module: &M) -> Signature;
30///     pub fn declare<M: Module>(module: &mut M) -> FuncId;
31///     pub fn call<M, A1, A2>(
32///         bcx: &mut FunctionBuilder,
33///         module: &mut M,
34///         id: FuncId,
35///         p1: A1, p2: A2,
36///     ) -> Value;          // or Inst when R is unit
37/// }
38/// ```
39///
40/// Each `A_i: JitArg`, so users can pass either an already-lowered IR `Value`
41/// or a Rust constant (`&'static str`, `i64`, `*const T`, ...).
42///
43/// # Panics
44///
45/// The generated `declare` helper unwraps `declare_function` with `expect`. It will
46/// panic if the symbol is already declared under the same name or if the module rejects
47/// the declaration for another reason (use the module API directly if you need non-panicking error handling).
48///
49/// # Return value of `call`
50///
51/// When the annotated function returns a Rust value, `call` returns the callee's first
52/// SSA result (`cranelift_codegen::ir::Value`, via `inst_results`).
53/// When the return type is unit (no return / `-> ()`), `call` returns the
54/// `cranelift_codegen::ir::Inst` from the emitted `call` instead; you can discard it for
55/// side-effect-only calls or keep it if you need the instruction handle.
56#[proc_macro_attribute]
57pub fn jit_export(_attr: TokenStream, item: TokenStream) -> TokenStream {
58    let mut input = parse_macro_input!(item as ItemFn);
59
60    // Auto-inject `extern "C"` if no ABI was specified.
61    if input.sig.abi.is_none() {
62        input.sig.abi = Some(syn::parse_quote!(extern "C"));
63    }
64
65    // Allow idiomatic Rust types like `&str` in the signature without nagging
66    // the user about `improper_ctypes_definitions`. This is fine on platforms
67    // where the fat-pointer ABI matches separate (ptr, len) args (e.g. SystemV
68    // x86_64); users targeting platforms that disagree should use flat params.
69    input.attrs.push(syn::parse_quote!(
70        #[allow(improper_ctypes_definitions)]
71    ));
72
73    let fn_name = input.sig.ident.clone();
74    let fn_name_str = fn_name.to_string();
75    let helper_mod = format_ident!("{}_jit", fn_name);
76
77    // Collect param types (skip `self` — `extern "C"` fns don't have it but be defensive).
78    let param_types: Vec<&Type> = input
79        .sig
80        .inputs
81        .iter()
82        .filter_map(|arg| match arg {
83            FnArg::Typed(PatType { ty, .. }) => Some(ty.as_ref()),
84            FnArg::Receiver(_) => None,
85        })
86        .collect();
87
88    let return_type: Option<&Type> = match &input.sig.output {
89        ReturnType::Default => None,
90        ReturnType::Type(_, ty) => match ty.as_ref() {
91            // `-> ()` is the same as no return for our purposes.
92            Type::Tuple(t) if t.elems.is_empty() => None,
93            other => Some(other),
94        },
95    };
96
97    let arg_idents: Vec<_> = (0..param_types.len())
98        .map(|i| format_ident!("p{}", i))
99        .collect();
100    let arg_generics: Vec<_> = (0..param_types.len())
101        .map(|i| format_ident!("A{}", i))
102        .collect();
103
104    let sig_param_pushes: Vec<TokenStream2> = param_types
105        .iter()
106        .map(|ty| {
107            quote! {
108                <#ty as ::lower_ir_utils::JitParam>::push_params(&mut sig.params, ptr_ty);
109            }
110        })
111        .collect();
112
113    let sig_return_pushes = match return_type {
114        Some(rt) => quote! {
115            <#rt as ::lower_ir_utils::JitParam>::push_params(&mut sig.returns, ptr_ty);
116        },
117        None => quote! {},
118    };
119
120    let arg_lowers: Vec<TokenStream2> = arg_idents
121        .iter()
122        .map(|id| {
123            quote! {
124                <_ as ::lower_ir_utils::JitArg>::lower(#id, bcx, ptr_ty, &mut args_buf);
125            }
126        })
127        .collect();
128
129    let (call_ret_ty, call_ret_expr) = if return_type.is_some() {
130        (
131            quote! { ::lower_ir_utils::__reexport::cranelift_codegen::ir::Value },
132            quote! { bcx.inst_results(__inst)[0] },
133        )
134    } else {
135        (
136            quote! { ::lower_ir_utils::__reexport::cranelift_codegen::ir::Inst },
137            quote! { __inst },
138        )
139    };
140
141    let expanded = quote! {
142        #input
143
144        #[allow(non_snake_case, non_camel_case_types, dead_code)]
145        pub mod #helper_mod {
146            use super::*;
147
148            pub const NAME: &'static str = #fn_name_str;
149
150            pub fn symbol_addr() -> *const u8 {
151                super::#fn_name as *const u8
152            }
153
154            pub fn register(jb: &mut ::lower_ir_utils::__reexport::cranelift_jit::JITBuilder) {
155                jb.symbol(NAME, symbol_addr());
156            }
157
158            pub fn signature<M: ::lower_ir_utils::__reexport::cranelift_module::Module>(
159                module: &M,
160            ) -> ::lower_ir_utils::__reexport::cranelift_codegen::ir::Signature {
161                let mut sig = module.make_signature();
162                let ptr_ty = module.target_config().pointer_type();
163                #(#sig_param_pushes)*
164                #sig_return_pushes
165                sig
166            }
167
168            pub fn declare<M: ::lower_ir_utils::__reexport::cranelift_module::Module>(
169                module: &mut M,
170            ) -> ::lower_ir_utils::__reexport::cranelift_module::FuncId {
171                let sig = signature(module);
172                module
173                    .declare_function(NAME, ::lower_ir_utils::__reexport::cranelift_module::Linkage::Import, &sig)
174                    .expect("declare_function failed")
175            }
176
177            #[allow(clippy::too_many_arguments)]
178            pub fn call<
179                M: ::lower_ir_utils::__reexport::cranelift_module::Module,
180                #(#arg_generics: ::lower_ir_utils::JitArg,)*
181            >(
182                bcx: &mut ::lower_ir_utils::__reexport::cranelift_frontend::FunctionBuilder<'_>,
183                module: &mut M,
184                id: ::lower_ir_utils::__reexport::cranelift_module::FuncId,
185                #(#arg_idents: #arg_generics,)*
186            ) -> #call_ret_ty {
187                use ::lower_ir_utils::__reexport::cranelift_codegen::ir::InstBuilder as _;
188                let ptr_ty = module.target_config().pointer_type();
189                let local = module.declare_func_in_func(id, bcx.func);
190                let mut args_buf: ::lower_ir_utils::__reexport::smallvec::SmallVec<
191                    [::lower_ir_utils::__reexport::cranelift_codegen::ir::Value; 8]
192                > = ::lower_ir_utils::__reexport::smallvec::SmallVec::new();
193                #(#arg_lowers)*
194                let __inst = bcx.ins().call(local, &args_buf);
195                #call_ret_expr
196            }
197        }
198    };
199
200    TokenStream::from(expanded)
201}