tailcall-impl 2.2.0

The procedural macro implementation for the tailcall crate
Documentation
use syn::{
    fold::{self, Fold},
    parse2, parse_quote, Error, Expr, ExprBlock, ExprIf, ExprMacro, ExprMatch, ExprReturn, ExprTry,
    ItemFn, Stmt,
};

use crate::call_syntax::{expand_call_macro, is_tailcall_macro};

pub struct TailPositionRewriter {
    error: Option<Error>,
}

impl TailPositionRewriter {
    pub fn rewrite(block: syn::Block) -> Result<syn::Block, Error> {
        let mut rewriter = Self { error: None };
        let block = rewriter.rewrite_tail_block(block);

        match rewriter.error {
            Some(error) => Err(error),
            None => Ok(block),
        }
    }

    fn rewrite_tail_block(&mut self, mut block: syn::Block) -> syn::Block {
        let last_stmt = block.stmts.pop();
        block.stmts = block
            .stmts
            .into_iter()
            .map(|stmt| self.fold_stmt(stmt))
            .collect();

        if let Some(stmt) = last_stmt {
            block.stmts.push(match stmt {
                Stmt::Expr(expr) => Stmt::Expr(self.rewrite_tail_expr(expr)),
                Stmt::Semi(expr, semi) => Stmt::Semi(self.fold_expr(expr), semi),
                Stmt::Local(local) => Stmt::Local(self.fold_local(local)),
                Stmt::Item(item) => Stmt::Item(item),
            });
        }

        block
    }

    fn rewrite_tail_expr(&mut self, expr: Expr) -> Expr {
        match expr {
            Expr::Block(ExprBlock {
                attrs,
                label,
                block,
            }) => Expr::Block(ExprBlock {
                attrs,
                label,
                block: self.rewrite_tail_block(block),
            }),
            Expr::If(ExprIf {
                attrs,
                if_token,
                cond,
                then_branch,
                else_branch,
            }) => Expr::If(ExprIf {
                attrs,
                if_token,
                cond: Box::new(self.fold_expr(*cond)),
                then_branch: self.rewrite_tail_block(then_branch),
                else_branch: else_branch.map(|(else_token, expr)| {
                    (else_token, Box::new(self.rewrite_tail_expr(*expr)))
                }),
            }),
            Expr::Match(ExprMatch {
                attrs,
                match_token,
                expr,
                brace_token,
                arms,
            }) => Expr::Match(ExprMatch {
                attrs,
                match_token,
                expr: Box::new(self.fold_expr(*expr)),
                brace_token,
                arms: arms
                    .into_iter()
                    .map(|mut arm| {
                        if let Some((if_token, guard)) = arm.guard.take() {
                            arm.guard = Some((if_token, Box::new(self.fold_expr(*guard))));
                        }
                        arm.body = Box::new(self.rewrite_tail_expr(*arm.body));
                        arm
                    })
                    .collect(),
            }),
            Expr::Macro(expr_macro) if is_tailcall_macro(&expr_macro.mac.path) => {
                expand_call_expr(expr_macro)
            }
            expr => {
                let expr = self.fold_expr(expr);
                parse_quote! { tailcall::runtime::Thunk::value(#expr) }
            }
        }
    }

    fn reject(&mut self, error: Error) {
        if let Some(existing) = &mut self.error {
            existing.combine(error);
        } else {
            self.error = Some(error);
        }
    }
}

impl Fold for TailPositionRewriter {
    fn fold_expr(&mut self, expr: Expr) -> Expr {
        match expr {
            Expr::Return(ExprReturn {
                attrs,
                return_token,
                expr: Some(expr),
            }) => Expr::Return(ExprReturn {
                attrs,
                return_token,
                expr: Some(Box::new(self.rewrite_tail_expr(*expr))),
            }),
            Expr::Try(ExprTry {
                attrs,
                expr,
                question_token,
            }) => {
                self.reject(Error::new_spanned(
                    question_token,
                    "the `?` operator is not supported inside #[tailcall] functions on stable Rust; use `match` or explicit early returns instead",
                ));

                Expr::Try(ExprTry {
                    attrs,
                    expr: Box::new(self.fold_expr(*expr)),
                    question_token,
                })
            }
            Expr::Macro(expr_macro) if is_tailcall_macro(&expr_macro.mac.path) => {
                self.reject(Error::new_spanned(
                    &expr_macro,
                    "tailcall::call! must be used in tail position",
                ));
                expand_call_expr(expr_macro)
            }
            expr => fold::fold_expr(self, expr),
        }
    }

    fn fold_expr_closure(&mut self, expr: syn::ExprClosure) -> syn::ExprClosure {
        expr
    }

    fn fold_item_fn(&mut self, item_fn: ItemFn) -> ItemFn {
        item_fn
    }
}

fn expand_call_expr(expr_macro: ExprMacro) -> Expr {
    parse2(expand_call_macro(expr_macro.mac.tokens))
        .expect("tailcall::call! should expand to an expression")
}