use std::fmt::{self, Display, Write};
use crate::ast::*;
use crate::token::{CommentKind, NumberLiteral, StringLiteralKind, TokenValue};
macro_rules! impl_display {
($( fn $name:ident(&mut $s:ident, $arg:ident: &$ast_name:ident<'_>) -> fmt::Result $body:block )+) => {
impl<W: Write> AstPrinter<'_, W> {
$(
fn $name(&mut $s, $arg: &$ast_name<'_>) -> fmt::Result $body
)+
}
$(
impl Display for $ast_name<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
AstPrinter::new(f).$name(self)
}
}
)+
}
}
fn get_utf8_length(first_byte: u8) -> Option<usize> {
match first_byte {
0b0000_0000..=0b0111_1111 => Some(1),
0b1100_0000..=0b1101_1111 => Some(2),
0b1110_0000..=0b1110_1111 => Some(3),
0b1111_0000..=0b1111_0111 => Some(4),
_ => None,
}
}
struct AstPrinter<'w, W: Write> {
indent_level: usize,
writer: &'w mut W,
}
impl<'w, W: Write> AstPrinter<'w, W> {
fn new(writer: &'w mut W) -> Self {
Self {
indent_level: 0,
writer,
}
}
fn print_new_line(&mut self) -> fmt::Result {
writeln!(self.writer)?;
for _ in 0..self.indent_level {
write!(self.writer, " ")?
}
Ok(())
}
fn print_token(&mut self, token: &TokenReference<'_>) -> fmt::Result {
match &token.token.value {
TokenValue::Symbol(symbol) => write!(self.writer, "{}", symbol.as_str()),
TokenValue::Number(num_lit) => self.print_number_token(*num_lit),
TokenValue::String { value, kind } => self.print_string_token(value.as_ref(), *kind),
TokenValue::Comment { value, kind } => self.print_comment(value, *kind),
TokenValue::Whitespace(_) => unreachable!(),
TokenValue::Ident(s) => write!(self.writer, "{}", s),
TokenValue::Eof => unreachable!(),
}
}
fn print_number_token(&mut self, num_lit: NumberLiteral) -> fmt::Result {
match num_lit {
NumberLiteral::Float(v) => write!(self.writer, "{}", v),
NumberLiteral::Integer(v) => write!(self.writer, "{}", v),
}
}
fn print_string_token(&mut self, value: &[u8], kind: StringLiteralKind) -> fmt::Result {
match kind {
StringLiteralKind::SingleQuoted => write!(self.writer, "'")?,
StringLiteralKind::DoubleQuoted => write!(self.writer, "\"")?,
StringLiteralKind::Bracketed { level } => {
if level == 0 {
write!(self.writer, "[[")?
} else {
write!(self.writer, "[{:=>level$}[", "=", level = level)?
}
}
}
let quoted = !matches!(kind, StringLiteralKind::Bracketed { .. });
let iter = value.iter().enumerate();
let mut first = true;
for (i, &b) in iter {
let c = get_utf8_length(b)
.and_then(|len| value.get(i..(i + len)))
.and_then(|slice| std::str::from_utf8(slice).ok())
.map(|s| s.chars().next().unwrap())
.unwrap_or_else(|| b.into());
match c {
'\x07' if quoted => write!(self.writer, "\\a"),
'\x08' if quoted => write!(self.writer, "\\b"),
'\x0c' if quoted => write!(self.writer, "\\f"),
'\n' if quoted => write!(self.writer, "\\n"),
'\n' if first => write!(self.writer, "\n\n"),
'\r' if quoted => write!(self.writer, "\\r"),
'\r' => panic!("\\r cannot be represented in bracketed string literals"),
'\t' if quoted => write!(self.writer, "\\t"),
'\x0b' if quoted => write!(self.writer, "\\v"),
'\\' if quoted => write!(self.writer, "\\\\"),
'\'' if matches!(kind, StringLiteralKind::SingleQuoted) => {
write!(self.writer, "\\'")
}
'"' if matches!(kind, StringLiteralKind::DoubleQuoted) => {
write!(self.writer, "\\\"")
}
'\0'..='\x1f' | '\u{007f}'..='\u{00ff}' if quoted => {
write!(self.writer, "\\x{:02x}", c as u32)
}
_ => write!(self.writer, "{}", c),
}?;
first = false;
}
match kind {
StringLiteralKind::SingleQuoted => write!(self.writer, "'"),
StringLiteralKind::DoubleQuoted => write!(self.writer, "\""),
StringLiteralKind::Bracketed { level } => {
if level == 0 {
write!(self.writer, "]]")
} else {
write!(self.writer, "]{:=>level$}]", "=", level = level)
}
}
}
}
fn print_comment(&mut self, value: &str, kind: CommentKind) -> fmt::Result {
match kind {
CommentKind::Unbracketed => {
write!(self.writer, "--{}", value)?;
self.print_new_line()
}
CommentKind::Bracketed { level: 0 } => {
write!(self.writer, "--[[{}]]", value)
}
CommentKind::Bracketed { level } => {
write!(self.writer, "--[{1:=>level$}[{}]{1:=>level$}]", value, "=", level = level)
}
}
}
fn print_punctuated<T, R, F>(&mut self, punc: &Punctuated<'_, T>, f: F, put_space: bool) -> fmt::Result
where
F: Fn(&mut Self, &R) -> fmt::Result,
T: std::borrow::Borrow<R>,
{
for (v, sep) in &punc.pairs {
f(self, v.borrow())?;
if let Some(sep) = sep {
self.print_token(sep)?;
if put_space {
write!(self.writer, " ")?;
}
}
}
Ok(())
}
fn print_parenthesized_list<T, R, F>(&mut self, plist: &ParenthesizedList<'_, T>, f: F) -> fmt::Result
where
F: Fn(&mut Self, &R) -> fmt::Result,
T: std::borrow::Borrow<R>,
{
self.print_token(&plist.brackets.0)?;
self.print_punctuated(&plist.list, f, true)?;
self.print_token(&plist.brackets.1)
}
fn print_block_indent(&mut self, block: &Block<'_>) -> fmt::Result {
self.indent_level += 1;
if !block.statements.is_empty() {
self.print_new_line()?;
}
self.print_block(block)?;
self.indent_level -= 1;
Ok(())
}
}
impl_display![
fn print_nil_lit(&mut self, nil_lit: &NilLit<'_>) -> fmt::Result {
self.print_token(&nil_lit.0)
}
fn print_boolean_lit(&mut self, boolean_lit: &BooleanLit<'_>) -> fmt::Result {
self.print_token(&boolean_lit.0)
}
fn print_number_lit(&mut self, number_lit: &NumberLit<'_>) -> fmt::Result {
self.print_token(&number_lit.0)
}
fn print_string_lit(&mut self, string_lit: &StringLit<'_>) -> fmt::Result {
self.print_token(&string_lit.0)
}
fn print_name(&mut self, name: &Name<'_>) -> fmt::Result {
self.print_token(&name.0)
}
fn print_vararg(&mut self, vararg: &Vararg<'_>) -> fmt::Result {
self.print_token(&vararg.0)
}
fn print_bin_op_expr(&mut self, bin_op: &BinOpExpr<'_>) -> fmt::Result {
self.print_expr(&bin_op.left)?;
write!(self.writer, " {} ", bin_op.op.kind().as_symbol().as_str())?;
self.print_expr(&bin_op.right)
}
fn print_un_op_expr(&mut self, un_op: &UnOpExpr<'_>) -> fmt::Result {
write!(self.writer, "{}", un_op.op.kind().as_symbol().as_str())?;
self.print_expr(&un_op.right)
}
fn print_parenthesized_expr(&mut self, pexpr: &ParenthesizedExpr<'_>) -> fmt::Result {
self.print_token(&pexpr.brackets.0)?;
self.print_expr(&pexpr.expr)?;
self.print_token(&pexpr.brackets.1)
}
fn print_var_field(&mut self, var_field: &VarField<'_>) -> fmt::Result {
match var_field {
VarField::Expr { brackets, key } => {
self.print_token(&brackets.0)?;
self.print_expr(key)?;
self.print_token(&brackets.1)
}
VarField::Name { period, key } => {
self.print_token(period)?;
self.print_name(key)
}
}
}
fn print_var(&mut self, var: &Var<'_>) -> fmt::Result {
match var {
Var::Name(name) => self.print_name(name),
Var::Field(prefix_expr, var_field) => {
self.print_prefix_expr(prefix_expr)?;
self.print_var_field(var_field)
}
}
}
fn print_function_callee(&mut self, callee: &FunctionCallee<'_>) -> fmt::Result {
match callee {
FunctionCallee::Expr(prefix_expr) => self.print_prefix_expr(prefix_expr),
FunctionCallee::Method { object, colon, name } => {
self.print_prefix_expr(object)?;
self.print_token(colon)?;
self.print_name(name)
}
}
}
fn print_function_args(&mut self, args: &FunctionArgs<'_>) -> fmt::Result {
match args {
FunctionArgs::TableConstructor(tcons) => {
write!(self.writer, " ")?;
self.print_table_cons(tcons)
}
FunctionArgs::StringLit(string_lit) => {
write!(self.writer, " ")?;
self.print_string_lit(string_lit)
}
FunctionArgs::ParenthesizedList(plist) => {
self.print_parenthesized_list(plist, Self::print_expr)
}
}
}
fn print_function_call(&mut self, function_call: &FunctionCall<'_>) -> fmt::Result {
self.print_function_callee(&function_call.callee)?;
self.print_function_args(&function_call.args)
}
fn print_prefix_expr(&mut self, prefix_expr: &PrefixExpr<'_>) -> fmt::Result {
match prefix_expr {
PrefixExpr::Parenthesized(pexpr) => self.print_parenthesized_expr(pexpr),
PrefixExpr::Var(var) => self.print_var(var),
PrefixExpr::Call(call) => self.print_function_call(call),
}
}
fn print_expr(&mut self, expr: &Expr<'_>) -> fmt::Result {
match expr {
Expr::Nil(nil_lit) => self.print_nil_lit(nil_lit),
Expr::Boolean(boolean_lit) => self.print_boolean_lit(boolean_lit),
Expr::Number(number_lit) => self.print_number_lit(number_lit),
Expr::String(string_lit) => self.print_string_lit(string_lit),
Expr::Vararg(vararg) => self.print_vararg(vararg),
Expr::UnOp(un_op_expr) => self.print_un_op_expr(un_op_expr),
Expr::BinOp(bin_op_expr) => self.print_bin_op_expr(bin_op_expr),
Expr::TableConstructor(tcons) => self.print_table_cons(tcons),
Expr::Prefix(prefix_expr) => self.print_prefix_expr(prefix_expr),
Expr::Function(function_expr) => self.print_function_expr(function_expr),
}
}
fn print_table_cons(&mut self, tcons: &TableConstructor<'_>) -> fmt::Result {
self.print_token(&tcons.brackets.0)?;
for field in &tcons.fields {
self.print_table_field(field)?;
}
self.print_token(&tcons.brackets.1)
}
fn print_table_key(&mut self, key: &TableKey<'_>) -> fmt::Result {
match key {
TableKey::Expr { brackets, key } => {
self.print_token(&brackets.0)?;
self.print_expr(key)
}
TableKey::Name { key } => self.print_name(key),
}
}
fn print_table_field(&mut self, field: &TableField<'_>) -> fmt::Result {
if let Some((key, token)) = &field.key {
self.print_table_key(key)?;
write!(self.writer, " ")?;
self.print_token(token)?;
write!(self.writer, " ")?;
}
self.print_expr(&field.value)?;
if let Some(sep) = &field.separator {
self.print_token(sep)?;
write!(self.writer, " ")?;
}
Ok(())
}
fn print_block(&mut self, block: &Block<'_>) -> fmt::Result {
for (i, stmt) in block.statements.iter().enumerate() {
match (i, stmt) {
(0, _) | (_, Statement::Empty(_)) => {},
_ => self.print_new_line()?,
}
self.print_stmt(stmt)?;
}
Ok(())
}
fn print_stmt(&mut self, stmt: &Statement<'_>) -> fmt::Result {
match stmt {
Statement::Empty(empty_stmt) => self.print_empty_stmt(empty_stmt),
Statement::Block(block_stmt) => self.print_block_stmt(block_stmt),
Statement::Return(return_stmt) => self.print_return_stmt(return_stmt),
Statement::Break(break_stmt) => self.print_break_stmt(break_stmt),
Statement::Assignment(assignment_stmt) => self.print_assignment_stmt(assignment_stmt),
Statement::FunctionCall(call) => self.print_function_call(call),
Statement::While(while_stmt) => self.print_while_stmt(while_stmt),
Statement::For(for_stmt) => self.print_for_stmt(for_stmt),
Statement::Repeat(repeat_stmt) => self.print_repeat_stmt(repeat_stmt),
Statement::LocalDeclaration(local_decl) => self.print_local_decl(local_decl),
Statement::FunctionDeclaration(func_decl) => self.print_function_decl(func_decl),
Statement::Label(label_stmt) => self.print_label_stmt(label_stmt),
Statement::Goto(goto_stmt) => self.print_goto_stmt(goto_stmt),
Statement::If(if_stmt) => self.print_if_stmt(if_stmt),
}
}
fn print_empty_stmt(&mut self, stmt: &EmptyStat<'_>) -> fmt::Result {
self.print_token(&stmt.0)
}
fn print_block_stmt(&mut self, stmt: &BlockStat<'_>) -> fmt::Result {
self.print_token(&stmt.do_)?;
self.print_block_indent(&stmt.block)?;
self.print_token(&stmt.end)
}
fn print_return_stmt(&mut self, stmt: &ReturnStat<'_>) -> fmt::Result {
self.print_token(&stmt.return_)?;
write!(self.writer, " ")?;
self.print_punctuated(&stmt.exprs, Self::print_expr, true)?;
if let Some(semi) = &stmt.semi {
self.print_token(semi)?;
}
Ok(())
}
fn print_break_stmt(&mut self, stmt: &BreakStat<'_>) -> fmt::Result {
self.print_token(&stmt.0)
}
fn print_assignment_stmt(&mut self, stmt: &AssignmentStat<'_>) -> fmt::Result {
self.print_punctuated(&stmt.vars, Self::print_var, true)?;
write!(self.writer, " ")?;
self.print_token(&stmt.assign)?;
write!(self.writer, " ")?;
self.print_punctuated(&stmt.exprs, Self::print_expr, true)
}
fn print_while_stmt(&mut self, stmt: &WhileStat<'_>) -> fmt::Result {
self.print_token(&stmt.while_)?;
write!(self.writer, " ")?;
self.print_expr(&stmt.condition)?;
write!(self.writer, " ")?;
self.print_token(&stmt.do_)?;
self.print_block_indent(&stmt.block)?;
self.print_new_line()?;
self.print_token(&stmt.end)
}
fn print_for_stmt(&mut self, stmt: &ForStat<'_>) -> fmt::Result {
match stmt {
ForStat::Generic(generic_for) => self.print_generic_for(generic_for),
ForStat::Numerical(numerical_for) => self.print_numerical_for(numerical_for),
}
}
fn print_generic_for(&mut self, generic_for: &GenericFor<'_>) -> fmt::Result {
self.print_token(&generic_for.for_)?;
write!(self.writer, " ")?;
self.print_punctuated(&generic_for.names, Self::print_name, true)?;
write!(self.writer, " ")?;
self.print_token(&generic_for.in_)?;
write!(self.writer, " ")?;
self.print_punctuated(&generic_for.exprs, Self::print_expr, true)?;
write!(self.writer, " ")?;
self.print_token(&generic_for.do_)?;
self.print_block_indent(&generic_for.block)?;
self.print_new_line()?;
self.print_token(&generic_for.end)
}
fn print_numerical_for(&mut self, numerical_for: &NumericalFor<'_>) -> fmt::Result {
self.print_token(&numerical_for.for_)?;
write!(self.writer, " ")?;
self.print_name(&numerical_for.name)?;
write!(self.writer, " ")?;
self.print_token(&numerical_for.assign)?;
write!(self.writer, " ")?;
self.print_expr(&numerical_for.from)?;
self.print_token(&numerical_for.comma)?;
write!(self.writer, " ")?;
self.print_expr(&numerical_for.to)?;
if let Some((comma, step)) = &numerical_for.step {
self.print_token(comma)?;
write!(self.writer, " ")?;
self.print_expr(step)?;
}
write!(self.writer, " ")?;
self.print_token(&numerical_for.do_)?;
self.print_block_indent(&numerical_for.block)?;
self.print_new_line()?;
self.print_token(&numerical_for.end)
}
fn print_repeat_stmt(&mut self, stmt: &RepeatStat<'_>) -> fmt::Result {
self.print_token(&stmt.repeat)?;
self.print_block_indent(&stmt.block)?;
self.print_new_line()?;
self.print_token(&stmt.until)?;
write!(self.writer, " ")?;
self.print_expr(&stmt.condition)
}
fn print_local_decl(&mut self, stmt: &LocalDeclarationStat<'_>) -> fmt::Result {
self.print_token(&stmt.local)?;
write!(self.writer, " ")?;
self.print_punctuated(&stmt.names, Self::print_name, true)?;
if let Some(def) = &stmt.definition {
self.print_local_def(def)?;
}
Ok(())
}
fn print_local_def(&mut self, def: &LocalDefinition<'_>) -> fmt::Result {
write!(self.writer, " ")?;
self.print_token(&def.assign)?;
write!(self.writer, " ")?;
self.print_punctuated(&def.exprs, Self::print_expr, true)
}
fn print_function_decl(&mut self, stmt: &FunctionDeclarationStat<'_>) -> fmt::Result {
match stmt {
FunctionDeclarationStat::Local { local, function, name, body } => {
self.print_token(local)?;
write!(self.writer, " ")?;
self.print_token(function)?;
write!(self.writer, " ")?;
self.print_name(name)?;
self.print_function_body(body)
}
FunctionDeclarationStat::Nonlocal { function, name, body } => {
self.print_token(function)?;
write!(self.writer, " ")?;
self.print_function_name(name)?;
self.print_function_body(body)
}
}
}
fn print_function_body(&mut self, body: &FunctionBody<'_>) -> fmt::Result {
self.print_token(&body.params.brackets.0)?;
self.print_punctuated(&body.params.list, Self::print_name, true)?;
if let Some(vararg) = &body.vararg {
self.print_vararg(vararg)?;
}
self.print_token(&body.params.brackets.1)?;
self.print_block_indent(&body.block)?;
self.print_new_line()?;
self.print_token(&body.end)
}
fn print_function_name(&mut self, name: &FunctionName<'_>) -> fmt::Result {
match name {
FunctionName::PlainName(name) => self.print_name(name),
FunctionName::Indexed(punc) => self.print_punctuated(punc, Self::print_name, false),
FunctionName::Method { receiver, colon, method } => {
self.print_punctuated(receiver, Self::print_name, false)?;
self.print_token(colon)?;
self.print_name(method)
}
}
}
fn print_function_expr(&mut self, expr: &FunctionExpr<'_>) -> fmt::Result {
self.print_token(&expr.function)?;
self.print_function_body(&expr.body)
}
fn print_label_stmt(&mut self, stmt: &LabelStat<'_>) -> fmt::Result {
self.print_token(&stmt.preceding)?;
self.print_name(&stmt.name)?;
self.print_token(&stmt.following)
}
fn print_goto_stmt(&mut self, stmt: &GotoStat<'_>) -> fmt::Result {
self.print_token(&stmt.goto)?;
write!(self.writer, " ")?;
self.print_name(&stmt.label)
}
fn print_if_stmt(&mut self, stmt: &IfStat<'_>) -> fmt::Result {
self.print_token(&stmt.if_)?;
write!(self.writer, " ")?;
self.print_expr(&stmt.condition)?;
write!(self.writer, " ")?;
self.print_token(&stmt.then)?;
self.print_block_indent(&stmt.block)?;
self.print_new_line()?;
for elseif in &stmt.elseifs {
self.print_else_if(elseif)?;
}
if let Some(else_) = &stmt.else_ {
self.print_else(else_)?;
}
self.print_token(&stmt.end)
}
fn print_else_if(&mut self, elseif: &ElseIf<'_>) -> fmt::Result {
self.print_token(&elseif.elseif)?;
write!(self.writer, " ")?;
self.print_expr(&elseif.condition)?;
write!(self.writer, " ")?;
self.print_token(&elseif.then)?;
self.print_block_indent(&elseif.block)
}
fn print_else(&mut self, else_: &Else<'_>) -> fmt::Result {
self.print_token(&else_.else_)?;
self.print_block_indent(&else_.block)
}
];