repeater/
lib.rs

1//! # repeater
2//!
3//! This crate provides the [`repeat!`] macro. See the macro for how to use.
4
5use proc_macro::{token_stream, Delimiter, Group, Literal, Spacing, Span, TokenStream, TokenTree};
6use std::borrow::Cow;
7
8#[derive(Clone, Copy)]
9struct Sigil {
10    char: char,
11    len: usize,
12}
13
14/// In its simplest form, `repeat` takes a repeat count (an unsigned integer) followed by `=>`
15/// and then the tokens to repeat. Wrap sections to be repeated in `#(` and `)*`.
16///
17/// ```rust
18/// # use repeater::repeat;
19/// let n = repeat!(5 => 0 #( + 1 )*);
20/// assert_eq!(n, 5);
21/// ```
22///
23/// You can have as many repeating sections as you want.
24///
25/// ```rust
26/// # use repeater::repeat;
27/// let n = repeat!(5 => 0 #( + 1 )* #( + 1 )*);
28/// assert_eq!(n, 10);
29/// ```
30///
31/// You can't nest loops within a single invocation of `repeat`, but you can nest invocations
32/// for the same effect (see below).
33///
34/// ```compile_fail
35/// # use repeater::repeat;
36/// let n = repeat!(5 => 0 #(#( + 1 )*)*);
37/// ```
38///
39/// You can change the sigil used for repeat commands and variables by specifying it before the
40/// repeat count, followed by a colon.
41///
42/// ```rust
43/// # use repeater::repeat;
44/// let n = repeat!($: 5 => 0 $( + 1 )*);
45/// assert_eq!(n, 5);
46/// ```
47///
48/// You can also change the amount of sigils required.
49///
50/// ```rust
51/// # use repeater::repeat;
52/// let n = repeat!(###: 5 => 0 ###( + 1 )*);
53/// assert_eq!(n, 5);
54/// ```
55///
56/// To get access to a loop variable counting from 0 to the specified repeat count, put its
57/// name before the repeat count, followed by a colon. You can access this variable in repeated
58/// sections using the provided name (including the invocation's sigil).
59///
60/// ```rust
61/// # use repeater::repeat;
62/// let n = repeat!(#i: 5 => 0 #(+ #i)*);
63/// assert_eq!(n, 10);
64/// ```
65///
66/// The `#(`...`)*` loop syntax accepts a single punctuation token between the `)` and `*` at
67/// the end to insert that token between repeats. The example below expands to `1+2+3+4+5`.
68///
69/// ```rust
70/// # use repeater::repeat;
71/// let n = repeat!(#i: 5 => #(#i)+*);
72/// assert_eq!(n, 10);
73/// ```
74///
75/// Any tokens without the current sigil prefix are inserted as is.
76///
77/// ```rust
78/// # use repeater::repeat;
79/// let i = -1;
80/// let n = repeat!(#i: 5 => [#(#i, i),*]);
81/// assert_eq!(n, [0, -1, 1, -1, 2, -1, 3, -1, 4, -1]);
82/// ```
83///
84/// Invocations can be nested. By using different prefixes for the different invocations,
85/// you can mix their variables freely.
86///
87/// ```rust
88/// # use repeater::repeat;
89/// let tuple = repeat!(#i: 2 => repeat!(##j: 2 => (#(##((#i, ##j),)*)*)));
90/// assert_eq!(tuple, ((0, 0), (0, 1), (1, 0), (1, 1)));
91/// ```
92///
93/// ```rust
94/// # use repeater::repeat;
95/// let tuple = repeat!(#i: 2 => (#(repeat!(##j: 2 => (##((#i, ##j)),*))),*));
96/// assert_eq!(tuple, (((0, 0), (0, 1)), ((1, 0), (1, 1))));
97/// ```
98#[proc_macro]
99pub fn repeat(input: TokenStream) -> TokenStream {
100    let mut input = ts_iter_fix(input);
101    let mut next = input.next();
102
103    let mut need_colon = false;
104
105    let sigil = if let Some(TokenTree::Punct(p)) = next {
106        need_colon = true;
107        let char = p.as_char();
108        let mut len = 1;
109        next = input.next();
110        let mut spacing = p.spacing();
111        while spacing == Spacing::Joint {
112            if let Some(TokenTree::Punct(ref p2)) = next {
113                if p2.as_char() == char {
114                    len += 1;
115                    spacing = p2.spacing();
116                    next = input.next();
117                } else {
118                    break;
119                }
120            } else {
121                return Error::new(p.span(), "joint spaced punct wasn't followed by punct").into();
122            }
123        }
124        Sigil { char, len }
125    } else {
126        Sigil { char: '#', len: 1 }
127    };
128
129    let loop_var = if let Some(TokenTree::Ident(ident)) = next {
130        need_colon = true;
131        next = input.next();
132        Some(ident.to_string())
133    } else {
134        None
135    };
136
137    'colon: {
138        if need_colon {
139            if let Some(TokenTree::Punct(p)) = &next {
140                if p.spacing() == Spacing::Alone && p.as_char() == ':' {
141                    next = input.next();
142                    break 'colon;
143                }
144            }
145            return Error::new(
146                next.map(|t| t.span()).unwrap_or_else(Span::call_site),
147                "expected `:` after sigil/loop variable",
148            )
149            .into();
150        }
151    }
152
153    let Some(TokenTree::Literal(repeat_count)) = &next else {
154        return Error::new(
155            next.map(|t| t.span()).unwrap_or_else(Span::call_site),
156            "expected integer literal as repeat count",
157        )
158        .into();
159    };
160    let Ok(repeat_count) = repeat_count.to_string().parse::<usize>() else {
161        return Error::new(
162            next.unwrap().span(),
163            "expected integer literal as repeat count",
164        )
165        .into();
166    };
167
168    next = input.next();
169    let Some(TokenTree::Punct(p0)) = next else {
170        return Error::new(
171            next.map(|t| t.span()).unwrap_or_else(Span::call_site),
172            "expected `=>` after repeat count",
173        )
174        .into();
175    };
176    if p0.spacing() != Spacing::Joint || p0.as_char() != '=' {
177        return Error::new(p0.span(), "expected `=>` after repeat count").into();
178    }
179    next = input.next();
180    let Some(TokenTree::Punct(p1)) = next else {
181        return Error::new(
182            next.map(|t| t.span()).unwrap_or_else(Span::call_site),
183            "expected `=>` after repeat count",
184        )
185        .into();
186    };
187    if p1.spacing() != Spacing::Alone || p1.as_char() != '>' {
188        return Error::new(p1.span(), "expected `=>` after repeat count").into();
189    }
190
191    let mut output = TokenStream::new();
192
193    if let Err(e) = process(&mut output, &mut input, sigil, &|token, output, input| {
194        if let TokenTree::Group(group) = &token {
195            if group.delimiter() == Delimiter::Parenthesis {
196                let delim = input.next();
197                let Some(TokenTree::Punct(p)) = &delim else {
198                    return Err(Error::new(
199                        delim.map(|t| t.span()).unwrap_or_else(|| group.span()),
200                        "expected delimiter or `*` after closing parenthesis for loop",
201                    ));
202                };
203                let delim = if p.as_char() != '*' {
204                    let Some(TokenTree::Punct(p)) = input.next() else {
205                        return Err(Error::new(p.span(), "expected `*` after loop delimiter"));
206                    };
207                    if p.as_char() != '*' {
208                        return Err(Error::new(p.span(), "expected `*` after loop delimiter"));
209                    }
210                    delim
211                } else {
212                    None
213                };
214                let group = ts_iter_fix(group.stream());
215                for i in 0..repeat_count {
216                    let mut group = group.clone();
217
218                    if i == 0 {
219                    } else if let Some(delim) = &delim {
220                        output.extend([delim.clone()]);
221                    }
222
223                    process(output, &mut group, sigil, &|token, output, _input| {
224                        if let TokenTree::Ident(ident) = token {
225                            let ident_s = ident.to_string();
226                            if Some(&ident_s) == loop_var.as_ref() {
227                                output.extend([TokenTree::Literal(Literal::usize_unsuffixed(i))]);
228                            } else {
229                                return Err(Error::new(
230                                    ident.span(),
231                                    format!("{ident_s} isn't a loop index"),
232                                ));
233                            }
234                        } else if let TokenTree::Group(_) = token {
235                            return Err(Error::new(token.span(), "can't loop in loop"));
236                        } else {
237                            let s = String::from(sigil.char).repeat(sigil.len) + &token.to_string();
238                            return Err(Error::new(
239                                token.span(),
240                                format!("invalid sigiled token: `{s}`"),
241                            ));
242                        }
243                        Ok(())
244                    })?;
245                }
246                return Ok(());
247            }
248        } else if let TokenTree::Ident(_) = token {
249            return Err(Error::new(
250                token.span(),
251                "can't access loop index outside loop",
252            ));
253        }
254        let s = String::from(sigil.char).repeat(sigil.len) + &token.to_string();
255        Err(Error::new(
256            token.span(),
257            format!("invalid sigiled token: `{s}`"),
258        ))
259    }) {
260        return e.into();
261    }
262
263    output
264}
265
266struct Error(Span, Cow<'static, str>);
267
268impl Error {
269    pub fn new(span: Span, message: impl Into<Cow<'static, str>>) -> Self {
270        Self(span, message.into())
271    }
272}
273
274impl From<Error> for TokenStream {
275    fn from(value: Error) -> Self {
276        let tokens: TokenStream = format!("compile_error!({:?})", value.1).parse().unwrap();
277        let mut ts = TokenStream::new();
278        ts.extend(tokens.into_iter().map(|mut tt| {
279            tt.set_span(value.0);
280            tt
281        }));
282        ts
283    }
284}
285
286fn ts_iter_fix(ts: TokenStream) -> TsIter {
287    TsIter(ts.into_iter())
288}
289
290#[derive(Clone)]
291struct TsIter(token_stream::IntoIter);
292
293impl Iterator for TsIter {
294    type Item = TokenTree;
295
296    fn next(&mut self) -> Option<Self::Item> {
297        self.0.next().map(flatten_token_tree)
298    }
299}
300
301fn flatten_token_tree(tt: TokenTree) -> TokenTree {
302    if let TokenTree::Group(group) = &tt {
303        if group.delimiter() == Delimiter::None {
304            let mut it = group.stream().into_iter();
305            if let Some(token) = it.next() {
306                if it.next().is_none() {
307                    return flatten_token_tree(token);
308                }
309            }
310        }
311    }
312    tt
313}
314
315fn process(
316    output: &mut TokenStream,
317    input: &mut TsIter,
318    sigil: Sigil,
319    handle: &impl Fn(TokenTree, &mut TokenStream, &mut TsIter) -> Result<(), Error>,
320) -> Result<(), Error> {
321    let mut accept_sigil = true;
322    let mut sigil_buf = Vec::with_capacity(sigil.len);
323
324    while let Some(token) = input.next() {
325        if let TokenTree::Punct(p) = &token {
326            if accept_sigil && p.as_char() == sigil.char {
327                accept_sigil = p.spacing() == Spacing::Joint;
328                sigil_buf.push(token);
329                continue;
330            }
331        }
332
333        accept_sigil = true;
334
335        if !sigil_buf.is_empty() {
336            if sigil_buf.len() == sigil.len {
337                sigil_buf.clear();
338                handle(token, output, input)?;
339                continue;
340            }
341            output.extend(sigil_buf.drain(..));
342        }
343
344        if let TokenTree::Group(group) = &token {
345            let mut group_output = TokenStream::new();
346            let mut input = ts_iter_fix(group.stream());
347            process(&mut group_output, &mut input, sigil, handle)?;
348            output.extend([TokenTree::Group(Group::new(
349                group.delimiter(),
350                group_output,
351            ))]);
352        } else {
353            output.extend([token]);
354        }
355    }
356
357    Ok(())
358}