Skip to main content

ir_assert_macro/
lib.rs

1use proc_macro::TokenStream as TokenStream1;
2use proc_macro2::{Span, TokenStream};
3use proc_macro_error2::{abort, abort_call_site, proc_macro_error};
4use syn::parse::Parser;
5use syn::punctuated::Punctuated;
6use syn::*;
7use template_quote::{quote, ToTokens};
8
9/// Generate a unique hash from the macro input tokens and the span of the predicate.
10///
11/// We use the predicate's token spans (source-file locations) rather than a global counter
12/// so the hash is identical in debug and release compilations of the same file.
13/// A counter would desync when `#[cfg]` attributes skip some macro call sites in release
14/// builds (e.g. a test function gated on `cfg(debug_assertions)`).
15fn unique_hash(input: &TokenStream, predicate: &Expr) -> u64 {
16    use std::collections::hash_map::DefaultHasher;
17    use std::hash::{Hash, Hasher};
18    let mut hasher = DefaultHasher::new();
19    let s = input.to_string();
20    let normalized: String = s.split_whitespace().collect::<Vec<_>>().join(" ");
21    normalized.hash(&mut hasher);
22    // The predicate tokens carry real source spans even inside macro expansions.
23    let span_dbg: String = predicate
24        .to_token_stream()
25        .into_iter()
26        .map(|t| format!("{:?}", t.span()))
27        .collect::<Vec<_>>()
28        .join(",");
29    span_dbg.hash(&mut hasher);
30    hasher.finish()
31}
32
33/// Classify a target expression for codegen.
34enum Target<'a> {
35    Closure {
36        coerce_ident: Ident,
37        arity: usize,
38        params: &'a Punctuated<Pat, Token![,]>,
39        body: &'a Expr,
40    },
41    Function(&'a Expr),
42}
43
44/// Parsed and prepared macro inputs used by both assert_ir! and debug_assert_ir! codegen.
45struct CodegenInput<'a> {
46    krate: TokenStream,
47    container_ident: Ident,
48    container_name: String,
49    pred_tokens: TokenStream,
50    pred_str: LitStr,
51    target_str_lits: Vec<LitStr>,
52    cargo_path: String,
53    rustup_path: String,
54    manifest_dir: String,
55    crate_name: String,
56    is_test: bool,
57    asm_tag: LitStr,
58    prepared: Vec<Target<'a>>,
59}
60
61impl<'a> CodegenInput<'a> {
62    fn parse(crate_expr: &Expr, predicate_expr: &'a Expr, targets: &[&'a Expr]) -> Self {
63        let krate: TokenStream = quote! { #crate_expr };
64
65        let hash_input: TokenStream = {
66            let pred_ts: TokenStream = quote! { #predicate_expr };
67            let targets_ts: Vec<TokenStream> = targets.iter().map(|t| quote! { #t }).collect();
68            quote! { #pred_ts #(#targets_ts)* }
69        };
70        let r = unique_hash(&hash_input, predicate_expr);
71
72        let container_name = format!("ir_assert_container_{}", r);
73        let container_ident = Ident::new(&container_name, Span::call_site());
74
75        let pred_str = LitStr::new(
76            &predicate_expr.to_token_stream().to_string(),
77            Span::call_site(),
78        );
79        let target_str_lits: Vec<LitStr> = targets
80            .iter()
81            .map(|t| LitStr::new(&quote!(#t).to_string(), Span::call_site()))
82            .collect();
83
84        let pred_tokens = quote! { #predicate_expr };
85
86        let cargo_path = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
87        let rustup_path = std::env::var("RUSTUP").unwrap_or_else(|_| "rustup".to_string());
88        let manifest_dir =
89            std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
90
91        let args: Vec<String> = std::env::args().collect();
92        let is_test = args.iter().any(|a| a == "--test");
93        let crate_name = args
94            .iter()
95            .position(|a| a == "--crate-name")
96            .and_then(|i| args.get(i + 1))
97            .cloned()
98            .unwrap_or_else(|| "unknown".to_string());
99
100        let asm_tag = LitStr::new(&format!("/* ir_assert {} {{0}} */", r), Span::call_site());
101
102        let prepared: Vec<Target<'a>> = targets
103            .iter()
104            .enumerate()
105            .map(|(i, target)| {
106                if let Expr::Closure(closure) = target {
107                    Target::Closure {
108                        coerce_ident: Ident::new(
109                            &format!("__ir_assert_fn_{}", i),
110                            Span::call_site(),
111                        ),
112                        arity: closure.inputs.len(),
113                        params: &closure.inputs,
114                        body: &closure.body,
115                    }
116                } else {
117                    Target::Function(target)
118                }
119            })
120            .collect();
121
122        Self {
123            krate,
124            container_ident,
125            container_name,
126            pred_tokens,
127            pred_str,
128            target_str_lits,
129            cargo_path,
130            rustup_path,
131            manifest_dir,
132            crate_name,
133            is_test,
134            asm_tag,
135            prepared,
136        }
137    }
138
139    /// Inline-asm statements that pin each target symbol inside the container function.
140    fn target_stmts(&self) -> Vec<TokenStream> {
141        let asm_tag = &self.asm_tag;
142        self.prepared
143            .iter()
144            .map(|t| match t {
145                Target::Closure {
146                    coerce_ident,
147                    params,
148                    body,
149                    ..
150                } => {
151                    let container_arg_tys: Vec<TokenStream> = params
152                        .iter()
153                        .map(|p| {
154                            if matches!(p, Pat::Type(_)) {
155                                quote! { _ }
156                            } else {
157                                quote! { usize }
158                            }
159                        })
160                        .collect();
161                    let container_params: Vec<TokenStream> = params
162                        .iter()
163                        .map(|p| {
164                            if matches!(p, Pat::Type(_)) {
165                                quote! { #p }
166                            } else {
167                                quote! { #p: usize }
168                            }
169                        })
170                        .collect();
171                    quote! {
172                        let #coerce_ident: fn(#(#container_arg_tys),*) -> _ = |#(#container_params),*| #body;
173
174                        #[cfg(target_arch = "wasm32")]
175                        unsafe {
176                            core::arch::asm!(#asm_tag, in(local) #coerce_ident as usize,
177                                options(nostack, preserves_flags, readonly));
178                        }
179                        #[cfg(not(target_arch = "wasm32"))]
180                        unsafe {
181                            core::arch::asm!(#asm_tag, in(reg) #coerce_ident as usize,
182                                options(nostack, preserves_flags, readonly));
183                        }
184                    }
185                }
186                Target::Function(expr) => quote! {
187                    #[cfg(target_arch = "wasm32")]
188                    unsafe {
189                        core::arch::asm!(#asm_tag, in(local) #expr as usize,
190                            options(nostack, preserves_flags, readonly));
191                    }
192                    #[cfg(not(target_arch = "wasm32"))]
193                    unsafe {
194                        core::arch::asm!(#asm_tag, in(reg) #expr as usize,
195                            options(nostack, preserves_flags, readonly));
196                    }
197                },
198            })
199            .collect()
200    }
201
202    /// The `#[no_mangle]` container function that embeds target symbol references via asm.
203    ///
204    /// Always compiled (including in release) so the IR-generation pass can locate the
205    /// container and discover the referenced target symbols.
206    fn container_fn(&self) -> TokenStream {
207        let target_stmts = self.target_stmts();
208        let container_ident = &self.container_ident;
209        quote! {
210            #[no_mangle]
211            #[inline(never)]
212            #[allow(unused, dead_code)]
213            fn #container_ident() {
214                #(#target_stmts)*
215            }
216        }
217    }
218
219    /// The `__macro_internal(...)` call that drives the actual IR assertion at runtime.
220    fn macro_internal_call(&self) -> TokenStream {
221        let Self {
222            krate,
223            container_name,
224            pred_tokens,
225            pred_str,
226            target_str_lits,
227            cargo_path,
228            rustup_path,
229            manifest_dir,
230            crate_name,
231            is_test,
232            ..
233        } = self;
234        quote! {
235            #krate::__macro_internal(
236                #cargo_path,
237                #rustup_path,
238                #manifest_dir,
239                #crate_name,
240                #is_test,
241                #container_name,
242                &{ use #krate::predicate::*; #pred_tokens },
243                #pred_str,
244                &[#(#target_str_lits),*],
245            );
246        }
247    }
248
249    /// Return expression for the single-target case (closure coercion or fn expr).
250    /// Returns `None` for multi-target invocations (result type is `()`).
251    fn return_expr(&self) -> Option<TokenStream> {
252        if self.prepared.len() != 1 {
253            return None;
254        }
255        match &self.prepared[0] {
256            Target::Closure {
257                arity,
258                params,
259                body,
260                ..
261            } => {
262                let arg_tys: Vec<TokenStream> = (0..*arity).map(|_| quote! { _ }).collect();
263                Some(quote! {
264                    let __ir_assert_ret: fn(#(#arg_tys),*) -> _ = |#params| #body;
265                    __ir_assert_ret
266                })
267            }
268            Target::Function(expr) => Some(quote! { #expr }),
269        }
270    }
271}
272
273/// Shared code-generation entry point for both proc-macros.
274///
275/// `debug_only = false` → assert_ir!: container + assertion always emitted.
276/// `debug_only = true`  → debug_assert_ir!: assertion gated on cfg(debug_assertions);
277///                        multiple targets are a compile error outside debug mode.
278fn codegen(input: TokenStream, debug_only: bool) -> TokenStream {
279    let parsed: Punctuated<Expr, Token![,]> = match Punctuated::parse_terminated.parse2(input.clone()) {
280        Ok(p) => p,
281        Err(e) => abort!(e.span(), "ir-assert: parse error: {}", e),
282    };
283
284    let mut iter = parsed.iter();
285    let crate_expr = iter
286        .next()
287        .unwrap_or_else(|| abort_call_site!("ir-assert: expected crate path"));
288    let predicate_expr = iter
289        .next()
290        .unwrap_or_else(|| abort_call_site!("ir-assert: expected predicate expression"));
291    let targets: Vec<&Expr> = iter.collect();
292
293    if targets.is_empty() {
294        abort_call_site!("ir-assert: expected at least one target function/closure after the predicate");
295    }
296
297    // Abort at proc-macro time for multi-target debug_assert_ir! in non-debug builds.
298    // The IR-generation pass sets IR_ASSERT_IR_GEN to suppress this error.
299    if debug_only && targets.len() > 1 && !debug_assertions_active() && std::env::var("IR_ASSERT_IR_GEN").is_err() {
300        abort!(
301            quote! { #(#targets)* },
302            "debug_assert_ir! does not support multiple targets when debug_assertions is disabled"
303        );
304    }
305
306    let cg = CodegenInput::parse(crate_expr, predicate_expr, &targets);
307    let container_fn = cg.container_fn();
308    let call = cg.macro_internal_call();
309    let return_tokens = cg.return_expr().unwrap_or_default();
310
311    quote! {
312        {
313            #container_fn
314            #(if debug_only) {
315                #[cfg(debug_assertions)]
316                { #call }
317            }
318            #(else) {
319                #call
320            }
321            #return_tokens
322        }
323    }
324}
325
326/// Returns true when `debug_assertions` is active for the crate being compiled.
327///
328/// `CARGO_CFG_DEBUG_ASSERTIONS` is only set for build scripts, not for proc-macros.
329/// Instead we inspect rustc's own command-line args (the proc-macro runs inside rustc):
330///
331/// 1. An explicit `-C debug-assertions=yes/no` flag overrides everything.
332/// 2. Otherwise, `debug_assertions` mirrors `opt-level`: it is ON when opt-level == 0
333///    (the default) and OFF when opt-level > 0.
334fn debug_assertions_active() -> bool {
335    let args: Vec<String> = std::env::args().collect();
336
337    // Helper: extract the value portion of a -C key=value flag.
338    let cflag_val = |key: &str| -> Option<String> {
339        for w in args.windows(2) {
340            if w[0] == "-C" {
341                if let Some(v) = w[1].strip_prefix(key) {
342                    return Some(v.to_owned());
343                }
344            }
345        }
346        for a in &args {
347            if let Some(v) = a.strip_prefix(&format!("-C{key}")) {
348                return Some(v.to_owned());
349            }
350        }
351        None
352    };
353
354    // Explicit -C debug-assertions= overrides everything.
355    if let Some(val) = cflag_val("debug-assertions=") {
356        return val == "yes" || val == "1";
357    }
358
359    // Derive from opt-level: debug_assertions is on only at opt-level 0 (the default).
360    matches!(cflag_val("opt-level=").as_deref(), None | Some("0"))
361}
362
363#[proc_macro_error]
364#[proc_macro]
365pub fn __assert_ir_impl(input: TokenStream1) -> TokenStream1 {
366    codegen(input.into(), false).into()
367}
368
369#[proc_macro_error]
370#[proc_macro]
371pub fn __debug_assert_ir_impl(input: TokenStream1) -> TokenStream1 {
372    codegen(input.into(), true).into()
373}