derive_io_macros/
lib.rs

1//! Support macros for `derive-io`. This is not intended to be used directly and
2//! has no stable API.
3
4use std::collections::HashSet;
5
6use proc_macro::*;
7
8/// `#[derive(Read)]`
9#[proc_macro_derive(Read, attributes(read))]
10pub fn derive_io_read(input: TokenStream) -> TokenStream {
11    generate("derive_io", "derive_io_read", input)
12}
13
14/// `#[derive(Write)]`
15#[proc_macro_derive(Write, attributes(write))]
16pub fn derive_io_write(input: TokenStream) -> TokenStream {
17    generate("derive_io", "derive_io_write", input)
18}
19
20/// `#[derive(AsyncRead)]`: tokio::io::AsyncRead
21#[proc_macro_derive(AsyncRead, attributes(read))]
22pub fn derive_io_async_read(input: TokenStream) -> TokenStream {
23    generate("derive_io", "derive_io_async_read", input)
24}
25
26/// `#[derive(AsyncWrite)]`: tokio::io::AsyncWrite
27#[proc_macro_derive(AsyncWrite, attributes(write))]
28pub fn derive_io_async_write(input: TokenStream) -> TokenStream {
29    generate("derive_io", "derive_io_async_write", input)
30}
31
32/// `#[derive(AsFileDescriptor)]`: `std::os::{AsFd, AsRawFd}`, `std::os::windows::io::{AsHandle, AsRawHandle}`
33#[proc_macro_derive(AsFileDescriptor, attributes(descriptor))]
34pub fn derive_io_as_file_descriptor(input: TokenStream) -> TokenStream {
35    generate("derive_io", "derive_io_as_file_descriptor", input)
36}
37
38/// `#[derive(AsSocketDescriptor)]`: `std::os::{AsFd, AsRawFd}`, `std::os::{AsSocket, AsRawSocket}`
39#[proc_macro_derive(AsSocketDescriptor, attributes(descriptor))]
40pub fn derive_io_as_socket_descriptor(input: TokenStream) -> TokenStream {
41    generate("derive_io", "derive_io_as_socket_descriptor", input)
42}
43
44/// Generates the equivalent of this Rust code as a TokenStream:
45///
46/// ```nocompile
47/// ::ctor::__support::ctor_parse!(#[ctor] fn foo() { ... });
48/// ::dtor::__support::dtor_parse!(#[dtor] fn foo() { ... });
49/// ```
50#[allow(unknown_lints, tail_expr_drop_order)]
51fn generate(macro_crate: &str, macro_type: &str, item: TokenStream) -> TokenStream {
52    let mut generics = TokenStream::new();
53    let mut where_clause = TokenStream::new();
54    let mut new_item = TokenStream::new();
55    let mut iterator = item.into_iter();
56
57    // Parse out generics and where clause into something easier for macros to digest:
58    //  - Generic bounds are moved to where clause, leaving just types/lifetimes
59    //  - If a generic has no bounds, we don't add it to the where clause
60    let mut in_generics = false;
61    let mut generics_ident = false;
62    let mut generics_accum = TokenStream::new();
63    let mut in_where_clause = false;
64    let mut in_generic_default = false;
65    let mut in_generic_const = false;
66    while let Some(token) = iterator.next() {
67        match token {
68            TokenTree::Punct(p) if !in_where_clause && p.as_char() == '<' => {
69                in_generics = true;
70                generics_ident = true;
71            }
72            TokenTree::Punct(ref p) if !in_where_clause && p.as_char() == '>' => {
73                if in_generics {
74                    in_generics = false;
75                    if generics_ident {
76                        generics.extend([TokenTree::Group(Group::new(
77                            Delimiter::Parenthesis,
78                            std::mem::take(&mut generics_accum),
79                        ))]);
80                    }
81                    if !generics_accum.is_empty() {
82                        panic!();
83                    }
84                }
85            }
86            TokenTree::Punct(ref p) if p.as_char() == ':' => {
87                if in_generics {
88                    if in_generic_const {
89                        generics_accum.extend([token]);
90                    } else {
91                        if generics_ident {
92                            generics.extend([TokenTree::Group(Group::new(
93                                Delimiter::Parenthesis,
94                                generics_accum.clone(),
95                            ))]);
96                        }
97                        generics_ident = false;
98                        where_clause.extend(std::mem::take(&mut generics_accum));
99                        where_clause.extend([token]);
100                    }
101                } else if in_where_clause {
102                    where_clause.extend([token]);
103                } else {
104                    new_item.extend([token]);
105                }
106            }
107            TokenTree::Punct(ref p) if p.as_char() == ',' => {
108                if in_generics {
109                    if generics_ident {
110                        generics.extend([TokenTree::Group(Group::new(
111                            Delimiter::Parenthesis,
112                            std::mem::take(&mut generics_accum),
113                        ))]);
114                    } else if !in_generic_const {
115                        where_clause.extend([token.clone()]);
116                    }
117                    generics.extend([token]);
118                    generics_ident = true;
119                    in_generic_default = false;
120                    in_generic_const = false;
121                } else if in_where_clause {
122                    where_clause.extend([token]);
123                } else {
124                    new_item.extend([token]);
125                }
126            }
127            TokenTree::Punct(ref p) if p.as_char() == '=' => {
128                if in_generics {
129                    generics_ident = false;
130                    in_generic_default = true;
131                    if in_generic_const {
132                        generics.extend([TokenTree::Group(Group::new(
133                            Delimiter::Parenthesis,
134                            std::mem::take(&mut generics_accum),
135                        ))]);
136                    }
137                }
138            }
139            TokenTree::Ident(ref l) if l.to_string() == "const" => {
140                panic!("const not yet supported");
141                // if in_generics {
142                //     generics_ident = true;
143                //     in_generic_const = true;
144                //     generics_accum.extend([token.clone()]);
145                // }
146            }
147            TokenTree::Ident(l) if l.to_string() == "where" => {
148                in_where_clause = true;
149            }
150            TokenTree::Group(ref p) if p.delimiter() == Delimiter::Brace => {
151                new_item.extend([token]);
152                break;
153            }
154            _ => {
155                if in_generics {
156                    if generics_ident {
157                        generics_accum.extend([token]);
158                    } else if !in_generic_default {
159                        where_clause.extend([token]);
160                    }
161                } else if in_where_clause {
162                    where_clause.extend([token]);
163                } else {
164                    new_item.extend([token]);
165                }
166            }
167        }
168    }
169    new_item.extend(iterator);
170
171    let mut inner = TokenStream::new();
172    inner.extend([
173        TokenTree::Group(Group::new(Delimiter::Parenthesis, new_item)),
174        TokenTree::Group(Group::new(Delimiter::Parenthesis, generics)),
175        TokenTree::Group(Group::new(Delimiter::Parenthesis, where_clause)),
176    ]);
177
178    let mut invoke = TokenStream::from_iter([
179        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
180        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
181        TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
182    ]);
183    invoke.extend([
184        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
185        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
186        TokenTree::Ident(Ident::new("__support", Span::call_site())),
187        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
188        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
189        TokenTree::Ident(Ident::new(
190            &format!("{}_parse", macro_type),
191            Span::call_site(),
192        )),
193        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
194        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
195        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
196    ]);
197
198    invoke
199}
200
201fn expect_any(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> TokenTree {
202    let next = iterator.next();
203    let Some(token) = next else {
204        panic!("Expected {} token, got end of stream", named);
205    };
206    token
207}
208
209fn expect_group(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Group {
210    let next = iterator.next();
211    let Some(TokenTree::Group(group)) = next else {
212        panic!("Expected {} group, got {:?}", named, next);
213    };
214    group
215}
216
217fn expect_ident(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Ident {
218    let next = iterator.next();
219    let Some(TokenTree::Ident(ident)) = next else {
220        panic!("Expected {} ident, got {:?}", named, next);
221    };
222    ident
223}
224
225fn expect_literal(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Literal {
226    let next = iterator.next();
227    if let Some(TokenTree::Group(ref group)) = next {
228        if group.delimiter() == Delimiter::None {
229            let mut iter = group.stream().into_iter();
230            return expect_literal(named, &mut iter);
231        }
232    }
233    let Some(TokenTree::Literal(literal)) = next else {
234        panic!("Expected {} literal, got {:?}", named, next);
235    };
236    literal
237}
238
239fn expect_punct(named: char, iterator: &mut impl Iterator<Item = TokenTree>) -> Punct {
240    let next = iterator.next();
241    let Some(TokenTree::Punct(punct)) = next else {
242        panic!("Expected {} punct, got {:?}", named, next);
243    };
244    if punct.as_char() != named {
245        panic!("Expected {} punct, got {:?}", named, punct);
246    }
247    punct
248}
249
250/// Unwrap a grouped meta element to its final group.
251fn expect_is_meta(named: &str, mut attr: TokenTree) -> Group {
252    let outer = attr.clone();
253    while let TokenTree::Group(group) = attr {
254        let mut iter = group.clone().stream().into_iter();
255        let first = iter
256            .next()
257            .expect("Expected attr group to have one element");
258        if let TokenTree::Ident(_) = first {
259            return Group::new(Delimiter::Bracket, group.stream());
260        }
261        attr = first;
262    }
263    panic!("Expected meta group {named}, got {outer}");
264}
265
266/// [
267///   (__next__) (args) expected_attr {on_error}
268///   ((attr attr) (item)) ((attr attr) (item))
269/// ] -> __next__!((args) (item))
270#[proc_macro]
271pub fn find_annotated(input: TokenStream) -> TokenStream {
272    let mut iterator = input.into_iter();
273
274    let next_macro = expect_group("__next__ macro", &mut iterator);
275    let args = expect_group("__next__ arguments", &mut iterator);
276    let expected_attr = expect_ident("expected_attr", &mut iterator);
277    let on_error = expect_group("on_error", &mut iterator);
278
279    while let Some(token) = iterator.next() {
280        let TokenTree::Group(check) = token else {
281            panic!("Expected check group");
282        };
283        let mut iter = check.stream().into_iter();
284        let attrs = expect_group("attrs", &mut iter);
285        let item = expect_any("item", &mut iter);
286        let mut index = 0;
287        for attr in attrs.stream().into_iter() {
288            let attr = expect_is_meta("attr", attr);
289            let first = expect_ident("first attr", &mut attr.clone().stream().into_iter());
290            if first.to_string() == expected_attr.to_string() {
291                let mut next = next_macro.stream();
292                next.extend([
293                    TokenTree::Punct(Punct::new('!', Spacing::Alone)),
294                    TokenTree::Group(Group::new(
295                        Delimiter::Parenthesis,
296                        TokenStream::from_iter([
297                            TokenTree::Group(args),
298                            TokenTree::Literal(Literal::usize_unsuffixed(index)),
299                            TokenTree::Group(attr),
300                            item,
301                        ]),
302                    )),
303                    TokenTree::Punct(Punct::new(';', Spacing::Alone)),
304                ]);
305                return next;
306            }
307            index += 1;
308        }
309    }
310
311    on_error.stream()
312}
313
314/// [(__next__) (args) expected_attr {on_error}
315///  ( (id) (([attr] [attr]) (item)) (([attr] [attr]) (item)) )
316///  ( (id) (([attr] [attr]) (item)) (([attr] [attr]) (item)) )
317/// ] -> __next__!((args) ((id) (item)))
318#[proc_macro]
319pub fn find_annotated_multi(input: TokenStream) -> TokenStream {
320    let mut iterator = input.into_iter();
321
322    let next_macro = expect_group("__next__ macro", &mut iterator);
323    let args = expect_group("__next__ arguments", &mut iterator);
324    let expected_attr = expect_ident("expected_attr", &mut iterator);
325    let on_error = expect_group("on_error", &mut iterator);
326    let mut output = TokenStream::new();
327
328    'outer: while let Some(token) = iterator.next() {
329        let TokenTree::Group(id) = token else {
330            panic!("Expected id group");
331        };
332        let mut iter = id.stream().into_iter();
333        let id = expect_group("id", &mut iter);
334        let mut index = 0;
335        while let Some(token) = iter.next() {
336            let TokenTree::Group(check) = token else {
337                panic!("Expected check group");
338            };
339            let mut iter = check.stream().into_iter();
340            let attrs = expect_group("attrs", &mut iter);
341            let item = expect_any("item", &mut iter);
342            for attr in attrs.stream().into_iter() {
343                let attr = expect_is_meta("attr", attr);
344                let first = expect_ident("first attr", &mut attr.clone().stream().into_iter());
345                if first.to_string() == expected_attr.to_string() {
346                    output.extend([TokenTree::Group(Group::new(
347                        Delimiter::Parenthesis,
348                        TokenStream::from_iter([
349                            TokenTree::Group(id.clone()),
350                            TokenTree::Literal(Literal::usize_unsuffixed(index)),
351                            TokenTree::Group(attr),
352                            item.clone(),
353                        ]),
354                    ))]);
355                    continue 'outer;
356                }
357            }
358            index += 1;
359        }
360        return on_error.stream();
361    }
362
363    let mut next = next_macro.stream();
364    next.extend([
365        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
366        TokenTree::Group(Group::new(
367            Delimiter::Parenthesis,
368            TokenStream::from_iter([
369                TokenTree::Group(args),
370                TokenTree::Group(Group::new(Delimiter::Parenthesis, output)),
371            ]),
372        )),
373        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
374    ]);
375    next
376}
377
378/// [prefix count repeated suffix] -> prefix (repeated*count suffix)
379#[proc_macro]
380pub fn repeat_in_parenthesis(input: TokenStream) -> TokenStream {
381    let mut iterator = input.into_iter();
382    let prefix = expect_group("prefix", &mut iterator);
383    let count = expect_literal("count", &mut iterator);
384    let count =
385        usize::from_str_radix(&count.to_string(), 10).expect("Expected count to be a number");
386    let repeated = expect_group("repeated", &mut iterator);
387    let suffix = expect_group("suffix", &mut iterator);
388    let mut repeat = TokenStream::new();
389    for _ in 0..count {
390        repeat.extend(repeated.clone().stream());
391    }
392    repeat.extend(suffix.stream());
393    let mut output = TokenStream::new();
394    output.extend(prefix.stream());
395    output.extend([TokenTree::Group(Group::new(Delimiter::Parenthesis, repeat))]);
396    output
397}
398
399// needle haystack(key=value,key=value) default -> extracted OR default
400#[proc_macro]
401pub fn extract_meta(input: TokenStream) -> TokenStream {
402    let mut iterator = input.into_iter();
403    let needle = expect_ident("needle", &mut iterator);
404    let haystack = expect_group("haystack", &mut iterator);
405    let default = expect_group("default", &mut iterator);
406
407    let mut haystack = haystack.stream().into_iter();
408
409    loop {
410        let attr = haystack.next();
411        if let Some(TokenTree::Group(ref group)) = attr {
412            haystack = group.stream().into_iter();
413            continue;
414        }
415        break;
416    }
417
418    loop {
419        let key = haystack.next();
420        if let Some(TokenTree::Group(ref group)) = key {
421            haystack = group.stream().into_iter();
422            continue;
423        }
424        let Some(key) = key else {
425            break;
426        };
427        expect_punct('=', &mut haystack);
428        let value = expect_ident("value", &mut haystack);
429
430        if key.to_string() == needle.to_string() {
431            return TokenStream::from_iter([TokenTree::Ident(value)]);
432        }
433
434        if haystack.next().is_none() {
435            break;
436        }
437    }
438
439    default.stream()
440}
441
442#[proc_macro]
443pub fn type_has_generic(input: TokenStream) -> TokenStream {
444    let mut iterator = input.into_iter();
445
446    let type_ = expect_group("type", &mut iterator);
447    let generic = expect_group("generics", &mut iterator);
448    let if_true = expect_group("if_true", &mut iterator);
449    let if_false = expect_group("if_false", &mut iterator);
450
451    fn recursive_collect_generics(generics: &mut HashSet<String>, type_tokens: TokenStream) {
452        let mut iterator = type_tokens.into_iter();
453        while let Some(token) = iterator.next() {
454            if let TokenTree::Ident(ident) = &token {
455                generics.insert(ident.to_string());
456            } else if let TokenTree::Group(group) = token {
457                recursive_collect_generics(generics, group.stream());
458            }
459        }
460    }
461
462    let mut generics = HashSet::new();
463    recursive_collect_generics(&mut generics, generic.stream());
464
465    fn recursive_check_generics(generics: &HashSet<String>, type_tokens: TokenStream) -> bool {
466        let mut iterator = type_tokens.into_iter();
467        while let Some(token) = iterator.next() {
468            if let TokenTree::Ident(ident) = &token {
469                if generics.contains(&ident.to_string()) {
470                    return true;
471                }
472            } else if let TokenTree::Group(group) = token {
473                if recursive_check_generics(generics, group.stream()) {
474                    return true;
475                }
476            }
477        }
478        false
479    }
480
481    if recursive_check_generics(&generics, type_.stream()) {
482        if_true.stream()
483    } else {
484        if_false.stream()
485    }
486}