use full_moon::{
ast::{Ast, Block, Call, Expression, Prefix, Stmt, Suffix},
node::Node,
tokenizer::{TokenReference, TokenType},
};
use crate::{
context::{Context, FormatNode},
formatters::trivia::{FormatTriviaType, UpdateLeadingTrivia},
};
fn extract_identifier_from_token(token: &TokenReference) -> Option<String> {
match token.token_type() {
TokenType::Identifier { identifier } => Some(identifier.to_string()),
_ => None,
}
}
fn get_expression_kind(expression: &Expression) -> Option<GroupKind> {
match expression {
Expression::FunctionCall(function_call) => {
let Prefix::Name(token) = function_call.prefix() else {
return None;
};
let name = extract_identifier_from_token(token)?;
if name == "require" {
Some(GroupKind::Require)
} else if name == "game" {
let Some(Suffix::Call(Call::MethodCall(method_call))) =
function_call.suffixes().next()
else {
return None;
};
let name = extract_identifier_from_token(method_call.name())?;
if name == "GetService" {
Some(GroupKind::GetService)
} else {
None
}
} else {
None
}
}
#[cfg(feature = "luau")]
Expression::TypeAssertion { expression, .. } => get_expression_kind(expression),
_ => None,
}
}
type StmtSemicolon = (Stmt, Option<TokenReference>);
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum GroupKind {
Require,
GetService,
}
enum BlockPartition {
RequiresGroup(GroupKind, Vec<(String, StmtSemicolon)>),
Other(Vec<StmtSemicolon>),
}
fn partition_nodes_into_groups(block: &Block) -> Vec<BlockPartition> {
let mut parts = Vec::new();
for stmt in block.stmts_with_semicolon() {
if let Stmt::LocalAssignment(node) = &stmt.0 {
if node.names().len() == 1 && node.expressions().len() == 1 {
let name = node.names().iter().next().unwrap();
let expression = node.expressions().iter().next().unwrap();
let current_line = name.start_position().unwrap().line();
let expression_kind = get_expression_kind(expression);
if let Some(expression_kind) = expression_kind {
let variable_name = extract_identifier_from_token(name)
.expect("require is stored as non-identifier");
let create_new_block = match parts.last() {
None => true,
Some(BlockPartition::Other(_)) => true,
Some(BlockPartition::RequiresGroup(other_kind, _))
if *other_kind != expression_kind =>
{
true
}
Some(BlockPartition::RequiresGroup(_, list)) => {
let previous_require =
list.last().expect("unreachable!: empty require group");
let position = previous_require
.1
.end_position()
.expect("unreachable!: previous require stmt has no end position");
let previous_require_line = position.line();
current_line - previous_require_line > 1
}
};
if create_new_block {
parts.push(BlockPartition::RequiresGroup(expression_kind, Vec::new()))
}
match parts.last_mut() {
Some(BlockPartition::RequiresGroup(_, map)) => {
map.push((variable_name, stmt.clone()))
}
_ => unreachable!(),
};
continue;
}
}
}
if parts.is_empty() {
parts.push(BlockPartition::Other(Vec::new()))
} else if let Some(BlockPartition::RequiresGroup(_, _)) = parts.last() {
parts.push(BlockPartition::Other(Vec::new()))
}
match parts.last_mut() {
Some(BlockPartition::Other(list)) => list.push(stmt.clone()),
_ => unreachable!(),
}
}
parts
}
pub(crate) fn sort_requires(ctx: &Context, input_ast: Ast) -> Ast {
let block = input_ast.nodes();
let parts = partition_nodes_into_groups(block);
match parts.last() {
Some(BlockPartition::Other(_)) if parts.len() == 1 => return input_ast,
None => return input_ast,
_ => (),
};
let mut stmts: Vec<StmtSemicolon> = Vec::new();
for part in parts {
match part {
BlockPartition::RequiresGroup(_, mut list) => {
if list
.iter()
.any(|(_, stmt)| !matches!(ctx.should_format_node(stmt), FormatNode::Normal))
{
stmts.extend(list.iter().map(|x| x.1.clone()));
continue;
}
let leading_trivia = match list.first_mut() {
Some((_, (Stmt::LocalAssignment(local_assignment), _))) => {
let trivia = local_assignment
.local_token()
.leading_trivia()
.cloned()
.collect();
*local_assignment = local_assignment
.update_leading_trivia(FormatTriviaType::Replace(vec![]));
trivia
}
_ => unreachable!(),
};
list.sort_by_key(|key| key.0.clone());
match list.first_mut() {
Some((_, (Stmt::LocalAssignment(local_assignment), _))) => {
*local_assignment = local_assignment
.update_leading_trivia(FormatTriviaType::Replace(leading_trivia))
}
_ => unreachable!(),
};
stmts.extend(list.iter().map(|x| x.1.clone()))
}
BlockPartition::Other(mut list) => stmts.append(&mut list),
};
}
let block = block.clone().with_stmts(stmts);
input_ast.with_nodes(block).update_positions()
}
#[cfg(test)]
mod tests {
use super::*;
fn extract_test_expression(ast: &Ast) -> &Expression {
let stmt = ast.nodes().stmts().next().unwrap();
match stmt {
Stmt::LocalAssignment(local_assignment) => {
local_assignment.expressions().iter().next().unwrap()
}
_ => unreachable!(),
}
}
#[test]
fn fail_extracting_non_identifier_token() {
let token = TokenReference::symbol("return").unwrap();
assert!(extract_identifier_from_token(&token).is_none());
}
#[test]
fn get_expression_kind_require_stmt() {
let ast = full_moon::parse("local NAME = require(EXPR)").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), Some(GroupKind::Require));
}
#[test]
fn get_expression_kind_get_service() {
let ast = full_moon::parse("local NAME = game:GetService(EXPR)").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), Some(GroupKind::GetService));
}
#[test]
fn get_expression_kind_other_1() {
let ast = full_moon::parse("local NAME = testing").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), None);
}
#[test]
fn get_expression_kind_other_2() {
let ast = full_moon::parse("local NAME = game").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), None);
}
#[test]
fn get_expression_kind_other_3() {
let ast = full_moon::parse("local NAME = game:FindFirstChild(EXPR)").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), None);
}
#[test]
fn get_expression_kind_other_4() {
let ast = full_moon::parse("local NAME = game.Name").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), None);
}
#[test]
fn get_expression_kind_other_5() {
let ast = full_moon::parse("local NAME = ('game'):GetService(EXPR)").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), None);
}
#[test]
fn get_expression_kind_other_6() {
let ast = full_moon::parse("local NAME = 1 + 2").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), None);
}
#[test]
#[cfg(feature = "luau")]
fn get_expression_kind_require_type_assertion() {
let ast = full_moon::parse("local NAME = require(path) :: any").unwrap();
let expression = extract_test_expression(&ast);
assert_eq!(get_expression_kind(expression), Some(GroupKind::Require));
}
}