unluac 1.1.1

Multi-dialect Lua decompiler written in Rust.
Documentation
//! 这个文件承载 HIR 的保守逻辑表达式整理。
//!
//! Lua 的 `and/or` 返回的是原始操作数,不是布尔值,所以很多看似显然的布尔代数
//! 恒等式其实并不安全。这里故意只实现一小撮在 Lua 值语义下也严格成立的规则,
//! 用来压掉短路 DAG 恢复后最机械的重复,而不越权重写控制流结构。
//!
//! 它依赖前面的 short-circuit / decision 恢复已经把候选逻辑表达式保守落成 HIR,
//! 这里仅做“值语义严格不变”的局部整理,不重新分析 CFG,也不替前层兜底修坏掉的
//! 短路结构。
//!
//! 例子:
//! - `x and x` 会折成 `x`
//! - `(a and b) or (a and c)` 会整理成 `a and (b or c)`,但前提是共享 guard `a`
//!   没有副作用
//! - `x or x` 会折成 `x`
//! - 它不会把一般 `if/branch` 结构强行改写成逻辑表达式,那仍然属于更前面的结构恢复职责

use super::expr_facts::{
    expr_is_side_effect_free, expr_truthiness, fold_associative_duplicate_and,
    fold_associative_duplicate_or,
};
use super::walk::{ExprRewritePass, rewrite_proto_exprs};
use crate::hir::common::{HirExpr, HirLogicalExpr, HirProto};

#[cfg(test)]
use crate::hir::common::HirBlock;

#[cfg(test)]
use crate::hir::common::HirStmt;

/// 对单个 proto 递归执行安全的逻辑表达式整理。
pub(super) fn simplify_logical_exprs_in_proto(proto: &mut HirProto) -> bool {
    rewrite_proto_exprs(proto, &mut LogicalExprPass)
}

struct LogicalExprPass;

impl ExprRewritePass for LogicalExprPass {
    fn rewrite_expr(&mut self, expr: &mut HirExpr) -> bool {
        let mut changed = false;

        if let Some(replacement) = simplify_logical_shape(expr) {
            *expr = replacement;
            changed = true;
        }
        if let Some(replacement) = super::decision::naturalize_pure_logical_expr(expr) {
            *expr = replacement;
            changed = true;
        }

        changed
    }
}

fn simplify_logical_shape(expr: &HirExpr) -> Option<HirExpr> {
    match expr {
        HirExpr::LogicalAnd(logical) => simplify_logical_and(&logical.lhs, &logical.rhs),
        HirExpr::LogicalOr(logical) => simplify_logical_or(&logical.lhs, &logical.rhs),
        _ => None,
    }
}

fn simplify_logical_and(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    if lhs == rhs {
        return Some(lhs.clone());
    }

    if let Some(replacement) = fold_associative_duplicate_and(lhs, rhs) {
        return Some(replacement);
    }

    if let Some(replacement) = fold_constant_short_circuit_and(lhs, rhs) {
        return Some(replacement);
    }

    match rhs {
        HirExpr::LogicalOr(inner) if lhs == &inner.lhs => Some(lhs.clone()),
        _ => match lhs {
            HirExpr::LogicalOr(inner) if rhs == &inner.lhs || rhs == &inner.rhs => {
                Some(rhs.clone())
            }
            _ => None,
        },
    }
}

fn simplify_logical_or(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    if lhs == rhs {
        return Some(lhs.clone());
    }

    if let Some(replacement) = fold_associative_duplicate_or(lhs, rhs) {
        return Some(replacement);
    }

    if let Some(replacement) = fold_constant_short_circuit_or(lhs, rhs) {
        return Some(replacement);
    }
    if let Some(replacement) = factor_shared_and_guards(lhs, rhs) {
        return Some(replacement);
    }
    if let Some(replacement) = pull_shared_or_tail(lhs, rhs) {
        return Some(replacement);
    }
    if let Some(replacement) = simplify_or_chain(lhs, rhs) {
        return Some(replacement);
    }
    if let Some(replacement) = fold_shared_fallback_or(lhs, rhs) {
        return Some(replacement);
    }

    match rhs {
        HirExpr::LogicalAnd(inner) if lhs == &inner.lhs => Some(lhs.clone()),
        _ => match lhs {
            // `(x and y) or x == x` 在 Lua 值语义下也严格成立:
            // 当 `x` 为假时,左边退化成 `x`;当 `x` 为真时,右边短路保留 `x`。
            HirExpr::LogicalAnd(inner) if rhs == &inner.lhs || rhs == &inner.rhs => {
                Some(rhs.clone())
            }
            _ => None,
        },
    }
}

