use crate::formatters::{
trivia_formatter::{self, FormatTriviaType},
CodeFormatter, Range,
};
use full_moon::ast::{
punctuated::Pair, Block, Expression, LastStmt, Prefix, Return, Stmt, UnOp, Value, Var,
};
#[cfg(feature = "luau")]
use full_moon::tokenizer::TokenType;
use full_moon::tokenizer::{Token, TokenReference};
use std::borrow::Cow;
impl CodeFormatter {
pub fn get_token_range<'ast>(token: &Token<'ast>) -> Range {
(token.start_position().bytes(), token.end_position().bytes())
}
pub fn get_range_in_expression<'ast>(expression: &Expression<'ast>) -> Range {
match expression {
Expression::Parentheses { contained, .. } => {
CodeFormatter::get_token_range(contained.tokens().0)
}
Expression::UnaryOperator { unop, .. } => match unop {
UnOp::Minus(token_reference) => {
CodeFormatter::get_token_range(token_reference.token())
}
UnOp::Not(token_reference) => {
CodeFormatter::get_token_range(token_reference.token())
}
UnOp::Hash(token_reference) => {
CodeFormatter::get_token_range(token_reference.token())
}
},
Expression::Value { value, .. } => {
let value = &**value;
match value {
Value::Function((token_ref, _)) => {
CodeFormatter::get_token_range(token_ref.token())
}
Value::FunctionCall(function_call) => {
CodeFormatter::get_range_in_prefix(function_call.prefix())
}
Value::TableConstructor(table_constructor) => CodeFormatter::get_token_range(
table_constructor.braces().tokens().0.token(),
),
Value::Number(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
Value::ParseExpression(expr) => CodeFormatter::get_range_in_expression(&expr),
Value::String(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
Value::Symbol(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
Value::Var(var) => match var {
Var::Name(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
Var::Expression(var_expr) => {
CodeFormatter::get_range_in_prefix(var_expr.prefix())
}
},
}
}
}
}
pub fn get_range_in_prefix(prefix: &Prefix) -> Range {
match prefix {
Prefix::Name(token) => CodeFormatter::get_token_range(token.token()),
Prefix::Expression(expression) => CodeFormatter::get_range_in_expression(expression),
}
}
fn get_range_in_stmt(stmt: Stmt) -> Range {
match stmt {
Stmt::Assignment(assignment) => {
CodeFormatter::get_token_range(assignment.equal_token().token())
}
Stmt::Do(do_block) => CodeFormatter::get_token_range(do_block.do_token().token()),
Stmt::FunctionCall(function_call) => {
CodeFormatter::get_range_in_prefix(function_call.prefix())
}
Stmt::FunctionDeclaration(function_declaration) => {
CodeFormatter::get_token_range(function_declaration.function_token().token())
}
Stmt::GenericFor(generic_for) => {
CodeFormatter::get_token_range(generic_for.for_token().token())
}
Stmt::If(if_block) => CodeFormatter::get_token_range(if_block.if_token().token()),
Stmt::LocalAssignment(local_assignment) => {
CodeFormatter::get_token_range(local_assignment.local_token().token())
}
Stmt::LocalFunction(local_function) => {
CodeFormatter::get_token_range(local_function.local_token().token())
}
Stmt::NumericFor(numeric_for) => {
CodeFormatter::get_token_range(numeric_for.for_token().token())
}
Stmt::Repeat(repeat_block) => {
CodeFormatter::get_token_range(repeat_block.repeat_token().token())
}
Stmt::While(while_block) => {
CodeFormatter::get_token_range(while_block.while_token().token())
}
#[cfg(feature = "luau")]
Stmt::CompoundAssignment(compound_assignment) => {
CodeFormatter::get_range_in_expression(compound_assignment.rhs())
}
#[cfg(feature = "luau")]
Stmt::ExportedTypeDeclaration(exported_type_declaration) => {
CodeFormatter::get_token_range(exported_type_declaration.export_token().token())
}
#[cfg(feature = "luau")]
Stmt::TypeDeclaration(type_declaration) => {
CodeFormatter::get_token_range(type_declaration.type_token().token())
}
}
}
pub fn format_return<'ast>(&mut self, return_node: Return<'ast>) -> Return<'ast> {
let formatted_returns = self.format_punctuated(
return_node.returns().to_owned(),
&CodeFormatter::format_expression,
);
let wanted_token: TokenReference<'ast> = if formatted_returns.is_empty() {
TokenReference::symbol("return").unwrap()
} else {
TokenReference::symbol("return ").unwrap()
};
let formatted_token = self.format_symbol(return_node.token().to_owned(), wanted_token);
return_node
.with_token(formatted_token)
.with_returns(formatted_returns)
}
pub fn format_last_stmt<'ast>(&mut self, last_stmt: LastStmt<'ast>) -> LastStmt<'ast> {
match last_stmt {
LastStmt::Break(token) => LastStmt::Break(
self.format_symbol(token.into_owned(), TokenReference::symbol("break").unwrap()),
),
LastStmt::Return(return_node) => LastStmt::Return(self.format_return(return_node)),
#[cfg(feature = "luau")]
LastStmt::Continue(token) => LastStmt::Continue(self.format_symbol(
token.into_owned(),
TokenReference::new(
vec![],
Token::new(TokenType::Identifier {
identifier: Cow::Owned(String::from("continue")),
}),
vec![],
),
)),
}
}
fn get_range_in_last_stmt<'ast>(last_stmt: &LastStmt<'ast>) -> Range {
match last_stmt {
LastStmt::Break(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
LastStmt::Return(return_node) => {
CodeFormatter::get_token_range(return_node.token().token())
}
#[cfg(feature = "luau")]
LastStmt::Continue(token_ref) => CodeFormatter::get_token_range(token_ref.token()),
}
}
pub fn last_stmt_add_trivia<'ast>(
&self,
last_stmt: LastStmt<'ast>,
additional_indent_level: Option<usize>,
) -> LastStmt<'ast> {
match last_stmt {
LastStmt::Break(break_node) => {
LastStmt::Break(Cow::Owned(trivia_formatter::token_reference_add_trivia(
break_node.into_owned(),
FormatTriviaType::Append(vec![
self.create_indent_trivia(additional_indent_level)
]),
FormatTriviaType::Append(vec![self.create_newline_trivia()]),
)))
}
LastStmt::Return(return_node) => {
let mut token = return_node.token().to_owned();
let mut returns = return_node.returns().to_owned();
if return_node.returns().is_empty() {
token = trivia_formatter::token_reference_add_trivia(
token,
FormatTriviaType::Append(vec![
self.create_indent_trivia(additional_indent_level)
]),
FormatTriviaType::Append(vec![self.create_newline_trivia()]),
);
} else {
token = trivia_formatter::token_reference_add_trivia(
token,
FormatTriviaType::Append(vec![
self.create_indent_trivia(additional_indent_level)
]),
FormatTriviaType::NoChange,
);
if let Some(last_pair) = returns.pop() {
match last_pair {
Pair::End(value) => {
let expression = trivia_formatter::expression_add_trailing_trivia(
value,
FormatTriviaType::Append(vec![self.create_newline_trivia()]),
);
returns.push(Pair::End(expression));
}
Pair::Punctuated(_, _) => (), }
}
}
LastStmt::Return(
return_node
.with_token(Cow::Owned(token))
.with_returns(returns),
)
}
#[cfg(feature = "luau")]
LastStmt::Continue(continue_node) => {
LastStmt::Continue(Cow::Owned(trivia_formatter::token_reference_add_trivia(
continue_node.into_owned(),
FormatTriviaType::Append(vec![
self.create_indent_trivia(additional_indent_level)
]),
FormatTriviaType::Append(vec![self.create_newline_trivia()]),
)))
}
}
}
pub fn format_block<'ast>(&mut self, block: Block<'ast>) -> Block<'ast> {
let formatted_statements: Vec<(Stmt<'ast>, Option<Cow<'ast, TokenReference<'ast>>>)> =
block
.iter_stmts()
.map(|stmt| {
let range_in_stmt = CodeFormatter::get_range_in_stmt(stmt.to_owned());
let additional_indent_level = self.get_range_indent_increase(range_in_stmt);
let stmt = self.format_stmt(stmt.to_owned());
(
self.stmt_add_trivia(stmt, additional_indent_level),
None, )
})
.collect();
let formatted_last_stmt = match block.last_stmt() {
Some(last_stmt) => {
let range_in_last_stmt = CodeFormatter::get_range_in_last_stmt(last_stmt);
let additional_indent_level = self.get_range_indent_increase(range_in_last_stmt);
let last_stmt = self.format_last_stmt(last_stmt.to_owned());
Some((
self.last_stmt_add_trivia(last_stmt, additional_indent_level),
None,
))
}
None => None,
};
block
.with_stmts(formatted_statements)
.with_last_stmt(formatted_last_stmt)
}
}