derive_arbitrary 1.3.2

Derives arbitrary traits
Documentation
extern crate proc_macro;

use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::*;

mod container_attributes;
mod field_attributes;
use container_attributes::ContainerAttributes;
use field_attributes::{determine_field_constructor, FieldConstructor};

static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";

#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
    expand_derive_arbitrary(input)
        .unwrap_or_else(syn::Error::into_compile_error)
        .into()
}

fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
    let container_attrs = ContainerAttributes::from_derive_input(&input)?;

    let (lifetime_without_bounds, lifetime_with_bounds) =
        build_arbitrary_lifetime(input.generics.clone());

    let recursive_count = syn::Ident::new(
        &format!("RECURSIVE_COUNT_{}", input.ident),
        Span::call_site(),
    );

    let arbitrary_method =
        gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
    let size_hint_method = gen_size_hint_method(&input)?;
    let name = input.ident;

    // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
    let generics = apply_trait_bounds(
        input.generics,
        lifetime_without_bounds.clone(),
        &container_attrs,
    )?;

    // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
    let mut generics_with_lifetime = generics.clone();
    generics_with_lifetime
        .params
        .push(GenericParam::Lifetime(lifetime_with_bounds));
    let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();

    // Build TypeGenerics and WhereClause without a lifetime
    let (_, ty_generics, where_clause) = generics.split_for_impl();

    Ok(quote! {
        const _: () = {
            std::thread_local! {
                #[allow(non_upper_case_globals)]
                static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0);
            }

            #[automatically_derived]
            impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
                #arbitrary_method
                #size_hint_method
            }
        };
    })
}

// Returns: (lifetime without bounds, lifetime with bounds)
// Example: ("'arbitrary", "'arbitrary: 'a + 'b")
fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
    let lifetime_without_bounds =
        LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
    let mut lifetime_with_bounds = lifetime_without_bounds.clone();

    for param in generics.params.iter() {
        if let GenericParam::Lifetime(lifetime_def) = param {
            lifetime_with_bounds
                .bounds
                .push(lifetime_def.lifetime.clone());
        }
    }

    (lifetime_without_bounds, lifetime_with_bounds)
}

fn apply_trait_bounds(
    mut generics: Generics,
    lifetime: LifetimeParam,
    container_attrs: &ContainerAttributes,
) -> Result<Generics> {
    // If user-supplied bounds exist, apply them to their matching type parameters.
    if let Some(config_bounds) = &container_attrs.bounds {
        let mut config_bounds_applied = 0;
        for param in generics.params.iter_mut() {
            if let GenericParam::Type(type_param) = param {
                if let Some(replacement) = config_bounds
                    .iter()
                    .flatten()
                    .find(|p| p.ident == type_param.ident)
                {
                    *type_param = replacement.clone();
                    config_bounds_applied += 1;
                } else {
                    // If no user-supplied bounds exist for this type, delete the original bounds.
                    // This mimics serde.
                    type_param.bounds = Default::default();
                    type_param.default = None;
                }
            }
        }
        let config_bounds_supplied = config_bounds
            .iter()
            .map(|bounds| bounds.len())
            .sum::<usize>();
        if config_bounds_applied != config_bounds_supplied {
            return Err(Error::new(
                Span::call_site(),
                format!(
                    "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
                    ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
                ),
            ));
        }
        Ok(generics)
    } else {
        // Otherwise, inject a `T: Arbitrary` bound for every parameter.
        Ok(add_trait_bounds(generics, lifetime))
    }
}

// Add a bound `T: Arbitrary` to every type parameter T.
fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
    for param in generics.params.iter_mut() {
        if let GenericParam::Type(type_param) = param {
            type_param
                .bounds
                .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
        }
    }
    generics
}

fn with_recursive_count_guard(
    recursive_count: &syn::Ident,
    expr: impl quote::ToTokens,
) -> impl quote::ToTokens {
    quote! {
        let guard_against_recursion = u.is_empty();
        if guard_against_recursion {
            #recursive_count.with(|count| {
                if count.get() > 0 {
                    return Err(arbitrary::Error::NotEnoughData);
                }
                count.set(count.get() + 1);
                Ok(())
            })?;
        }

        let result = (|| { #expr })();

        if guard_against_recursion {
            #recursive_count.with(|count| {
                count.set(count.get() - 1);
            });
        }

        result
    }
}

