#![deny(
missing_docs,
missing_debug_implementations,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unstable_features,
unused_import_braces,
unused_qualifications,
rustdoc::broken_intra_doc_links,
rustdoc::private_intra_doc_links,
rustdoc::missing_crate_level_docs,
rustdoc::invalid_codeblock_attributes,
rustdoc::bare_urls
)]
use std::{
borrow::Borrow,
fmt::{Display, Write},
};
use proc_macro::TokenStream as TokenStream1;
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
mod error;
mod format_message;
mod not;
mod utils;
mod variables;
use error::*;
use format_message::*;
use utils::*;
use variables::*;
struct Args {
expr: syn::Expr,
format: TokenStream,
}
impl syn::parse::Parse for Args {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
if input.is_empty() {
let msg = "missing condition to check";
return Err(syn::Error::new(Span::call_site(), msg)); }
let span_source: TokenStream = input.fork().parse().unwrap(); let expr = match input.parse() {
Ok(expr) => expr,
Err(e) => {
let err = if input.is_empty() {
let msg = format!("incomplete expression: {e}");
syn::Error::new_spanned(span_source, msg) } else if let Ok(comma) = input.parse::<syn::Token![,]>() {
let msg = format!("Expression before the comma is incomplete: {e}");
syn::Error::new_spanned(comma, msg) } else {
e
};
return Err(err);
}
};
let format;
if input.is_empty() {
format = TokenStream::new();
} else if let Err(e) = input.parse::<syn::Token![,]>() {
let msg = "condition has to be followed by a comma, if a message is provided";
return Err(syn::Error::new(e.span(), msg)); } else {
format = input.parse()?;
}
Ok(Args { expr, format })
}
}
#[proc_macro]
pub fn assert(input: TokenStream1) -> TokenStream1 {
let input = syn::parse_macro_input!(input as Args);
match assert_internal(input) {
Ok(tokens) => tokens.into(),
Err(err) => err.into(),
}
}
fn assert_internal(input: Args) -> Result<TokenStream> {
let Args { expr, format } = input;
let expr_str = printable_expr_string(&expr);
if expr_str == "true" {
return Ok(assert_true_flavor());
} else if expr_str == "false" {
return Ok(quote! {
::std::panic!("surprisingly, `false` did not evaluate to true")
});
}
let mut setup = TokenStream::new();
let mut format_message = FormatMessage::new();
setup.extend(quote! { struct __OneAssertWrapper<T>(T); });
format_message.add_text(format!("assertion `{expr_str}` failed"));
if !format.is_empty() {
format_message.add_placeholder(": {}", quote! { ::std::format_args!(#format) });
}
let output = eval_expr(expr, setup, format_message)?;
Ok(output)
}
#[allow(clippy::match_same_arms)] fn eval_expr(
mut e: syn::Expr,
mut setup: TokenStream,
mut format_message: FormatMessage,
) -> Result<TokenStream> {
let mut assert_condition = e.to_token_stream();
let mut variables = Variables::new();
while let syn::Expr::Paren(syn::ExprParen { expr: inner, .. })
| syn::Expr::Group(syn::ExprGroup { expr: inner, .. }) = e
{
e = *inner;
}
match e {
syn::Expr::Array(_) => {}
syn::Expr::Assign(syn::ExprAssign { eq_token, .. }) => {
let msg = "Expected a boolean expression, found an assignment. Did you intend to compare with `==`?";
return Error::err_spanned(eq_token, msg); }
syn::Expr::Async(_) => {
let msg = "Expected a boolean expression, found an async block. Did you intend to await a future?";
return Error::err_spanned(e, msg); }
syn::Expr::Await(_) => {}
syn::Expr::Binary(syn::ExprBinary {
left,
op,
right,
attrs,
}) => {
match op {
syn::BinOp::And(_) => return Ok(resolve_and(setup, format_message, left, right)),
syn::BinOp::Or(_) => return Ok(resolve_or(setup, format_message, left, right)),
syn::BinOp::Eq(_)
| syn::BinOp::Lt(_)
| syn::BinOp::Le(_)
| syn::BinOp::Ne(_)
| syn::BinOp::Ge(_)
| syn::BinOp::Gt(_) => {
let lhs = variables.add_borrowed_var(left, "lhs", "left");
let rhs = variables.add_borrowed_var(right, "rhs", "right");
assert_condition = quote! { #(#attrs)* #lhs #op #rhs };
}
syn::BinOp::Add(_)
| syn::BinOp::Sub(_)
| syn::BinOp::Mul(_)
| syn::BinOp::Div(_)
| syn::BinOp::Rem(_)
| syn::BinOp::BitXor(_)
| syn::BinOp::BitAnd(_)
| syn::BinOp::BitOr(_)
| syn::BinOp::Shl(_)
| syn::BinOp::Shr(_) => {
let lhs = variables.add_moving_var(left, "lhs", "left");
let rhs = variables.add_moving_var(right, "rhs", "right");
assert_condition = quote! { #(#attrs)* #lhs #op #rhs };
}
syn::BinOp::AddAssign(_)
| syn::BinOp::SubAssign(_)
| syn::BinOp::MulAssign(_)
| syn::BinOp::DivAssign(_)
| syn::BinOp::RemAssign(_)
| syn::BinOp::BitXorAssign(_)
| syn::BinOp::BitAndAssign(_)
| syn::BinOp::BitOrAssign(_)
| syn::BinOp::ShlAssign(_)
| syn::BinOp::ShrAssign(_) => {
let msg = "Expected a boolean expression, found an assignment";
return Error::err_spanned(op, msg); }
_ => {}
}
}
syn::Expr::Block(_) => {}
syn::Expr::Break(_) => {
let msg = "Expected a boolean expression, found a break statement";
return Error::err_spanned(e, msg); }
syn::Expr::Call(syn::ExprCall {
args,
func,
paren_token,
attrs,
}) if !args.is_empty() => {
let index_len = (args.len() - 1).to_string().len();
let out_args = args.iter().enumerate().map(|(i, arg)| {
variables.add_moving_var(
arg,
format_args!("arg{i}"),
format_args!("arg {i:>index_len$}"),
)
});
assert_condition = quote! { #(#attrs)* #func };
paren_token.surround(&mut assert_condition, |out| {
out.extend(quote! { #(#out_args),* })
});
}
syn::Expr::Call(_) => {}
syn::Expr::Cast(_) => {}
syn::Expr::Closure(_) => {}
syn::Expr::Const(_) => {}
syn::Expr::Continue(_) => {
let msg = "Expected a boolean expression, found a continue statement";
return Error::err_spanned(e, msg); }
syn::Expr::Field(_) => {}
syn::Expr::ForLoop(_) => {
let msg = "Expected a boolean expression, found a for loop";
return Error::err_spanned(e, msg); }
syn::Expr::Group(_) => unreachable!(),
syn::Expr::If(expr_if) => return resolve_if(setup, format_message, expr_if),
syn::Expr::Index(syn::ExprIndex {
index,
expr,
attrs,
bracket_token,
}) => {
if !matches!(*index, syn::Expr::Lit(_)) {
let index = variables.add_moving_var(index, "index", "index");
assert_condition = quote! { #(#attrs)* #expr };
bracket_token.surround(&mut assert_condition, |out| index.to_tokens(out));
}
}
syn::Expr::Infer(_) => {}
syn::Expr::Let(_) => {
let msg = "Expected a boolean expression, found a let statement";
return Error::err_spanned(e, msg); }
syn::Expr::Lit(_) => {}
syn::Expr::Loop(_) => {}
syn::Expr::Macro(_) => {}
syn::Expr::Match(_) => {}
syn::Expr::MethodCall(syn::ExprMethodCall {
receiver,
method,
turbofish,
args,
attrs,
dot_token,
paren_token,
}) => {
let obj = variables.add_moving_var(receiver, "object", "self");
let index_len = (args.len().saturating_sub(1)).to_string().len();
let out_args = args.iter().enumerate().map(|(i, arg)| {
variables.add_moving_var(
arg,
format_args!("arg{i}"),
format_args!("arg {i:>index_len$}"),
)
});
assert_condition = quote! { #(#attrs)* #obj #dot_token #method #turbofish };
paren_token.surround(&mut assert_condition, |out| {
out.extend(quote! { #(#out_args),* })
});
}
syn::Expr::Paren(_) => unreachable!(),
syn::Expr::Path(_) => {}
syn::Expr::Range(_) => {}
syn::Expr::Reference(_) => {}
syn::Expr::Repeat(_) => {}
syn::Expr::Return(_) => {
let msg = "Expected a boolean expression, found a return statement";
return Error::err_spanned(e, msg); }
syn::Expr::Struct(_) => {
let msg = "Expected a boolean expression, found a struct literal";
return Error::err_spanned(e, msg);
}
syn::Expr::Try(_) => {}
syn::Expr::Tuple(_) => {}
syn::Expr::Unary(syn::ExprUnary {
expr,
op: syn::UnOp::Not(not_token),
attrs,
}) => {
return not::eval_not_expr(*expr, setup, format_message, not_token, attrs);
}
syn::Expr::Unary(_) => {}
syn::Expr::Unsafe(_) => {}
syn::Expr::Verbatim(_) => {}
syn::Expr::While(_) => {
let msg = "Expected a boolean expression, found a while loop";
return Error::err_spanned(e, msg);
}
_ => {} }
variables.resolve_variables(&mut setup, &mut format_message);
Ok(quote! { #[allow(unreachable_code)] {
#setup
if #assert_condition {
} else {
::std::panic!(#format_message);
}
}})
}
fn resolve_and(
setup: TokenStream,
format_message: FormatMessage,
left: impl Borrow<syn::Expr>,
right: impl Borrow<syn::Expr>,
) -> TokenStream {
let left = left.borrow();
let right = right.borrow();
let mut message_if_left_false = format_message.clone();
message_if_left_false.add_cause("left side of `&&` evaluated to false");
let mut message_if_right_false = format_message;
message_if_right_false
.add_cause("left side of `&&` evaluated to true, but right side evaluated to false");
quote! { #[allow(unreachable_code)] {
#setup
if #left {
if #right {
} else {
::std::panic!(#message_if_right_false);
}
} else {
::std::panic!(#message_if_left_false);
}
}}
}
fn resolve_or(
setup: TokenStream,
mut format_message: FormatMessage,
left: impl Borrow<syn::Expr>,
right: impl Borrow<syn::Expr>,
) -> TokenStream {
let left = left.borrow();
let right = right.borrow();
format_message.add_cause("both sides of `||` evaluated to false");
quote! { #[allow(unreachable_code)] {
#setup
if #left {
} else {
if #right {
} else {
::std::panic!(#format_message);
}
}
}}
}
fn resolve_if(
setup: TokenStream,
format_message: FormatMessage,
expr_if: syn::ExprIf,
) -> Result<TokenStream> {
let syn::ExprIf {
if_token,
cond,
then_branch,
mut else_branch,
..
} = expr_if;
let mut format_cond = format_message.clone();
format_cond.add_cause(format_args!(
"
- if condition `{}` was true
- then-block `{}` evaluated to false",
printable_expr_string(&cond),
printable_expr_string(&then_branch)
));
let mut out = quote! {
if #cond {
if #then_branch { } else { ::std::panic!(#format_cond); }
}
};
let mut cause_message = format!(
"
- if condition `{}` was false",
printable_expr_string(&cond)
);
loop {
let Some((_, else_expr)) = else_branch else {
let msg = "if-expression is missing a final else-block to handle the case where all conditions are false.
If you want a conditional assert, put the assert! inside the if block.";
return Err(Error::new_spanned(if_token, msg));
};
match *else_expr {
syn::Expr::If(nested_if) => {
let syn::ExprIf {
cond,
then_branch,
else_branch: inner_else_branch,
..
} = nested_if;
let mut format_cond = format_message.clone();
format_cond.add_cause(format_args!(
"{}
- else-if condition `{}` was true
- then-block `{}` evaluated to false",
cause_message,
printable_expr_string(&cond),
printable_expr_string(&then_branch)
));
out.extend(quote! {
else if #cond {
if #then_branch { } else { ::std::panic!(#format_cond); }
}
});
cause_message = format!(
"{}
- else-if condition `{}` was false",
cause_message,
printable_expr_string(&cond)
);
else_branch = inner_else_branch;
}
else_block => {
let mut format_else = format_message;
format_else.add_cause(format_args!(
"{}
- else-block `{}` evaluated to false",
cause_message,
printable_expr_string(&else_block)
));
out.extend(quote! {
else {
if #else_block { } else { ::std::panic!(#format_else); }
}
});
break;
}
}
}
Ok(quote! { #[allow(unreachable_code, unused_braces)] {
#setup
#out
}})
}
fn assert_true_flavor() -> TokenStream {
quote! {
let line = ::std::line!();
if line % 100 == 69 {
::std::panic!("You actually used `assert!(true)`? Nice.");
} else if line % 100 == 0 {
::std::panic!("Congratulations! You are the {}th person to use `assert!(true)`! You win a free panic!", line);
} else if line % 10 == 0 {
} else {
const MESSAGES: &[&'static ::std::primitive::str] = &[
"Ha! Did you think `assert!(true)` would do nothing? Fool!",
"assertion `true` failed:\n left: tr\n right: ue",
"assertion `true` failed: `true` did not evaluate to true",
"assertion `true` failed: `true` did not evaluate to true...? Huh? What? 🤔",
"Undefined reference to `true`. Did you mean `false`?",
"assertion `true` failed: `true` did not evaluate to true. What a surprise!",
];
let msg = MESSAGES[line as usize % MESSAGES.len()];
::std::panic!("{}", msg);
}
}
}