Skip to main content

chumsky_proc_macro/
lib.rs

1use chumsky::{
2    extra::ParserExtra, input::ValueInput, label::LabelError, util::MaybeRef, IterParser, Parser,
3};
4use proc_macro2::TokenTree;
5use std::{fmt, vec};
6
7pub trait LikeTokenTree {
8    fn as_tok(&self) -> &TokenTree;
9}
10
11impl LikeTokenTree for TokenTree {
12    fn as_tok(&self) -> &TokenTree {
13        self
14    }
15}
16
17impl LikeTokenTree for (TokenTree, proc_macro2::Span) {
18    fn as_tok(&self) -> &TokenTree {
19        &self.0
20    }
21}
22
23#[derive(Clone, Debug)]
24pub struct TokenTreeWrapper(pub TokenTree);
25
26impl fmt::Display for TokenTreeWrapper {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        write!(f, "{}", self.0)
29    }
30}
31
32impl PartialEq for TokenTreeWrapper {
33    fn eq(&self, _other: &Self) -> bool {
34        false
35    }
36}
37
38impl Eq for TokenTreeWrapper {}
39
40impl LikeTokenTree for TokenTreeWrapper {
41    fn as_tok(&self) -> &TokenTree {
42        &self.0
43    }
44}
45
46#[derive(Debug)]
47#[non_exhaustive]
48pub enum Expected {
49    StandAlonePunct(char),
50    PunctSeq(String),
51    Ident,
52    ExactIdent(String),
53}
54
55impl<'a, T> From<Expected> for chumsky::error::RichPattern<'a, T> {
56    fn from(value: Expected) -> Self {
57        use chumsky::error::RichPattern;
58
59        match value {
60            Expected::StandAlonePunct(p) => RichPattern::Label(format!("'{}'", p).into()),
61            Expected::PunctSeq(p) => RichPattern::Label(format!("'{}'", p).into()),
62            Expected::Ident => RichPattern::Label("identifier".into()),
63            Expected::ExactIdent(p) => RichPattern::Identifier(p),
64        }
65    }
66}
67
68/// A single punctuation character like `+`, `-` or `#`.
69pub fn punct<'src, I, E>(val: char) -> impl Parser<'src, I, (), E> + Clone
70where
71    I: ValueInput<'src>,
72    I::Token: LikeTokenTree + 'src,
73    E: ParserExtra<'src, I>,
74    E::Error: LabelError<'src, I, Expected>,
75{
76    punct_impl(val)
77}
78
79fn punct_impl<'src, I, E>(val: char) -> impl Parser<'src, I, (), E> + Clone
80where
81    I: ValueInput<'src>,
82    I::Token: LikeTokenTree + 'src,
83    E: ParserExtra<'src, I>,
84    E::Error: LabelError<'src, I, Expected>,
85{
86    use chumsky::prelude::*;
87    any().try_map(move |x: I::Token, span| {
88        match &x.as_tok() {
89            TokenTree::Punct(p) => {
90                // disabled spacing check because proc macro sucks
91                if
92                /*p.spacing() == spacing &&*/
93                p.as_char() == val {
94                    Some(())
95                } else {
96                    None
97                }
98            }
99
100            _ => None,
101        }
102        .ok_or_else(|| {
103            LabelError::expected_found(
104                [Expected::StandAlonePunct(val)],
105                Some(MaybeRef::Val(x)),
106                span,
107            )
108        })
109    })
110}
111
112/// A sequence of punctuation character like `+=`, '--', or even `#<##>`.
113pub fn punct_seq<'src, I, E, S: AsRef<str>>(seq: S) -> impl Parser<'src, I, (), E> + Clone
114where
115    I: ValueInput<'src>,
116    I::Token: LikeTokenTree + 'src,
117    E: ParserExtra<'src, I> + 'src,
118    E::Error: LabelError<'src, I, Expected>,
119{
120    use chumsky::prelude::*;
121
122    let seq = seq.as_ref().to_string();
123    assert!(!seq.len() >= 2);
124
125    seq.chars()
126        .map(|val| punct_impl(val).boxed())
127        .reduce(|a, b| a.then_ignore(b).boxed())
128        .unwrap_or(empty().boxed())
129        .map_err_with_state(move |_, span, _| {
130            LabelError::expected_found([Expected::PunctSeq(seq.clone())], None, span)
131        })
132}
133
134pub fn ident<'src, I, E>() -> impl Parser<'src, I, String, E> + Clone
135where
136    I: ValueInput<'src>,
137    I::Token: LikeTokenTree + 'src,
138    E: ParserExtra<'src, I>,
139    E::Error: LabelError<'src, I, Expected>,
140{
141    use chumsky::prelude::*;
142
143    any().try_map(move |x: I::Token, span| {
144        match &x.as_tok() {
145            TokenTree::Ident(i) => Some(i.to_string()),
146            _ => None,
147        }
148        .ok_or_else(|| LabelError::expected_found([Expected::Ident], Some(MaybeRef::Val(x)), span))
149    })
150}
151
152pub fn exact_ident<'src, I, E, S>(exact: S) -> impl Parser<'src, I, (), E> + Clone
153where
154    S: AsRef<str> + Clone,
155    I: ValueInput<'src>,
156    I::Token: LikeTokenTree + 'src,
157    E: ParserExtra<'src, I>,
158    E::Error: LabelError<'src, I, Expected>,
159{
160    use chumsky::prelude::*;
161
162    any().try_map(move |x: I::Token, span| {
163        match &x.as_tok() {
164            TokenTree::Ident(i) if i.to_string().as_str() == exact.as_ref() => Some(()),
165            _ => None,
166        }
167        .ok_or_else(|| {
168            LabelError::expected_found(
169                [Expected::ExactIdent(exact.as_ref().to_string())],
170                Some(MaybeRef::Val(x)),
171                span,
172            )
173        })
174    })
175}
176
177pub fn namespace_with_ident<'src, I, E>() -> impl IterParser<'src, I, String, E> + Clone
178where
179    I: ValueInput<'src>,
180    I::Token: LikeTokenTree + 'src,
181    E: ParserExtra<'src, I> + 'src,
182    E::Error: LabelError<'src, I, Expected>,
183{
184    use chumsky::prelude::*;
185
186    ident().separated_by(punct_seq("::")).at_least(1)
187}
188
189/// A literal character (`'a'`), string (`"hello"`), number (`2.3`), etc.
190pub fn literal<'src, I, E>() -> impl Parser<'src, I, String, E> + Clone
191where
192    I: ValueInput<'src>,
193    I::Token: LikeTokenTree + 'src,
194    E: ParserExtra<'src, I>,
195    E::Error: LabelError<'src, I, Expected>,
196{
197    use chumsky::prelude::*;
198
199    any().try_map(move |x: I::Token, span| {
200        match &x.as_tok() {
201            TokenTree::Literal(i) => Some(i.to_string()),
202            _ => None,
203        }
204        .ok_or_else(|| LabelError::expected_found([Expected::Ident], Some(MaybeRef::Val(x)), span))
205    })
206}
207
208pub enum GroupDelim {
209    /// `( ... )`
210    Parenthesis,
211    /// `{ ... }`
212    Brace,
213    /// `[ ... ]`
214    Bracket,
215}
216
217impl GroupDelim {
218    fn to_procmacro(&self) -> proc_macro2::Delimiter {
219        match self {
220            GroupDelim::Brace => proc_macro2::Delimiter::Brace,
221            GroupDelim::Bracket => proc_macro2::Delimiter::Bracket,
222            GroupDelim::Parenthesis => proc_macro2::Delimiter::Parenthesis,
223        }
224    }
225}
226
227pub trait GroupExtension<P, PE> {
228    fn grouped(self, delim: GroupDelim) -> P;
229}
230
231impl<'wholesrc, 'partsrc, 'b, WI, V, WE, PP, PE>
232    GroupExtension<chumsky::Boxed<'wholesrc, 'b, WI, V, WE>, PE> for PP
233where
234    WI: ValueInput<'wholesrc> + 'b,
235    WI::Token: LikeTokenTree + 'wholesrc + 'b,
236    WE: ParserExtra<'wholesrc, WI> + 'b,
237    WE::Error: LabelError<'wholesrc, WI, Expected>,
238    PP: Parser<'partsrc, chumsky::input::Stream<vec::IntoIter<TokenTreeWrapper>>, V, PE>
239        + 'b + 'wholesrc,
240    PE: ParserExtra<'partsrc, chumsky::input::Stream<vec::IntoIter<TokenTreeWrapper>>>,
241    PE::Context: Default,
242    PE::State: Default,
243    PE::Error:
244        LabelError<'partsrc, chumsky::input::Stream<vec::IntoIter<TokenTreeWrapper>>, Expected>,
245    WE::Error: From<PE::Error>,
246{
247    fn grouped(self, delim: GroupDelim) -> chumsky::Boxed<'wholesrc, 'b, WI, V, WE> {
248        use chumsky::prelude::*;
249        any()
250            .try_map(move |x: WI::Token, span: WI::Span| match &x.as_tok() {
251                TokenTree::Group(i) if i.delimiter() == delim.to_procmacro() => self
252                    .parse(chumsky::input::Stream::from_iter(
253                        i.stream()
254                            .into_iter()
255                            .map(|x| TokenTreeWrapper(x))
256                            .collect::<Vec<_>>()
257                            .into_iter(),
258                    ))
259                    .into_result()
260                    .map_err(|x| x.into_iter().reduce(|a, b| a.merge(b)).unwrap().into()),
261                _ => Err(LabelError::expected_found(
262                    [Expected::Ident],
263                    Some(MaybeRef::Val(x)),
264                    span,
265                )),
266            })
267            .boxed()
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use chumsky::{extra::Err, input::Stream, prelude::*};
275    use quote::quote;
276
277    #[test]
278    fn test_punct() {
279        let toks = quote! { + }.into_iter();
280
281        let parser = &punct::<_, Err<Simple<_>>>('+');
282        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
283    }
284
285    #[test]
286    fn test_punct2() {
287        let toks = quote! { += }.into_iter();
288
289        let parser = &punct::<_, Err<Simple<_>>>('+');
290        let _v = parser
291            .parse(Stream::from_iter(toks))
292            .into_result()
293            .unwrap_err();
294    }
295
296    #[test]
297    fn test_punct_seq() {
298        let toks = quote! { --> }.into_iter();
299
300        let parser = &punct_seq::<_, Err<Simple<_>>, _>("-->");
301        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
302    }
303
304    #[test]
305    fn test_punct_seq2() {
306        let toks = quote! { -> }.into_iter();
307
308        let parser = &punct_seq::<_, Err<Simple<_>>, _>("-->");
309        let _v = parser
310            .parse(Stream::from_iter(toks))
311            .into_result()
312            .unwrap_err();
313    }
314
315    #[test]
316    fn test_punct_seq3() {
317        let toks = quote! { --># }.into_iter();
318
319        let parser = &punct_seq::<_, Err<Simple<_>>, _>("-->");
320        let _v = parser
321            .parse(Stream::from_iter(toks))
322            .into_result()
323            .unwrap_err();
324    }
325
326    #[test]
327    fn test_punct_seq4() {
328        let toks = quote! { ---> }.into_iter();
329
330        let parser = &punct_seq::<_, Err<Simple<_>>, _>("-->");
331        let _v = parser
332            .parse(Stream::from_iter(toks))
333            .into_result()
334            .unwrap_err();
335    }
336
337    #[test]
338    fn test_punct_seq5() {
339        let toks = quote! { -> }.into_iter();
340
341        let parser = &punct_seq::<_, Err<Simple<_>>, _>("->");
342        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
343    }
344
345    #[test]
346    fn test_punct_seq6() {
347        let toks = quote! { ->< }.into_iter();
348
349        let parser = &punct_seq::<_, Err<Simple<_>>, _>("->").then(punct('<'));
350        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
351    }
352
353    #[test]
354    fn test_punct_seq7() {
355        let toks = quote! { ---> }.into_iter();
356
357        let parser = &punct_seq::<_, Err<Simple<_>>, _>("--->");
358        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
359    }
360
361    #[test]
362    fn test_punct_seq8() {
363        let toks = quote! { ++++ }.into_iter();
364
365        let parser = &punct_seq::<_, Err<Simple<_>>, _>("++++");
366        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
367    }
368
369    #[test]
370    fn test_ident0() {
371        let toks = quote! { hello_world }.into_iter();
372
373        let parser = &ident::<_, Err<Simple<_>>>();
374        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
375    }
376
377    #[test]
378    fn test_ident1() {
379        let toks = quote! { 1 2 3 }.into_iter();
380
381        let parser = &ident::<_, Err<Simple<_>>>();
382        let _v = parser
383            .parse(Stream::from_iter(toks))
384            .into_result()
385            .unwrap_err();
386    }
387
388    #[test]
389    fn test_exact_ident0() {
390        let toks = quote! { hey }.into_iter();
391
392        let parser = &exact_ident::<_, Err<Simple<_>>, _>("hey");
393        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
394    }
395
396    #[test]
397    fn test_exact_ident1() {
398        let toks = quote! { heey }.into_iter();
399
400        let parser = &exact_ident::<_, Err<Simple<_>>, _>("hey");
401        let _v = parser
402            .parse(Stream::from_iter(toks))
403            .into_result()
404            .unwrap_err();
405    }
406
407    #[test]
408    fn test_namespace_indent() {
409        let toks = quote! { hello::world::test }.into_iter();
410
411        let parser = &namespace_with_ident::<_, Err<Simple<_>>>().collect::<Vec<_>>();
412        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
413    }
414
415    #[test]
416    fn test_literal() {
417        let toks = quote! { "hey" }.into_iter();
418
419        let parser = &literal::<_, Err<Simple<_>>>();
420        let _v = parser.parse(Stream::from_iter(toks)).into_result().unwrap();
421    }
422
423    #[test]
424    fn test_group() {
425        let toks = quote! { hello (world::x) }.into_iter();
426
427        let parser = &namespace_with_ident::<_, Err<Rich<_>>>()
428            .collect::<Vec<_>>()
429            .then(
430                namespace_with_ident::<_, Err<Rich<_>>>()
431                    .collect::<Vec<_>>()
432                    .grouped(GroupDelim::Parenthesis),
433            );
434        let _v = parser
435            .parse(Stream::from_iter(toks.map(|x| TokenTreeWrapper(x))))
436            .into_result()
437            .unwrap();
438    }
439}