#![doc = include_str!("../README.md")]
use proc_macro::TokenStream;
use syn::{
parse_macro_input, AttributeArgs, GenericArgument, ItemFn, Lit, Meta, NestedMeta,
PathArguments, PathSegment, ReturnType, Type,
};
#[proc_macro_attribute]
pub fn assert_fn(args: TokenStream, item: TokenStream) -> TokenStream {
let raw_item = item.clone();
let item = parse_macro_input!(item as ItemFn);
let args = parse_macro_input!(args as AttributeArgs);
let return_type = get_return_type(&item);
let assert_message = get_message(&args);
let fn_name = item.sig.ident.to_string();
let macro_export = get_macro_export(&args);
let (params, values) = get_values_and_params(&item);
let (async_block, dot_await) = get_async(&item);
let tuple_destructure = get_tuple_destructure(&assert_message, &return_type);
let (if_result_open, if_result_close) = get_result_block(&return_type);
let assert_call = get_assert_call(&return_type);
let message = assert_message.map(|msg| msg.message).unwrap_or_default();
format!(
r#"
{macro_export}
macro_rules! assert_{fn_name} {{
({params_trimmed}$(,)?) => {{ {async_block} {{
let result = {fn_name}({values}){dot_await};
{if_result_open}
{tuple_destructure}
{assert_call}{message});
{if_result_close}
result
}}}};
({params}$($arg:tt)+) => {{ {async_block} {{
let result = {fn_name}({values}){dot_await};
{if_result_open}
{assert_call}, $($arg)*);
{if_result_close}
result
}}}};
}}
{original_fn}
"#,
macro_export = macro_export,
fn_name = fn_name,
params = params,
params_trimmed = params.trim_end_matches(|c| c == ','),
values = values.trim_end_matches(|c| c == ','),
async_block = async_block,
dot_await = dot_await,
tuple_destructure = tuple_destructure,
if_result_open = if_result_open,
if_result_close = if_result_close,
assert_call = assert_call,
message = message,
original_fn = raw_item.to_string()
)
.parse()
.expect("Generated invalid tokens")
}
enum AssertReturnType {
Bool,
Tuple(u8),
ResultBool,
ResultTuple(u8),
}
fn get_return_type(item: &ItemFn) -> AssertReturnType {
let fn_name = item.sig.ident.to_string();
let return_type = match &item.sig.output {
ReturnType::Default => panic!("{} does not return anything", fn_name),
ReturnType::Type(_, return_type) => *return_type.clone(),
};
match return_type {
Type::Path(path) => {
let last_segment = path
.path
.segments
.last()
.expect("{} returned an unexpected return type");
let path_ident = last_segment.ident.to_string();
if path_ident == "bool" {
AssertReturnType::Bool
} else if path_ident == "Result" {
get_return_result_type(&fn_name, last_segment)
} else {
panic!(
"{} must return a bool, tuple or a Result wrapping one of those types",
fn_name
)
}
}
Type::Tuple(tuple) => AssertReturnType::Tuple(tuple.elems.len() as u8),
_ => panic!(
"{} must return a bool, tuple or a Result wrapping one of those types",
fn_name
),
}
}
fn get_return_result_type(fn_name: &str, path_segment: &PathSegment) -> AssertReturnType {
let args = match &path_segment.arguments {
PathArguments::AngleBracketed(args) => args,
_ => panic!("{} returned an invalid Result type", fn_name),
};
let arg_type = match args
.args
.first()
.unwrap_or_else(|| panic!("{} returned an invalid Result type", fn_name))
{
GenericArgument::Type(arg_type) => arg_type,
_ => panic!("{} returned an invalid Result type", fn_name),
};
match arg_type {
Type::Path(arg_path) => {
let arg_path_ident = arg_path
.path
.segments
.last()
.unwrap_or_else(|| panic!("{} returned an invalid Result type", fn_name))
.ident
.clone();
if arg_path_ident == "bool" {
AssertReturnType::ResultBool
} else {
panic!("{} must return a Result of type bool or tuple", fn_name)
}
}
Type::Tuple(tuple) => AssertReturnType::ResultTuple(tuple.elems.len() as u8),
_ => panic!("{} must return a Result of type bool or tuple", fn_name),
}
}
fn get_values_and_params(item: &ItemFn) -> (String, String) {
item.sig.inputs.iter().enumerate().fold(
("".to_string(), "".to_string()),
|(params, values), (n, _)| {
(
format!("{}$arg_{}:expr,", params, n),
format!("{}$arg_{},", values, n),
)
},
)
}
fn get_async(item: &ItemFn) -> (String, String) {
if item.sig.asyncness.is_some() {
("async".to_string(), ".await".to_string())
} else {
("".to_string(), "".to_string())
}
}
fn get_tuple_destructure(
assert_message: &Option<AssertMessage>,
return_type: &AssertReturnType,
) -> String {
if let Some(mut args) = assert_message
.clone()
.map(|msg| msg.args)
.filter(|args| !args.is_empty())
{
let tuple_size = match return_type {
AssertReturnType::Bool | AssertReturnType::ResultBool => {
panic!("Tried to use message args on function with boolean return type")
}
AssertReturnType::Tuple(n) | AssertReturnType::ResultTuple(n) => *n,
};
while (args.len() as u8) < tuple_size {
args.push("_".to_string());
}
format!("let ({}) = result;", args.join(", "))
} else {
"".to_string()
}
}
fn get_result_block(return_type: &AssertReturnType) -> (String, String) {
if matches!(
return_type,
AssertReturnType::ResultBool | AssertReturnType::ResultTuple(_)
) {
(
"if let Ok(result) = result.as_ref() {".to_string(),
"}".to_string(),
)
} else {
("".to_string(), "".to_string())
}
}
fn get_assert_call(return_type: &AssertReturnType) -> String {
match return_type {
AssertReturnType::Bool | AssertReturnType::ResultBool => "assert!(result".to_string(),
AssertReturnType::Tuple(_) | AssertReturnType::ResultTuple(_) => {
"assert_eq!(result.0, result.1".to_string()
}
}
}
#[derive(Clone)]
struct AssertMessage {
message: String,
args: Vec<String>,
}
fn get_message(args: &[NestedMeta]) -> Option<AssertMessage> {
args.iter()
.filter_map(|item| match item {
NestedMeta::Meta(Meta::List(list)) => Some(list),
_ => None,
})
.filter_map(|list| {
list.path
.segments
.last()
.filter(|seg| seg.ident == "message")
.map(|_| list.nested.clone())
})
.find_map(|params| {
let mut iter = params.into_iter();
match iter.next() {
Some(NestedMeta::Lit(Lit::Str(str))) => Some(str.value()),
_ => None,
}
.map(|message| {
let args = iter
.filter_map(|nested_meta| match nested_meta {
NestedMeta::Meta(Meta::Path(path)) => path.segments.last().cloned(),
_ => None,
})
.map(|seg| seg.ident.to_string())
.collect::<Vec<_>>();
let message_args = if args.is_empty() {
"".to_string()
} else {
format!(
", {}",
args.iter()
.filter(|arg| *arg != "_")
.map(|arg| format!("{}={}", arg, arg))
.collect::<Vec<_>>()
.join(", ")
)
};
let message = format!(", \"{}\"{}", message, message_args);
AssertMessage { message, args }
})
})
}
fn get_macro_export(args: &[NestedMeta]) -> String {
args.iter()
.find_map(|item| match item {
NestedMeta::Meta(Meta::Path(path)) => path.segments.last(),
_ => None,
})
.filter(|seg| seg.ident == "export")
.map(|_| "#[macro_export]".to_string())
.unwrap_or_default()
}