use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
Expr, Token,
};
struct JoinInput {
cx: Option<Expr>,
futures: Punctuated<Expr, Token![,]>,
}
impl Parse for JoinInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let cx = if input.peek2(Token![;]) {
let cx_expr: Expr = input.parse()?;
let _semi: Token![;] = input.parse()?;
Some(cx_expr)
} else {
None
};
let futures = Punctuated::parse_terminated(input)?;
Ok(Self { cx, futures })
}
}
pub fn join_impl(input: TokenStream) -> TokenStream {
let JoinInput { cx, futures } = parse_macro_input!(input as JoinInput);
let expanded = generate_join(cx.as_ref(), &futures);
TokenStream::from(expanded)
}
fn generate_join(cx: Option<&Expr>, futures: &Punctuated<Expr, Token![,]>) -> TokenStream2 {
let future_count = futures.len();
if future_count == 0 {
return quote! { () };
}
if future_count == 1 {
let fut = futures.first().unwrap();
return quote! {
(#fut.await,)
};
}
let fut_idents: Vec<_> = (0..future_count)
.map(|i| syn::Ident::new(&format!("__join_fut_{i}"), proc_macro2::Span::call_site()))
.collect();
let result_idents: Vec<_> = (0..future_count)
.map(|i| {
syn::Ident::new(
&format!("__join_result_{i}"),
proc_macro2::Span::call_site(),
)
})
.collect();
let fut_bindings: Vec<_> = futures
.iter()
.zip(fut_idents.iter())
.map(|(future, ident)| {
quote! { let #ident = #future; }
})
.collect();
let await_stmts: Vec<_> = fut_idents
.iter()
.zip(result_idents.iter())
.map(|(fut_ident, result_ident)| {
quote! { let #result_ident = #fut_ident.await; }
})
.collect();
let result_tuple: Vec<_> = result_idents
.iter()
.map(|ident| quote! { #ident })
.collect();
let cx_comment = if cx.is_some() {
quote! {
let _ = &#cx;
}
} else {
quote! {}
};
quote! {
{
#cx_comment
#(#fut_bindings)*
#(#await_stmts)*
(#(#result_tuple),*)
}
}
}
pub fn join_all_impl(input: TokenStream) -> TokenStream {
let JoinInput { cx, futures } = parse_macro_input!(input as JoinInput);
let expanded = generate_join_all(cx.as_ref(), &futures);
TokenStream::from(expanded)
}
fn generate_join_all(cx: Option<&Expr>, futures: &Punctuated<Expr, Token![,]>) -> TokenStream2 {
let future_count = futures.len();
if future_count == 0 {
return quote! { [] };
}
if future_count == 1 {
let fut = futures.first().unwrap();
return quote! {
[#fut.await]
};
}
let fut_idents: Vec<_> = (0..future_count)
.map(|i| syn::Ident::new(&format!("__join_fut_{i}"), proc_macro2::Span::call_site()))
.collect();
let result_idents: Vec<_> = (0..future_count)
.map(|i| {
syn::Ident::new(
&format!("__join_result_{i}"),
proc_macro2::Span::call_site(),
)
})
.collect();
let fut_bindings: Vec<_> = futures
.iter()
.zip(fut_idents.iter())
.map(|(future, ident)| {
quote! { let #ident = #future; }
})
.collect();
let await_stmts: Vec<_> = fut_idents
.iter()
.zip(result_idents.iter())
.map(|(fut_ident, result_ident)| {
quote! { let #result_ident = #fut_ident.await; }
})
.collect();
let result_array: Vec<_> = result_idents
.iter()
.map(|ident| quote! { #ident })
.collect();
let cx_comment = if cx.is_some() {
quote! {
let _ = &#cx;
}
} else {
quote! {}
};
quote! {
{
#cx_comment
#(#fut_bindings)*
#(#await_stmts)*
[#(#result_array),*]
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_single_future() {
let input: proc_macro2::TokenStream = quote! { future_a };
let parsed: JoinInput = syn::parse2(input).unwrap();
assert_eq!(parsed.futures.len(), 1);
}
#[test]
fn test_parse_multiple_futures() {
let input: proc_macro2::TokenStream = quote! { future_a, future_b, future_c };
let parsed: JoinInput = syn::parse2(input).unwrap();
assert_eq!(parsed.futures.len(), 3);
}
#[test]
fn test_parse_trailing_comma() {
let input: proc_macro2::TokenStream = quote! { future_a, future_b, };
let parsed: JoinInput = syn::parse2(input).unwrap();
assert_eq!(parsed.futures.len(), 2);
}
#[test]
fn test_parse_with_cx() {
let input: proc_macro2::TokenStream = quote! { cx; future_a, future_b };
let parsed: JoinInput = syn::parse2(input).unwrap();
assert!(parsed.cx.is_some());
assert_eq!(parsed.futures.len(), 2);
}
}