derive_io_macros/
lib.rs

1//! Support macros for `derive-io`. This is not intended to be used directly and
2//! has no stable API.
3
4use std::collections::HashSet;
5
6use proc_macro::*;
7
8/// `#[derive(Read)]`
9///
10/// Derives `std::io::Read` for the given struct.
11///
12/// Supported attributes:
13///
14/// - `#[read]`: Marks the field as a read stream.
15/// - `#[read(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
16/// - `#[read(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
17/// - `#[read(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
18#[proc_macro_derive(Read, attributes(read, write, descriptor))]
19pub fn derive_io_read(input: TokenStream) -> TokenStream {
20    generate("derive_io", "derive_io_read", input)
21}
22
23/// `#[derive(BufRead)]`
24///
25/// Derives `std::io::BufRead` for the given struct. `std::io::Read` must also
26/// be implemented.
27///
28/// Unsupported methods:
29///
30/// - `split` (std-internal implementation)
31/// - `lines` (std-internal implementation)
32/// - `has_data_left` (unstable feature)
33///
34/// Supported attributes:
35///
36/// - `#[read]`: Marks the field as a read stream.
37/// - `#[read(as_ref)]`: Delegates the field to the inner type using
38///   `AsRef`/`AsMut`.
39/// - `#[read(deref)]`: Delegates the field to the inner type using
40///   `Deref`/`DerefMut`.
41/// - `#[read(<function>=<override>)]`: Overrides the default `<function>`
42///   method with the given override function.
43#[proc_macro_derive(BufRead, attributes(read, write, descriptor))]
44pub fn derive_io_bufread(input: TokenStream) -> TokenStream {
45    generate("derive_io", "derive_io_bufread", input)
46}
47
48/// `#[derive(Write)]`
49///
50/// Derives `std::io::Write` for the given struct.
51///
52/// Supported attributes:
53///
54/// - `#[write]`: Marks the field as a write stream.
55/// - `#[write(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
56/// - `#[write(deref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
57/// - `#[write(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
58#[proc_macro_derive(Write, attributes(read, write, descriptor))]
59pub fn derive_io_write(input: TokenStream) -> TokenStream {
60    generate("derive_io", "derive_io_write", input)
61}
62
63/// `#[derive(AsyncRead)]`:
64///
65/// Derives `tokio::io::AsyncRead` for the given struct.
66///
67/// Supported attributes:
68///
69/// - `#[read]`: Marks the field as a read stream.
70/// - `#[read(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
71/// - `#[read(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
72/// - `#[read(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
73#[proc_macro_derive(AsyncRead, attributes(read, write, descriptor, duck))]
74pub fn derive_io_async_read(input: TokenStream) -> TokenStream {
75    generate("derive_io", "derive_io_async_read", input)
76}
77
78/// `#[derive(AsyncWrite)]`:
79///
80/// Derives `tokio::io::AsyncWrite` for the given struct.
81///
82/// Supported attributes:
83///
84/// - `#[write]`: Marks the field as a write stream.
85/// - `#[write(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
86/// - `#[write(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
87/// - `#[write(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
88#[proc_macro_derive(AsyncWrite, attributes(read, write, descriptor, duck))]
89pub fn derive_io_async_write(input: TokenStream) -> TokenStream {
90    generate("derive_io", "derive_io_async_write", input)
91}
92
93/// `#[derive(AsFileDescriptor)]`
94///
95/// Derives `std::os::fd::{AsFd, AsRawFd}` and `std::os::windows::io::{AsHandle, AsRawHandle}` for the given struct.
96///
97/// Supported attributes:
98///
99/// - `#[descriptor]`: Marks the field as a file descriptor.
100/// - `#[descriptor(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
101/// - `#[descriptor(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
102/// - `#[descriptor(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
103#[proc_macro_derive(AsFileDescriptor, attributes(read, write, descriptor, duck))]
104pub fn derive_io_as_file_descriptor(input: TokenStream) -> TokenStream {
105    generate("derive_io", "derive_io_as_file_descriptor", input)
106}
107
108/// `#[derive(AsSocketDescriptor)]`
109///
110/// Derives `std::os::fd::{AsFd, AsRawFd}` and `std::os::windows::io::{AsSocket, AsRawSocket}` for the given struct.
111///
112/// Supported attributes:
113///
114/// - `#[descriptor]`: Marks the field as a socket descriptor.
115/// - `#[descriptor(as_ref)]`: Delegates the field to the inner type using `AsRef`/`AsMut`.
116/// - `#[descriptor(deref)]`: Delegates the field to the inner type using `Deref`/`DerefMut`.
117/// - `#[descriptor(<function>=<override>)]`: Overrides the default `<function>` method with the given override function.
118#[proc_macro_derive(AsSocketDescriptor, attributes(read, write, descriptor, duck))]
119pub fn derive_io_as_socket_descriptor(input: TokenStream) -> TokenStream {
120    generate("derive_io", "derive_io_as_socket_descriptor", input)
121}
122
123/// Generates the equivalent of this Rust code as a TokenStream:
124///
125/// ```nocompile
126/// ::ctor::__support::ctor_parse!(#[ctor] fn foo() { ... });
127/// ::dtor::__support::dtor_parse!(#[dtor] fn foo() { ... });
128/// ```
129#[allow(unknown_lints, tail_expr_drop_order)]
130fn generate(macro_crate: &str, macro_type: &str, item: TokenStream) -> TokenStream {
131    let mut generics = TokenStream::new();
132    let mut where_clause = TokenStream::new();
133    let mut new_item = TokenStream::new();
134    let mut iterator = item.into_iter();
135
136    // Parse out generics and where clause into something easier for macros to digest:
137    //  - Generic bounds are moved to where clause, leaving just types/lifetimes
138    //  - If a generic has no bounds, we don't add it to the where clause
139    let mut in_generics = false;
140    let mut generics_ident = false;
141    let mut generics_accum = TokenStream::new();
142    let mut in_where_clause = false;
143    let mut in_generic_default = false;
144    let mut in_generic_const = false;
145    for token in iterator.by_ref() {
146        match token {
147            TokenTree::Punct(p) if !in_where_clause && p.as_char() == '<' => {
148                in_generics = true;
149                generics_ident = true;
150            }
151            TokenTree::Punct(ref p) if !in_where_clause && p.as_char() == '>' => {
152                if in_generics {
153                    in_generics = false;
154                    if generics_ident {
155                        generics.extend([TokenTree::Group(Group::new(
156                            Delimiter::Parenthesis,
157                            std::mem::take(&mut generics_accum),
158                        ))]);
159                    }
160                    if !generics_accum.is_empty() {
161                        panic!();
162                    }
163                }
164            }
165            TokenTree::Punct(ref p) if p.as_char() == ':' => {
166                if in_generics {
167                    if in_generic_const {
168                        generics_accum.extend([token]);
169                    } else {
170                        if generics_ident {
171                            generics.extend([TokenTree::Group(Group::new(
172                                Delimiter::Parenthesis,
173                                generics_accum.clone(),
174                            ))]);
175                        }
176                        generics_ident = false;
177                        where_clause.extend(std::mem::take(&mut generics_accum));
178                        where_clause.extend([token]);
179                    }
180                } else if in_where_clause {
181                    where_clause.extend([token]);
182                } else {
183                    new_item.extend([token]);
184                }
185            }
186            TokenTree::Punct(ref p) if p.as_char() == ',' => {
187                if in_generics {
188                    if generics_ident {
189                        generics.extend([TokenTree::Group(Group::new(
190                            Delimiter::Parenthesis,
191                            std::mem::take(&mut generics_accum),
192                        ))]);
193                    } else if !in_generic_const {
194                        where_clause.extend([token.clone()]);
195                    }
196                    generics.extend([token]);
197                    generics_ident = true;
198                    in_generic_default = false;
199                    in_generic_const = false;
200                } else if in_where_clause {
201                    where_clause.extend([token]);
202                } else {
203                    new_item.extend([token]);
204                }
205            }
206            TokenTree::Punct(ref p) if p.as_char() == '=' => {
207                if in_generics {
208                    generics_ident = false;
209                    in_generic_default = true;
210                    if in_generic_const {
211                        generics.extend([TokenTree::Group(Group::new(
212                            Delimiter::Parenthesis,
213                            std::mem::take(&mut generics_accum),
214                        ))]);
215                    }
216                }
217            }
218            TokenTree::Ident(ref l) if l.to_string() == "const" => {
219                panic!("const not yet supported");
220                // if in_generics {
221                //     generics_ident = true;
222                //     in_generic_const = true;
223                //     generics_accum.extend([token.clone()]);
224                // }
225            }
226            TokenTree::Ident(l) if l.to_string() == "where" => {
227                in_where_clause = true;
228            }
229            TokenTree::Group(ref p) if p.delimiter() == Delimiter::Brace => {
230                new_item.extend([token]);
231                break;
232            }
233            _ => {
234                if in_generics {
235                    if generics_ident {
236                        generics_accum.extend([token]);
237                    } else if !in_generic_default {
238                        where_clause.extend([token]);
239                    }
240                } else if in_where_clause {
241                    where_clause.extend([token]);
242                } else {
243                    new_item.extend([token]);
244                }
245            }
246        }
247    }
248    new_item.extend(iterator);
249
250    let mut inner = TokenStream::new();
251    inner.extend([
252        TokenTree::Group(Group::new(Delimiter::Parenthesis, new_item)),
253        TokenTree::Group(Group::new(Delimiter::Parenthesis, generics)),
254        TokenTree::Group(Group::new(Delimiter::Parenthesis, where_clause)),
255    ]);
256
257    let mut invoke = TokenStream::from_iter([
258        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
259        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
260        TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
261    ]);
262    invoke.extend([
263        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
264        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
265        TokenTree::Ident(Ident::new("__support", Span::call_site())),
266        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
267        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
268        TokenTree::Ident(Ident::new(
269            &format!("{macro_type}_parse"),
270            Span::call_site(),
271        )),
272        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
273        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
274        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
275    ]);
276
277    invoke
278}
279
280fn expect_any(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> TokenTree {
281    let next = iterator.next();
282    let Some(token) = next else {
283        panic!("Expected {named} token, got end of stream");
284    };
285    token
286}
287
288fn expect_group(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Group {
289    let next = iterator.next();
290    let Some(TokenTree::Group(group)) = next else {
291        panic!("Expected {named} group, got {next:?}");
292    };
293    group
294}
295
296fn expect_ident(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Ident {
297    let next = iterator.next();
298    let Some(TokenTree::Ident(ident)) = next else {
299        panic!("Expected {named} ident, got {next:?}");
300    };
301    ident
302}
303
304fn expect_literal(named: &str, iterator: &mut impl Iterator<Item = TokenTree>) -> Literal {
305    let next = iterator.next();
306    if let Some(TokenTree::Group(ref group)) = next {
307        if group.delimiter() == Delimiter::None {
308            let mut iter = group.stream().into_iter();
309            return expect_literal(named, &mut iter);
310        }
311    }
312    let Some(TokenTree::Literal(literal)) = next else {
313        panic!("Expected {named} literal, got {next:?}");
314    };
315    literal
316}
317
318/// Unwrap a grouped meta element to its final group.
319fn expect_is_meta(named: &str, mut attr: TokenTree) -> Group {
320    let outer = attr.clone();
321    while let TokenTree::Group(group) = attr {
322        let mut iter = group.clone().stream().into_iter();
323        let first = iter
324            .next()
325            .expect("Expected attr group to have one element");
326        if let TokenTree::Ident(_) = first {
327            return Group::new(Delimiter::Bracket, group.stream());
328        }
329        attr = first;
330    }
331    panic!("Expected meta group {named}, got {outer}");
332}
333
334/// [
335///   (__next__) (args) expected_attr {on_error}
336///   ((attr attr) (item)) ((attr attr) (item))
337/// ] -> __next__!((args) (item))
338#[proc_macro]
339pub fn find_annotated(input: TokenStream) -> TokenStream {
340    let mut iterator = input.into_iter();
341
342    let next_macro = expect_group("__next__ macro", &mut iterator);
343    let args = expect_group("__next__ arguments", &mut iterator);
344    let expected_attr = expect_ident("expected_attr", &mut iterator);
345    let on_error = expect_group("on_error", &mut iterator);
346
347    for token in iterator {
348        let TokenTree::Group(check) = token else {
349            panic!("Expected check group");
350        };
351        let mut iter = check.stream().into_iter();
352        let attrs = expect_group("attrs", &mut iter);
353        let item = expect_any("item", &mut iter);
354        for (index, attr) in attrs.stream().into_iter().enumerate() {
355            let attr = expect_is_meta("attr", attr);
356            let first = expect_ident("first attr", &mut attr.clone().stream().into_iter());
357            if first.to_string() == expected_attr.to_string() {
358                let mut next = next_macro.stream();
359                next.extend([
360                    TokenTree::Punct(Punct::new('!', Spacing::Alone)),
361                    TokenTree::Group(Group::new(
362                        Delimiter::Parenthesis,
363                        TokenStream::from_iter([
364                            TokenTree::Group(args),
365                            TokenTree::Literal(Literal::usize_unsuffixed(index)),
366                            TokenTree::Group(attr),
367                            item,
368                        ]),
369                    )),
370                    TokenTree::Punct(Punct::new(';', Spacing::Alone)),
371                ]);
372                return next;
373            }
374        }
375    }
376
377    on_error.stream()
378}
379
380/// [(__next__) (args) expected_attr {on_error}
381///  ( (id) (([attr] [attr]) (item)) (([attr] [attr]) (item)) )
382///  ( (id) (([attr] [attr]) (item)) (([attr] [attr]) (item)) )
383/// ] -> __next__!((args) ((id) (item)))
384#[proc_macro]
385pub fn find_annotated_multi(input: TokenStream) -> TokenStream {
386    let mut iterator = input.into_iter();
387
388    let next_macro = expect_group("__next__ macro", &mut iterator);
389    let args = expect_group("__next__ arguments", &mut iterator);
390    let expected_attr = expect_ident("expected_attr", &mut iterator);
391    let on_error = expect_group("on_error", &mut iterator);
392    let mut output = TokenStream::new();
393
394    'outer: for token in iterator {
395        let TokenTree::Group(id) = token else {
396            panic!("Expected id group");
397        };
398        let mut iter = id.stream().into_iter();
399        let id = expect_group("id", &mut iter);
400        let mut index = 0;
401        for token in iter {
402            let TokenTree::Group(check) = token else {
403                panic!("Expected check group");
404            };
405            let mut iter = check.stream().into_iter();
406            let attrs = expect_group("attrs", &mut iter);
407            let item = expect_any("item", &mut iter);
408            for attr in attrs.stream().into_iter() {
409                let attr = expect_is_meta("attr", attr);
410                let first = expect_ident("first attr", &mut attr.clone().stream().into_iter());
411                if first.to_string() == expected_attr.to_string() {
412                    output.extend([TokenTree::Group(Group::new(
413                        Delimiter::Parenthesis,
414                        TokenStream::from_iter([
415                            TokenTree::Group(id.clone()),
416                            TokenTree::Literal(Literal::usize_unsuffixed(index)),
417                            TokenTree::Group(attr),
418                            item.clone(),
419                        ]),
420                    ))]);
421                    continue 'outer;
422                }
423            }
424            index += 1;
425        }
426        return on_error.stream();
427    }
428
429    let mut next = next_macro.stream();
430    next.extend([
431        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
432        TokenTree::Group(Group::new(
433            Delimiter::Parenthesis,
434            TokenStream::from_iter([
435                TokenTree::Group(args),
436                TokenTree::Group(Group::new(Delimiter::Parenthesis, output)),
437            ]),
438        )),
439        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
440    ]);
441    next
442}
443
444/// [prefix count repeated suffix] -> prefix (repeated*count suffix)
445#[proc_macro]
446pub fn repeat_in_parenthesis(input: TokenStream) -> TokenStream {
447    let mut iterator = input.into_iter();
448    let prefix = expect_group("prefix", &mut iterator);
449    let count = expect_literal("count", &mut iterator);
450    let count: usize = str::parse(&count.to_string()).expect("Expected count to be a number");
451    let repeated = expect_group("repeated", &mut iterator);
452    let suffix = expect_group("suffix", &mut iterator);
453    let mut repeat = TokenStream::new();
454    for _ in 0..count {
455        repeat.extend(repeated.clone().stream());
456    }
457    repeat.extend(suffix.stream());
458    let mut output = TokenStream::new();
459    output.extend(prefix.stream());
460    output.extend([TokenTree::Group(Group::new(Delimiter::Parenthesis, repeat))]);
461    output
462}
463
464// needle haystack(key=value,key=value) default -> extracted OR default
465#[proc_macro]
466pub fn extract_meta(input: TokenStream) -> TokenStream {
467    let mut iterator = input.into_iter();
468    let needle = expect_ident("needle", &mut iterator);
469    let haystack = expect_group("haystack", &mut iterator);
470    let default = expect_group("default", &mut iterator);
471
472    let mut haystack = haystack.stream().into_iter();
473
474    loop {
475        let attr = haystack.next();
476        if let Some(TokenTree::Group(ref group)) = attr {
477            haystack = group.stream().into_iter();
478            continue;
479        }
480        break;
481    }
482
483    loop {
484        let key = haystack.next();
485        if let Some(TokenTree::Group(ref group)) = key {
486            haystack = group.stream().into_iter();
487            continue;
488        }
489        let Some(key) = key else {
490            break;
491        };
492        let Some(next) = haystack.next() else {
493            break;
494        };
495        // Ignore simple attributes
496        let TokenTree::Punct(punct) = next else {
497            panic!("Expected = after key, got {next:?}");
498        };
499
500        if punct.as_char() == ',' {
501            continue;
502        }
503
504        let value = expect_ident("value", &mut haystack);
505
506        if key.to_string() == needle.to_string() {
507            return TokenStream::from_iter([TokenTree::Ident(value)]);
508        }
509
510        if haystack.next().is_none() {
511            break;
512        }
513    }
514
515    default.stream()
516}
517
518#[proc_macro]
519pub fn if_meta(input: TokenStream) -> TokenStream {
520    let mut iterator = input.into_iter();
521    let needle = expect_ident("needle", &mut iterator);
522    let haystack = expect_group("haystack", &mut iterator);
523    let if_true = expect_group("if_true", &mut iterator);
524    let if_false = expect_group("if_false", &mut iterator);
525
526    let mut haystack = haystack.stream().into_iter();
527
528    loop {
529        let attr = haystack.next();
530        if let Some(TokenTree::Group(ref group)) = attr {
531            haystack = group.stream().into_iter();
532            continue;
533        }
534        break;
535    }
536
537    loop {
538        let key = haystack.next();
539        if let Some(TokenTree::Group(ref group)) = key {
540            haystack = group.stream().into_iter();
541            continue;
542        }
543        let Some(key) = key else {
544            break;
545        };
546        if key.to_string() == needle.to_string() {
547            return if_true.stream();
548        }
549    }
550
551    if_false.stream()
552}
553
554#[proc_macro]
555pub fn type_has_generic(input: TokenStream) -> TokenStream {
556    let mut iterator = input.into_iter();
557
558    let type_ = expect_group("type", &mut iterator);
559    let generic = expect_group("generics", &mut iterator);
560    let if_true = expect_group("if_true", &mut iterator);
561    let if_false = expect_group("if_false", &mut iterator);
562
563    fn recursive_collect_generics(generics: &mut HashSet<String>, type_tokens: TokenStream) {
564        let iterator = type_tokens.into_iter();
565        for token in iterator {
566            if let TokenTree::Ident(ident) = &token {
567                generics.insert(ident.to_string());
568            } else if let TokenTree::Group(group) = token {
569                recursive_collect_generics(generics, group.stream());
570            }
571        }
572    }
573
574    let mut generics = HashSet::new();
575    recursive_collect_generics(&mut generics, generic.stream());
576
577    fn recursive_check_generics(generics: &HashSet<String>, type_tokens: TokenStream) -> bool {
578        let iterator = type_tokens.into_iter();
579        for token in iterator {
580            if let TokenTree::Ident(ident) = &token {
581                if generics.contains(&ident.to_string()) {
582                    return true;
583                }
584            } else if let TokenTree::Group(group) = token {
585                if recursive_check_generics(generics, group.stream()) {
586                    return true;
587                }
588            }
589        }
590        false
591    }
592
593    if recursive_check_generics(&generics, type_.stream()) {
594        if_true.stream()
595    } else {
596        if_false.stream()
597    }
598}