char_classes_proc_macro/
lib.rs

1use std::{convert::identity, iter::once};
2
3use litrs::{ByteStringLit, StringLit};
4use proc_macro::{Delimiter, Group, Ident, Punct, Spacing::*, Span, TokenStream, TokenTree, Literal};
5
6/// Make `compile_error! {"..."}`
7#[must_use]
8fn err(msg: &str, span: Span) -> TokenStream {
9    let s = |mut tt: TokenTree| {
10        tt.set_span(span);
11        tt
12    };
13
14    <TokenStream as FromIterator<TokenTree>>::from_iter([
15        Punct::new(':', Joint).into(),
16        Punct::new(':', Joint).into(),
17        Ident::new("core", span).into(),
18        Punct::new(':', Joint).into(),
19        Punct::new(':', Joint).into(),
20        Ident::new("compile_error", span).into(),
21        Punct::new('!', Joint).into(),
22        Group::new(Delimiter::Brace, s(
23            Literal::string(msg).into(),
24        ).into()).into(),
25    ].map(s))
26}
27
28fn matches(stream: TokenStream) -> TokenStream {
29    <TokenStream as FromIterator<TokenTree>>::from_iter([
30        Punct::new(':', Joint).into(),
31        Punct::new(':', Joint).into(),
32        Ident::new("core", Span::call_site()).into(),
33        Punct::new(':', Joint).into(),
34        Punct::new(':', Joint).into(),
35        Ident::new("matches", Span::call_site()).into(),
36        Punct::new('!', Joint).into(),
37        Group::new(Delimiter::Parenthesis, stream).into()
38    ])
39}
40
41fn first_elem(stream: TokenStream) -> TokenStream {
42    <TokenStream as FromIterator<TokenTree>>::from_iter([
43        Punct::new(':', Joint).into(),
44        Punct::new(':', Joint).into(),
45        Ident::new("char_classes", Span::call_site()).into(),
46        Punct::new(':', Joint).into(),
47        Punct::new(':', Joint).into(),
48        Ident::new("FirstElem", Span::call_site()).into(),
49        Punct::new(':', Joint).into(),
50        Punct::new(':', Joint).into(),
51        Ident::new("first_elem", Span::call_site()).into(),
52        Group::new(Delimiter::Parenthesis, stream).into()
53    ])
54}
55
56enum Str {
57    Norm(String),
58    Byte(Vec<u8>),
59}
60
61fn lit_str(tt: &TokenTree) -> Result<Str, TokenStream> {
62    StringLit::try_from(tt)
63        .map(|s| Str::Norm(s.into_value().into_owned()))
64        .map_err(|e| e.to_string())
65        .or_else(|e| ByteStringLit::try_from(tt)
66            .map(|b| Str::Byte(b.into_value().into_owned()))
67            .map_err(|e2| format!("{e}\n{e2}")))
68        .map_err(|e| err(&e, tt.span()))
69}
70
71trait Spaned {
72    fn spaned(self, span: Span) -> Self;
73}
74impl Spaned for TokenTree {
75    fn spaned(mut self, span: Span) -> Self {
76        self.set_span(span);
77        self
78    }
79}
80impl Spaned for Literal {
81    fn spaned(mut self, span: Span) -> Self {
82        self.set_span(span);
83        self
84    }
85}
86
87trait ToPat: Sized {
88    fn to_pat(self, span: Span) -> TokenStream;
89}
90impl ToPat for u8 {
91    fn to_pat(self, span: Span) -> TokenStream {
92        TokenTree::from(Literal::byte_character(self).spaned(span)).into()
93    }
94}
95impl ToPat for char {
96    fn to_pat(self, span: Span) -> TokenStream {
97        TokenTree::from(Literal::character(self).spaned(span)).into()
98    }
99}
100impl<T: ToPat> ToPat for (T, T) {
101    fn to_pat(self, span: Span) -> TokenStream {
102        let (from, to) = self;
103        TokenStream::from_iter([
104            from.to_pat(span),
105            <TokenStream as FromIterator<TokenTree>>::from_iter([
106                Punct::new('.', Joint).into(),
107                Punct::new('.', Joint).into(),
108                Punct::new('=', Joint).into(),
109            ]),
110            to.to_pat(span),
111        ])
112    }
113}
114
115trait IsDash {
116    fn is_dash(&self) -> bool;
117}
118impl IsDash for u8 {
119    fn is_dash(&self) -> bool {
120        *self == b'-'
121    }
122}
123impl IsDash for char {
124    fn is_dash(&self) -> bool {
125        *self == '-'
126    }
127}
128
129fn some(stream: TokenStream, span: Span) -> TokenStream {
130    TokenStream::from_iter([
131        TokenTree::from(Ident::new("Some", span)),
132        TokenTree::from(Group::new(Delimiter::Parenthesis, stream)),
133    ])
134}
135
136fn to_pats<T, I>(iter: I, span: Span) -> Result<TokenStream, TokenStream>
137where T: ToPat + IsDash,
138      I: IntoIterator<Item = T>,
139{
140    let mut iter = iter.into_iter().peekable();
141    let Some(mut first) = iter.next() else {
142        return Err(err("cannot support empty pattern", span));
143    };
144    let mut result = TokenStream::new();
145    let mut sep: fn(&mut TokenStream) = |_| ();
146
147    while let Some(cur) = iter.next() {
148        sep(&mut result);
149
150        if let Some(to) = iter.next_if(|_| cur.is_dash()) {
151            result.extend([(first, to).to_pat(span)]);
152
153            if let Some(next) = iter.next() {
154                first = next;
155            } else {
156                return Ok(some(result, span));
157            }
158        } else {
159            result.extend([first.to_pat(span)]);
160            first = cur;
161        }
162
163        sep = |result| {
164            result.extend([TokenTree::from(Punct::new('|', Alone))]);
165        };
166    }
167
168    sep(&mut result);
169    result.extend([first.to_pat(span)]);
170    Ok(some(result, span))
171}
172
173/// Like `char_classes::any()`, expand into [`matches`] for better performance
174///
175/// # Examples
176///
177/// ```ignore
178/// use char_classes::any;
179///
180/// assert!(any!("ab",      'a'));
181/// assert!(any!("ab",      'b'));
182/// assert!(any!("ab",      'b'));
183/// assert!(any!("a-c",     'a'));
184/// assert!(any!("a-c",     'b'));
185/// assert!(any!("a-c",     'c'));
186/// assert!(any!(b"ab",    b'a'));
187/// assert!(any!(b"ab",    b'b'));
188/// ```
189#[proc_macro]
190pub fn any(input: TokenStream) -> TokenStream {
191    let mut iter = input.into_iter();
192    let Some(first) = iter.next() else {
193        return err("unexpected end of input, expected a literal", Span::call_site());
194    };
195    let Some(comma) = iter.next() else {
196        return err("unexpected end of input, expected a comma", Span::call_site());
197    };
198    if !matches!(&comma, TokenTree::Punct(p) if p.as_char() == ',') {
199        return err("unexpected token, expected a comma", comma.span());
200    }
201    let lit_str = match lit_str(&first) {
202        Ok(s) => s,
203        Err(e) => return e,
204    };
205    match lit_str {
206        Str::Norm(s) => to_pats(s.chars(), first.span()),
207        Str::Byte(bytes) => to_pats(bytes, first.span()),
208    }.map_or_else(identity, |pat| {
209        let expr = once(Punct::new('&', Joint).into())
210            .chain(iter);
211        matches(first_elem(expr.collect()).into_iter()
212            .chain([Punct::new(',', Alone).into()])
213            .chain(pat)
214            .collect())
215    })
216}