use full_moon::{
ast::{
punctuated::{Pair, Punctuated},
span::ContainedSpan,
Ast, Block, Expression, FunctionArgs, TableConstructor,
},
node::Node,
tokenizer::{StringLiteralQuoteType, Token, TokenReference, TokenType},
visitors::VisitorMut,
};
#[cfg(feature = "luau")]
use full_moon::ast::luau::TypeInfo;
pub struct AstVerifier {}
impl AstVerifier {
pub fn new() -> Self {
Self {}
}
pub fn compare(&mut self, input_ast: Ast, reparsed_output: Ast) -> bool {
let massaged_input = self.visit_ast(input_ast);
let massaged_output = self.visit_ast(reparsed_output);
massaged_input.similar(&massaged_output)
}
}
fn remove_parentheses(expression: Expression) -> Expression {
match expression {
Expression::Parentheses { expression, .. } => *expression,
_ => expression,
}
}
#[cfg(feature = "luau")]
fn remove_type_parentheses(type_info: TypeInfo) -> TypeInfo {
match type_info {
TypeInfo::Tuple { ref types, .. } => {
if types.len() == 1 {
types.into_iter().next().unwrap().clone()
} else {
type_info
}
}
_ => type_info,
}
}
impl VisitorMut for AstVerifier {
fn visit_block(&mut self, node: Block) -> Block {
let stmts = node
.stmts_with_semicolon()
.map(|(stmt, _semicolon)| (stmt.to_owned(), None))
.collect();
let last_stmt = node
.last_stmt_with_semicolon()
.map(|(last_stmt, _semicolon)| (last_stmt.to_owned(), None));
node.with_stmts(stmts).with_last_stmt(last_stmt)
}
fn visit_table_constructor(&mut self, node: TableConstructor) -> TableConstructor {
let current_fields = node.fields();
let mut fields = Punctuated::new();
for field in current_fields.to_owned().into_pairs() {
let pair = match field {
Pair::Punctuated(field, _) | Pair::End(field) => {
Pair::Punctuated(field, TokenReference::symbol(",").unwrap())
}
};
fields.push(pair)
}
node.with_fields(fields)
}
fn visit_function_args(
&mut self,
node: full_moon::ast::FunctionArgs,
) -> full_moon::ast::FunctionArgs {
match node {
FunctionArgs::String(string) => FunctionArgs::Parentheses {
parentheses: ContainedSpan::new(
TokenReference::symbol("(").unwrap(),
TokenReference::symbol(")").unwrap(),
),
arguments: std::iter::once(Pair::End(Expression::String(string))).collect(),
},
FunctionArgs::TableConstructor(table) => FunctionArgs::Parentheses {
parentheses: ContainedSpan::new(
TokenReference::symbol("(").unwrap(),
TokenReference::symbol(")").unwrap(),
),
arguments: std::iter::once(Pair::End(Expression::TableConstructor(table)))
.collect(),
},
_ => node,
}
}
fn visit_expression(&mut self, node: Expression) -> Expression {
remove_parentheses(node)
}
fn visit_number(&mut self, token: Token) -> Token {
let token_type = match token.token_type() {
TokenType::Number { text } => {
#[cfg(feature = "luau")]
let text = text.replace('_', "");
#[cfg(feature = "luajit")]
let text = text
.trim_end_matches("ULL")
.trim_end_matches("LL")
.to_string();
let number = match text.as_str().parse::<f64>() {
Ok(num) => num.to_string(),
Err(_) => match i64::from_str_radix(&text.as_str()[2..], 16) {
Ok(num) => num.to_string(),
#[cfg(feature = "luau")]
Err(_) => match i64::from_str_radix(&text.as_str()[2..], 2) {
Ok(num) => num.to_string(),
Err(_) => unreachable!(),
},
#[cfg(not(feature = "luau"))]
Err(_) => unreachable!(),
},
};
TokenType::Number {
text: number.into(),
}
}
_ => unreachable!(),
};
Token::new(token_type)
}
fn visit_string_literal(&mut self, token: Token) -> Token {
let token_type = match token.token_type() {
TokenType::StringLiteral {
literal,
multi_line_depth,
..
} => TokenType::StringLiteral {
literal: literal.to_owned().replace('\\', "").into(),
multi_line_depth: multi_line_depth.to_owned(),
quote_type: StringLiteralQuoteType::Brackets,
},
_ => unreachable!(),
};
Token::new(token_type)
}
#[cfg(feature = "luau")]
fn visit_type_info(&mut self, type_info: TypeInfo) -> TypeInfo {
remove_type_parentheses(type_info)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_equivalent_asts() {
let input_ast = full_moon::parse("local x = 1").unwrap();
let output_ast = full_moon::parse("local x = 1").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_different_asts() {
let input_ast = full_moon::parse("local x = 1").unwrap();
let output_ast = full_moon::parse("local x = 2").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(!ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_stmt_semicolons() {
let input_ast = full_moon::parse("local x = 1;").unwrap();
let output_ast = full_moon::parse("local x = 1").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_string_quote_types() {
let input_ast = full_moon::parse("local x = '1'").unwrap();
let output_ast = full_moon::parse("local x = \"1\"").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_string_escapes() {
let input_ast = full_moon::parse("local x = '\\q'").unwrap();
let output_ast = full_moon::parse("local x = 'q'").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_numbers() {
let input_ast = full_moon::parse("local x = .1").unwrap();
let output_ast = full_moon::parse("local x = 0.1").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_numbers_2() {
let input_ast = full_moon::parse("local x = -.1").unwrap();
let output_ast = full_moon::parse("local x = -0.1").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_hex_numbers() {
let input_ast = full_moon::parse("local x = 0XFFFF").unwrap();
let output_ast = full_moon::parse("local x = 0xFFFF").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
#[cfg(feature = "luau")]
fn test_equivalent_hex_numbers_with_separators() {
let input_ast = full_moon::parse("local x = 0xffff_ffc0").unwrap();
let output_ast = full_moon::parse("local x = 0xffffffc0").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_hex_numbers_very_large_number() {
let input_ast = full_moon::parse("max = max or 0xffffffff").unwrap();
let output_ast = full_moon::parse("max = max or 0xffffffff").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_different_hex_numbers() {
let input_ast = full_moon::parse("local x = 0xFFAA").unwrap();
let output_ast = full_moon::parse("local x = 0xFFFF").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(!ast_verifier.compare(input_ast, output_ast));
}
#[test]
#[cfg(feature = "luau")]
fn test_equivalent_binary_numbers() {
use full_moon::LuaVersion;
let input_ast = full_moon::parse_fallible("local x = 0B10101", LuaVersion::luau())
.into_result()
.unwrap();
let output_ast = full_moon::parse_fallible("local x = 0b10101", LuaVersion::luau())
.into_result()
.unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
#[cfg(feature = "luau")]
fn test_different_binary_numbers() {
use full_moon::LuaVersion;
let input_ast = full_moon::parse_fallible("local x = 0b1111", LuaVersion::luau())
.into_result()
.unwrap();
let output_ast = full_moon::parse_fallible("local x = 0b1110", LuaVersion::luau())
.into_result()
.unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(!ast_verifier.compare(input_ast, output_ast));
}
#[test]
#[cfg(feature = "luajit")]
fn test_equivalent_luajit_numbers() {
use full_moon::LuaVersion;
let input_ast = full_moon::parse_fallible("local x = 2 ^ 63LL", LuaVersion::luajit())
.into_result()
.unwrap();
let output_ast = full_moon::parse_fallible("local x = 2 ^ 63", LuaVersion::luajit())
.into_result()
.unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_table_separators() {
let input_ast = full_moon::parse("local x = {'a'; 'b'; 'c';}").unwrap();
let output_ast = full_moon::parse("local x = {'a', 'b', 'c'}").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_function_calls() {
let input_ast = full_moon::parse("local x = call'foo'").unwrap();
let output_ast = full_moon::parse("local x = call('foo')").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_function_calls_2() {
let input_ast = full_moon::parse("local x = call{'foo'}").unwrap();
let output_ast = full_moon::parse("local x = call({'foo'})").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_conditions() {
let input_ast = full_moon::parse("if (true) then return end").unwrap();
let output_ast = full_moon::parse("if true then return end").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
#[cfg(feature = "luau")]
fn test_equivalent_types_removed_parentheses() {
let input_ast = full_moon::parse("type Foo = (number)").unwrap();
let output_ast = full_moon::parse("type Foo = number").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
}