test-strategy 0.3.0

Procedural macro to easily write higher-order strategies in proptest.
Documentation
use proc_macro2::{Group, Spacing, Span, TokenStream, TokenTree};
use quote::{quote, ToTokens};
use std::{
    collections::{BTreeMap, HashSet},
    iter::once,
    ops::Deref,
};
use structmeta::{Parse, ToTokens};
use syn::{
    ext::IdentExt,
    parenthesized,
    parse::{Parse, ParseStream},
    parse2, parse_str,
    punctuated::Punctuated,
    spanned::Spanned,
    token::{Comma, Paren},
    visit::{visit_path, visit_type, Visit},
    Attribute, DeriveInput, Expr, Field, GenericParam, Generics, Ident, Lit, Path, Result, Token,
    Type, WherePredicate,
};

macro_rules! bail {
    ($span:expr, $message:literal $(,)?) => {
        return std::result::Result::Err(syn::Error::new($span, $message))
    };
    ($span:expr, $err:expr $(,)?) => {
        return std::result::Result::Err(syn::Error::new($span, $err))
    };
    ($span:expr, $fmt:expr, $($arg:tt)*) => {
        return std::result::Result::Err(syn::Error::new($span, std::format!($fmt, $($arg)*)))
    };
}

pub fn into_macro_output(input: Result<TokenStream>) -> proc_macro::TokenStream {
    match input {
        Ok(s) => s,
        Err(e) => e.to_compile_error(),
    }
    .into()
}

pub struct Parenthesized<T> {
    pub paren_token: Option<Paren>,
    pub content: T,
}
impl<T: Parse> Parse for Parenthesized<T> {
    fn parse(input: ParseStream) -> Result<Self> {
        let content;
        let paren_token = Some(parenthesized!(content in input));
        let content = content.parse()?;
        Ok(Self {
            paren_token,
            content,
        })
    }
}
impl<T> Deref for Parenthesized<T> {
    type Target = T;
    fn deref(&self) -> &Self::Target {
        &self.content
    }
}

pub fn parse_parenthesized_args(input: TokenStream) -> Result<Args> {
    if input.is_empty() {
        Ok(Args::new())
    } else {
        Ok(parse2::<Parenthesized<Args>>(input)?.content)
    }
}

#[derive(Parse)]
pub struct Args(#[parse(terminated)] Punctuated<Arg, Comma>);

impl Args {
    fn new() -> Self {
        Self(Punctuated::new())
    }
    pub fn expect_single_value(&self, span: Span) -> Result<&Expr> {
        if self.len() != 1 {
            bail!(
                span,
                "expect 1 arguments, but supplied {} arguments.",
                self.len()
            );
        }
        match &self[0] {
            Arg::Value(expr) => Ok(expr),
            Arg::NameValue { .. } => bail!(span, "expected unnamed argument."),
        }
    }
}
impl Deref for Args {
    type Target = Punctuated<Arg, Comma>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}
impl IntoIterator for Args {
    type Item = Arg;
    type IntoIter = <Punctuated<Arg, Comma> as IntoIterator>::IntoIter;

    fn into_iter(self) -> Self::IntoIter {
        self.0.into_iter()
    }
}

#[derive(ToTokens, Parse)]
pub enum Arg {
    NameValue {
        #[parse(peek, any)]
        name: Ident,
        #[parse(peek)]
        eq_token: Token![=],
        value: Expr,
    },
    Value(Expr),
}

pub struct SharpVals {
    allow_vals: bool,
    allow_self: bool,
    pub vals: BTreeMap<FieldKey, Span>,
    pub self_span: Option<Span>,
}
impl SharpVals {
    pub fn new(allow_vals: bool, allow_self: bool) -> Self {
        Self {
            allow_vals,
            allow_self,
            vals: BTreeMap::new(),
            self_span: None,
        }
    }
    pub fn expand(&mut self, input: TokenStream) -> Result<TokenStream> {
        let mut tokens = Vec::new();
        let mut iter = input.into_iter().peekable();
        while let Some(t) = iter.next() {
            match &t {
                TokenTree::Group(g) => {
                    tokens.push(TokenTree::Group(Group::new(
                        g.delimiter(),
                        self.expand(g.stream())?,
                    )));
                    continue;
                }
                TokenTree::Punct(p) => {
                    if p.as_char() == '#' && p.spacing() == Spacing::Alone {
                        if let Some(token) = iter.peek() {
                            if let Some(key) = FieldKey::try_from_token(token) {
                                let span = token.span();
                                let allow = if &key == "self" {
                                    self.self_span.get_or_insert(span);
                                    self.allow_self
                                } else {
                                    self.vals.entry(key.clone()).or_insert(span);
                                    self.allow_vals
                                };
                                if !allow {
                                    bail!(span, "cannot use `#{}` in this position.", key);
                                }
                                if self.self_span.is_some() {
                                    if let Some(key) = self.vals.keys().next() {
                                        bail!(span, "cannot use both `#self` and `#{}`", key);
                                    }
                                }
                                let mut ident = key.to_dummy_ident();
                                ident.set_span(span);
                                tokens.extend(ident.to_token_stream());
                                iter.next();
                                continue;
                            }
                        }
                    }
                }
                _ => {}
            }
            tokens.extend(once(t));
        }
        Ok(tokens.into_iter().collect())
    }
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
pub enum FieldKey {
    Named(String),
    Unnamed(usize),
}

impl FieldKey {
    pub fn from_ident(ident: &Ident) -> Self {
        Self::Named(ident.unraw().to_string())
    }
    pub fn from_field(idx: usize, field: &Field) -> Self {
        if let Some(ident) = &field.ident {
            Self::from_ident(ident)
        } else {
            Self::Unnamed(idx)
        }
    }
    pub fn try_from_token(token: &TokenTree) -> Option<Self> {
        match token {
            TokenTree::Ident(ident) => Some(Self::from_ident(ident)),
            TokenTree::Literal(token) => {
                if let Lit::Int(lit) = Lit::new(token.clone()) {
                    if lit.suffix().is_empty() {
                        if let Ok(idx) = lit.base10_parse() {
                            return Some(Self::Unnamed(idx));
                        }
                    }
                }
                None
            }
            _ => None,
        }
    }

    pub fn to_dummy_ident(&self) -> Ident {
        Ident::new(&format!("_{self}"), Span::call_site())
    }
    pub fn to_valid_ident(&self) -> Option<Ident> {
        match self {
            Self::Named(name) => to_valid_ident(name).ok(),
            Self::Unnamed(..) => None,
        }
    }
}
impl PartialEq<str> for FieldKey {
    fn eq(&self, other: &str) -> bool {
        match self {
            FieldKey::Named(name) => name == other,
            _ => false,
        }
    }
}
impl std::fmt::Display for FieldKey {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Named(name) => name.fmt(f),
            Self::Unnamed(idx) => idx.fmt(f),
        }
    }
}

