rkt_codegen 0.6.0

Procedural macros for the rkt web framework.
Documentation
use devise::ext::SpanDiagnosticExt;
use devise::{Result, Spanned};
use proc_macro2::{Span, TokenStream};

use super::EntryAttr;
use crate::attribute::suppress::Lint;
use crate::exports::{_ExitCode, _error, mixed};

/// `#[rkt::launch]`: generates a `main` function that calls the attributed
/// function to generate a `Rocket` instance. Then calls `.launch()` on the
/// returned instance inside of an `rkt::async_main`.
pub struct Launch;

/// Determines if `f` likely spawns an async task, returning the spawn call.
fn likely_spawns(f: &syn::ItemFn) -> Option<&syn::ExprCall> {
    use syn::visit::{self, Visit};

    struct SpawnFinder<'a>(Option<&'a syn::ExprCall>);

    impl<'ast> Visit<'ast> for SpawnFinder<'ast> {
        fn visit_expr_call(&mut self, i: &'ast syn::ExprCall) {
            if self.0.is_some() {
                return;
            }

            if let syn::Expr::Path(ref e) = *i.func {
                let mut segments = e.path.segments.clone();
                if let Some(last) = segments.pop() {
                    if last.value().ident != "spawn" {
                        return visit::visit_expr_call(self, i);
                    }

                    if let Some(prefix) = segments.pop() {
                        if prefix.value().ident == "tokio" {
                            self.0 = Some(i);
                            return;
                        }
                    }

                    if let Some(syn::Expr::Async(_)) = i.args.first() {
                        self.0 = Some(i);
                        return;
                    }
                }
            };

            visit::visit_expr_call(self, i);
        }
    }

    let mut v = SpawnFinder(None);
    v.visit_item_fn(f);
    v.0
}

impl EntryAttr for Launch {
    const REQUIRES_ASYNC: bool = false;

    fn function(f: &mut syn::ItemFn) -> Result<TokenStream> {
        if f.sig.ident == "main" {
            return Err(Span::call_site()
                .error("attribute cannot be applied to `main` function")
                .note("this attribute generates a `main` function")
                .span_note(f.sig.ident.span(), "this function cannot be `main`"));
        }

        // Always infer the type as `Rocket<Build>`.
        if let syn::ReturnType::Type(_, ref mut ty) = &mut f.sig.output {
            if let syn::Type::Infer(_) = &mut **ty {
                *ty = syn::parse_quote_spanned!(ty.span() => ::rkt::Rocket<::rkt::Build>);
            }
        }

        let ty = match &f.sig.output {
            syn::ReturnType::Type(_, ty) => ty,
            _ => {
                return Err(Span::call_site()
                    .error("attribute can only be applied to functions that return a value")
                    .span_note(f.sig.span(), "this function must return a value"))
            }
        };

        let block = &f.block;
        let rocket = quote_spanned!(mixed(ty.span()) => {
            let ___rocket: #ty = #block;
            let ___rocket: ::rkt::Rocket<::rkt::Build> = ___rocket;
            ___rocket
        });

        let launch = match f.sig.asyncness {
            Some(_) => quote_spanned!(ty.span() => async move { #rocket.launch().await }),
            None => quote_spanned!(ty.span() => #rocket.launch()),
        };

        let lint = Lint::SyncSpawn;
        if f.sig.asyncness.is_none() && lint.enabled(f.sig.asyncness.span()) {
            if let Some(call) = likely_spawns(f) {
                call.span()
                    .warning("task is being spawned outside an async context")
                    .span_help(
                        f.sig.span(),
                        "declare this function as `async fn` \
                                              to require async execution",
                    )
                    .span_note(Span::call_site(), "`#[launch]` call is here")
                    .note(lint.how_to_suppress())
                    .emit_as_expr_tokens();
            }
        }

        let (vis, mut sig) = (&f.vis, f.sig.clone());
        sig.ident = syn::Ident::new("main", f.sig.ident.span());
        let ret_ty = _ExitCode.respanned(ty.span());
        sig.output = syn::parse_quote_spanned!(ty.span() => -> #ret_ty);
        sig.asyncness = None;

        Ok(quote_spanned!(block.span() =>
            #[allow(dead_code)] #f

            #vis #sig {
                #_error::Error::report(::rkt::async_main(#launch))
            }
        ))
    }
}