tailcall-impl 2.0.0

The procedural macro implementation for the tailcall crate
Documentation
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
    fold::{self, Fold},
    parse2, parse_quote, Error, Expr, ExprBlock, ExprCall, ExprIf, ExprMacro, ExprMatch,
    ExprMethodCall, ExprReturn, ExprTry, ImplItemMethod, ItemFn, Signature, Stmt,
};

use super::helpers::{
    function_argument_exprs, helper_method_call_tokens, helper_path_from_call, helper_signature,
    is_tailcall_macro, method_helper_signature,
};

pub fn apply_fn_tailcall_transform(item_fn: ItemFn) -> TokenStream {
    match TailcallTransform::new(item_fn).expand() {
        Ok(output) => output,
        Err(error) => error.to_compile_error(),
    }
}

pub fn expand_call_macro(tokens: TokenStream) -> TokenStream {
    if let Ok(expr_call) = parse2::<ExprCall>(tokens.clone()) {
        return match helper_path_from_call(&expr_call) {
            Ok(func) => {
                let args = expr_call.args;
                quote! { #func(#args) }
            }
            Err(error) => error.to_compile_error(),
        };
    }

    if let Ok(expr_method_call) = parse2::<ExprMethodCall>(tokens.clone()) {
        return match helper_method_call_tokens(&expr_method_call) {
            Ok(tokens) => tokens,
            Err(error) => error.to_compile_error(),
        };
    }

    Error::new(
        Span::call_site(),
        "tailcall::call! expects either `path(args...)` or `self.method(args...)`",
    )
    .to_compile_error()
}

pub fn apply_method_tailcall_transform(method: ImplItemMethod) -> TokenStream {
    match TailcallMethodTransform::new(method).expand() {
        Ok(output) => output,
        Err(error) => error.to_compile_error(),
    }
}

struct TailcallTransform {
    item_fn: ItemFn,
}

struct TailcallMethodTransform {
    method: ImplItemMethod,
}

impl TailcallMethodTransform {
    fn new(method: ImplItemMethod) -> Self {
        Self { method }
    }

    fn expand(self) -> Result<TokenStream, Error> {
        let ImplItemMethod {
            attrs,
            vis,
            defaultness,
            sig,
            block,
        } = self.method;

        reject_unsupported_signature(&sig)?;

        let helper_sig = method_helper_signature(&sig)?;
        let helper_fn_ident = &helper_sig.ident;
        let helper_args = function_argument_exprs(&sig)?;
        let helper_block = TailPositionRewriter::rewrite(block)?;

        Ok(quote! {
            #(#attrs)*
            #defaultness #vis #sig {
                tailcall::trampoline::run(Self::#helper_fn_ident(#(#helper_args),*))
            }

            #[doc(hidden)]
            #[inline(always)]
            #helper_sig {
                tailcall::trampoline::call(move || #helper_block)
            }
        })
    }
}

impl TailcallTransform {
    fn new(item_fn: ItemFn) -> Self {
        Self { item_fn }
    }

    fn expand(self) -> Result<TokenStream, Error> {
        let ItemFn {
            attrs,
            vis,
            sig,
            block,
        } = self.item_fn;

        reject_unsupported_signature(&sig)?;

        let helper_sig = helper_signature(&sig);
        let helper_fn_ident = &helper_sig.ident;
        let helper_args = function_argument_exprs(&sig)?;
        let helper_block = TailPositionRewriter::rewrite(*block)?;

        Ok(quote! {
            #(#attrs)*
            #vis #sig {
                tailcall::trampoline::run(#helper_fn_ident(#(#helper_args),*))
            }

            #[doc(hidden)]
            #[inline(always)]
            #helper_sig {
                tailcall::trampoline::call(move || #helper_block)
            }
        })
    }
}

fn reject_unsupported_signature(sig: &Signature) -> Result<(), Error> {
    if sig.constness.is_some() {
        return Err(Error::new_spanned(
            sig.constness,
            "#[tailcall] does not support const functions",
        ));
    }

    if sig.asyncness.is_some() {
        return Err(Error::new_spanned(
            sig.asyncness,
            "#[tailcall] does not support async functions",
        ));
    }

    Ok(())
}

struct TailPositionRewriter {
    error: Option<Error>,
}

impl TailPositionRewriter {
    fn rewrite(block: syn::Block) -> Result<syn::Block, Error> {
        let mut rewriter = Self { error: None };
        let block = rewriter.rewrite_tail_block(block);

        match rewriter.error {
            Some(error) => Err(error),
            None => Ok(block),
        }
    }

    fn rewrite_tail_block(&mut self, mut block: syn::Block) -> syn::Block {
        let last_stmt = block.stmts.pop();
        block.stmts = block
            .stmts
            .into_iter()
            .map(|stmt| self.fold_stmt(stmt))
            .collect();

        if let Some(stmt) = last_stmt {
            block.stmts.push(match stmt {
                Stmt::Expr(expr) => Stmt::Expr(self.rewrite_tail_expr(expr)),
                Stmt::Semi(expr, semi) => Stmt::Semi(self.fold_expr(expr), semi),
                Stmt::Local(local) => Stmt::Local(self.fold_local(local)),
                Stmt::Item(item) => Stmt::Item(item),
            });
        }

        block
    }

    fn rewrite_tail_expr(&mut self, expr: Expr) -> Expr {
        match expr {
            Expr::Block(ExprBlock {
                attrs,
                label,
                block,
            }) => Expr::Block(ExprBlock {
                attrs,
                label,
                block: self.rewrite_tail_block(block),
            }),
            Expr::If(ExprIf {
                attrs,
                if_token,
                cond,
                then_branch,
                else_branch,
            }) => Expr::If(ExprIf {
                attrs,
                if_token,
                cond: Box::new(self.fold_expr(*cond)),
                then_branch: self.rewrite_tail_block(then_branch),
                else_branch: else_branch.map(|(else_token, expr)| {
                    (else_token, Box::new(self.rewrite_tail_expr(*expr)))
                }),
            }),
            Expr::Match(ExprMatch {
                attrs,
                match_token,
                expr,
                brace_token,
                arms,
            }) => Expr::Match(ExprMatch {
                attrs,
                match_token,
                expr: Box::new(self.fold_expr(*expr)),
                brace_token,
                arms: arms
                    .into_iter()
                    .map(|mut arm| {
                        if let Some((if_token, guard)) = arm.guard.take() {
                            arm.guard = Some((if_token, Box::new(self.fold_expr(*guard))));
                        }
                        arm.body = Box::new(self.rewrite_tail_expr(*arm.body));
                        arm
                    })
                    .collect(),
            }),
            Expr::Macro(expr_macro) if is_tailcall_macro(&expr_macro.mac.path) => {
                expand_call_expr(expr_macro)
            }
            expr => {
                let expr = self.fold_expr(expr);
                parse_quote! { tailcall::trampoline::done(#expr) }
            }
        }
    }

    fn reject(&mut self, error: Error) {
        if let Some(existing) = &mut self.error {
            existing.combine(error);
        } else {
            self.error = Some(error);
        }
    }
}

impl Fold for TailPositionRewriter {
    fn fold_expr(&mut self, expr: Expr) -> Expr {
        match expr {
            Expr::Return(ExprReturn {
                attrs,
                return_token,
                expr: Some(expr),
            }) => Expr::Return(ExprReturn {
                attrs,
                return_token,
                expr: Some(Box::new(self.rewrite_tail_expr(*expr))),
            }),
            Expr::Try(ExprTry {
                attrs,
                expr,
                question_token,
            }) => {
                self.reject(Error::new_spanned(
                    question_token,
                    "the `?` operator is not supported inside #[tailcall] functions on stable Rust; use `match` or explicit early returns instead",
                ));

                Expr::Try(ExprTry {
                    attrs,
                    expr: Box::new(self.fold_expr(*expr)),
                    question_token,
                })
            }
            Expr::Macro(expr_macro) if is_tailcall_macro(&expr_macro.mac.path) => {
                self.reject(Error::new_spanned(
                    &expr_macro,
                    "tailcall::call! must be used in tail position",
                ));
                expand_call_expr(expr_macro)
            }
            expr => fold::fold_expr(self, expr),
        }
    }

    fn fold_expr_closure(&mut self, expr: syn::ExprClosure) -> syn::ExprClosure {
        expr
    }

    fn fold_item_fn(&mut self, item_fn: ItemFn) -> ItemFn {
        item_fn
    }
}

fn expand_call_expr(expr_macro: ExprMacro) -> Expr {
    parse2(expand_call_macro(expr_macro.mac.tokens))
        .expect("tailcall::call! should expand to an expression")
}