Skip to main content

ir_assert_macro/
lib.rs

1use proc_macro::TokenStream as TokenStream1;
2use proc_macro2::{Span, TokenStream};
3use std::sync::atomic::{AtomicU64, Ordering};
4use syn::parse::Parser;
5use syn::punctuated::Punctuated;
6use syn::*;
7use template_quote::{quote, ToTokens};
8
9/// Global counter to ensure each macro invocation produces a unique container name,
10/// even when the same predicate+targets pair appears multiple times.
11/// This is deterministic across compilations because macro expansion order is stable.
12static COUNTER: AtomicU64 = AtomicU64::new(0);
13
14/// Generate a unique hash from the macro input tokens and an invocation counter.
15fn unique_hash(input: &TokenStream) -> u64 {
16    use std::collections::hash_map::DefaultHasher;
17    use std::hash::{Hash, Hasher};
18    let mut hasher = DefaultHasher::new();
19    let s = input.to_string();
20    let normalized: String = s.split_whitespace().collect::<Vec<_>>().join(" ");
21    normalized.hash(&mut hasher);
22    COUNTER.fetch_add(1, Ordering::Relaxed).hash(&mut hasher);
23    hasher.finish()
24}
25
26/// Prepare a target expression for codegen. Returns the statements to embed in the container.
27enum Target<'a> {
28    Closure {
29        coerce_ident: Ident,
30        arity: usize,
31        params: &'a Punctuated<Pat, Token![,]>,
32        body: &'a Expr,
33    },
34    Function(&'a Expr),
35}
36
37fn inner(input: TokenStream) -> TokenStream {
38    let parsed: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated
39        .parse2(input.clone())
40        .unwrap_or_else(|e| panic!("ir-assert: parse error: {}", e));
41
42    let mut iter = parsed.iter();
43    let crate_expr = iter
44        .next()
45        .unwrap_or_else(|| panic!("ir-assert: expected crate path"));
46    let predicate_expr = iter
47        .next()
48        .unwrap_or_else(|| panic!("ir-assert: expected predicate expression"));
49    let targets: Vec<&Expr> = iter.collect();
50
51    if targets.is_empty() {
52        panic!("ir-assert: expected at least one target function/closure after the predicate");
53    }
54
55    let krate: TokenStream = quote! { #crate_expr };
56
57    // Hash predicate + targets for deterministic naming
58    let hash_input: TokenStream = {
59        let pred_ts: TokenStream = quote! { #predicate_expr };
60        let targets_ts: Vec<TokenStream> = targets.iter().map(|t| quote! { #t }).collect();
61        quote! { #pred_ts #(#targets_ts)* }
62    };
63    let r = unique_hash(&hash_input);
64
65    let container_name = format!("ir_assert_container_{}", r);
66    let container_ident = Ident::new(&container_name, Span::call_site());
67
68    // Stringify the predicate and targets for error messages (before transformation)
69    let pred_str = LitStr::new(
70        &predicate_expr.to_token_stream().to_string(),
71        Span::call_site(),
72    );
73    let target_str_lits: Vec<LitStr> = targets
74        .iter()
75        .map(|t| LitStr::new(&quote!(#t).to_string(), Span::call_site()))
76        .collect();
77
78    let pred_tokens = quote! { #predicate_expr };
79
80    let cargo_path = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
81    let rustup_path = std::env::var("RUSTUP").unwrap_or_else(|_| "rustup".to_string());
82    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
83
84    let args: Vec<String> = std::env::args().collect();
85    let is_test = args.iter().any(|a| a == "--test");
86    let crate_name = args
87        .iter()
88        .position(|a| a == "--crate-name")
89        .and_then(|i| args.get(i + 1))
90        .cloned()
91        .unwrap_or_else(|| "unknown".to_string());
92
93    let asm_tag = LitStr::new(&format!("/* ir_assert {} {{0}} */", r), Span::call_site());
94
95    // Prepare targets as an enum for codegen
96    let prepared: Vec<Target> = targets
97        .iter()
98        .enumerate()
99        .map(|(i, target)| {
100            if let Expr::Closure(closure) = target {
101                Target::Closure {
102                    coerce_ident: Ident::new(&format!("__ir_assert_fn_{}", i), Span::call_site()),
103                    arity: closure.inputs.len(),
104                    params: &closure.inputs,
105                    body: &closure.body,
106                }
107            } else {
108                Target::Function(target)
109            }
110        })
111        .collect();
112
113    let target_stmts: Vec<TokenStream> = prepared
114        .iter()
115        .map(|t| match t {
116            Target::Closure {
117                coerce_ident,
118                arity: _,
119                params,
120                body,
121            } => {
122                // For the container function, un-annotated closure params need a concrete
123                // type so the container compiles in isolation. We use `usize` as the
124                // fallback; annotated params keep their original type via `_` inference.
125                let container_arg_tys: Vec<TokenStream> = params.iter().map(|p| {
126                    if matches!(p, Pat::Type(_)) {
127                        quote! { _ }
128                    } else {
129                        quote! { usize }
130                    }
131                }).collect();
132                let container_params: Vec<TokenStream> = params.iter().map(|p| {
133                    if matches!(p, Pat::Type(_)) {
134                        quote! { #p }
135                    } else {
136                        quote! { #p: usize }
137                    }
138                }).collect();
139                quote! {
140                    let #coerce_ident: fn(#(#container_arg_tys),*) -> _ = |#(#container_params),*| #body;
141
142                    #[cfg(target_arch = "wasm32")]
143                    unsafe {
144                        core::arch::asm!(#asm_tag, in(local) #coerce_ident as usize,
145                            options(nostack, preserves_flags, readonly));
146                    }
147                    #[cfg(not(target_arch = "wasm32"))]
148                    unsafe {
149                        core::arch::asm!(#asm_tag, in(reg) #coerce_ident as usize,
150                            options(nostack, preserves_flags, readonly));
151                    }
152                }
153            }
154            Target::Function(expr) => quote! {
155                #[cfg(target_arch = "wasm32")]
156                unsafe {
157                    core::arch::asm!(#asm_tag, in(local) #expr as usize,
158                        options(nostack, preserves_flags, readonly));
159                }
160                #[cfg(not(target_arch = "wasm32"))]
161                unsafe {
162                    core::arch::asm!(#asm_tag, in(reg) #expr as usize,
163                        options(nostack, preserves_flags, readonly));
164                }
165            },
166        })
167        .collect();
168
169    // When there's exactly one target, return it from the block
170    let return_expr: Option<TokenStream> = if prepared.len() == 1 {
171        match &prepared[0] {
172            Target::Closure {
173                arity,
174                params,
175                body,
176                ..
177            } => {
178                let arg_tys: Vec<TokenStream> = (0..*arity).map(|_| quote! { _ }).collect();
179                Some(quote! {
180                    let __ir_assert_ret: fn(#(#arg_tys),*) -> _ = |#params| #body;
181                    __ir_assert_ret
182                })
183            }
184            Target::Function(expr) => Some(quote! { #expr }),
185        }
186    } else {
187        None
188    };
189
190    let return_tokens = return_expr
191        .map(|expr| quote! { #expr })
192        .unwrap_or_else(|| quote! {});
193
194    quote! {
195        {
196            #[no_mangle]
197            #[inline(never)]
198            #[allow(unused, dead_code)]
199            fn #container_ident() {
200                #(#target_stmts)*
201            }
202
203            #krate::__macro_internal(
204                #cargo_path,
205                #rustup_path,
206                #manifest_dir,
207                #crate_name,
208                #is_test,
209                #container_name,
210                &{ use #krate::predicate::*; #pred_tokens },
211                #pred_str,
212                &[#(#target_str_lits),*],
213            );
214
215            #return_tokens
216        }
217    }
218}
219
220#[proc_macro]
221pub fn __assert_ir_impl(input: TokenStream1) -> TokenStream1 {
222    inner(input.into()).into()
223}