use super::super::common::{
AstBinaryExpr, AstBinaryOpKind, AstBlock, AstCallKind, AstExpr, AstFunctionExpr, AstLValue,
AstLogicalExpr, AstModule, AstStmt, AstTableField, AstTableKey, AstUnaryExpr, AstUnaryOpKind,
};
use super::ReadabilityContext;
use crate::hir::{
HirBinaryExpr, HirBinaryOpKind, HirExpr, HirLogicalExpr, HirUnaryExpr, HirUnaryOpKind,
synthesize_readable_pure_logical_expr,
};
pub(super) fn apply(module: &mut AstModule, context: ReadabilityContext) -> bool {
let _ = context.target;
rewrite_block(&mut module.body)
}
fn rewrite_block(block: &mut AstBlock) -> bool {
let mut changed = false;
for stmt in &mut block.stmts {
changed |= rewrite_stmt(stmt);
}
changed
}
fn rewrite_stmt(stmt: &mut AstStmt) -> bool {
match stmt {
AstStmt::If(if_stmt) => {
let mut changed = rewrite_expr(&mut if_stmt.cond);
changed |= rewrite_block(&mut if_stmt.then_block);
if let Some(else_block) = &mut if_stmt.else_block {
changed |= rewrite_block(else_block);
}
changed
}
AstStmt::While(while_stmt) => {
rewrite_expr(&mut while_stmt.cond) | rewrite_block(&mut while_stmt.body)
}
AstStmt::Repeat(repeat_stmt) => {
rewrite_block(&mut repeat_stmt.body) | rewrite_expr(&mut repeat_stmt.cond)
}
AstStmt::NumericFor(numeric_for) => {
let mut changed = rewrite_expr(&mut numeric_for.start);
changed |= rewrite_expr(&mut numeric_for.limit);
changed |= rewrite_expr(&mut numeric_for.step);
changed |= rewrite_block(&mut numeric_for.body);
changed
}
AstStmt::GenericFor(generic_for) => {
let mut changed = false;
for expr in &mut generic_for.iterator {
changed |= rewrite_expr(expr);
}
changed |= rewrite_block(&mut generic_for.body);
changed
}
AstStmt::DoBlock(block) => rewrite_block(block),
AstStmt::FunctionDecl(function_decl) => rewrite_function(&mut function_decl.func),
AstStmt::LocalFunctionDecl(local_function_decl) => {
rewrite_function(&mut local_function_decl.func)
}
AstStmt::LocalDecl(local_decl) => {
let mut changed = false;
for value in &mut local_decl.values {
changed |= rewrite_expr(value);
}
changed
}
AstStmt::GlobalDecl(global_decl) => {
let mut changed = false;
for value in &mut global_decl.values {
changed |= rewrite_expr(value);
}
changed
}
AstStmt::Assign(assign) => {
let mut changed = false;
for target in &mut assign.targets {
changed |= rewrite_lvalue(target);
}
for value in &mut assign.values {
changed |= rewrite_expr(value);
}
changed
}
AstStmt::CallStmt(call_stmt) => rewrite_call(&mut call_stmt.call),
AstStmt::Return(ret) => {
let mut changed = false;
for value in &mut ret.values {
changed |= rewrite_expr(value);
}
changed
}
AstStmt::Break | AstStmt::Continue | AstStmt::Goto(_) | AstStmt::Label(_) | AstStmt::Error(_) => false,
}
}
fn rewrite_function(function: &mut AstFunctionExpr) -> bool {
rewrite_block(&mut function.body)
}
fn rewrite_call(call: &mut AstCallKind) -> bool {
match call {
AstCallKind::Call(call) => {
let mut changed = rewrite_expr(&mut call.callee);
for arg in &mut call.args {
changed |= rewrite_expr(arg);
}
changed
}
AstCallKind::MethodCall(call) => {
let mut changed = rewrite_expr(&mut call.receiver);
for arg in &mut call.args {
changed |= rewrite_expr(arg);
}
changed
}
}
}
fn rewrite_lvalue(target: &mut AstLValue) -> bool {
match target {
AstLValue::Name(_) => false,
AstLValue::FieldAccess(access) => rewrite_expr(&mut access.base),
AstLValue::IndexAccess(access) => {
rewrite_expr(&mut access.base) | rewrite_expr(&mut access.index)
}
}
}
fn rewrite_expr(expr: &mut AstExpr) -> bool {
let mut changed = match expr {
AstExpr::FieldAccess(access) => rewrite_expr(&mut access.base),
AstExpr::IndexAccess(access) => {
rewrite_expr(&mut access.base) | rewrite_expr(&mut access.index)
}
AstExpr::Unary(unary) => rewrite_expr(&mut unary.expr),
AstExpr::Binary(binary) => rewrite_expr(&mut binary.lhs) | rewrite_expr(&mut binary.rhs),
AstExpr::LogicalAnd(logical) | AstExpr::LogicalOr(logical) => {
rewrite_expr(&mut logical.lhs) | rewrite_expr(&mut logical.rhs)
}
AstExpr::Call(call) => {
let mut changed = rewrite_expr(&mut call.callee);
for arg in &mut call.args {
changed |= rewrite_expr(arg);
}
changed
}
AstExpr::MethodCall(call) => {
let mut changed = rewrite_expr(&mut call.receiver);
for arg in &mut call.args {
changed |= rewrite_expr(arg);
}
changed
}
AstExpr::SingleValue(expr) => rewrite_expr(expr),
AstExpr::TableConstructor(table) => {
let mut changed = false;
for field in &mut table.fields {
match field {
AstTableField::Array(value) => changed |= rewrite_expr(value),
AstTableField::Record(record) => {
if let AstTableKey::Expr(key) = &mut record.key {
changed |= rewrite_expr(key);
}
changed |= rewrite_expr(&mut record.value);
}
}
}
changed
}
AstExpr::FunctionExpr(function) => rewrite_function(function),
AstExpr::Nil
| AstExpr::Boolean(_)
| AstExpr::Integer(_)
| AstExpr::Number(_)
| AstExpr::String(_)
| AstExpr::Int64(_)
| AstExpr::UInt64(_)
| AstExpr::Complex { .. }
| AstExpr::Var(_)
| AstExpr::VarArg | AstExpr::Error(_) => false,
};
if let Some(hir_expr) = hir_from_ast_expr(expr)
&& let Some(pretty_hir) = synthesize_readable_pure_logical_expr(&hir_expr)
&& pretty_hir != hir_expr
&& let Some(pretty_ast) = ast_from_hir_expr(&pretty_hir)
{
*expr = pretty_ast;
changed = true;
}
changed
}
fn hir_from_ast_expr(expr: &AstExpr) -> Option<HirExpr> {
match expr {
AstExpr::Nil => Some(HirExpr::Nil),
AstExpr::Boolean(value) => Some(HirExpr::Boolean(*value)),
AstExpr::Integer(value) => Some(HirExpr::Integer(*value)),
AstExpr::Number(value) => Some(HirExpr::Number(*value)),
AstExpr::String(value) => Some(HirExpr::String(value.clone())),
AstExpr::Int64(value) => Some(HirExpr::Int64(*value)),
AstExpr::UInt64(value) => Some(HirExpr::UInt64(*value)),
AstExpr::Complex { real, imag } => Some(HirExpr::Complex {
real: *real,
imag: *imag,
}),
AstExpr::Var(name) => match name {
super::super::common::AstNameRef::Param(param) => Some(HirExpr::ParamRef(*param)),
super::super::common::AstNameRef::Local(local) => Some(HirExpr::LocalRef(*local)),
super::super::common::AstNameRef::Temp(temp) => Some(HirExpr::TempRef(*temp)),
super::super::common::AstNameRef::SyntheticLocal(_) => None,
super::super::common::AstNameRef::Upvalue(upvalue) => {
Some(HirExpr::UpvalueRef(*upvalue))
}
super::super::common::AstNameRef::Global(_) => None,
},
AstExpr::Unary(unary) if unary.op == AstUnaryOpKind::Not => {
Some(HirExpr::Unary(Box::new(HirUnaryExpr {
op: HirUnaryOpKind::Not,
expr: hir_from_ast_expr(&unary.expr)?,
})))
}
AstExpr::Binary(binary) if binary.op == AstBinaryOpKind::Eq => {
Some(HirExpr::Binary(Box::new(HirBinaryExpr {
op: HirBinaryOpKind::Eq,
lhs: hir_from_ast_expr(&binary.lhs)?,
rhs: hir_from_ast_expr(&binary.rhs)?,
})))
}
AstExpr::LogicalAnd(logical) => Some(HirExpr::LogicalAnd(Box::new(HirLogicalExpr {
lhs: hir_from_ast_expr(&logical.lhs)?,
rhs: hir_from_ast_expr(&logical.rhs)?,
}))),
AstExpr::LogicalOr(logical) => Some(HirExpr::LogicalOr(Box::new(HirLogicalExpr {
lhs: hir_from_ast_expr(&logical.lhs)?,
rhs: hir_from_ast_expr(&logical.rhs)?,
}))),
AstExpr::SingleValue(expr) => hir_from_ast_expr(expr),
AstExpr::FieldAccess(_)
| AstExpr::IndexAccess(_)
| AstExpr::Unary(_)
| AstExpr::Binary(_)
| AstExpr::Call(_)
| AstExpr::MethodCall(_)
| AstExpr::VarArg
| AstExpr::TableConstructor(_)
| AstExpr::FunctionExpr(_)
| AstExpr::Error(_) => None,
}
}
fn ast_from_hir_expr(expr: &HirExpr) -> Option<AstExpr> {
match expr {
HirExpr::Nil => Some(AstExpr::Nil),
HirExpr::Boolean(value) => Some(AstExpr::Boolean(*value)),
HirExpr::Integer(value) => Some(AstExpr::Integer(*value)),
HirExpr::Number(value) => Some(AstExpr::Number(*value)),
HirExpr::String(value) => Some(AstExpr::String(value.clone())),
HirExpr::Int64(value) => Some(AstExpr::Int64(*value)),
HirExpr::UInt64(value) => Some(AstExpr::UInt64(*value)),
HirExpr::Complex { real, imag } => Some(AstExpr::Complex {
real: *real,
imag: *imag,
}),
HirExpr::ParamRef(param) => Some(AstExpr::Var(super::super::common::AstNameRef::Param(
*param,
))),
HirExpr::LocalRef(local) => Some(AstExpr::Var(super::super::common::AstNameRef::Local(
*local,
))),
HirExpr::TempRef(temp) => Some(AstExpr::Var(super::super::common::AstNameRef::Temp(*temp))),
HirExpr::UpvalueRef(upvalue) => Some(AstExpr::Var(
super::super::common::AstNameRef::Upvalue(*upvalue),
)),
HirExpr::Unary(unary) if unary.op == HirUnaryOpKind::Not => {
Some(AstExpr::Unary(Box::new(AstUnaryExpr {
op: AstUnaryOpKind::Not,
expr: ast_from_hir_expr(&unary.expr)?,
})))
}
HirExpr::Binary(binary) if binary.op == HirBinaryOpKind::Eq => {
Some(AstExpr::Binary(Box::new(AstBinaryExpr {
op: AstBinaryOpKind::Eq,
lhs: ast_from_hir_expr(&binary.lhs)?,
rhs: ast_from_hir_expr(&binary.rhs)?,
})))
}
HirExpr::LogicalAnd(logical) => Some(AstExpr::LogicalAnd(Box::new(AstLogicalExpr {
lhs: ast_from_hir_expr(&logical.lhs)?,
rhs: ast_from_hir_expr(&logical.rhs)?,
}))),
HirExpr::LogicalOr(logical) => Some(AstExpr::LogicalOr(Box::new(AstLogicalExpr {
lhs: ast_from_hir_expr(&logical.lhs)?,
rhs: ast_from_hir_expr(&logical.rhs)?,
}))),
HirExpr::Decision(_)
| HirExpr::GlobalRef(_)
| HirExpr::TableAccess(_)
| HirExpr::Unary(_)
| HirExpr::Binary(_)
| HirExpr::Call(_)
| HirExpr::VarArg
| HirExpr::TableConstructor(_)
| HirExpr::Closure(_)
| HirExpr::Unresolved(_) => None,
}
}
#[cfg(test)]
mod tests;