use full_moon::{
ast::{
punctuated::{Pair, Punctuated},
span::ContainedSpan,
Ast, Block, Expression, FunctionArgs, TableConstructor, Value,
},
node::Node,
tokenizer::{StringLiteralQuoteType, Token, TokenReference, TokenType},
visitors::VisitorMut,
};
pub struct AstVerifier {}
impl AstVerifier {
pub fn new() -> Self {
Self {}
}
pub fn compare(&mut self, input_ast: Ast, reparsed_output: Ast) -> bool {
let massaged_input = self.visit_ast(input_ast);
let massaged_output = self.visit_ast(reparsed_output);
massaged_input.similar(&massaged_output)
}
}
fn remove_parentheses(expression: Expression) -> Expression {
match expression {
Expression::Parentheses { expression, .. } => *expression,
Expression::Value { value, .. } => Expression::Value {
value: match *value {
Value::ParenthesesExpression(expression) => return remove_parentheses(expression),
_ => value,
},
#[cfg(feature = "luau")]
type_assertion: None,
},
_ => expression,
}
}
impl VisitorMut for AstVerifier {
fn visit_block(&mut self, node: Block) -> Block {
let stmts = node
.stmts_with_semicolon()
.map(|(stmt, _semicolon)| (stmt.to_owned(), None))
.collect();
let last_stmt = node
.last_stmt_with_semicolon()
.map(|(last_stmt, _semicolon)| (last_stmt.to_owned(), None));
node.with_stmts(stmts).with_last_stmt(last_stmt)
}
fn visit_table_constructor(&mut self, node: TableConstructor) -> TableConstructor {
let current_fields = node.fields();
let mut fields = Punctuated::new();
for field in current_fields.to_owned().into_pairs() {
let pair = match field {
Pair::Punctuated(field, _) | Pair::End(field) => {
Pair::Punctuated(field, TokenReference::symbol(",").unwrap())
}
};
fields.push(pair)
}
node.with_fields(fields)
}
fn visit_function_args(
&mut self,
node: full_moon::ast::FunctionArgs,
) -> full_moon::ast::FunctionArgs {
match node {
FunctionArgs::String(string) => FunctionArgs::Parentheses {
parentheses: ContainedSpan::new(
TokenReference::symbol("(").unwrap(),
TokenReference::symbol(")").unwrap(),
),
arguments: std::iter::once(Pair::End(Expression::Value {
value: Box::new(Value::String(string)),
#[cfg(feature = "luau")]
type_assertion: None,
}))
.collect(),
},
FunctionArgs::TableConstructor(table) => FunctionArgs::Parentheses {
parentheses: ContainedSpan::new(
TokenReference::symbol("(").unwrap(),
TokenReference::symbol(")").unwrap(),
),
arguments: std::iter::once(Pair::End(Expression::Value {
value: Box::new(Value::TableConstructor(table)),
#[cfg(feature = "luau")]
type_assertion: None,
}))
.collect(),
},
_ => node,
}
}
fn visit_expression(&mut self, node: Expression) -> Expression {
remove_parentheses(node)
}
fn visit_number(&mut self, token: Token) -> Token {
let token_type = match token.token_type() {
TokenType::Number { text } => {
#[cfg(feature = "luau")]
let text = text.replace('_', "");
let number = match text.as_str().parse::<f64>() {
Ok(num) => num,
Err(_) => match i32::from_str_radix(&text.as_str()[2..], 16) {
Ok(num) => num.into(),
#[cfg(feature = "luau")]
Err(_) => match i32::from_str_radix(&text.as_str()[2..], 2) {
Ok(num) => num.into(),
Err(_) => unreachable!(),
},
#[cfg(not(feature = "luau"))]
Err(_) => unreachable!(),
},
};
TokenType::Number {
text: number.to_string().into(),
}
}
_ => unreachable!(),
};
Token::new(token_type)
}
fn visit_string_literal(&mut self, token: Token) -> Token {
let token_type = match token.token_type() {
TokenType::StringLiteral {
literal,
multi_line,
..
} => TokenType::StringLiteral {
literal: literal.to_owned().replace('\\', "").into(),
multi_line: multi_line.to_owned(),
quote_type: StringLiteralQuoteType::Brackets,
},
_ => unreachable!(),
};
Token::new(token_type)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_equivalent_asts() {
let input_ast = full_moon::parse("local x = 1").unwrap();
let output_ast = full_moon::parse("local x = 1").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_different_asts() {
let input_ast = full_moon::parse("local x = 1").unwrap();
let output_ast = full_moon::parse("local x = 2").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(!ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_stmt_semicolons() {
let input_ast = full_moon::parse("local x = 1;").unwrap();
let output_ast = full_moon::parse("local x = 1").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_string_quote_types() {
let input_ast = full_moon::parse("local x = '1'").unwrap();
let output_ast = full_moon::parse("local x = \"1\"").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_string_escapes() {
let input_ast = full_moon::parse("local x = '\\q'").unwrap();
let output_ast = full_moon::parse("local x = 'q'").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_numbers() {
let input_ast = full_moon::parse("local x = .1").unwrap();
let output_ast = full_moon::parse("local x = 0.1").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_numbers_2() {
let input_ast = full_moon::parse("local x = -.1").unwrap();
let output_ast = full_moon::parse("local x = -0.1").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_hex_numbers() {
let input_ast = full_moon::parse("local x = 0XFFFF").unwrap();
let output_ast = full_moon::parse("local x = 0xFFFF").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_different_hex_numbers() {
let input_ast = full_moon::parse("local x = 0xFFAA").unwrap();
let output_ast = full_moon::parse("local x = 0xFFFF").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(!ast_verifier.compare(input_ast, output_ast));
}
#[test]
#[cfg(feature = "luau")]
fn test_equivalent_binary_numbers() {
let input_ast = full_moon::parse("local x = 0B10101").unwrap();
let output_ast = full_moon::parse("local x = 0b10101").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
#[cfg(feature = "luau")]
fn test_different_binary_numbers() {
let input_ast = full_moon::parse("local x = 0b1111").unwrap();
let output_ast = full_moon::parse("local x = 0b1110").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(!ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_table_separators() {
let input_ast = full_moon::parse("local x = {'a'; 'b'; 'c';}").unwrap();
let output_ast = full_moon::parse("local x = {'a', 'b', 'c'}").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_function_calls() {
let input_ast = full_moon::parse("local x = call'foo'").unwrap();
let output_ast = full_moon::parse("local x = call('foo')").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_function_calls_2() {
let input_ast = full_moon::parse("local x = call{'foo'}").unwrap();
let output_ast = full_moon::parse("local x = call({'foo'})").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
#[test]
fn test_equivalent_conditions() {
let input_ast = full_moon::parse("if (true) then return end").unwrap();
let output_ast = full_moon::parse("if true then return end").unwrap();
let mut ast_verifier = AstVerifier::new();
assert!(ast_verifier.compare(input_ast, output_ast));
}
}