#![allow(clippy::default_trait_access)]
use darling::FromMeta;
use proc_macro::TokenStream;
use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use std::{convert::identity, str::FromStr};
use syn::{
parse::Parser, parse_macro_input, parse_quote, punctuated::Punctuated, token, Attribute,
AttributeArgs, Block, Expr, FnArg, GenericArgument, GenericMethodArgument, GenericParam,
Generics, Ident, ImplItem, ImplItemMethod, ItemFn, ItemImpl, ItemMod, Pat, Path, PathArguments,
PathSegment, ReturnType, Signature, Stmt, Type, TypePath, TypeReference, Visibility,
};
use unzip_n::unzip_n;
mod util;
#[derive(FromMeta)]
struct TestFuzzImplOpts {}
#[proc_macro_attribute]
pub fn test_fuzz_impl(args: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = parse_macro_input!(args as AttributeArgs);
let _ = TestFuzzImplOpts::from_list(&attr_args).unwrap();
let item = parse_macro_input!(item as ItemImpl);
let ItemImpl {
attrs,
defaultness,
unsafety,
impl_token,
generics,
trait_,
self_ty,
brace_token: _,
items,
} = item;
let (trait_path, trait_) = trait_.map_or((None, None), |(bang, path, for_)| {
(Some(path.clone()), Some(quote! { #bang #path #for_ }))
});
let (impl_items, modules) = map_impl_items(&generics, &trait_path, &*self_ty, &items);
let result = quote! {
#(#attrs)* #defaultness #unsafety #impl_token #generics #trait_ #self_ty {
#(#impl_items)*
}
#(#modules)*
};
log(&result.to_token_stream());
result.into()
}
fn map_impl_items(
generics: &Generics,
trait_path: &Option<Path>,
self_ty: &Type,
items: &[ImplItem],
) -> (Vec<ImplItem>, Vec<ItemMod>) {
let impl_items_modules = items
.iter()
.map(map_impl_item(generics, trait_path, self_ty));
let (impl_items, modules): (Vec<_>, Vec<_>) = impl_items_modules.unzip();
let modules = modules.into_iter().filter_map(identity).collect();
(impl_items, modules)
}
fn map_impl_item(
generics: &Generics,
trait_path: &Option<Path>,
self_ty: &Type,
) -> impl Fn(&ImplItem) -> (ImplItem, Option<ItemMod>) {
let generics = generics.clone();
let trait_path = trait_path.clone();
let self_ty = self_ty.clone();
move |impl_item| {
if let ImplItem::Method(method) = &impl_item {
map_method(&generics, &trait_path, &self_ty, method)
} else {
(impl_item.clone(), None)
}
}
}
fn map_method(
generics: &Generics,
trait_path: &Option<Path>,
self_ty: &Type,
method: &ImplItemMethod,
) -> (ImplItem, Option<ItemMod>) {
let ImplItemMethod {
attrs,
vis,
defaultness,
sig,
block,
} = &method;
let mut attrs = attrs.clone();
if let Some(i) = attrs.iter().position(is_test_fuzz) {
let attr = attrs.remove(i);
let opts = opts_from_attr(&attr);
let (method, module) = map_method_or_fn(
&generics.clone(),
trait_path,
&Some(self_ty.clone()),
&opts,
&attrs,
vis,
defaultness,
sig,
block,
);
(parse_quote!( #method ), module)
} else {
(parse_quote!( #method ), None)
}
}
#[derive(Clone, Debug, Default, FromMeta)]
struct TestFuzzOpts {
#[darling(default)]
enable_in_production: bool,
#[darling(default)]
rename: Option<Ident>,
#[darling(default)]
skip: bool,
#[darling(default)]
specialize: Option<String>,
#[darling(default)]
specialize_impl: Option<String>,
}
#[proc_macro_attribute]
pub fn test_fuzz(args: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = parse_macro_input!(args as AttributeArgs);
let opts = TestFuzzOpts::from_list(&attr_args).unwrap();
let item = parse_macro_input!(item as ItemFn);
let ItemFn {
attrs,
vis,
sig,
block,
} = &item;
let (item, module) = map_method_or_fn(
&Generics::default(),
&None,
&None,
&opts,
attrs,
vis,
&None,
sig,
block,
);
let result = quote! {
#item
#module
};
log(&result.to_token_stream());
result.into()
}
#[allow(
clippy::ptr_arg,
clippy::too_many_arguments,
clippy::trivially_copy_pass_by_ref
)]
fn map_method_or_fn(
generics: &Generics,
trait_path: &Option<Path>,
self_ty: &Option<Type>,
opts: &TestFuzzOpts,
attrs: &Vec<Attribute>,
vis: &Visibility,
defaultness: &Option<token::Default>,
sig: &Signature,
block: &Block,
) -> (TokenStream2, Option<ItemMod>) {
let stmts = &block.stmts;
let opts_specialize = opts
.specialize
.as_ref()
.map(|s| parse_generic_method_arguments(s));
let opts_specialize_impl = opts
.specialize_impl
.as_ref()
.map(|s| parse_generic_method_arguments(s));
if opts.skip {
return (
parse_quote! {
#(#attrs)* #vis #defaultness #sig {
#(#stmts)*
}
},
None,
);
}
let combined_generics = combine_generics(generics, &sig.generics);
let combined_generics_deserializable = restrict_to_deserialize(&combined_generics);
let (impl_generics, ty_generics, where_clause) = combined_generics.split_for_impl();
let (impl_generics_deserializable, _, _) = combined_generics_deserializable.split_for_impl();
let ty_generics_as_turbofish = ty_generics.as_turbofish();
let target_specialization = opts_specialize.as_ref().map(args_as_turbofish);
let combined_specialization =
combine_options(opts_specialize_impl, opts_specialize, |mut left, right| {
left.extend(right);
left
})
.as_ref()
.map(args_as_turbofish);
let (receiver, arg_tys, fmt_args, ser_args, de_args) = map_args(self_ty, sig.inputs.iter());
let args_is: Vec<TokenStream2> = (0..sig.inputs.len())
.map(|i| {
let i = Literal::usize_unsuffixed(i);
quote! { args . #i }
})
.collect();
let pub_arg_tys: Vec<TokenStream2> = arg_tys.iter().map(|ty| quote! { pub #ty }).collect();
let def_args: Vec<Expr> = arg_tys
.iter()
.map(|ty| {
parse_quote! {
test_fuzz::runtime::TryDefault::<#ty>::default()?
}
})
.collect();
let ret_ty = match &sig.output {
ReturnType::Type(_, ty) => self_ty.as_ref().map_or(*ty.clone(), |self_ty| {
util::expand_self(self_ty, trait_path, ty)
}),
ReturnType::Default => parse_quote! { () },
};
let target_ident = &sig.ident;
let renamed_target_ident = opts.rename.as_ref().unwrap_or(target_ident);
let mod_ident = Ident::new(&format!("{}_fuzz", renamed_target_ident), Span::call_site());
let (in_production_write_args, mod_attr) = if opts.enable_in_production {
(
quote! {
#[cfg(not(test))]
if test_fuzz::runtime::write_enabled() {
#mod_ident :: write_args(#mod_ident :: Args(
#(#ser_args),*
));
}
},
quote! {},
)
} else {
(
quote! {},
quote! {
#[cfg(test)]
},
)
};
let input_args = {
#[cfg(feature = "persistent")]
quote! {}
#[cfg(not(feature = "persistent"))]
quote! {
let mut args = UsingReader::<_>::read_args #combined_specialization (std::io::stdin());
}
};
let output_args = {
#[cfg(feature = "persistent")]
quote! {}
#[cfg(not(feature = "persistent"))]
quote! {
args.as_ref().map(|x| {
if test_fuzz::runtime::pretty_print_enabled() {
eprint!("{:#?}", x);
} else {
eprint!("{:?}", x);
};
});
eprintln!();
}
};
let call: Expr = if receiver {
let mut de_args = de_args.iter();
let self_arg = de_args
.next()
.expect("should have at least one deserialized argument");
parse_quote! {
#self_arg . #target_ident #target_specialization (
#(#de_args),*
)
}
} else if let Some(self_ty) = self_ty {
parse_quote! {
#self_ty :: #target_ident #target_specialization (
#(#de_args),*
)
}
} else {
parse_quote! {
super :: #target_ident #target_specialization (
#(#de_args),*
)
}
};
let call_with_deserialized_arguments = {
#[cfg(feature = "persistent")]
quote! {
test_fuzz::afl::fuzz!(|data: &[u8]| {
let mut args = UsingReader::<_>::read_args #combined_specialization (data);
let ret = args.map(|mut args|
#call
);
});
}
#[cfg(not(feature = "persistent"))]
quote! {
let ret = args.map(|mut args|
#call
);
}
};
let output_ret = {
#[cfg(feature = "persistent")]
quote! {
let _: Option<#ret_ty> = None;
}
#[cfg(not(feature = "persistent"))]
quote! {
struct Ret(#ret_ty);
impl std::fmt::Debug for Ret {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use test_fuzz::runtime::TryDebugFallback;
let mut debug_tuple = fmt.debug_tuple("Ret");
test_fuzz::runtime::TryDebug(&self.0).apply(&mut |value| {
debug_tuple.field(value);
});
debug_tuple.finish()
}
}
let ret = ret.map(Ret);
ret.map(|x| {
if test_fuzz::runtime::pretty_print_enabled() {
eprint!("{:#?}", x);
} else {
eprint!("{:?}", x);
};
});
eprintln!();
}
};
(
parse_quote! {
#(#attrs)* #vis #defaultness #sig {
#[cfg(test)]
if !test_fuzz::runtime::test_fuzz_enabled() {
#mod_ident :: write_args(#mod_ident :: Args(
#(#ser_args),*
));
}
#in_production_write_args
#(#stmts)*
}
},
Some(parse_quote! {
#mod_attr
mod #mod_ident {
use super::*;
pub(super) struct Args #ty_generics (
#(#pub_arg_tys),*
);
pub(super) fn write_args #impl_generics (args: Args #ty_generics_as_turbofish) #where_clause {
#[derive(serde::Serialize)]
struct Args #ty_generics (
#(#pub_arg_tys),*
);
let args = Args(
#(#args_is),*
);
test_fuzz::runtime::write_args(&args);
}
struct UsingReader<R>(R);
impl<R: std::io::Read> UsingReader<R> {
pub fn read_args #impl_generics_deserializable (reader: R) -> Option<Args #ty_generics_as_turbofish> #where_clause {
#[derive(serde::Deserialize)]
struct Args #ty_generics (
#(#pub_arg_tys),*
);
let args = test_fuzz::runtime::read_args::<Args #ty_generics_as_turbofish, _>(reader);
args.map(|args| #mod_ident :: Args(
#(#args_is),*
))
}
}
impl #impl_generics std::fmt::Debug for Args #ty_generics #where_clause {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use test_fuzz::runtime::TryDebugFallback;
let mut debug_struct = fmt.debug_struct("Args");
#(#fmt_args)*
debug_struct.finish()
}
}
impl #impl_generics Args #ty_generics #where_clause {
fn write_default() {
if !test_fuzz::runtime::test_fuzz_enabled() {
use test_fuzz::runtime::TryDefaultFallback;
let args = (|| -> Option< #mod_ident :: Args #ty_generics_as_turbofish> {
Some(#mod_ident::Args(
#(#def_args),*
))
})();
args.map(|args| write_args(args));
}
}
}
#[test]
fn default() {
Args #combined_specialization :: write_default();
}
#[test]
fn entry() {
if test_fuzz::runtime::test_fuzz_enabled() {
if test_fuzz::runtime::display_enabled()
|| test_fuzz::runtime::replay_enabled()
{
#input_args
if test_fuzz::runtime::display_enabled() {
#output_args
}
if test_fuzz::runtime::replay_enabled() {
#call_with_deserialized_arguments
#output_ret
}
} else {
std::panic::set_hook(std::boxed::Box::new(|_| std::process::abort()));
#input_args
#call_with_deserialized_arguments
let _ = std::panic::take_hook();
}
}
}
}
}),
)
}
fn map_args<'a, I>(
self_ty: &Option<Type>,
inputs: I,
) -> (bool, Vec<Type>, Vec<Stmt>, Vec<Expr>, Vec<Expr>)
where
I: Iterator<Item = &'a FnArg>,
{
unzip_n!(5);
let (receiver, ty, fmt, ser, de): (Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
inputs.enumerate().map(map_arg(self_ty)).unzip_n();
let receiver = receiver.first().map_or(false, |&x| x);
(receiver, ty, fmt, ser, de)
}
fn map_arg(self_ty: &Option<Type>) -> impl Fn((usize, &FnArg)) -> (bool, Type, Stmt, Expr, Expr) {
let self_ty = self_ty.clone();
move |(i, arg)| {
let i = Literal::usize_unsuffixed(i);
match arg {
FnArg::Receiver(_) => (
true,
parse_quote! { #self_ty },
parse_quote! {
test_fuzz::runtime::TryDebug(&self.#i).apply(&mut |value| {
debug_struct.field("self", value);
});
},
parse_quote! { self.clone() },
parse_quote! { args.#i },
),
FnArg::Typed(pat_ty) => {
let pat = &*pat_ty.pat;
let ty = &*pat_ty.ty;
let name = format!("{}", pat.to_token_stream());
let fmt = parse_quote! {
test_fuzz::runtime::TryDebug(&self.#i).apply(&mut |value| {
debug_struct.field(#name, value);
});
};
let default = (
false,
parse_quote! { #ty },
parse_quote! { #fmt },
parse_quote! { #pat.clone() },
parse_quote! { args.#i },
);
match ty {
Type::Path(path) => map_arc_arg(&i, pat, path)
.map_or(default, |(ty, ser, de)| (false, ty, fmt, ser, de)),
Type::Reference(ty) => {
let (ty, ser, de) = map_ref_arg(&i, pat, ty);
(false, ty, fmt, ser, de)
}
_ => default,
}
}
}
}
}
fn map_arc_arg(i: &Literal, pat: &Pat, path: &TypePath) -> Option<(Type, Expr, Expr)> {
if let Some(PathArguments::AngleBracketed(args)) =
util::match_type_path(path, &["std", "sync", "Arc"])
{
if args.args.len() == 1 {
if let GenericArgument::Type(ty) = &args.args[0] {
Some((
parse_quote! { #ty },
parse_quote! { (*#pat).clone() },
parse_quote! { std::sync::Arc::new(args.#i) },
))
} else {
None
}
} else {
None
}
} else {
None
}
}
fn map_ref_arg(i: &Literal, pat: &Pat, ty: &TypeReference) -> (Type, Expr, Expr) {
match &*ty.elem {
Type::Path(path) if util::match_type_path(path, &["str"]) == Some(PathArguments::None) => (
parse_quote! { String },
parse_quote! { #pat.to_owned() },
parse_quote! { args.#i.as_str() },
),
Type::Slice(ty) => {
let ty = &*ty.elem;
(
parse_quote! { Vec<#ty> },
parse_quote! { #pat.to_vec() },
parse_quote! { args.#i.as_slice() },
)
}
_ => {
let mutability = if ty.mutability.is_some() {
quote! { mut }
} else {
quote! {}
};
let ty = &*ty.elem;
(
parse_quote! { #ty },
parse_quote! { (*#pat).clone() },
parse_quote! { & #mutability args.#i },
)
}
}
}
fn opts_from_attr(attr: &Attribute) -> TestFuzzOpts {
attr.parse_args::<TokenStream2>()
.map_or(TestFuzzOpts::default(), |tokens| {
let attr_args = parse_macro_input::parse::<AttributeArgs>(tokens.into()).unwrap();
TestFuzzOpts::from_list(&attr_args).unwrap()
})
}
fn is_test_fuzz(attr: &Attribute) -> bool {
attr.path
.segments
.iter()
.all(|PathSegment { ident, .. }| ident == "test_fuzz")
}
fn parse_generic_method_arguments(s: &str) -> Punctuated<GenericMethodArgument, token::Comma> {
let tokens = TokenStream::from_str(s).unwrap();
Parser::parse(Punctuated::<Type, token::Comma>::parse_terminated, tokens)
.unwrap()
.into_iter()
.map(GenericMethodArgument::Type)
.collect()
}
fn combine_generics(left: &Generics, right: &Generics) -> Generics {
let mut generics = left.clone();
generics.params.extend(right.params.clone());
generics.where_clause = combine_options(
generics.where_clause,
right.where_clause.clone(),
|mut left, right| {
left.predicates.extend(right.predicates);
left
},
);
generics
}
fn combine_options<T, F>(x: Option<T>, y: Option<T>, f: F) -> Option<T>
where
F: FnOnce(T, T) -> T,
{
match (x, y) {
(Some(x), Some(y)) => Some(f(x, y)),
(x, None) => x,
(None, y) => y,
}
}
fn restrict_to_deserialize(generics: &Generics) -> Generics {
let mut generics = generics.clone();
generics.params.iter_mut().for_each(|param| {
if let GenericParam::Type(ty_param) = param {
ty_param
.bounds
.push(parse_quote! { serde::de::DeserializeOwned });
}
});
generics
}
fn args_as_turbofish(args: &Punctuated<GenericMethodArgument, token::Comma>) -> TokenStream2 {
quote! {
::<#args>
}
}
fn log(tokens: &TokenStream2) {
if log_enabled() {
println!("{}", tokens);
}
}
fn log_enabled() -> bool {
option_env!("TEST_FUZZ_LOG").map_or(false, |value| value != "0")
}