char_classes_proc_macro/
lib.rs

1use std::{convert::identity, iter::{once, Peekable}};
2
3use litrs::{ByteStringLit, StringLit};
4use proc_macro::{Delimiter, Group, Ident, Punct, Spacing::*, Span, TokenStream, TokenTree, Literal};
5
6/// Make `compile_error! {"..."}`
7#[must_use]
8fn err(msg: &str, span: Span) -> TokenStream {
9    let s = |mut tt: TokenTree| {
10        tt.set_span(span);
11        tt
12    };
13
14    <TokenStream as FromIterator<TokenTree>>::from_iter([
15        Punct::new(':', Joint).into(),
16        Punct::new(':', Joint).into(),
17        Ident::new("core", span).into(),
18        Punct::new(':', Joint).into(),
19        Punct::new(':', Joint).into(),
20        Ident::new("compile_error", span).into(),
21        Punct::new('!', Joint).into(),
22        Group::new(Delimiter::Brace, s(
23            Literal::string(msg).into(),
24        ).into()).into(),
25    ].map(s))
26}
27
28fn matches(stream: TokenStream) -> TokenStream {
29    <TokenStream as FromIterator<TokenTree>>::from_iter([
30        Punct::new(':', Joint).into(),
31        Punct::new(':', Joint).into(),
32        Ident::new("core", Span::call_site()).into(),
33        Punct::new(':', Joint).into(),
34        Punct::new(':', Joint).into(),
35        Ident::new("matches", Span::call_site()).into(),
36        Punct::new('!', Joint).into(),
37        Group::new(Delimiter::Parenthesis, stream).into()
38    ])
39}
40
41fn tts(tt: impl Into<TokenTree>) -> TokenStream {
42    let tt: TokenTree = tt.into();
43    TokenStream::from(tt)
44}
45
46fn stream(i: impl IntoIterator<Item = TokenTree>) -> TokenStream {
47    i.into_iter().collect()
48}
49fn streams(i: impl IntoIterator<Item = TokenStream>) -> TokenStream {
50    i.into_iter().collect()
51}
52
53enum Mode {
54    Normal,
55    Exclude,
56    Not(TokenTree),
57    Pattern,
58}
59
60use Mode::*;
61
62impl Mode {
63    fn resolve(iter: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Self {
64        match iter.peek() {
65            Some(TokenTree::Punct(p)) if p.as_char() == '!' => {
66                Not(iter.next().unwrap())
67            },
68            Some(TokenTree::Punct(p)) if p.as_char() == '^' => {
69                iter.next().unwrap();
70                Exclude
71            },
72            Some(TokenTree::Punct(p)) if p.as_char() == '@' => {
73                iter.next().unwrap();
74                Pattern
75            },
76            _ => Normal,
77        }
78    }
79
80    fn run(
81        self,
82        expr: TokenStream,
83        pat: TokenStream,
84        com: TokenTree,
85    ) -> TokenStream {
86        match self {
87            Pattern => unimplemented!(),
88            Normal => matches(streams([expr, tts(com), pat])),
89            Not(not) => once(not)
90                .chain(Normal.run(expr, pat, com))
91                .collect(),
92            Exclude => stream([
93                Ident::new("match", com.span()).into(),
94                Group::new(Delimiter::None, expr).into(),
95                Group::new(
96                    Delimiter::Brace,
97                    streams([
98                        none(),
99                        tts(Punct::new('|', Alone)),
100                        pat,
101                        stream([
102                            Punct::new('=', Joint).into(),
103                            Punct::new('>', Alone).into(),
104                            Ident::new("false", Span::call_site()).into(),
105                            Punct::new(',', Alone).into(),
106                        ]),
107                        some(tts(Ident::new("_", Span::call_site()))),
108                        stream([
109                            Punct::new('=', Joint).into(),
110                            Punct::new('>', Alone).into(),
111                            Ident::new("true", Span::call_site()).into(),
112                            Punct::new(',', Alone).into(),
113                        ]),
114                    ]),
115                ).into(),
116            ]),
117        }
118    }
119}
120
121fn first_elem(stream: TokenStream) -> TokenStream {
122    let stream = [
123        TokenStream::from(TokenTree::Punct(Punct::new('&', Joint))),
124        stream,
125    ].into_iter().collect();
126    <TokenStream as FromIterator<TokenTree>>::from_iter([
127        Punct::new(':', Joint).into(),
128        Punct::new(':', Joint).into(),
129        Ident::new("char_classes", Span::call_site()).into(),
130        Punct::new(':', Joint).into(),
131        Punct::new(':', Joint).into(),
132        Ident::new("FirstElem", Span::call_site()).into(),
133        Punct::new(':', Joint).into(),
134        Punct::new(':', Joint).into(),
135        Ident::new("first_elem", Span::call_site()).into(),
136        Group::new(Delimiter::Parenthesis, stream).into()
137    ])
138}
139
140enum Str {
141    Norm(String),
142    Byte(Vec<u8>),
143}
144
145fn lit_str(tt: &TokenTree) -> Result<Str, TokenStream> {
146    StringLit::try_from(tt)
147        .map(|s| Str::Norm(s.into_value()))
148        .map_err(|e| e.to_string())
149        .or_else(|e| ByteStringLit::try_from(tt)
150            .map(|b| Str::Byte(b.into_value()))
151            .map_err(|e2| format!("{e}\n{e2}")))
152        .map_err(|e| err(&e, tt.span()))
153}
154
155trait Spaned {
156    fn spaned(self, span: Span) -> Self;
157}
158impl Spaned for TokenTree {
159    fn spaned(mut self, span: Span) -> Self {
160        self.set_span(span);
161        self
162    }
163}
164impl Spaned for Literal {
165    fn spaned(mut self, span: Span) -> Self {
166        self.set_span(span);
167        self
168    }
169}
170impl Spaned for Punct {
171    fn spaned(mut self, span: Span) -> Self {
172        self.set_span(span);
173        self
174    }
175}
176
177trait ToPat: Sized {
178    fn to_pat(self, span: Span) -> TokenStream;
179}
180impl ToPat for u8 {
181    fn to_pat(self, span: Span) -> TokenStream {
182        TokenTree::from(Literal::byte_character(self).spaned(span)).into()
183    }
184}
185impl ToPat for char {
186    fn to_pat(self, span: Span) -> TokenStream {
187        TokenTree::from(Literal::character(self).spaned(span)).into()
188    }
189}
190impl<T: ToPat> ToPat for (T, T) {
191    fn to_pat(self, span: Span) -> TokenStream {
192        let (from, to) = self;
193        TokenStream::from_iter([
194            from.to_pat(span),
195            <TokenStream as FromIterator<TokenTree>>::from_iter([
196                Punct::new('.', Joint).into(),
197                Punct::new('.', Joint).into(),
198                Punct::new('=', Joint).into(),
199            ]),
200            to.to_pat(span),
201        ])
202    }
203}
204
205trait IsDash {
206    fn is_dash(&self) -> bool;
207}
208impl IsDash for u8 {
209    fn is_dash(&self) -> bool {
210        *self == b'-'
211    }
212}
213impl IsDash for char {
214    fn is_dash(&self) -> bool {
215        *self == '-'
216    }
217}
218
219trait Expected: Iterator<Item = TokenTree> + Sized {
220    fn expected(&mut self, ty: &str) -> Result<TokenTree, TokenStream> {
221        self.next()
222            .ok_or_else(||
223        {
224            let msg = format!("unexpected end of input, expected a {ty}");
225            err(&msg, Span::call_site())
226        })
227    }
228}
229impl<T: Iterator<Item = TokenTree>> Expected for T { }
230
231fn none() -> TokenStream {
232    stream([
233        Punct::new(':', Joint).into(),
234        Punct::new(':', Joint).into(),
235        Ident::new("core", Span::call_site()).into(),
236        Punct::new(':', Joint).into(),
237        Punct::new(':', Joint).into(),
238        Ident::new("option", Span::call_site()).into(),
239        Punct::new(':', Joint).into(),
240        Punct::new(':', Joint).into(),
241        Ident::new("Option", Span::call_site()).into(),
242        Punct::new(':', Joint).into(),
243        Punct::new(':', Joint).into(),
244        Ident::new("None", Span::call_site()).into(),
245    ])
246}
247
248fn some(input: TokenStream) -> TokenStream {
249    stream([
250        Punct::new(':', Joint).into(),
251        Punct::new(':', Joint).into(),
252        Ident::new("core", Span::call_site()).into(),
253        Punct::new(':', Joint).into(),
254        Punct::new(':', Joint).into(),
255        Ident::new("option", Span::call_site()).into(),
256        Punct::new(':', Joint).into(),
257        Punct::new(':', Joint).into(),
258        Ident::new("Option", Span::call_site()).into(),
259        Punct::new(':', Joint).into(),
260        Punct::new(':', Joint).into(),
261        Ident::new("Some", Span::call_site()).into(),
262        Group::new(Delimiter::Parenthesis, input).into(),
263    ])
264}
265
266fn to_pats<T, I>(iter: I, span: Span) -> Result<TokenStream, TokenStream>
267where T: ToPat + IsDash,
268      I: IntoIterator<Item = T>,
269{
270    let mut iter = iter.into_iter().peekable();
271    let Some(mut first) = iter.next() else {
272        return Err(err("not support empty pattern", span));
273    };
274    let mut result = TokenStream::new();
275    let mut sep: fn(&mut TokenStream) = |_| ();
276
277    while let Some(cur) = iter.next() {
278        sep(&mut result);
279
280        if let Some(to) = iter.next_if(|_| cur.is_dash()) {
281            result.extend([(first, to).to_pat(span)]);
282
283            if let Some(next) = iter.next() {
284                first = next;
285            } else {
286                return Ok(some(result));
287            }
288        } else {
289            result.extend([first.to_pat(span)]);
290            first = cur;
291        }
292
293        sep = |result| {
294            result.extend([TokenTree::from(Punct::new('|', Alone))]);
295        };
296    }
297
298    sep(&mut result);
299    result.extend([first.to_pat(span)]);
300    Ok(some(result))
301}
302
303/// Like `char_classes::any()`, expand into `match` for better performance (about 5x)
304///
305/// - `^"..."` is exclude pattern
306/// - `!"..."` like `!any!(...)`
307/// - `@"..."` expand to pattern only, can used for `match`
308///
309/// # Examples
310///
311/// ```ignore
312/// use char_classes::any;
313///
314/// assert!(any!("ab",      'a'));
315/// assert!(any!("ab",      'b'));
316/// assert!(any!("ab",      'b'));
317/// assert!(any!("a-c",     'a'));
318/// assert!(any!("a-c",     'b'));
319/// assert!(any!("a-c",     'c'));
320/// assert!(any!(b"ab",    b'a'));
321/// assert!(any!(b"ab",    b'b'));
322///
323/// assert!(! any!(^b"ab",   b'b'));
324/// assert!(! any!(^"ab",   ""));
325/// assert!(any!(!"ab",   ""));
326///
327/// assert!(any!(b"ab")(b'b'));
328/// ```
329///
330/// **predicate mode**:
331///
332/// ```ignore
333/// use char_classes::any;
334///
335/// assert!(any!(b"ab")(b"b"));
336/// assert!(any!(!b"ab")(b"c"));
337/// assert!(any!(^b"ab")(b"c"));
338///
339/// assert!(any!(!b"ab")(b""));
340/// assert!(! any!(^b"ab")(b""));
341/// ```
342///
343/// **pattern mode**:
344///
345/// ```ignore
346/// use char_classes::any;
347///
348/// match 'x' {
349///     any!(@"a-z") => (),
350///     _ => panic!(),
351/// }
352/// assert!(matches!('c', any!(@"a-z")));
353/// ```
354#[proc_macro]
355pub fn any(input: TokenStream) -> TokenStream {
356    any_impl(input).unwrap_or_else(identity)
357}
358
359fn any_impl(input: TokenStream) -> Result<TokenStream, TokenStream> {
360    let mut iter = input.into_iter().peekable();
361    let mode = Mode::resolve(&mut iter);
362    let first = iter.expected("literal")?;
363    let lit_str = lit_str(&first)?;
364
365    let pat = match lit_str {
366        Str::Norm(s) => to_pats(s.chars(), first.span()),
367        Str::Byte(bytes) => to_pats(bytes, first.span()),
368    }?;
369    let predicate_mode = iter.peek().is_none();
370    if let Pattern = mode {
371        if let Some(extra) = iter.peek() {
372            return Err(err("unexpected token, expect end of input", extra.span()));
373        }
374        return Ok(pat);
375    }
376
377    let com = iter.next()
378        .unwrap_or_else(|| Punct::new(',', Alone).into());
379
380    let output = if predicate_mode {
381        let name = TokenTree::from(Ident::new("input", first.span()));
382        let expr = first_elem(name.clone().into());
383
384        once(Punct::new('|', Joint).into())
385            .chain([name, Punct::new('|', Alone).into()])
386            .chain(mode.run(expr, pat, com))
387            .collect()
388    } else {
389        mode.run(first_elem(iter.collect()), pat, com)
390    };
391
392    Ok(tts(Group::new(Delimiter::None, output)))
393}