Skip to main content

easy_ext/
lib.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3/*!
4<!-- Note: Document from sync-markdown-to-rustdoc:start through sync-markdown-to-rustdoc:end
5     is synchronized from README.md. Any changes to that range are not preserved. -->
6<!-- tidy:sync-markdown-to-rustdoc:start -->
7
8A lightweight attribute macro for easily writing [extension trait pattern][rfc0445].
9
10```toml
11[dependencies]
12easy-ext = "1"
13```
14
15## Examples
16
17```
18use easy_ext::ext;
19
20#[ext(ResultExt)]
21pub impl<T, E> Result<T, E> {
22    fn err_into<U>(self) -> Result<T, U>
23    where
24        E: Into<U>,
25    {
26        self.map_err(Into::into)
27    }
28}
29```
30
31Code like this will be generated:
32
33```
34pub trait ResultExt<T, E> {
35    fn err_into<U>(self) -> Result<T, U>
36    where
37        E: Into<U>;
38}
39
40impl<T, E> ResultExt<T, E> for Result<T, E> {
41    fn err_into<U>(self) -> Result<T, U>
42    where
43        E: Into<U>,
44    {
45        self.map_err(Into::into)
46    }
47}
48```
49
50You can elide the trait name.
51
52```
53use easy_ext::ext;
54
55#[ext]
56impl<T, E> Result<T, E> {
57    fn err_into<U>(self) -> Result<T, U>
58    where
59        E: Into<U>,
60    {
61        self.map_err(Into::into)
62    }
63}
64```
65
66Note that in this case, `#[ext]` assigns a random name, so you cannot
67import/export the generated trait.
68
69### Visibility
70
71There are two ways to specify visibility.
72
73#### Impl-level visibility
74
75The first way is to specify visibility at the impl level. For example:
76
77```
78use easy_ext::ext;
79
80#[ext(StrExt)] // generate `pub trait StrExt`
81pub impl str {
82    fn foo(&self) {}
83}
84```
85
86#### Associated-item-level visibility
87
88Another way is to specify visibility at the associated item level.
89
90For example, if the method is `pub` then the trait will also be `pub`:
91
92```
93use easy_ext::ext;
94
95#[ext(ResultExt)] // generate `pub trait ResultExt`
96impl<T, E> Result<T, E> {
97    pub fn err_into<U>(self) -> Result<T, U>
98    where
99        E: Into<U>,
100    {
101        self.map_err(Into::into)
102    }
103}
104```
105
106This is useful when migrate from an inherent impl to an extension trait.
107
108Note that the visibility of all the associated items in the `impl` must be identical.
109
110Note that you cannot specify impl-level visibility and associated-item-level visibility at the same time.
111
112### [Supertraits](https://doc.rust-lang.org/reference/items/traits.html#supertraits)
113
114If you want the extension trait to be a subtrait of another trait,
115add `Self: SubTrait` bound to the `where` clause.
116
117```
118use easy_ext::ext;
119
120#[ext(Ext)]
121impl<T> T
122where
123    Self: Default,
124{
125    fn method(&self) {}
126}
127```
128
129### Supported items
130
131#### [Associated functions (methods)](https://doc.rust-lang.org/reference/items/associated-items.html#associated-functions-and-methods)
132
133```
134use easy_ext::ext;
135
136#[ext]
137impl<T> T {
138    fn method(&self) {}
139}
140```
141
142#### [Associated constants](https://doc.rust-lang.org/reference/items/associated-items.html#associated-constants)
143
144```
145use easy_ext::ext;
146
147#[ext]
148impl<T> T {
149    const MSG: &'static str = "Hello!";
150}
151```
152
153#### [Associated types](https://doc.rust-lang.org/reference/items/associated-items.html#associated-types)
154
155```
156use easy_ext::ext;
157
158#[ext]
159impl str {
160    type Owned = String;
161
162    fn method(&self) -> Self::Owned {
163        self.to_owned()
164    }
165}
166```
167
168[rfc0445]: https://rust-lang.github.io/rfcs/0445-extension-trait-conventions.html
169
170<!-- tidy:sync-markdown-to-rustdoc:end -->
171*/
172
173#![doc(test(
174    no_crate_inject,
175    attr(allow(
176        dead_code,
177        unused_variables,
178        unreachable_pub,
179        clippy::undocumented_unsafe_blocks,
180        clippy::unused_trait_names,
181    ))
182))]
183#![forbid(unsafe_code)]
184
185// older compilers require explicit `extern crate`.
186#[allow(unused_extern_crates)]
187extern crate proc_macro;
188
189#[macro_use]
190mod error;
191
192mod ast;
193mod iter;
194mod to_tokens;
195
196use std::{collections::hash_map::DefaultHasher, hash::Hasher, iter::FromIterator, mem};
197
198use proc_macro::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree};
199
200use self::{
201    ast::{
202        Attribute, AttributeKind, FnArg, GenericParam, Generics, ImplItem, ItemImpl, ItemTrait,
203        PredicateType, Signature, TraitItem, TraitItemConst, TraitItemMethod, TraitItemType,
204        TypeParam, Visibility, WherePredicate, parsing,
205    },
206    error::{Error, Result},
207    iter::TokenIter,
208    to_tokens::ToTokens,
209};
210
211/// A lightweight attribute macro for easily writing [extension trait pattern][rfc0445].
212///
213/// See the [crate-level documentation](crate) for details.
214///
215/// [rfc0445]: https://rust-lang.github.io/rfcs/0445-extension-trait-conventions.html
216#[proc_macro_attribute]
217pub fn ext(args: TokenStream, input: TokenStream) -> TokenStream {
218    expand(args, input).unwrap_or_else(Error::into_compile_error)
219}
220
221fn expand(args: TokenStream, input: TokenStream) -> Result<TokenStream> {
222    let trait_name = match parse_args(args)? {
223        None => Ident::new(&format!("__ExtTrait{}", hash(&input)), Span::call_site()),
224        Some(trait_name) => trait_name,
225    };
226
227    let mut item: ItemImpl = parsing::parse_impl(&mut TokenIter::new(input))?;
228
229    let mut tokens = trait_from_impl(&mut item, trait_name)?.to_token_stream();
230    item.to_tokens(&mut tokens);
231    Ok(tokens)
232}
233
234fn parse_args(input: TokenStream) -> Result<Option<Ident>> {
235    let input = &mut TokenIter::new(input);
236    let vis = ast::parsing::parse_visibility(input)?;
237    if !vis.is_inherited() {
238        bail!(vis, "use `{} impl` instead", vis);
239    }
240    let trait_name = input.parse_ident_opt();
241    if !input.is_empty() {
242        let tt = input.next().unwrap();
243        bail!(tt, "unexpected token: `{}`", tt);
244    }
245    Ok(trait_name)
246}
247
248fn determine_trait_generics<'a>(
249    generics: &mut Generics,
250    self_ty: &'a [TokenTree],
251) -> Option<&'a Ident> {
252    if self_ty.len() != 1 {
253        return None;
254    }
255    if let TokenTree::Ident(self_ty) = &self_ty[0] {
256        let i = generics.params.iter().position(|(param, _)| {
257            if let GenericParam::Type(param) = param {
258                param.ident.to_string() == self_ty.to_string()
259            } else {
260                false
261            }
262        });
263        if let Some(i) = i {
264            let mut params = mem::replace(&mut generics.params, vec![]);
265            let (param, _) = params.remove(i);
266            generics.params = params;
267
268            if let GenericParam::Type(TypeParam {
269                colon_token: Some(colon_token), bounds, ..
270            }) = param
271            {
272                let bounds = bounds.into_iter().filter(|(b, _)| !b.is_maybe).collect::<Vec<_>>();
273                if !bounds.is_empty() {
274                    let where_clause = generics.make_where_clause();
275                    if let Some((_, p)) = where_clause.predicates.last_mut() {
276                        p.get_or_insert_with(|| Punct::new(',', Spacing::Alone));
277                    }
278                    where_clause.predicates.push((
279                        WherePredicate::Type(PredicateType {
280                            lifetimes: None,
281                            bounded_ty: std::iter::once(TokenTree::Ident(Ident::new(
282                                "Self",
283                                self_ty.span(),
284                            )))
285                            .collect(),
286                            colon_token,
287                            bounds,
288                        }),
289                        None,
290                    ));
291                }
292            }
293
294            return Some(self_ty);
295        }
296    }
297    None
298}
299
300fn trait_from_impl(item: &mut ItemImpl, trait_name: Ident) -> Result<ItemTrait> {
301    /// Replace `self_ty` with `Self`.
302    struct ReplaceParam {
303        self_ty: String,
304        // Restrict the scope for removing `?Trait` bounds, because `?Trait`
305        // bounds are only permitted at the point where a type parameter is
306        // declared.
307        remove_maybe: bool,
308    }
309
310    impl ReplaceParam {
311        fn visit_token_stream(&self, tokens: &mut TokenStream) -> bool {
312            let mut out: Vec<TokenTree> = vec![];
313            let mut modified = false;
314            let iter = tokens.clone().into_iter();
315            for tt in iter {
316                match tt {
317                    TokenTree::Ident(ident) => {
318                        if ident.to_string() == self.self_ty {
319                            modified = true;
320                            let self_ = Ident::new("Self", ident.span());
321                            out.push(self_.into());
322                        } else {
323                            out.push(TokenTree::Ident(ident));
324                        }
325                    }
326                    TokenTree::Group(group) => {
327                        let mut content = group.stream();
328                        modified |= self.visit_token_stream(&mut content);
329                        let mut new = Group::new(group.delimiter(), content);
330                        new.set_span(group.span());
331                        out.push(TokenTree::Group(new));
332                    }
333                    other => out.push(other),
334                }
335            }
336            if modified {
337                *tokens = TokenStream::from_iter(out);
338            }
339            modified
340        }
341
342        // Everything below is simply traversing the syntax tree.
343
344        fn visit_trait_item_mut(&self, node: &mut TraitItem) {
345            match node {
346                TraitItem::Const(node) => {
347                    self.visit_token_stream(&mut node.ty);
348                }
349                TraitItem::Method(node) => {
350                    self.visit_signature_mut(&mut node.sig);
351                }
352                TraitItem::Type(node) => {
353                    self.visit_generics_mut(&mut node.generics);
354                }
355            }
356        }
357
358        fn visit_signature_mut(&self, node: &mut Signature) {
359            self.visit_generics_mut(&mut node.generics);
360            for arg in &mut node.inputs {
361                self.visit_fn_arg_mut(arg);
362            }
363            if let Some(ty) = &mut node.output {
364                self.visit_token_stream(ty);
365            }
366        }
367
368        fn visit_fn_arg_mut(&self, node: &mut FnArg) {
369            match node {
370                FnArg::Receiver(pat, _) => {
371                    self.visit_token_stream(pat);
372                }
373                FnArg::Typed(pat, _, ty, _) => {
374                    self.visit_token_stream(pat);
375                    self.visit_token_stream(ty);
376                }
377            }
378        }
379
380        fn visit_generics_mut(&self, generics: &mut Generics) {
381            for (param, _) in &mut generics.params {
382                match param {
383                    GenericParam::Type(param) => {
384                        for (bound, _) in &mut param.bounds {
385                            self.visit_token_stream(&mut bound.tokens);
386                        }
387                    }
388                    GenericParam::Const(_) | GenericParam::Lifetime(_) => {}
389                }
390            }
391            if let Some(where_clause) = &mut generics.where_clause {
392                let predicates = Vec::with_capacity(where_clause.predicates.len());
393                for (mut predicate, p) in mem::replace(&mut where_clause.predicates, predicates) {
394                    match &mut predicate {
395                        WherePredicate::Type(pred) => {
396                            if self.remove_maybe {
397                                let mut iter = pred.bounded_ty.clone().into_iter();
398                                if let Some(TokenTree::Ident(i)) = iter.next() {
399                                    if iter.next().is_none() && self.self_ty == i.to_string() {
400                                        let bounds = mem::replace(&mut pred.bounds, vec![])
401                                            .into_iter()
402                                            .filter(|(b, _)| !b.is_maybe)
403                                            .collect::<Vec<_>>();
404                                        if !bounds.is_empty() {
405                                            self.visit_token_stream(&mut pred.bounded_ty);
406                                            pred.bounds = bounds;
407                                            for (bound, _) in &mut pred.bounds {
408                                                self.visit_token_stream(&mut bound.tokens);
409                                            }
410                                            where_clause.predicates.push((predicate, p));
411                                        }
412                                        continue;
413                                    }
414                                }
415                            }
416
417                            self.visit_token_stream(&mut pred.bounded_ty);
418                            for (bound, _) in &mut pred.bounds {
419                                self.visit_token_stream(&mut bound.tokens);
420                            }
421                        }
422                        WherePredicate::Lifetime(_) => {}
423                    }
424                    where_clause.predicates.push((predicate, p));
425                }
426            }
427        }
428    }
429
430    let mut generics = item.generics.clone();
431    let mut visitor = determine_trait_generics(&mut generics, &item.self_ty)
432        .map(|self_ty| ReplaceParam { self_ty: self_ty.to_string(), remove_maybe: false });
433
434    if let Some(visitor) = &mut visitor {
435        visitor.remove_maybe = true;
436        visitor.visit_generics_mut(&mut generics);
437        visitor.remove_maybe = false;
438    }
439    let ty_generics = generics.ty_generics();
440    item.trait_ = Some((
441        trait_name.clone(),
442        ty_generics.to_token_stream(),
443        Ident::new("for", Span::call_site()),
444    ));
445
446    // impl-level visibility
447    let impl_vis = if item.vis.is_inherited() { None } else { Some(item.vis.clone()) };
448    // assoc-item-level visibility
449    let mut assoc_vis = None;
450    let mut items = Vec::with_capacity(item.items.len());
451    item.items.iter_mut().try_for_each(|item| {
452        trait_item_from_impl_item(item, &mut assoc_vis, impl_vis.as_ref()).map(|mut item| {
453            if let Some(visitor) = &mut visitor {
454                visitor.visit_trait_item_mut(&mut item);
455            }
456            items.push(item);
457        })
458    })?;
459
460    let mut attrs = item.attrs.clone();
461    find_remove(&mut item.attrs, AttributeKind::Doc); // https://github.com/taiki-e/easy-ext/issues/20
462    attrs.push(Attribute::new(vec![
463        TokenTree::Ident(Ident::new("allow", Span::call_site())),
464        TokenTree::Group(Group::new(
465            Delimiter::Parenthesis,
466            std::iter::once(TokenTree::Ident(Ident::new(
467                "patterns_in_fns_without_body",
468                Span::call_site(),
469            )))
470            .collect(),
471        )),
472    ])); // mut self
473
474    Ok(ItemTrait {
475        attrs,
476        // priority: impl-level visibility > assoc-item-level visibility > inherited visibility
477        vis: impl_vis.unwrap_or_else(|| assoc_vis.unwrap_or(Visibility::Inherited)),
478        unsafety: item.unsafety.clone(),
479        trait_token: Ident::new("trait", item.impl_token.span()),
480        ident: trait_name,
481        generics,
482        brace_token: item.brace_token,
483        items,
484    })
485}
486
487fn trait_item_from_impl_item(
488    impl_item: &mut ImplItem,
489    prev_vis: &mut Option<Visibility>,
490    impl_vis: Option<&Visibility>,
491) -> Result<TraitItem> {
492    fn check_visibility(
493        current: Visibility,
494        prev: &mut Option<Visibility>,
495        impl_vis: Option<&Visibility>,
496        span: &dyn ToTokens,
497    ) -> Result<()> {
498        if impl_vis.is_some() {
499            if current.is_inherited() {
500                return Ok(());
501            }
502            bail!(current, "all associated items must have inherited visibility");
503        }
504        match prev {
505            None => *prev = Some(current),
506            Some(prev) if *prev == current => {}
507            Some(prev) => {
508                if prev.is_inherited() {
509                    bail!(current, "all associated items must have inherited visibility");
510                }
511                bail!(
512                    if current.is_inherited() { span } else { &current },
513                    "all associated items must have a visibility of `{}`",
514                    prev,
515                );
516            }
517        }
518        Ok(())
519    }
520
521    match impl_item {
522        ImplItem::Const(impl_const) => {
523            let vis = mem::replace(&mut impl_const.vis, Visibility::Inherited);
524            check_visibility(vis, prev_vis, impl_vis, &impl_const.ident)?;
525
526            let attrs = impl_const.attrs.clone();
527            find_remove(&mut impl_const.attrs, AttributeKind::Doc); // https://github.com/taiki-e/easy-ext/issues/20
528            Ok(TraitItem::Const(TraitItemConst {
529                attrs,
530                const_token: impl_const.const_token.clone(),
531                ident: impl_const.ident.clone(),
532                colon_token: impl_const.colon_token.clone(),
533                ty: impl_const.ty.clone(),
534                semi_token: impl_const.semi_token.clone(),
535            }))
536        }
537        ImplItem::Type(impl_type) => {
538            let vis = mem::replace(&mut impl_type.vis, Visibility::Inherited);
539            check_visibility(vis, prev_vis, impl_vis, &impl_type.ident)?;
540
541            let attrs = impl_type.attrs.clone();
542            find_remove(&mut impl_type.attrs, AttributeKind::Doc); // https://github.com/taiki-e/easy-ext/issues/20
543            Ok(TraitItem::Type(TraitItemType {
544                attrs,
545                type_token: impl_type.type_token.clone(),
546                ident: impl_type.ident.clone(),
547                generics: impl_type.generics.clone(),
548                semi_token: impl_type.semi_token.clone(),
549            }))
550        }
551        ImplItem::Method(impl_method) => {
552            let vis = mem::replace(&mut impl_method.vis, Visibility::Inherited);
553            check_visibility(vis, prev_vis, impl_vis, &impl_method.sig.ident)?;
554
555            let mut attrs = impl_method.attrs.clone();
556            find_remove(&mut impl_method.attrs, AttributeKind::Doc); // https://github.com/taiki-e/easy-ext/issues/20
557            find_remove(&mut attrs, AttributeKind::Inline); // `#[inline]` is ignored on function prototypes
558            Ok(TraitItem::Method(TraitItemMethod {
559                attrs,
560                sig: {
561                    let mut sig = impl_method.sig.clone();
562                    for arg in &mut sig.inputs {
563                        if let FnArg::Typed(pat, ..) = arg {
564                            if pat.to_string() != "self" {
565                                *pat = std::iter::once(TokenTree::Ident(Ident::new(
566                                    "_",
567                                    pat.clone().into_iter().next().unwrap().span(),
568                                )))
569                                .collect();
570                            }
571                        }
572                    }
573                    sig
574                },
575                semi_token: {
576                    let mut punct = Punct::new(';', Spacing::Alone);
577                    punct.set_span(impl_method.body.span());
578                    punct
579                },
580            }))
581        }
582    }
583}
584
585fn find_remove(attrs: &mut Vec<Attribute>, kind: AttributeKind) {
586    while let Some(i) = attrs.iter().position(|attr| attr.kind == kind) {
587        attrs.remove(i);
588    }
589}
590
591/// Returns the hash value of the input AST.
592fn hash(input: &TokenStream) -> u64 {
593    let mut hasher = DefaultHasher::new();
594    hasher.write(input.to_string().as_bytes());
595    hasher.finish()
596}