Skip to main content

spargio_macros/
lib.rs

1//! Procedural macros for `spargio`.
2//!
3//! This crate currently exposes the [`main`] attribute macro.
4#![deny(missing_docs)]
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::parse::Parser;
9use syn::parse_macro_input;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{Expr, ExprLit, FnArg, ItemFn, Lit, MetaNameValue, Pat, PatIdent, Token};
13
14#[derive(Default)]
15struct MainArgs {
16    shards: Option<Expr>,
17    backend: Option<BackendArg>,
18}
19
20#[derive(Clone, Copy)]
21enum BackendArg {
22    IoUring,
23}
24
25impl BackendArg {
26    fn parse(value: &Expr) -> syn::Result<Self> {
27        let Expr::Lit(ExprLit {
28            lit: Lit::Str(lit), ..
29        }) = value
30        else {
31            return Err(syn::Error::new(
32                value.span(),
33                "backend must be a string literal: \"io_uring\"",
34            ));
35        };
36
37        match lit.value().as_str() {
38            "io_uring" => Ok(Self::IoUring),
39            other => Err(syn::Error::new(
40                lit.span(),
41                format!("unsupported backend '{other}'; expected \"io_uring\""),
42            )),
43        }
44    }
45
46    fn as_tokens(self) -> proc_macro2::TokenStream {
47        match self {
48            Self::IoUring => quote!(::spargio::BackendKind::IoUring),
49        }
50    }
51}
52
53impl MainArgs {
54    fn parse(args: TokenStream) -> syn::Result<Self> {
55        let mut out = Self::default();
56        let parser = Punctuated::<MetaNameValue, Token![,]>::parse_terminated;
57        let args = parser.parse(args)?;
58        for arg in args {
59            if arg.path.is_ident("shards") {
60                if out.shards.is_some() {
61                    return Err(syn::Error::new(
62                        arg.path.span(),
63                        "duplicate 'shards' option",
64                    ));
65                }
66                out.shards = Some(arg.value);
67                continue;
68            }
69            if arg.path.is_ident("backend") {
70                if out.backend.is_some() {
71                    return Err(syn::Error::new(
72                        arg.path.span(),
73                        "duplicate 'backend' option",
74                    ));
75                }
76                out.backend = Some(BackendArg::parse(&arg.value)?);
77                continue;
78            }
79            return Err(syn::Error::new(
80                arg.path.span(),
81                "unsupported option; expected one of: shards, backend",
82            ));
83        }
84        Ok(out)
85    }
86}
87
88#[proc_macro_attribute]
89/// Attribute macro for defining a Spargio runtime entrypoint.
90///
91/// Supported options:
92/// - `shards = <expr>`: sets runtime shard count.
93/// - `backend = "io_uring"`: selects runtime backend.
94///
95/// The attributed function must be `async` and may take at most one parameter
96/// (a `spargio::RuntimeHandle` binding).
97pub fn main(args: TokenStream, item: TokenStream) -> TokenStream {
98    let args = match MainArgs::parse(args) {
99        Ok(args) => args,
100        Err(err) => return err.to_compile_error().into(),
101    };
102
103    let input = parse_macro_input!(item as ItemFn);
104    if input.sig.asyncness.is_none() {
105        return syn::Error::new(
106            input.sig.fn_token.span(),
107            "#[spargio::main] can only be used on async functions",
108        )
109        .to_compile_error()
110        .into();
111    }
112    let inject_handle = match input.sig.inputs.len() {
113        0 => None,
114        1 => {
115            let Some(arg) = input.sig.inputs.first() else {
116                return syn::Error::new(input.sig.inputs.span(), "missing function parameter")
117                    .to_compile_error()
118                    .into();
119            };
120            match arg {
121                FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
122                    Pat::Ident(PatIdent { .. }) => Some(()),
123                    _ => {
124                        return syn::Error::new(
125                            pat_type.pat.span(),
126                            "#[spargio::main] parameter must be an identifier binding",
127                        )
128                        .to_compile_error()
129                        .into();
130                    }
131                },
132                FnArg::Receiver(receiver) => {
133                    return syn::Error::new(
134                        receiver.span(),
135                        "#[spargio::main] does not support method receivers",
136                    )
137                    .to_compile_error()
138                    .into();
139                }
140            }
141        }
142        _ => {
143            return syn::Error::new(
144                input.sig.inputs.span(),
145                "#[spargio::main] supports at most one function parameter (RuntimeHandle)",
146            )
147            .to_compile_error()
148            .into();
149        }
150    };
151    if !input.sig.generics.params.is_empty() {
152        return syn::Error::new(
153            input.sig.generics.span(),
154            "#[spargio::main] does not support generic parameters",
155        )
156        .to_compile_error()
157        .into();
158    }
159
160    let attrs = input.attrs;
161    let vis = input.vis;
162    let name = input.sig.ident;
163    let inputs = input.sig.inputs;
164    let output = input.sig.output;
165    let block = input.block;
166    let inner_name = syn::Ident::new(&format!("__spargio_async_{}", name), name.span());
167
168    let shards_builder = args
169        .shards
170        .map(|expr| quote!(.shards(#expr)))
171        .unwrap_or_default();
172    let backend_builder = args
173        .backend
174        .map(|backend| {
175            let backend = backend.as_tokens();
176            quote!(.backend(#backend))
177        })
178        .unwrap_or_default();
179
180    let call_inner = if inject_handle.is_some() {
181        quote!(#inner_name(__spargio_handle).await)
182    } else {
183        quote!(#inner_name().await)
184    };
185
186    quote! {
187            #(#attrs)*
188            #vis fn #name() #output {
189                let __spargio_builder = ::spargio::Runtime::builder()
190                    #shards_builder
191                    #backend_builder;
192                match ::spargio::__private::block_on(::spargio::run_with(__spargio_builder, |__spargio_handle| async move { #call_inner })) {
193                    Ok(__spargio_out) => __spargio_out,
194                    Err(::spargio::RuntimeError::UnsupportedBackend(__spargio_msg)) => {
195                        panic!(
196                            "spargio::main backend is not supported on this platform: {}",
197                            __spargio_msg
198                        )
199                    }
200                    Err(__spargio_err) => {
201                        panic!("spargio::main runtime startup failed: {:?}", __spargio_err)
202                    }
203                }
204            }
205
206            async fn #inner_name(#inputs) #output {
207                #block
208            }
209        }
210        .into()
211}