fn factor_shared_and_guards(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    factor_shared_and_guards_one_side(lhs, rhs)
        .or_else(|| factor_shared_and_guards_one_side(rhs, lhs))
}

fn factor_shared_and_guards_one_side(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    let HirExpr::LogicalAnd(lhs_and) = lhs else {
        return None;
    };
    let HirExpr::LogicalAnd(rhs_and) = rhs else {
        return None;
    };

    if lhs_and.lhs == rhs_and.lhs && expr_is_side_effect_free(&lhs_and.lhs) {
        return Some(HirExpr::LogicalAnd(Box::new(HirLogicalExpr {
            lhs: lhs_and.lhs.clone(),
            rhs: HirExpr::LogicalOr(Box::new(HirLogicalExpr {
                lhs: lhs_and.rhs.clone(),
                rhs: rhs_and.rhs.clone(),
            })),
        })));
    }

    if lhs_and.rhs == rhs_and.rhs && expr_is_side_effect_free(&lhs_and.rhs) {
        return Some(HirExpr::LogicalAnd(Box::new(HirLogicalExpr {
            lhs: HirExpr::LogicalOr(Box::new(HirLogicalExpr {
                lhs: lhs_and.lhs.clone(),
                rhs: rhs_and.lhs.clone(),
            })),
            rhs: lhs_and.rhs.clone(),
        })));
    }

    None
}

fn pull_shared_or_tail(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    pull_shared_or_tail_one_side(lhs, rhs).or_else(|| pull_shared_or_tail_one_side(rhs, lhs))
}

fn pull_shared_or_tail_one_side(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    let HirExpr::LogicalAnd(lhs_and) = lhs else {
        return None;
    };
    let HirExpr::LogicalOr(inner_or) = &lhs_and.rhs else {
        return None;
    };
    if rhs != &inner_or.rhs || !expr_is_side_effect_free(rhs) {
        return None;
    }

    Some(HirExpr::LogicalOr(Box::new(HirLogicalExpr {
        lhs: HirExpr::LogicalAnd(Box::new(HirLogicalExpr {
            lhs: lhs_and.lhs.clone(),
            rhs: inner_or.lhs.clone(),
        })),
        rhs: rhs.clone(),
    })))
}

fn simplify_or_chain(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    let terms = flatten_or_chain_exprs(lhs, rhs);
    if terms.len() < 3 {
        return None;
    }

    let mut best = None;
    for left in 0..terms.len() {
        for right in left + 1..terms.len() {
            let rewritten = factor_shared_and_guards_one_side(&terms[left], &terms[right])
                .or_else(|| factor_shared_and_guards_one_side(&terms[right], &terms[left]))
                .or_else(|| pull_shared_or_tail_one_side(&terms[left], &terms[right]))
                .or_else(|| pull_shared_or_tail_one_side(&terms[right], &terms[left]));
            let Some(rewritten) = rewritten else {
                continue;
            };

            let mut rebuilt = Vec::with_capacity(terms.len() - 1);
            for (index, term) in terms.iter().enumerate() {
                if index == left {
                    rebuilt.push(rewritten.clone());
                } else if index != right {
                    rebuilt.push(term.clone());
                }
            }
            let candidate = rebuild_or_chain(rebuilt);
            if expr_cost(&candidate)
                < expr_cost(&HirExpr::LogicalOr(Box::new(HirLogicalExpr {
                    lhs: lhs.clone(),
                    rhs: rhs.clone(),
                })))
            {
                match &best {
                    Some(existing) if expr_cost(existing) <= expr_cost(&candidate) => {}
                    _ => best = Some(candidate),
                }
            }
        }
    }

    best
}

