derive_io_macros/
lib.rs

1//! Support macros for `derive-io`. This is not intended to be used directly and
2//! has no stable API.
3
4use std::collections::HashSet;
5
6use proc_macro::*;
7
8/// `#[derive(Read)]`
9///
10/// Derives `std::io::Read` for the given struct.
11///
12/// Supported attributes:
13///
14/// - `#[read]`: Marks the field as a read stream.
15/// - `#[read(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
16/// - `#[read(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
17/// - `#[read(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
18#[proc_macro_derive(Read, attributes(read))]
19pub fn derive_io_read(input: TokenStream) -> TokenStream {
20    generate("derive_io", "derive_io_read", input)
21}
22
23/// `#[derive(Write)]`
24///
25/// Derives `std::io::Write` for the given struct.
26///
27/// Supported attributes:
28///
29/// - `#[write]`: Marks the field as a write stream.
30/// - `#[write(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
31/// - `#[write(deref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
32/// - `#[write(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
33#[proc_macro_derive(Write, attributes(write))]
34pub fn derive_io_write(input: TokenStream) -> TokenStream {
35    generate("derive_io", "derive_io_write", input)
36}
37
38/// `#[derive(AsyncRead)]`:
39///
40/// Derives `tokio::io::AsyncRead` for the given struct.
41///
42/// Supported attributes:
43///
44/// - `#[read]`: Marks the field as a read stream.
45/// - `#[read(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
46/// - `#[read(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
47/// - `#[read(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
48#[proc_macro_derive(AsyncRead, attributes(read))]
49pub fn derive_io_async_read(input: TokenStream) -> TokenStream {
50    generate("derive_io", "derive_io_async_read", input)
51}
52
53/// `#[derive(AsyncWrite)]`:
54///
55/// Derives `tokio::io::AsyncWrite` for the given struct.
56///
57/// Supported attributes:
58///
59/// - `#[write]`: Marks the field as a write stream.
60/// - `#[write(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
61/// - `#[write(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
62/// - `#[write(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
63#[proc_macro_derive(AsyncWrite, attributes(write))]
64pub fn derive_io_async_write(input: TokenStream) -> TokenStream {
65    generate("derive_io", "derive_io_async_write", input)
66}
67
68/// `#[derive(AsFileDescriptor)]`
69///
70/// Derives `std::os::fd::{AsFd, AsRawFd}` and `std::os::windows::io::{AsHandle, AsRawHandle}` for the given struct.
71///
72/// Supported attributes:
73///
74/// - `#[descriptor]`: Marks the field as a file descriptor.
75/// - `#[descriptor(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
76/// - `#[descriptor(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
77/// - `#[descriptor(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
78#[proc_macro_derive(AsFileDescriptor, attributes(descriptor))]
79pub fn derive_io_as_file_descriptor(input: TokenStream) -> TokenStream {
80    generate("derive_io", "derive_io_as_file_descriptor", input)
81}
82
83/// `#[derive(AsSocketDescriptor)]`
84///
85/// Derives `std::os::fd::{AsFd, AsRawFd}` and `std::os::windows::io::{AsSocket, AsRawSocket}` for the given struct.
86///
87/// Supported attributes:
88///
89/// - `#[descriptor]`: Marks the field as a socket descriptor.
90/// - `#[descriptor(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
91/// - `#[descriptor(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
92/// - `#[descriptor(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
93#[proc_macro_derive(AsSocketDescriptor, attributes(descriptor))]
94pub fn derive_io_as_socket_descriptor(input: TokenStream) -> TokenStream {
95    generate("derive_io", "derive_io_as_socket_descriptor", input)
96}
97
98/// Generates the equivalent of this Rust code as a TokenStream:
99///
100/// ```nocompile
101/// ::ctor::__support::ctor_parse!(#[ctor] fn foo() { ... });
102/// ::dtor::__support::dtor_parse!(#[dtor] fn foo() { ... });
103/// ```
104#[allow(unknown_lints, tail_expr_drop_order)]
105fn generate(macro_crate: &str, macro_type: &str, item: TokenStream) -> TokenStream {
106    let mut generics = TokenStream::new();
107    let mut where_clause = TokenStream::new();
108    let mut new_item = TokenStream::new();
109    let mut iterator = item.into_iter();
110
111    // Parse out generics and where clause into something easier for macros to digest:
112    //  - Generic bounds are moved to where clause, leaving just types/lifetimes
113    //  - If a generic has no bounds, we don't add it to the where clause
114    let mut in_generics = false;
115    let mut generics_ident = false;
116    let mut generics_accum = TokenStream::new();
117    let mut in_where_clause = false;
118    let mut in_generic_default = false;
119    let mut in_generic_const = false;
120    while let Some(token) = iterator.next() {
121        match token {
122            TokenTree::Punct(p) if !in_where_clause && p.as_char() == '<' => {
123                in_generics = true;
124                generics_ident = true;
125            }
126            TokenTree::Punct(ref p) if !in_where_clause && p.as_char() == '>' => {
127                if in_generics {
128                    in_generics = false;
129                    if generics_ident {
130                        generics.extend([TokenTree::Group(Group::new(
131                            Delimiter::Parenthesis,
132                            std::mem::take(&mut generics_accum),
133                        ))]);
134                    }
135                    if !generics_accum.is_empty() {
136                        panic!();
137                    }
138                }
139            }
140            TokenTree::Punct(ref p) if p.as_char() == ':' => {
141                if in_generics {
142                    if in_generic_const {
143                        generics_accum.extend([token]);
144                    } else {
145                        if generics_ident {
146                            generics.extend([TokenTree::Group(Group::new(
147                                Delimiter::Parenthesis,
148                                generics_accum.clone(),
149                            ))]);
150                        }
151                        generics_ident = false;
152                        where_clause.extend(std::mem::take(&mut generics_accum));
153                        where_clause.extend([token]);
154                    }
155                } else if in_where_clause {
156                    where_clause.extend([token]);
157                } else {
158                    new_item.extend([token]);
159                }
160            }
161            TokenTree::Punct(ref p) if p.as_char() == ',' => {
162                if in_generics {
163                    if generics_ident {
164                        generics.extend([TokenTree::Group(Group::new(
165                            Delimiter::Parenthesis,
166                            std::mem::take(&mut generics_accum),
167                        ))]);
168                    } else if !in_generic_const {
169                        where_clause.extend([token.clone()]);
170                    }
171                    generics.extend([token]);
172                    generics_ident = true;
173                    in_generic_default = false;
174                    in_generic_const = false;
175                } else if in_where_clause {
176                    where_clause.extend([token]);
177                } else {
178                    new_item.extend([token]);
179                }
180            }
181            TokenTree::Punct(ref p) if p.as_char() == '=' => {
182                if in_generics {
183                    generics_ident = false;
184                    in_generic_default = true;
185                    if in_generic_const {
186                        generics.extend([TokenTree::Group(Group::new(
187                            Delimiter::Parenthesis,
188                            std::mem::take(&mut generics_accum),
189                        ))]);
190                    }
191                }
192            }
193            TokenTree::Ident(ref l) if l.to_string() == "const" => {
194                panic!("const not yet supported");
195                // if in_generics {
196                //     generics_ident = true;
197                //     in_generic_const = true;
198                //     generics_accum.extend([token.clone()]);
199                // }
200            }
201            TokenTree::Ident(l) if l.to_string() == "where" => {
202                in_where_clause = true;
203            }
204            TokenTree::Group(ref p) if p.delimiter() == Delimiter::Brace => {
205                new_item.extend([token]);
206                break;
207            }
208            _ => {
209                if in_generics {
210                    if generics_ident {
211                        generics_accum.extend([token]);
212                    } else if !in_generic_default {
213                        where_clause.extend([token]);
214                    }
215                } else if in_where_clause {
216                    where_clause.extend([token]);
217                } else {
218                    new_item.extend([token]);
219                }
220            }
221        }
222    }
223    new_item.extend(iterator);
224
225    let mut inner = TokenStream::new();
226    inner.extend([
227        TokenTree::Group(Group::new(Delimiter::Parenthesis, new_item)),
228        TokenTree::Group(Group::new(Delimiter::Parenthesis, generics)),
229        TokenTree::Group(Group::new(Delimiter::Parenthesis, where_clause)),
230    ]);
231
232    let mut invoke = TokenStream::from_iter([
233        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
234        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
235        TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
236    ]);
237    invoke.extend([
238        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
239        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
240        TokenTree::Ident(Ident::new("__support", Span::call_site())),
241        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
242        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
243        TokenTree::Ident(Ident::new(
244            &format!("{}_parse", macro_type),
245            Span::call_site(),
246        )),
247        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
248        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
249        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
250    ]);
251
252    invoke
253}
254
255fn expect_any(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> TokenTree {
256    let next = iterator.next();
257    let Some(token) = next else {
258        panic!("Expected {} token, got end of stream", named);
259    };
260    token
261}
262
263fn expect_group(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Group {
264    let next = iterator.next();
265    let Some(TokenTree::Group(group)) = next else {
266        panic!("Expected {} group, got {:?}", named, next);
267    };
268    group
269}
270
271fn expect_ident(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Ident {
272    let next = iterator.next();
273    let Some(TokenTree::Ident(ident)) = next else {
274        panic!("Expected {} ident, got {:?}", named, next);
275    };
276    ident
277}
278
279fn expect_literal(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Literal {
280    let next = iterator.next();
281    if let Some(TokenTree::Group(ref group)) = next {
282        if group.delimiter() == Delimiter::None {
283            let mut iter = group.stream().into_iter();
284            return expect_literal(named, &mut iter);
285        }
286    }
287    let Some(TokenTree::Literal(literal)) = next else {
288        panic!("Expected {} literal, got {:?}", named, next);
289    };
290    literal
291}
292
293fn expect_punct(named: char, iterator: &mut impl Iterator<Item = TokenTree>) -> Punct {
294    let next = iterator.next();
295    let Some(TokenTree::Punct(punct)) = next else {
296        panic!("Expected {} punct, got {:?}", named, next);
297    };
298    if punct.as_char() != named {
299        panic!("Expected {} punct, got {:?}", named, punct);
300    }
301    punct
302}
303
304/// Unwrap a grouped meta element to its final group.
305fn expect_is_meta(named: &str, mut attr: TokenTree) -> Group {
306    let outer = attr.clone();
307    while let TokenTree::Group(group) = attr {
308        let mut iter = group.clone().stream().into_iter();
309        let first = iter
310            .next()
311            .expect("Expected attr group to have one element");
312        if let TokenTree::Ident(_) = first {
313            return Group::new(Delimiter::Bracket, group.stream());
314        }
315        attr = first;
316    }
317    panic!("Expected meta group {named}, got {outer}");
318}
319
320/// [
321///   (__next__) (args) expected_attr {on_error}
322///   ((attr attr) (item)) ((attr attr) (item))
323/// ] -> __next__!((args) (item))
324#[proc_macro]
325pub fn find_annotated(input: TokenStream) -> TokenStream {
326    let mut iterator = input.into_iter();
327
328    let next_macro = expect_group("__next__ macro", &mut iterator);
329    let args = expect_group("__next__ arguments", &mut iterator);
330    let expected_attr = expect_ident("expected_attr", &mut iterator);
331    let on_error = expect_group("on_error", &mut iterator);
332
333    while let Some(token) = iterator.next() {
334        let TokenTree::Group(check) = token else {
335            panic!("Expected check group");
336        };
337        let mut iter = check.stream().into_iter();
338        let attrs = expect_group("attrs", &mut iter);
339        let item = expect_any("item", &mut iter);
340        let mut index = 0;
341        for attr in attrs.stream().into_iter() {
342            let attr = expect_is_meta("attr", attr);
343            let first = expect_ident("first attr", &mut attr.clone().stream().into_iter());
344            if first.to_string() == expected_attr.to_string() {
345                let mut next = next_macro.stream();
346                next.extend([
347                    TokenTree::Punct(Punct::new('!', Spacing::Alone)),
348                    TokenTree::Group(Group::new(
349                        Delimiter::Parenthesis,
350                        TokenStream::from_iter([
351                            TokenTree::Group(args),
352                            TokenTree::Literal(Literal::usize_unsuffixed(index)),
353                            TokenTree::Group(attr),
354                            item,
355                        ]),
356                    )),
357                    TokenTree::Punct(Punct::new(';', Spacing::Alone)),
358                ]);
359                return next;
360            }
361            index += 1;
362        }
363    }
364
365    on_error.stream()
366}
367
368/// [(__next__) (args) expected_attr {on_error}
369///  ( (id) (([attr] [attr]) (item)) (([attr] [attr]) (item)) )
370///  ( (id) (([attr] [attr]) (item)) (([attr] [attr]) (item)) )
371/// ] -> __next__!((args) ((id) (item)))
372#[proc_macro]
373pub fn find_annotated_multi(input: TokenStream) -> TokenStream {
374    let mut iterator = input.into_iter();
375
376    let next_macro = expect_group("__next__ macro", &mut iterator);
377    let args = expect_group("__next__ arguments", &mut iterator);
378    let expected_attr = expect_ident("expected_attr", &mut iterator);
379    let on_error = expect_group("on_error", &mut iterator);
380    let mut output = TokenStream::new();
381
382    'outer: while let Some(token) = iterator.next() {
383        let TokenTree::Group(id) = token else {
384            panic!("Expected id group");
385        };
386        let mut iter = id.stream().into_iter();
387        let id = expect_group("id", &mut iter);
388        let mut index = 0;
389        while let Some(token) = iter.next() {
390            let TokenTree::Group(check) = token else {
391                panic!("Expected check group");
392            };
393            let mut iter = check.stream().into_iter();
394            let attrs = expect_group("attrs", &mut iter);
395            let item = expect_any("item", &mut iter);
396            for attr in attrs.stream().into_iter() {
397                let attr = expect_is_meta("attr", attr);
398                let first = expect_ident("first attr", &mut attr.clone().stream().into_iter());
399                if first.to_string() == expected_attr.to_string() {
400                    output.extend([TokenTree::Group(Group::new(
401                        Delimiter::Parenthesis,
402                        TokenStream::from_iter([
403                            TokenTree::Group(id.clone()),
404                            TokenTree::Literal(Literal::usize_unsuffixed(index)),
405                            TokenTree::Group(attr),
406                            item.clone(),
407                        ]),
408                    ))]);
409                    continue 'outer;
410                }
411            }
412            index += 1;
413        }
414        return on_error.stream();
415    }
416
417    let mut next = next_macro.stream();
418    next.extend([
419        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
420        TokenTree::Group(Group::new(
421            Delimiter::Parenthesis,
422            TokenStream::from_iter([
423                TokenTree::Group(args),
424                TokenTree::Group(Group::new(Delimiter::Parenthesis, output)),
425            ]),
426        )),
427        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
428    ]);
429    next
430}
431
432/// [prefix count repeated suffix] -> prefix (repeated*count suffix)
433#[proc_macro]
434pub fn repeat_in_parenthesis(input: TokenStream) -> TokenStream {
435    let mut iterator = input.into_iter();
436    let prefix = expect_group("prefix", &mut iterator);
437    let count = expect_literal("count", &mut iterator);
438    let count =
439        usize::from_str_radix(&count.to_string(), 10).expect("Expected count to be a number");
440    let repeated = expect_group("repeated", &mut iterator);
441    let suffix = expect_group("suffix", &mut iterator);
442    let mut repeat = TokenStream::new();
443    for _ in 0..count {
444        repeat.extend(repeated.clone().stream());
445    }
446    repeat.extend(suffix.stream());
447    let mut output = TokenStream::new();
448    output.extend(prefix.stream());
449    output.extend([TokenTree::Group(Group::new(Delimiter::Parenthesis, repeat))]);
450    output
451}
452
453// needle haystack(key=value,key=value) default -> extracted OR default
454#[proc_macro]
455pub fn extract_meta(input: TokenStream) -> TokenStream {
456    let mut iterator = input.into_iter();
457    let needle = expect_ident("needle", &mut iterator);
458    let haystack = expect_group("haystack", &mut iterator);
459    let default = expect_group("default", &mut iterator);
460
461    let mut haystack = haystack.stream().into_iter();
462
463    loop {
464        let attr = haystack.next();
465        if let Some(TokenTree::Group(ref group)) = attr {
466            haystack = group.stream().into_iter();
467            continue;
468        }
469        break;
470    }
471
472    loop {
473        let key = haystack.next();
474        if let Some(TokenTree::Group(ref group)) = key {
475            haystack = group.stream().into_iter();
476            continue;
477        }
478        let Some(key) = key else {
479            break;
480        };
481        let Some(next) = haystack.next() else {
482            break;
483        };
484        // Ignore simple attributes
485        let TokenTree::Punct(punct) = next else {
486            panic!("Expected = after key, got {:?}", next);
487        };
488
489        if punct.as_char() == ',' {
490            continue;
491        }
492
493        let value = expect_ident("value", &mut haystack);
494
495        if key.to_string() == needle.to_string() {
496            return TokenStream::from_iter([TokenTree::Ident(value)]);
497        }
498
499        if haystack.next().is_none() {
500            break;
501        }
502    }
503
504    default.stream()
505}
506
507#[proc_macro]
508pub fn if_meta(input: TokenStream) -> TokenStream {
509    let mut iterator = input.into_iter();
510    let needle = expect_ident("needle", &mut iterator);
511    let haystack = expect_group("haystack", &mut iterator);
512    let if_true = expect_group("if_true", &mut iterator);
513    let if_false = expect_group("if_false", &mut iterator);
514
515    let mut haystack = haystack.stream().into_iter();
516
517    loop {
518        let attr = haystack.next();
519        if let Some(TokenTree::Group(ref group)) = attr {
520            haystack = group.stream().into_iter();
521            continue;
522        }
523        break;
524    }
525
526    loop {
527        let key = haystack.next();
528        if let Some(TokenTree::Group(ref group)) = key {
529            haystack = group.stream().into_iter();
530            continue;
531        }
532        let Some(key) = key else {
533            break;
534        };
535        if key.to_string() == needle.to_string() {
536            return if_true.stream();
537        }
538    }
539
540    if_false.stream()
541}
542
543#[proc_macro]
544pub fn type_has_generic(input: TokenStream) -> TokenStream {
545    let mut iterator = input.into_iter();
546
547    let type_ = expect_group("type", &mut iterator);
548    let generic = expect_group("generics", &mut iterator);
549    let if_true = expect_group("if_true", &mut iterator);
550    let if_false = expect_group("if_false", &mut iterator);
551
552    fn recursive_collect_generics(generics: &mut HashSet<String>, type_tokens: TokenStream) {
553        let mut iterator = type_tokens.into_iter();
554        while let Some(token) = iterator.next() {
555            if let TokenTree::Ident(ident) = &token {
556                generics.insert(ident.to_string());
557            } else if let TokenTree::Group(group) = token {
558                recursive_collect_generics(generics, group.stream());
559            }
560        }
561    }
562
563    let mut generics = HashSet::new();
564    recursive_collect_generics(&mut generics, generic.stream());
565
566    fn recursive_check_generics(generics: &HashSet<String>, type_tokens: TokenStream) -> bool {
567        let mut iterator = type_tokens.into_iter();
568        while let Some(token) = iterator.next() {
569            if let TokenTree::Ident(ident) = &token {
570                if generics.contains(&ident.to_string()) {
571                    return true;
572                }
573            } else if let TokenTree::Group(group) = token {
574                if recursive_check_generics(generics, group.stream()) {
575                    return true;
576                }
577            }
578        }
579        false
580    }
581
582    if recursive_check_generics(&generics, type_.stream()) {
583        if_true.stream()
584    } else {
585        if_false.stream()
586    }
587}