use crate::{MirBody, MirExpr, MirExprKind, MirFunction, MirProgram, MirStmt};
use cjc_ast::{BinOp, UnaryOp};
use std::collections::{BTreeMap, BTreeSet};
pub fn optimize_program(program: &MirProgram) -> MirProgram {
let mut optimized = program.clone();
for func in &mut optimized.functions {
constant_fold_fn(func);
}
for func in &mut optimized.functions {
strength_reduce_fn(func);
}
for func in &mut optimized.functions {
dce_fn(func);
}
for func in &mut optimized.functions {
cse_fn(func);
}
for func in &mut optimized.functions {
licm_fn(func);
}
for func in &mut optimized.functions {
constant_fold_fn(func);
}
optimized
}
fn constant_fold_fn(func: &mut MirFunction) {
constant_fold_body(&mut func.body);
}
fn constant_fold_body(body: &mut MirBody) {
for stmt in &mut body.stmts {
constant_fold_stmt(stmt);
}
if let Some(ref mut expr) = body.result {
constant_fold_expr(expr);
}
}
fn constant_fold_stmt(stmt: &mut MirStmt) {
match stmt {
MirStmt::Let { init, .. } => {
constant_fold_expr(init);
}
MirStmt::Expr(expr) => {
constant_fold_expr(expr);
}
MirStmt::If {
cond,
then_body,
else_body,
} => {
constant_fold_expr(cond);
constant_fold_body(then_body);
if let Some(eb) = else_body {
constant_fold_body(eb);
}
}
MirStmt::While { cond, body } => {
constant_fold_expr(cond);
constant_fold_body(body);
}
MirStmt::Return(opt_expr) => {
if let Some(expr) = opt_expr {
constant_fold_expr(expr);
}
}
MirStmt::Break | MirStmt::Continue => {}
MirStmt::NoGcBlock(body) => {
constant_fold_body(body);
}
}
}
fn constant_fold_expr(expr: &mut MirExpr) {
match &mut expr.kind {
MirExprKind::Binary { left, right, .. } => {
constant_fold_expr(left);
constant_fold_expr(right);
}
MirExprKind::Unary { operand, .. } => {
constant_fold_expr(operand);
}
MirExprKind::Call { callee, args } => {
constant_fold_expr(callee);
for arg in args {
constant_fold_expr(arg);
}
}
MirExprKind::Field { object, .. } => {
constant_fold_expr(object);
}
MirExprKind::Index { object, index } => {
constant_fold_expr(object);
constant_fold_expr(index);
}
MirExprKind::MultiIndex { object, indices } => {
constant_fold_expr(object);
for idx in indices {
constant_fold_expr(idx);
}
}
MirExprKind::Assign { target, value } => {
constant_fold_expr(target);
constant_fold_expr(value);
}
MirExprKind::Block(body) => {
constant_fold_body(body);
}
MirExprKind::StructLit { fields, .. } => {
for (_, fexpr) in fields {
constant_fold_expr(fexpr);
}
}
MirExprKind::ArrayLit(elems) | MirExprKind::TupleLit(elems) => {
for e in elems {
constant_fold_expr(e);
}
}
MirExprKind::MakeClosure { captures, .. } => {
for cap in captures {
constant_fold_expr(cap);
}
}
MirExprKind::If {
cond,
then_body,
else_body,
} => {
constant_fold_expr(cond);
constant_fold_body(then_body);
if let Some(eb) = else_body {
constant_fold_body(eb);
}
}
MirExprKind::Match { scrutinee, arms } => {
constant_fold_expr(scrutinee);
for arm in arms {
constant_fold_body(&mut arm.body);
}
}
MirExprKind::Lambda { body, .. } => {
constant_fold_expr(body);
}
MirExprKind::LinalgLU { operand }
| MirExprKind::LinalgQR { operand }
| MirExprKind::LinalgCholesky { operand }
| MirExprKind::LinalgInv { operand } => {
constant_fold_expr(operand);
}
MirExprKind::Broadcast { operand, target_shape } => {
constant_fold_expr(operand);
for s in target_shape {
constant_fold_expr(s);
}
}
MirExprKind::VariantLit { fields, .. } => {
for f in fields {
constant_fold_expr(f);
}
}
MirExprKind::IntLit(_)
| MirExprKind::FloatLit(_)
| MirExprKind::BoolLit(_)
| MirExprKind::NaLit
| MirExprKind::StringLit(_)
| MirExprKind::ByteStringLit(_)
| MirExprKind::ByteCharLit(_)
| MirExprKind::RawStringLit(_)
| MirExprKind::RawByteStringLit(_)
| MirExprKind::RegexLit { .. }
| MirExprKind::Var(_)
| MirExprKind::Col(_)
| MirExprKind::Void => {}
MirExprKind::TensorLit { rows } => {
for row in rows {
for elem in row {
constant_fold_expr(elem);
}
}
}
}
if let Some(folded) = try_fold(expr) {
*expr = folded;
}
}
fn try_fold(expr: &MirExpr) -> Option<MirExpr> {
match &expr.kind {
MirExprKind::Binary { op, left, right } => try_fold_binary(*op, left, right),
MirExprKind::Unary { op, operand } => try_fold_unary(*op, operand),
MirExprKind::If {
cond,
then_body,
else_body,
} => {
if let MirExprKind::BoolLit(b) = &cond.kind {
if *b {
Some(MirExpr {
kind: MirExprKind::Block(then_body.clone()),
})
} else if let Some(eb) = else_body {
Some(MirExpr {
kind: MirExprKind::Block(eb.clone()),
})
} else {
Some(MirExpr {
kind: MirExprKind::Void,
})
}
} else {
None
}
}
_ => None,
}
}
fn try_fold_binary(op: BinOp, left: &MirExpr, right: &MirExpr) -> Option<MirExpr> {
match (&left.kind, &right.kind) {
(MirExprKind::IntLit(a), MirExprKind::IntLit(b)) => {
fold_int_binop(op, *a, *b).map(|kind| MirExpr { kind })
}
(MirExprKind::FloatLit(a), MirExprKind::FloatLit(b)) => {
fold_float_binop(op, *a, *b).map(|kind| MirExpr { kind })
}
(MirExprKind::BoolLit(a), MirExprKind::BoolLit(b)) => match op {
BinOp::Eq => Some(MirExpr {
kind: MirExprKind::BoolLit(a == b),
}),
BinOp::Ne => Some(MirExpr {
kind: MirExprKind::BoolLit(a != b),
}),
BinOp::And => Some(MirExpr {
kind: MirExprKind::BoolLit(*a && *b),
}),
BinOp::Or => Some(MirExpr {
kind: MirExprKind::BoolLit(*a || *b),
}),
BinOp::BitAnd => Some(MirExpr {
kind: MirExprKind::BoolLit(*a & *b),
}),
BinOp::BitOr => Some(MirExpr {
kind: MirExprKind::BoolLit(*a | *b),
}),
BinOp::BitXor => Some(MirExpr {
kind: MirExprKind::BoolLit(*a ^ *b),
}),
_ => None,
},
(MirExprKind::StringLit(a), MirExprKind::StringLit(b)) => match op {
BinOp::Add => Some(MirExpr {
kind: MirExprKind::StringLit(format!("{a}{b}")),
}),
BinOp::Eq => Some(MirExpr {
kind: MirExprKind::BoolLit(a == b),
}),
BinOp::Ne => Some(MirExpr {
kind: MirExprKind::BoolLit(a != b),
}),
_ => None,
},
_ => None,
}
}
fn fold_int_binop(op: BinOp, a: i64, b: i64) -> Option<MirExprKind> {
match op {
BinOp::Add => Some(MirExprKind::IntLit(a.wrapping_add(b))),
BinOp::Sub => Some(MirExprKind::IntLit(a.wrapping_sub(b))),
BinOp::Mul => Some(MirExprKind::IntLit(a.wrapping_mul(b))),
BinOp::Div => {
if b == 0 {
None
} else {
Some(MirExprKind::IntLit(a / b))
}
}
BinOp::Mod => {
if b == 0 {
None
} else {
Some(MirExprKind::IntLit(a % b))
}
}
BinOp::Pow => {
if b < 0 {
None } else {
Some(MirExprKind::IntLit(a.wrapping_pow(b as u32)))
}
}
BinOp::BitAnd => Some(MirExprKind::IntLit(a & b)),
BinOp::BitOr => Some(MirExprKind::IntLit(a | b)),
BinOp::BitXor => Some(MirExprKind::IntLit(a ^ b)),
BinOp::Shl => Some(MirExprKind::IntLit(a.wrapping_shl(b as u32))),
BinOp::Shr => Some(MirExprKind::IntLit(a.wrapping_shr(b as u32))),
BinOp::Eq => Some(MirExprKind::BoolLit(a == b)),
BinOp::Ne => Some(MirExprKind::BoolLit(a != b)),
BinOp::Lt => Some(MirExprKind::BoolLit(a < b)),
BinOp::Gt => Some(MirExprKind::BoolLit(a > b)),
BinOp::Le => Some(MirExprKind::BoolLit(a <= b)),
BinOp::Ge => Some(MirExprKind::BoolLit(a >= b)),
BinOp::And | BinOp::Or | BinOp::Match | BinOp::NotMatch => None,
}
}
fn fold_float_binop(op: BinOp, a: f64, b: f64) -> Option<MirExprKind> {
match op {
BinOp::Add => Some(MirExprKind::FloatLit(a + b)),
BinOp::Sub => Some(MirExprKind::FloatLit(a - b)),
BinOp::Mul => Some(MirExprKind::FloatLit(a * b)),
BinOp::Div => Some(MirExprKind::FloatLit(a / b)), BinOp::Mod => Some(MirExprKind::FloatLit(a % b)),
BinOp::Pow => Some(MirExprKind::FloatLit(a.powf(b))),
BinOp::Eq => Some(MirExprKind::BoolLit(a == b)),
BinOp::Ne => Some(MirExprKind::BoolLit(a != b)),
BinOp::Lt => Some(MirExprKind::BoolLit(a < b)),
BinOp::Gt => Some(MirExprKind::BoolLit(a > b)),
BinOp::Le => Some(MirExprKind::BoolLit(a <= b)),
BinOp::Ge => Some(MirExprKind::BoolLit(a >= b)),
BinOp::And | BinOp::Or | BinOp::Match | BinOp::NotMatch => None,
BinOp::BitAnd | BinOp::BitOr | BinOp::BitXor | BinOp::Shl | BinOp::Shr => None,
}
}
fn try_fold_unary(op: UnaryOp, operand: &MirExpr) -> Option<MirExpr> {
match (&op, &operand.kind) {
(UnaryOp::Neg, MirExprKind::IntLit(v)) => Some(MirExpr {
kind: MirExprKind::IntLit(-v),
}),
(UnaryOp::Neg, MirExprKind::FloatLit(v)) => Some(MirExpr {
kind: MirExprKind::FloatLit(-v),
}),
(UnaryOp::Not, MirExprKind::BoolLit(b)) => Some(MirExpr {
kind: MirExprKind::BoolLit(!b),
}),
(UnaryOp::BitNot, MirExprKind::IntLit(v)) => Some(MirExpr {
kind: MirExprKind::IntLit(!v),
}),
_ => None,
}
}
fn dce_fn(func: &mut MirFunction) {
dce_body(&mut func.body);
}
fn dce_body(body: &mut MirBody) {
let mut used_vars = BTreeSet::new();
for stmt in &body.stmts {
collect_used_vars_stmt(stmt, &mut used_vars);
}
if let Some(ref expr) = body.result {
collect_used_vars_expr(expr, &mut used_vars);
}
body.stmts.retain(|stmt| {
match stmt {
MirStmt::Let { name, init, .. } => {
if !used_vars.contains(name.as_str()) && is_pure_expr(init) {
return false; }
true
}
_ => true,
}
});
for stmt in &mut body.stmts {
dce_stmt(stmt);
}
let mut new_stmts = Vec::new();
for stmt in std::mem::take(&mut body.stmts) {
match stmt {
MirStmt::If {
cond,
then_body,
else_body,
} => {
if let MirExprKind::BoolLit(b) = &cond.kind {
if *b {
new_stmts.extend(then_body.stmts);
if let Some(result_expr) = then_body.result {
new_stmts.push(MirStmt::Expr(*result_expr));
}
} else if let Some(eb) = else_body {
new_stmts.extend(eb.stmts);
if let Some(result_expr) = eb.result {
new_stmts.push(MirStmt::Expr(*result_expr));
}
}
} else {
new_stmts.push(MirStmt::If {
cond,
then_body,
else_body,
});
}
}
MirStmt::While { ref cond, .. } => {
if let MirExprKind::BoolLit(false) = &cond.kind {
} else {
new_stmts.push(stmt);
}
}
other => new_stmts.push(other),
}
}
body.stmts = new_stmts;
}
fn dce_stmt(stmt: &mut MirStmt) {
match stmt {
MirStmt::If {
then_body,
else_body,
..
} => {
dce_body(then_body);
if let Some(eb) = else_body {
dce_body(eb);
}
}
MirStmt::While { body, .. } => {
dce_body(body);
}
MirStmt::NoGcBlock(body) => {
dce_body(body);
}
_ => {}
}
}
fn is_pure_expr(expr: &MirExpr) -> bool {
match &expr.kind {
MirExprKind::IntLit(_)
| MirExprKind::FloatLit(_)
| MirExprKind::BoolLit(_)
| MirExprKind::NaLit
| MirExprKind::StringLit(_)
| MirExprKind::ByteStringLit(_)
| MirExprKind::ByteCharLit(_)
| MirExprKind::RawStringLit(_)
| MirExprKind::RawByteStringLit(_)
| MirExprKind::RegexLit { .. }
| MirExprKind::Var(_)
| MirExprKind::Void => true,
MirExprKind::Binary { left, right, .. } => is_pure_expr(left) && is_pure_expr(right),
MirExprKind::Unary { operand, .. } => is_pure_expr(operand),
MirExprKind::TupleLit(elems) | MirExprKind::ArrayLit(elems) => {
elems.iter().all(is_pure_expr)
}
MirExprKind::TensorLit { rows } => {
rows.iter().all(|row| row.iter().all(is_pure_expr))
}
MirExprKind::StructLit { fields, .. } => fields.iter().all(|(_, e)| is_pure_expr(e)),
MirExprKind::VariantLit { fields, .. } => fields.iter().all(is_pure_expr),
MirExprKind::Call { .. } => false,
MirExprKind::Field { object, .. } => is_pure_expr(object),
MirExprKind::Assign { .. } => false,
MirExprKind::Index { .. } | MirExprKind::MultiIndex { .. } => false,
MirExprKind::Block(_)
| MirExprKind::If { .. }
| MirExprKind::Match { .. }
| MirExprKind::Lambda { .. }
| MirExprKind::MakeClosure { .. }
| MirExprKind::Col(_) => false,
MirExprKind::LinalgLU { .. }
| MirExprKind::LinalgQR { .. }
| MirExprKind::LinalgCholesky { .. }
| MirExprKind::LinalgInv { .. }
| MirExprKind::Broadcast { .. } => false,
}
}
fn collect_used_vars_stmt(stmt: &MirStmt, used: &mut BTreeSet<String>) {
match stmt {
MirStmt::Let { init, .. } => {
collect_used_vars_expr(init, used);
}
MirStmt::Expr(expr) => {
collect_used_vars_expr(expr, used);
}
MirStmt::If {
cond,
then_body,
else_body,
} => {
collect_used_vars_expr(cond, used);
collect_used_vars_body(then_body, used);
if let Some(eb) = else_body {
collect_used_vars_body(eb, used);
}
}
MirStmt::While { cond, body } => {
collect_used_vars_expr(cond, used);
collect_used_vars_body(body, used);
}
MirStmt::Return(opt_expr) => {
if let Some(expr) = opt_expr {
collect_used_vars_expr(expr, used);
}
}
MirStmt::Break | MirStmt::Continue => {}
MirStmt::NoGcBlock(body) => {
collect_used_vars_body(body, used);
}
}
}
fn collect_used_vars_body(body: &MirBody, used: &mut BTreeSet<String>) {
for stmt in &body.stmts {
collect_used_vars_stmt(stmt, used);
}
if let Some(ref expr) = body.result {
collect_used_vars_expr(expr, used);
}
}
fn collect_used_vars_expr(expr: &MirExpr, used: &mut BTreeSet<String>) {
match &expr.kind {
MirExprKind::Var(name) => {
used.insert(name.clone());
}
MirExprKind::Binary { left, right, .. } => {
collect_used_vars_expr(left, used);
collect_used_vars_expr(right, used);
}
MirExprKind::Unary { operand, .. } => {
collect_used_vars_expr(operand, used);
}
MirExprKind::Call { callee, args } => {
collect_used_vars_expr(callee, used);
for arg in args {
collect_used_vars_expr(arg, used);
}
}
MirExprKind::Field { object, .. } => {
collect_used_vars_expr(object, used);
}
MirExprKind::Index { object, index } => {
collect_used_vars_expr(object, used);
collect_used_vars_expr(index, used);
}
MirExprKind::MultiIndex { object, indices } => {
collect_used_vars_expr(object, used);
for idx in indices {
collect_used_vars_expr(idx, used);
}
}
MirExprKind::Assign { target, value } => {
collect_used_vars_expr(target, used);
collect_used_vars_expr(value, used);
}
MirExprKind::Block(body) => {
collect_used_vars_body(body, used);
}
MirExprKind::StructLit { fields, .. } => {
for (_, fexpr) in fields {
collect_used_vars_expr(fexpr, used);
}
}
MirExprKind::ArrayLit(elems) | MirExprKind::TupleLit(elems) => {
for e in elems {
collect_used_vars_expr(e, used);
}
}
MirExprKind::MakeClosure { captures, .. } => {
for cap in captures {
collect_used_vars_expr(cap, used);
}
}
MirExprKind::If {
cond,
then_body,
else_body,
} => {
collect_used_vars_expr(cond, used);
collect_used_vars_body(then_body, used);
if let Some(eb) = else_body {
collect_used_vars_body(eb, used);
}
}
MirExprKind::Match { scrutinee, arms } => {
collect_used_vars_expr(scrutinee, used);
for arm in arms {
collect_used_vars_body(&arm.body, used);
}
}
MirExprKind::Lambda { body, .. } => {
collect_used_vars_expr(body, used);
}
MirExprKind::LinalgLU { operand }
| MirExprKind::LinalgQR { operand }
| MirExprKind::LinalgCholesky { operand }
| MirExprKind::LinalgInv { operand } => {
collect_used_vars_expr(operand, used);
}
MirExprKind::Broadcast { operand, target_shape } => {
collect_used_vars_expr(operand, used);
for s in target_shape {
collect_used_vars_expr(s, used);
}
}
MirExprKind::VariantLit { fields, .. } => {
for f in fields {
collect_used_vars_expr(f, used);
}
}
MirExprKind::IntLit(_)
| MirExprKind::FloatLit(_)
| MirExprKind::BoolLit(_)
| MirExprKind::NaLit
| MirExprKind::StringLit(_)
| MirExprKind::ByteStringLit(_)
| MirExprKind::ByteCharLit(_)
| MirExprKind::RawStringLit(_)
| MirExprKind::RawByteStringLit(_)
| MirExprKind::RegexLit { .. }
| MirExprKind::Col(_)
| MirExprKind::Void => {}
MirExprKind::TensorLit { rows } => {
for row in rows {
for elem in row {
collect_used_vars_expr(elem, used);
}
}
}
}
}
fn strength_reduce_fn(func: &mut MirFunction) {
strength_reduce_body(&mut func.body);
}
fn strength_reduce_body(body: &mut MirBody) {
for stmt in &mut body.stmts {
strength_reduce_stmt(stmt);
}
if let Some(ref mut expr) = body.result {
strength_reduce_expr(expr);
}
}
fn strength_reduce_stmt(stmt: &mut MirStmt) {
match stmt {
MirStmt::Let { init, .. } => strength_reduce_expr(init),
MirStmt::Expr(expr) => strength_reduce_expr(expr),
MirStmt::If { cond, then_body, else_body } => {
strength_reduce_expr(cond);
strength_reduce_body(then_body);
if let Some(eb) = else_body {
strength_reduce_body(eb);
}
}
MirStmt::While { cond, body } => {
strength_reduce_expr(cond);
strength_reduce_body(body);
}
MirStmt::Return(opt_expr) => {
if let Some(expr) = opt_expr {
strength_reduce_expr(expr);
}
}
MirStmt::Break | MirStmt::Continue => {}
MirStmt::NoGcBlock(body) => strength_reduce_body(body),
}
}
fn strength_reduce_expr(expr: &mut MirExpr) {
match &mut expr.kind {
MirExprKind::Binary { left, right, .. } => {
strength_reduce_expr(left);
strength_reduce_expr(right);
}
MirExprKind::Unary { operand, .. } => strength_reduce_expr(operand),
MirExprKind::Call { callee, args } => {
strength_reduce_expr(callee);
for arg in args { strength_reduce_expr(arg); }
}
MirExprKind::Block(body) => strength_reduce_body(body),
MirExprKind::If { cond, then_body, else_body } => {
strength_reduce_expr(cond);
strength_reduce_body(then_body);
if let Some(eb) = else_body { strength_reduce_body(eb); }
}
MirExprKind::Lambda { body, .. } => strength_reduce_expr(body),
_ => {}
}
if let Some(reduced) = try_strength_reduce(expr) {
*expr = reduced;
}
}
fn try_strength_reduce(expr: &MirExpr) -> Option<MirExpr> {
match &expr.kind {
MirExprKind::Binary { op, left, right } => {
match op {
BinOp::Mul => {
if matches!(right.kind, MirExprKind::IntLit(0)) {
return Some(MirExpr { kind: MirExprKind::IntLit(0) });
}
if matches!(left.kind, MirExprKind::IntLit(0)) {
return Some(MirExpr { kind: MirExprKind::IntLit(0) });
}
if matches!(right.kind, MirExprKind::IntLit(1)) {
return Some(*left.clone());
}
if matches!(left.kind, MirExprKind::IntLit(1)) {
return Some(*right.clone());
}
if matches!(right.kind, MirExprKind::IntLit(2)) {
return Some(MirExpr {
kind: MirExprKind::Binary {
op: BinOp::Add,
left: left.clone(),
right: left.clone(),
},
});
}
None
}
BinOp::Add => {
if matches!(right.kind, MirExprKind::IntLit(0)) {
return Some(*left.clone());
}
if matches!(left.kind, MirExprKind::IntLit(0)) {
return Some(*right.clone());
}
None
}
BinOp::Sub => {
if matches!(right.kind, MirExprKind::IntLit(0)) {
return Some(*left.clone());
}
None
}
BinOp::Div => {
if matches!(right.kind, MirExprKind::IntLit(1)) {
return Some(*left.clone());
}
None
}
_ => None,
}
}
_ => None,
}
}
fn cse_fn(func: &mut MirFunction) {
cse_body(&mut func.body);
}
fn cse_body(body: &mut MirBody) {
let mut expr_to_var: BTreeMap<String, String> = BTreeMap::new();
let mut replacements: BTreeMap<String, String> = BTreeMap::new();
for stmt in &body.stmts {
if let MirStmt::Let { name, init, mutable, .. } = stmt {
if !mutable && is_pure_expr(init) {
let key = expr_key(init);
if let Some(existing) = expr_to_var.get(&key) {
replacements.insert(name.clone(), existing.clone());
} else {
expr_to_var.insert(key, name.clone());
}
}
}
}
if !replacements.is_empty() {
for stmt in &mut body.stmts {
apply_cse_replacements_stmt(stmt, &replacements);
}
if let Some(ref mut result) = body.result {
apply_cse_replacements_expr(result, &replacements);
}
}
for stmt in &mut body.stmts {
match stmt {
MirStmt::If { then_body, else_body, .. } => {
cse_body(then_body);
if let Some(eb) = else_body { cse_body(eb); }
}
MirStmt::While { body: wb, .. } => cse_body(wb),
MirStmt::NoGcBlock(b) => cse_body(b),
_ => {}
}
}
}
fn expr_key(expr: &MirExpr) -> String {
match &expr.kind {
MirExprKind::IntLit(v) => format!("int:{v}"),
MirExprKind::FloatLit(v) => format!("float:{}", v.to_bits()),
MirExprKind::BoolLit(v) => format!("bool:{v}"),
MirExprKind::NaLit => "na".to_string(),
MirExprKind::StringLit(s) => format!("str:{s}"),
MirExprKind::Var(name) => format!("var:{name}"),
MirExprKind::Binary { op, left, right } => {
format!("bin:{:?}({},{})", op, expr_key(left), expr_key(right))
}
MirExprKind::Unary { op, operand } => {
format!("un:{:?}({})", op, expr_key(operand))
}
MirExprKind::Field { object, name } => {
format!("field:{}:{}", expr_key(object), name)
}
_ => format!("opaque:{:p}", expr),
}
}
fn apply_cse_replacements_stmt(stmt: &mut MirStmt, replacements: &BTreeMap<String, String>) {
match stmt {
MirStmt::Let { init, .. } => apply_cse_replacements_expr(init, replacements),
MirStmt::Expr(expr) => apply_cse_replacements_expr(expr, replacements),
MirStmt::If { cond, then_body, else_body } => {
apply_cse_replacements_expr(cond, replacements);
for s in &mut then_body.stmts { apply_cse_replacements_stmt(s, replacements); }
if let Some(ref mut r) = then_body.result { apply_cse_replacements_expr(r, replacements); }
if let Some(eb) = else_body {
for s in &mut eb.stmts { apply_cse_replacements_stmt(s, replacements); }
if let Some(ref mut r) = eb.result { apply_cse_replacements_expr(r, replacements); }
}
}
MirStmt::While { cond, body } => {
apply_cse_replacements_expr(cond, replacements);
for s in &mut body.stmts { apply_cse_replacements_stmt(s, replacements); }
if let Some(ref mut r) = body.result { apply_cse_replacements_expr(r, replacements); }
}
MirStmt::Return(opt_expr) => {
if let Some(expr) = opt_expr { apply_cse_replacements_expr(expr, replacements); }
}
MirStmt::Break | MirStmt::Continue => {}
MirStmt::NoGcBlock(body) => {
for s in &mut body.stmts { apply_cse_replacements_stmt(s, replacements); }
if let Some(ref mut r) = body.result { apply_cse_replacements_expr(r, replacements); }
}
}
}
fn apply_cse_replacements_expr(expr: &mut MirExpr, replacements: &BTreeMap<String, String>) {
match &mut expr.kind {
MirExprKind::Var(name) => {
if let Some(replacement) = replacements.get(name.as_str()) {
*name = replacement.clone();
}
}
MirExprKind::Binary { left, right, .. } => {
apply_cse_replacements_expr(left, replacements);
apply_cse_replacements_expr(right, replacements);
}
MirExprKind::Unary { operand, .. } => {
apply_cse_replacements_expr(operand, replacements);
}
MirExprKind::Call { callee, args } => {
apply_cse_replacements_expr(callee, replacements);
for arg in args { apply_cse_replacements_expr(arg, replacements); }
}
MirExprKind::Field { object, .. } => {
apply_cse_replacements_expr(object, replacements);
}
MirExprKind::Index { object, index } => {
apply_cse_replacements_expr(object, replacements);
apply_cse_replacements_expr(index, replacements);
}
MirExprKind::Block(body) => {
for s in &mut body.stmts { apply_cse_replacements_stmt(s, replacements); }
if let Some(ref mut r) = body.result { apply_cse_replacements_expr(r, replacements); }
}
_ => {} }
}
fn licm_fn(func: &mut MirFunction) {
licm_body(&mut func.body);
}
fn licm_body(body: &mut MirBody) {
for stmt in &mut body.stmts {
match stmt {
MirStmt::If { then_body, else_body, .. } => {
licm_body(then_body);
if let Some(eb) = else_body { licm_body(eb); }
}
MirStmt::While { body: wb, .. } => licm_body(wb),
MirStmt::NoGcBlock(b) => licm_body(b),
_ => {}
}
}
let mut new_stmts = Vec::new();
for stmt in std::mem::take(&mut body.stmts) {
if let MirStmt::While { cond, body: loop_body } = stmt {
let (hoisted, remaining_body) = hoist_invariants(loop_body);
new_stmts.extend(hoisted);
new_stmts.push(MirStmt::While { cond, body: remaining_body });
} else {
new_stmts.push(stmt);
}
}
body.stmts = new_stmts;
}
fn hoist_invariants(loop_body: MirBody) -> (Vec<MirStmt>, MirBody) {
let mut modified_vars = BTreeSet::new();
collect_modified_vars_body(&loop_body, &mut modified_vars);
let mut hoisted = Vec::new();
let mut remaining = Vec::new();
for stmt in loop_body.stmts {
if let MirStmt::Let { ref name, ref init, mutable, alloc_hint } = stmt {
if is_pure_expr(init) && !references_any(init, &modified_vars) {
hoisted.push(MirStmt::Let {
name: name.clone(),
mutable,
init: init.clone(),
alloc_hint,
});
continue;
}
}
remaining.push(stmt);
}
(hoisted, MirBody { stmts: remaining, result: loop_body.result })
}
fn collect_modified_vars_body(body: &MirBody, modified: &mut BTreeSet<String>) {
for stmt in &body.stmts {
collect_modified_vars_stmt(stmt, modified);
}
}
fn collect_modified_vars_stmt(stmt: &MirStmt, modified: &mut BTreeSet<String>) {
match stmt {
MirStmt::Let { name, init, .. } => {
modified.insert(name.clone());
collect_modified_vars_expr(init, modified);
}
MirStmt::Expr(expr) => collect_modified_vars_expr(expr, modified),
MirStmt::If { cond, then_body, else_body } => {
collect_modified_vars_expr(cond, modified);
collect_modified_vars_body(then_body, modified);
if let Some(eb) = else_body { collect_modified_vars_body(eb, modified); }
}
MirStmt::While { cond, body } => {
collect_modified_vars_expr(cond, modified);
collect_modified_vars_body(body, modified);
}
MirStmt::Return(_) => {}
MirStmt::Break | MirStmt::Continue => {}
MirStmt::NoGcBlock(body) => collect_modified_vars_body(body, modified),
}
}
fn collect_modified_vars_expr(expr: &MirExpr, modified: &mut BTreeSet<String>) {
match &expr.kind {
MirExprKind::Assign { target, value } => {
if let MirExprKind::Var(name) = &target.kind {
modified.insert(name.clone());
}
collect_modified_vars_expr(value, modified);
}
MirExprKind::Binary { left, right, .. } => {
collect_modified_vars_expr(left, modified);
collect_modified_vars_expr(right, modified);
}
MirExprKind::Call { callee, args } => {
collect_modified_vars_expr(callee, modified);
for arg in args { collect_modified_vars_expr(arg, modified); }
}
_ => {}
}
}
fn references_any(expr: &MirExpr, vars: &BTreeSet<String>) -> bool {
match &expr.kind {
MirExprKind::Var(name) => vars.contains(name.as_str()),
MirExprKind::Binary { left, right, .. } => {
references_any(left, vars) || references_any(right, vars)
}
MirExprKind::Unary { operand, .. } => references_any(operand, vars),
MirExprKind::Field { object, .. } => references_any(object, vars),
MirExprKind::ArrayLit(elems) | MirExprKind::TupleLit(elems) => {
elems.iter().any(|e| references_any(e, vars))
}
MirExprKind::StructLit { fields, .. } => {
fields.iter().any(|(_, e)| references_any(e, vars))
}
_ => false, }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::*;
fn mk_expr(kind: MirExprKind) -> MirExpr {
MirExpr { kind }
}
fn mk_int(v: i64) -> MirExpr {
mk_expr(MirExprKind::IntLit(v))
}
fn mk_float(v: f64) -> MirExpr {
mk_expr(MirExprKind::FloatLit(v))
}
fn mk_bool(v: bool) -> MirExpr {
mk_expr(MirExprKind::BoolLit(v))
}
fn mk_binary(op: BinOp, left: MirExpr, right: MirExpr) -> MirExpr {
mk_expr(MirExprKind::Binary {
op,
left: Box::new(left),
right: Box::new(right),
})
}
fn mk_unary(op: UnaryOp, operand: MirExpr) -> MirExpr {
mk_expr(MirExprKind::Unary {
op,
operand: Box::new(operand),
})
}
fn mk_fn(name: &str, stmts: Vec<MirStmt>, result: Option<MirExpr>) -> MirFunction {
MirFunction {
id: MirFnId(0),
name: name.to_string(),
type_params: vec![],
params: vec![],
return_type: None,
body: MirBody {
stmts,
result: result.map(Box::new),
},
is_nogc: false,
cfg_body: None,
decorators: vec![],
vis: cjc_ast::Visibility::Private,
}
}
fn mk_program(functions: Vec<MirFunction>) -> MirProgram {
MirProgram {
functions,
struct_defs: vec![],
enum_defs: vec![],
entry: MirFnId(0),
}
}
#[test]
fn test_fold_int_add() {
let mut expr = mk_binary(BinOp::Add, mk_int(2), mk_int(3));
constant_fold_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::IntLit(5)));
}
#[test]
fn test_fold_int_mul() {
let mut expr = mk_binary(BinOp::Mul, mk_int(4), mk_int(5));
constant_fold_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::IntLit(20)));
}
#[test]
fn test_fold_int_div_by_zero_not_folded() {
let mut expr = mk_binary(BinOp::Div, mk_int(10), mk_int(0));
constant_fold_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::Binary { .. }));
}
#[test]
fn test_fold_float_add() {
let mut expr = mk_binary(BinOp::Add, mk_float(1.5), mk_float(2.5));
constant_fold_expr(&mut expr);
match expr.kind {
MirExprKind::FloatLit(v) => assert_eq!(v, 4.0),
_ => panic!("expected FloatLit"),
}
}
#[test]
fn test_fold_nested() {
let inner = mk_binary(BinOp::Add, mk_int(2), mk_int(3));
let mut expr = mk_binary(BinOp::Mul, inner, mk_int(4));
constant_fold_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::IntLit(20)));
}
#[test]
fn test_fold_comparison() {
let mut expr = mk_binary(BinOp::Lt, mk_int(1), mk_int(2));
constant_fold_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::BoolLit(true)));
}
#[test]
fn test_fold_unary_neg() {
let mut expr = mk_unary(UnaryOp::Neg, mk_int(42));
constant_fold_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::IntLit(-42)));
}
#[test]
fn test_fold_unary_not() {
let mut expr = mk_unary(UnaryOp::Not, mk_bool(true));
constant_fold_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::BoolLit(false)));
}
#[test]
fn test_dce_removes_unused_pure_let() {
let mut body = MirBody {
stmts: vec![
MirStmt::Let {
name: "unused".to_string(),
mutable: false,
init: mk_int(42),
alloc_hint: None,
},
MirStmt::Expr(mk_expr(MirExprKind::Call {
callee: Box::new(mk_expr(MirExprKind::Var("print".to_string()))),
args: vec![mk_expr(MirExprKind::StringLit("hi".to_string()))],
})),
],
result: None,
};
dce_body(&mut body);
assert_eq!(body.stmts.len(), 1);
assert!(matches!(body.stmts[0], MirStmt::Expr(_)));
}
#[test]
fn test_dce_keeps_used_let() {
let mut body = MirBody {
stmts: vec![
MirStmt::Let {
name: "x".to_string(),
mutable: false,
init: mk_int(42),
alloc_hint: None,
},
],
result: Some(Box::new(mk_expr(MirExprKind::Var("x".to_string())))),
};
dce_body(&mut body);
assert_eq!(body.stmts.len(), 1);
}
#[test]
fn test_dce_removes_dead_if_false() {
let mut body = MirBody {
stmts: vec![MirStmt::If {
cond: mk_bool(false),
then_body: MirBody {
stmts: vec![MirStmt::Expr(mk_int(1))],
result: None,
},
else_body: None,
}],
result: None,
};
dce_body(&mut body);
assert!(body.stmts.is_empty());
}
#[test]
fn test_dce_inlines_if_true() {
let mut body = MirBody {
stmts: vec![MirStmt::If {
cond: mk_bool(true),
then_body: MirBody {
stmts: vec![MirStmt::Expr(mk_int(1))],
result: None,
},
else_body: None,
}],
result: None,
};
dce_body(&mut body);
assert_eq!(body.stmts.len(), 1);
assert!(matches!(body.stmts[0], MirStmt::Expr(_)));
}
#[test]
fn test_dce_removes_dead_while_false() {
let mut body = MirBody {
stmts: vec![MirStmt::While {
cond: mk_bool(false),
body: MirBody {
stmts: vec![MirStmt::Expr(mk_int(1))],
result: None,
},
}],
result: None,
};
dce_body(&mut body);
assert!(body.stmts.is_empty());
}
#[test]
fn test_optimize_program_preserves_semantics() {
let program = mk_program(vec![mk_fn(
"__main",
vec![
MirStmt::Let {
name: "x".to_string(),
mutable: false,
init: mk_binary(BinOp::Add, mk_int(2), mk_int(3)),
alloc_hint: None,
},
],
Some(mk_expr(MirExprKind::Var("x".to_string()))),
)]);
let optimized = optimize_program(&program);
let main = &optimized.functions[0];
match &main.body.stmts[0] {
MirStmt::Let { init, .. } => {
assert!(matches!(init.kind, MirExprKind::IntLit(5)));
}
_ => panic!("expected Let"),
}
}
fn mk_var(name: &str) -> MirExpr {
mk_expr(MirExprKind::Var(name.to_string()))
}
#[test]
fn test_sr_mul_by_zero() {
let mut expr = mk_binary(BinOp::Mul, mk_var("x"), mk_int(0));
strength_reduce_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::IntLit(0)));
}
#[test]
fn test_sr_mul_by_one() {
let mut expr = mk_binary(BinOp::Mul, mk_var("x"), mk_int(1));
strength_reduce_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::Var(ref n) if n == "x"));
}
#[test]
fn test_sr_mul_by_two() {
let mut expr = mk_binary(BinOp::Mul, mk_var("x"), mk_int(2));
strength_reduce_expr(&mut expr);
match &expr.kind {
MirExprKind::Binary { op, left, right } => {
assert_eq!(*op, BinOp::Add);
assert!(matches!(left.kind, MirExprKind::Var(ref n) if n == "x"));
assert!(matches!(right.kind, MirExprKind::Var(ref n) if n == "x"));
}
_ => panic!("expected Binary Add"),
}
}
#[test]
fn test_sr_add_zero() {
let mut expr = mk_binary(BinOp::Add, mk_var("x"), mk_int(0));
strength_reduce_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::Var(ref n) if n == "x"));
}
#[test]
fn test_sr_sub_zero() {
let mut expr = mk_binary(BinOp::Sub, mk_var("x"), mk_int(0));
strength_reduce_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::Var(ref n) if n == "x"));
}
#[test]
fn test_sr_div_by_one() {
let mut expr = mk_binary(BinOp::Div, mk_var("x"), mk_int(1));
strength_reduce_expr(&mut expr);
assert!(matches!(expr.kind, MirExprKind::Var(ref n) if n == "x"));
}
#[test]
fn test_cse_eliminates_duplicate_pure_let() {
let mut body = MirBody {
stmts: vec![
MirStmt::Let {
name: "a".to_string(),
mutable: false,
init: mk_int(10),
alloc_hint: None,
},
MirStmt::Let {
name: "b".to_string(),
mutable: false,
init: mk_int(20),
alloc_hint: None,
},
MirStmt::Let {
name: "x".to_string(),
mutable: false,
init: mk_binary(BinOp::Add, mk_var("a"), mk_var("b")),
alloc_hint: None,
},
MirStmt::Let {
name: "y".to_string(),
mutable: false,
init: mk_binary(BinOp::Add, mk_var("a"), mk_var("b")),
alloc_hint: None,
},
],
result: Some(Box::new(mk_var("y"))),
};
cse_body(&mut body);
match &body.result {
Some(expr) => {
assert!(matches!(expr.kind, MirExprKind::Var(ref n) if n == "x"));
}
None => panic!("expected result"),
}
}
#[test]
fn test_licm_hoists_invariant_let() {
let mut body = MirBody {
stmts: vec![MirStmt::While {
cond: mk_var("cond"),
body: MirBody {
stmts: vec![
MirStmt::Let {
name: "inv".to_string(),
mutable: false,
init: mk_binary(BinOp::Add, mk_int(1), mk_int(2)),
alloc_hint: None,
},
MirStmt::Let {
name: "x".to_string(),
mutable: false,
init: mk_binary(BinOp::Add, mk_var("inv"), mk_var("i")),
alloc_hint: None,
},
MirStmt::Expr(mk_expr(MirExprKind::Assign {
target: Box::new(mk_var("i")),
value: Box::new(mk_binary(BinOp::Add, mk_var("i"), mk_int(1))),
})),
],
result: None,
},
}],
result: None,
};
licm_body(&mut body);
assert_eq!(body.stmts.len(), 2);
match &body.stmts[0] {
MirStmt::Let { name, .. } => assert_eq!(name, "inv"),
_ => panic!("expected hoisted Let"),
}
assert!(matches!(body.stmts[1], MirStmt::While { .. }));
match &body.stmts[1] {
MirStmt::While { body: wb, .. } => {
assert_eq!(wb.stmts.len(), 2);
}
_ => panic!("expected While"),
}
}
#[test]
fn test_licm_does_not_hoist_dependent() {
let mut body = MirBody {
stmts: vec![MirStmt::While {
cond: mk_var("cond"),
body: MirBody {
stmts: vec![
MirStmt::Expr(mk_expr(MirExprKind::Assign {
target: Box::new(mk_var("i")),
value: Box::new(mk_binary(BinOp::Add, mk_var("i"), mk_int(1))),
})),
MirStmt::Let {
name: "x".to_string(),
mutable: false,
init: mk_binary(BinOp::Mul, mk_var("i"), mk_int(2)),
alloc_hint: None,
},
],
result: None,
},
}],
result: None,
};
licm_body(&mut body);
assert_eq!(body.stmts.len(), 1);
match &body.stmts[0] {
MirStmt::While { body: wb, .. } => {
assert_eq!(wb.stmts.len(), 2); }
_ => panic!("expected While"),
}
}
#[test]
fn test_full_optimize_with_strength_reduction() {
let program = mk_program(vec![mk_fn(
"__main",
vec![
MirStmt::Let {
name: "y".to_string(),
mutable: false,
init: mk_int(42),
alloc_hint: None,
},
MirStmt::Let {
name: "x".to_string(),
mutable: false,
init: mk_binary(BinOp::Mul, mk_var("y"), mk_int(1)),
alloc_hint: None,
},
MirStmt::Let {
name: "z".to_string(),
mutable: false,
init: mk_binary(BinOp::Add, mk_var("y"), mk_int(0)),
alloc_hint: None,
},
],
Some(mk_binary(BinOp::Add, mk_var("x"), mk_var("z"))),
)]);
let optimized = optimize_program(&program);
let main = &optimized.functions[0];
match &main.body.stmts[1] {
MirStmt::Let { init, .. } => {
assert!(
matches!(init.kind, MirExprKind::Var(ref n) if n == "y"),
"expected Var(y) after strength reduction, got {:?}",
init.kind
);
}
_ => panic!("expected Let"),
}
}
}