use ast;
use attr;
use std::collections::HashSet;
use syn::{self, visit, GenericParam};
pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
syn::Generics {
params: generics
.params
.iter()
.map(|generic_param| match *generic_param {
GenericParam::Type(ref ty_param) => syn::GenericParam::Type(syn::TypeParam {
default: None,
..ty_param.clone()
}),
ref param => param.clone(),
})
.collect(),
..generics.clone()
}
}
pub fn with_where_predicates(
generics: &syn::Generics,
predicates: &[syn::WherePredicate],
) -> syn::Generics {
let mut cloned = generics.clone();
cloned
.make_where_clause()
.predicates
.extend(predicates.iter().cloned());
cloned
}
pub fn with_where_predicates_from_fields<F>(
item: &ast::Input,
generics: &syn::Generics,
from_field: F,
) -> syn::Generics
where
F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
{
let mut cloned = generics.clone();
{
let fields = item.body.all_fields();
let field_where_predicates = fields
.iter()
.flat_map(|field| from_field(&field.attrs))
.flat_map(|predicates| predicates.to_vec());
cloned
.make_where_clause()
.predicates
.extend(field_where_predicates);
}
cloned
}
pub fn with_bound<F>(
item: &ast::Input,
generics: &syn::Generics,
filter: F,
bound: &syn::Path,
) -> syn::Generics
where
F: Fn(&attr::Field) -> bool,
{
#[derive(Debug)]
struct FindTyParams {
all_ty_params: HashSet<syn::Ident>,
relevant_ty_params: HashSet<syn::Ident>,
}
impl<'ast> visit::Visit<'ast> for FindTyParams {
fn visit_path(&mut self, path: &'ast syn::Path) {
if is_phantom_data(path) {
return;
}
if path.leading_colon.is_none() && path.segments.len() == 1 {
let id = &path.segments[0].ident;
if self.all_ty_params.contains(id) {
self.relevant_ty_params.insert(id.clone());
}
}
visit::visit_path(self, path);
}
}
let all_ty_params: HashSet<_> = generics
.type_params()
.map(|ty_param| ty_param.ident.clone())
.collect();
let relevant_tys = item
.body
.all_fields()
.into_iter()
.filter(|field| {
if let syn::Type::Path(syn::TypePath { ref path, .. }) = *field.ty {
!is_phantom_data(path)
} else {
true
}
})
.filter(|field| filter(&field.attrs))
.map(|field| &field.ty);
let mut visitor = FindTyParams {
all_ty_params,
relevant_ty_params: HashSet::new(),
};
for ty in relevant_tys {
visit::visit_type(&mut visitor, ty);
}
let mut cloned = generics.clone();
{
let relevant_where_predicates = generics
.type_params()
.map(|ty_param| &ty_param.ident)
.filter(|id| visitor.relevant_ty_params.contains(id))
.map(|id| -> syn::WherePredicate { parse_quote!( #id : #bound ) });
cloned
.make_where_clause()
.predicates
.extend(relevant_where_predicates);
}
cloned
}
#[allow(clippy::match_like_matches_macro)] fn is_phantom_data(path: &syn::Path) -> bool {
match path.segments.last() {
Some(path) if path.ident == "PhantomData" => true,
_ => false,
}
}