use crate::formatters::{
trivia_formatter::{self, FormatTriviaType},
trivia_util, CodeFormatter, Range,
};
use full_moon::ast::{Block, Expression, LastStmt, Prefix, Return, Stmt, UnOp, Value, Var};
#[cfg(feature = "luau")]
use full_moon::tokenizer::TokenType;
use full_moon::tokenizer::{Token, TokenReference};
use std::borrow::Cow;
impl CodeFormatter {
pub fn get_token_range<'ast>(token: &Token<'ast>) -> Range {
(token.start_position().bytes(), token.end_position().bytes())
}
pub fn get_range_in_expression<'ast>(expression: &Expression<'ast>) -> Range {
match expression {
Expression::Parentheses { contained, .. } => {
CodeFormatter::get_token_range(contained.tokens().0)
}
Expression::UnaryOperator { unop, .. } => match unop {
UnOp::Minus(token_reference) => {
CodeFormatter::get_token_range(token_reference.token())
}
UnOp::Not(token_reference) => {
CodeFormatter::get_token_range(token_reference.token())
}
UnOp::Hash(token_reference) => {
CodeFormatter::get_token_range(token_reference.token())
}
},
Expression::Value { value, .. } => {
let value = &**value;
match value {
Value::Function((token_ref, _)) => {
CodeFormatter::get_token_range(token_ref.token())
}
Value::FunctionCall(function_call) => {
CodeFormatter::get_range_in_prefix(function_call.prefix())
}
Value::TableConstructor(table_constructor) => CodeFormatter::get_token_range(
table_constructor.braces().tokens().0.token(),
),
Value::Number(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
Value::ParseExpression(expr) => CodeFormatter::get_range_in_expression(&expr),
Value::String(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
Value::Symbol(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
Value::Var(var) => match var {
Var::Name(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
Var::Expression(var_expr) => {
CodeFormatter::get_range_in_prefix(var_expr.prefix())
}
},
}
}
}
}
pub fn get_range_in_prefix(prefix: &Prefix) -> Range {
match prefix {
Prefix::Name(token) => CodeFormatter::get_token_range(token.token()),
Prefix::Expression(expression) => CodeFormatter::get_range_in_expression(expression),
}
}
fn get_range_in_stmt(stmt: &Stmt) -> Range {
match stmt {
Stmt::Assignment(assignment) => {
CodeFormatter::get_token_range(assignment.equal_token().token())
}
Stmt::Do(do_block) => CodeFormatter::get_token_range(do_block.do_token().token()),
Stmt::FunctionCall(function_call) => {
CodeFormatter::get_range_in_prefix(function_call.prefix())
}
Stmt::FunctionDeclaration(function_declaration) => {
CodeFormatter::get_token_range(function_declaration.function_token().token())
}
Stmt::GenericFor(generic_for) => {
CodeFormatter::get_token_range(generic_for.for_token().token())
}
Stmt::If(if_block) => CodeFormatter::get_token_range(if_block.if_token().token()),
Stmt::LocalAssignment(local_assignment) => {
CodeFormatter::get_token_range(local_assignment.local_token().token())
}
Stmt::LocalFunction(local_function) => {
CodeFormatter::get_token_range(local_function.local_token().token())
}
Stmt::NumericFor(numeric_for) => {
CodeFormatter::get_token_range(numeric_for.for_token().token())
}
Stmt::Repeat(repeat_block) => {
CodeFormatter::get_token_range(repeat_block.repeat_token().token())
}
Stmt::While(while_block) => {
CodeFormatter::get_token_range(while_block.while_token().token())
}
#[cfg(feature = "luau")]
Stmt::CompoundAssignment(compound_assignment) => {
CodeFormatter::get_range_in_expression(compound_assignment.rhs())
}
#[cfg(feature = "luau")]
Stmt::ExportedTypeDeclaration(exported_type_declaration) => {
CodeFormatter::get_token_range(exported_type_declaration.export_token().token())
}
#[cfg(feature = "luau")]
Stmt::TypeDeclaration(type_declaration) => {
CodeFormatter::get_token_range(type_declaration.type_token().token())
}
}
}
pub fn format_return<'ast>(&mut self, return_node: &Return<'ast>) -> Return<'ast> {
let (mut formatted_returns, comments_buf) =
self.format_punctuated(return_node.returns(), &CodeFormatter::format_expression);
let wanted_token: TokenReference<'ast> = if formatted_returns.is_empty() {
TokenReference::symbol("return").unwrap()
} else {
if let Some(pair) = formatted_returns.pop() {
let pair = pair.map(|expr| {
trivia_formatter::expression_add_trailing_trivia(
expr,
FormatTriviaType::Append(comments_buf),
)
});
formatted_returns.push(pair);
}
TokenReference::symbol("return ").unwrap()
};
let formatted_token = self.format_symbol(return_node.token(), &wanted_token);
Return::new()
.with_token(formatted_token)
.with_returns(formatted_returns)
}
pub fn format_last_stmt<'ast>(&mut self, last_stmt: LastStmt<'ast>) -> LastStmt<'ast> {
match last_stmt {
LastStmt::Break(token) => LastStmt::Break(crate::fmt_symbol!(self, &token, "break")),
LastStmt::Return(return_node) => LastStmt::Return(self.format_return(&return_node)),
#[cfg(feature = "luau")]
LastStmt::Continue(token) => LastStmt::Continue(self.format_symbol(
&token,
&TokenReference::new(
vec![],
Token::new(TokenType::Identifier {
identifier: Cow::Owned(String::from("continue")),
}),
vec![],
),
)),
}
}
fn get_range_in_last_stmt<'ast>(last_stmt: &LastStmt<'ast>) -> Range {
match last_stmt {
LastStmt::Break(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
LastStmt::Return(return_node) => {
CodeFormatter::get_token_range(return_node.token().token())
}
#[cfg(feature = "luau")]
LastStmt::Continue(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
}
}
pub fn last_stmt_add_trivia<'ast>(
&self,
last_stmt: LastStmt<'ast>,
additional_indent_level: Option<usize>,
) -> LastStmt<'ast> {
match last_stmt {
LastStmt::Break(break_node) => {
LastStmt::Break(Cow::Owned(trivia_formatter::token_reference_add_trivia(
break_node.into_owned(),
FormatTriviaType::Append(vec![
self.create_indent_trivia(additional_indent_level)
]),
FormatTriviaType::Append(vec![self.create_newline_trivia()]),
)))
}
LastStmt::Return(return_node) => {
LastStmt::Return(self.return_add_trivia(return_node, additional_indent_level))
}
#[cfg(feature = "luau")]
LastStmt::Continue(continue_node) => {
LastStmt::Continue(Cow::Owned(trivia_formatter::token_reference_add_trivia(
continue_node.into_owned(),
FormatTriviaType::Append(vec![
self.create_indent_trivia(additional_indent_level)
]),
FormatTriviaType::Append(vec![self.create_newline_trivia()]),
)))
}
}
}
pub fn format_block<'ast>(&mut self, block: Block<'ast>) -> Block<'ast> {
let mut formatted_statements: Vec<(Stmt<'ast>, Option<Cow<'ast, TokenReference<'ast>>>)> =
Vec::new();
let mut stmt_iterator = block.iter_stmts().peekable();
while let Some(stmt) = stmt_iterator.next() {
let range_in_stmt = CodeFormatter::get_range_in_stmt(stmt);
let additional_indent_level = self.get_range_indent_increase(range_in_stmt);
let stmt = self.format_stmt(stmt);
let require_semicolon = match stmt {
Stmt::Assignment(_)
| Stmt::LocalAssignment(_)
| Stmt::FunctionCall(_)
| Stmt::Repeat(_) => {
let next_stmt = stmt_iterator.peek();
match next_stmt {
Some(next_stmt) => match next_stmt {
Stmt::FunctionCall(function_call) => match function_call.prefix() {
Prefix::Expression(expr) => match expr {
Expression::Parentheses { .. } => true,
_ => false,
},
_ => false,
},
_ => false,
},
None => false,
}
}
_ => false,
};
let mut trivia_stmt = self.stmt_add_trivia(stmt, additional_indent_level);
let semicolon = match require_semicolon {
true => {
let (updated_stmt, trivia) = trivia_util::get_stmt_trailing_trivia(trivia_stmt);
trivia_stmt = updated_stmt;
Some(Cow::Owned(trivia_formatter::token_reference_add_trivia(
TokenReference::symbol(";").expect("could not make semicolon"),
FormatTriviaType::NoChange,
FormatTriviaType::Append(trivia),
)))
}
false => None,
};
formatted_statements.push((trivia_stmt, semicolon))
}
drop(stmt_iterator);
let formatted_last_stmt = match block.last_stmt() {
Some(last_stmt) => {
let range_in_last_stmt = CodeFormatter::get_range_in_last_stmt(last_stmt);
let additional_indent_level = self.get_range_indent_increase(range_in_last_stmt);
let last_stmt = self.format_last_stmt(last_stmt.to_owned());
Some((
self.last_stmt_add_trivia(last_stmt, additional_indent_level),
None,
))
}
None => None,
};
block
.with_stmts(formatted_statements)
.with_last_stmt(formatted_last_stmt)
}
}