async_static/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, quote_spanned};
3use syn::parse::{Parse, ParseStream, Result};
4use syn::spanned::Spanned;
5use syn::{parse_macro_input, Expr, Ident, Token, Type, Visibility};
6
7struct AsyncStatic {
8    visibility: Visibility,
9    name: Ident,
10    ty: Type,
11    init: Expr,
12}
13
14impl Parse for AsyncStatic {
15    fn parse(input: ParseStream) -> Result<Self> {
16        let visibility: Visibility = input.parse()?;
17        input.parse::<Token![static]>()?;
18        input.parse::<Token![ref]>()?;
19        let name: Ident = input.parse()?;
20        input.parse::<Token![:]>()?;
21        let ty: Type = input.parse()?;
22        input.parse::<Token![=]>()?;
23        let init: Expr = input.parse()?;
24        input.parse::<Token![;]>()?;
25        Ok(AsyncStatic {
26            visibility,
27            name,
28            ty,
29            init,
30        })
31    }
32}
33
34#[proc_macro]
35pub fn async_static(input: TokenStream) -> TokenStream {
36    let AsyncStatic {
37        visibility,
38        name,
39        ty,
40        init,
41    } = parse_macro_input!(input as AsyncStatic);
42
43    let init_future = quote_spanned! {init.span()=>
44        once_cell::sync::Lazy::new(|| std::sync::Mutex::new(Box::pin(async { #init })))
45    };
46
47    let expanded = quote! {
48        #[allow(non_camel_case_types)]
49        #visibility struct #name;
50
51        impl std::future::Future for #name {
52            type Output = &'static #ty;
53            #[inline(always)]
54            fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> {
55                static ONCE: once_cell::sync::OnceCell<#ty> = once_cell::sync::OnceCell::new();
56                static FUT: once_cell::sync::Lazy<std::sync::Mutex<std::pin::Pin<Box<dyn Send + std::future::Future<Output = #ty>>>>> = #init_future;
57
58                // this is racy, but that's OK: it's just a fast case
59                if let Some(v) = ONCE.get() {
60                    return std::task::Poll::Ready(v);
61                }
62                if let Ok(mut fut) = FUT.try_lock() {
63                    match fut.as_mut().poll(cx) {
64                        std::task::Poll::Ready(value) => {
65                            if ONCE.set(value).is_err() {
66                                cx.waker().wake_by_ref();
67                                return std::task::Poll::Pending;
68                            }
69                        }
70                        std::task::Poll::Pending => {
71                            cx.waker().wake_by_ref();
72                            return std::task::Poll::Pending;
73                        }
74                    };
75                    std::task::Poll::Ready(ONCE.get().unwrap())
76                } else {
77                    cx.waker().wake_by_ref();
78                    std::task::Poll::Pending
79                }
80            }
81        }
82    };
83
84    TokenStream::from(expanded)
85}