use std::{
collections::{HashMap, HashSet},
rc::Rc,
};
use proc_macro2::Span;
use quote::quote;
use syn::{
spanned::Spanned, GenericArgument, Generics, Ident, PathArguments, TraitBound, Type,
TypeParamBound, TypePath, TypeReference, TypeTraitObject, WherePredicate,
};
use crate::{
filter_map_assoc_paths, match_assoc_type,
parse_assoc_type::{BoxType, DestType},
parse_attrs::Convert,
syn_utils::{iter_path, iter_type, type_arguments_mut},
trait_sig::{MethodError, TypeTransform},
};
pub struct TypeConverter {
pub assoc_type_conversions: HashMap<Ident, DestType>,
pub collections: HashMap<Ident, usize>,
pub trait_ident: Ident,
pub type_conversions: HashMap<Type, Rc<Convert>>,
pub used_conversions: HashSet<Type>,
}
#[derive(Debug)]
pub enum TransformError {
AssocTypeWithoutDestType,
UnsupportedType,
ExpectedAtLeastNTypes(usize),
AssocTypeAfterFirstNTypes(usize, Ident),
QualifiedAssociatedType,
SelfQualifiedAsOtherTrait,
}
impl TypeConverter {
#[rustfmt::skip]
fn get_collection_type_count(&self, ident: &Ident) -> Option<usize> {
if let Some(count) = self.collections.get(ident) {
return Some(*count);
}
if ident == "Vec" { return Some(1); }
if ident == "VecDeque" { return Some(1); }
if ident == "LinkedList" { return Some(1); }
if ident == "HashSet" { return Some(1); }
if ident == "BinaryHeap" { return Some(1); }
if ident == "BTreeSet" { return Some(1); }
if ident == "HashMap" { return Some(2); }
if ident == "BTreeMap" { return Some(2); }
None
}
pub fn convert_type(
&mut self,
type_: &mut Type,
) -> Result<TypeTransform, (Span, TransformError)> {
if let Some(conv) = self.type_conversions.get(type_) {
self.used_conversions.insert(conv.original_type.clone());
*type_ = conv.dest_type.clone();
return Ok(TypeTransform::Verbatim(conv.clone()));
}
if !iter_type(type_).any(match_assoc_type) {
return Ok(TypeTransform::NoOp);
}
if let Type::Tuple(tuple) = type_ {
let mut types = Vec::new();
for elem in &mut tuple.elems {
types.push(self.convert_type(elem)?);
}
return Ok(TypeTransform::Tuple(types));
} else if let Type::Reference(TypeReference {
lifetime: None,
mutability: Some(_),
elem,
..
}) = type_
{
if let Type::TraitObject(TypeTraitObject {
dyn_token: Some(_),
bounds,
}) = elem.as_mut()
{
if bounds.len() == 1 {
if let TypeParamBound::Trait(bound) = &mut bounds[0] {
if bound.path.segments.len() == 1 {
let first = &mut bound.path.segments[0];
if first.ident == "Iterator" {
if let PathArguments::AngleBracketed(args) = &mut first.arguments {
if args.args.len() == 1 {
if let GenericArgument::Binding(binding) = &mut args.args[0]
{
if binding.ident == "Item"
&& iter_type(&binding.ty).any(match_assoc_type)
{
let inner = self.convert_type(&mut binding.ty)?;
let box_type = BoxType {
inner: quote! {#elem},
placeholder_lifetime: true,
};
*type_ = Type::Verbatim(quote! {#box_type});
return Ok(TypeTransform::Iterator(
box_type,
inner.into(),
));
}
}
}
}
}
}
}
}
}
}
if let Type::Path(TypePath {
path,
qself: Some(qself),
}) = type_
{
if let Type::Path(self_path) = qself.ty.as_ref() {
if self_path.path.segments[0].ident == "Self" {
if !self_path.path.is_ident("Self") {
return Err((self_path.span(), TransformError::QualifiedAssociatedType));
}
if qself.position == 1
&& path.segments.len() == 2
&& path.segments[0].arguments.is_empty()
&& path.segments[0].ident == self.trait_ident
{
let ident = &path.segments[1].ident;
let dest_type =
self.assoc_type_conversions.get(ident).ok_or_else(|| {
(ident.span(), TransformError::AssocTypeWithoutDestType)
})?;
*type_ = dest_type.get_dest();
return Ok(dest_type.type_transformation());
}
return Err((type_.span(), TransformError::SelfQualifiedAsOtherTrait));
}
}
} else if let Type::Path(TypePath { path, qself: None }) = type_ {
if path.segments[0].ident == "Self" {
if path.segments.len() == 2 {
let ident = &path.segments.last().unwrap().ident;
let dest_type = self
.assoc_type_conversions
.get(ident)
.ok_or_else(|| (ident.span(), TransformError::AssocTypeWithoutDestType))?;
*type_ = dest_type.get_dest();
return Ok(dest_type.type_transformation());
}
} else {
let path_len = path.segments.len();
let last_seg = path.segments.last_mut().unwrap();
if let PathArguments::AngleBracketed(args) = &mut last_seg.arguments {
let mut args: Vec<_> = type_arguments_mut(&mut args.args).collect();
if path_len == 1 {
if let Some(type_count) = self.get_collection_type_count(&last_seg.ident) {
if args.len() < type_count {
return Err((
last_seg.span(),
TransformError::ExpectedAtLeastNTypes(type_count),
));
}
for ty in args.iter().skip(type_count) {
if iter_type(ty).any(match_assoc_type) {
return Err((
ty.span(),
TransformError::AssocTypeAfterFirstNTypes(
type_count,
last_seg.ident.clone(),
),
));
}
}
let mut transforms = Vec::new();
for arg in args {
transforms.push(self.convert_type(arg)?);
}
return Ok(TypeTransform::IntoIterMapCollect(transforms));
}
}
if args.len() == 1 {
if iter_type(args[0]).any(match_assoc_type)
&& ((last_seg.ident == "Option" && path_len == 1)
|| last_seg.ident == "Result")
{
return Ok(TypeTransform::Map(self.convert_type(args[0])?.into()));
}
} else if args.len() == 2
&& path_len == 1
&& (iter_type(args[0]).any(match_assoc_type)
|| iter_type(args[1]).any(match_assoc_type))
&& last_seg.ident == "Result"
{
return Ok(TypeTransform::Result(
self.convert_type(args[0])?.into(),
self.convert_type(args[1])?.into(),
));
}
}
}
}
Err((type_.span(), TransformError::UnsupportedType))
}
}
pub fn dynamize_function_bounds(
generics: &mut Generics,
type_converter: &mut TypeConverter,
) -> Result<HashMap<Ident, Vec<TypeTransform>>, (Span, MethodError)> {
let mut type_param_transforms = HashMap::new();
for type_param in generics.type_params_mut() {
for bound in &mut type_param.bounds {
if let TypeParamBound::Trait(bound) = bound {
dynamize_trait_bound(
bound,
type_converter,
&type_param.ident,
&mut type_param_transforms,
)?;
}
}
}
if let Some(where_clause) = &mut generics.where_clause {
for predicate in &mut where_clause.predicates {
if let WherePredicate::Type(predicate_type) = predicate {
if let Type::Path(path) = &mut predicate_type.bounded_ty {
if let Some(ident) = path.path.get_ident() {
for bound in &mut predicate_type.bounds {
if let TypeParamBound::Trait(bound) = bound {
dynamize_trait_bound(
bound,
type_converter,
ident,
&mut type_param_transforms,
)?;
}
}
continue;
}
}
if let Some(assoc_type) =
iter_type(&predicate_type.bounded_ty).find_map(filter_map_assoc_paths)
{
return Err((assoc_type.span(), MethodError::UnconvertedAssocType));
}
for bound in &mut predicate_type.bounds {
if let TypeParamBound::Trait(bound) = bound {
if let Some(assoc_type) =
iter_path(&bound.path).find_map(filter_map_assoc_paths)
{
return Err((assoc_type.span(), MethodError::UnconvertedAssocType));
}
}
}
}
}
}
Ok(type_param_transforms)
}
fn dynamize_trait_bound(
bound: &mut TraitBound,
type_converter: &mut TypeConverter,
type_ident: &Ident,
type_param_transforms: &mut HashMap<Ident, Vec<TypeTransform>>,
) -> Result<(), (Span, MethodError)> {
if bound.path.segments.len() == 1 {
let segment = &mut bound.path.segments[0];
if let PathArguments::Parenthesized(args) = &mut segment.arguments {
if segment.ident == "Fn" || segment.ident == "FnOnce" || segment.ident == "FnMut" {
let mut transforms = Vec::new();
for input_type in &mut args.inputs {
match type_converter.convert_type(input_type) {
Ok(ret_type) => {
transforms.push(ret_type);
}
Err((span, err)) => {
return Err((span, err.into()));
}
}
}
if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) {
type_param_transforms.insert(type_ident.clone(), transforms);
}
}
}
}
if let Some(path) = iter_path(&bound.path)
.filter_map(filter_map_assoc_paths)
.next()
{
return Err((path.span(), MethodError::UnconvertedAssocType));
}
Ok(())
}