1use proc_macro::TokenStream;
2use quote::{quote, quote_spanned};
3use syn::parse::{Parse, ParseStream, Result};
4use syn::spanned::Spanned;
5use syn::{parse_macro_input, Expr, Ident, Token, Type, Visibility};
6
7struct AsyncStatic {
8 visibility: Visibility,
9 name: Ident,
10 ty: Type,
11 init: Expr,
12}
13
14impl Parse for AsyncStatic {
15 fn parse(input: ParseStream) -> Result<Self> {
16 let visibility: Visibility = input.parse()?;
17 input.parse::<Token![static]>()?;
18 input.parse::<Token![ref]>()?;
19 let name: Ident = input.parse()?;
20 input.parse::<Token![:]>()?;
21 let ty: Type = input.parse()?;
22 input.parse::<Token![=]>()?;
23 let init: Expr = input.parse()?;
24 input.parse::<Token![;]>()?;
25 Ok(AsyncStatic {
26 visibility,
27 name,
28 ty,
29 init,
30 })
31 }
32}
33
34#[proc_macro]
35pub fn async_static(input: TokenStream) -> TokenStream {
36 let AsyncStatic {
37 visibility,
38 name,
39 ty,
40 init,
41 } = parse_macro_input!(input as AsyncStatic);
42
43 let init_future = quote_spanned! {init.span()=>
44 once_cell::sync::Lazy::new(|| std::sync::Mutex::new(Box::pin(async { #init })))
45 };
46
47 let expanded = quote! {
48 #[allow(non_camel_case_types)]
49 #visibility struct #name;
50
51 impl std::future::Future for #name {
52 type Output = &'static #ty;
53 #[inline(always)]
54 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> {
55 static ONCE: once_cell::sync::OnceCell<#ty> = once_cell::sync::OnceCell::new();
56 static FUT: once_cell::sync::Lazy<std::sync::Mutex<std::pin::Pin<Box<dyn Send + std::future::Future<Output = #ty>>>>> = #init_future;
57
58 if let Some(v) = ONCE.get() {
60 return std::task::Poll::Ready(v);
61 }
62 if let Ok(mut fut) = FUT.try_lock() {
63 match fut.as_mut().poll(cx) {
64 std::task::Poll::Ready(value) => {
65 if ONCE.set(value).is_err() {
66 cx.waker().wake_by_ref();
67 return std::task::Poll::Pending;
68 }
69 }
70 std::task::Poll::Pending => {
71 cx.waker().wake_by_ref();
72 return std::task::Poll::Pending;
73 }
74 };
75 std::task::Poll::Ready(ONCE.get().unwrap())
76 } else {
77 cx.waker().wake_by_ref();
78 std::task::Poll::Pending
79 }
80 }
81 }
82 };
83
84 TokenStream::from(expanded)
85}