1#![warn(missing_docs, unreachable_pub)]
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::parse::Parser;
8
9#[proc_macro_attribute]
43pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
44 transform(attr, item, false)
45}
46
47#[proc_macro_attribute]
49pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
50 transform(attr, item, true)
51}
52
53fn transform(attr: TokenStream, item: TokenStream, is_test: bool) -> TokenStream {
54 let opts = match Options::parse(attr.clone()) {
55 Ok(opts) => opts,
56 Err(e) => return token_stream_with_error(attr, e),
57 };
58 let mut func: syn::ItemFn = match syn::parse(item.clone()) {
59 Ok(func) => func,
60 Err(e) => return token_stream_with_error(item, e),
61 };
62
63 let head = if is_test {
64 quote! { #[::std::prelude::v1::test] }
65 } else {
66 quote! {}
67 };
68
69 let init = if is_test && opts.env_logger {
70 quote! { let _ = env_logger::builder().is_test(true).try_init(); }
71 } else {
72 quote! {}
73 };
74
75 let mut rt = quote! {
76 photonio::runtime::Builder::new()
77 };
78 if let Some(v) = opts.num_threads {
79 rt = quote! { #rt.num_threads(#v) }
80 }
81
82 func.sig.asyncness = None;
83 let block = func.block;
84 func.block = syn::parse2(quote! {
85 {
86 #init;
87 let block = async #block;
88 #rt.build().expect("failed to build runtime").block_on(block)
89 }
90 })
91 .unwrap();
92
93 quote! {
94 #head
95 #func
96 }
97 .into()
98}
99
100#[derive(Default)]
101struct Options {
102 num_threads: Option<usize>,
103 env_logger: bool,
105}
106
107type Attributes = syn::punctuated::Punctuated<syn::MetaNameValue, syn::Token![,]>;
108
109impl Options {
110 fn parse(input: TokenStream) -> Result<Self, syn::Error> {
111 let mut opts = Options::default();
112 let attrs = Attributes::parse_terminated.parse(input)?;
113 for attr in attrs {
114 let name = attr
115 .path
116 .get_ident()
117 .ok_or_else(|| syn::Error::new_spanned(&attr, "missing attribute name"))?
118 .to_string();
119 match name.as_str() {
120 "num_threads" => {
121 opts.num_threads = Some(parse_int(&attr.lit)?);
122 }
123 "env_logger" => {
124 opts.env_logger = true;
125 }
126 _ => return Err(syn::Error::new_spanned(&attr, "unknown attribute name")),
127 }
128 }
129 Ok(opts)
130 }
131}
132
133fn parse_int(lit: &syn::Lit) -> Result<usize, syn::Error> {
134 if let syn::Lit::Int(i) = lit {
135 if let Ok(v) = i.base10_parse() {
136 return Ok(v);
137 }
138 }
139 Err(syn::Error::new(lit.span(), "failed to parse int"))
140}
141
142fn token_stream_with_error(mut item: TokenStream, error: syn::Error) -> TokenStream {
143 item.extend(TokenStream::from(error.into_compile_error()));
144 item
145}