fn flatten_or_chain_exprs(lhs: &HirExpr, rhs: &HirExpr) -> Vec<HirExpr> {
    let mut out = Vec::new();
    collect_or_chain_exprs(lhs, &mut out);
    collect_or_chain_exprs(rhs, &mut out);
    out
}

fn collect_or_chain_exprs(expr: &HirExpr, out: &mut Vec<HirExpr>) {
    match expr {
        HirExpr::LogicalOr(logical) => {
            collect_or_chain_exprs(&logical.lhs, out);
            collect_or_chain_exprs(&logical.rhs, out);
        }
        _ => out.push(expr.clone()),
    }
}

fn rebuild_or_chain(mut terms: Vec<HirExpr>) -> HirExpr {
    let first = terms
        .drain(..1)
        .next()
        .expect("or chain rebuild requires at least one term");
    terms.into_iter().fold(first, |lhs, rhs| {
        HirExpr::LogicalOr(Box::new(HirLogicalExpr { lhs, rhs }))
    })
}

fn expr_cost(expr: &HirExpr) -> usize {
    match expr {
        HirExpr::LogicalAnd(logical) | HirExpr::LogicalOr(logical) => {
            1 + expr_cost(&logical.lhs) + expr_cost(&logical.rhs)
        }
        HirExpr::Unary(unary) => 1 + expr_cost(&unary.expr),
        HirExpr::Binary(binary) => 1 + expr_cost(&binary.lhs) + expr_cost(&binary.rhs),
        _ => 1,
    }
}

/// 这里只折叠“左值 truthiness 已知”的短路表达式。
///
/// 这样做的原因是这类重写不需要推导额外控制流,也不会像一般布尔代数那样误伤
/// Lua 的值语义。唯一需要额外守住的是:当运行时原本会短路掉右值时,右值必须
/// 没有副作用,才能把它安全删除。
fn fold_constant_short_circuit_and(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    match expr_truthiness(lhs) {
        Some(true) => Some(rhs.clone()),
        Some(false) if expr_is_side_effect_free(rhs) => Some(lhs.clone()),
        Some(false) => None,
        None => None,
    }
}

/// 这里和 `fold_constant_short_circuit_and` 对偶:只在左值 truthiness 已知时折叠,
/// 并且只在“原本会短路掉右值”的分支上要求右值无副作用。
fn fold_constant_short_circuit_or(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    match expr_truthiness(lhs) {
        Some(true) if expr_is_side_effect_free(rhs) => Some(lhs.clone()),
        Some(true) => None,
        Some(false) => Some(rhs.clone()),
        None => None,
    }
}

/// 这里处理一类共享 fallback 的机械展开:
///
/// `((not x) and y) or (x or y)` 在 Lua 里和 `x or y` 等价,只是前者会在恢复
/// 决策 DAG 时留下重复的 fallback 片段。只要 `y` 无副作用,这里就可以安全地
/// 把它重新收回更自然的短路表达式。
fn fold_shared_fallback_or(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    shared_fallback_or_one_side(lhs, rhs).or_else(|| shared_fallback_or_one_side(rhs, lhs))
}

fn shared_fallback_or_one_side(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
    let HirExpr::LogicalAnd(lhs_and) = lhs else {
        return None;
    };
    let HirExpr::LogicalOr(rhs_or) = rhs else {
        return None;
    };
    let guard = strip_negation(&lhs_and.lhs)?;
    if guard != rhs_or.lhs || lhs_and.rhs != rhs_or.rhs {
        return None;
    }
    expr_is_side_effect_free(&lhs_and.rhs).then_some(rhs.clone())
}

fn strip_negation(expr: &HirExpr) -> Option<HirExpr> {
    match expr {
        HirExpr::Unary(unary) if matches!(unary.op, crate::hir::common::HirUnaryOpKind::Not) => {
            Some(unary.expr.clone())
        }
        _ => None,
    }
}

#[cfg(test)]
mod tests;