prest_init_macro/
lib.rs

1#![allow(dead_code)]
2// fork of the tokio main macro
3
4use proc_macro2::{Span, TokenStream, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::parse::{Parse, ParseStream, Parser};
7use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
8
9// syn::AttributeArgs does not implement syn::Parse
10type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
11
12#[derive(Debug, Default)]
13struct Config {
14    log_filters: Vec<(String, String)>,
15    manifest: Manifest,
16    tables: Vec<Ident>,
17}
18
19#[proc_macro_attribute]
20pub fn init(
21    args: proc_macro::TokenStream,
22    item: proc_macro::TokenStream,
23) -> proc_macro::TokenStream {
24    init_pc2(args.into(), item.into()).into()
25}
26
27pub(crate) fn init_pc2(args: TokenStream, item: TokenStream) -> TokenStream {
28    // If any of the steps for this macro fail, we still want to expand to an item that is as close
29    // to the expected output as possible. This helps out IDEs such that completions and other
30    // related features keep working.
31    let input: ItemFn = match syn::parse2(item.clone()) {
32        Ok(it) => it,
33        Err(e) => return token_stream_with_error(item, e),
34    };
35
36    if input.sig.ident != "main" || !input.sig.inputs.is_empty() {
37        let msg = "init macro should be only used on the main function without arguments";
38        let e = syn::Error::new_spanned(&input.sig.ident, msg);
39        return token_stream_with_error(expand(input, Default::default()), e);
40    }
41
42    let config = AttributeArgs::parse_terminated
43        .parse2(args)
44        .and_then(|args| build_config(&input, args));
45
46    match config {
47        Ok(config) => expand(input, config),
48        Err(e) => token_stream_with_error(expand(input, Default::default()), e),
49    }
50}
51
52fn build_config(input: &ItemFn, args: AttributeArgs) -> Result<Config, syn::Error> {
53    if input.sig.asyncness.is_none() {
54        let msg = "the `async` keyword is missing from the function declaration";
55        return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
56    }
57
58    // parse all source files in search for Storage derivations
59
60    let mut log_filters = vec![];
61
62    for arg in args {
63        match arg {
64            syn::Meta::NameValue(namevalue) => {
65                let ident = namevalue
66                    .path
67                    .get_ident()
68                    .ok_or_else(|| {
69                        syn::Error::new_spanned(&namevalue, "Must have specified ident")
70                    })?
71                    .to_string()
72                    .to_lowercase();
73                match ident.as_str() {
74                    "log_filters" => {
75                        let args = match &namevalue.value {
76                            syn::Expr::Array(arr) => arr,
77                            expr => {
78                                return Err(syn::Error::new_spanned(
79                                    expr,
80                                    "Must be an array of tuples",
81                                ))
82                            }
83                        };
84                        for arg in args.elems.iter() {
85                            let tuple = match arg {
86                                syn::Expr::Tuple(tuple) => tuple,
87                                arg => return Err(syn::Error::new_spanned(arg, "Must be a tuple")),
88                            };
89                            let mut tuple = tuple.elems.iter();
90                            let filter = match tuple.next() {
91                                Some(syn::Expr::Lit(syn::ExprLit { lit, .. })) => lit,
92                                Some(v) => {
93                                    return Err(syn::Error::new_spanned(v, "Must be a literal"))
94                                }
95                                None => {
96                                    return Err(syn::Error::new_spanned(arg, "Missing log value"))
97                                }
98                            };
99                            let filter = parse_string(
100                                filter.clone(),
101                                syn::spanned::Spanned::span(filter),
102                                "log",
103                            )?;
104
105                            let level = match tuple.next() {
106                                Some(syn::Expr::Lit(syn::ExprLit { lit, .. })) => lit,
107                                Some(v) => {
108                                    return Err(syn::Error::new_spanned(v, "Must be a literal"))
109                                }
110                                None => {
111                                    return Err(syn::Error::new_spanned(arg, "Missing log value"))
112                                }
113                            };
114                            let level = parse_string(
115                                level.clone(),
116                                syn::spanned::Spanned::span(level),
117                                "filter",
118                            )?;
119
120                            if tuple.next().is_some() {
121                                return Err(syn::Error::new_spanned(
122                                    arg,
123                                    "Unexpected 3rd tuple item",
124                                ));
125                            }
126
127                            log_filters.push((filter, level));
128                        }
129                    }
130                    name => {
131                        let msg = format!(
132                            "Unknown attribute {name} is specified; expected `log_filters`",
133                        );
134                        return Err(syn::Error::new_spanned(namevalue, msg));
135                    }
136                }
137            }
138            other => {
139                return Err(syn::Error::new_spanned(
140                    other,
141                    "Unknown attribute inside the macro",
142                ));
143            }
144        }
145    }
146
147    let manifest = get_manifest();
148
149    use std::{fs, io};
150    fn find_tables(dir: fs::ReadDir, tables: &mut Vec<String>) -> io::Result<()> {
151        for file in dir {
152            let file = file?;
153            if file.file_name().to_string_lossy() == "target" {
154                continue;
155            }
156            match file.metadata()? {
157                data if data.is_dir() => find_tables(fs::read_dir(file.path())?, tables)?,
158                _ => {
159                    let content = std::fs::read_to_string(file.path())?;
160                    let mut expecting = false;
161                    for line in content.lines() {
162                        if expecting
163                            && (line.starts_with("pub") || line.starts_with("struct"))
164                            && line.contains("struct")
165                        {
166                            let struct_to_end = line.split("struct ").nth(1).unwrap();
167                            let struct_name = struct_to_end.split(" ").nth(0).unwrap();
168                            tables.push(struct_name.to_owned());
169                            expecting = false;
170                        }
171                        if line.starts_with("#[derive(") && line.contains("Storage") {
172                            expecting = true;
173                        }
174                    }
175                }
176            };
177        }
178        Ok(())
179    }
180
181    let mut tables = vec![];
182    find_tables(fs::read_dir(&manifest.manifest_dir).unwrap(), &mut tables)
183        .expect("Tables search must succeed");
184    let tables = tables.into_iter().map(|t| ident(&t)).collect();
185
186    Ok(Config {
187        log_filters,
188        manifest,
189        tables,
190    })
191}
192
193fn expand(mut input: ItemFn, config: Config) -> TokenStream {
194    input.sig.asyncness = None;
195
196    // If type mismatch occurs, the current rustc points to the last statement.
197    // let (last_stmt_start_span, last_stmt_end_span) = {
198    let last_stmt_start_span = {
199        let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
200
201        // `Span` on stable Rust has a limitation that only points to the first
202        // token, not the whole tokens. We can work around this limitation by
203        // using the first/last span of the tokens like
204        // `syn::Error::new_spanned` does.
205        let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
206        // let end = last_stmt.last().map_or(start, |t| t.span());
207        // (start, end)
208        start
209    };
210
211    let body_ident = quote! { body };
212
213    let rt = quote_spanned! {last_stmt_start_span=>
214        #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)]
215        return prest::RT.block_on(#body_ident);
216    };
217
218    let Manifest {
219        name,
220        version,
221        manifest_dir,
222        persistent,
223        domain,
224    } = config.manifest;
225
226    let domain = match domain {
227        Some(v) => quote!( Some(#v) ),
228        None => quote!(None),
229    };
230    let init_config = quote!(
231        prest::APP_CONFIG._init(#manifest_dir, #name, #version, #persistent, #domain)
232    );
233
234    let filters = config.log_filters.into_iter().map(|(filter, level)| {
235        let level = ident(&level.to_ascii_uppercase());
236        quote!((#filter, prest::logs::Level::#level))
237    });
238
239    let init_tracing = quote!(
240        let __________ = std::thread::spawn(|| prest::logs::init_tracing_subscriber(&[ #(#filters ,)* ]))
241    );
242
243    let register_tables = config
244        .tables
245        .into_iter()
246        .map(|table| quote!( prest::DB._register_schema(#table::schema()); ));
247
248    let body = input.body();
249    let body = quote! {
250        let _start = std::time::Instant::now();
251        #init_config;
252        #init_tracing;
253        prest::Lazy::force(&prest::RT);
254        let _ = prest::dotenv();
255        std::thread::spawn(|| {
256            prest::Lazy::force(&prest::SYSTEM_INFO);
257        });
258        std::thread::spawn(|| {
259            prest::Lazy::force(&prest::DB);
260            #(#register_tables)*
261        });
262        prest::RT.block_on(async {
263            prest::DB.migrate().await.expect("DB migration should be successful");
264        });
265        prest::info!(target: "prest", "Initialized {} v{} in {}ms", APP_CONFIG.name, &APP_CONFIG.version, _start.elapsed().as_millis());
266        prest::RT.set_ready();
267        let body = async #body;
268    };
269
270    input.into_tokens(body, rt)
271}
272
273fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
274    match int {
275        syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
276            Ok(value) => Ok(value),
277            Err(e) => Err(syn::Error::new(
278                span,
279                format!("Failed to parse value of `{field}` as integer: {e}"),
280            )),
281        },
282        _ => Err(syn::Error::new(
283            span,
284            format!("Failed to parse value of `{field}` as integer."),
285        )),
286    }
287}
288
289fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
290    match int {
291        syn::Lit::Str(s) => Ok(s.value()),
292        syn::Lit::Verbatim(s) => Ok(s.to_string()),
293        _ => Err(syn::Error::new(
294            span,
295            format!("Failed to parse value of `{field}` as string."),
296        )),
297    }
298}
299
300fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
301    match lit {
302        syn::Lit::Str(s) => {
303            let err = syn::Error::new(
304                span,
305                format!(
306                    "Failed to parse value of `{}` as path: \"{}\"",
307                    field,
308                    s.value()
309                ),
310            );
311            s.parse::<syn::Path>().map_err(|_| err.clone())
312        }
313        _ => Err(syn::Error::new(
314            span,
315            format!("Failed to parse value of `{field}` as path."),
316        )),
317    }
318}
319
320fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
321    match bool {
322        syn::Lit::Bool(b) => Ok(b.value),
323        _ => Err(syn::Error::new(
324            span,
325            format!("Failed to parse value of `{field}` as bool."),
326        )),
327    }
328}
329
330fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
331    tokens.extend(error.into_compile_error());
332    tokens
333}
334
335#[derive(Debug, Default)]
336struct Manifest {
337    name: String,
338    version: String,
339    manifest_dir: String,
340    persistent: bool,
341    domain: Option<String>,
342}
343
344fn get_manifest() -> Manifest {
345    let name = std::env::var("CARGO_PKG_NAME").unwrap();
346    let version = std::env::var("CARGO_PKG_VERSION").unwrap();
347
348    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
349    let manifest = std::fs::read_to_string(format!("{manifest_dir}/Cargo.toml")).unwrap();
350    let parsed = manifest.parse::<toml::Table>().unwrap();
351    let metadata = parsed.get("package").map(|t| t.get("metadata")).flatten();
352
353    let persistent = metadata
354        .map(|cfgs| cfgs.get("persistent").map(|v| v.as_bool()))
355        .flatten()
356        .flatten()
357        .unwrap_or(true);
358
359    let domain = metadata
360        .map(|cfgs| {
361            cfgs.get("domain")
362                .map(|v| v.as_str().map(ToString::to_string))
363        })
364        .flatten()
365        .flatten();
366
367    Manifest {
368        name,
369        version,
370        manifest_dir,
371        persistent,
372        domain,
373    }
374}
375
376struct ItemFn {
377    outer_attrs: Vec<Attribute>,
378    vis: Visibility,
379    sig: Signature,
380    brace_token: syn::token::Brace,
381    inner_attrs: Vec<Attribute>,
382    stmts: Vec<proc_macro2::TokenStream>,
383}
384
385impl ItemFn {
386    /// Get the body of the function item in a manner so that it can be
387    /// conveniently used with the `quote!` macro.
388    fn body(&self) -> Body<'_> {
389        Body {
390            brace_token: self.brace_token,
391            stmts: &self.stmts,
392        }
393    }
394
395    /// Convert our local function item into a token stream.
396    fn into_tokens(
397        self,
398        body: proc_macro2::TokenStream,
399        last_block: proc_macro2::TokenStream,
400    ) -> TokenStream {
401        let mut tokens = proc_macro2::TokenStream::new();
402        // Outer attributes are simply streamed as-is.
403        for attr in self.outer_attrs {
404            attr.to_tokens(&mut tokens);
405        }
406
407        // Inner attributes require extra care, since they're not supported on
408        // blocks (which is what we're expanded into) we instead lift them
409        // outside of the function. This matches the behavior of `syn`.
410        for mut attr in self.inner_attrs {
411            attr.style = syn::AttrStyle::Outer;
412            attr.to_tokens(&mut tokens);
413        }
414
415        self.vis.to_tokens(&mut tokens);
416        self.sig.to_tokens(&mut tokens);
417
418        self.brace_token.surround(&mut tokens, |tokens| {
419            body.to_tokens(tokens);
420            last_block.to_tokens(tokens);
421        });
422
423        tokens
424    }
425}
426
427impl Parse for ItemFn {
428    #[inline]
429    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
430        // This parse implementation has been largely lifted from `syn`, with
431        // the exception of:
432        // * We don't have access to the plumbing necessary to parse inner
433        //   attributes in-place.
434        // * We do our own statements parsing to avoid recursively parsing
435        //   entire statements and only look for the parts we're interested in.
436
437        let outer_attrs = input.call(Attribute::parse_outer)?;
438        let vis: Visibility = input.parse()?;
439        let sig: Signature = input.parse()?;
440
441        let content;
442        let brace_token = braced!(content in input);
443        let inner_attrs = Attribute::parse_inner(&content)?;
444
445        let mut buf = proc_macro2::TokenStream::new();
446        let mut stmts = Vec::new();
447
448        while !content.is_empty() {
449            if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
450                semi.to_tokens(&mut buf);
451                stmts.push(buf);
452                buf = proc_macro2::TokenStream::new();
453                continue;
454            }
455
456            // Parse a single token tree and extend our current buffer with it.
457            // This avoids parsing the entire content of the sub-tree.
458            buf.extend([content.parse::<TokenTree>()?]);
459        }
460
461        if !buf.is_empty() {
462            stmts.push(buf);
463        }
464
465        Ok(Self {
466            outer_attrs,
467            vis,
468            sig,
469            brace_token,
470            inner_attrs,
471            stmts,
472        })
473    }
474}
475
476struct Body<'a> {
477    brace_token: syn::token::Brace,
478    // Statements, with terminating `;`.
479    stmts: &'a [TokenStream],
480}
481
482impl ToTokens for Body<'_> {
483    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
484        self.brace_token.surround(tokens, |tokens| {
485            for stmt in self.stmts {
486                stmt.to_tokens(tokens);
487            }
488        });
489    }
490}
491
492fn ident(name: &str) -> Ident {
493    Ident::new(name, Span::call_site())
494}