prest_init_macro/
lib.rs

1#![allow(dead_code)]
2// fork of the tokio main macro
3
4use proc_macro2::{Span, TokenStream, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::parse::{Parse, ParseStream, Parser};
7use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
8
9// syn::AttributeArgs does not implement syn::Parse
10type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
11
12#[derive(Debug, Default)]
13struct Config {
14    log_filters: Vec<(String, String)>,
15    manifest: Manifest,
16    tables: Vec<Ident>,
17}
18
19#[proc_macro_attribute]
20pub fn init(
21    args: proc_macro::TokenStream,
22    item: proc_macro::TokenStream,
23) -> proc_macro::TokenStream {
24    init_pc2(args.into(), item.into()).into()
25}
26
27pub(crate) fn init_pc2(args: TokenStream, item: TokenStream) -> TokenStream {
28    // If any of the steps for this macro fail, we still want to expand to an item that is as close
29    // to the expected output as possible. This helps out IDEs such that completions and other
30    // related features keep working.
31    let input: ItemFn = match syn::parse2(item.clone()) {
32        Ok(it) => it,
33        Err(e) => return token_stream_with_error(item, e),
34    };
35
36    if input.sig.ident != "main" || !input.sig.inputs.is_empty() {
37        let msg = "init macro should be only used on the main function without arguments";
38        let e = syn::Error::new_spanned(&input.sig.ident, msg);
39        return token_stream_with_error(expand(input, Default::default()), e);
40    }
41
42    let config = AttributeArgs::parse_terminated
43        .parse2(args)
44        .and_then(|args| build_config(&input, args));
45
46    match config {
47        Ok(config) => expand(input, config),
48        Err(e) => token_stream_with_error(expand(input, Default::default()), e),
49    }
50}
51
52fn build_config(input: &ItemFn, args: AttributeArgs) -> Result<Config, syn::Error> {
53    if input.sig.asyncness.is_none() {
54        let msg = "the `async` keyword is missing from the function declaration";
55        return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
56    }
57
58    // parse all source files in search for Table derivations
59
60    let mut log_filters = vec![];
61
62    for arg in args {
63        match arg {
64            syn::Meta::NameValue(namevalue) => {
65                let ident = namevalue
66                    .path
67                    .get_ident()
68                    .ok_or_else(|| {
69                        syn::Error::new_spanned(&namevalue, "Must have specified ident")
70                    })?
71                    .to_string()
72                    .to_lowercase();
73                match ident.as_str() {
74                    "log_filters" => {
75                        let args = match &namevalue.value {
76                            syn::Expr::Array(arr) => arr,
77                            expr => {
78                                return Err(syn::Error::new_spanned(
79                                    expr,
80                                    "Must be an array of tuples",
81                                ))
82                            }
83                        };
84                        for arg in args.elems.iter() {
85                            let tuple = match arg {
86                                syn::Expr::Tuple(tuple) => tuple,
87                                arg => return Err(syn::Error::new_spanned(arg, "Must be a tuple")),
88                            };
89                            let mut tuple = tuple.elems.iter();
90                            let filter = match tuple.next() {
91                                Some(syn::Expr::Lit(syn::ExprLit { lit, .. })) => lit,
92                                Some(v) => {
93                                    return Err(syn::Error::new_spanned(v, "Must be a literal"))
94                                }
95                                None => {
96                                    return Err(syn::Error::new_spanned(arg, "Missing log value"))
97                                }
98                            };
99                            let filter = parse_string(
100                                filter.clone(),
101                                syn::spanned::Spanned::span(filter),
102                                "log",
103                            )?;
104
105                            let level = match tuple.next() {
106                                Some(syn::Expr::Lit(syn::ExprLit { lit, .. })) => lit,
107                                Some(v) => {
108                                    return Err(syn::Error::new_spanned(v, "Must be a literal"))
109                                }
110                                None => {
111                                    return Err(syn::Error::new_spanned(arg, "Missing log value"))
112                                }
113                            };
114                            let level = parse_string(
115                                level.clone(),
116                                syn::spanned::Spanned::span(level),
117                                "filter",
118                            )?;
119
120                            if tuple.next().is_some() {
121                                return Err(syn::Error::new_spanned(
122                                    arg,
123                                    "Unexpected 3rd tuple item",
124                                ));
125                            }
126
127                            log_filters.push((filter, level));
128                        }
129                    }
130                    name => {
131                        let msg = format!(
132                            "Unknown attribute {name} is specified; expected `log_filters`",
133                        );
134                        return Err(syn::Error::new_spanned(namevalue, msg));
135                    }
136                }
137            }
138            other => {
139                return Err(syn::Error::new_spanned(
140                    other,
141                    "Unknown attribute inside the macro",
142                ));
143            }
144        }
145    }
146
147    let manifest = get_manifest();
148
149    use std::{fs, io};
150    fn find_tables(dir: fs::ReadDir, tables: &mut Vec<String>) -> io::Result<()> {
151        for file in dir {
152            let file = file?;
153            if file.file_name().to_string_lossy() == "target" {
154                continue;
155            }
156            match file.metadata()? {
157                data if data.is_dir() => find_tables(fs::read_dir(file.path())?, tables)?,
158                _ => {
159                    let content = std::fs::read_to_string(file.path())?;
160                    let mut expecting = false;
161                    for line in content.lines() {
162                        if expecting
163                            && (line.starts_with("pub") || line.starts_with("struct"))
164                            && line.contains("struct")
165                        {
166                            let struct_to_end = line.split("struct ").nth(1).unwrap();
167                            let struct_name = struct_to_end.split(" ").nth(0).unwrap();
168                            tables.push(struct_name.to_owned());
169                            expecting = false;
170                        }
171                        if line.starts_with("#[derive(") && line.contains("Table") {
172                            expecting = true;
173                        }
174                    }
175                }
176            };
177        }
178        Ok(())
179    }
180
181    let mut tables = vec![];
182    find_tables(fs::read_dir(&manifest.manifest_dir).unwrap(), &mut tables)
183        .expect("Tables search must succeed");
184    let tables = tables.into_iter().map(|t| ident(&t)).collect();
185
186    Ok(Config {
187        log_filters,
188        manifest,
189        tables,
190    })
191}
192
193fn expand(mut input: ItemFn, config: Config) -> TokenStream {
194    input.sig.asyncness = None;
195
196    // If type mismatch occurs, the current rustc points to the last statement.
197    // let (last_stmt_start_span, last_stmt_end_span) = {
198    let last_stmt_start_span = {
199        let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
200
201        // `Span` on stable Rust has a limitation that only points to the first
202        // token, not the whole tokens. We can work around this limitation by
203        // using the first/last span of the tokens like
204        // `syn::Error::new_spanned` does.
205        let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
206        // let end = last_stmt.last().map_or(start, |t| t.span());
207        // (start, end)
208        start
209    };
210
211    let body_ident = quote! { body };
212
213    let rt = quote_spanned! {last_stmt_start_span=>
214        #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)]
215        return prest::RT.block_on(#body_ident);
216    };
217
218    let Manifest {
219        name,
220        version,
221        manifest_dir,
222        persistent,
223        domain,
224    } = config.manifest;
225
226    let domain = match domain {
227        Some(v) => quote!( Some(#v) ),
228        None => quote!(None),
229    };
230    let init_config = quote!(
231        prest::APP_CONFIG._init(#manifest_dir, #name, #version, #persistent, #domain)
232    );
233
234    let filters = config.log_filters.into_iter().map(|(filter, level)| {
235        let level = ident(&level.to_ascii_uppercase());
236        quote!((#filter, prest::logs::Level::#level))
237    });
238
239    let init_tracing = quote!(
240        let __________ = prest::logs::init_tracing_subscriber(&[ #(#filters ,)* ])
241    );
242
243    let register_tables = config
244        .tables
245        .into_iter()
246        .map(|table| quote!( prest::DB._register_table(#table::schema()); ));
247
248    let body = input.body();
249    let body = quote! {
250        let _start = std::time::Instant::now();
251        #init_config;
252        #init_tracing;
253        prest::Lazy::force(&prest::RT);
254        let _ = prest::dotenv();
255        prest::Lazy::force(&prest::SYSTEM_INFO);
256        prest::Lazy::force(&prest::DB);
257        #(#register_tables)*
258        prest::RT.block_on(async {
259            prest::DB.migrate().await.expect("DB migration should be successful");
260        });
261        prest::info!(target: "prest", "Initialized {} v{} in {}ms", APP_CONFIG.name, &APP_CONFIG.version, _start.elapsed().as_millis());
262        prest::RT.set_ready();
263        let body = async #body;
264    };
265
266    input.into_tokens(body, rt)
267}
268
269fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
270    match int {
271        syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
272            Ok(value) => Ok(value),
273            Err(e) => Err(syn::Error::new(
274                span,
275                format!("Failed to parse value of `{field}` as integer: {e}"),
276            )),
277        },
278        _ => Err(syn::Error::new(
279            span,
280            format!("Failed to parse value of `{field}` as integer."),
281        )),
282    }
283}
284
285fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
286    match int {
287        syn::Lit::Str(s) => Ok(s.value()),
288        syn::Lit::Verbatim(s) => Ok(s.to_string()),
289        _ => Err(syn::Error::new(
290            span,
291            format!("Failed to parse value of `{field}` as string."),
292        )),
293    }
294}
295
296fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
297    match lit {
298        syn::Lit::Str(s) => {
299            let err = syn::Error::new(
300                span,
301                format!(
302                    "Failed to parse value of `{}` as path: \"{}\"",
303                    field,
304                    s.value()
305                ),
306            );
307            s.parse::<syn::Path>().map_err(|_| err.clone())
308        }
309        _ => Err(syn::Error::new(
310            span,
311            format!("Failed to parse value of `{field}` as path."),
312        )),
313    }
314}
315
316fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
317    match bool {
318        syn::Lit::Bool(b) => Ok(b.value),
319        _ => Err(syn::Error::new(
320            span,
321            format!("Failed to parse value of `{field}` as bool."),
322        )),
323    }
324}
325
326fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
327    tokens.extend(error.into_compile_error());
328    tokens
329}
330
331#[derive(Debug, Default)]
332struct Manifest {
333    name: String,
334    version: String,
335    manifest_dir: String,
336    persistent: bool,
337    domain: Option<String>,
338}
339
340fn get_manifest() -> Manifest {
341    let name = std::env::var("CARGO_PKG_NAME").unwrap();
342    let version = std::env::var("CARGO_PKG_VERSION").unwrap();
343
344    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
345    let manifest = std::fs::read_to_string(format!("{manifest_dir}/Cargo.toml")).unwrap();
346    let parsed = manifest.parse::<toml::Table>().unwrap();
347    let metadata = parsed.get("package").map(|t| t.get("metadata")).flatten();
348
349    let persistent = metadata
350        .map(|cfgs| cfgs.get("persistent").map(|v| v.as_bool()))
351        .flatten()
352        .flatten()
353        .unwrap_or(true);
354
355    let domain = metadata
356        .map(|cfgs| {
357            cfgs.get("domain")
358                .map(|v| v.as_str().map(ToString::to_string))
359        })
360        .flatten()
361        .flatten();
362
363    Manifest {
364        name,
365        version,
366        manifest_dir,
367        persistent,
368        domain,
369    }
370}
371
372struct ItemFn {
373    outer_attrs: Vec<Attribute>,
374    vis: Visibility,
375    sig: Signature,
376    brace_token: syn::token::Brace,
377    inner_attrs: Vec<Attribute>,
378    stmts: Vec<proc_macro2::TokenStream>,
379}
380
381impl ItemFn {
382    /// Get the body of the function item in a manner so that it can be
383    /// conveniently used with the `quote!` macro.
384    fn body(&self) -> Body<'_> {
385        Body {
386            brace_token: self.brace_token,
387            stmts: &self.stmts,
388        }
389    }
390
391    /// Convert our local function item into a token stream.
392    fn into_tokens(
393        self,
394        body: proc_macro2::TokenStream,
395        last_block: proc_macro2::TokenStream,
396    ) -> TokenStream {
397        let mut tokens = proc_macro2::TokenStream::new();
398        // Outer attributes are simply streamed as-is.
399        for attr in self.outer_attrs {
400            attr.to_tokens(&mut tokens);
401        }
402
403        // Inner attributes require extra care, since they're not supported on
404        // blocks (which is what we're expanded into) we instead lift them
405        // outside of the function. This matches the behavior of `syn`.
406        for mut attr in self.inner_attrs {
407            attr.style = syn::AttrStyle::Outer;
408            attr.to_tokens(&mut tokens);
409        }
410
411        self.vis.to_tokens(&mut tokens);
412        self.sig.to_tokens(&mut tokens);
413
414        self.brace_token.surround(&mut tokens, |tokens| {
415            body.to_tokens(tokens);
416            last_block.to_tokens(tokens);
417        });
418
419        tokens
420    }
421}
422
423impl Parse for ItemFn {
424    #[inline]
425    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
426        // This parse implementation has been largely lifted from `syn`, with
427        // the exception of:
428        // * We don't have access to the plumbing necessary to parse inner
429        //   attributes in-place.
430        // * We do our own statements parsing to avoid recursively parsing
431        //   entire statements and only look for the parts we're interested in.
432
433        let outer_attrs = input.call(Attribute::parse_outer)?;
434        let vis: Visibility = input.parse()?;
435        let sig: Signature = input.parse()?;
436
437        let content;
438        let brace_token = braced!(content in input);
439        let inner_attrs = Attribute::parse_inner(&content)?;
440
441        let mut buf = proc_macro2::TokenStream::new();
442        let mut stmts = Vec::new();
443
444        while !content.is_empty() {
445            if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
446                semi.to_tokens(&mut buf);
447                stmts.push(buf);
448                buf = proc_macro2::TokenStream::new();
449                continue;
450            }
451
452            // Parse a single token tree and extend our current buffer with it.
453            // This avoids parsing the entire content of the sub-tree.
454            buf.extend([content.parse::<TokenTree>()?]);
455        }
456
457        if !buf.is_empty() {
458            stmts.push(buf);
459        }
460
461        Ok(Self {
462            outer_attrs,
463            vis,
464            sig,
465            brace_token,
466            inner_attrs,
467            stmts,
468        })
469    }
470}
471
472struct Body<'a> {
473    brace_token: syn::token::Brace,
474    // Statements, with terminating `;`.
475    stmts: &'a [TokenStream],
476}
477
478impl ToTokens for Body<'_> {
479    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
480        self.brace_token.surround(tokens, |tokens| {
481            for stmt in self.stmts {
482                stmt.to_tokens(tokens);
483            }
484        });
485    }
486}
487
488fn ident(name: &str) -> Ident {
489    Ident::new(name, Span::call_site())
490}