pub struct GenericParamSet {
    idents: HashSet<Ident>,
}

impl GenericParamSet {
    pub fn new(generics: &Generics) -> Self {
        let mut idents = HashSet::new();
        for p in &generics.params {
            match p {
                GenericParam::Type(t) => {
                    idents.insert(t.ident.unraw());
                }
                GenericParam::Const(t) => {
                    idents.insert(t.ident.unraw());
                }
                _ => {}
            }
        }
        Self { idents }
    }
    fn contains(&self, ident: &Ident) -> bool {
        self.idents.contains(&ident.unraw())
    }

    pub fn contains_in_type(&self, ty: &Type) -> bool {
        struct Visitor<'a> {
            generics: &'a GenericParamSet,
            result: bool,
        }
        impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
            fn visit_path(&mut self, i: &'ast syn::Path) {
                if i.leading_colon.is_none() {
                    if let Some(s) = i.segments.iter().next() {
                        if self.generics.contains(&s.ident) {
                            self.result = true;
                        }
                    }
                }
                visit_path(self, i);
            }
        }
        let mut visitor = Visitor {
            generics: self,
            result: false,
        };
        visit_type(&mut visitor, ty);
        visitor.result
    }
}

pub fn impl_trait(
    input: &DeriveInput,
    trait_path: &Path,
    wheres: &[WherePredicate],
    contents: TokenStream,
) -> TokenStream {
    let ty = &input.ident;
    let (impl_g, ty_g, where_clause) = input.generics.split_for_impl();
    let mut wheres = wheres.to_vec();
    if let Some(where_clause) = where_clause {
        wheres.extend(where_clause.predicates.iter().cloned());
    }
    let where_clause = if wheres.is_empty() {
        quote! {}
    } else {
        quote! { where #(#wheres,)*}
    };
    quote! {
        #[automatically_derived]
        impl #impl_g #trait_path for #ty #ty_g #where_clause {
            #contents
        }
    }
}
pub fn impl_trait_result(
    input: &DeriveInput,
    trait_path: &Path,
    wheres: &[WherePredicate],
    contents: TokenStream,
    dump: bool,
) -> Result<TokenStream> {
    let ts = impl_trait(input, trait_path, wheres, contents);
    if dump {
        panic!("macro result: \n{ts}");
    }
    Ok(ts)
}

pub fn to_valid_ident(s: &str) -> Result<Ident> {
    if let Ok(ident) = parse_str(s) {
        Ok(ident)
    } else {
        parse_str(&format!("r#{s}"))
    }
}

pub fn parse_from_attrs<T: Parse + Default>(attrs: &[Attribute], name: &str) -> Result<T> {
    let mut a = None;
    for attr in attrs {
        if attr.path.is_ident(name) {
            if a.is_some() {
                bail!(attr.span(), "attribute `{}` can specified only once", name);
            }
            a = Some(attr);
        }
    }
    if let Some(a) = a {
        a.parse_args()
    } else {
        Ok(T::default())
    }
}