1#![allow(dead_code)]
2use proc_macro2::{Span, TokenStream, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::parse::{Parse, ParseStream, Parser};
7use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
8
9type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
11
12#[derive(Debug, Default)]
13struct Config {
14 log_filters: Vec<(String, String)>,
15 manifest: Manifest,
16 tables: Vec<Ident>,
17}
18
19#[proc_macro_attribute]
20pub fn init(
21 args: proc_macro::TokenStream,
22 item: proc_macro::TokenStream,
23) -> proc_macro::TokenStream {
24 init_pc2(args.into(), item.into()).into()
25}
26
27pub(crate) fn init_pc2(args: TokenStream, item: TokenStream) -> TokenStream {
28 let input: ItemFn = match syn::parse2(item.clone()) {
32 Ok(it) => it,
33 Err(e) => return token_stream_with_error(item, e),
34 };
35
36 if input.sig.ident != "main" || !input.sig.inputs.is_empty() {
37 let msg = "init macro should be only used on the main function without arguments";
38 let e = syn::Error::new_spanned(&input.sig.ident, msg);
39 return token_stream_with_error(expand(input, Default::default()), e);
40 }
41
42 let config = AttributeArgs::parse_terminated
43 .parse2(args)
44 .and_then(|args| build_config(&input, args));
45
46 match config {
47 Ok(config) => expand(input, config),
48 Err(e) => token_stream_with_error(expand(input, Default::default()), e),
49 }
50}
51
52fn build_config(input: &ItemFn, args: AttributeArgs) -> Result<Config, syn::Error> {
53 if input.sig.asyncness.is_none() {
54 let msg = "the `async` keyword is missing from the function declaration";
55 return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
56 }
57
58 let mut log_filters = vec![];
61
62 for arg in args {
63 match arg {
64 syn::Meta::NameValue(namevalue) => {
65 let ident = namevalue
66 .path
67 .get_ident()
68 .ok_or_else(|| {
69 syn::Error::new_spanned(&namevalue, "Must have specified ident")
70 })?
71 .to_string()
72 .to_lowercase();
73 match ident.as_str() {
74 "log_filters" => {
75 let args = match &namevalue.value {
76 syn::Expr::Array(arr) => arr,
77 expr => {
78 return Err(syn::Error::new_spanned(
79 expr,
80 "Must be an array of tuples",
81 ))
82 }
83 };
84 for arg in args.elems.iter() {
85 let tuple = match arg {
86 syn::Expr::Tuple(tuple) => tuple,
87 arg => return Err(syn::Error::new_spanned(arg, "Must be a tuple")),
88 };
89 let mut tuple = tuple.elems.iter();
90 let filter = match tuple.next() {
91 Some(syn::Expr::Lit(syn::ExprLit { lit, .. })) => lit,
92 Some(v) => {
93 return Err(syn::Error::new_spanned(v, "Must be a literal"))
94 }
95 None => {
96 return Err(syn::Error::new_spanned(arg, "Missing log value"))
97 }
98 };
99 let filter = parse_string(
100 filter.clone(),
101 syn::spanned::Spanned::span(filter),
102 "log",
103 )?;
104
105 let level = match tuple.next() {
106 Some(syn::Expr::Lit(syn::ExprLit { lit, .. })) => lit,
107 Some(v) => {
108 return Err(syn::Error::new_spanned(v, "Must be a literal"))
109 }
110 None => {
111 return Err(syn::Error::new_spanned(arg, "Missing log value"))
112 }
113 };
114 let level = parse_string(
115 level.clone(),
116 syn::spanned::Spanned::span(level),
117 "filter",
118 )?;
119
120 if tuple.next().is_some() {
121 return Err(syn::Error::new_spanned(
122 arg,
123 "Unexpected 3rd tuple item",
124 ));
125 }
126
127 log_filters.push((filter, level));
128 }
129 }
130 name => {
131 let msg = format!(
132 "Unknown attribute {name} is specified; expected `log_filters`",
133 );
134 return Err(syn::Error::new_spanned(namevalue, msg));
135 }
136 }
137 }
138 other => {
139 return Err(syn::Error::new_spanned(
140 other,
141 "Unknown attribute inside the macro",
142 ));
143 }
144 }
145 }
146
147 let manifest = get_manifest();
148
149 use std::{fs, io};
150 fn find_tables(dir: fs::ReadDir, tables: &mut Vec<String>) -> io::Result<()> {
151 for file in dir {
152 let file = file?;
153 if file.file_name().to_string_lossy() == "target" {
154 continue;
155 }
156 match file.metadata()? {
157 data if data.is_dir() => find_tables(fs::read_dir(file.path())?, tables)?,
158 _ => {
159 let content = std::fs::read_to_string(file.path())?;
160 let mut expecting = false;
161 for line in content.lines() {
162 if expecting
163 && (line.starts_with("pub") || line.starts_with("struct"))
164 && line.contains("struct")
165 {
166 let struct_to_end = line.split("struct ").nth(1).unwrap();
167 let struct_name = struct_to_end.split(" ").nth(0).unwrap();
168 tables.push(struct_name.to_owned());
169 expecting = false;
170 }
171 if line.starts_with("#[derive(") && line.contains("Storage") {
172 expecting = true;
173 }
174 }
175 }
176 };
177 }
178 Ok(())
179 }
180
181 let mut tables = vec![];
182 find_tables(fs::read_dir(&manifest.manifest_dir).unwrap(), &mut tables)
183 .expect("Tables search must succeed");
184 let tables = tables.into_iter().map(|t| ident(&t)).collect();
185
186 Ok(Config {
187 log_filters,
188 manifest,
189 tables,
190 })
191}
192
193fn expand(mut input: ItemFn, config: Config) -> TokenStream {
194 input.sig.asyncness = None;
195
196 let last_stmt_start_span = {
199 let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
200
201 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
206 start
209 };
210
211 let body_ident = quote! { body };
212
213 let rt = quote_spanned! {last_stmt_start_span=>
214 #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)]
215 return prest::RT.block_on(#body_ident);
216 };
217
218 let Manifest {
219 name,
220 version,
221 manifest_dir,
222 persistent,
223 domain,
224 } = config.manifest;
225
226 let domain = match domain {
227 Some(v) => quote!( Some(#v) ),
228 None => quote!(None),
229 };
230 let init_config = quote!(
231 prest::APP_CONFIG._init(#manifest_dir, #name, #version, #persistent, #domain)
232 );
233
234 let filters = config.log_filters.into_iter().map(|(filter, level)| {
235 let level = ident(&level.to_ascii_uppercase());
236 quote!((#filter, prest::logs::Level::#level))
237 });
238
239 let init_tracing = quote!(
240 let __________ = std::thread::spawn(|| prest::logs::init_tracing_subscriber(&[ #(#filters ,)* ]))
241 );
242
243 let register_tables = config
244 .tables
245 .into_iter()
246 .map(|table| quote!( prest::DB._register_schema(#table::schema()); ));
247
248 let body = input.body();
249 let body = quote! {
250 let _start = std::time::Instant::now();
251 #init_config;
252 #init_tracing;
253 prest::Lazy::force(&prest::RT);
254 let _ = prest::dotenv();
255 std::thread::spawn(|| {
256 prest::Lazy::force(&prest::SYSTEM_INFO);
257 });
258 std::thread::spawn(|| {
259 prest::Lazy::force(&prest::DB);
260 #(#register_tables)*
261 });
262 prest::RT.block_on(async {
263 prest::DB.migrate().await.expect("DB migration should be successful");
264 });
265 prest::info!(target: "prest", "Initialized {} v{} in {}ms", APP_CONFIG.name, &APP_CONFIG.version, _start.elapsed().as_millis());
266 prest::RT.set_ready();
267 let body = async #body;
268 };
269
270 input.into_tokens(body, rt)
271}
272
273fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
274 match int {
275 syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
276 Ok(value) => Ok(value),
277 Err(e) => Err(syn::Error::new(
278 span,
279 format!("Failed to parse value of `{field}` as integer: {e}"),
280 )),
281 },
282 _ => Err(syn::Error::new(
283 span,
284 format!("Failed to parse value of `{field}` as integer."),
285 )),
286 }
287}
288
289fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
290 match int {
291 syn::Lit::Str(s) => Ok(s.value()),
292 syn::Lit::Verbatim(s) => Ok(s.to_string()),
293 _ => Err(syn::Error::new(
294 span,
295 format!("Failed to parse value of `{field}` as string."),
296 )),
297 }
298}
299
300fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
301 match lit {
302 syn::Lit::Str(s) => {
303 let err = syn::Error::new(
304 span,
305 format!(
306 "Failed to parse value of `{}` as path: \"{}\"",
307 field,
308 s.value()
309 ),
310 );
311 s.parse::<syn::Path>().map_err(|_| err.clone())
312 }
313 _ => Err(syn::Error::new(
314 span,
315 format!("Failed to parse value of `{field}` as path."),
316 )),
317 }
318}
319
320fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
321 match bool {
322 syn::Lit::Bool(b) => Ok(b.value),
323 _ => Err(syn::Error::new(
324 span,
325 format!("Failed to parse value of `{field}` as bool."),
326 )),
327 }
328}
329
330fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
331 tokens.extend(error.into_compile_error());
332 tokens
333}
334
335#[derive(Debug, Default)]
336struct Manifest {
337 name: String,
338 version: String,
339 manifest_dir: String,
340 persistent: bool,
341 domain: Option<String>,
342}
343
344fn get_manifest() -> Manifest {
345 let name = std::env::var("CARGO_PKG_NAME").unwrap();
346 let version = std::env::var("CARGO_PKG_VERSION").unwrap();
347
348 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
349 let manifest = std::fs::read_to_string(format!("{manifest_dir}/Cargo.toml")).unwrap();
350 let parsed = manifest.parse::<toml::Table>().unwrap();
351 let metadata = parsed.get("package").map(|t| t.get("metadata")).flatten();
352
353 let persistent = metadata
354 .map(|cfgs| cfgs.get("persistent").map(|v| v.as_bool()))
355 .flatten()
356 .flatten()
357 .unwrap_or(true);
358
359 let domain = metadata
360 .map(|cfgs| {
361 cfgs.get("domain")
362 .map(|v| v.as_str().map(ToString::to_string))
363 })
364 .flatten()
365 .flatten();
366
367 Manifest {
368 name,
369 version,
370 manifest_dir,
371 persistent,
372 domain,
373 }
374}
375
376struct ItemFn {
377 outer_attrs: Vec<Attribute>,
378 vis: Visibility,
379 sig: Signature,
380 brace_token: syn::token::Brace,
381 inner_attrs: Vec<Attribute>,
382 stmts: Vec<proc_macro2::TokenStream>,
383}
384
385impl ItemFn {
386 fn body(&self) -> Body<'_> {
389 Body {
390 brace_token: self.brace_token,
391 stmts: &self.stmts,
392 }
393 }
394
395 fn into_tokens(
397 self,
398 body: proc_macro2::TokenStream,
399 last_block: proc_macro2::TokenStream,
400 ) -> TokenStream {
401 let mut tokens = proc_macro2::TokenStream::new();
402 for attr in self.outer_attrs {
404 attr.to_tokens(&mut tokens);
405 }
406
407 for mut attr in self.inner_attrs {
411 attr.style = syn::AttrStyle::Outer;
412 attr.to_tokens(&mut tokens);
413 }
414
415 self.vis.to_tokens(&mut tokens);
416 self.sig.to_tokens(&mut tokens);
417
418 self.brace_token.surround(&mut tokens, |tokens| {
419 body.to_tokens(tokens);
420 last_block.to_tokens(tokens);
421 });
422
423 tokens
424 }
425}
426
427impl Parse for ItemFn {
428 #[inline]
429 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
430 let outer_attrs = input.call(Attribute::parse_outer)?;
438 let vis: Visibility = input.parse()?;
439 let sig: Signature = input.parse()?;
440
441 let content;
442 let brace_token = braced!(content in input);
443 let inner_attrs = Attribute::parse_inner(&content)?;
444
445 let mut buf = proc_macro2::TokenStream::new();
446 let mut stmts = Vec::new();
447
448 while !content.is_empty() {
449 if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
450 semi.to_tokens(&mut buf);
451 stmts.push(buf);
452 buf = proc_macro2::TokenStream::new();
453 continue;
454 }
455
456 buf.extend([content.parse::<TokenTree>()?]);
459 }
460
461 if !buf.is_empty() {
462 stmts.push(buf);
463 }
464
465 Ok(Self {
466 outer_attrs,
467 vis,
468 sig,
469 brace_token,
470 inner_attrs,
471 stmts,
472 })
473 }
474}
475
476struct Body<'a> {
477 brace_token: syn::token::Brace,
478 stmts: &'a [TokenStream],
480}
481
482impl ToTokens for Body<'_> {
483 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
484 self.brace_token.surround(tokens, |tokens| {
485 for stmt in self.stmts {
486 stmt.to_tokens(tokens);
487 }
488 });
489 }
490}
491
492fn ident(name: &str) -> Ident {
493 Ident::new(name, Span::call_site())
494}