1#![allow(dead_code)]
2use proc_macro2::{Span, TokenStream, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::parse::{Parse, ParseStream, Parser};
7use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
8
9type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
11
12#[derive(Debug, Default)]
13struct Config {
14 log_filters: Vec<(String, String)>,
15 manifest: Manifest,
16 tables: Vec<Ident>,
17}
18
19#[proc_macro_attribute]
20pub fn init(
21 args: proc_macro::TokenStream,
22 item: proc_macro::TokenStream,
23) -> proc_macro::TokenStream {
24 init_pc2(args.into(), item.into()).into()
25}
26
27pub(crate) fn init_pc2(args: TokenStream, item: TokenStream) -> TokenStream {
28 let input: ItemFn = match syn::parse2(item.clone()) {
32 Ok(it) => it,
33 Err(e) => return token_stream_with_error(item, e),
34 };
35
36 if input.sig.ident != "main" || !input.sig.inputs.is_empty() {
37 let msg = "init macro should be only used on the main function without arguments";
38 let e = syn::Error::new_spanned(&input.sig.ident, msg);
39 return token_stream_with_error(expand(input, Default::default()), e);
40 }
41
42 let config = AttributeArgs::parse_terminated
43 .parse2(args)
44 .and_then(|args| build_config(&input, args));
45
46 match config {
47 Ok(config) => expand(input, config),
48 Err(e) => token_stream_with_error(expand(input, Default::default()), e),
49 }
50}
51
52fn build_config(input: &ItemFn, args: AttributeArgs) -> Result<Config, syn::Error> {
53 if input.sig.asyncness.is_none() {
54 let msg = "the `async` keyword is missing from the function declaration";
55 return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
56 }
57
58 let mut log_filters = vec![];
61
62 for arg in args {
63 match arg {
64 syn::Meta::NameValue(namevalue) => {
65 let ident = namevalue
66 .path
67 .get_ident()
68 .ok_or_else(|| {
69 syn::Error::new_spanned(&namevalue, "Must have specified ident")
70 })?
71 .to_string()
72 .to_lowercase();
73 match ident.as_str() {
74 "log_filters" => {
75 let args = match &namevalue.value {
76 syn::Expr::Array(arr) => arr,
77 expr => {
78 return Err(syn::Error::new_spanned(
79 expr,
80 "Must be an array of tuples",
81 ))
82 }
83 };
84 for arg in args.elems.iter() {
85 let tuple = match arg {
86 syn::Expr::Tuple(tuple) => tuple,
87 arg => return Err(syn::Error::new_spanned(arg, "Must be a tuple")),
88 };
89 let mut tuple = tuple.elems.iter();
90 let filter = match tuple.next() {
91 Some(syn::Expr::Lit(syn::ExprLit { lit, .. })) => lit,
92 Some(v) => {
93 return Err(syn::Error::new_spanned(v, "Must be a literal"))
94 }
95 None => {
96 return Err(syn::Error::new_spanned(arg, "Missing log value"))
97 }
98 };
99 let filter = parse_string(
100 filter.clone(),
101 syn::spanned::Spanned::span(filter),
102 "log",
103 )?;
104
105 let level = match tuple.next() {
106 Some(syn::Expr::Lit(syn::ExprLit { lit, .. })) => lit,
107 Some(v) => {
108 return Err(syn::Error::new_spanned(v, "Must be a literal"))
109 }
110 None => {
111 return Err(syn::Error::new_spanned(arg, "Missing log value"))
112 }
113 };
114 let level = parse_string(
115 level.clone(),
116 syn::spanned::Spanned::span(level),
117 "filter",
118 )?;
119
120 if tuple.next().is_some() {
121 return Err(syn::Error::new_spanned(
122 arg,
123 "Unexpected 3rd tuple item",
124 ));
125 }
126
127 log_filters.push((filter, level));
128 }
129 }
130 name => {
131 let msg = format!(
132 "Unknown attribute {name} is specified; expected `log_filters`",
133 );
134 return Err(syn::Error::new_spanned(namevalue, msg));
135 }
136 }
137 }
138 other => {
139 return Err(syn::Error::new_spanned(
140 other,
141 "Unknown attribute inside the macro",
142 ));
143 }
144 }
145 }
146
147 let manifest = get_manifest();
148
149 use std::{fs, io};
150 fn find_tables(dir: fs::ReadDir, tables: &mut Vec<String>) -> io::Result<()> {
151 for file in dir {
152 let file = file?;
153 if file.file_name().to_string_lossy() == "target" {
154 continue;
155 }
156 match file.metadata()? {
157 data if data.is_dir() => find_tables(fs::read_dir(file.path())?, tables)?,
158 _ => {
159 let content = std::fs::read_to_string(file.path())?;
160 let mut expecting = false;
161 for line in content.lines() {
162 if expecting
163 && (line.starts_with("pub") || line.starts_with("struct"))
164 && line.contains("struct")
165 {
166 let struct_to_end = line.split("struct ").nth(1).unwrap();
167 let struct_name = struct_to_end.split(" ").nth(0).unwrap();
168 tables.push(struct_name.to_owned());
169 expecting = false;
170 }
171 if line.starts_with("#[derive(") && line.contains("Table") {
172 expecting = true;
173 }
174 }
175 }
176 };
177 }
178 Ok(())
179 }
180
181 let mut tables = vec![];
182 find_tables(fs::read_dir(&manifest.manifest_dir).unwrap(), &mut tables)
183 .expect("Tables search must succeed");
184 let tables = tables.into_iter().map(|t| ident(&t)).collect();
185
186 Ok(Config {
187 log_filters,
188 manifest,
189 tables,
190 })
191}
192
193fn expand(mut input: ItemFn, config: Config) -> TokenStream {
194 input.sig.asyncness = None;
195
196 let last_stmt_start_span = {
199 let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
200
201 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
206 start
209 };
210
211 let body_ident = quote! { body };
212
213 let rt = quote_spanned! {last_stmt_start_span=>
214 #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)]
215 return prest::RT.block_on(#body_ident);
216 };
217
218 let Manifest {
219 name,
220 version,
221 manifest_dir,
222 persistent,
223 domain,
224 } = config.manifest;
225
226 let domain = match domain {
227 Some(v) => quote!( Some(#v) ),
228 None => quote!(None),
229 };
230 let init_config = quote!(
231 prest::APP_CONFIG._init(#manifest_dir, #name, #version, #persistent, #domain)
232 );
233
234 let filters = config.log_filters.into_iter().map(|(filter, level)| {
235 let level = ident(&level.to_ascii_uppercase());
236 quote!((#filter, prest::logs::Level::#level))
237 });
238
239 let init_tracing = quote!(
240 let __________ = prest::logs::init_tracing_subscriber(&[ #(#filters ,)* ])
241 );
242
243 let register_tables = config
244 .tables
245 .into_iter()
246 .map(|table| quote!( prest::DB._register_table(#table::schema()); ));
247
248 let body = input.body();
249 let body = quote! {
250 let _start = std::time::Instant::now();
251 #init_config;
252 #init_tracing;
253 prest::Lazy::force(&prest::RT);
254 let _ = prest::dotenv();
255 prest::Lazy::force(&prest::SYSTEM_INFO);
256 prest::Lazy::force(&prest::DB);
257 #(#register_tables)*
258 prest::RT.block_on(async {
259 prest::DB.migrate().await.expect("DB migration should be successful");
260 });
261 prest::info!(target: "prest", "Initialized {} v{} in {}ms", APP_CONFIG.name, &APP_CONFIG.version, _start.elapsed().as_millis());
262 prest::RT.set_ready();
263 let body = async #body;
264 };
265
266 input.into_tokens(body, rt)
267}
268
269fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
270 match int {
271 syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
272 Ok(value) => Ok(value),
273 Err(e) => Err(syn::Error::new(
274 span,
275 format!("Failed to parse value of `{field}` as integer: {e}"),
276 )),
277 },
278 _ => Err(syn::Error::new(
279 span,
280 format!("Failed to parse value of `{field}` as integer."),
281 )),
282 }
283}
284
285fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
286 match int {
287 syn::Lit::Str(s) => Ok(s.value()),
288 syn::Lit::Verbatim(s) => Ok(s.to_string()),
289 _ => Err(syn::Error::new(
290 span,
291 format!("Failed to parse value of `{field}` as string."),
292 )),
293 }
294}
295
296fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
297 match lit {
298 syn::Lit::Str(s) => {
299 let err = syn::Error::new(
300 span,
301 format!(
302 "Failed to parse value of `{}` as path: \"{}\"",
303 field,
304 s.value()
305 ),
306 );
307 s.parse::<syn::Path>().map_err(|_| err.clone())
308 }
309 _ => Err(syn::Error::new(
310 span,
311 format!("Failed to parse value of `{field}` as path."),
312 )),
313 }
314}
315
316fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
317 match bool {
318 syn::Lit::Bool(b) => Ok(b.value),
319 _ => Err(syn::Error::new(
320 span,
321 format!("Failed to parse value of `{field}` as bool."),
322 )),
323 }
324}
325
326fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
327 tokens.extend(error.into_compile_error());
328 tokens
329}
330
331#[derive(Debug, Default)]
332struct Manifest {
333 name: String,
334 version: String,
335 manifest_dir: String,
336 persistent: bool,
337 domain: Option<String>,
338}
339
340fn get_manifest() -> Manifest {
341 let name = std::env::var("CARGO_PKG_NAME").unwrap();
342 let version = std::env::var("CARGO_PKG_VERSION").unwrap();
343
344 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
345 let manifest = std::fs::read_to_string(format!("{manifest_dir}/Cargo.toml")).unwrap();
346 let parsed = manifest.parse::<toml::Table>().unwrap();
347 let metadata = parsed.get("package").map(|t| t.get("metadata")).flatten();
348
349 let persistent = metadata
350 .map(|cfgs| cfgs.get("persistent").map(|v| v.as_bool()))
351 .flatten()
352 .flatten()
353 .unwrap_or(true);
354
355 let domain = metadata
356 .map(|cfgs| {
357 cfgs.get("domain")
358 .map(|v| v.as_str().map(ToString::to_string))
359 })
360 .flatten()
361 .flatten();
362
363 Manifest {
364 name,
365 version,
366 manifest_dir,
367 persistent,
368 domain,
369 }
370}
371
372struct ItemFn {
373 outer_attrs: Vec<Attribute>,
374 vis: Visibility,
375 sig: Signature,
376 brace_token: syn::token::Brace,
377 inner_attrs: Vec<Attribute>,
378 stmts: Vec<proc_macro2::TokenStream>,
379}
380
381impl ItemFn {
382 fn body(&self) -> Body<'_> {
385 Body {
386 brace_token: self.brace_token,
387 stmts: &self.stmts,
388 }
389 }
390
391 fn into_tokens(
393 self,
394 body: proc_macro2::TokenStream,
395 last_block: proc_macro2::TokenStream,
396 ) -> TokenStream {
397 let mut tokens = proc_macro2::TokenStream::new();
398 for attr in self.outer_attrs {
400 attr.to_tokens(&mut tokens);
401 }
402
403 for mut attr in self.inner_attrs {
407 attr.style = syn::AttrStyle::Outer;
408 attr.to_tokens(&mut tokens);
409 }
410
411 self.vis.to_tokens(&mut tokens);
412 self.sig.to_tokens(&mut tokens);
413
414 self.brace_token.surround(&mut tokens, |tokens| {
415 body.to_tokens(tokens);
416 last_block.to_tokens(tokens);
417 });
418
419 tokens
420 }
421}
422
423impl Parse for ItemFn {
424 #[inline]
425 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
426 let outer_attrs = input.call(Attribute::parse_outer)?;
434 let vis: Visibility = input.parse()?;
435 let sig: Signature = input.parse()?;
436
437 let content;
438 let brace_token = braced!(content in input);
439 let inner_attrs = Attribute::parse_inner(&content)?;
440
441 let mut buf = proc_macro2::TokenStream::new();
442 let mut stmts = Vec::new();
443
444 while !content.is_empty() {
445 if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
446 semi.to_tokens(&mut buf);
447 stmts.push(buf);
448 buf = proc_macro2::TokenStream::new();
449 continue;
450 }
451
452 buf.extend([content.parse::<TokenTree>()?]);
455 }
456
457 if !buf.is_empty() {
458 stmts.push(buf);
459 }
460
461 Ok(Self {
462 outer_attrs,
463 vis,
464 sig,
465 brace_token,
466 inner_attrs,
467 stmts,
468 })
469 }
470}
471
472struct Body<'a> {
473 brace_token: syn::token::Brace,
474 stmts: &'a [TokenStream],
476}
477
478impl ToTokens for Body<'_> {
479 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
480 self.brace_token.surround(tokens, |tokens| {
481 for stmt in self.stmts {
482 stmt.to_tokens(tokens);
483 }
484 });
485 }
486}
487
488fn ident(name: &str) -> Ident {
489 Ident::new(name, Span::call_site())
490}