#![doc(test(
no_crate_inject,
attr(
deny(warnings, rust_2018_idioms, single_use_lifetimes),
allow(dead_code, unused_variables)
)
))]
#![forbid(unsafe_code)]
#![warn(rust_2018_idioms, unreachable_pub)]
#![cfg_attr(test, warn(single_use_lifetimes))]
#![warn(clippy::pedantic)]
#![allow(
clippy::cast_lossless,
clippy::derive_partial_eq_without_eq,
clippy::similar_names,
clippy::too_many_lines
)]
#[allow(unused_extern_crates)]
extern crate proc_macro;
#[macro_use]
mod error;
mod ast;
mod iter;
mod to_tokens;
use std::{collections::hash_map::DefaultHasher, hash::Hasher, iter::FromIterator, mem};
use proc_macro::{Delimiter, Group, Ident, Span, TokenStream, TokenTree};
use crate::{
ast::{
parsing, printing::punct, Attribute, AttributeKind, FnArg, GenericParam, Generics,
ImplItem, ItemImpl, ItemTrait, PredicateType, Signature, TraitItem, TraitItemConst,
TraitItemMethod, TraitItemType, TypeParam, Visibility, WherePredicate,
},
error::{Error, Result},
iter::TokenIter,
to_tokens::ToTokens,
};
#[proc_macro_attribute]
pub fn ext(args: TokenStream, input: TokenStream) -> TokenStream {
expand(args, input).unwrap_or_else(Error::into_compile_error)
}
fn expand(args: TokenStream, input: TokenStream) -> Result<TokenStream> {
let trait_name = match parse_args(args)? {
None => Ident::new(&format!("__ExtTrait{}", hash(&input)), Span::call_site()),
Some(trait_name) => trait_name,
};
let mut item: ItemImpl = parsing::parse_impl(&mut TokenIter::new(input))?;
let mut tokens = trait_from_impl(&mut item, trait_name)?.to_token_stream();
tokens.extend(item.to_token_stream());
Ok(tokens)
}
fn parse_args(input: TokenStream) -> Result<Option<Ident>> {
let input = &mut TokenIter::new(input);
let vis = ast::parsing::parse_visibility(input)?;
if !vis.is_inherited() {
bail!(vis, "use `{} impl` instead", vis);
}
let trait_name = input.parse_ident_opt();
if !input.is_empty() {
let tt = input.next().unwrap();
bail!(tt, "unexpected token: `{}`", tt);
}
Ok(trait_name)
}
fn determine_trait_generics<'a>(
generics: &mut Generics,
self_ty: &'a [TokenTree],
) -> Option<&'a Ident> {
if self_ty.len() != 1 {
return None;
}
if let TokenTree::Ident(self_ty) = &self_ty[0] {
let i = generics.params.iter().position(|(param, _)| {
if let GenericParam::Type(param) = param {
param.ident.to_string() == self_ty.to_string()
} else {
false
}
});
if let Some(i) = i {
let mut params = mem::replace(&mut generics.params, Vec::new());
let (param, _) = params.remove(i);
generics.params = params;
if let GenericParam::Type(TypeParam {
colon_token: Some(colon_token), bounds, ..
}) = param
{
let bounds = bounds.into_iter().filter(|(b, _)| !b.is_maybe).collect::<Vec<_>>();
if !bounds.is_empty() {
let where_clause = generics.make_where_clause();
if let Some((_, p)) = where_clause.predicates.last_mut() {
p.get_or_insert_with(|| punct(',', Span::call_site()));
}
where_clause.predicates.push((
WherePredicate::Type(PredicateType {
lifetimes: None,
bounded_ty: vec![TokenTree::Ident(Ident::new("Self", self_ty.span()))]
.into_iter()
.collect(),
colon_token,
bounds,
}),
None,
));
}
}
return Some(self_ty);
}
}
None
}
fn trait_from_impl(item: &mut ItemImpl, trait_name: Ident) -> Result<ItemTrait> {
struct ReplaceParam {
self_ty: String,
remove_maybe: bool,
}
impl ReplaceParam {
fn visit_token_stream(&self, tokens: &mut TokenStream) -> bool {
let mut out: Vec<TokenTree> = Vec::new();
let mut modified = false;
let iter = tokens.clone().into_iter();
for tt in iter {
match tt {
TokenTree::Ident(ident) => {
if ident.to_string() == self.self_ty {
modified = true;
let self_ = Ident::new("Self", ident.span());
out.push(self_.into());
} else {
out.push(TokenTree::Ident(ident));
}
}
TokenTree::Group(group) => {
let mut content = group.stream();
modified |= self.visit_token_stream(&mut content);
let mut new = Group::new(group.delimiter(), content);
new.set_span(group.span());
out.push(TokenTree::Group(new));
}
other => out.push(other),
}
}
if modified {
*tokens = TokenStream::from_iter(out);
}
modified
}
fn visit_trait_item_mut(&mut self, node: &mut TraitItem) {
match node {
TraitItem::Const(node) => {
self.visit_token_stream(&mut node.ty);
}
TraitItem::Method(node) => {
self.visit_signature_mut(&mut node.sig);
}
TraitItem::Type(node) => {
self.visit_generics_mut(&mut node.generics);
}
}
}
fn visit_signature_mut(&mut self, node: &mut Signature) {
self.visit_generics_mut(&mut node.generics);
for arg in &mut node.inputs {
self.visit_fn_arg_mut(arg);
}
if let Some(ty) = &mut node.output {
self.visit_token_stream(ty);
}
}
fn visit_fn_arg_mut(&mut self, node: &mut FnArg) {
match node {
FnArg::Receiver(pat, _) => {
self.visit_token_stream(pat);
}
FnArg::Typed(pat, _, ty, _) => {
self.visit_token_stream(pat);
self.visit_token_stream(ty);
}
}
}
fn visit_generics_mut(&mut self, generics: &mut Generics) {
for (param, _) in &mut generics.params {
match param {
GenericParam::Type(param) => {
for (bound, _) in &mut param.bounds {
self.visit_token_stream(&mut bound.tokens);
}
}
GenericParam::Const(_) | GenericParam::Lifetime(_) => {}
}
}
if let Some(where_clause) = &mut generics.where_clause {
let predicates = Vec::with_capacity(where_clause.predicates.len());
for (mut predicate, p) in mem::replace(&mut where_clause.predicates, predicates) {
match &mut predicate {
WherePredicate::Type(pred) => {
if self.remove_maybe {
let mut iter = pred.bounded_ty.clone().into_iter();
if let Some(TokenTree::Ident(i)) = iter.next() {
if iter.next().is_none() && self.self_ty == i.to_string() {
let bounds = mem::replace(&mut pred.bounds, Vec::new())
.into_iter()
.filter(|(b, _)| !b.is_maybe)
.collect::<Vec<_>>();
if !bounds.is_empty() {
self.visit_token_stream(&mut pred.bounded_ty);
pred.bounds = bounds;
for (bound, _) in &mut pred.bounds {
self.visit_token_stream(&mut bound.tokens);
}
where_clause.predicates.push((predicate, p));
}
continue;
}
}
}
self.visit_token_stream(&mut pred.bounded_ty);
for (bound, _) in &mut pred.bounds {
self.visit_token_stream(&mut bound.tokens);
}
}
WherePredicate::Lifetime(_) => {}
}
where_clause.predicates.push((predicate, p));
}
}
}
}
let mut generics = item.generics.clone();
let mut visitor = determine_trait_generics(&mut generics, &item.self_ty)
.map(|self_ty| ReplaceParam { self_ty: self_ty.to_string(), remove_maybe: false });
if let Some(visitor) = &mut visitor {
visitor.remove_maybe = true;
visitor.visit_generics_mut(&mut generics);
visitor.remove_maybe = false;
}
let ty_generics = generics.ty_generics();
item.trait_ = Some((
trait_name.clone(),
ty_generics.to_token_stream(),
Ident::new("for", Span::call_site()),
));
let impl_vis = if item.vis.is_inherited() { None } else { Some(item.vis.clone()) };
let mut assoc_vis = None;
let mut items = Vec::with_capacity(item.items.len());
item.items.iter_mut().try_for_each(|item| {
trait_item_from_impl_item(item, &mut assoc_vis, &impl_vis).map(|mut item| {
if let Some(visitor) = &mut visitor {
visitor.visit_trait_item_mut(&mut item);
}
items.push(item);
})
})?;
let mut attrs = item.attrs.clone();
find_remove(&mut item.attrs, AttributeKind::Doc); attrs.push(Attribute::new(vec![
TokenTree::Ident(Ident::new("allow", Span::call_site())),
TokenTree::Group(Group::new(
Delimiter::Parenthesis,
Some(TokenTree::Ident(Ident::new("patterns_in_fns_without_body", Span::call_site())))
.into_iter()
.collect(),
)),
]));
Ok(ItemTrait {
attrs,
vis: impl_vis.unwrap_or_else(|| assoc_vis.unwrap_or(Visibility::Inherited)),
unsafety: item.unsafety.clone(),
trait_token: Ident::new("trait", item.impl_token.span()),
ident: trait_name,
generics,
brace_token: item.brace_token,
items,
})
}
fn trait_item_from_impl_item(
impl_item: &mut ImplItem,
prev_vis: &mut Option<Visibility>,
impl_vis: &Option<Visibility>,
) -> Result<TraitItem> {
fn check_visibility(
current: Visibility,
prev: &mut Option<Visibility>,
impl_vis: &Option<Visibility>,
span: &dyn ToTokens,
) -> Result<()> {
if impl_vis.is_some() {
if current.is_inherited() {
return Ok(());
}
bail!(current, "all associated items must have inherited visibility");
}
match prev {
None => *prev = Some(current),
Some(prev) if *prev == current => {}
Some(prev) => {
if prev.is_inherited() {
bail!(current, "all associated items must have inherited visibility");
}
bail!(
if current.is_inherited() { span } else { ¤t },
"all associated items must have a visibility of `{}`",
prev,
);
}
}
Ok(())
}
match impl_item {
ImplItem::Const(impl_const) => {
let vis = mem::replace(&mut impl_const.vis, Visibility::Inherited);
check_visibility(vis, prev_vis, impl_vis, &impl_const.ident)?;
let attrs = impl_const.attrs.clone();
find_remove(&mut impl_const.attrs, AttributeKind::Doc); Ok(TraitItem::Const(TraitItemConst {
attrs,
const_token: impl_const.const_token.clone(),
ident: impl_const.ident.clone(),
colon_token: impl_const.colon_token.clone(),
ty: impl_const.ty.clone(),
semi_token: impl_const.semi_token.clone(),
}))
}
ImplItem::Type(impl_type) => {
let vis = mem::replace(&mut impl_type.vis, Visibility::Inherited);
check_visibility(vis, prev_vis, impl_vis, &impl_type.ident)?;
let attrs = impl_type.attrs.clone();
find_remove(&mut impl_type.attrs, AttributeKind::Doc); Ok(TraitItem::Type(TraitItemType {
attrs,
type_token: impl_type.type_token.clone(),
ident: impl_type.ident.clone(),
generics: impl_type.generics.clone(),
semi_token: impl_type.semi_token.clone(),
}))
}
ImplItem::Method(impl_method) => {
let vis = mem::replace(&mut impl_method.vis, Visibility::Inherited);
check_visibility(vis, prev_vis, impl_vis, &impl_method.sig.ident)?;
let mut attrs = impl_method.attrs.clone();
find_remove(&mut impl_method.attrs, AttributeKind::Doc); find_remove(&mut attrs, AttributeKind::Inline); Ok(TraitItem::Method(TraitItemMethod {
attrs,
sig: {
let mut sig = impl_method.sig.clone();
for arg in &mut sig.inputs {
if let FnArg::Typed(pat, ..) = arg {
*pat = Some(TokenTree::Ident(Ident::new(
"_",
pat.clone().into_iter().next().unwrap().span(),
)))
.into_iter()
.collect();
}
}
sig
},
semi_token: punct(';', impl_method.body.span()),
}))
}
}
}
fn find_remove(attrs: &mut Vec<Attribute>, kind: AttributeKind) {
while let Some(i) = attrs.iter().position(|attr| attr.kind == kind) {
attrs.remove(i);
}
}
fn hash(input: &TokenStream) -> u64 {
let mut hasher = DefaultHasher::new();
hasher.write(input.to_string().as_bytes());
hasher.finish()
}