use std::collections::{HashMap, HashSet};
use quote::ToTokens;
use scale_info::{form::PortableForm, PortableRegistry, Type};
use crate::{utils::syn_type_path, TypegenError};
#[derive(Debug, Clone, Default)]
pub struct DerivesRegistry {
default_derives: Derives,
specific_type_derives: HashMap<syn::TypePath, Derives>,
recursive_type_derives: HashMap<syn::TypePath, Derives>,
}
impl DerivesRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn add_derives_for_all(&mut self, derives: impl IntoIterator<Item = syn::Path>) {
self.default_derives.derives.extend(derives);
}
pub fn add_attributes_for_all(&mut self, attributes: impl IntoIterator<Item = syn::Attribute>) {
self.default_derives.attributes.extend(attributes);
}
pub fn add_derives_for(
&mut self,
ty: syn::TypePath,
derives: impl IntoIterator<Item = syn::Path>,
recursive: bool,
) {
let type_derives = if recursive {
self.recursive_type_derives.entry(ty).or_default()
} else {
self.specific_type_derives.entry(ty).or_default()
};
type_derives.derives.extend(derives);
}
pub fn add_attributes_for(
&mut self,
ty: syn::TypePath,
attributes: impl IntoIterator<Item = syn::Attribute>,
recursive: bool,
) {
let type_derives = if recursive {
self.recursive_type_derives.entry(ty).or_default()
} else {
self.specific_type_derives.entry(ty).or_default()
};
type_derives.attributes.extend(attributes);
}
pub fn default_derives(&self) -> &Derives {
&self.default_derives
}
pub fn flatten_recursive_derives(
self,
types: &PortableRegistry,
) -> Result<FlatDerivesRegistry, TypegenError> {
let DerivesRegistry {
default_derives,
mut specific_type_derives,
mut recursive_type_derives,
} = self;
if recursive_type_derives.is_empty() {
return Ok(FlatDerivesRegistry {
default_derives,
specific_type_derives,
});
}
let mut syn_path_for_id: HashMap<u32, syn::TypePath> = types
.types
.iter()
.filter_map(|t| {
if t.ty.path.is_empty() {
None
} else {
match syn_type_path(&t.ty) {
Ok(path) => Some(Ok((t.id, path))),
Err(err) => Some(Err(err)),
}
}
})
.collect::<Result<_, TypegenError>>()?;
let mut add_derives_for_id: HashMap<u32, Derives> = HashMap::new();
for ty in types.types.iter() {
let Some(path) = syn_path_for_id.get(&ty.id) else {
continue;
};
let Some(recursive_derives) = recursive_type_derives.remove(path) else {
continue;
};
let mut collected_type_ids: HashSet<u32> = HashSet::new();
collect_type_ids(ty.id, types, &mut collected_type_ids);
for id in collected_type_ids {
add_derives_for_id
.entry(id)
.or_default()
.extend_from(recursive_derives.clone());
}
}
for (id, derived_to_add) in add_derives_for_id {
if let Some(path) = syn_path_for_id.remove(&id) {
specific_type_derives
.entry(path)
.or_default()
.extend_from(derived_to_add);
}
}
Ok(FlatDerivesRegistry {
default_derives,
specific_type_derives,
})
}
pub fn derives_on_specific_types(&self) -> impl Iterator<Item = (&syn::TypePath, &Derives)> {
self.specific_type_derives
.iter()
.chain(self.recursive_type_derives.iter())
}
}
#[derive(Debug, Clone, Default)]
pub struct Derives {
derives: HashSet<syn::Path>,
attributes: HashSet<syn::Attribute>,
}
impl FromIterator<syn::Path> for Derives {
fn from_iter<T: IntoIterator<Item = syn::Path>>(iter: T) -> Self {
let derives = iter.into_iter().collect();
Self {
derives,
attributes: HashSet::new(),
}
}
}
impl Derives {
pub fn new() -> Self {
Self {
derives: HashSet::new(),
attributes: HashSet::new(),
}
}
pub fn extend_from(&mut self, other: Derives) {
self.derives.extend(other.derives);
self.attributes.extend(other.attributes);
}
pub fn extend(&mut self, derives: impl Iterator<Item = syn::Path>) {
for derive in derives {
self.insert_derive(derive)
}
}
pub fn insert_derive(&mut self, derive: syn::Path) {
self.derives.insert(derive);
}
pub fn insert_attribute(&mut self, attribute: syn::Attribute) {
self.attributes.insert(attribute);
}
pub fn derives(&self) -> &HashSet<syn::Path> {
&self.derives
}
pub fn attributes(&self) -> &HashSet<syn::Attribute> {
&self.attributes
}
}
impl ToTokens for Derives {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
if !self.derives.is_empty() {
let mut sorted = self.derives.iter().cloned().collect::<Vec<_>>();
sorted.sort_by(|a, b| {
quote::quote!(#a)
.to_string()
.cmp("e::quote!(#b).to_string())
});
tokens.extend(quote::quote! {
#[derive(#( #sorted ),*)]
})
}
if !self.attributes.is_empty() {
let mut sorted = self.attributes.iter().cloned().collect::<Vec<_>>();
sorted.sort_by(|a, b| {
quote::quote!(#a)
.to_string()
.cmp("e::quote!(#b).to_string())
});
tokens.extend(quote::quote! {
#( #sorted )*
})
}
}
}
#[derive(Debug, Clone, Default)]
pub struct FlatDerivesRegistry {
default_derives: Derives,
specific_type_derives: HashMap<syn::TypePath, Derives>,
}
impl FlatDerivesRegistry {
pub fn resolve(&self, ty: &syn::TypePath) -> Derives {
let mut resolved_derives = self.default_derives.clone();
if let Some(specific) = self.specific_type_derives.get(ty) {
resolved_derives.extend_from(specific.clone());
}
resolved_derives
}
pub fn resolve_derives_for_type(
&self,
ty: &Type<PortableForm>,
) -> Result<Derives, TypegenError> {
Ok(self.resolve(&syn_type_path(ty)?))
}
}
fn collect_type_ids(id: u32, types: &PortableRegistry, collected_types: &mut HashSet<u32>) {
if collected_types.contains(&id) {
return;
}
collected_types.insert(id);
let ty = types
.resolve(id)
.expect("Should contain this id, if Registry not corrupted");
for param in ty.type_params.iter() {
if let Some(id) = param.ty.map(|e| e.id) {
collect_type_ids(id, types, collected_types);
}
}
match &ty.type_def {
scale_info::TypeDef::Composite(def) => {
for f in def.fields.iter() {
collect_type_ids(f.ty.id, types, collected_types);
}
}
scale_info::TypeDef::Variant(def) => {
for v in def.variants.iter() {
for f in v.fields.iter() {
collect_type_ids(f.ty.id, types, collected_types);
}
}
}
scale_info::TypeDef::Sequence(def) => {
collect_type_ids(def.type_param.id, types, collected_types);
}
scale_info::TypeDef::Array(def) => {
collect_type_ids(def.type_param.id, types, collected_types);
}
scale_info::TypeDef::Tuple(def) => {
for f in def.fields.iter() {
collect_type_ids(f.id, types, collected_types);
}
}
scale_info::TypeDef::Primitive(_) => {}
scale_info::TypeDef::Compact(def) => {
collect_type_ids(def.type_param.id, types, collected_types);
}
scale_info::TypeDef::BitSequence(_) => {}
}
}