use proc_macro2::Ident;
use quote::format_ident;
use syn::punctuated::Punctuated;
use syn::{
AngleBracketedGenericArguments, GenericArgument, Path, PathArguments, ReturnType, Token, Type,
TypeParamBound, WhereClause, WherePredicate,
};
pub fn map_where(where_clause: &WhereClause, param: &Ident) -> Option<WhereClause> {
let mut clause = where_clause.clone();
let mut contains_param = false;
for pred in &mut clause.predicates {
if let WherePredicate::Type(t) = pred {
map_type(&mut t.bounded_ty, param, &mut contains_param);
map_type_param_bounds(&mut t.bounds, param, &mut contains_param);
}
}
if contains_param {
Some(clause)
} else {
None
}
}
pub fn map_type_param_bounds(
bounds: &mut Punctuated<TypeParamBound, Token![+]>,
param: &Ident,
contains_param: &mut bool,
) {
for bound in bounds {
if let TypeParamBound::Trait(trt) = bound {
map_path(&mut trt.path, param, contains_param);
}
}
}
fn map_angle_bracketed_generic_arguments(
args: &mut AngleBracketedGenericArguments,
param: &Ident,
contains_param: &mut bool,
) {
for arg in &mut args.args {
match arg {
GenericArgument::Type(t) => map_type(t, param, contains_param),
GenericArgument::AssocType(assoc) => {
map_type(&mut assoc.ty, param, contains_param);
if let Some(generics) = &mut assoc.generics {
map_angle_bracketed_generic_arguments(generics, param, contains_param);
}
}
GenericArgument::Constraint(_) => {}
_ => {}
}
}
}
pub fn map_path(path: &mut Path, param: &Ident, contains_param: &mut bool) {
if let Some(seg) = path.segments.first_mut() {
if &seg.ident == param {
seg.ident = format_ident!("__B");
*contains_param = true;
}
}
for seg in &mut path.segments {
match &mut seg.arguments {
PathArguments::AngleBracketed(args) => {
map_angle_bracketed_generic_arguments(args, param, contains_param);
}
PathArguments::Parenthesized(args) => {
for input in &mut args.inputs {
map_type(input, param, contains_param);
}
if let ReturnType::Type(_, t) = &mut args.output {
map_type(t, param, contains_param)
}
}
_ => continue,
}
}
}
fn map_type(typ: &mut Type, param: &Ident, contains_param: &mut bool) {
match typ {
Type::Array(array) => {
map_type(&mut array.elem, param, contains_param);
}
Type::BareFn(fun) => {
for input in &mut fun.inputs {
map_type(&mut input.ty, param, contains_param);
}
match &mut fun.output {
ReturnType::Default => {}
ReturnType::Type(_, t) => map_type(t, param, contains_param),
}
}
Type::Group(group) => map_type(&mut group.elem, param, contains_param),
Type::ImplTrait(impl_trait) => {
map_type_param_bounds(&mut impl_trait.bounds, param, contains_param);
}
Type::Paren(paren) => map_type(&mut paren.elem, param, contains_param),
Type::Path(path) => map_path(&mut path.path, param, contains_param),
Type::Ptr(ptr) => map_type(&mut ptr.elem, param, contains_param),
Type::Reference(refer) => map_type(&mut refer.elem, param, contains_param),
Type::Slice(slice) => map_type(&mut slice.elem, param, contains_param),
Type::TraitObject(obj) => {
map_type_param_bounds(&mut obj.bounds, param, contains_param);
}
Type::Tuple(tup) => {
for elem in &mut tup.elems {
map_type(elem, param, contains_param);
}
}
_ => {}
}
}