use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
Expr, Token,
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
};
struct JoinInput {
cx: Option<Expr>,
futures: Punctuated<Expr, Token![,]>,
}
impl Parse for JoinInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let fork = input.fork();
let cx = if let Ok(cx_expr) = fork.parse::<Expr>() {
if fork.peek(Token![;]) {
let _ = input.parse::<Expr>()?;
let _semi: Token![;] = input.parse()?;
Some(cx_expr)
} else {
None
}
} else {
None
};
let futures = Punctuated::parse_terminated(input)?;
Ok(Self { cx, futures })
}
}
pub fn join_impl(input: TokenStream) -> TokenStream {
let JoinInput { cx, futures } = parse_macro_input!(input as JoinInput);
let expanded = generate_join(cx.as_ref(), &futures);
TokenStream::from(expanded)
}
fn generate_join(cx: Option<&Expr>, futures: &Punctuated<Expr, Token![,]>) -> TokenStream2 {
let future_count = futures.len();
let cx_ack = generate_cx_ack(cx);
if future_count == 0 {
return quote! {
{
#cx_ack
()
}
};
}
if future_count == 1 {
let fut = futures
.first()
.expect("future_count == 1 guarantees first element exists");
return quote! {
{
#cx_ack
(#fut.await,)
}
};
}
let fut_idents: Vec<_> = (0..future_count)
.map(|i| syn::Ident::new(&format!("__join_fut_{i}"), proc_macro2::Span::call_site()))
.collect();
let result_idents: Vec<_> = (0..future_count)
.map(|i| {
syn::Ident::new(
&format!("__join_result_{i}"),
proc_macro2::Span::call_site(),
)
})
.collect();
let fut_bindings: Vec<_> = futures
.iter()
.zip(fut_idents.iter())
.map(|(future, ident)| {
quote! { let #ident = #future; }
})
.collect();
let await_stmts: Vec<_> = fut_idents
.iter()
.zip(result_idents.iter())
.map(|(fut_ident, result_ident)| {
quote! { let #result_ident = #fut_ident.await; }
})
.collect();
let result_tuple: Vec<_> = result_idents
.iter()
.map(|ident| quote! { #ident })
.collect();
quote! {
{
#cx_ack
#(#fut_bindings)*
#(#await_stmts)*
(#(#result_tuple),*)
}
}
}
pub fn join_all_impl(input: TokenStream) -> TokenStream {
let JoinInput { cx, futures } = parse_macro_input!(input as JoinInput);
let expanded = generate_join_all(cx.as_ref(), &futures);
TokenStream::from(expanded)
}
fn generate_join_all(cx: Option<&Expr>, futures: &Punctuated<Expr, Token![,]>) -> TokenStream2 {
let future_count = futures.len();
let cx_ack = generate_cx_ack(cx);
if future_count == 0 {
return quote! {
{
#cx_ack
[]
}
};
}
if future_count == 1 {
let fut = futures
.first()
.expect("future_count == 1 guarantees first element exists");
return quote! {
{
#cx_ack
[#fut.await]
}
};
}
let fut_idents: Vec<_> = (0..future_count)
.map(|i| syn::Ident::new(&format!("__join_fut_{i}"), proc_macro2::Span::call_site()))
.collect();
let result_idents: Vec<_> = (0..future_count)
.map(|i| {
syn::Ident::new(
&format!("__join_result_{i}"),
proc_macro2::Span::call_site(),
)
})
.collect();
let fut_bindings: Vec<_> = futures
.iter()
.zip(fut_idents.iter())
.map(|(future, ident)| {
quote! { let #ident = #future; }
})
.collect();
let await_stmts: Vec<_> = fut_idents
.iter()
.zip(result_idents.iter())
.map(|(fut_ident, result_ident)| {
quote! { let #result_ident = #fut_ident.await; }
})
.collect();
let result_array: Vec<_> = result_idents
.iter()
.map(|ident| quote! { #ident })
.collect();
quote! {
{
#cx_ack
#(#fut_bindings)*
#(#await_stmts)*
[#(#result_array),*]
}
}
}
fn generate_cx_ack(cx: Option<&Expr>) -> TokenStream2 {
if cx.is_some() {
quote! {
let _ = &#cx;
}
} else {
quote! {}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_single_future() {
let input: proc_macro2::TokenStream = quote! { future_a };
let parsed: JoinInput = syn::parse2(input).unwrap();
assert_eq!(parsed.futures.len(), 1);
}
#[test]
fn test_parse_multiple_futures() {
let input: proc_macro2::TokenStream = quote! { future_a, future_b, future_c };
let parsed: JoinInput = syn::parse2(input).unwrap();
assert_eq!(parsed.futures.len(), 3);
}
#[test]
fn test_parse_trailing_comma() {
let input: proc_macro2::TokenStream = quote! { future_a, future_b, };
let parsed: JoinInput = syn::parse2(input).unwrap();
assert_eq!(parsed.futures.len(), 2);
}
#[test]
fn test_parse_with_cx() {
let input: proc_macro2::TokenStream = quote! { cx; future_a, future_b };
let parsed: JoinInput = syn::parse2(input).unwrap();
assert!(parsed.cx.is_some());
assert_eq!(parsed.futures.len(), 2);
}
#[test]
fn test_join_single_future_keeps_cx_expr() {
let input: JoinInput = syn::parse2(quote! { make_cx(); future_a }).unwrap();
let tokens = generate_join(input.cx.as_ref(), &input.futures).to_string();
assert!(
tokens.contains("make_cx"),
"single-future join must still typecheck the cx expression"
);
}
#[test]
fn test_join_empty_keeps_cx_expr() {
let input: JoinInput = syn::parse2(quote! { make_cx(); }).unwrap();
let tokens = generate_join(input.cx.as_ref(), &input.futures).to_string();
assert!(
tokens.contains("make_cx"),
"empty join must still typecheck the cx expression"
);
}
#[test]
fn test_join_all_single_future_keeps_cx_expr() {
let input: JoinInput = syn::parse2(quote! { make_cx(); future_a }).unwrap();
let tokens = generate_join_all(input.cx.as_ref(), &input.futures).to_string();
assert!(
tokens.contains("make_cx"),
"single-future join_all must still typecheck the cx expression"
);
}
#[test]
fn test_join_all_empty_keeps_cx_expr() {
let input: JoinInput = syn::parse2(quote! { make_cx(); }).unwrap();
let tokens = generate_join_all(input.cx.as_ref(), &input.futures).to_string();
assert!(
tokens.contains("make_cx"),
"empty join_all must still typecheck the cx expression"
);
}
}