fn gen_arbitrary_method(
    input: &DeriveInput,
    lifetime: LifetimeParam,
    recursive_count: &syn::Ident,
) -> Result<TokenStream> {
    fn arbitrary_structlike(
        fields: &Fields,
        ident: &syn::Ident,
        lifetime: LifetimeParam,
        recursive_count: &syn::Ident,
    ) -> Result<TokenStream> {
        let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
        let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });

        let arbitrary_take_rest = construct_take_rest(fields)?;
        let take_rest_body =
            with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });

        Ok(quote! {
            fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
                #body
            }

            fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
                #take_rest_body
            }
        })
    }

    let ident = &input.ident;
    let output = match &input.data {
        Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
        Data::Union(data) => arbitrary_structlike(
            &Fields::Named(data.fields.clone()),
            ident,
            lifetime,
            recursive_count,
        )?,
        Data::Enum(data) => {
            let variants: Vec<TokenStream> = data
                .variants
                .iter()
                .enumerate()
                .map(|(i, variant)| {
                    let idx = i as u64;
                    let variant_name = &variant.ident;
                    construct(&variant.fields, |_, field| gen_constructor_for_field(field))
                        .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
                })
                .collect::<Result<_>>()?;

            let variants_take_rest: Vec<TokenStream> = data
                .variants
                .iter()
                .enumerate()
                .map(|(i, variant)| {
                    let idx = i as u64;
                    let variant_name = &variant.ident;
                    construct_take_rest(&variant.fields)
                        .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
                })
                .collect::<Result<_>>()?;

            let count = data.variants.len() as u64;

            let arbitrary = with_recursive_count_guard(
                recursive_count,
                quote! {
                    // Use a multiply + shift to generate a ranged random number
                    // with slight bias. For details, see:
                    // https://lemire.me/blog/2016/06/30/fast-random-shuffling
                    Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
                        #(#variants,)*
                        _ => unreachable!()
                    })
                },
            );

            let arbitrary_take_rest = with_recursive_count_guard(
                recursive_count,
                quote! {
                    // Use a multiply + shift to generate a ranged random number
                    // with slight bias. For details, see:
                    // https://lemire.me/blog/2016/06/30/fast-random-shuffling
                    Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
                        #(#variants_take_rest,)*
                        _ => unreachable!()
                    })
                },
            );

            quote! {
                fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
                    #arbitrary
                }

                fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
                    #arbitrary_take_rest
                }
            }
        }
    };
    Ok(output)
}

fn construct(
    fields: &Fields,
    ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
) -> Result<TokenStream> {
    let output = match fields {
        Fields::Named(names) => {
            let names: Vec<TokenStream> = names
                .named
                .iter()
                .enumerate()
                .map(|(i, f)| {
                    let name = f.ident.as_ref().unwrap();
                    ctor(i, f).map(|ctor| quote! { #name: #ctor })
                })
                .collect::<Result<_>>()?;
            quote! { { #(#names,)* } }
        }
        Fields::Unnamed(names) => {
            let names: Vec<TokenStream> = names
                .unnamed
                .iter()
                .enumerate()
                .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
                .collect::<Result<_>>()?;
            quote! { ( #(#names),* ) }
        }
        Fields::Unit => quote!(),
    };
    Ok(output)
}

fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
    construct(fields, |idx, field| {
        determine_field_constructor(field).map(|field_constructor| match field_constructor {
            FieldConstructor::Default => quote!(Default::default()),
            FieldConstructor::Arbitrary => {
                if idx + 1 == fields.len() {
                    quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
                } else {
                    quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
                }
            }
            FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
            FieldConstructor::Value(value) => quote!(#value),
        })
    })
}

fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
    let size_hint_fields = |fields: &Fields| {
        fields
            .iter()
            .map(|f| {
                let ty = &f.ty;
                determine_field_constructor(f).map(|field_constructor| {
                    match field_constructor {
                        FieldConstructor::Default | FieldConstructor::Value(_) => {
                            quote!((0, Some(0)))
                        }
                        FieldConstructor::Arbitrary => {
                            quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
                        }

                        // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
                        // just an educated guess, although it's gonna be inaccurate for dynamically
                        // allocated types (Vec, HashMap, etc.).
                        FieldConstructor::With(_) => {
                            quote! { (::core::mem::size_of::<#ty>(), None) }
                        }
                    }
                })
            })
            .collect::<Result<Vec<TokenStream>>>()
            .map(|hints| {
                quote! {
                    arbitrary::size_hint::and_all(&[
                        #( #hints ),*
                    ])
                }
            })
    };
    let size_hint_structlike = |fields: &Fields| {
        size_hint_fields(fields).map(|hint| {
            quote! {
                #[inline]
                fn size_hint(depth: usize) -> (usize, Option<usize>) {
                    arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
                }
            }
        })
    };
    match &input.data {
        Data::Struct(data) => size_hint_structlike(&data.fields),
        Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
        Data::Enum(data) => data
            .variants
            .iter()
            .map(|v| size_hint_fields(&v.fields))
            .collect::<Result<Vec<TokenStream>>>()
            .map(|variants| {
                quote! {
                    #[inline]
                    fn size_hint(depth: usize) -> (usize, Option<usize>) {
                        arbitrary::size_hint::and(
                            <u32 as arbitrary::Arbitrary>::size_hint(depth),
                            arbitrary::size_hint::recursion_guard(depth, |depth| {
                                arbitrary::size_hint::or_all(&[ #( #variants ),* ])
                            }),
                        )
                    }
                }
            }),
    }
}

fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
    let ctor = match determine_field_constructor(field)? {
        FieldConstructor::Default => quote!(Default::default()),
        FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
        FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
        FieldConstructor::Value(value) => quote!(#value),
    };
    Ok(ctor)
}