Skip to main content

async_select_proc_macros/
lib.rs

1//! Procedural macros for `select!`.
2
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, ToTokens};
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Index, Pat, Result, Token};
7
8mod kw {
9    syn::custom_keyword!(complete);
10}
11
12struct Clause {
13    expr: Expr,
14}
15
16impl Parse for Clause {
17    fn parse(input: ParseStream<'_>) -> Result<Self> {
18        input.parse::<Token![=>]>()?;
19        let expr = Expr::parse_with_earlier_boundary_rule(input)?;
20        if matches!(expr, Expr::Block(_)) {
21            input.parse::<Option<Token![,]>>()?;
22        } else if !input.is_empty() {
23            input.parse::<Token![,]>()?;
24        }
25        Ok(Clause { expr })
26    }
27}
28
29impl ToTokens for Clause {
30    fn to_tokens(&self, tokens: &mut TokenStream) {
31        self.expr.to_tokens(tokens)
32    }
33}
34
35struct Condition {
36    expr: Expr,
37}
38
39impl Parse for Condition {
40    fn parse(input: ParseStream<'_>) -> Result<Self> {
41        input.parse::<Token![,]>()?;
42        input.parse::<Token![if]>()?;
43        let expr = Expr::parse_without_eager_brace(input)?;
44        Ok(Condition { expr })
45    }
46}
47
48impl ToTokens for Condition {
49    fn to_tokens(&self, tokens: &mut TokenStream) {
50        self.expr.to_tokens(tokens)
51    }
52}
53
54struct Branch {
55    bind: Pat,
56    check: Pat,
57    future: Expr,
58    condition: Option<Condition>,
59    clause: Clause,
60}
61
62impl Branch {
63    fn conditional_future(&self) -> ConditionalFuture<'_> {
64        ConditionalFuture { future: &self.future, condition: self.condition.as_ref() }
65    }
66}
67
68struct ConditionalFuture<'a> {
69    future: &'a Expr,
70    condition: Option<&'a Condition>,
71}
72
73impl ToTokens for ConditionalFuture<'_> {
74    fn to_tokens(&self, tokens: &mut TokenStream) {
75        let future = self.future;
76        match self.condition {
77            None => quote! { ::core::option::Option::Some(#future) },
78            Some(condition) => quote! { if #condition { ::core::option::Option::Some(#future) } else { None } },
79        }
80        .to_tokens(tokens);
81    }
82}
83
84#[derive(Default)]
85struct Select {
86    default_clause: Option<Clause>,
87    complete_clause: Option<Clause>,
88    branches: Vec<Branch>,
89}
90
91// This is mainly copied from https://github.com/tokio-rs/tokio/blob/tokio-1.46.1/tokio-macros/src/select.rs#L58
92//
93// See the LICENSE: https://github.com/tokio-rs/tokio/blob/tokio-1.46.1/LICENSE
94fn clean_pattern(pat: &mut Pat) {
95    match pat {
96        syn::Pat::Ident(ident) => {
97            ident.by_ref = None;
98            ident.mutability = None;
99            if let Some((_at, pat)) = &mut ident.subpat {
100                clean_pattern(&mut *pat);
101            }
102        },
103        syn::Pat::Or(or) => {
104            for case in &mut or.cases {
105                clean_pattern(case);
106            }
107        },
108        syn::Pat::Slice(slice) => {
109            for elem in &mut slice.elems {
110                clean_pattern(elem);
111            }
112        },
113        syn::Pat::Struct(struct_pat) => {
114            for field in &mut struct_pat.fields {
115                clean_pattern(&mut field.pat);
116            }
117        },
118        syn::Pat::Tuple(tuple) => {
119            for elem in &mut tuple.elems {
120                clean_pattern(elem);
121            }
122        },
123        syn::Pat::TupleStruct(tuple) => {
124            for elem in &mut tuple.elems {
125                clean_pattern(elem);
126            }
127        },
128        syn::Pat::Reference(reference) => {
129            reference.mutability = None;
130            clean_pattern(&mut reference.pat);
131        },
132        syn::Pat::Type(type_pat) => {
133            clean_pattern(&mut type_pat.pat);
134        },
135        _ => {},
136    };
137}
138
139fn to_check_pat(pat: &Pat) -> Pat {
140    let mut pat = pat.clone();
141    clean_pattern(&mut pat);
142    pat
143}
144
145impl Parse for Select {
146    fn parse(input: ParseStream<'_>) -> Result<Self> {
147        let mut select = Select::default();
148        while !input.is_empty() {
149            if input.peek(Token![default]) && input.peek2(Token![=>]) {
150                if select.default_clause.is_some() {
151                    return Err(input.error("`select!`: more than one `default` clauses"));
152                }
153                input.parse::<Token![default]>()?;
154                let clause = Clause::parse(input)?;
155                select.default_clause = Some(clause);
156            } else if input.peek(kw::complete) && input.peek2(Token![=>]) {
157                if select.complete_clause.is_some() {
158                    return Err(input.error("`select!`: more than one `complete` clauses"));
159                }
160                input.parse::<kw::complete>()?;
161                let clause = Clause::parse(input)?;
162                select.complete_clause = Some(clause);
163            } else {
164                let bind = Pat::parse_multi(input)?;
165                input.parse::<Token![=]>()?;
166                let future = input.parse::<Expr>()?;
167                let condition = if input.peek(Token![,]) { Some(input.parse::<Condition>()?) } else { None };
168                let clause = Clause::parse(input)?;
169                let check = to_check_pat(&bind);
170                select.branches.push(Branch { bind, check, future, condition, clause });
171            }
172        }
173        match (select.branches.is_empty(), select.complete_clause.is_some(), select.default_clause.is_some()) {
174            (true, false, false) => return Err(input.error("`select!`: no branch")),
175            (true, false, true) => return Err(input.error("`select!`: no branch except `default`")),
176            (true, true, false) => return Err(input.error("`select!`: no branch except `complete`")),
177            (true, true, true) => return Err(input.error("`select!`: no branch except `default` and `complete`")),
178            (_, _, _) => {},
179        };
180        Ok(select)
181    }
182}
183
184fn define_output_enum(ident: &Ident, branches: usize, span: Span) -> (Vec<Ident>, TokenStream) {
185    let type_names: Vec<_> = (0..branches).map(|i| format_ident!("T{i}", span = span)).collect();
186    let branch_names: Vec<_> = (0..branches).map(|i| format_ident!("_{i}", span = span)).collect();
187    let output_enum = quote! {
188        enum #ident<#(#type_names,)*> {
189            Completed,
190            WouldBlock,
191            #(
192                #branch_names(#type_names),
193            )*
194        };
195    };
196    (branch_names, output_enum)
197}
198
199fn select_internal(input: proc_macro::TokenStream, biased: bool) -> proc_macro::TokenStream {
200    let select = syn::parse_macro_input!(input as Select);
201    let span = Span::call_site();
202    let output_ident = Ident::new("__SelectOutput", span);
203    let (branch_names, output_enum) = define_output_enum(&output_ident, select.branches.len(), span);
204
205    let branch_futures = select.branches.iter().map(|branch| branch.conditional_future());
206
207    let select_futures_declartion = quote! {
208        let mut __select_futures = (#(#branch_futures,)*);
209        // Shadow it so it won't be moved accidentally.
210        let mut __select_futures = &mut __select_futures;
211    };
212
213    let default_handler = match select.default_clause.as_ref() {
214        None => quote! { ::core::unreachable!("not in unblocking mode") },
215        Some(clause) => quote! { #clause },
216    };
217
218    let complete_handler = match select.complete_clause.as_ref() {
219        None => quote! {
220            ::core::panic!("all branches are disabled or completed and there is no `default` nor `complete`")
221        },
222        Some(clause) => quote! { #clause },
223    };
224
225    let (pending_declaration, pending_assignment, pending_check) =
226        match select.complete_clause.is_some() || select.default_clause.is_none() {
227            true => (
228                quote! {
229                    let mut any_pending = false;
230                },
231                quote! {
232                    any_pending = true;
233                },
234                quote! {
235                    if !any_pending {
236                        return ::core::task::Poll::Ready(__SelectOutput::Completed);
237                    }
238                },
239            ),
240            false => (quote! {}, quote! {}, quote! {}),
241        };
242    let default_clause = match select.default_clause.is_some() {
243        true => quote! { ::core::task::Poll::Ready(__SelectOutput::WouldBlock) },
244        false => quote! { ::core::task::Poll::Pending },
245    };
246
247    let (biased_start, biased_branch) = match biased {
248        true => (quote! {}, quote! { let branch = i; }),
249        false => (
250            quote! {
251                let start = (&__select_futures as *const _ as usize) >> 3;
252            },
253            quote! {
254                #[allow(clippy::modulo_one)]
255                let branch = (start +i ) % BRANCHES;
256            },
257        ),
258    };
259
260    let branch_handlers = select.branches.iter().map(|branch| &branch.clause);
261    let branch_bindings = select.branches.iter().map(|branch| &branch.bind);
262    let branch_binding_checks = select.branches.iter().map(|branch| &branch.check);
263
264    let n_branches = select.branches.len();
265    let branch_indices = (0..n_branches).map(Index::from);
266
267    quote! {{
268        #output_enum
269        const BRANCHES: usize = #n_branches;
270        let mut output = {
271            #select_futures_declartion
272            ::core::future::poll_fn(|cx| {
273                #biased_start
274                #pending_declaration
275                for i in 0..BRANCHES {
276                    #biased_branch
277                    match branch {
278                        #(
279                            #branch_indices => {
280                                let ::core::option::Option::Some(future) = __select_futures.#branch_indices.as_mut() else {
281                                    continue;
282                                };
283                                #[allow(unused_unsafe)]
284                                let future = unsafe {
285                                    ::core::pin::Pin::new_unchecked(future)
286                                };
287                                let mut output = match ::core::future::Future::poll(
288                                    future,
289                                    cx,
290                                ) {
291                                    ::core::task::Poll::Ready(output) => output,
292                                    ::core::task::Poll::Pending => {
293                                        #pending_assignment
294                                        continue;
295                                    },
296                                };
297                                __select_futures.#branch_indices = ::core::option::Option::None;
298                                #[allow(unreachable_patterns)]
299                                #[allow(unused_variables)]
300                                match &output {
301                                    #branch_binding_checks => {},
302                                    _ => continue,
303                                };
304                                return ::core::task::Poll::Ready(__SelectOutput::#branch_names(output));
305                            }
306                        )*
307                            _ => ::core::unreachable!("select! encounter mismatch branch in polling"),
308                    }
309                }
310                #pending_check
311                #default_clause
312            }).await
313        };
314        match output {
315            __SelectOutput::WouldBlock => #default_handler,
316            __SelectOutput::Completed => #complete_handler,
317            #(
318                __SelectOutput::#branch_names(#branch_bindings) => #branch_handlers,
319            )*
320            #[allow(unreachable_patterns)] // In case of refutable patterns in branches
321            _ => ::core::unreachable!("select! fail to pattern match"),
322        }
323    }}.into()
324}
325
326#[proc_macro]
327pub fn select_default(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
328    select_internal(input, false)
329}
330
331#[proc_macro]
332pub fn select_biased(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
333    select_internal(input, true)
334}