use darling::ast::NestedMeta;
use darling::FromMeta;
use derive_syn_parse::Parse;
use proc_macro::TokenStream as TokenStream1;
use proc_macro2::Span;
use proc_macro2::TokenStream;
use proc_macro_error::{abort, proc_macro_error};
use std::collections::{HashMap, HashSet};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::*;
use template_quote::{quote, ToTokens};
use type_leak::*;
mod sumtrait_internal;
fn random() -> u64 {
use std::hash::{BuildHasher, Hasher};
std::collections::hash_map::RandomState::new()
.build_hasher()
.finish()
}
fn generic_param_to_arg(i: GenericParam) -> GenericArgument {
match i {
GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => {
GenericArgument::Lifetime(lifetime)
}
GenericParam::Type(TypeParam { ident, .. }) => GenericArgument::Type(parse_quote!(#ident)),
GenericParam::Const(ConstParam { ident, .. }) => {
GenericArgument::Const(parse_quote!(#ident))
}
}
}
fn merge_generic_params(
args1: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
args2: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
) -> impl Iterator<Item = GenericParam> {
let it1 = args1.into_iter();
let it2 = args2.into_iter();
it1.clone()
.filter(|arg| matches!(arg, GenericParam::Lifetime(_)))
.chain(
it2.clone()
.filter(|arg| matches!(arg, GenericParam::Lifetime(_))),
)
.chain(
it1.clone()
.filter(|arg| matches!(arg, GenericParam::Const(_))),
)
.chain(
it2.clone()
.filter(|arg| matches!(arg, GenericParam::Const(_))),
)
.chain(
it1.clone()
.filter(|arg| matches!(arg, GenericParam::Type(_))),
)
.chain(
it2.clone()
.filter(|arg| matches!(arg, GenericParam::Type(_))),
)
}
fn merge_generic_args(
args1: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
args2: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
) -> impl Iterator<Item = GenericArgument> {
let it1 = args1.into_iter();
let it2 = args2.into_iter();
it1.clone()
.filter(|arg| matches!(arg, GenericArgument::Lifetime(_)))
.chain(
it2.clone()
.filter(|arg| matches!(arg, GenericArgument::Lifetime(_))),
)
.chain(
it1.clone()
.filter(|arg| matches!(arg, GenericArgument::Const(_))),
)
.chain(
it2.clone()
.filter(|arg| matches!(arg, GenericArgument::Const(_))),
)
.chain(
it1.clone()
.filter(|arg| matches!(arg, GenericArgument::Type(_))),
)
.chain(
it2.clone()
.filter(|arg| matches!(arg, GenericArgument::Type(_))),
)
.chain(it1.filter(|arg| {
matches!(
arg,
GenericArgument::AssocType(_)
| GenericArgument::AssocConst(_)
| GenericArgument::Constraint(_)
)
}))
.chain(it2.filter(|arg| {
matches!(
arg,
GenericArgument::AssocType(_)
| GenericArgument::AssocConst(_)
| GenericArgument::Constraint(_)
)
}))
}
fn path_of_ident(ident: Ident, is_super: bool) -> Path {
let mut segments = vec![];
if is_super {
segments.push(PathSegment {
ident: Ident::new("super", Span::call_site()),
arguments: PathArguments::None,
});
}
segments.push(PathSegment {
ident,
arguments: PathArguments::None,
});
Path {
leading_colon: None,
segments: segments.into_iter().collect(),
}
}
fn split_for_impl(
generics: Option<&Generics>,
) -> (Vec<GenericParam>, Vec<GenericArgument>, Vec<WherePredicate>) {
if let Some(generics) = generics {
let (_, ty_generics, where_clause) = generics.split_for_impl();
let ty_generics: std::result::Result<AngleBracketedGenericArguments, _> =
parse2(ty_generics.into_token_stream());
(
generics.params.iter().cloned().collect(),
ty_generics
.map(|g| g.args.into_iter().collect())
.unwrap_or(vec![]),
where_clause
.map(|w| w.predicates.iter().cloned().collect())
.unwrap_or(vec![]),
)
} else {
(vec![], vec![], vec![])
}
}
#[derive(Parse)]
struct Arguments {
#[call(Punctuated::parse_terminated)]
bounds: Punctuated<Path, Token![+]>,
}
enum SumTypeImpl {
Trait(Path),
}
impl SumTypeImpl {
#[allow(clippy::too_many_arguments)]
fn gen(
&self,
enum_path: &Path,
unspecified_ty_params: &[Ident],
variants: &[(Ident, Type)],
impl_generics: Vec<GenericParam>,
ty_generics: Vec<GenericArgument>,
where_clause: Vec<WherePredicate>,
constraint_expr_trait_ident: &Ident,
) -> TokenStream {
match self {
SumTypeImpl::Trait(trait_path) => {
quote! {
#trait_path!(
#constraint_expr_trait_ident,
#trait_path,
#enum_path,
[#(#unspecified_ty_params),*],
[#(for (id, ty) in variants),{#id:#ty}],
[ #(#impl_generics),* ],
[#(#ty_generics),*],
{ #(#where_clause),* },
);
}
}
}
}
}
struct ExprMacroInfo {
span: Span,
variant_ident: Ident,
reftype_ident: Option<Ident>,
analyzed_bounds: HashMap<Ident, HashSet<Lifetime>>,
generics: Generics,
}
struct TypeMacroInfo {
_span: Span,
generic_args: Punctuated<GenericArgument, Token![,]>,
}
trait ProcessTree: Sized {
fn collect_inline_macro(
&mut self,
enum_path: &Path,
typeref_path: &Path,
constraint_expr_trait_path: &Path,
generics: Option<&Generics>,
is_module: bool,
) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>);
fn emit_items(
mut self,
args: &Arguments,
generics_env: Option<&Generics>,
is_module: bool,
vis: Visibility,
) -> (TokenStream, Self) {
let r = random();
let enum_ident = Ident::new(&format!("__Sumtype_Enum_{}", r), Span::call_site());
let typeref_ident =
Ident::new(&format!("__Sumtype_TypeRef_Trait_{}", r), Span::call_site());
let constraint_expr_trait_ident = Ident::new(
&format!("__Sumtype_ConstraintExprTrait_{}", r),
Span::call_site(),
);
let (found_exprs, type_emitted) = self.collect_inline_macro(
&path_of_ident(enum_ident.clone(), is_module),
&path_of_ident(typeref_ident.clone(), is_module),
&path_of_ident(constraint_expr_trait_ident.clone(), is_module),
generics_env,
is_module,
);
let reftypes = found_exprs
.iter()
.filter_map(|info| info.reftype_ident.clone())
.collect::<Vec<_>>();
let (impl_generics_env, _, where_clause_env) = split_for_impl(generics_env);
if found_exprs.is_empty() {
abort!(Span::call_site(), "Cannot find any sumtype!() in expr");
}
let expr_generics_list = found_exprs.iter().fold(HashMap::new(), |mut acc, info| {
*acc.entry(info.generics.clone()).or_insert(0usize) += 1;
acc
});
if expr_generics_list.len() != 1 {
let mut expr_gparams = expr_generics_list.into_iter().collect::<Vec<_>>();
expr_gparams.sort_by_key(|item| item.1);
abort!(expr_gparams[0].0.span(), "Generic argument mismatch");
}
let expr_generics = expr_generics_list.into_iter().next().unwrap().0;
let mut analyzed = found_exprs.iter().fold(
HashMap::new(),
|mut acc: HashMap<Ident, HashSet<TypeParamBound>>, info| {
for (id, lts) in &info.analyzed_bounds {
acc.entry(id.clone())
.or_default()
.extend(lts.iter().map(|lt| TypeParamBound::Lifetime(lt.clone())));
}
acc
},
);
if let Some(where_clause) = &expr_generics.where_clause {
for pred in &where_clause.predicates {
if let WherePredicate::Type(PredicateType {
bounded_ty: Type::Path(path),
bounds,
..
}) = pred
{
if path.qself.is_none() {
if let Some(id) = path.path.get_ident() {
analyzed
.entry(id.clone())
.or_insert(HashSet::new())
.extend(bounds.clone());
}
}
}
}
}
let expr_garg = expr_generics
.params
.iter()
.cloned()
.map(generic_param_to_arg)
.collect::<Vec<_>>();
for info in &type_emitted {
if info.generic_args.len() != expr_garg.len()
|| !expr_garg.iter().zip(&info.generic_args).all(|two| {
matches!(
two,
(GenericArgument::Lifetime(_), GenericArgument::Lifetime(_))
| (GenericArgument::Const(_), GenericArgument::Const(_))
| (GenericArgument::Type(_), GenericArgument::Type(_))
)
})
{
abort!(
info.generic_args.span(),
"The generic arguments are incompatible with generic params in expression."
)
}
}
let mut impl_generics =
merge_generic_params(impl_generics_env, expr_generics.params).collect::<Vec<_>>();
for g in impl_generics.iter_mut() {
if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
if let Some(bs) = analyzed.get(ident) {
for b in bs {
bounds.push(b.clone());
}
}
}
}
let ty_generics = impl_generics
.iter()
.cloned()
.map(generic_param_to_arg)
.collect::<Vec<_>>();
let where_clause = expr_generics
.where_clause
.clone()
.map(|wc| wc.predicates)
.into_iter()
.flatten()
.chain(where_clause_env)
.collect::<Vec<_>>();
let (unspecified_ty_params, variants) = found_exprs.iter().enumerate().fold(
(vec![], vec![]),
|(mut ty_params, mut variants), (i, info)| {
if let Some(reft) = &info.reftype_ident {
variants.push((
info.variant_ident.clone(),
parse_quote!(<#reft as #typeref_ident<#(#ty_generics),*>>::Type),
));
} else {
let tp_ident =
Ident::new(&format!("__Sumtype_TypeParam_{}", i), Span::call_site());
variants.push((info.variant_ident.clone(), parse_quote!(#tp_ident)));
ty_params.push(tp_ident);
}
(ty_params, variants)
},
);
if let (Some(info), true) = (
found_exprs.iter().find(|info| info.reftype_ident.is_none()),
!type_emitted.is_empty(),
) {
abort!(
&info.span,
r#"
To emit full type, you should specify the type.
Example: sumtype!(std::iter::empty(), std::iter::Empty<T>)
"#
)
} else {
let replaced_ty_generics: Vec<_> = ty_generics
.iter()
.map(|ga| match ga {
GenericArgument::Lifetime(lt) => quote!(& #lt ()),
GenericArgument::Const(_) => quote!(),
o => quote!(#o),
})
.collect();
let constraint_traits = (0..args.bounds.len())
.map(|n| {
Ident::new(
&format!("__Sumtype_ConstraintExprTrait_{}_{}", n, random()),
Span::call_site(),
)
})
.collect::<Vec<_>>();
let out = quote! {
#(for reft in &reftypes) {
#[doc(hidden)]
#[allow(non_camel_case_types)]
#[allow(non_camel_case_types)]
struct #reft;
}
#[doc(hidden)]
#[allow(non_camel_case_types)]
trait #typeref_ident <#(#impl_generics),*> { type Type; }
#[doc(hidden)]
#[allow(non_camel_case_types)]
#vis enum #enum_ident <
#(#impl_generics),*
#(if !impl_generics.is_empty() && !unspecified_ty_params.is_empty()) { , }
#(#unspecified_ty_params),*
> {
#(for (ident, ty) in &variants) {
#ident ( #ty ),
}
__Uninhabited(
(
::core::convert::Infallible,
#(::core::marker::PhantomData<#replaced_ty_generics>),*
)
),
}
#[doc(hidden)]
#[allow(non_camel_case_types)]
trait #constraint_expr_trait_ident<#(#impl_generics),*> {}
impl<#(#impl_generics,)*__Sumtype_TypeParam> #constraint_expr_trait_ident<#(#ty_generics),*> for __Sumtype_TypeParam
where
#(for t in &constraint_traits) {
__Sumtype_TypeParam: #t<#(#ty_generics),*>,
}
#(#where_clause,)*
{}
#(for (trait_, constraint_trait) in args.bounds.iter().zip(&constraint_traits)) {
#{ SumTypeImpl::Trait(trait_.clone()).gen(
&path_of_ident(enum_ident.clone(), false),
unspecified_ty_params.as_slice(),
variants.as_slice(),
impl_generics.clone(),
ty_generics.clone(),
where_clause.clone(),
constraint_trait,
) }
}
};
(out, self)
}
}
}
const _: () = {
use syn::visit_mut::VisitMut;
struct Visitor<'a> {
enum_path: &'a Path,
typeref_path: &'a Path,
constraint_expr_trait_path: &'a Path,
found_exprs: Vec<ExprMacroInfo>,
emit_type: Vec<TypeMacroInfo>,
generics: Option<&'a Generics>,
is_module: bool,
}
impl ProcessTree for Block {
fn collect_inline_macro(
&mut self,
enum_path: &Path,
typeref_path: &Path,
constraint_expr_trait_path: &Path,
generics: Option<&Generics>,
is_module: bool,
) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
let mut visitor = Visitor::new(
enum_path,
typeref_path,
constraint_expr_trait_path,
generics,
is_module,
);
visitor.visit_block_mut(self);
(visitor.found_exprs, visitor.emit_type)
}
}
impl ProcessTree for Item {
fn collect_inline_macro(
&mut self,
enum_path: &Path,
typeref_path: &Path,
constraint_expr_trait_path: &Path,
generics: Option<&Generics>,
is_module: bool,
) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
let mut visitor = Visitor::new(
enum_path,
typeref_path,
constraint_expr_trait_path,
generics,
is_module,
);
visitor.visit_item_mut(self);
(visitor.found_exprs, visitor.emit_type)
}
}
impl ProcessTree for Stmt {
fn collect_inline_macro(
&mut self,
enum_path: &Path,
typeref_path: &Path,
constraint_expr_trait_path: &Path,
generics: Option<&Generics>,
is_module: bool,
) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
let mut visitor = Visitor::new(
enum_path,
typeref_path,
constraint_expr_trait_path,
generics,
is_module,
);
visitor.visit_stmt_mut(self);
(visitor.found_exprs, visitor.emit_type)
}
}
impl<'a> Visitor<'a> {
fn new(
enum_path: &'a Path,
typeref_path: &'a Path,
constraint_expr_trait_path: &'a Path,
generics: Option<&'a Generics>,
is_module: bool,
) -> Self {
Self {
enum_path,
typeref_path,
constraint_expr_trait_path,
found_exprs: Vec::new(),
emit_type: Vec::new(),
generics,
is_module,
}
}
fn do_type_macro(&mut self, mac: &Macro) -> TokenStream {
#[derive(Parse)]
struct Arg {
#[call(Punctuated::parse_terminated)]
generic_args: Punctuated<GenericArgument, Token![,]>,
}
let arg: Arg = mac
.parse_body()
.unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
let ty_generics = merge_generic_args(
self.generics
.iter()
.flat_map(|g| g.params.iter().cloned().map(generic_param_to_arg)),
arg.generic_args.clone(),
)
.collect::<Vec<_>>();
self.emit_type.push(TypeMacroInfo {
_span: mac.span(),
generic_args: arg.generic_args,
});
quote! {
#{&self.enum_path}
#(if !ty_generics.is_empty()){
<#(#ty_generics),*>
}
}
}
fn analyze_lifetime_bounds(
&self,
generics: &Generics,
ty: &Type,
) -> HashMap<Ident, HashSet<Lifetime>> {
struct LifetimeVisitor {
generic_lifetimes: HashSet<Lifetime>,
generic_params: HashSet<Ident>,
lifetime_stack: Vec<Lifetime>,
result: HashMap<Ident, HashSet<Lifetime>>,
}
use syn::visit::Visit;
impl syn::visit::Visit<'_> for LifetimeVisitor {
fn visit_type_reference(&mut self, i: &TypeReference) {
if let Some(lt) = &i.lifetime {
if self.generic_lifetimes.contains(lt) {
self.lifetime_stack.push(lt.clone());
syn::visit::visit_type_reference(self, i);
self.lifetime_stack.pop();
return;
}
}
syn::visit::visit_type_reference(self, i);
}
fn visit_type_path(&mut self, i: &TypePath) {
if i.qself.is_none() {
if let Some(id) = i.path.get_ident() {
if self.generic_params.contains(id) {
self.result
.entry(id.clone())
.or_default()
.extend(self.lifetime_stack.clone());
}
return;
}
}
syn::visit::visit_type_path(self, i);
}
}
let mut visitor = LifetimeVisitor {
generic_lifetimes: generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Lifetime(LifetimeParam { lifetime, .. }) = p {
Some(lifetime.clone())
} else {
None
}
})
.collect(),
generic_params: generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Type(TypeParam { ident, .. }) = p {
Some(ident.clone())
} else {
None
}
})
.collect(),
lifetime_stack: Vec::new(),
result: HashMap::new(),
};
visitor.visit_type(ty);
visitor.result
}
fn do_expr_macro(&mut self, mac: &Macro) -> TokenStream {
#[derive(Parse)]
struct Arg {
expr: Expr,
_comma_token: Option<Token![,]>,
_for_token: Option<Token![for]>,
#[prefix(Option<Token![<]>)]
#[postfix(Option<Token![>]>)]
#[parse_if(_for_token.is_some())]
#[call(Punctuated::parse_separated_nonempty)]
for_generics: Option<Punctuated<GenericParam, Token![,]>>,
#[parse_if(_comma_token.is_some())]
ty: Option<Type>,
#[parse_if(_comma_token.is_some())]
where_clause: Option<Option<WhereClause>>,
}
let arg: Arg = mac
.parse_body()
.unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
let n = self.found_exprs.len();
let variant_ident = Ident::new(&format!("__SumType_Variant_{}", n), Span::call_site());
let reftype_ident = Ident::new(
&format!("__SumType_RefType_{}_{}", random(), n),
Span::call_site(),
);
let reftype_path = path_of_ident(reftype_ident.clone(), self.is_module);
let id_fn_ident =
Ident::new(&format!("__sum_type_id_fn_{}", random()), Span::call_site());
let (mut impl_generics, _, where_clause) = split_for_impl(self.generics);
let analyzed =
if let (Some(generics), Some(ty)) = (self.generics.as_ref(), arg.ty.as_ref()) {
self.analyze_lifetime_bounds(generics, ty)
} else {
HashMap::new()
};
let generics = Generics {
params: arg.for_generics.clone().unwrap_or_default(),
where_clause: arg.where_clause.unwrap_or(Some(WhereClause {
predicates: Punctuated::new(),
where_token: Default::default(),
})),
..Default::default()
};
for g in impl_generics.iter_mut() {
if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
if let Some(lts) = analyzed.get(ident) {
for lt in lts {
bounds.push(TypeParamBound::Lifetime(lt.clone().clone()));
}
}
}
}
let impl_generics =
merge_generic_params(impl_generics, generics.params.clone()).collect::<Vec<_>>();
let ty_generics = impl_generics
.iter()
.cloned()
.map(generic_param_to_arg)
.collect::<Vec<_>>();
let where_clause = generics
.where_clause
.clone()
.map(|wc| wc.predicates)
.into_iter()
.flatten()
.chain(where_clause)
.collect::<Vec<_>>();
self.found_exprs.push(ExprMacroInfo {
span: mac.span(),
variant_ident: variant_ident.clone(),
reftype_ident: arg.ty.as_ref().map(|_| reftype_ident.clone()),
analyzed_bounds: analyzed.clone(),
generics,
});
quote! {
{
#(if let Some(ty) = &arg.ty){
impl<#(#impl_generics,)*> #{&self.typeref_path} <#(#ty_generics),*> for #reftype_path
#(if !where_clause.is_empty()) {
where #(#where_clause,)*
}
{
type Type = #ty;
}
}
fn #id_fn_ident<
#(#impl_generics,)* __SumType_T: #{&self.constraint_expr_trait_path}<#(#ty_generics),*>
>(t: __SumType_T) -> __SumType_T
#(if !where_clause.is_empty()) {
where #(#where_clause,)*
}
{ t }
#id_fn_ident::<#(#ty_generics,)*_>(#{&self.enum_path}::#variant_ident(#{&arg.expr}))
}
}
}
}
impl VisitMut for Visitor<'_> {
fn visit_type_mut(&mut self, ty: &mut Type) {
if let Type::Macro(tm) = &*ty {
if tm.mac.path.is_ident("sumtype") {
let out = self.do_type_macro(&tm.mac);
*ty = parse2(out).unwrap();
return;
}
}
syn::visit_mut::visit_type_mut(self, ty);
}
fn visit_expr_mut(&mut self, expr: &mut Expr) {
if let Expr::Macro(em) = &*expr {
if em.mac.path.is_ident("sumtype") {
let out = self.do_expr_macro(&em.mac);
*expr = parse2(out).unwrap();
return;
}
}
syn::visit_mut::visit_expr_mut(self, expr);
}
fn visit_stmt_mut(&mut self, stmt: &mut Stmt) {
if let Stmt::Macro(sm) = &*stmt {
if sm.mac.path.is_ident("sumtype") {
let out = self.do_expr_macro(&sm.mac);
*stmt = parse2(out).unwrap();
return;
}
}
syn::visit_mut::visit_stmt_mut(self, stmt);
}
}
};
fn inner(args: &Arguments, input: TokenStream) -> TokenStream {
let public = Visibility::Public(Default::default());
if let Ok(block) = parse2::<Block>(input.clone()) {
let (out, block) = block.emit_items(args, None, false, public);
quote! { #out #[allow(non_local_definitions)] #block }
} else if let Ok(item_trait) = parse2::<ItemTrait>(input.clone()) {
let generics = item_trait.generics.clone();
let vis = item_trait.vis.clone();
let (out, item) = Item::Trait(item_trait).emit_items(args, Some(&generics), false, vis);
quote! { #out #[allow(non_local_definitions)] #item }
} else if let Ok(item_impl) = parse2::<ItemImpl>(input.clone()) {
let generics = item_impl.generics.clone();
let (out, item) = Item::Impl(item_impl).emit_items(args, Some(&generics), false, public);
quote! { #out #[allow(non_local_definitions)] #item }
} else if let Ok(item_fn) = parse2::<ItemFn>(input.clone()) {
let generics = item_fn.sig.generics.clone();
let vis = item_fn.vis.clone();
let (out, item) = Item::Fn(item_fn).emit_items(args, Some(&generics), false, vis);
quote! { #out #[allow(non_local_definitions)] #item }
} else if let Ok(item_mod) = parse2::<ItemMod>(input.clone()) {
let (out, item) = Item::Mod(item_mod).emit_items(args, None, true, public);
quote! { #out #[allow(non_local_definitions)] #item }
} else if let Ok(item) = parse2::<Item>(input.clone()) {
let (out, item) = item.emit_items(args, None, false, public);
quote! { #out #[allow(non_local_definitions)] #item }
} else if let Ok(stmt) = parse2::<Stmt>(input.clone()) {
let (out, stmt) = stmt.emit_items(args, None, false, public);
quote! { #out #[allow(non_local_definitions)] #stmt }
} else {
abort!(input.span(), "This element is not supported")
}
}
fn process_supported_supertraits<'a>(
traits: impl IntoIterator<Item = &'a TypeParamBound>,
krate: &Path,
) -> (Vec<Path>, Vec<Path>) {
let mut supertraits = Vec::new();
let mut derive_traits = Vec::new();
for tpb in traits.into_iter() {
if let TypeParamBound::Trait(tb) = tpb {
if let Some(ident) = tb.path.get_ident() {
match ident.to_string().as_str() {
"Copy" | "Clone" | "Hash" | "Eq" => {
supertraits.push(parse_quote!(#krate::traits::#ident))
}
"PartialEq" => derive_traits.push(parse_quote!(PartialEq)),
o if o.starts_with("__SumTrait_Sealed") => (),
_ => (),
}
} else {
supertraits.push(tb.path.clone())
}
} else {
abort!(tpb.span(), "Only path is supported");
}
}
(supertraits, derive_traits)
}
fn collect_typeref_types(input: &ItemTrait) -> Vec<Type> {
let mut leaker = Leaker::from_trait(input)
.unwrap_or_else(|_| Leaker::with_generics(input.generics.clone()));
leaker.self_ty_can_be_interned = false;
leaker
.finish()
.iter()
.cloned()
.collect()
}
fn sumtrait_impl(
args: Option<Path>,
marker_path: &Path,
krate: &Path,
input: ItemTrait,
) -> TokenStream {
let (supertraits, derive_traits) = process_supported_supertraits(&input.supertraits, krate);
for item in &input.items {
match item {
TraitItem::Const(_) => abort!(item.span(), "associated const is not supported"),
TraitItem::Fn(tfn) => {
if tfn.sig.inputs.is_empty() || !matches!(&tfn.sig.inputs[0], FnArg::Receiver(_)) {
abort!(tfn.sig.span(), "requires receiver")
}
}
TraitItem::Type(tty) => {
if tty.default.is_some() {
abort!(tty.span(), "associated type defaults is not supported")
}
if !tty.generics.params.is_empty() || tty.generics.where_clause.is_some() {
abort!(
tty.generics.span(),
"generalized associated types is not supported"
)
}
}
o => abort!(o.span(), "Not supported"),
}
}
let temporary_mac_name =
Ident::new(&format!("__sumtype_macro_{}", random()), Span::call_site());
let typeref_types = collect_typeref_types(&input);
let (_, _, where_clause) = input.generics.split_for_impl();
let typeref_id = random() as usize;
quote! {
#input
#(for (i, ty) in typeref_types.iter().enumerate()) {
impl<#(for p in &input.generics.params),{#p}> #krate::TypeRef<#typeref_id, #i> for #marker_path #where_clause {
type Type = #ty;
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! #temporary_mac_name {
($($t:tt)*) => {
#krate::_sumtrait_internal!(
{ $($t)* }
[#(#typeref_types),*],
{#input},
#typeref_id,
#krate,
#marker_path,
[#{args.map(|m| quote!(#m)).unwrap_or(quote!(_))}],
[#(#supertraits),*],
[#(#derive_traits),*],
);
};
}
#[doc(hidden)]
#{&input.vis} use #temporary_mac_name as #{&input.ident};
}
}
#[doc(hidden)]
#[proc_macro_error]
#[proc_macro]
pub fn _sumtrait_internal(input: TokenStream1) -> TokenStream1 {
sumtrait_internal::sumtrait_internal(input.into()).into()
}
#[proc_macro_error]
#[proc_macro_attribute]
pub fn sumtrait(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
#[derive(FromMeta, Debug)]
struct SumtraitArgs {
implement: Option<Path>,
krate: Option<Path>,
marker: Path,
}
let args = SumtraitArgs::from_list(&NestedMeta::parse_meta_list(attr.into()).unwrap()).unwrap();
let krate = args.krate.unwrap_or(parse_quote!(::sumtype));
sumtrait_impl(
args.implement,
&args.marker,
&krate,
parse(input).unwrap_or_else(|_| abort!(Span::call_site(), "Requires trait definition")),
)
.into()
}
#[proc_macro_error]
#[proc_macro_attribute]
pub fn sumtype(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
inner(&parse_macro_input!(attr as Arguments), input.into()).into()
}