use std::{borrow::Cow, mem};
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{punctuated::Punctuated, *};
macro_rules! parse_quote {
($($tt:tt)*) => {
syn::parse2(quote::quote!($($tt)*))
};
}
macro_rules! error {
($span:expr, $msg:expr) => {
return Err(syn::Error::new_spanned(&$span, $msg))
};
($span:expr, $($tt:tt)*) => {
error!($span, format!($($tt)*))
};
}
fn param_ident(attrs: Vec<Attribute>, ident: Ident) -> GenericParam {
GenericParam::Type(TypeParam {
attrs,
ident,
colon_token: None,
bounds: Punctuated::new(),
eq_token: None,
default: None,
})
}
mod private_maybe_enum {
use super::*;
pub trait Sealed {}
impl Sealed for ItemEnum {}
impl Sealed for Item {}
impl Sealed for Stmt {}
impl Sealed for DeriveInput {}
}
pub struct EnumElements<'a> {
pub attrs: &'a [Attribute],
pub vis: &'a Visibility,
pub ident: &'a Ident,
pub generics: &'a Generics,
pub variants: &'a Punctuated<Variant, token::Comma>,
}
pub trait MaybeEnum: ToTokens + self::private_maybe_enum::Sealed {
fn elements(&self) -> Result<EnumElements<'_>>;
}
impl MaybeEnum for ItemEnum {
fn elements(&self) -> Result<EnumElements<'_>> {
Ok(EnumElements {
attrs: &self.attrs,
vis: &self.vis,
ident: &self.ident,
generics: &self.generics,
variants: &self.variants,
})
}
}
impl MaybeEnum for Item {
fn elements(&self) -> Result<EnumElements<'_>> {
match self {
Item::Enum(item) => MaybeEnum::elements(item),
_ => error!(self, "may only be used on enums"),
}
}
}
impl MaybeEnum for Stmt {
fn elements(&self) -> Result<EnumElements<'_>> {
match self {
Stmt::Item(Item::Enum(item)) => MaybeEnum::elements(item),
_ => error!(self, "may only be used on enums"),
}
}
}
impl MaybeEnum for DeriveInput {
fn elements(&self) -> Result<EnumElements<'_>> {
match &self.data {
Data::Enum(data) => Ok(EnumElements {
attrs: &self.attrs,
vis: &self.vis,
ident: &self.ident,
generics: &self.generics,
variants: &data.variants,
}),
Data::Struct(_) => error!(self, "cannot be implemented for structs"),
Data::Union(_) => error!(self, "cannot be implemented for unions"),
}
}
}
pub struct EnumData {
vis: Visibility,
ident: Ident,
generics: Generics,
variants: Vec<Ident>,
fields: Vec<Type>,
}
impl EnumData {
pub fn new<E>(maybe_enum: &E) -> Result<Self>
where
E: MaybeEnum,
{
let elements = MaybeEnum::elements(maybe_enum)?;
if elements.variants.is_empty() {
error!(maybe_enum, "cannot be implemented for enums with no variants");
}
parse_variants(elements.variants).map(|(variants, fields)| Self {
vis: elements.vis.clone(),
ident: elements.ident.clone(),
generics: elements.generics.clone(),
variants,
fields,
})
}
pub fn make_impl(&self) -> Result<EnumImpl<'_>> {
EnumImpl::new(self, Vec::new())
}
pub fn impl_with_capacity(&self, capacity: usize) -> Result<EnumImpl<'_>> {
EnumImpl::new(self, Vec::with_capacity(capacity))
}
pub fn make_impl_trait<I>(
&self,
trait_path: Path,
supertraits_types: I,
item: ItemTrait,
) -> Result<EnumImpl<'_>>
where
I: IntoIterator<Item = Ident>,
I::IntoIter: ExactSizeIterator,
{
EnumImpl::from_trait(self, trait_path, Vec::new(), item, supertraits_types)
}
pub fn impl_trait_with_capacity<I>(
&self,
capacity: usize,
trait_path: Path,
supertraits_types: I,
item: ItemTrait,
) -> Result<EnumImpl<'_>>
where
I: IntoIterator<Item = Ident>,
I::IntoIter: ExactSizeIterator,
{
EnumImpl::from_trait(
self,
trait_path,
Vec::with_capacity(capacity),
item,
supertraits_types,
)
}
#[doc(hidden)]
pub fn vis(&self) -> &Visibility {
&self.vis
}
#[doc(hidden)]
pub fn ident(&self) -> &Ident {
&self.ident
}
#[doc(hidden)]
pub fn generics(&self) -> &Generics {
&self.generics
}
#[doc(hidden)]
pub fn variants(&self) -> &[Ident] {
&self.variants
}
#[doc(hidden)]
pub fn fields(&self) -> &[Type] {
&self.fields
}
}
fn parse_variants(variants: &Punctuated<Variant, token::Comma>) -> Result<(Vec<Ident>, Vec<Type>)> {
variants.iter().try_fold(
(Vec::with_capacity(variants.len()), Vec::with_capacity(variants.len())),
|(mut variants, mut fields), v| {
if let Some((_, e)) = &v.discriminant {
error!(e, "an enum with discriminants is not supported")
}
match &v.fields {
Fields::Unnamed(f) => match f.unnamed.len() {
1 => fields.push(f.unnamed.iter().next().unwrap().ty.clone()),
0 => error!(v.fields, "a variant with zero fields is not supported"),
_ => error!(v.fields, "a variant with multiple fields is not supported"),
},
Fields::Unit => error!(v, "an enum with units variant is not supported"),
Fields::Named(_) => error!(v, "an enum with named fields variant is not supported"),
}
variants.push(v.ident.clone());
Ok((variants, fields))
},
)
}
#[doc(hidden)]
pub struct Trait {
path: Path,
ty: Path,
}
impl Trait {
#[doc(hidden)]
pub fn new(path: Path, ty: Path) -> Self {
Self { path, ty }
}
}
pub struct EnumImpl<'a> {
data: &'a EnumData,
defaultness: bool,
unsafety: bool,
generics: Generics,
trait_: Option<Trait>,
self_ty: Box<Type>,
items: Vec<ImplItem>,
unsafe_code: bool,
}
#[doc(hidden)]
pub fn build(impls: EnumImpl<'_>) -> TokenStream {
impls.build()
}
#[doc(hidden)]
pub fn build_item(impls: EnumImpl<'_>) -> ItemImpl {
impls.build_item()
}
impl<'a> EnumImpl<'a> {
fn new(data: &'a EnumData, items: Vec<ImplItem>) -> Result<Self> {
let ident = &data.ident;
let ty_generics = &data.generics;
parse_quote!(#ident #ty_generics).map(|self_ty| Self {
data,
defaultness: false,
unsafety: false,
generics: data.generics.clone(),
trait_: None,
self_ty: Box::new(self_ty),
items,
unsafe_code: false,
})
}
#[doc(hidden)]
pub fn trait_(&mut self) -> &mut Option<Trait> {
&mut self.trait_
}
#[doc(hidden)]
pub fn self_ty(&mut self) -> &mut Type {
&mut *self.self_ty
}
pub fn push_generic_param(&mut self, param: GenericParam) {
self.generics.params.push(param);
}
pub fn push_generic_param_ident(&mut self, ident: Ident) {
self.push_generic_param(param_ident(Vec::new(), ident));
}
pub fn push_where_predicate(&mut self, predicate: WherePredicate) {
self.generics.make_where_clause().predicates.push(predicate);
}
pub fn push_item(&mut self, item: ImplItem) {
self.items.push(item);
}
fn arms(&self, f: impl FnMut(&Ident) -> TokenStream) -> TokenStream {
let arms = self.data.variants.iter().map(f);
quote!(#(#arms,)*)
}
fn trait_path(&self) -> Option<&Path> {
self.trait_.as_ref().map(|t| &t.path)
}
pub fn push_method(&mut self, item: TraitItemMethod) -> Result<()> {
let self_ty = SelfType::parse(item.sig.inputs.iter().next())?;
let mut args = Vec::with_capacity(item.sig.inputs.len());
item.sig.inputs.iter().skip(1).try_for_each(|arg| match arg {
FnArg::Typed(arg) => {
args.push(&arg.pat);
Ok(())
}
_ => error!(arg, "unsupported arguments type"),
})?;
let args = &args;
let method = &item.sig.ident;
let ident = &self.data.ident;
let method = match self_ty {
SelfType::None => {
let trait_ = self.trait_path();
let arms = if trait_.is_none() {
self.arms(|v| quote!(#ident::#v(x) => x.#method(#(#args),*)))
} else {
self.arms(|v| quote!(#ident::#v(x) => #trait_::#method(x #(,#args)*)))
};
parse_quote!(match self { #arms })
}
SelfType::Pin(mode, pin) => {
self.unsafe_code = true;
let trait_ = self.trait_path();
let arms = if trait_.is_none() {
self.arms(
|v| quote!(#ident::#v(x) => #pin::new_unchecked(x).#method(#(#args),*)),
)
} else {
self.arms(|v| quote!(#ident::#v(x) => #trait_::#method(#pin::new_unchecked(x) #(,#args)*)))
};
match mode {
CaptureMode::Ref { mutability: false } => {
if self.unsafety || item.sig.unsafety.is_some() {
parse_quote!(match #pin::get_ref(self) { #arms })
} else {
parse_quote!(unsafe { match #pin::get_ref(self) { #arms } })
}
}
CaptureMode::Ref { mutability: true } => {
if self.unsafety || item.sig.unsafety.is_some() {
parse_quote!(match #pin::get_unchecked_mut(self) { #arms })
} else {
parse_quote!(unsafe { match #pin::get_unchecked_mut(self) { #arms } })
}
}
}
}
};
method.map(|method| {
self.push_item(ImplItem::Method(ImplItemMethod {
attrs: item.attrs,
vis: Visibility::Inherited,
defaultness: None,
sig: item.sig,
block: Block {
brace_token: token::Brace::default(),
stmts: vec![Stmt::Expr(method)],
},
}))
})
}
pub fn append_items_from_trait(&mut self, item: ItemTrait) -> Result<()> {
let fst = self.data.fields.iter().next();
item.items.into_iter().try_for_each(|item| match item {
TraitItem::Type(TraitItemType { ident, .. }) => {
let trait_ = self.trait_.as_ref().map(|t| &t.ty);
parse_quote!(type #ident = <#fst as #trait_>::#ident;)
.map(|ty| self.push_item(ImplItem::Type(ty)))
}
TraitItem::Method(method) => self.push_method(method),
_ => Ok(()),
})
}
fn from_trait<I>(
data: &'a EnumData,
path: Path,
items: Vec<ImplItem>,
mut item: ItemTrait,
supertraits_types: I,
) -> Result<Self>
where
I: IntoIterator<Item = Ident>,
I::IntoIter: ExactSizeIterator,
{
#[allow(single_use_lifetimes)]
fn generics_params<'a>(
iter: impl Iterator<Item = &'a GenericParam>,
) -> impl Iterator<Item = Cow<'a, GenericParam>> {
iter.map(|param| match param {
GenericParam::Type(ty) => {
Cow::Owned(param_ident(ty.attrs.clone(), ty.ident.clone()))
}
param => Cow::Borrowed(param),
})
}
let mut generics = data.generics.clone();
let trait_ = {
if item.generics.params.is_empty() {
path.clone()
} else {
let generics = generics_params(item.generics.params.iter());
parse_quote!(#path<#(#generics),*>)?
}
};
let fst = data.fields.iter().next();
let mut types: Vec<_> = item
.items
.iter()
.filter_map(|item| match item {
TraitItem::Type(ty) => Some((false, Cow::Borrowed(&ty.ident))),
_ => None,
})
.collect();
let supertraits_types = supertraits_types.into_iter();
if supertraits_types.len() > 0 {
if let Some(TypeParamBound::Trait(_)) = item.supertraits.iter().next() {
types.extend(supertraits_types.map(|ident| (true, Cow::Owned(ident))));
}
}
let where_clause = &mut generics.make_where_clause().predicates;
where_clause.push(parse_quote!(#fst: #trait_)?);
data.fields
.iter()
.skip(1)
.map(|variant| {
if types.is_empty() {
parse_quote!(#variant: #trait_)
} else {
let types = types.iter().map(|(supertraits, ident)| {
match item.supertraits.iter().next() {
Some(TypeParamBound::Trait(trait_)) if *supertraits => {
quote!(#ident = <#fst as #trait_>::#ident)
}
_ => quote!(#ident = <#fst as #trait_>::#ident),
}
});
if item.generics.params.is_empty() {
parse_quote!(#variant: #path<#(#types),*>)
} else {
let generics = generics_params(item.generics.params.iter());
parse_quote!(#variant: #path<#(#generics),*, #(#types),*>)
}
}
})
.try_for_each(|res| res.map(|f| where_clause.push(f)))?;
if !item.generics.params.is_empty() {
generics.params.extend(mem::replace(&mut item.generics.params, Punctuated::new()));
}
if let Some(old) = item.generics.where_clause.as_mut() {
if !old.predicates.is_empty() {
generics
.make_where_clause()
.predicates
.extend(mem::replace(&mut old.predicates, Punctuated::new()));
}
}
let ident = &data.ident;
let ty_generics = &data.generics;
parse_quote!(#ident #ty_generics)
.map(|self_ty| Self {
data,
defaultness: false,
unsafety: item.unsafety.is_some(),
generics,
trait_: Some(Trait::new(path, trait_)),
self_ty: Box::new(self_ty),
items,
unsafe_code: false,
})
.and_then(|mut impls| impls.append_items_from_trait(item).map(|_| impls))
}
pub fn build(self) -> TokenStream {
self.build_item().into_token_stream()
}
pub fn build_item(self) -> ItemImpl {
ItemImpl {
attrs: if self.unsafe_code {
vec![syn::parse_quote!(#[allow(unsafe_code)])]
} else {
Vec::new()
},
defaultness: if self.defaultness { Some(token::Default::default()) } else { None },
unsafety: if self.unsafety { Some(token::Unsafe::default()) } else { None },
impl_token: token::Impl::default(),
generics: self.generics,
trait_: self.trait_.map(|Trait { ty, .. }| (None, ty, token::For::default())),
self_ty: self.self_ty,
brace_token: token::Brace::default(),
items: self.items,
}
}
}
enum SelfType {
None,
Pin(CaptureMode, Path),
}
enum CaptureMode {
Ref { mutability: bool },
}
impl SelfType {
fn parse(arg: Option<&FnArg>) -> Result<Self> {
fn remove_last_path_args(mut path: Path) -> Path {
path.segments.last_mut().unwrap().arguments = PathArguments::None;
path
}
fn arg_to_string(arg: Option<&FnArg>) -> String {
arg.unwrap().clone().into_token_stream().to_string()
}
match arg {
Some(FnArg::Receiver(_)) => Ok(SelfType::None),
Some(FnArg::Typed(PatType { pat, ty, .. })) => match (&**pat, &**ty) {
(
Pat::Ident(PatIdent { ident, .. }),
Type::Path(TypePath { qself: None, path }),
) if ident == "self" => {
let ty = &path.segments[path.segments.len() - 1];
if let PathArguments::AngleBracketed(args) = &ty.arguments {
if args.args.len() == 1 && ty.ident == "Pin" {
if let GenericArgument::Type(Type::Reference(TypeReference {
mutability,
elem,
..
})) = &args.args[0]
{
match &**elem {
Type::Path(TypePath { path: p, qself: None })
if p.is_ident("Self") =>
{
return Ok(SelfType::Pin(
CaptureMode::Ref { mutability: mutability.is_some() },
remove_last_path_args(path.clone()),
));
}
_ => {}
}
}
}
}
error!(arg, "unsupported first argument type: {}", arg_to_string(arg))
}
_ => error!(
arg,
"methods that do not have `self` argument are not supported: {}",
arg_to_string(arg)
),
},
None => error!(arg, "methods without arguments are not supported"),
}
}
}