use proc_macro::TokenStream as TokenStream1;
use proc_macro2::{Span, TokenStream};
use proc_macro_error2::{abort, abort_call_site, proc_macro_error};
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::*;
use template_quote::{quote, ToTokens};
fn unique_hash(input: &TokenStream, predicate: &Expr) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let s = input.to_string();
let normalized: String = s.split_whitespace().collect::<Vec<_>>().join(" ");
normalized.hash(&mut hasher);
let span_dbg: String = predicate
.to_token_stream()
.into_iter()
.map(|t| format!("{:?}", t.span()))
.collect::<Vec<_>>()
.join(",");
span_dbg.hash(&mut hasher);
hasher.finish()
}
enum Target<'a> {
Closure {
coerce_ident: Ident,
arity: usize,
params: &'a Punctuated<Pat, Token![,]>,
body: &'a Expr,
},
Function(&'a Expr),
}
struct CodegenInput<'a> {
krate: TokenStream,
container_ident: Ident,
container_name: String,
pred_tokens: TokenStream,
pred_str: LitStr,
target_str_lits: Vec<LitStr>,
cargo_path: String,
rustup_path: String,
manifest_dir: String,
crate_name: String,
is_test: bool,
asm_tag: LitStr,
prepared: Vec<Target<'a>>,
}
impl<'a> CodegenInput<'a> {
fn parse(crate_expr: &Expr, predicate_expr: &'a Expr, targets: &[&'a Expr]) -> Self {
let krate: TokenStream = quote! { #crate_expr };
let hash_input: TokenStream = {
let pred_ts: TokenStream = quote! { #predicate_expr };
let targets_ts: Vec<TokenStream> = targets.iter().map(|t| quote! { #t }).collect();
quote! { #pred_ts #(#targets_ts)* }
};
let r = unique_hash(&hash_input, predicate_expr);
let container_name = format!("ir_assert_container_{}", r);
let container_ident = Ident::new(&container_name, Span::call_site());
let pred_str = LitStr::new(
&predicate_expr.to_token_stream().to_string(),
Span::call_site(),
);
let target_str_lits: Vec<LitStr> = targets
.iter()
.map(|t| LitStr::new("e!(#t).to_string(), Span::call_site()))
.collect();
let pred_tokens = quote! { #predicate_expr };
let cargo_path = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
let rustup_path = std::env::var("RUSTUP").unwrap_or_else(|_| "rustup".to_string());
let manifest_dir =
std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
let args: Vec<String> = std::env::args().collect();
let is_test = args.iter().any(|a| a == "--test");
let crate_name = args
.iter()
.position(|a| a == "--crate-name")
.and_then(|i| args.get(i + 1))
.cloned()
.unwrap_or_else(|| "unknown".to_string());
let asm_tag = LitStr::new(&format!("/* ir_assert {} {{0}} */", r), Span::call_site());
let prepared: Vec<Target<'a>> = targets
.iter()
.enumerate()
.map(|(i, target)| {
if let Expr::Closure(closure) = target {
Target::Closure {
coerce_ident: Ident::new(
&format!("__ir_assert_fn_{}", i),
Span::call_site(),
),
arity: closure.inputs.len(),
params: &closure.inputs,
body: &closure.body,
}
} else {
Target::Function(target)
}
})
.collect();
Self {
krate,
container_ident,
container_name,
pred_tokens,
pred_str,
target_str_lits,
cargo_path,
rustup_path,
manifest_dir,
crate_name,
is_test,
asm_tag,
prepared,
}
}
fn target_stmts(&self) -> Vec<TokenStream> {
let asm_tag = &self.asm_tag;
self.prepared
.iter()
.map(|t| match t {
Target::Closure {
coerce_ident,
params,
body,
..
} => {
let container_arg_tys: Vec<TokenStream> = params
.iter()
.map(|p| {
if matches!(p, Pat::Type(_)) {
quote! { _ }
} else {
quote! { usize }
}
})
.collect();
let container_params: Vec<TokenStream> = params
.iter()
.map(|p| {
if matches!(p, Pat::Type(_)) {
quote! { #p }
} else {
quote! { #p: usize }
}
})
.collect();
quote! {
let #coerce_ident: fn(#(#container_arg_tys),*) -> _ = |#(#container_params),*| #body;
#[cfg(target_arch = "wasm32")]
unsafe {
core::arch::asm!(#asm_tag, in(local) #coerce_ident as usize,
options(nostack, preserves_flags, readonly));
}
#[cfg(not(target_arch = "wasm32"))]
unsafe {
core::arch::asm!(#asm_tag, in(reg) #coerce_ident as usize,
options(nostack, preserves_flags, readonly));
}
}
}
Target::Function(expr) => quote! {
#[cfg(target_arch = "wasm32")]
unsafe {
core::arch::asm!(#asm_tag, in(local) #expr as usize,
options(nostack, preserves_flags, readonly));
}
#[cfg(not(target_arch = "wasm32"))]
unsafe {
core::arch::asm!(#asm_tag, in(reg) #expr as usize,
options(nostack, preserves_flags, readonly));
}
},
})
.collect()
}
fn container_fn(&self) -> TokenStream {
let target_stmts = self.target_stmts();
let container_ident = &self.container_ident;
quote! {
#[no_mangle]
#[inline(never)]
#[allow(unused, dead_code)]
fn #container_ident() {
#(#target_stmts)*
}
}
}
fn macro_internal_call(&self) -> TokenStream {
let Self {
krate,
container_name,
pred_tokens,
pred_str,
target_str_lits,
cargo_path,
rustup_path,
manifest_dir,
crate_name,
is_test,
..
} = self;
quote! {
#krate::__macro_internal(
#cargo_path,
#rustup_path,
#manifest_dir,
#crate_name,
#is_test,
#container_name,
&{ use #krate::predicate::*; #pred_tokens },
#pred_str,
&[#(#target_str_lits),*],
);
}
}
fn return_expr(&self) -> Option<TokenStream> {
if self.prepared.len() != 1 {
return None;
}
match &self.prepared[0] {
Target::Closure {
arity,
params,
body,
..
} => {
let arg_tys: Vec<TokenStream> = (0..*arity).map(|_| quote! { _ }).collect();
Some(quote! {
let __ir_assert_ret: fn(#(#arg_tys),*) -> _ = |#params| #body;
__ir_assert_ret
})
}
Target::Function(expr) => Some(quote! { #expr }),
}
}
}
fn codegen(input: TokenStream, debug_only: bool) -> TokenStream {
let parsed: Punctuated<Expr, Token![,]> = match Punctuated::parse_terminated.parse2(input.clone()) {
Ok(p) => p,
Err(e) => abort!(e.span(), "ir-assert: parse error: {}", e),
};
let mut iter = parsed.iter();
let crate_expr = iter
.next()
.unwrap_or_else(|| abort_call_site!("ir-assert: expected crate path"));
let predicate_expr = iter
.next()
.unwrap_or_else(|| abort_call_site!("ir-assert: expected predicate expression"));
let targets: Vec<&Expr> = iter.collect();
if targets.is_empty() {
abort_call_site!("ir-assert: expected at least one target function/closure after the predicate");
}
if debug_only && targets.len() > 1 && !debug_assertions_active() && std::env::var("IR_ASSERT_IR_GEN").is_err() {
abort!(
quote! { #(#targets)* },
"debug_assert_ir! does not support multiple targets when debug_assertions is disabled"
);
}
let cg = CodegenInput::parse(crate_expr, predicate_expr, &targets);
let container_fn = cg.container_fn();
let call = cg.macro_internal_call();
let return_tokens = cg.return_expr().unwrap_or_default();
quote! {
{
#container_fn
#(if debug_only) {
#[cfg(debug_assertions)]
{ #call }
}
#(else) {
#call
}
#return_tokens
}
}
}
fn debug_assertions_active() -> bool {
let args: Vec<String> = std::env::args().collect();
let cflag_val = |key: &str| -> Option<String> {
for w in args.windows(2) {
if w[0] == "-C" {
if let Some(v) = w[1].strip_prefix(key) {
return Some(v.to_owned());
}
}
}
for a in &args {
if let Some(v) = a.strip_prefix(&format!("-C{key}")) {
return Some(v.to_owned());
}
}
None
};
if let Some(val) = cflag_val("debug-assertions=") {
return val == "yes" || val == "1";
}
matches!(cflag_val("opt-level=").as_deref(), None | Some("0"))
}
#[proc_macro_error]
#[proc_macro]
pub fn __assert_ir_impl(input: TokenStream1) -> TokenStream1 {
codegen(input.into(), false).into()
}
#[proc_macro_error]
#[proc_macro]
pub fn __debug_assert_ir_impl(input: TokenStream1) -> TokenStream1 {
codegen(input.into(), true).into()
}