use proc_macro2::{Span, TokenStream};
use syn::spanned::Spanned;
use syn::{DeriveInput, Expr, Field, Ident, Path, Type, Variant};
use crate::ast::*;
use crate::attr::{self, ParamsMode, ParsedAttributes, StratMode};
use crate::error::{self, Context, Ctx, DeriveResult};
use crate::use_tracking::{UseMarkable, UseTracker};
use crate::util::{fields_to_vec, is_unit_type, self_ty};
use crate::void::IsUninhabited;
pub fn impl_proptest_arbitrary(ast: DeriveInput) -> TokenStream {
let mut ctx = Context::default();
let result = derive_proptest_arbitrary(&mut ctx, ast);
match (result, ctx.check()) {
(Ok(derive), Ok(())) => derive,
(_, Err(err)) => err,
(Err(result), Ok(())) => panic!(
"[proptest_derive]: internal error, this is a bug! \
result: {:?}",
result
),
}
}
struct DeriveData<B> {
ident: Ident,
attrs: ParsedAttributes,
tracker: UseTracker,
body: B,
}
fn derive_proptest_arbitrary(
ctx: Ctx,
ast: DeriveInput,
) -> DeriveResult<TokenStream> {
use syn::Data::*;
error::if_has_lifetimes(ctx, &ast);
let attrs = attr::parse_top_attributes(ctx, &ast.attrs)?;
let mut tracker = UseTracker::new(ast.generics);
if attrs.no_bound {
tracker.no_track();
}
let the_impl = match ast.data {
Struct(data) => derive_struct(
ctx,
DeriveData {
tracker,
attrs,
ident: ast.ident,
body: fields_to_vec(data.fields),
},
),
Enum(data) => derive_enum(
ctx,
DeriveData {
tracker,
attrs,
ident: ast.ident,
body: data.variants.into_iter().collect(),
},
),
_ => error::not_struct_or_enum(ctx)?,
}?;
let q = the_impl.into_tokens(ctx)?;
Ok(q)
}
fn derive_struct(
ctx: Ctx,
mut ast: DeriveData<Vec<Field>>,
) -> DeriveResult<Impl> {
error::if_enum_attrs_present(ctx, &ast.attrs, error::STRUCT);
error::if_strategy_present(ctx, &ast.attrs, error::STRUCT);
let v_path = ast.ident.clone().into();
let parts = if ast.body.is_empty() {
error::if_present_on_unit_struct(ctx, &ast.attrs);
let (strat, ctor) = pair_unit_self(&v_path);
(Params::empty(), strat, ctor)
} else {
if (&*ast.body).is_uninhabited() {
error::uninhabited_struct(ctx);
}
let closure = map_closure(v_path, &ast.body);
let parts = if let Some(param_ty) = ast.attrs.params.into_option() {
add_top_params(
param_ty,
derive_product_has_params(
ctx,
&mut ast.tracker,
error::STRUCT_FIELD,
closure,
ast.body,
)?,
)
} else {
derive_product_no_params(
ctx,
&mut ast.tracker,
ast.body,
error::STRUCT_FIELD,
)?
.finish(closure)
};
add_top_filter(ast.attrs.filter, parts)
};
Ok(Impl::new(ast.ident, ast.tracker, parts))
}
fn add_top_filter(filter: Vec<Expr>, parts: ImplParts) -> ImplParts {
let (params, strat, ctor) = parts;
let (strat, ctor) = add_filter_self(filter, (strat, ctor));
(params, strat, ctor)
}
fn add_filter_self(filter: Vec<Expr>, pair: StratPair) -> StratPair {
pair_filter(filter, self_ty(), pair)
}
fn add_top_params(
param_ty: Option<Type>,
(strat, ctor): StratPair,
) -> ImplParts {
let params = Params::empty();
if let Some(params_ty) = param_ty {
(params + params_ty, strat, extract_api(ctor, FromReg::Top))
} else {
(params, strat, ctor)
}
}
fn derive_product_has_params(
ctx: Ctx,
ut: &mut UseTracker,
item: &str,
closure: MapClosure,
fields: Vec<Field>,
) -> DeriveResult<StratPair> {
let len = fields.len();
fields
.into_iter()
.try_fold(StratAcc::new(len), |acc, field| {
let attrs = attr::parse_attributes(ctx, &field.attrs)?;
error::if_enum_attrs_present(ctx, &attrs, item);
error::if_specified_params(ctx, &attrs, item);
let span = field.span();
let ty = field.ty.clone();
let pair =
product_handle_default_params(ut, ty, span, attrs.strategy);
let pair = pair_filter(attrs.filter, field.ty, pair);
Ok(acc.add(pair))
})
.map(|acc| acc.finish(closure))
}
fn product_handle_default_params(
ut: &mut UseTracker,
ty: Type,
span: Span,
strategy: StratMode,
) -> StratPair {
match strategy {
StratMode::Strategy(strat) => pair_existential(ty, strat),
StratMode::Value(value) => pair_value(ty, value),
StratMode::Regex(regex) => pair_regex(ty, regex),
StratMode::Arbitrary => {
ty.mark_uses(ut);
pair_any(ty, span)
}
}
}
fn derive_product_no_params(
ctx: Ctx,
ut: &mut UseTracker,
fields: Vec<Field>,
item: &str,
) -> DeriveResult<PartsAcc<Ctor>> {
let acc = PartsAcc::new(fields.len());
fields.into_iter().try_fold(acc, |mut acc, field| {
let attrs = attr::parse_attributes(ctx, &field.attrs)?;
error::if_enum_attrs_present(ctx, &attrs, item);
let span = field.span();
let ty = field.ty;
let strat = pair_filter(
attrs.filter,
ty.clone(),
match attrs.params {
ParamsMode::Passthrough => match attrs.strategy {
StratMode::Strategy(strat) => pair_existential(ty, strat),
StratMode::Value(value) => pair_value(ty, value),
StratMode::Regex(regex) => pair_regex(ty, regex),
StratMode::Arbitrary => {
ty.mark_uses(ut);
let pref = acc.add_param(arbitrary_param(&ty));
pair_any_with(ty, pref, span)
}
},
ParamsMode::Default => {
product_handle_default_params(ut, ty, span, attrs.strategy)
}
ParamsMode::Specified(params_ty) =>
{
extract_nparam(
&mut acc,
params_ty,
match attrs.strategy {
StratMode::Strategy(strat) => {
pair_existential(ty, strat)
}
StratMode::Value(value) => {
pair_value_exist(ty, value)
}
StratMode::Regex(regex) => {
error::cant_set_param_and_regex(ctx, item);
pair_regex(ty, regex)
}
StratMode::Arbitrary => {
error::cant_set_param_but_not_strat(
ctx, &ty, item,
)?
}
},
)
}
},
);
Ok(acc.add_strat(strat))
})
}
fn extract_nparam<C>(
acc: &mut PartsAcc<C>,
params_ty: Type,
(strat, ctor): StratPair,
) -> StratPair {
(
strat,
extract_api(ctor, FromReg::Num(acc.add_param(params_ty))),
)
}
fn derive_enum(
ctx: Ctx,
mut ast: DeriveData<Vec<Variant>>,
) -> DeriveResult<Impl> {
error::if_skip_present(ctx, &ast.attrs, error::ENUM);
error::if_strategy_present(ctx, &ast.attrs, error::ENUM);
error::if_weight_present(ctx, &ast.attrs, error::ENUM);
if ast.body.is_empty() {
error::uninhabited_enum_with_no_variants(ctx)?;
}
if (&*ast.body).is_uninhabited() {
error::uninhabited_enum_variants_uninhabited(ctx)?;
}
let parts = if let Some(sty) = ast.attrs.params.into_option() {
derive_enum_has_params(ctx, &mut ast.tracker, &ast.ident, ast.body, sty)
} else {
derive_enum_no_params(ctx, &mut ast.tracker, &ast.ident, ast.body)
}?;
let parts = add_top_filter(ast.attrs.filter, parts);
Ok(Impl::new(ast.ident, ast.tracker, parts))
}
fn derive_enum_no_params(
ctx: Ctx,
ut: &mut UseTracker,
_self: &Ident,
variants: Vec<Variant>,
) -> DeriveResult<ImplParts> {
let mut acc = PartsAcc::new(variants.len());
for variant in variants {
if let Some((weight, ident, fields, attrs)) =
keep_inhabited_variant(ctx, _self, variant)?
{
let path = parse_quote!( #_self::#ident );
let (strat, ctor) = if fields.is_empty() {
pair_unit_variant(ctx, &attrs, path)
} else {
derive_variant_with_fields(
ctx, ut, path, attrs, fields, &mut acc,
)?
};
acc = acc.add_strat((strat, (weight, ctor)));
}
}
ensure_union_has_strategies(ctx, &acc.strats);
Ok(acc.finish(ctx))
}
fn ensure_union_has_strategies<C>(ctx: Ctx, strats: &StratAcc<C>) {
if strats.is_empty() {
error::uninhabited_enum_because_of_skipped_variants(ctx);
}
}
fn derive_variant_with_fields<C>(
ctx: Ctx,
ut: &mut UseTracker,
v_path: Path,
attrs: ParsedAttributes,
fields: Vec<Field>,
acc: &mut PartsAcc<C>,
) -> DeriveResult<StratPair> {
let filter = attrs.filter.clone();
let pair = match attrs.params {
ParamsMode::Passthrough => match attrs.strategy {
StratMode::Strategy(strat) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_existential_self(strat)
}
StratMode::Value(value) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_value_self(value)
}
StratMode::Regex(regex) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_regex_self(regex)
}
StratMode::Arbitrary => {
variant_no_explicit_strategy(ctx, ut, v_path, fields, acc)?
}
},
ParamsMode::Default => {
variant_handle_default_params(ctx, ut, v_path, attrs, fields)?
}
ParamsMode::Specified(params_ty) => extract_nparam(
acc,
params_ty,
match attrs.strategy {
StratMode::Strategy(strat) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_existential_self(strat)
}
StratMode::Value(value) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_value_exist_self(value)
}
StratMode::Regex(regex) => {
error::cant_set_param_and_regex(ctx, error::ENUM_VARIANT);
deny_all_attrs_on_fields(ctx, fields)?;
pair_regex_self(regex)
}
StratMode::Arbitrary => {
let ty = self_ty();
error::cant_set_param_but_not_strat(
ctx,
&ty,
error::ENUM_VARIANT,
)?
}
},
),
};
let pair = add_filter_self(filter, pair);
Ok(pair)
}
fn variant_no_explicit_strategy<C>(
ctx: Ctx,
ut: &mut UseTracker,
v_path: Path,
fields: Vec<Field>,
acc: &mut PartsAcc<C>,
) -> DeriveResult<StratPair> {
let closure = map_closure(v_path, &fields);
let fields_acc =
derive_product_no_params(ctx, ut, fields, error::ENUM_VARIANT_FIELD)?;
let (params, count) = fields_acc.params.consume();
let (strat, ctor) = fields_acc.strats.finish(closure);
let params_ty = params.into();
Ok((
strat,
if is_unit_type(¶ms_ty) {
ctor
} else {
let pref = acc.add_param(params_ty);
extract_all(ctor, count, FromReg::Num(pref))
},
))
}
fn variant_handle_default_params(
ctx: Ctx,
ut: &mut UseTracker,
v_path: Path,
attrs: ParsedAttributes,
fields: Vec<Field>,
) -> DeriveResult<StratPair> {
let pair = match attrs.strategy {
StratMode::Strategy(strat) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_existential_self(strat)
}
StratMode::Value(value) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_value_self(value)
}
StratMode::Regex(regex) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_regex_self(regex)
}
StratMode::Arbitrary =>
{
derive_product_has_params(
ctx,
ut,
error::ENUM_VARIANT_FIELD,
map_closure(v_path, &fields),
fields,
)?
}
};
Ok(pair)
}
fn deny_all_attrs_on_fields(ctx: Ctx, fields: Vec<Field>) -> DeriveResult<()> {
fields.into_iter().try_for_each(|field| {
let f_attr = attr::parse_attributes(ctx, &field.attrs)?;
error::if_anything_specified(ctx, &f_attr, error::ENUM_VARIANT_FIELD);
Ok(())
})
}
fn derive_enum_has_params(
ctx: Ctx,
ut: &mut UseTracker,
_self: &Ident,
variants: Vec<Variant>,
sty: Option<Type>,
) -> DeriveResult<ImplParts> {
let mut acc = StratAcc::new(variants.len());
for variant in variants {
let parts = keep_inhabited_variant(ctx, _self, variant)?;
if let Some((weight, ident, fields, attrs)) = parts {
let path = parse_quote!( #_self::#ident );
let (strat, ctor) = if fields.is_empty() {
pair_unit_variant(ctx, &attrs, path)
} else {
let filter = attrs.filter.clone();
add_filter_self(
filter,
variant_handle_default_params(
ctx, ut, path, attrs, fields,
)?,
)
};
acc = acc.add((strat, (weight, ctor)));
}
}
ensure_union_has_strategies(ctx, &acc);
Ok(add_top_params(sty, acc.finish(ctx)))
}
fn keep_inhabited_variant(
ctx: Ctx,
_self: &Ident,
variant: Variant,
) -> DeriveResult<Option<(u32, Ident, Vec<Field>, ParsedAttributes)>> {
let attrs = attr::parse_attributes(ctx, &variant.attrs)?;
let fields = fields_to_vec(variant.fields);
if attrs.skip {
ensure_has_only_skip_attr(ctx, &attrs, error::ENUM_VARIANT);
fields.into_iter().try_for_each(|field| {
let f_attrs = attr::parse_attributes(ctx, &field.attrs)?;
error::if_skip_present(ctx, &f_attrs, error::ENUM_VARIANT_FIELD);
ensure_has_only_skip_attr(ctx, &f_attrs, error::ENUM_VARIANT_FIELD);
Ok(())
})?;
return Ok(None);
}
if (&*fields).is_uninhabited() {
return Ok(None);
}
let weight = attrs.weight.unwrap_or(1);
Ok(Some((weight, variant.ident, fields, attrs)))
}
fn ensure_has_only_skip_attr(ctx: Ctx, attrs: &ParsedAttributes, item: &str) {
if attrs.params.is_set() {
error::skipped_variant_has_param(ctx, item);
}
if attrs.strategy.is_set() {
error::skipped_variant_has_strat(ctx, item);
}
if attrs.weight.is_some() {
error::skipped_variant_has_weight(ctx, item);
}
if !attrs.filter.is_empty() {
error::skipped_variant_has_filter(ctx, item);
}
}
fn pair_unit_variant(
ctx: Ctx,
attrs: &ParsedAttributes,
v_path: Path,
) -> StratPair {
error::if_present_on_unit_variant(ctx, attrs);
pair_unit_self(&v_path)
}
struct PartsAcc<C> {
params: ParamAcc,
strats: StratAcc<C>,
}
impl<C> PartsAcc<C> {
fn new(size: usize) -> Self {
Self {
params: ParamAcc::empty(),
strats: StratAcc::new(size),
}
}
fn add_strat(self, pair: (Strategy, C)) -> Self {
Self {
strats: self.strats.add(pair),
params: self.params,
}
}
fn add_param(&mut self, ty: Type) -> usize {
self.params.add(ty)
}
}
impl PartsAcc<Ctor> {
fn finish(self, closure: MapClosure) -> ImplParts {
let (params, count) = self.params.consume();
let (strat, ctor) = self.strats.finish(closure);
(params, strat, extract_all(ctor, count, FromReg::Top))
}
}
impl PartsAcc<(u32, Ctor)> {
fn finish(self, ctx: Ctx) -> ImplParts {
let (params, count) = self.params.consume();
let (strat, ctor) = self.strats.finish(ctx);
(params, strat, extract_all(ctor, count, FromReg::Top))
}
}
struct ParamAcc {
types: Params,
}
impl ParamAcc {
fn empty() -> Self {
Self {
types: Params::empty(),
}
}
fn add(&mut self, ty: Type) -> usize {
let var = self.types.len();
self.types += ty;
var
}
fn consume(self) -> (Params, usize) {
let count = self.types.len();
(self.types, count)
}
}
struct StratAcc<C> {
types: Vec<Strategy>,
ctors: Vec<C>,
}
impl<C> StratAcc<C> {
fn new(size: usize) -> Self {
Self {
types: Vec::with_capacity(size),
ctors: Vec::with_capacity(size),
}
}
fn add(mut self, (strat, ctor): (Strategy, C)) -> Self {
self.types.push(strat);
self.ctors.push(ctor);
self
}
fn consume(self) -> (Vec<Strategy>, Vec<C>) {
(self.types, self.ctors)
}
fn is_empty(&self) -> bool {
self.types.is_empty()
}
}
impl StratAcc<Ctor> {
fn finish(self, closure: MapClosure) -> StratPair {
pair_map(self.consume(), closure)
}
}
impl StratAcc<(u32, Ctor)> {
fn finish(self, ctx: Ctx) -> StratPair {
if self
.ctors
.iter()
.map(|&(w, _)| w)
.try_fold(0u32, |acc, w| acc.checked_add(w))
.is_none()
{
error::weight_overflowing(ctx)
}
pair_oneof(self.consume())
}
}