1use proc_macro::{token_stream, Delimiter, Group, Literal, Spacing, Span, TokenStream, TokenTree};
6use std::borrow::Cow;
7
8#[derive(Clone, Copy)]
9struct Sigil {
10 char: char,
11 len: usize,
12}
13
14#[proc_macro]
99pub fn repeat(input: TokenStream) -> TokenStream {
100 let mut input = ts_iter_fix(input);
101 let mut next = input.next();
102
103 let mut need_colon = false;
104
105 let sigil = if let Some(TokenTree::Punct(p)) = next {
106 need_colon = true;
107 let char = p.as_char();
108 let mut len = 1;
109 next = input.next();
110 let mut spacing = p.spacing();
111 while spacing == Spacing::Joint {
112 if let Some(TokenTree::Punct(ref p2)) = next {
113 if p2.as_char() == char {
114 len += 1;
115 spacing = p2.spacing();
116 next = input.next();
117 } else {
118 break;
119 }
120 } else {
121 return Error::new(p.span(), "joint spaced punct wasn't followed by punct").into();
122 }
123 }
124 Sigil { char, len }
125 } else {
126 Sigil { char: '#', len: 1 }
127 };
128
129 let loop_var = if let Some(TokenTree::Ident(ident)) = next {
130 need_colon = true;
131 next = input.next();
132 Some(ident.to_string())
133 } else {
134 None
135 };
136
137 'colon: {
138 if need_colon {
139 if let Some(TokenTree::Punct(p)) = &next {
140 if p.spacing() == Spacing::Alone && p.as_char() == ':' {
141 next = input.next();
142 break 'colon;
143 }
144 }
145 return Error::new(
146 next.map(|t| t.span()).unwrap_or_else(Span::call_site),
147 "expected `:` after sigil/loop variable",
148 )
149 .into();
150 }
151 }
152
153 let Some(TokenTree::Literal(repeat_count)) = &next else {
154 return Error::new(
155 next.map(|t| t.span()).unwrap_or_else(Span::call_site),
156 "expected integer literal as repeat count",
157 )
158 .into();
159 };
160 let Ok(repeat_count) = repeat_count.to_string().parse::<usize>() else {
161 return Error::new(
162 next.unwrap().span(),
163 "expected integer literal as repeat count",
164 )
165 .into();
166 };
167
168 next = input.next();
169 let Some(TokenTree::Punct(p0)) = next else {
170 return Error::new(
171 next.map(|t| t.span()).unwrap_or_else(Span::call_site),
172 "expected `=>` after repeat count",
173 )
174 .into();
175 };
176 if p0.spacing() != Spacing::Joint || p0.as_char() != '=' {
177 return Error::new(p0.span(), "expected `=>` after repeat count").into();
178 }
179 next = input.next();
180 let Some(TokenTree::Punct(p1)) = next else {
181 return Error::new(
182 next.map(|t| t.span()).unwrap_or_else(Span::call_site),
183 "expected `=>` after repeat count",
184 )
185 .into();
186 };
187 if p1.spacing() != Spacing::Alone || p1.as_char() != '>' {
188 return Error::new(p1.span(), "expected `=>` after repeat count").into();
189 }
190
191 let mut output = TokenStream::new();
192
193 if let Err(e) = process(&mut output, &mut input, sigil, &|token, output, input| {
194 if let TokenTree::Group(group) = &token {
195 if group.delimiter() == Delimiter::Parenthesis {
196 let delim = input.next();
197 let Some(TokenTree::Punct(p)) = &delim else {
198 return Err(Error::new(
199 delim.map(|t| t.span()).unwrap_or_else(|| group.span()),
200 "expected delimiter or `*` after closing parenthesis for loop",
201 ));
202 };
203 let delim = if p.as_char() != '*' {
204 let Some(TokenTree::Punct(p)) = input.next() else {
205 return Err(Error::new(p.span(), "expected `*` after loop delimiter"));
206 };
207 if p.as_char() != '*' {
208 return Err(Error::new(p.span(), "expected `*` after loop delimiter"));
209 }
210 delim
211 } else {
212 None
213 };
214 let group = ts_iter_fix(group.stream());
215 for i in 0..repeat_count {
216 let mut group = group.clone();
217
218 if i == 0 {
219 } else if let Some(delim) = &delim {
220 output.extend([delim.clone()]);
221 }
222
223 process(output, &mut group, sigil, &|token, output, _input| {
224 if let TokenTree::Ident(ident) = token {
225 let ident_s = ident.to_string();
226 if Some(&ident_s) == loop_var.as_ref() {
227 output.extend([TokenTree::Literal(Literal::usize_unsuffixed(i))]);
228 } else {
229 return Err(Error::new(
230 ident.span(),
231 format!("{ident_s} isn't a loop index"),
232 ));
233 }
234 } else if let TokenTree::Group(_) = token {
235 return Err(Error::new(token.span(), "can't loop in loop"));
236 } else {
237 let s = String::from(sigil.char).repeat(sigil.len) + &token.to_string();
238 return Err(Error::new(
239 token.span(),
240 format!("invalid sigiled token: `{s}`"),
241 ));
242 }
243 Ok(())
244 })?;
245 }
246 return Ok(());
247 }
248 } else if let TokenTree::Ident(_) = token {
249 return Err(Error::new(
250 token.span(),
251 "can't access loop index outside loop",
252 ));
253 }
254 let s = String::from(sigil.char).repeat(sigil.len) + &token.to_string();
255 Err(Error::new(
256 token.span(),
257 format!("invalid sigiled token: `{s}`"),
258 ))
259 }) {
260 return e.into();
261 }
262
263 output
264}
265
266struct Error(Span, Cow<'static, str>);
267
268impl Error {
269 pub fn new(span: Span, message: impl Into<Cow<'static, str>>) -> Self {
270 Self(span, message.into())
271 }
272}
273
274impl From<Error> for TokenStream {
275 fn from(value: Error) -> Self {
276 let tokens: TokenStream = format!("compile_error!({:?})", value.1).parse().unwrap();
277 let mut ts = TokenStream::new();
278 ts.extend(tokens.into_iter().map(|mut tt| {
279 tt.set_span(value.0);
280 tt
281 }));
282 ts
283 }
284}
285
286fn ts_iter_fix(ts: TokenStream) -> TsIter {
287 TsIter(ts.into_iter())
288}
289
290#[derive(Clone)]
291struct TsIter(token_stream::IntoIter);
292
293impl Iterator for TsIter {
294 type Item = TokenTree;
295
296 fn next(&mut self) -> Option<Self::Item> {
297 self.0.next().map(flatten_token_tree)
298 }
299}
300
301fn flatten_token_tree(tt: TokenTree) -> TokenTree {
302 if let TokenTree::Group(group) = &tt {
303 if group.delimiter() == Delimiter::None {
304 let mut it = group.stream().into_iter();
305 if let Some(token) = it.next() {
306 if it.next().is_none() {
307 return flatten_token_tree(token);
308 }
309 }
310 }
311 }
312 tt
313}
314
315fn process(
316 output: &mut TokenStream,
317 input: &mut TsIter,
318 sigil: Sigil,
319 handle: &impl Fn(TokenTree, &mut TokenStream, &mut TsIter) -> Result<(), Error>,
320) -> Result<(), Error> {
321 let mut accept_sigil = true;
322 let mut sigil_buf = Vec::with_capacity(sigil.len);
323
324 while let Some(token) = input.next() {
325 if let TokenTree::Punct(p) = &token {
326 if accept_sigil && p.as_char() == sigil.char {
327 accept_sigil = p.spacing() == Spacing::Joint;
328 sigil_buf.push(token);
329 continue;
330 }
331 }
332
333 accept_sigil = true;
334
335 if !sigil_buf.is_empty() {
336 if sigil_buf.len() == sigil.len {
337 sigil_buf.clear();
338 handle(token, output, input)?;
339 continue;
340 }
341 output.extend(sigil_buf.drain(..));
342 }
343
344 if let TokenTree::Group(group) = &token {
345 let mut group_output = TokenStream::new();
346 let mut input = ts_iter_fix(group.stream());
347 process(&mut group_output, &mut input, sigil, handle)?;
348 output.extend([TokenTree::Group(Group::new(
349 group.delimiter(),
350 group_output,
351 ))]);
352 } else {
353 output.extend([token]);
354 }
355 }
356
357 Ok(())
358}