futures_join_macro/
lib.rs

1//! The futures-rs `join! macro implementation.
2
3#![recursion_limit = "128"]
4#![warn(rust_2018_idioms, unreachable_pub)]
5// It cannot be included in the published code because this lints have false positives in the minimum required version.
6#![cfg_attr(test, warn(single_use_lifetimes))]
7#![warn(clippy::all)]
8
9extern crate proc_macro;
10
11use proc_macro::TokenStream;
12use proc_macro2::{Span, TokenStream as TokenStream2};
13use proc_macro_hack::proc_macro_hack;
14use quote::{format_ident, quote};
15use syn::parse::{Parse, ParseStream};
16use syn::{parenthesized, parse_quote, Expr, Ident, Token};
17
18mod kw {
19    syn::custom_keyword!(futures_crate_path);
20}
21
22#[derive(Default)]
23struct Join {
24    futures_crate_path: Option<syn::Path>,
25    fut_exprs: Vec<Expr>,
26}
27
28impl Parse for Join {
29    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
30        let mut join = Join::default();
31
32        // When `futures_crate_path(::path::to::futures::lib)` is provided,
33        // it sets the path through which futures library functions will be
34        // accessed.
35        if input.peek(kw::futures_crate_path) {
36            input.parse::<kw::futures_crate_path>()?;
37            let content;
38            parenthesized!(content in input);
39            join.futures_crate_path = Some(content.parse()?);
40        }
41
42        while !input.is_empty() {
43            join.fut_exprs.push(input.parse::<Expr>()?);
44
45            if !input.is_empty() {
46                input.parse::<Token![,]>()?;
47            }
48        }
49
50        Ok(join)
51    }
52}
53
54fn bind_futures(
55    futures_crate: &syn::Path,
56    fut_exprs: Vec<Expr>,
57    span: Span,
58) -> (Vec<TokenStream2>, Vec<Ident>) {
59    let mut future_let_bindings = Vec::with_capacity(fut_exprs.len());
60    let future_names: Vec<_> = fut_exprs
61        .into_iter()
62        .enumerate()
63        .map(|(i, expr)| {
64            let name = format_ident!("_fut{}", i, span = span);
65            future_let_bindings.push(quote! {
66                // Move future into a local so that it is pinned in one place and
67                // is no longer accessible by the end user.
68                let mut #name = #futures_crate::future::maybe_done(#expr);
69            });
70            name
71        })
72        .collect();
73
74    (future_let_bindings, future_names)
75}
76
77/// The `join!` macro.
78#[proc_macro_hack]
79pub fn join(input: TokenStream) -> TokenStream {
80    let parsed = syn::parse_macro_input!(input as Join);
81
82    let futures_crate = parsed
83        .futures_crate_path
84        .unwrap_or_else(|| parse_quote!(::futures_util));
85
86    // should be def_site, but that's unstable
87    let span = Span::call_site();
88
89    let (future_let_bindings, future_names) = bind_futures(&futures_crate, parsed.fut_exprs, span);
90
91    let poll_futures = future_names.iter().map(|fut| {
92        quote! {
93            __all_done &= #futures_crate::core_reexport::future::Future::poll(
94                unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }, __cx).is_ready();
95        }
96    });
97    let take_outputs = future_names.iter().map(|fut| {
98        quote! {
99            unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }.take_output().unwrap(),
100        }
101    });
102
103    TokenStream::from(quote! { {
104        #( #future_let_bindings )*
105
106        #futures_crate::future::poll_fn(move |__cx: &mut #futures_crate::task::Context<'_>| {
107            let mut __all_done = true;
108            #( #poll_futures )*
109            if __all_done {
110                #futures_crate::core_reexport::task::Poll::Ready((
111                    #( #take_outputs )*
112                ))
113            } else {
114                #futures_crate::core_reexport::task::Poll::Pending
115            }
116        }).await
117    } })
118}
119
120/// The `try_join!` macro.
121#[proc_macro_hack]
122pub fn try_join(input: TokenStream) -> TokenStream {
123    let parsed = syn::parse_macro_input!(input as Join);
124
125    let futures_crate = parsed
126        .futures_crate_path
127        .unwrap_or_else(|| parse_quote!(::futures_util));
128
129    // should be def_site, but that's unstable
130    let span = Span::call_site();
131
132    let (future_let_bindings, future_names) = bind_futures(&futures_crate, parsed.fut_exprs, span);
133
134    let poll_futures = future_names.iter().map(|fut| {
135        quote! {
136            if #futures_crate::core_reexport::future::Future::poll(
137                unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }, __cx).is_pending()
138            {
139                __all_done = false;
140            } else if unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }.output_mut().unwrap().is_err() {
141                // `.err().unwrap()` rather than `.unwrap_err()` so that we don't introduce
142                // a `T: Debug` bound.
143                return #futures_crate::core_reexport::task::Poll::Ready(
144                    #futures_crate::core_reexport::result::Result::Err(
145                        unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }.take_output().unwrap().err().unwrap()
146                    )
147                );
148            }
149        }
150    });
151    let take_outputs = future_names.iter().map(|fut| {
152        quote! {
153            // `.ok().unwrap()` rather than `.unwrap()` so that we don't introduce
154            // an `E: Debug` bound.
155            unsafe { #futures_crate::core_reexport::pin::Pin::new_unchecked(&mut #fut) }.take_output().unwrap().ok().unwrap(),
156        }
157    });
158
159    TokenStream::from(quote! { {
160        #( #future_let_bindings )*
161
162        #futures_crate::future::poll_fn(move |__cx: &mut #futures_crate::task::Context<'_>| {
163            let mut __all_done = true;
164            #( #poll_futures )*
165            if __all_done {
166                #futures_crate::core_reexport::task::Poll::Ready(
167                    #futures_crate::core_reexport::result::Result::Ok((
168                        #( #take_outputs )*
169                    ))
170                )
171            } else {
172                #futures_crate::core_reexport::task::Poll::Pending
173            }
174        }).await
175    } })
176}