closure_it/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(clippy::needless_doctest_main)]
3
4use proc_macro::{
5    Delimiter, Group, Ident, Literal, Punct, Spacing::*, Span, TokenStream,
6    TokenTree,
7};
8
9#[must_use]
10fn stream<I>(iter: I) -> TokenStream
11where I: IntoIterator<Item = TokenTree>,
12{
13    TokenStream::from_iter(iter)
14}
15
16fn err(msg: &str, span: Span) -> TokenStream {
17    let s = |mut t: TokenTree| {
18        t.set_span(span);
19        t
20    };
21    stream([
22        s(Punct::new(':', Joint).into()),
23        s(Punct::new(':', Joint).into()),
24        s(Ident::new("core", span).into()),
25        s(Punct::new(':', Joint).into()),
26        s(Punct::new(':', Joint).into()),
27        s(Ident::new("compile_error", span).into()),
28        s(Punct::new('!', Joint).into()),
29        s(Group::new(Delimiter::Brace, stream([
30            s(Literal::string(msg).into()),
31        ])).into()),
32    ])
33}
34
35#[derive(Default, Clone)]
36struct Closure<'a> {
37    it: Option<TokenTree>,
38    catch_it: &'a str,
39}
40impl Closure<'_> {
41    fn make_closure(&mut self) -> TokenStream {
42        let Some(it) = self.it.take() else {
43            return TokenStream::new();
44        };
45        stream([
46            Punct::new('|', Joint).into(),
47            it,
48            Punct::new('|', Joint).into(),
49        ])
50    }
51
52    fn ext_proc_it(&mut self, input: TokenStream) -> TokenStream {
53        let ext = &mut Self { catch_it: self.catch_it, ..Default::default() };
54        let proc_it = ext.proc_it(input);
55        ext.make_closure().into_iter().chain(proc_it).collect()
56    }
57
58    fn proc_it(&mut self, input: TokenStream) -> TokenStream {
59        let iter = &mut input.into_iter().peekable();
60        let mut result = TokenStream::new();
61
62        while let Some(tt) = iter.next() {
63            match tt {
64                TokenTree::Group(group)
65                    if group.delimiter() == Delimiter::Parenthesis =>
66                {
67                    let grouped = self.ext_proc_it(group.stream());
68                    result.extend([
69                        Group::new(group.delimiter(), grouped).into(),
70                    ] as [TokenTree; 1]);
71                },
72                TokenTree::Group(group) => {
73                    let grouped = self.proc_it(group.stream());
74                    result.extend([
75                        Group::new(group.delimiter(), grouped).into(),
76                    ] as [TokenTree; 1]);
77                },
78                TokenTree::Ident(ref ident)
79                    if ident.to_string() == self.catch_it =>
80                {
81                    result.extend([self.it.get_or_insert(tt).clone()]);
82                },
83                TokenTree::Ident(_) | TokenTree::Literal(_) => {
84                    result.extend([tt]);
85                },
86                TokenTree::Punct(ref punct)
87                    if matches!(punct.as_char(), ',' | ';') =>
88                {
89                    result.extend([tt]);
90                    result.extend(self.ext_proc_it(iter.collect()));
91                },
92                TokenTree::Punct(ref punct)
93                    if punct.as_char() == '='
94                        && punct.spacing() == Joint
95                        && iter.peek().is_some_and(|p| {
96                            matches!(p, TokenTree::Punct(p)
97                                if p.as_char() == '>')
98                        }) =>
99                {
100                    result.extend([tt, iter.next().unwrap()]);
101                    result.extend(self.ext_proc_it(iter.collect()));
102                },
103                TokenTree::Punct(_) => {
104                    result.extend([tt]);
105                },
106            }
107        }
108
109        result
110    }
111}
112
113fn get_catch_it<F>(attr: TokenStream, f: F) -> TokenStream
114where F: FnOnce(&str) -> TokenStream,
115{
116    let iter = &mut attr.into_iter();
117    let catch_it = match iter.next() {
118        Some(TokenTree::Ident(ident)) => &*ident.to_string(),
119        Some(attr) => return err("invalid input", attr.span()),
120        _ => "it",
121    };
122    if let Some(extra) = iter.next() {
123        return err("invalid input", extra.span());
124    }
125    f(catch_it)
126}
127
128/// Replace `it` to closure body, expand the closure after `,` `;` `=>` and `(`
129///
130/// # Examples
131/// ```
132/// #[closure_it::closure_it]
133/// fn main() {
134///     assert_eq!([0i32, 1, 2].map(it+2), [2, 3, 4]);
135///     assert_eq!([0i32, -1, 2].map(it.abs()), [0, 1, 2]);
136///     assert_eq!(Some(2).map_or(3, it*2), 4);
137/// }
138/// ```
139///
140/// ```
141/// #[closure_it::closure_it(this)]
142/// fn main() {
143///     assert_eq!([0i32, 1, 2].map(this+2), [2, 3, 4]);
144///     assert_eq!([0i32, -1, 2].map(this.abs()), [0, 1, 2]);
145///     assert_eq!(Some(2).map_or(3, this*2), 4);
146/// }
147/// ```
148#[proc_macro_attribute]
149pub fn closure_it(attr: TokenStream, item: TokenStream) -> TokenStream {
150    get_catch_it(attr, |catch_it| {
151        Closure { catch_it, ..Default::default() }
152            .ext_proc_it(item)
153    })
154}