async_select_proc_macros/
lib.rs1use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, ToTokens};
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Index, Pat, Result, Token};
7
8mod kw {
9 syn::custom_keyword!(complete);
10}
11
12struct Clause {
13 expr: Expr,
14}
15
16impl Parse for Clause {
17 fn parse(input: ParseStream<'_>) -> Result<Self> {
18 input.parse::<Token![=>]>()?;
19 let expr = Expr::parse_with_earlier_boundary_rule(input)?;
20 if matches!(expr, Expr::Block(_)) {
21 input.parse::<Option<Token![,]>>()?;
22 } else if !input.is_empty() {
23 input.parse::<Token![,]>()?;
24 }
25 Ok(Clause { expr })
26 }
27}
28
29impl ToTokens for Clause {
30 fn to_tokens(&self, tokens: &mut TokenStream) {
31 self.expr.to_tokens(tokens)
32 }
33}
34
35struct Condition {
36 expr: Expr,
37}
38
39impl Parse for Condition {
40 fn parse(input: ParseStream<'_>) -> Result<Self> {
41 input.parse::<Token![,]>()?;
42 input.parse::<Token![if]>()?;
43 let expr = Expr::parse_without_eager_brace(input)?;
44 Ok(Condition { expr })
45 }
46}
47
48impl ToTokens for Condition {
49 fn to_tokens(&self, tokens: &mut TokenStream) {
50 self.expr.to_tokens(tokens)
51 }
52}
53
54struct Branch {
55 bind: Pat,
56 check: Pat,
57 future: Expr,
58 condition: Option<Condition>,
59 clause: Clause,
60}
61
62impl Branch {
63 fn conditional_future(&self) -> ConditionalFuture<'_> {
64 ConditionalFuture { future: &self.future, condition: self.condition.as_ref() }
65 }
66}
67
68struct ConditionalFuture<'a> {
69 future: &'a Expr,
70 condition: Option<&'a Condition>,
71}
72
73impl ToTokens for ConditionalFuture<'_> {
74 fn to_tokens(&self, tokens: &mut TokenStream) {
75 let future = self.future;
76 match self.condition {
77 None => quote! { ::core::option::Option::Some(#future) },
78 Some(condition) => quote! { if #condition { ::core::option::Option::Some(#future) } else { None } },
79 }
80 .to_tokens(tokens);
81 }
82}
83
84#[derive(Default)]
85struct Select {
86 default_clause: Option<Clause>,
87 complete_clause: Option<Clause>,
88 branches: Vec<Branch>,
89}
90
91fn clean_pattern(pat: &mut Pat) {
95 match pat {
96 syn::Pat::Ident(ident) => {
97 ident.by_ref = None;
98 ident.mutability = None;
99 if let Some((_at, pat)) = &mut ident.subpat {
100 clean_pattern(&mut *pat);
101 }
102 },
103 syn::Pat::Or(or) => {
104 for case in &mut or.cases {
105 clean_pattern(case);
106 }
107 },
108 syn::Pat::Slice(slice) => {
109 for elem in &mut slice.elems {
110 clean_pattern(elem);
111 }
112 },
113 syn::Pat::Struct(struct_pat) => {
114 for field in &mut struct_pat.fields {
115 clean_pattern(&mut field.pat);
116 }
117 },
118 syn::Pat::Tuple(tuple) => {
119 for elem in &mut tuple.elems {
120 clean_pattern(elem);
121 }
122 },
123 syn::Pat::TupleStruct(tuple) => {
124 for elem in &mut tuple.elems {
125 clean_pattern(elem);
126 }
127 },
128 syn::Pat::Reference(reference) => {
129 reference.mutability = None;
130 clean_pattern(&mut reference.pat);
131 },
132 syn::Pat::Type(type_pat) => {
133 clean_pattern(&mut type_pat.pat);
134 },
135 _ => {},
136 };
137}
138
139fn to_check_pat(pat: &Pat) -> Pat {
140 let mut pat = pat.clone();
141 clean_pattern(&mut pat);
142 pat
143}
144
145impl Parse for Select {
146 fn parse(input: ParseStream<'_>) -> Result<Self> {
147 let mut select = Select::default();
148 while !input.is_empty() {
149 if input.peek(Token![default]) && input.peek2(Token![=>]) {
150 if select.default_clause.is_some() {
151 return Err(input.error("`select!`: more than one `default` clauses"));
152 }
153 input.parse::<Token![default]>()?;
154 let clause = Clause::parse(input)?;
155 select.default_clause = Some(clause);
156 } else if input.peek(kw::complete) && input.peek2(Token![=>]) {
157 if select.complete_clause.is_some() {
158 return Err(input.error("`select!`: more than one `complete` clauses"));
159 }
160 input.parse::<kw::complete>()?;
161 let clause = Clause::parse(input)?;
162 select.complete_clause = Some(clause);
163 } else {
164 let bind = Pat::parse_multi(input)?;
165 input.parse::<Token![=]>()?;
166 let future = input.parse::<Expr>()?;
167 let condition = if input.peek(Token![,]) { Some(input.parse::<Condition>()?) } else { None };
168 let clause = Clause::parse(input)?;
169 let check = to_check_pat(&bind);
170 select.branches.push(Branch { bind, check, future, condition, clause });
171 }
172 }
173 match (select.branches.is_empty(), select.complete_clause.is_some(), select.default_clause.is_some()) {
174 (true, false, false) => return Err(input.error("`select!`: no branch")),
175 (true, false, true) => return Err(input.error("`select!`: no branch except `default`")),
176 (true, true, false) => return Err(input.error("`select!`: no branch except `complete`")),
177 (true, true, true) => return Err(input.error("`select!`: no branch except `default` and `complete`")),
178 (_, _, _) => {},
179 };
180 Ok(select)
181 }
182}
183
184fn define_output_enum(ident: &Ident, branches: usize, span: Span) -> (Vec<Ident>, TokenStream) {
185 let type_names: Vec<_> = (0..branches).map(|i| format_ident!("T{i}", span = span)).collect();
186 let branch_names: Vec<_> = (0..branches).map(|i| format_ident!("_{i}", span = span)).collect();
187 let output_enum = quote! {
188 enum #ident<#(#type_names,)*> {
189 Completed,
190 WouldBlock,
191 #(
192 #branch_names(#type_names),
193 )*
194 };
195 };
196 (branch_names, output_enum)
197}
198
199fn select_internal(input: proc_macro::TokenStream, biased: bool) -> proc_macro::TokenStream {
200 let select = syn::parse_macro_input!(input as Select);
201 let span = Span::call_site();
202 let output_ident = Ident::new("__SelectOutput", span);
203 let (branch_names, output_enum) = define_output_enum(&output_ident, select.branches.len(), span);
204
205 let branch_futures = select.branches.iter().map(|branch| branch.conditional_future());
206
207 let select_futures_declartion = quote! {
208 let mut __select_futures = (#(#branch_futures,)*);
209 let mut __select_futures = &mut __select_futures;
211 };
212
213 let default_handler = match select.default_clause.as_ref() {
214 None => quote! { ::core::unreachable!("not in unblocking mode") },
215 Some(clause) => quote! { #clause },
216 };
217
218 let complete_handler = match select.complete_clause.as_ref() {
219 None => quote! {
220 ::core::panic!("all branches are disabled or completed and there is no `default` nor `complete`")
221 },
222 Some(clause) => quote! { #clause },
223 };
224
225 let (pending_declaration, pending_assignment, pending_check) =
226 match select.complete_clause.is_some() || select.default_clause.is_none() {
227 true => (
228 quote! {
229 let mut any_pending = false;
230 },
231 quote! {
232 any_pending = true;
233 },
234 quote! {
235 if !any_pending {
236 return ::core::task::Poll::Ready(__SelectOutput::Completed);
237 }
238 },
239 ),
240 false => (quote! {}, quote! {}, quote! {}),
241 };
242 let default_clause = match select.default_clause.is_some() {
243 true => quote! { ::core::task::Poll::Ready(__SelectOutput::WouldBlock) },
244 false => quote! { ::core::task::Poll::Pending },
245 };
246
247 let (biased_start, biased_branch) = match biased {
248 true => (quote! {}, quote! { let branch = i; }),
249 false => (
250 quote! {
251 let start = (&__select_futures as *const _ as usize) >> 3;
252 },
253 quote! {
254 #[allow(clippy::modulo_one)]
255 let branch = (start +i ) % BRANCHES;
256 },
257 ),
258 };
259
260 let branch_handlers = select.branches.iter().map(|branch| &branch.clause);
261 let branch_bindings = select.branches.iter().map(|branch| &branch.bind);
262 let branch_binding_checks = select.branches.iter().map(|branch| &branch.check);
263
264 let n_branches = select.branches.len();
265 let branch_indices = (0..n_branches).map(Index::from);
266
267 quote! {{
268 #output_enum
269 const BRANCHES: usize = #n_branches;
270 let mut output = {
271 #select_futures_declartion
272 ::core::future::poll_fn(|cx| {
273 #biased_start
274 #pending_declaration
275 for i in 0..BRANCHES {
276 #biased_branch
277 match branch {
278 #(
279 #branch_indices => {
280 let ::core::option::Option::Some(future) = __select_futures.#branch_indices.as_mut() else {
281 continue;
282 };
283 #[allow(unused_unsafe)]
284 let future = unsafe {
285 ::core::pin::Pin::new_unchecked(future)
286 };
287 let mut output = match ::core::future::Future::poll(
288 future,
289 cx,
290 ) {
291 ::core::task::Poll::Ready(output) => output,
292 ::core::task::Poll::Pending => {
293 #pending_assignment
294 continue;
295 },
296 };
297 __select_futures.#branch_indices = ::core::option::Option::None;
298 #[allow(unreachable_patterns)]
299 #[allow(unused_variables)]
300 match &output {
301 #branch_binding_checks => {},
302 _ => continue,
303 };
304 return ::core::task::Poll::Ready(__SelectOutput::#branch_names(output));
305 }
306 )*
307 _ => ::core::unreachable!("select! encounter mismatch branch in polling"),
308 }
309 }
310 #pending_check
311 #default_clause
312 }).await
313 };
314 match output {
315 __SelectOutput::WouldBlock => #default_handler,
316 __SelectOutput::Completed => #complete_handler,
317 #(
318 __SelectOutput::#branch_names(#branch_bindings) => #branch_handlers,
319 )*
320 #[allow(unreachable_patterns)] _ => ::core::unreachable!("select! fail to pattern match"),
322 }
323 }}.into()
324}
325
326#[proc_macro]
327pub fn select_default(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
328 select_internal(input, false)
329}
330
331#[proc_macro]
332pub fn select_biased(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
333 select_internal(input, true)
334}