derive_io_macros/
lib.rs

1//! Support macros for `derive-io`. This is not intended to be used directly and
2//! has no stable API.
3
4use proc_macro::*;
5
6/// `#[derive(Read)]`
7#[proc_macro_derive(Read, attributes(read))]
8pub fn derive_io_read(input: TokenStream) -> TokenStream {
9    generate("derive_io", "derive_io_read", input)
10}
11
12/// `#[derive(Write)]`
13#[proc_macro_derive(Write, attributes(write))]
14pub fn derive_io_write(input: TokenStream) -> TokenStream {
15    generate("derive_io", "derive_io_write", input)
16}
17
18/// `#[derive(AsyncRead)]`
19#[proc_macro_derive(AsyncRead, attributes(read))]
20pub fn derive_io_async_read(input: TokenStream) -> TokenStream {
21    generate("derive_io", "derive_io_async_read", input)
22}
23
24/// `#[derive(AsyncWrite)]`
25#[proc_macro_derive(AsyncWrite, attributes(write))]
26pub fn derive_io_async_write(input: TokenStream) -> TokenStream {
27    generate("derive_io", "derive_io_async_write", input)
28}
29
30/// Generates the equivalent of this Rust code as a TokenStream:
31///
32/// ```nocompile
33/// ::ctor::__support::ctor_parse!(#[ctor] fn foo() { ... });
34/// ::dtor::__support::dtor_parse!(#[dtor] fn foo() { ... });
35/// ```
36#[allow(unknown_lints, tail_expr_drop_order)]
37fn generate(macro_crate: &str, macro_type: &str, item: TokenStream) -> TokenStream {
38    let mut generics = TokenStream::new();
39    let mut where_clause = TokenStream::new();
40    let mut new_item = TokenStream::new();
41    let mut iterator = item.into_iter();
42
43    // Parse out generics and where clause into something easier for macros to digest:
44    //  - Generic bounds are moved to where clause, leaving just types/lifetimes
45    //  - If a generic has no bounds, we don't add it to the where clause
46    let mut in_generics = false;
47    let mut generics_ident = false;
48    let mut generics_accum = TokenStream::new();
49    let mut in_where_clause = false;
50    while let Some(token) = iterator.next() {
51        match token {
52            TokenTree::Punct(p) if p.as_char() == '<' => {
53                in_generics = true;
54                generics_ident = true;
55            }
56            TokenTree::Punct(ref p) if p.as_char() == '>' => {
57                if in_generics {
58                    in_generics = false;
59                    if generics_ident {
60                        generics.extend(std::mem::take(&mut generics_accum));
61                    }
62                }
63            }
64            TokenTree::Punct(ref p) if p.as_char() == ':' => {
65                if in_generics {
66                    generics_ident = false;
67                    generics.extend(generics_accum.clone());
68                    where_clause.extend(std::mem::take(&mut generics_accum));
69                    where_clause.extend([token]);
70                } else if in_where_clause {
71                    where_clause.extend([token]);
72                } else {
73                    new_item.extend([token]);
74                }
75            }
76            TokenTree::Punct(ref p) if p.as_char() == ',' => {
77                if in_generics {
78                    if generics_ident {
79                        generics.extend(std::mem::take(&mut generics_accum));
80                        generics.extend([token.clone()]);
81                    } else {
82                        where_clause.extend([token]);
83                    }
84                    generics_ident = true;
85                } else if in_where_clause {
86                    where_clause.extend([token]);
87                } else {
88                    new_item.extend([token]);
89                }
90            }
91            TokenTree::Ident(l) if l.to_string() == "where" => {
92                in_where_clause = true;
93            }
94            TokenTree::Group(ref p) if p.delimiter() == Delimiter::Brace => {
95                new_item.extend([token]);
96                break;
97            }
98            _ => {
99                if in_generics {
100                    if generics_ident {
101                        generics_accum.extend([token]);
102                    } else {
103                        where_clause.extend([token]);
104                    }
105                } else if in_where_clause {
106                    where_clause.extend([token]);
107                } else {
108                    new_item.extend([token]);
109                }
110            }
111        }
112    }
113    new_item.extend(iterator);
114
115    let mut inner = TokenStream::new();
116    inner.extend([
117        TokenTree::Group(Group::new(Delimiter::Parenthesis, new_item)),
118        TokenTree::Group(Group::new(Delimiter::Parenthesis, generics)),
119        TokenTree::Group(Group::new(Delimiter::Parenthesis, where_clause)),
120    ]);
121
122    let mut invoke = TokenStream::from_iter([
123        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
124        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
125        TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
126    ]);
127    invoke.extend([
128        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
129        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
130        TokenTree::Ident(Ident::new("__support", Span::call_site())),
131        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
132        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
133        TokenTree::Ident(Ident::new(
134            &format!("{}_parse", macro_type),
135            Span::call_site(),
136        )),
137        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
138        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
139        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
140    ]);
141
142    invoke
143}
144
145fn expect_any(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> TokenTree {
146    let next = iterator.next();
147    let Some(token) = next else {
148        panic!("Expected {} token, got end of stream", named);
149    };
150    token
151}
152
153fn expect_group(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Group {
154    let next = iterator.next();
155    let Some(TokenTree::Group(group)) = next else {
156        panic!("Expected {} group, got {:?}", named, next);
157    };
158    group
159}
160
161fn expect_ident(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Ident {
162    let next = iterator.next();
163    let Some(TokenTree::Ident(ident)) = next else {
164        panic!("Expected {} ident, got {:?}", named, next);
165    };
166    ident
167}
168
169fn expect_literal(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Literal {
170    let next = iterator.next();
171    if let Some(TokenTree::Group(ref group)) = next {
172        if group.delimiter() == Delimiter::None {
173            let mut iter = group.stream().into_iter();
174            return expect_literal(named, &mut iter);
175        }
176    }
177    let Some(TokenTree::Literal(literal)) = next else {
178        panic!("Expected {} literal, got {:?}", named, next);
179    };
180    literal
181}
182
183fn expect_punct(named: char, iterator: &mut impl Iterator<Item = TokenTree>) -> Punct {
184    let next = iterator.next();
185    let Some(TokenTree::Punct(punct)) = next else {
186        panic!("Expected {} punct, got {:?}", named, next);
187    };
188    if punct.as_char() != named {
189        panic!("Expected {} punct, got {:?}", named, punct);
190    }
191    punct
192}
193
194/// Unwrap a grouped meta element to its final group.
195fn expect_is_meta(named: &str, mut attr: TokenTree) -> Group {
196    let outer = attr.clone();
197    while let TokenTree::Group(group) = attr {
198        let mut iter = group.clone().stream().into_iter();
199        let first = iter
200            .next()
201            .expect("Expected attr group to have one element");
202        if let TokenTree::Ident(_) = first {
203            return Group::new(Delimiter::Bracket, group.stream());
204        }
205        attr = first;
206    }
207    panic!("Expected meta group {named}, got {outer}");
208}
209
210/// [
211///   (__next__) (args) expected_attr {on_error}
212///   ((attr attr) (item)) ((attr attr) (item))
213/// ] -> __next__!((args) (item))
214#[proc_macro]
215pub fn find_annotated(input: TokenStream) -> TokenStream {
216    let mut iterator = input.into_iter();
217
218    let next_macro = expect_group("__next__ macro", &mut iterator);
219    let args = expect_group("__next__ arguments", &mut iterator);
220    let expected_attr = expect_ident("expected_attr", &mut iterator);
221    let on_error = expect_group("on_error", &mut iterator);
222
223    while let Some(token) = iterator.next() {
224        let TokenTree::Group(check) = token else {
225            panic!("Expected check group");
226        };
227        let mut iter = check.stream().into_iter();
228        let attrs = expect_group("attrs", &mut iter);
229        let item = expect_any("item", &mut iter);
230        let mut index = 0;
231        for attr in attrs.stream().into_iter() {
232            let attr = expect_is_meta("attr", attr);
233            let first = expect_ident("first attr", &mut attr.clone().stream().into_iter());
234            if first.to_string() == expected_attr.to_string() {
235                let mut next = next_macro.stream();
236                next.extend([
237                    TokenTree::Punct(Punct::new('!', Spacing::Alone)),
238                    TokenTree::Group(Group::new(
239                        Delimiter::Parenthesis,
240                        TokenStream::from_iter([
241                            TokenTree::Group(args),
242                            TokenTree::Literal(Literal::usize_unsuffixed(index)),
243                            TokenTree::Group(attr),
244                            item,
245                        ]),
246                    )),
247                    TokenTree::Punct(Punct::new(';', Spacing::Alone)),
248                ]);
249                return next;
250            }
251            index += 1;
252        }
253    }
254
255    on_error.stream()
256}
257
258/// [(__next__) (args) expected_attr {on_error}
259///  ( (id) (([attr] [attr]) (item)) (([attr] [attr]) (item)) )
260///  ( (id) (([attr] [attr]) (item)) (([attr] [attr]) (item)) )
261/// ] -> __next__!((args) ((id) (item)))
262#[proc_macro]
263pub fn find_annotated_multi(input: TokenStream) -> TokenStream {
264    let mut iterator = input.into_iter();
265
266    let next_macro = expect_group("__next__ macro", &mut iterator);
267    let args = expect_group("__next__ arguments", &mut iterator);
268    let expected_attr = expect_ident("expected_attr", &mut iterator);
269    let on_error = expect_group("on_error", &mut iterator);
270    let mut output = TokenStream::new();
271
272    'outer: while let Some(token) = iterator.next() {
273        let TokenTree::Group(id) = token else {
274            panic!("Expected id group");
275        };
276        let mut iter = id.stream().into_iter();
277        let id = expect_group("id", &mut iter);
278        let mut index = 0;
279        while let Some(token) = iter.next() {
280            let TokenTree::Group(check) = token else {
281                panic!("Expected check group");
282            };
283            let mut iter = check.stream().into_iter();
284            let attrs = expect_group("attrs", &mut iter);
285            let item = expect_any("item", &mut iter);
286            for attr in attrs.stream().into_iter() {
287                let attr = expect_is_meta("attr", attr);
288                let first = expect_ident("first attr", &mut attr.clone().stream().into_iter());
289                if first.to_string() == expected_attr.to_string() {
290                    output.extend([TokenTree::Group(Group::new(
291                        Delimiter::Parenthesis,
292                        TokenStream::from_iter([
293                            TokenTree::Group(id.clone()),
294                            TokenTree::Literal(Literal::usize_unsuffixed(index)),
295                            TokenTree::Group(attr),
296                            item.clone(),
297                        ]),
298                    ))]);
299                    continue 'outer;
300                }
301            }
302            index += 1;
303        }
304        return on_error.stream();
305    }
306
307    let mut next = next_macro.stream();
308    next.extend([
309        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
310        TokenTree::Group(Group::new(
311            Delimiter::Parenthesis,
312            TokenStream::from_iter([
313                TokenTree::Group(args),
314                TokenTree::Group(Group::new(Delimiter::Parenthesis, output)),
315            ]),
316        )),
317        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
318    ]);
319    next
320}
321
322/// [prefix count repeated suffix] -> prefix (repeated*count suffix)
323#[proc_macro]
324pub fn repeat_in_parenthesis(input: TokenStream) -> TokenStream {
325    let mut iterator = input.into_iter();
326    let prefix = expect_group("prefix", &mut iterator);
327    let count = expect_literal("count", &mut iterator);
328    let count =
329        usize::from_str_radix(&count.to_string(), 10).expect("Expected count to be a number");
330    let repeated = expect_group("repeated", &mut iterator);
331    let suffix = expect_group("suffix", &mut iterator);
332    let mut repeat = TokenStream::new();
333    for _ in 0..count {
334        repeat.extend(repeated.clone().stream());
335    }
336    repeat.extend(suffix.stream());
337    let mut output = TokenStream::new();
338    output.extend(prefix.stream());
339    output.extend([TokenTree::Group(Group::new(Delimiter::Parenthesis, repeat))]);
340    output
341}
342
343// needle haystack(key=value,key=value) default -> extracted OR default
344#[proc_macro]
345pub fn extract_meta(input: TokenStream) -> TokenStream {
346    let mut iterator = input.into_iter();
347    let needle = expect_ident("needle", &mut iterator);
348    let haystack = expect_group("haystack", &mut iterator);
349    let default = expect_group("default", &mut iterator);
350
351    let mut haystack = haystack.stream().into_iter();
352
353    loop {
354        let attr = haystack.next();
355        if let Some(TokenTree::Group(ref group)) = attr {
356            haystack = group.stream().into_iter();
357            continue;
358        }
359        break;
360    }
361
362    loop {
363        let key = haystack.next();
364        if let Some(TokenTree::Group(ref group)) = key {
365            haystack = group.stream().into_iter();
366            continue;
367        }
368        let Some(key) = key else {
369            break;
370        };
371        expect_punct('=', &mut haystack);
372        let value = expect_ident("value", &mut haystack);
373
374        if key.to_string() == needle.to_string() {
375            return TokenStream::from_iter([TokenTree::Ident(value)]);
376        }
377
378        if haystack.next().is_none() {
379            break;
380        }
381    }
382
383    default.stream()
384}