easy_ext/
lib.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3/*!
4<!-- tidy:crate-doc:start -->
5A lightweight attribute macro for easily writing [extension trait pattern][rfc0445].
6
7```toml
8[dependencies]
9easy-ext = "1"
10```
11
12## Examples
13
14```rust
15use easy_ext::ext;
16
17#[ext(ResultExt)]
18pub impl<T, E> Result<T, E> {
19    fn err_into<U>(self) -> Result<T, U>
20    where
21        E: Into<U>,
22    {
23        self.map_err(Into::into)
24    }
25}
26```
27
28Code like this will be generated:
29
30```rust
31pub trait ResultExt<T, E> {
32    fn err_into<U>(self) -> Result<T, U>
33    where
34        E: Into<U>;
35}
36
37impl<T, E> ResultExt<T, E> for Result<T, E> {
38    fn err_into<U>(self) -> Result<T, U>
39    where
40        E: Into<U>,
41    {
42        self.map_err(Into::into)
43    }
44}
45```
46
47You can elide the trait name.
48
49```rust
50use easy_ext::ext;
51
52#[ext]
53impl<T, E> Result<T, E> {
54    fn err_into<U>(self) -> Result<T, U>
55    where
56        E: Into<U>,
57    {
58        self.map_err(Into::into)
59    }
60}
61```
62
63Note that in this case, `#[ext]` assigns a random name, so you cannot
64import/export the generated trait.
65
66### Visibility
67
68There are two ways to specify visibility.
69
70#### Impl-level visibility
71
72The first way is to specify visibility at the impl level. For example:
73
74```rust
75use easy_ext::ext;
76
77// unnamed
78#[ext]
79pub impl str {
80    fn foo(&self) {}
81}
82
83// named
84#[ext(StrExt)]
85pub impl str {
86    fn bar(&self) {}
87}
88```
89
90#### Associated-item-level visibility
91
92Another way is to specify visibility at the associated item level.
93
94For example, if the method is `pub` then the trait will also be `pub`:
95
96```rust
97use easy_ext::ext;
98
99#[ext(ResultExt)] // generate `pub trait ResultExt`
100impl<T, E> Result<T, E> {
101    pub fn err_into<U>(self) -> Result<T, U>
102    where
103        E: Into<U>,
104    {
105        self.map_err(Into::into)
106    }
107}
108```
109
110This is useful when migrate from an inherent impl to an extension trait.
111
112Note that the visibility of all the associated items in the `impl` must be identical.
113
114Note that you cannot specify impl-level visibility and associated-item-level visibility at the same time.
115
116### [Supertraits](https://doc.rust-lang.org/reference/items/traits.html#supertraits)
117
118If you want the extension trait to be a subtrait of another trait,
119add `Self: SubTrait` bound to the `where` clause.
120
121```rust
122use easy_ext::ext;
123
124#[ext(Ext)]
125impl<T> T
126where
127    Self: Default,
128{
129    fn method(&self) {}
130}
131```
132
133### Supported items
134
135#### [Associated functions (methods)](https://doc.rust-lang.org/reference/items/associated-items.html#associated-functions-and-methods)
136
137```rust
138use easy_ext::ext;
139
140#[ext]
141impl<T> T {
142    fn method(&self) {}
143}
144```
145
146#### [Associated constants](https://doc.rust-lang.org/reference/items/associated-items.html#associated-constants)
147
148```rust
149use easy_ext::ext;
150
151#[ext]
152impl<T> T {
153    const MSG: &'static str = "Hello!";
154}
155```
156
157#### [Associated types](https://doc.rust-lang.org/reference/items/associated-items.html#associated-types)
158
159```rust
160use easy_ext::ext;
161
162#[ext]
163impl str {
164    type Owned = String;
165
166    fn method(&self) -> Self::Owned {
167        self.to_owned()
168    }
169}
170```
171
172[rfc0445]: https://github.com/rust-lang/rfcs/blob/HEAD/text/0445-extension-trait-conventions.md
173
174<!-- tidy:crate-doc:end -->
175*/
176
177#![doc(test(
178    no_crate_inject,
179    attr(
180        deny(warnings, rust_2018_idioms, single_use_lifetimes),
181        allow(dead_code, unused_variables)
182    )
183))]
184#![forbid(unsafe_code)]
185
186// older compilers require explicit `extern crate`.
187#[allow(unused_extern_crates)]
188extern crate proc_macro;
189
190#[macro_use]
191mod error;
192
193mod ast;
194mod iter;
195mod to_tokens;
196
197use std::{collections::hash_map::DefaultHasher, hash::Hasher, iter::FromIterator, mem};
198
199use proc_macro::{Delimiter, Group, Ident, Span, TokenStream, TokenTree};
200
201use crate::{
202    ast::{
203        parsing, printing::punct, Attribute, AttributeKind, FnArg, GenericParam, Generics,
204        ImplItem, ItemImpl, ItemTrait, PredicateType, Signature, TraitItem, TraitItemConst,
205        TraitItemMethod, TraitItemType, TypeParam, Visibility, WherePredicate,
206    },
207    error::{Error, Result},
208    iter::TokenIter,
209    to_tokens::ToTokens,
210};
211
212/// A lightweight attribute macro for easily writing [extension trait pattern][rfc0445].
213///
214/// See crate level documentation for details.
215///
216/// [rfc0445]: https://github.com/rust-lang/rfcs/blob/HEAD/text/0445-extension-trait-conventions.md
217#[proc_macro_attribute]
218pub fn ext(args: TokenStream, input: TokenStream) -> TokenStream {
219    expand(args, input).unwrap_or_else(Error::into_compile_error)
220}
221
222fn expand(args: TokenStream, input: TokenStream) -> Result<TokenStream> {
223    let trait_name = match parse_args(args)? {
224        None => Ident::new(&format!("__ExtTrait{}", hash(&input)), Span::call_site()),
225        Some(trait_name) => trait_name,
226    };
227
228    let mut item: ItemImpl = parsing::parse_impl(&mut TokenIter::new(input))?;
229
230    let mut tokens = trait_from_impl(&mut item, trait_name)?.to_token_stream();
231    tokens.extend(item.to_token_stream());
232    Ok(tokens)
233}
234
235fn parse_args(input: TokenStream) -> Result<Option<Ident>> {
236    let input = &mut TokenIter::new(input);
237    let vis = ast::parsing::parse_visibility(input)?;
238    if !vis.is_inherited() {
239        bail!(vis, "use `{} impl` instead", vis);
240    }
241    let trait_name = input.parse_ident_opt();
242    if !input.is_empty() {
243        let tt = input.next().unwrap();
244        bail!(tt, "unexpected token: `{}`", tt);
245    }
246    Ok(trait_name)
247}
248
249fn determine_trait_generics<'a>(
250    generics: &mut Generics,
251    self_ty: &'a [TokenTree],
252) -> Option<&'a Ident> {
253    if self_ty.len() != 1 {
254        return None;
255    }
256    if let TokenTree::Ident(self_ty) = &self_ty[0] {
257        let i = generics.params.iter().position(|(param, _)| {
258            if let GenericParam::Type(param) = param {
259                param.ident.to_string() == self_ty.to_string()
260            } else {
261                false
262            }
263        });
264        if let Some(i) = i {
265            let mut params = mem::replace(&mut generics.params, vec![]);
266            let (param, _) = params.remove(i);
267            generics.params = params;
268
269            if let GenericParam::Type(TypeParam {
270                colon_token: Some(colon_token), bounds, ..
271            }) = param
272            {
273                let bounds = bounds.into_iter().filter(|(b, _)| !b.is_maybe).collect::<Vec<_>>();
274                if !bounds.is_empty() {
275                    let where_clause = generics.make_where_clause();
276                    if let Some((_, p)) = where_clause.predicates.last_mut() {
277                        p.get_or_insert_with(|| punct(',', Span::call_site()));
278                    }
279                    where_clause.predicates.push((
280                        WherePredicate::Type(PredicateType {
281                            lifetimes: None,
282                            bounded_ty: vec![TokenTree::Ident(Ident::new("Self", self_ty.span()))]
283                                .into_iter()
284                                .collect(),
285                            colon_token,
286                            bounds,
287                        }),
288                        None,
289                    ));
290                }
291            }
292
293            return Some(self_ty);
294        }
295    }
296    None
297}
298
299fn trait_from_impl(item: &mut ItemImpl, trait_name: Ident) -> Result<ItemTrait> {
300    /// Replace `self_ty` with `Self`.
301    struct ReplaceParam {
302        self_ty: String,
303        // Restrict the scope for removing `?Trait` bounds, because `?Trait`
304        // bounds are only permitted at the point where a type parameter is
305        // declared.
306        remove_maybe: bool,
307    }
308
309    impl ReplaceParam {
310        fn visit_token_stream(&self, tokens: &mut TokenStream) -> bool {
311            let mut out: Vec<TokenTree> = vec![];
312            let mut modified = false;
313            let iter = tokens.clone().into_iter();
314            for tt in iter {
315                match tt {
316                    TokenTree::Ident(ident) => {
317                        if ident.to_string() == self.self_ty {
318                            modified = true;
319                            let self_ = Ident::new("Self", ident.span());
320                            out.push(self_.into());
321                        } else {
322                            out.push(TokenTree::Ident(ident));
323                        }
324                    }
325                    TokenTree::Group(group) => {
326                        let mut content = group.stream();
327                        modified |= self.visit_token_stream(&mut content);
328                        let mut new = Group::new(group.delimiter(), content);
329                        new.set_span(group.span());
330                        out.push(TokenTree::Group(new));
331                    }
332                    other => out.push(other),
333                }
334            }
335            if modified {
336                *tokens = TokenStream::from_iter(out);
337            }
338            modified
339        }
340
341        // Everything below is simply traversing the syntax tree.
342
343        fn visit_trait_item_mut(&mut self, node: &mut TraitItem) {
344            match node {
345                TraitItem::Const(node) => {
346                    self.visit_token_stream(&mut node.ty);
347                }
348                TraitItem::Method(node) => {
349                    self.visit_signature_mut(&mut node.sig);
350                }
351                TraitItem::Type(node) => {
352                    self.visit_generics_mut(&mut node.generics);
353                }
354            }
355        }
356
357        fn visit_signature_mut(&mut self, node: &mut Signature) {
358            self.visit_generics_mut(&mut node.generics);
359            for arg in &mut node.inputs {
360                self.visit_fn_arg_mut(arg);
361            }
362            if let Some(ty) = &mut node.output {
363                self.visit_token_stream(ty);
364            }
365        }
366
367        fn visit_fn_arg_mut(&mut self, node: &mut FnArg) {
368            match node {
369                FnArg::Receiver(pat, _) => {
370                    self.visit_token_stream(pat);
371                }
372                FnArg::Typed(pat, _, ty, _) => {
373                    self.visit_token_stream(pat);
374                    self.visit_token_stream(ty);
375                }
376            }
377        }
378
379        fn visit_generics_mut(&mut self, generics: &mut Generics) {
380            for (param, _) in &mut generics.params {
381                match param {
382                    GenericParam::Type(param) => {
383                        for (bound, _) in &mut param.bounds {
384                            self.visit_token_stream(&mut bound.tokens);
385                        }
386                    }
387                    GenericParam::Const(_) | GenericParam::Lifetime(_) => {}
388                }
389            }
390            if let Some(where_clause) = &mut generics.where_clause {
391                let predicates = Vec::with_capacity(where_clause.predicates.len());
392                for (mut predicate, p) in mem::replace(&mut where_clause.predicates, predicates) {
393                    match &mut predicate {
394                        WherePredicate::Type(pred) => {
395                            if self.remove_maybe {
396                                let mut iter = pred.bounded_ty.clone().into_iter();
397                                if let Some(TokenTree::Ident(i)) = iter.next() {
398                                    if iter.next().is_none() && self.self_ty == i.to_string() {
399                                        let bounds = mem::replace(&mut pred.bounds, vec![])
400                                            .into_iter()
401                                            .filter(|(b, _)| !b.is_maybe)
402                                            .collect::<Vec<_>>();
403                                        if !bounds.is_empty() {
404                                            self.visit_token_stream(&mut pred.bounded_ty);
405                                            pred.bounds = bounds;
406                                            for (bound, _) in &mut pred.bounds {
407                                                self.visit_token_stream(&mut bound.tokens);
408                                            }
409                                            where_clause.predicates.push((predicate, p));
410                                        }
411                                        continue;
412                                    }
413                                }
414                            }
415
416                            self.visit_token_stream(&mut pred.bounded_ty);
417                            for (bound, _) in &mut pred.bounds {
418                                self.visit_token_stream(&mut bound.tokens);
419                            }
420                        }
421                        WherePredicate::Lifetime(_) => {}
422                    }
423                    where_clause.predicates.push((predicate, p));
424                }
425            }
426        }
427    }
428
429    let mut generics = item.generics.clone();
430    let mut visitor = determine_trait_generics(&mut generics, &item.self_ty)
431        .map(|self_ty| ReplaceParam { self_ty: self_ty.to_string(), remove_maybe: false });
432
433    if let Some(visitor) = &mut visitor {
434        visitor.remove_maybe = true;
435        visitor.visit_generics_mut(&mut generics);
436        visitor.remove_maybe = false;
437    }
438    let ty_generics = generics.ty_generics();
439    item.trait_ = Some((
440        trait_name.clone(),
441        ty_generics.to_token_stream(),
442        Ident::new("for", Span::call_site()),
443    ));
444
445    // impl-level visibility
446    let impl_vis = if item.vis.is_inherited() { None } else { Some(item.vis.clone()) };
447    // assoc-item-level visibility
448    let mut assoc_vis = None;
449    let mut items = Vec::with_capacity(item.items.len());
450    item.items.iter_mut().try_for_each(|item| {
451        trait_item_from_impl_item(item, &mut assoc_vis, &impl_vis).map(|mut item| {
452            if let Some(visitor) = &mut visitor {
453                visitor.visit_trait_item_mut(&mut item);
454            }
455            items.push(item);
456        })
457    })?;
458
459    let mut attrs = item.attrs.clone();
460    find_remove(&mut item.attrs, AttributeKind::Doc); // https://github.com/taiki-e/easy-ext/issues/20
461    attrs.push(Attribute::new(vec![
462        TokenTree::Ident(Ident::new("allow", Span::call_site())),
463        TokenTree::Group(Group::new(
464            Delimiter::Parenthesis,
465            std::iter::once(TokenTree::Ident(Ident::new(
466                "patterns_in_fns_without_body",
467                Span::call_site(),
468            )))
469            .collect(),
470        )),
471    ])); // mut self
472
473    Ok(ItemTrait {
474        attrs,
475        // priority: impl-level visibility > assoc-item-level visibility > inherited visibility
476        vis: impl_vis.unwrap_or_else(|| assoc_vis.unwrap_or(Visibility::Inherited)),
477        unsafety: item.unsafety.clone(),
478        trait_token: Ident::new("trait", item.impl_token.span()),
479        ident: trait_name,
480        generics,
481        brace_token: item.brace_token,
482        items,
483    })
484}
485
486fn trait_item_from_impl_item(
487    impl_item: &mut ImplItem,
488    prev_vis: &mut Option<Visibility>,
489    impl_vis: &Option<Visibility>,
490) -> Result<TraitItem> {
491    fn check_visibility(
492        current: Visibility,
493        prev: &mut Option<Visibility>,
494        impl_vis: &Option<Visibility>,
495        span: &dyn ToTokens,
496    ) -> Result<()> {
497        if impl_vis.is_some() {
498            if current.is_inherited() {
499                return Ok(());
500            }
501            bail!(current, "all associated items must have inherited visibility");
502        }
503        match prev {
504            None => *prev = Some(current),
505            Some(prev) if *prev == current => {}
506            Some(prev) => {
507                if prev.is_inherited() {
508                    bail!(current, "all associated items must have inherited visibility");
509                }
510                bail!(
511                    if current.is_inherited() { span } else { &current },
512                    "all associated items must have a visibility of `{}`",
513                    prev,
514                );
515            }
516        }
517        Ok(())
518    }
519
520    match impl_item {
521        ImplItem::Const(impl_const) => {
522            let vis = mem::replace(&mut impl_const.vis, Visibility::Inherited);
523            check_visibility(vis, prev_vis, impl_vis, &impl_const.ident)?;
524
525            let attrs = impl_const.attrs.clone();
526            find_remove(&mut impl_const.attrs, AttributeKind::Doc); // https://github.com/taiki-e/easy-ext/issues/20
527            Ok(TraitItem::Const(TraitItemConst {
528                attrs,
529                const_token: impl_const.const_token.clone(),
530                ident: impl_const.ident.clone(),
531                colon_token: impl_const.colon_token.clone(),
532                ty: impl_const.ty.clone(),
533                semi_token: impl_const.semi_token.clone(),
534            }))
535        }
536        ImplItem::Type(impl_type) => {
537            let vis = mem::replace(&mut impl_type.vis, Visibility::Inherited);
538            check_visibility(vis, prev_vis, impl_vis, &impl_type.ident)?;
539
540            let attrs = impl_type.attrs.clone();
541            find_remove(&mut impl_type.attrs, AttributeKind::Doc); // https://github.com/taiki-e/easy-ext/issues/20
542            Ok(TraitItem::Type(TraitItemType {
543                attrs,
544                type_token: impl_type.type_token.clone(),
545                ident: impl_type.ident.clone(),
546                generics: impl_type.generics.clone(),
547                semi_token: impl_type.semi_token.clone(),
548            }))
549        }
550        ImplItem::Method(impl_method) => {
551            let vis = mem::replace(&mut impl_method.vis, Visibility::Inherited);
552            check_visibility(vis, prev_vis, impl_vis, &impl_method.sig.ident)?;
553
554            let mut attrs = impl_method.attrs.clone();
555            find_remove(&mut impl_method.attrs, AttributeKind::Doc); // https://github.com/taiki-e/easy-ext/issues/20
556            find_remove(&mut attrs, AttributeKind::Inline); // `#[inline]` is ignored on function prototypes
557            Ok(TraitItem::Method(TraitItemMethod {
558                attrs,
559                sig: {
560                    let mut sig = impl_method.sig.clone();
561                    for arg in &mut sig.inputs {
562                        if let FnArg::Typed(pat, ..) = arg {
563                            if pat.to_string() != "self" {
564                                *pat = std::iter::once(TokenTree::Ident(Ident::new(
565                                    "_",
566                                    pat.clone().into_iter().next().unwrap().span(),
567                                )))
568                                .collect();
569                            }
570                        }
571                    }
572                    sig
573                },
574                semi_token: punct(';', impl_method.body.span()),
575            }))
576        }
577    }
578}
579
580fn find_remove(attrs: &mut Vec<Attribute>, kind: AttributeKind) {
581    while let Some(i) = attrs.iter().position(|attr| attr.kind == kind) {
582        attrs.remove(i);
583    }
584}
585
586/// Returns the hash value of the input AST.
587fn hash(input: &TokenStream) -> u64 {
588    let mut hasher = DefaultHasher::new();
589    hasher.write(input.to_string().as_bytes());
590    hasher.finish()
591}