photonio_macros/
lib.rs

1//! Procedural macros for PhotonIO.
2
3#![warn(missing_docs, unreachable_pub)]
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::parse::Parser;
8
9/// Marks a function to be run on a runtime.
10///
11/// # Examples
12///
13/// ```ignore
14/// use photonio::{
15///     fs::File,
16///     io::{Write, WriteAt},
17/// };
18///
19/// #[photonio::main(num_threads = 4)]
20/// async fn main() -> std::io::Result<()> {
21///     let mut file = File::create("hello.txt").await?;
22///     file.write(b"hello").await?;
23///     file.write_at(b"world", 5).await?;
24///     Ok(())
25/// }
26/// ```
27///
28/// This is equivalent to:
29///
30/// ```ignore
31/// use photonio::{fs::File, io::Write, runtime::Builder};
32///
33/// fn main() -> std::io::Result<()> {
34///     let rt = Builder::new().num_threads(4).build()?;
35///     rt.block_on(async {
36///         let mut file = File::create("hello.txt").await?;
37///         file.write(b"hello").await?;
38///         Ok(())
39///     })
40/// }
41/// ```
42#[proc_macro_attribute]
43pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
44    transform(attr, item, false)
45}
46
47/// This is similar to [`macro@main`], but for tests.
48#[proc_macro_attribute]
49pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
50    transform(attr, item, true)
51}
52
53fn transform(attr: TokenStream, item: TokenStream, is_test: bool) -> TokenStream {
54    let opts = match Options::parse(attr.clone()) {
55        Ok(opts) => opts,
56        Err(e) => return token_stream_with_error(attr, e),
57    };
58    let mut func: syn::ItemFn = match syn::parse(item.clone()) {
59        Ok(func) => func,
60        Err(e) => return token_stream_with_error(item, e),
61    };
62
63    let head = if is_test {
64        quote! { #[::std::prelude::v1::test] }
65    } else {
66        quote! {}
67    };
68
69    let init = if is_test && opts.env_logger {
70        quote! { let _ = env_logger::builder().is_test(true).try_init(); }
71    } else {
72        quote! {}
73    };
74
75    let mut rt = quote! {
76        photonio::runtime::Builder::new()
77    };
78    if let Some(v) = opts.num_threads {
79        rt = quote! { #rt.num_threads(#v) }
80    }
81
82    func.sig.asyncness = None;
83    let block = func.block;
84    func.block = syn::parse2(quote! {
85        {
86            #init;
87            let block = async #block;
88            #rt.build().expect("failed to build runtime").block_on(block)
89        }
90    })
91    .unwrap();
92
93    quote! {
94        #head
95        #func
96    }
97    .into()
98}
99
100#[derive(Default)]
101struct Options {
102    num_threads: Option<usize>,
103    // Internal options for tests.
104    env_logger: bool,
105}
106
107type Attributes = syn::punctuated::Punctuated<syn::MetaNameValue, syn::Token![,]>;
108
109impl Options {
110    fn parse(input: TokenStream) -> Result<Self, syn::Error> {
111        let mut opts = Options::default();
112        let attrs = Attributes::parse_terminated.parse(input)?;
113        for attr in attrs {
114            let name = attr
115                .path
116                .get_ident()
117                .ok_or_else(|| syn::Error::new_spanned(&attr, "missing attribute name"))?
118                .to_string();
119            match name.as_str() {
120                "num_threads" => {
121                    opts.num_threads = Some(parse_int(&attr.lit)?);
122                }
123                "env_logger" => {
124                    opts.env_logger = true;
125                }
126                _ => return Err(syn::Error::new_spanned(&attr, "unknown attribute name")),
127            }
128        }
129        Ok(opts)
130    }
131}
132
133fn parse_int(lit: &syn::Lit) -> Result<usize, syn::Error> {
134    if let syn::Lit::Int(i) = lit {
135        if let Ok(v) = i.base10_parse() {
136            return Ok(v);
137        }
138    }
139    Err(syn::Error::new(lit.span(), "failed to parse int"))
140}
141
142fn token_stream_with_error(mut item: TokenStream, error: syn::Error) -> TokenStream {
143    item.extend(TokenStream::from(error.into_compile_error()));
144    item
145}