use std::iter;
use proc_macro2::Ident;
use syn::{
spanned::Spanned,
visit::{self, Visit},
Generics, Result, Type, TypePath,
};
use crate::utils;
struct ContainIdents<'a> {
result: bool,
idents: &'a[Ident]
}
impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if self.idents.iter().any(|id| id == i) {
self.result = true;
}
}
}
fn type_contain_idents(ty: &Type, idents: &[Ident]) -> bool {
let mut visitor = ContainIdents { result: false, idents };
visitor.visit_type(ty);
visitor.result
}
struct TypePathStartsWithIdent<'a> {
result: bool,
ident: &'a Ident
}
impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> {
fn visit_type_path(&mut self, i: &'ast TypePath) {
if let Some(segment) = i.path.segments.first() {
if &segment.ident == self.ident {
self.result = true;
return;
}
}
visit::visit_type_path(self, i);
}
}
fn type_path_or_sub_starts_with_ident(ty: &TypePath, ident: &Ident) -> bool {
let mut visitor = TypePathStartsWithIdent { result: false, ident };
visitor.visit_type_path(ty);
visitor.result
}
fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool {
let mut visitor = TypePathStartsWithIdent { result: false, ident };
visitor.visit_type(ty);
visitor.result
}
struct FindTypePathsNotStartOrContainIdent<'a> {
result: Vec<TypePath>,
ident: &'a Ident
}
impl<'a, 'ast> Visit<'ast> for FindTypePathsNotStartOrContainIdent<'a> {
fn visit_type_path(&mut self, i: &'ast TypePath) {
if type_path_or_sub_starts_with_ident(i, &self.ident) {
visit::visit_type_path(self, i);
} else {
self.result.push(i.clone());
}
}
}
fn find_type_paths_not_start_or_contain_ident(ty: &Type, ident: &Ident) -> Vec<TypePath> {
let mut visitor = FindTypePathsNotStartOrContainIdent { result: Vec::new(), ident };
visitor.visit_type(ty);
visitor.result
}
pub fn add(
input_ident: &Ident,
generics: &mut Generics,
data: &syn::Data,
codec_bound: syn::Path,
codec_skip_bound: Option<syn::Path>,
dumb_trait_bounds: bool,
) -> Result<()> {
let ty_params = generics.type_params().map(|p| p.ident.clone()).collect::<Vec<_>>();
if ty_params.is_empty() {
return Ok(());
}
let codec_types = get_types_to_add_trait_bound(input_ident, data, &ty_params, dumb_trait_bounds)?;
let compact_types = collect_types(&data, utils::is_compact)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.collect::<Vec<_>>();
let skip_types = if codec_skip_bound.is_some() {
let needs_default_bound = |f: &syn::Field| utils::should_skip(&f.attrs);
collect_types(&data, needs_default_bound)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.collect::<Vec<_>>()
} else {
Vec::new()
};
if !codec_types.is_empty() || !compact_types.is_empty() || !skip_types.is_empty() {
let where_clause = generics.make_where_clause();
codec_types
.into_iter()
.for_each(|ty| {
where_clause.predicates.push(parse_quote!(#ty : #codec_bound))
});
let has_compact_bound: syn::Path = parse_quote!(_parity_scale_codec::HasCompact);
compact_types
.into_iter()
.for_each(|ty| {
where_clause.predicates.push(parse_quote!(#ty : #has_compact_bound))
});
skip_types
.into_iter()
.for_each(|ty| {
let codec_skip_bound = codec_skip_bound.as_ref().unwrap();
where_clause.predicates.push(parse_quote!(#ty : #codec_skip_bound))
});
}
Ok(())
}
fn get_types_to_add_trait_bound(
input_ident: &Ident,
data: &syn::Data,
ty_params: &[Ident],
dumb_trait_bound: bool,
) -> Result<Vec<Type>> {
if dumb_trait_bound {
Ok(ty_params.iter().map(|t| parse_quote!( #t )).collect())
} else {
let needs_codec_bound = |f: &syn::Field| !utils::is_compact(f)
&& utils::get_encoded_as_type(f).is_none()
&& !utils::should_skip(&f.attrs);
let res = collect_types(&data, needs_codec_bound)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.flat_map(|ty| {
find_type_paths_not_start_or_contain_ident(&ty, input_ident)
.into_iter()
.map(|ty| Type::Path(ty.clone()))
.filter(|ty| type_contain_idents(ty, &ty_params))
.chain(iter::once(ty))
})
.filter(|ty| !type_or_sub_type_path_starts_with_ident(ty, input_ident))
.collect();
Ok(res)
}
}
fn collect_types(
data: &syn::Data,
type_filter: fn(&syn::Field) -> bool,
) -> Result<Vec<syn::Type>> {
use syn::*;
let types = match *data {
Data::Struct(ref data) => match &data.fields {
| Fields::Named(FieldsNamed { named: fields , .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => {
fields.iter()
.filter(|f| type_filter(f))
.map(|f| f.ty.clone())
.collect()
},
Fields::Unit => { Vec::new() },
},
Data::Enum(ref data) => data.variants.iter()
.filter(|variant| !utils::should_skip(&variant.attrs))
.flat_map(|variant| {
match &variant.fields {
| Fields::Named(FieldsNamed { named: fields , .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => {
fields.iter()
.filter(|f| type_filter(f))
.map(|f| f.ty.clone())
.collect()
},
Fields::Unit => { Vec::new() },
}
}).collect(),
Data::Union(ref data) => return Err(Error::new(
data.union_token.span(),
"Union types are not supported."
)),
};
Ok(types)
}