Skip to main content

closure_it/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(clippy::needless_doctest_main)]
3
4use proc_macro::{
5    Delimiter, Group, Ident, Literal, Punct, Spacing::*, Span, TokenStream,
6    TokenTree,
7};
8
9fn span_setter(span: Span) -> impl Fn(TokenTree) -> TokenTree {
10    move |mut tt| {
11        tt.set_span(span);
12        tt
13    }
14}
15
16#[must_use]
17fn stream<I>(iter: I) -> TokenStream
18where I: IntoIterator<Item = TokenTree>,
19{
20    TokenStream::from_iter(iter)
21}
22
23fn err(msg: &str, span: Span) -> TokenStream {
24    let s = span_setter(span);
25    stream([
26        s(Punct::new(':', Joint).into()),
27        s(Punct::new(':', Joint).into()),
28        s(Ident::new("core", span).into()),
29        s(Punct::new(':', Joint).into()),
30        s(Punct::new(':', Joint).into()),
31        s(Ident::new("compile_error", span).into()),
32        s(Punct::new('!', Joint).into()),
33        s(Group::new(Delimiter::Brace, stream([
34            s(Literal::string(msg).into()),
35        ])).into()),
36    ])
37}
38
39#[derive(Default, Clone)]
40struct Closure<'a> {
41    it: Option<TokenTree>,
42    catch_it: &'a str,
43}
44impl Closure<'_> {
45    fn make_closure(&mut self) -> TokenStream {
46        let Some(it) = self.it.take() else {
47            return TokenStream::new();
48        };
49        let s = span_setter(it.span());
50        stream([
51            s(Punct::new('|', Joint).into()),
52            it,
53            s(Punct::new('|', Joint).into()),
54        ])
55    }
56
57    fn ext_proc_it(&self, input: TokenStream) -> TokenStream {
58        let ext = &mut Self { catch_it: self.catch_it, ..Default::default() };
59        let proc_it = ext.proc_it(input);
60        ext.make_closure().into_iter().chain(proc_it).collect()
61    }
62
63    fn proc_it(&mut self, input: TokenStream) -> TokenStream {
64        let iter = &mut input.into_iter().peekable();
65        let mut result = TokenStream::new();
66
67        while let Some(tt) = iter.next() {
68            let s = span_setter(tt.span());
69            match tt {
70                TokenTree::Group(group)
71                    if group.delimiter() == Delimiter::Parenthesis =>
72                {
73                    let grouped = self.ext_proc_it(group.stream());
74                    result.extend([
75                        s(Group::new(group.delimiter(), grouped).into()),
76                    ]);
77                },
78                TokenTree::Group(group) => {
79                    let grouped = self.proc_it(group.stream());
80                    result.extend([
81                        s(Group::new(group.delimiter(), grouped).into()),
82                    ]);
83                },
84                TokenTree::Ident(ref ident)
85                    if ident.to_string() == self.catch_it =>
86                {
87                    self.it.get_or_insert_with(|| tt.clone());
88                    result.extend([tt]);
89                },
90                TokenTree::Ident(_) | TokenTree::Literal(_) => {
91                    result.extend([tt]);
92                },
93                TokenTree::Punct(ref punct)
94                    if matches!(punct.as_char(), ',' | ';') =>
95                {
96                    result.extend([tt]);
97                    result.extend(self.ext_proc_it(iter.collect()));
98                },
99                TokenTree::Punct(ref punct)
100                    if punct.as_char() == '='
101                        && punct.spacing() == Joint
102                        && iter.peek().is_some_and(|p| {
103                            matches!(p, TokenTree::Punct(p)
104                                if p.as_char() == '>')
105                        }) =>
106                {
107                    result.extend([tt, iter.next().unwrap()]);
108                    result.extend(self.ext_proc_it(iter.collect()));
109                },
110                TokenTree::Punct(_) => {
111                    result.extend([tt]);
112                },
113            }
114        }
115
116        result
117    }
118}
119
120fn get_catch_it<F>(attr: TokenStream, f: F) -> TokenStream
121where F: FnOnce(&str) -> TokenStream,
122{
123    let iter = &mut attr.into_iter();
124    let catch_it = match iter.next() {
125        Some(TokenTree::Ident(ident)) => &*ident.to_string(),
126        Some(attr) => return err("invalid input", attr.span()),
127        _ => "it",
128    };
129    if let Some(extra) = iter.next() {
130        return err("invalid input", extra.span());
131    }
132    f(catch_it)
133}
134
135/// Replace `it` to closure body,
136/// expand the closure parameter after `,` `;` `=>` and `(`
137///
138/// # Examples
139/// ```
140/// #[closure_it::closure_it]
141/// fn main() {
142///     assert_eq!([0i32, 1, 2].map(it+2), [2, 3, 4]);
143///     assert_eq!([0i32, -1, 2].map(it.abs()), [0, 1, 2]);
144///     assert_eq!(Some(2).map_or(3, it*2), 4);
145/// }
146/// ```
147///
148/// ```
149/// #[closure_it::closure_it(this)]
150/// fn main() {
151///     assert_eq!([0i32, 1, 2].map(this+2), [2, 3, 4]);
152///     assert_eq!([0i32, -1, 2].map(this.abs()), [0, 1, 2]);
153///     assert_eq!(Some(2).map_or(3, this*2), 4);
154/// }
155/// ```
156#[proc_macro_attribute]
157pub fn closure_it(attr: TokenStream, item: TokenStream) -> TokenStream {
158    get_catch_it(attr, |catch_it| {
159        Closure { catch_it, ..Default::default() }
160            .ext_proc_it(item)
161    })
162}