use proc_macro2::{TokenStream, Span};
use syn::{Type, Path, Expr, Field, Ident, Variant, DeriveInput};
use syn::spanned::Spanned;
use crate::util::{is_unit_type, self_ty, fields_to_vec};
use crate::void::IsUninhabited;
use crate::error::{self, Ctx, Context, DeriveResult};
use crate::attr::{self, ParamsMode, ParsedAttributes, StratMode};
use crate::use_tracking::{UseMarkable, UseTracker};
use crate::ast::*;
pub fn impl_proptest_arbitrary(ast: DeriveInput) -> TokenStream {
let mut ctx = Context::default();
let result = derive_proptest_arbitrary(&mut ctx, ast);
match (result, ctx.check()) {
(Ok(derive), Ok(())) => derive,
(_, Err(err)) => err,
(Err(result), Ok(())) =>
panic!("[proptest_derive]: internal error, this is a bug! \
result: {:?}", result),
}
}
struct DeriveData<B> {
ident: Ident,
attrs: ParsedAttributes,
tracker: UseTracker,
body: B
}
fn derive_proptest_arbitrary(ctx: Ctx, ast: DeriveInput)
-> DeriveResult<TokenStream>
{
use syn::Data::*;
error::if_has_lifetimes(ctx, &ast);
let attrs = attr::parse_top_attributes(ctx, &ast.attrs)?;
let mut tracker = UseTracker::new(ast.generics);
if attrs.no_bound {
tracker.no_track();
}
let the_impl = match ast.data {
Struct(data) => derive_struct(ctx, DeriveData {
tracker, attrs, ident: ast.ident,
body: fields_to_vec(data.fields),
}),
Enum(data) => derive_enum(ctx, DeriveData {
tracker, attrs, ident: ast.ident,
body: data.variants.into_iter().collect(),
}),
_ => { error::not_struct_or_enum(ctx)? }
}?;
let q = the_impl.into_tokens(ctx)?;
Ok(q)
}
fn derive_struct(ctx: Ctx, mut ast: DeriveData<Vec<Field>>) -> DeriveResult<Impl> {
error::if_enum_attrs_present(ctx, &ast.attrs, error::STRUCT);
error::if_strategy_present(ctx, &ast.attrs, error::STRUCT);
let v_path = ast.ident.clone().into();
let parts = if ast.body.is_empty() {
error::if_present_on_unit_struct(ctx, &ast.attrs);
let (strat, ctor) = pair_unit_self(&v_path);
(Params::empty(), strat, ctor)
} else {
if (&*ast.body).is_uninhabited() {
error::uninhabited_struct(ctx);
}
let closure = map_closure(v_path, &ast.body);
let parts = if let Some(param_ty) = ast.attrs.params.into_option() {
add_top_params(param_ty,
derive_product_has_params(ctx, &mut ast.tracker,
error::STRUCT_FIELD, closure, ast.body)?)
} else {
derive_product_no_params(
ctx, &mut ast.tracker,
ast.body, error::STRUCT_FIELD
)?.finish(closure)
};
add_top_filter(ast.attrs.filter, parts)
};
Ok(Impl::new(ast.ident, ast.tracker, parts))
}
fn add_top_filter(filter: Vec<Expr>, parts: ImplParts) -> ImplParts {
let (params, strat, ctor) = parts;
let (strat, ctor) = add_filter_self(filter, (strat, ctor));
(params, strat, ctor)
}
fn add_filter_self(filter: Vec<Expr>, pair: StratPair) -> StratPair {
pair_filter(filter, self_ty(), pair)
}
fn add_top_params(param_ty: Option<Type>, (strat, ctor): StratPair) -> ImplParts {
let params = Params::empty();
if let Some(params_ty) = param_ty {
(params + params_ty, strat, extract_api(ctor, FromReg::Top))
} else {
(params, strat, ctor)
}
}
fn derive_product_has_params(
ctx: Ctx, ut: &mut UseTracker,
item: &str, closure: MapClosure, fields: Vec<Field>)
-> DeriveResult<StratPair>
{
let len = fields.len();
fields.into_iter().try_fold(StratAcc::new(len), |acc, field| {
let attrs = attr::parse_attributes(ctx, &field.attrs)?;
error::if_enum_attrs_present(ctx, &attrs, item);
error::if_specified_params(ctx, &attrs, item);
let span = field.span();
let ty = field.ty.clone();
let pair = product_handle_default_params(ut, ty, span, attrs.strategy);
let pair = pair_filter(attrs.filter, field.ty, pair);
Ok(acc.add(pair))
}).map(|acc| acc.finish(closure))
}
fn product_handle_default_params
(ut: &mut UseTracker, ty: Type, span: Span, strategy: StratMode) -> StratPair {
match strategy {
StratMode::Strategy(strat) => pair_existential(ty, strat),
StratMode::Value(value) => pair_value(ty, value),
StratMode::Regex(regex) => pair_regex(ty, regex),
StratMode::Arbitrary => { ty.mark_uses(ut); pair_any(ty, span) },
}
}
fn derive_product_no_params
(ctx: Ctx, ut: &mut UseTracker, fields: Vec<Field>, item: &str)
-> DeriveResult<PartsAcc<Ctor>>
{
let acc = PartsAcc::new(fields.len());
fields.into_iter().try_fold(acc, |mut acc, field| {
let attrs = attr::parse_attributes(ctx, &field.attrs)?;
error::if_enum_attrs_present(ctx, &attrs, item);
let span = field.span();
let ty = field.ty;
let strat = pair_filter(attrs.filter, ty.clone(), match attrs.params {
ParamsMode::Passthrough => match attrs.strategy {
StratMode::Strategy(strat) => pair_existential(ty, strat),
StratMode::Value(value) => pair_value(ty, value),
StratMode::Regex(regex) => pair_regex(ty, regex),
StratMode::Arbitrary => {
ty.mark_uses(ut);
let pref = acc.add_param(arbitrary_param(&ty));
pair_any_with(ty, pref, span)
},
},
ParamsMode::Default =>
product_handle_default_params(ut, ty, span, attrs.strategy),
ParamsMode::Specified(params_ty) =>
extract_nparam(&mut acc, params_ty, match attrs.strategy {
StratMode::Strategy(strat) => pair_existential(ty, strat),
StratMode::Value(value) => pair_value_exist(ty, value),
StratMode::Regex(regex) => {
error::cant_set_param_and_regex(ctx, item);
pair_regex(ty, regex)
},
StratMode::Arbitrary =>
error::cant_set_param_but_not_strat(ctx, &ty, item)?,
}),
});
Ok(acc.add_strat(strat))
})
}
fn extract_nparam<C>
(acc: &mut PartsAcc<C>, params_ty: Type, (strat, ctor): StratPair)
-> StratPair
{
(strat, extract_api(ctor, FromReg::Num(acc.add_param(params_ty))))
}
fn derive_enum(ctx: Ctx, mut ast: DeriveData<Vec<Variant>>) -> DeriveResult<Impl> {
use crate::void::IsUninhabited;
error::if_skip_present(ctx, &ast.attrs, error::ENUM);
error::if_strategy_present(ctx, &ast.attrs, error::ENUM);
error::if_weight_present(ctx, &ast.attrs, error::ENUM);
if ast.body.is_empty() {
error::uninhabited_enum_with_no_variants(ctx)?;
}
if (&*ast.body).is_uninhabited() {
error::uninhabited_enum_variants_uninhabited(ctx)?;
}
let parts = if let Some(sty) = ast.attrs.params.into_option() {
derive_enum_has_params(ctx, &mut ast.tracker, &ast.ident, ast.body, sty)
} else {
derive_enum_no_params(ctx, &mut ast.tracker, &ast.ident, ast.body)
}?;
let parts = add_top_filter(ast.attrs.filter, parts);
Ok(Impl::new(ast.ident, ast.tracker, parts))
}
fn derive_enum_no_params(
ctx: Ctx, ut: &mut UseTracker, _self: &Ident, variants: Vec<Variant>)
-> DeriveResult<ImplParts>
{
let mut acc = PartsAcc::new(variants.len());
for variant in variants {
if let Some((weight, ident, fields, attrs))
= keep_inhabited_variant(ctx, _self, variant)? {
let path = parse_quote!( #_self::#ident );
let (strat, ctor) = if fields.is_empty() {
pair_unit_variant(ctx, &attrs, path)
} else {
derive_variant_with_fields(ctx, ut, path, attrs, fields, &mut acc)?
};
acc = acc.add_strat((strat, (weight, ctor)));
}
}
ensure_union_has_strategies(ctx, &acc.strats);
Ok(acc.finish(ctx))
}
fn ensure_union_has_strategies<C>(ctx: Ctx, strats: &StratAcc<C>) {
if strats.is_empty() {
error::uninhabited_enum_because_of_skipped_variants(ctx);
}
}
fn derive_variant_with_fields<C>
(ctx: Ctx, ut: &mut UseTracker, v_path: Path, attrs: ParsedAttributes,
fields: Vec<Field>, acc: &mut PartsAcc<C>)
-> DeriveResult<StratPair>
{
let filter = attrs.filter.clone();
let pair = match attrs.params {
ParamsMode::Passthrough => match attrs.strategy {
StratMode::Strategy(strat) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_existential_self(strat)
},
StratMode::Value(value) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_value_self(value)
},
StratMode::Regex(regex) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_regex_self(regex)
},
StratMode::Arbitrary =>
variant_no_explicit_strategy(ctx, ut, v_path, fields, acc)?,
},
ParamsMode::Default =>
variant_handle_default_params(ctx, ut, v_path, attrs, fields)?,
ParamsMode::Specified(params_ty) =>
extract_nparam(acc, params_ty, match attrs.strategy {
StratMode::Strategy(strat) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_existential_self(strat)
},
StratMode::Value(value) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_value_exist_self(value)
},
StratMode::Regex(regex) => {
error::cant_set_param_and_regex(ctx, error::ENUM_VARIANT);
deny_all_attrs_on_fields(ctx, fields)?;
pair_regex_self(regex)
},
StratMode::Arbitrary => {
let ty = self_ty();
error::cant_set_param_but_not_strat(ctx, &ty, error::ENUM_VARIANT)?
},
}),
};
let pair = add_filter_self(filter, pair);
Ok(pair)
}
fn variant_no_explicit_strategy<C>
(ctx: Ctx, ut: &mut UseTracker, v_path: Path,
fields: Vec<Field>, acc: &mut PartsAcc<C>)
-> DeriveResult<StratPair>
{
let closure = map_closure(v_path, &fields);
let fields_acc = derive_product_no_params(ctx, ut, fields,
error::ENUM_VARIANT_FIELD)?;
let (params, count) = fields_acc.params.consume();
let (strat, ctor) = fields_acc.strats.finish(closure);
let params_ty = params.into();
Ok((strat, if is_unit_type(¶ms_ty) { ctor } else {
let pref = acc.add_param(params_ty);
if pref + 1 == count {
ctor
} else {
extract_all(ctor, count, FromReg::Num(pref))
}
}))
}
fn variant_handle_default_params(
ctx: Ctx, ut: &mut UseTracker,
v_path: Path, attrs: ParsedAttributes, fields: Vec<Field>)
-> DeriveResult<StratPair> {
let pair = match attrs.strategy {
StratMode::Strategy(strat) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_existential_self(strat)
},
StratMode::Value(value) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_value_self(value)
},
StratMode::Regex(regex) => {
deny_all_attrs_on_fields(ctx, fields)?;
pair_regex_self(regex)
},
StratMode::Arbitrary =>
derive_product_has_params(ctx, ut, error::ENUM_VARIANT_FIELD,
map_closure(v_path, &fields), fields)?,
};
Ok(pair)
}
fn deny_all_attrs_on_fields(ctx: Ctx, fields: Vec<Field>) -> DeriveResult<()> {
fields.into_iter().try_for_each(|field| {
let f_attr = attr::parse_attributes(ctx, &field.attrs)?;
error::if_anything_specified(ctx, &f_attr, error::ENUM_VARIANT_FIELD);
Ok(())
})
}
fn derive_enum_has_params(
ctx: Ctx, ut: &mut UseTracker, _self: &Ident, variants: Vec<Variant>,
sty: Option<Type>)
-> DeriveResult<ImplParts>
{
let mut acc = StratAcc::new(variants.len());
for variant in variants {
let parts = keep_inhabited_variant(ctx, _self, variant)?;
if let Some((weight, ident, fields, attrs)) = parts {
let path = parse_quote!( #_self::#ident );
let (strat, ctor) = if fields.is_empty() {
pair_unit_variant(ctx, &attrs, path)
} else {
let filter = attrs.filter.clone();
add_filter_self(filter,
variant_handle_default_params(ctx, ut, path, attrs, fields)?)
};
acc = acc.add((strat, (weight, ctor)));
}
}
ensure_union_has_strategies(ctx, &acc);
Ok(add_top_params(sty, acc.finish(ctx)))
}
fn keep_inhabited_variant(ctx: Ctx, _self: &Ident, variant: Variant)
-> DeriveResult<Option<(u32, Ident, Vec<Field>, ParsedAttributes)>>
{
use crate::void::IsUninhabited;
let attrs = attr::parse_attributes(ctx, &variant.attrs)?;
let fields = fields_to_vec(variant.fields);
if attrs.skip {
ensure_has_only_skip_attr(ctx, &attrs, error::ENUM_VARIANT);
fields.into_iter().try_for_each(|field| {
let f_attrs = attr::parse_attributes(ctx, &field.attrs)?;
error::if_skip_present(ctx, &f_attrs, error::ENUM_VARIANT_FIELD);
ensure_has_only_skip_attr(ctx, &f_attrs, error::ENUM_VARIANT_FIELD);
Ok(())
})?;
return Ok(None)
}
if (&*fields).is_uninhabited() { return Ok(None) }
let weight = attrs.weight.unwrap_or(1);
Ok(Some((weight, variant.ident, fields, attrs)))
}
fn ensure_has_only_skip_attr(ctx: Ctx, attrs: &ParsedAttributes, item: &str) {
if attrs.params.is_set() {
error::skipped_variant_has_param(ctx, item);
}
if attrs.strategy.is_set() {
error::skipped_variant_has_strat(ctx, item);
}
if attrs.weight.is_some() {
error::skipped_variant_has_weight(ctx, item);
}
if !attrs.filter.is_empty() {
error::skipped_variant_has_filter(ctx, item);
}
}
fn pair_unit_variant(ctx: Ctx, attrs: &ParsedAttributes, v_path: Path)
-> StratPair
{
error::if_present_on_unit_variant(ctx, attrs);
pair_unit_self(&v_path)
}
struct PartsAcc<C> {
params: ParamAcc,
strats: StratAcc<C>,
}
impl<C> PartsAcc<C> {
fn new(size: usize) -> Self {
Self {
params: ParamAcc::empty(),
strats: StratAcc::new(size),
}
}
fn add_strat(self, pair: (Strategy, C)) -> Self {
Self {
strats: self.strats.add(pair),
params: self.params
}
}
fn add_param(&mut self, ty: Type) -> usize {
self.params.add(ty)
}
}
impl PartsAcc<Ctor> {
fn finish(self, closure: MapClosure) -> ImplParts {
let (params, count) = self.params.consume();
let (strat, ctor) = self.strats.finish(closure);
(params, strat, extract_all(ctor, count, FromReg::Top))
}
}
impl PartsAcc<(u32, Ctor)> {
fn finish(self, ctx: Ctx) -> ImplParts {
let (params, count) = self.params.consume();
let (strat, ctor) = self.strats.finish(ctx);
(params, strat, extract_all(ctor, count, FromReg::Top))
}
}
struct ParamAcc {
types: Params,
}
impl ParamAcc {
fn empty() -> Self {
Self { types: Params::empty(), }
}
fn add(&mut self, ty: Type) -> usize {
let var = self.types.len();
self.types += ty;
var
}
fn consume(self) -> (Params, usize) {
let count = self.types.len();
(self.types, count)
}
}
struct StratAcc<C> {
types: Vec<Strategy>,
ctors: Vec<C>,
}
impl<C> StratAcc<C> {
fn new(size: usize) -> Self {
Self {
types: Vec::with_capacity(size),
ctors: Vec::with_capacity(size),
}
}
fn add(mut self, (strat, ctor): (Strategy, C)) -> Self {
self.types.push(strat);
self.ctors.push(ctor);
self
}
fn consume(self) -> (Vec<Strategy>, Vec<C>) {
(self.types, self.ctors)
}
fn is_empty(&self) -> bool {
self.types.is_empty()
}
}
impl StratAcc<Ctor> {
fn finish(self, closure: MapClosure) -> StratPair {
pair_map(self.consume(), closure)
}
}
impl StratAcc<(u32, Ctor)> {
fn finish(self, ctx: Ctx) -> StratPair {
if self.ctors.iter()
.map(|&(w, _)| w)
.try_fold(0u32, |acc, w| acc.checked_add(w))
.is_none() {
error::weight_overflowing(ctx)
}
pair_oneof(self.consume())
}
}