use proc_macro2::Ident;
use std::collections::{HashMap, HashSet};
use syn::{fold::Fold, GenericParam, Generics};
pub struct TypeReplacer {
type_replacer: HashMap<Ident, Ident>,
}
impl TypeReplacer {
pub fn new(type_replacer: &HashMap<Ident, Ident>) -> Self {
TypeReplacer {
type_replacer: type_replacer.clone(),
}
}
}
impl Fold for TypeReplacer {
fn fold_ident(&mut self, key: Ident) -> Ident {
match self.type_replacer.get(&key) {
None => key,
Some(i) => i.clone(),
}
}
}
pub struct GenericDuplication {
hash_map: HashMap<Ident, Ident>,
}
impl GenericDuplication {
pub fn new(hash_map: &HashMap<Ident, Ident>) -> GenericDuplication {
GenericDuplication {
hash_map: hash_map.clone(),
}
}
}
impl Fold for GenericDuplication {
fn fold_generics(&mut self, mut i: Generics) -> Generics {
let mut replacer = TypeReplacer::new(&self.hash_map);
let iter = i
.type_params()
.flat_map(|type_param| {
let type_param = type_param.clone();
if self.hash_map.contains_key(&type_param.ident) {
vec![type_param.clone(), replacer.fold_type_param(type_param)]
} else {
vec![type_param.clone()]
}
})
.map(GenericParam::Type);
i.params = syn::punctuated::Punctuated::from_iter(iter);
i
}
}
pub struct GenericRemoval {
set: HashSet<Ident>,
}
impl GenericRemoval {
pub fn new(set: &[Ident]) -> GenericRemoval {
GenericRemoval {
set: HashSet::from_iter(set.iter().cloned()),
}
}
}
impl Fold for GenericRemoval {
fn fold_generics(&mut self, mut i: Generics) -> Generics {
let iter = i
.type_params()
.flat_map(|type_param| {
let type_param = type_param.clone();
if self.set.contains(&type_param.ident) {
None
} else {
Some(type_param.clone())
}
})
.map(GenericParam::Type);
i.params = syn::punctuated::Punctuated::from_iter(iter);
i
}
}