use crate::fold::TypeReplacer;
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote, ToTokens};
use std::collections::HashMap;
use syn::{
fold::Fold,
spanned::Spanned,
Expr,
GenericArgument,
Lit,
LitInt,
PathArguments,
Type,
TypeArray,
TypePath,
TypeTuple,
};
pub trait ExprDefaultImpl {
fn expr_default_impl(
&self,
generic_finder: &HashMap<Ident, Ident>,
fn_name_by_type: &HashMap<Type, Ident>,
has_no_generics: bool,
top_level: bool,
is_apply: bool,
val_name: TokenStream,
) -> Option<TokenStream>;
}
impl ExprDefaultImpl for TypePath {
fn expr_default_impl(
&self,
generic_finder: &HashMap<Ident, Ident>,
fn_name_by_type: &HashMap<Type, Ident>,
has_no_generics: bool,
top_level: bool,
is_apply: bool,
val_name: TokenStream,
) -> Option<TokenStream> {
let has_no_generics = has_no_generics || {
let mut replacer = TypeReplacer::new(generic_finder);
let replaced = replacer.fold_type_path(self.clone());
*self == replaced
};
if has_no_generics {
if is_apply {
Some(quote! {
#val_name
})
} else {
Some(quote! {})
}
} else if !top_level && generic_finder.contains_key(&self.path.segments[0].ident) {
let fn_name = fn_name_by_type.get(&Type::Path(self.clone()))?;
Some(quote! {
self.#fn_name(#val_name)
})
} else if self.path.segments[0].ident == "Vec" {
let loop_var = format_ident!("x");
let PathArguments::AngleBracketed(b) = &self.path.segments[0].arguments else {
None?
};
let GenericArgument::Type(new_ty) = &b.args[0] else {
None?
};
let inner = new_ty.expr_default_impl(
generic_finder,
fn_name_by_type,
has_no_generics,
false,
is_apply,
loop_var.to_token_stream(),
)?;
if is_apply {
Some(quote! {
#val_name.into_iter().map(|#loop_var| {
#inner
}).collect()
})
} else {
Some(quote! {
#val_name.iter().for_each(|#loop_var| {
#inner;
})
})
}
} else {
None
}
}
}
impl ExprDefaultImpl for Type {
fn expr_default_impl(
&self,
generic_finder: &HashMap<Ident, Ident>,
fn_name_by_type: &HashMap<Type, Ident>,
has_no_generics: bool,
top_level: bool,
is_apply: bool,
val_name: TokenStream,
) -> Option<TokenStream> {
match self {
Type::Tuple(y) => y.expr_default_impl(
generic_finder,
fn_name_by_type,
has_no_generics,
false,
is_apply,
val_name,
),
Type::Path(new_ty) => new_ty.expr_default_impl(
generic_finder,
fn_name_by_type,
has_no_generics,
top_level,
is_apply,
val_name,
),
Type::Array(new_ty) => new_ty.expr_default_impl(
generic_finder,
fn_name_by_type,
has_no_generics,
false,
is_apply,
val_name,
),
y => {
println!("unknown type: {y:?}");
None?
}
}
}
}
impl ExprDefaultImpl for TypeTuple {
fn expr_default_impl(
&self,
generic_finder: &HashMap<Ident, Ident>,
fn_name_by_type: &HashMap<Type, Ident>,
has_no_generics: bool,
_top_level: bool,
is_apply: bool,
val_name: TokenStream,
) -> Option<TokenStream> {
let inner = self
.elems
.iter()
.enumerate()
.map(|(i, x)| {
let local_idx = LitInt::new(&format!("{i}"), val_name.span());
let local_var = if is_apply {
quote! { #val_name.#local_idx }
} else {
quote! { &#val_name.#local_idx }
};
x.expr_default_impl(
generic_finder,
fn_name_by_type,
has_no_generics,
false,
is_apply,
local_var,
)
.map(|x1| {
if i == 0 {
x1
} else {
quote! {, #x1}
}
})
})
.collect::<Option<TokenStream>>()?;
Some(quote! {
(#inner)
})
}
}
impl ExprDefaultImpl for TypeArray {
fn expr_default_impl(
&self,
generic_finder: &HashMap<Ident, Ident>,
fn_name_by_type: &HashMap<Type, Ident>,
has_no_generics: bool,
_top_level: bool,
is_apply: bool,
val_name: TokenStream,
) -> Option<TokenStream> {
let Expr::Lit(len) = &self.len else { None? };
let Lit::Int(len) = &len.lit else { None? };
let len = len.base10_parse::<usize>().ok()?;
if is_apply {
let inner = (0..len)
.map(|i| {
let local_val = quote! {#val_name[#i]};
let inner = self.elem.expr_default_impl(
generic_finder,
fn_name_by_type,
has_no_generics,
false,
is_apply,
local_val,
);
inner.map(|x| {
if i == 0 {
quote! {#x}
} else {
quote! {, #x}
}
})
})
.collect::<Option<TokenStream>>()?;
Some(quote! {
[#inner]
})
} else {
let local_val = quote! {x};
let inner = self.elem.expr_default_impl(
generic_finder,
fn_name_by_type,
has_no_generics,
false,
is_apply,
local_val.clone(),
);
Some(quote! {
#val_name.iter().for_each(|#local_val| {#inner});
})
}
}
}