proc_macro_tool/
func_utils.rs

1use proc_macro::{
2    Delimiter, Group, Ident, Literal, Punct, Spacing::*, Span, TokenStream,
3    TokenTree,
4};
5
6use crate::{
7    GetSpan, ParseIter, ParseIterExt as _, SetSpan, TokenStreamExt as _,
8    TokenTreeExt as _,
9};
10
11/// Create [`TokenStream`] from
12/// [`IntoIterator<Item = TokenTree>`](IntoIterator)
13#[must_use]
14pub fn stream<I>(iter: I) -> TokenStream
15where I: IntoIterator<Item = TokenTree>,
16{
17    TokenStream::from_iter(iter)
18}
19
20/// Create [`TokenStream`] from
21/// [`IntoIterator<Item = TokenStream>`](IntoIterator)
22#[must_use]
23pub fn streams<I>(iter: I) -> TokenStream
24where I: IntoIterator<Item = TokenStream>,
25{
26    TokenStream::from_iter(iter)
27}
28
29fn pfunc_predicate<I>(names: &[&str], p: &Punct, iter: &mut ParseIter<I>) -> bool
30where I: Iterator<Item = TokenTree>,
31{
32    p.as_char() == '#'
33        && iter.peek_is(|i| i.as_ident()
34            .is_some_and(|i| names.contains(&&*i.to_string())))
35        && iter.peek_i_is(1, |t| t.is_solid_group())
36}
37
38fn subtree_contain_pfunc(names: &[&str], stream: impl IntoIterator<Item = TokenTree>) -> bool {
39    let iter = &mut stream.parse_iter();
40
41    while let Some(tt) = iter.next() {
42        match tt {
43            TokenTree::Punct(p) if pfunc_predicate(names, &p, iter) => {
44                return true;
45            },
46            TokenTree::Group(g)
47                if subtree_contain_pfunc(names, g.stream()) =>
48            {
49                return true;
50            },
51            _ => {},
52        }
53    }
54
55    false
56}
57
58fn pfunc_impl<F, R>(
59    stream: impl IntoIterator<Item = TokenTree>,
60    proc_input: bool,
61    names: &[&str],
62    lossless: bool,
63    f: &mut F,
64) -> Result<TokenStream, R>
65where F: FnMut(Ident, Group) -> Result<TokenStream, R>,
66{
67    let iter = &mut stream.parse_iter();
68    let mut result = TokenStream::new();
69
70    while let Some(tt) = iter.next() {
71        match tt {
72            TokenTree::Punct(p) if pfunc_predicate(names, &p, iter) => {
73                let ident = iter.next().unwrap().into_ident().unwrap();
74                let mut group = iter.next().unwrap().into_group().unwrap();
75                if proc_input {
76                    let sub = pfunc_impl(
77                        group.stream(),
78                        proc_input,
79                        names,
80                        lossless,
81                        f,
82                    )?;
83                    group = sub
84                        .grouped(group.delimiter())
85                        .set_spaned(group.span());
86                }
87                result.add(f(ident, group)?);
88            },
89            TokenTree::Group(g)
90                if !lossless || subtree_contain_pfunc(names, g.stream()) =>
91            {
92                let sub = pfunc_impl(
93                    g.stream(),
94                    proc_input,
95                    names,
96                    lossless,
97                    f,
98                )?;
99                result.push(sub
100                    .grouped(g.delimiter())
101                    .set_spaned(g.span())
102                    .into());
103            },
104            _ => _ = result.push(tt),
105        }
106    }
107
108    Ok(result)
109}
110
111/// Call `f` on `#name(...)` `#name[...]` etc, exclude [`Delimiter::None`]
112///
113/// Apply pfunc for `(...)` when `proc_input` is `true`
114#[allow(clippy::missing_panics_doc)]
115pub fn pfunc<'a>(
116    stream: impl IntoIterator<Item = TokenTree>,
117    proc_input: bool,
118    names: impl AsRef<[&'a str]>,
119    mut f: impl FnMut(Ident, Group) -> TokenStream,
120) -> TokenStream {
121    let f = &mut |i, g| {
122        Ok::<_, ()>(f(i, g))
123    };
124    pfunc_impl(stream, proc_input, names.as_ref(), false, f).unwrap()
125}
126
127/// Call `f` on `#name(...)` `#name[...]` etc, exclude [`Delimiter::None`]
128///
129/// Apply pfunc for `(...)` when `proc_input` is `true`
130pub fn try_pfunc<'a, R>(
131    stream: impl IntoIterator<Item = TokenTree>,
132    proc_input: bool,
133    names: impl AsRef<[&'a str]>,
134    mut f: impl FnMut(Ident, Group) -> Result<TokenStream, R>,
135) -> Result<TokenStream, R> {
136    pfunc_impl(stream, proc_input, names.as_ref(), false, &mut f)
137}
138
139/// Like [`pfunc`], but it's lossless when no changes are made
140#[allow(clippy::missing_panics_doc)]
141pub fn pfunc_lossless<'a>(
142    stream: impl IntoIterator<Item = TokenTree>,
143    proc_input: bool,
144    names: impl AsRef<[&'a str]>,
145    mut f: impl FnMut(Ident, Group) -> TokenStream,
146) -> TokenStream {
147    let f = &mut |i, g| {
148        Ok::<_, ()>(f(i, g))
149    };
150    pfunc_impl(stream, proc_input, names.as_ref(), true, f).unwrap()
151}
152
153/// Like [`try_pfunc`], but it's lossless when no changes are made
154pub fn try_pfunc_lossless<'a, R>(
155    stream: impl IntoIterator<Item = TokenTree>,
156    proc_input: bool,
157    names: impl AsRef<[&'a str]>,
158    mut f: impl FnMut(Ident, Group) -> Result<TokenStream, R>,
159) -> Result<TokenStream, R> {
160    pfunc_impl(stream, proc_input, names.as_ref(), true, &mut f)
161}
162
163/// Make `compile_error! {"..."}`
164#[must_use]
165pub fn err(msg: &str, span: impl GetSpan) -> TokenStream {
166    let s = span_setter(span.span());
167
168    stream([
169        Punct::new(':', Joint).into(),
170        Punct::new(':', Joint).into(),
171        Ident::new("core", span.span()).into(),
172        Punct::new(':', Joint).into(),
173        Punct::new(':', Joint).into(),
174        Ident::new("compile_error", span.span()).into(),
175        Punct::new('!', Joint).into(),
176        Group::new(Delimiter::Brace, stream([
177            Literal::string(msg).into(),
178        ].map(s))).into(),
179    ].map(s))
180}
181
182/// Like [`err()`], but use [`Result`]
183///
184/// # Errors
185/// - always return [`Err`]
186pub fn rerr<T>(msg: &str, span: impl GetSpan) -> Result<T, TokenStream> {
187    Err(err(msg, span))
188}
189
190/// Make puncts, `spacing` is last punct spacing
191///
192/// - `"+-"` like `[Joint('+'), Joint('-')]`
193/// - `"+- "` like `[Joint('+'), Alone('-')]`
194/// - `"+ -"` like `[Alone('+'), Joint('-')]`
195pub fn puncts(puncts: impl AsRef<[u8]>) -> TokenStream {
196    puncts_spanned(puncts, Span::call_site())
197}
198
199/// Make puncts, `spacing` is last punct spacing
200///
201/// Like [`puncts`], but `.set_span(span)`
202pub fn puncts_spanned(puncts: impl AsRef<[u8]>, span: Span) -> TokenStream {
203    let puncts = puncts.as_ref().trim_ascii_start();
204    let iter = &mut puncts.iter().copied().peekable();
205    let mut result = TokenStream::new();
206
207    while let Some(ch) = iter.next() {
208        debug_assert!(! ch.is_ascii_whitespace());
209        let mut s = None;
210        while iter.next_if(u8::is_ascii_whitespace).is_some() {
211            s = Some(Alone)
212        }
213        let spacing = s.or(iter.peek().map(|_| Joint))
214            .unwrap_or(Joint);
215        let p = Punct::new(ch.into(), spacing);
216        result.push(p.set_spaned(span).into());
217    }
218
219    result
220}
221
222/// Generate a function, set input `TokenTree` span
223pub fn span_setter<T>(span: Span) -> impl Fn(T) -> T + Copy
224where T: SetSpan,
225{
226    move |tt| {
227        tt.set_spaned(span)
228    }
229}
230
231/// Like [`Group::new(Delimiter::Parenthesis, iter)`](Group::new)
232pub fn paren<I>(iter: I) -> Group
233where I: IntoIterator,
234      TokenStream: FromIterator<I::Item>,
235{
236    iter.into_iter()
237        .collect::<TokenStream>()
238        .grouped_paren()
239}
240
241/// Like [`Group::new(Delimiter::Brace, iter)`](Group::new)
242pub fn brace<I>(iter: I) -> Group
243where I: IntoIterator,
244      TokenStream: FromIterator<I::Item>,
245{
246    iter.into_iter()
247        .collect::<TokenStream>()
248        .grouped_brace()
249}
250
251/// Like [`Group::new(Delimiter::Bracket, iter)`](Group::new)
252pub fn bracket<I>(iter: I) -> Group
253where I: IntoIterator,
254      TokenStream: FromIterator<I::Item>,
255{
256    iter.into_iter()
257        .collect::<TokenStream>()
258        .grouped_bracket()
259}
260
261/// Like [`Group::new(Delimiter::None, iter)`](Group::new)
262pub fn none<I>(iter: I) -> Group
263where I: IntoIterator,
264      TokenStream: FromIterator<I::Item>,
265{
266    iter.into_iter()
267        .collect::<TokenStream>()
268        .grouped_none()
269}