use crate::ir::{Expr, MethodCall, Statement};
use super::collect_expr_var_names;
pub fn collapse_i128_patterns(stmts: Vec<Statement>) -> Vec<Statement> {
let simplified: Vec<Statement> = stmts.into_iter()
.map(|s| simplify_stmt_exprs(s))
.collect();
let collapsed: Vec<Statement> = simplified.into_iter()
.map(|s| collapse_deep_i128_stmt(s))
.collect();
eliminate_overflow_guards(collapsed)
}
fn eliminate_overflow_guards(stmts: Vec<Statement>) -> Vec<Statement> {
let mut result = Vec::new();
for stmt in stmts {
match stmt {
Statement::If { ref condition, ref then_body, ref else_body } => {
if is_overflow_check(condition) && else_body.is_empty() {
result.extend(eliminate_overflow_guards(then_body.clone()));
} else {
result.push(Statement::If {
condition: condition.clone(),
then_body: eliminate_overflow_guards(then_body.clone()),
else_body: eliminate_overflow_guards(else_body.clone()),
});
}
}
Statement::While { condition, body } => {
result.push(Statement::While {
condition,
body: eliminate_overflow_guards(body),
});
}
Statement::Loop { body } => {
result.push(Statement::Loop {
body: eliminate_overflow_guards(body),
});
}
Statement::ForEach { var_name, collection, body } => {
result.push(Statement::ForEach {
var_name,
collection,
body: eliminate_overflow_guards(body),
});
}
Statement::ForRange { var_name, bound, body } => {
result.push(Statement::ForRange {
var_name,
bound,
body: eliminate_overflow_guards(body),
});
}
other => result.push(other),
}
}
result
}
fn is_overflow_check(expr: &Expr) -> bool {
if let Expr::BinOp { left, op: crate::ir::BinOp::Ge, right } = expr {
if is_literal_zero(right) && count_shr63(left) >= 2 && has_xor(left) {
return true;
}
}
if let Expr::BinOp { left, op: crate::ir::BinOp::Ge, right } = expr {
if is_literal_zero(right) {
if let Expr::BinOp { op: crate::ir::BinOp::BitAnd, left: inner_l, right: inner_r } = left.as_ref() {
if count_shr63(inner_l) >= 1 && count_shr63(inner_r) >= 1 {
return true;
}
}
}
}
false
}
fn count_shr63(expr: &Expr) -> usize {
match expr {
Expr::BinOp { left, op: crate::ir::BinOp::Shr, right } => {
let is_63 = matches!(right.as_ref(), Expr::Literal(crate::ir::Literal::I64(63)));
let count = count_shr63(left) + count_shr63(right);
if is_63 { count + 1 } else { count }
}
Expr::BinOp { left, right, .. } => count_shr63(left) + count_shr63(right),
Expr::UnOp { operand, .. } => count_shr63(operand),
_ => 0,
}
}
fn has_xor(expr: &Expr) -> bool {
match expr {
Expr::BinOp { op: crate::ir::BinOp::BitXor, .. } => true,
Expr::BinOp { left, right, .. } => has_xor(left) || has_xor(right),
Expr::UnOp { operand, .. } => has_xor(operand),
_ => false,
}
}
fn simplify_stmt_exprs(stmt: Statement) -> Statement {
match stmt {
Statement::Let { name, mutable, value } => Statement::Let {
name,
mutable,
value: simplify_expr(&value),
},
Statement::Assign { target, value } => Statement::Assign {
target: simplify_expr(&target),
value: simplify_expr(&value),
},
Statement::Expr(e) => Statement::Expr(simplify_expr(&e)),
Statement::Return(Some(e)) => Statement::Return(Some(simplify_expr(&e))),
Statement::Return(None) => Statement::Return(None),
Statement::If { condition, then_body, else_body } => Statement::If {
condition: simplify_expr(&condition),
then_body: then_body.into_iter().map(simplify_stmt_exprs).collect(),
else_body: else_body.into_iter().map(simplify_stmt_exprs).collect(),
},
Statement::While { condition, body } => Statement::While {
condition: simplify_expr(&condition),
body: body.into_iter().map(simplify_stmt_exprs).collect(),
},
Statement::Loop { body } => Statement::Loop {
body: body.into_iter().map(simplify_stmt_exprs).collect(),
},
Statement::ForEach { var_name, collection, body } => Statement::ForEach {
var_name,
collection: simplify_expr(&collection),
body: body.into_iter().map(simplify_stmt_exprs).collect(),
},
Statement::ForRange { var_name, bound, body } => Statement::ForRange {
var_name,
bound: simplify_expr(&bound),
body: body.into_iter().map(simplify_stmt_exprs).collect(),
},
}
}
fn simplify_expr(expr: &Expr) -> Expr {
let simplified = match expr {
Expr::BinOp { left, op, right } => {
let l = simplify_expr(left);
let r = simplify_expr(right);
Expr::BinOp {
left: Box::new(l),
op: *op,
right: Box::new(r),
}
}
Expr::UnOp { op, operand } => {
let inner = simplify_expr(operand);
Expr::UnOp {
op: *op,
operand: Box::new(inner),
}
}
Expr::MethodChain { receiver, calls } => Expr::MethodChain {
receiver: Box::new(simplify_expr(receiver)),
calls: calls.iter().map(|c| MethodCall {
name: c.name.clone(),
args: c.args.iter().map(|a| simplify_expr(a)).collect(),
}).collect(),
},
Expr::HostCall { module, name, args } => Expr::HostCall {
module: module.clone(),
name: name.clone(),
args: args.iter().map(|a| simplify_expr(a)).collect(),
},
Expr::MacroCall { name, args } => Expr::MacroCall {
name: name.clone(),
args: args.iter().map(|a| simplify_expr(a)).collect(),
},
Expr::StructLiteral { name, fields } => Expr::StructLiteral {
name: name.clone(),
fields: fields.iter().map(|(k, v)| (k.clone(), simplify_expr(v))).collect(),
},
Expr::EnumVariant { enum_name, variant_name, fields } => Expr::EnumVariant {
enum_name: enum_name.clone(),
variant_name: variant_name.clone(),
fields: fields.iter().map(|f| simplify_expr(f)).collect(),
},
Expr::Ref(inner) => Expr::Ref(Box::new(simplify_expr(inner))),
Expr::Literal(_) | Expr::Var(_) | Expr::Raw(_) => expr.clone(),
};
apply_algebraic_rules(&simplified)
}
fn apply_algebraic_rules(expr: &Expr) -> Expr {
use crate::ir::BinOp as Op;
match expr {
Expr::BinOp { left, op, right } => {
let l = left.as_ref();
let r = right.as_ref();
if *op == Op::Sub && is_literal_zero(l) {
if let Expr::BinOp { left: inner_l, op: Op::Sub, right: inner_r } = r {
if is_literal_zero(inner_l) {
return apply_algebraic_rules(inner_r);
}
}
}
if *op == Op::Mul {
if is_literal_zero(l) || is_literal_zero(r) {
return Expr::Literal(crate::ir::Literal::I64(0));
}
}
if *op == Op::Add {
if is_literal_zero(r) {
return l.clone();
}
if is_literal_zero(l) {
return r.clone();
}
}
if *op == Op::Sub && is_literal_zero(r) {
return l.clone();
}
if *op == Op::BitOr {
if is_literal_zero(r) {
return l.clone();
}
if is_literal_zero(l) {
return r.clone();
}
}
if *op == Op::BitAnd {
if is_literal_zero(l) || is_literal_zero(r) {
return Expr::Literal(crate::ir::Literal::I64(0));
}
}
if *op == Op::Shl {
if is_literal_zero(l) {
return Expr::Literal(crate::ir::Literal::I64(0));
}
if is_literal_zero(r) {
return l.clone();
}
}
if *op == Op::Shr {
if is_literal_zero(l) {
return Expr::Literal(crate::ir::Literal::I64(0));
}
if is_literal_zero(r) {
return l.clone();
}
}
if *op == Op::Mul {
if is_literal_one(r) {
return l.clone();
}
if is_literal_one(l) {
return r.clone();
}
}
if *op == Op::Ne && is_literal_zero(l) && is_literal_zero(r) {
return Expr::Literal(crate::ir::Literal::I64(0));
}
if *op == Op::Lt && is_literal_zero(l) && is_literal_zero(r) {
return Expr::Literal(crate::ir::Literal::I64(0));
}
if *op == Op::BitOr {
if let Some(base) = match_sign_extension(l, r) {
return base;
}
if let Some(base) = match_sign_extension(r, l) {
return base;
}
}
if is_literal_zero(l) && is_literal_zero(r) {
match op {
Op::Eq => return bool_literal(true), Op::Le => return bool_literal(true), Op::Ge => return bool_literal(true), _ => {}
}
}
expr.clone()
}
Expr::UnOp { op: crate::ir::UnOp::Neg, operand } => {
if let Expr::UnOp { op: crate::ir::UnOp::Neg, operand: inner } = operand.as_ref() {
return inner.as_ref().clone();
}
expr.clone()
}
Expr::UnOp { op: crate::ir::UnOp::Not, operand } => {
if let Expr::UnOp { op: crate::ir::UnOp::Not, operand: inner } = operand.as_ref() {
return inner.as_ref().clone();
}
if let Expr::BinOp { left, op, right } = operand.as_ref() {
let flipped = match op {
Op::Ne => Some(Op::Eq),
Op::Eq => Some(Op::Ne),
Op::Lt => Some(Op::Ge),
Op::Ge => Some(Op::Lt),
Op::Gt => Some(Op::Le),
Op::Le => Some(Op::Gt),
_ => None,
};
if let Some(new_op) = flipped {
return Expr::BinOp {
left: left.clone(),
op: new_op,
right: right.clone(),
};
}
}
if is_literal_zero(operand) {
return Expr::Literal(crate::ir::Literal::I32(1));
}
if is_nonzero_literal(operand) {
return Expr::Literal(crate::ir::Literal::I32(0));
}
expr.clone()
}
_ => expr.clone(),
}
}
fn is_literal_zero(expr: &Expr) -> bool {
matches!(expr,
Expr::Literal(crate::ir::Literal::I32(0))
| Expr::Literal(crate::ir::Literal::I64(0))
)
}
fn is_literal_one(expr: &Expr) -> bool {
matches!(expr,
Expr::Literal(crate::ir::Literal::I32(1))
| Expr::Literal(crate::ir::Literal::I64(1))
)
}
fn is_nonzero_literal(expr: &Expr) -> bool {
match expr {
Expr::Literal(crate::ir::Literal::I32(n)) => *n != 0,
Expr::Literal(crate::ir::Literal::I64(n)) => *n != 0,
_ => false,
}
}
fn bool_literal(val: bool) -> Expr {
Expr::Literal(crate::ir::Literal::I32(if val { 1 } else { 0 }))
}
fn match_sign_extension(shift_half: &Expr, base: &Expr) -> Option<Expr> {
if let Expr::BinOp { left: shl_inner, op: crate::ir::BinOp::Shl, right: shl_amount } = shift_half {
let shift_amt = match shl_amount.as_ref() {
Expr::Literal(crate::ir::Literal::I64(n)) => Some(*n),
Expr::Literal(crate::ir::Literal::I32(n)) => Some(*n as i64),
_ => None,
};
if matches!(shift_amt, Some(32) | Some(64)) {
if let Expr::BinOp { left: shr_inner, op: crate::ir::BinOp::Shr, right: shr_amount } = shl_inner.as_ref() {
let shr_amt = match shr_amount.as_ref() {
Expr::Literal(crate::ir::Literal::I64(63)) => true,
Expr::Literal(crate::ir::Literal::I32(63)) => true,
_ => false,
};
if shr_amt && *shr_inner.as_ref() == *base {
return Some(base.clone());
}
}
}
}
None
}
fn collapse_deep_i128_stmt(stmt: Statement) -> Statement {
match stmt {
Statement::Let { name, mutable, value } => Statement::Let {
name,
mutable,
value: collapse_deep_i128_expr(&value),
},
Statement::Assign { target, value } => Statement::Assign {
target: collapse_deep_i128_expr(&target),
value: collapse_deep_i128_expr(&value),
},
Statement::Expr(e) => Statement::Expr(collapse_deep_i128_expr(&e)),
Statement::Return(Some(e)) => Statement::Return(Some(collapse_deep_i128_expr(&e))),
Statement::Return(None) => Statement::Return(None),
Statement::If { condition, then_body, else_body } => Statement::If {
condition: collapse_deep_i128_expr(&condition),
then_body: then_body.into_iter().map(collapse_deep_i128_stmt).collect(),
else_body: else_body.into_iter().map(collapse_deep_i128_stmt).collect(),
},
Statement::While { condition, body } => Statement::While {
condition: collapse_deep_i128_expr(&condition),
body: body.into_iter().map(collapse_deep_i128_stmt).collect(),
},
Statement::Loop { body } => Statement::Loop {
body: body.into_iter().map(collapse_deep_i128_stmt).collect(),
},
Statement::ForEach { var_name, collection, body } => Statement::ForEach {
var_name,
collection: collapse_deep_i128_expr(&collection),
body: body.into_iter().map(collapse_deep_i128_stmt).collect(),
},
Statement::ForRange { var_name, bound, body } => Statement::ForRange {
var_name,
bound: collapse_deep_i128_expr(&bound),
body: body.into_iter().map(collapse_deep_i128_stmt).collect(),
},
}
}
fn collapse_deep_i128_expr(expr: &Expr) -> Expr {
match expr {
Expr::BinOp { .. } => {
let depth = expr_depth(expr);
if depth > 8 && has_i128_constants(expr) {
let mut vars = std::collections::BTreeSet::new();
collect_expr_var_names(expr, &mut vars);
let var_list = if vars.is_empty() {
String::new()
} else {
format!(" on {}", vars.into_iter().collect::<Vec<_>>().join(", "))
};
Expr::Raw(format!("/* i128 arithmetic{var_list} */"))
} else {
match expr {
Expr::BinOp { left, op, right } => Expr::BinOp {
left: Box::new(collapse_deep_i128_expr(left)),
op: *op,
right: Box::new(collapse_deep_i128_expr(right)),
},
_ => unreachable!(),
}
}
}
Expr::MethodChain { receiver, calls } => Expr::MethodChain {
receiver: Box::new(collapse_deep_i128_expr(receiver)),
calls: calls.iter().map(|c| MethodCall {
name: c.name.clone(),
args: c.args.iter().map(|a| collapse_deep_i128_expr(a)).collect(),
}).collect(),
},
Expr::HostCall { module, name, args } => Expr::HostCall {
module: module.clone(),
name: name.clone(),
args: args.iter().map(|a| collapse_deep_i128_expr(a)).collect(),
},
Expr::MacroCall { name, args } => Expr::MacroCall {
name: name.clone(),
args: args.iter().map(|a| collapse_deep_i128_expr(a)).collect(),
},
Expr::StructLiteral { name, fields } => Expr::StructLiteral {
name: name.clone(),
fields: fields.iter().map(|(k, v)| (k.clone(), collapse_deep_i128_expr(v))).collect(),
},
Expr::EnumVariant { enum_name, variant_name, fields } => Expr::EnumVariant {
enum_name: enum_name.clone(),
variant_name: variant_name.clone(),
fields: fields.iter().map(|f| collapse_deep_i128_expr(f)).collect(),
},
Expr::Ref(inner) => Expr::Ref(Box::new(collapse_deep_i128_expr(inner))),
Expr::UnOp { op, operand } => Expr::UnOp {
op: *op,
operand: Box::new(collapse_deep_i128_expr(operand)),
},
Expr::Literal(_) | Expr::Var(_) | Expr::Raw(_) => expr.clone(),
}
}
fn expr_depth(expr: &Expr) -> usize {
match expr {
Expr::BinOp { left, right, .. } => {
1 + expr_depth(left).max(expr_depth(right))
}
Expr::UnOp { operand, .. } => 1 + expr_depth(operand),
Expr::Ref(inner) => expr_depth(inner),
_ => 0,
}
}
fn has_i128_constants(expr: &Expr) -> bool {
match expr {
Expr::Literal(crate::ir::Literal::I64(v)) => {
matches!(*v,
4294967295 | 4294966296 | 4294966299 | 32 | 63 | 64
)
}
Expr::Literal(crate::ir::Literal::I32(v)) => {
matches!(*v, 32 | 63 | 64)
}
Expr::BinOp { left, op, right } => {
if matches!(op, crate::ir::BinOp::Shl | crate::ir::BinOp::Shr) {
if let Expr::Literal(crate::ir::Literal::I64(n)) = right.as_ref() {
if matches!(*n, 32 | 63 | 64) {
return true;
}
}
if let Expr::Literal(crate::ir::Literal::I32(n)) = right.as_ref() {
if matches!(*n, 32 | 63 | 64) {
return true;
}
}
}
has_i128_constants(left) || has_i128_constants(right)
}
Expr::UnOp { operand, .. } => has_i128_constants(operand),
Expr::Ref(inner) => has_i128_constants(inner),
_ => false,
}
}