1use std::{convert::identity, iter::once};
2
3use litrs::{ByteStringLit, StringLit};
4use proc_macro::{Delimiter, Group, Ident, Punct, Spacing::*, Span, TokenStream, TokenTree, Literal};
5
6#[must_use]
8fn err(msg: &str, span: Span) -> TokenStream {
9 let s = |mut tt: TokenTree| {
10 tt.set_span(span);
11 tt
12 };
13
14 <TokenStream as FromIterator<TokenTree>>::from_iter([
15 Punct::new(':', Joint).into(),
16 Punct::new(':', Joint).into(),
17 Ident::new("core", span).into(),
18 Punct::new(':', Joint).into(),
19 Punct::new(':', Joint).into(),
20 Ident::new("compile_error", span).into(),
21 Punct::new('!', Joint).into(),
22 Group::new(Delimiter::Brace, s(
23 Literal::string(msg).into(),
24 ).into()).into(),
25 ].map(s))
26}
27
28fn matches(stream: TokenStream) -> TokenStream {
29 <TokenStream as FromIterator<TokenTree>>::from_iter([
30 Punct::new(':', Joint).into(),
31 Punct::new(':', Joint).into(),
32 Ident::new("core", Span::call_site()).into(),
33 Punct::new(':', Joint).into(),
34 Punct::new(':', Joint).into(),
35 Ident::new("matches", Span::call_site()).into(),
36 Punct::new('!', Joint).into(),
37 Group::new(Delimiter::Parenthesis, stream).into()
38 ])
39}
40
41fn first_elem(stream: TokenStream) -> TokenStream {
42 let stream = [
43 TokenStream::from(TokenTree::Punct(Punct::new('&', Joint))),
44 stream,
45 ].into_iter().collect();
46 <TokenStream as FromIterator<TokenTree>>::from_iter([
47 Punct::new(':', Joint).into(),
48 Punct::new(':', Joint).into(),
49 Ident::new("char_classes", Span::call_site()).into(),
50 Punct::new(':', Joint).into(),
51 Punct::new(':', Joint).into(),
52 Ident::new("FirstElem", Span::call_site()).into(),
53 Punct::new(':', Joint).into(),
54 Punct::new(':', Joint).into(),
55 Ident::new("first_elem", Span::call_site()).into(),
56 Group::new(Delimiter::Parenthesis, stream).into()
57 ])
58}
59
60enum Str {
61 Norm(String),
62 Byte(Vec<u8>),
63}
64
65fn lit_str(tt: &TokenTree) -> Result<Str, TokenStream> {
66 StringLit::try_from(tt)
67 .map(|s| Str::Norm(s.into_value().into_owned()))
68 .map_err(|e| e.to_string())
69 .or_else(|e| ByteStringLit::try_from(tt)
70 .map(|b| Str::Byte(b.into_value().into_owned()))
71 .map_err(|e2| format!("{e}\n{e2}")))
72 .map_err(|e| err(&e, tt.span()))
73}
74
75trait Spaned {
76 fn spaned(self, span: Span) -> Self;
77}
78impl Spaned for TokenTree {
79 fn spaned(mut self, span: Span) -> Self {
80 self.set_span(span);
81 self
82 }
83}
84impl Spaned for Literal {
85 fn spaned(mut self, span: Span) -> Self {
86 self.set_span(span);
87 self
88 }
89}
90impl Spaned for Punct {
91 fn spaned(mut self, span: Span) -> Self {
92 self.set_span(span);
93 self
94 }
95}
96
97trait ToPat: Sized {
98 fn to_pat(self, span: Span) -> TokenStream;
99}
100impl ToPat for u8 {
101 fn to_pat(self, span: Span) -> TokenStream {
102 TokenTree::from(Literal::byte_character(self).spaned(span)).into()
103 }
104}
105impl ToPat for char {
106 fn to_pat(self, span: Span) -> TokenStream {
107 TokenTree::from(Literal::character(self).spaned(span)).into()
108 }
109}
110impl<T: ToPat> ToPat for (T, T) {
111 fn to_pat(self, span: Span) -> TokenStream {
112 let (from, to) = self;
113 TokenStream::from_iter([
114 from.to_pat(span),
115 <TokenStream as FromIterator<TokenTree>>::from_iter([
116 Punct::new('.', Joint).into(),
117 Punct::new('.', Joint).into(),
118 Punct::new('=', Joint).into(),
119 ]),
120 to.to_pat(span),
121 ])
122 }
123}
124
125trait IsDash {
126 fn is_dash(&self) -> bool;
127}
128impl IsDash for u8 {
129 fn is_dash(&self) -> bool {
130 *self == b'-'
131 }
132}
133impl IsDash for char {
134 fn is_dash(&self) -> bool {
135 *self == '-'
136 }
137}
138
139fn some(stream: TokenStream, span: Span) -> TokenStream {
140 TokenStream::from_iter([
141 TokenTree::from(Ident::new("Some", span)),
142 TokenTree::from(Group::new(Delimiter::Parenthesis, stream)),
143 ])
144}
145
146fn to_pats<T, I>(iter: I, span: Span) -> Result<TokenStream, TokenStream>
147where T: ToPat + IsDash,
148 I: IntoIterator<Item = T>,
149{
150 let mut iter = iter.into_iter().peekable();
151 let Some(mut first) = iter.next() else {
152 return Err(err("cannot support empty pattern", span));
153 };
154 let mut result = TokenStream::new();
155 let mut sep: fn(&mut TokenStream) = |_| ();
156
157 while let Some(cur) = iter.next() {
158 sep(&mut result);
159
160 if let Some(to) = iter.next_if(|_| cur.is_dash()) {
161 result.extend([(first, to).to_pat(span)]);
162
163 if let Some(next) = iter.next() {
164 first = next;
165 } else {
166 return Ok(some(result, span));
167 }
168 } else {
169 result.extend([first.to_pat(span)]);
170 first = cur;
171 }
172
173 sep = |result| {
174 result.extend([TokenTree::from(Punct::new('|', Alone))]);
175 };
176 }
177
178 sep(&mut result);
179 result.extend([first.to_pat(span)]);
180 Ok(some(result, span))
181}
182
183#[proc_macro]
204pub fn any(input: TokenStream) -> TokenStream {
205 let mut iter = input.into_iter().peekable();
206 let not = iter.next_if(|tt| {
207 matches!(tt, TokenTree::Punct(p) if p.as_char() == '^')
208 }).map(|tt| Punct::new('!', Joint).spaned(tt.span()).into());
209 let Some(first) = iter.next() else {
210 return err("unexpected end of input, expected a literal", Span::call_site());
211 };
212 let comma = iter.next();
213 if comma.as_ref().is_some_and(|comma| {
214 !matches!(&comma, TokenTree::Punct(p) if p.as_char() == ',')
215 }) {
216 return err("unexpected token, expected a comma", comma.unwrap().span());
217 }
218 let lit_str = match lit_str(&first) {
219 Ok(s) => s,
220 Err(e) => return e,
221 };
222 match lit_str {
223 Str::Norm(s) => to_pats(s.chars(), first.span()),
224 Str::Byte(bytes) => to_pats(bytes, first.span()),
225 }.map_or_else(identity, |pat| {
226 if let Some(comma) = comma {
227 let expr = not.into_iter().chain(matches(
228 first_elem(iter.collect())
229 .into_iter()
230 .chain([comma])
231 .chain(pat)
232 .collect(),
233 )).collect();
234
235 TokenTree::from(Group::new(Delimiter::None, expr)).into()
236 } else {
237 let name = TokenTree::from(Ident::new("input", first.span()));
238 let mut comma = Punct::new(',', Alone);
239 comma.set_span(first.span());
240
241 let expr = once(Punct::new('|', Joint).into())
242 .chain([name.clone(), Punct::new('|', Alone).into()])
243 .chain(not)
244 .chain(matches(first_elem(name.into())
245 .into_iter()
246 .chain([comma.into()])
247 .chain(pat)
248 .collect()))
249 .collect();
250 TokenTree::from(Group::new(Delimiter::None, expr)).into()
251 }
252 })
253}