#[cfg(test)]
use crate::frontend::ast::Span;
use crate::frontend::ast::{BinaryOp, Expr, ExprKind, Literal};
use std::collections::{HashMap, HashSet};
pub fn fold_constants(expr: Expr) -> Expr {
match expr.kind {
ExprKind::Binary { left, op, right } => {
let left_folded = fold_constants((*left).clone());
let right_folded = fold_constants((*right).clone());
if let (ExprKind::Literal(l), ExprKind::Literal(r)) =
(&left_folded.kind, &right_folded.kind)
{
if let Some(result) = fold_binary_op(l, op, r) {
return Expr::new(ExprKind::Literal(result), expr.span);
}
}
Expr::new(
ExprKind::Binary {
left: Box::new(left_folded),
op,
right: Box::new(right_folded),
},
expr.span,
)
}
ExprKind::Let {
name,
type_annotation,
value,
body,
is_mutable,
else_block,
} => {
let folded_value = Box::new(fold_constants((*value).clone()));
let folded_body = Box::new(fold_constants((*body).clone()));
let folded_else = else_block.map(|e| Box::new(fold_constants((*e).clone())));
Expr::new(
ExprKind::Let {
name,
type_annotation,
value: folded_value,
body: folded_body,
is_mutable,
else_block: folded_else,
},
expr.span,
)
}
ExprKind::Block(exprs) => {
let folded_exprs = exprs.into_iter().map(fold_constants).collect();
Expr::new(ExprKind::Block(folded_exprs), expr.span)
}
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
let folded_cond = Box::new(fold_constants((*condition).clone()));
let folded_then = Box::new(fold_constants((*then_branch).clone()));
let folded_else = else_branch.map(|e| Box::new(fold_constants((*e).clone())));
if let ExprKind::Literal(Literal::Bool(b)) = folded_cond.kind {
if b {
return Expr::new(ExprKind::Block(vec![(*folded_then).clone()]), expr.span);
}
if let Some(else_expr) = folded_else {
return Expr::new(ExprKind::Block(vec![(*else_expr).clone()]), expr.span);
}
return Expr::new(ExprKind::Block(vec![]), expr.span);
}
Expr::new(
ExprKind::If {
condition: folded_cond,
then_branch: folded_then,
else_branch: folded_else,
},
expr.span,
)
}
_ => expr, }
}
fn fold_binary_op(left: &Literal, op: BinaryOp, right: &Literal) -> Option<Literal> {
match (left, right) {
(Literal::Integer(a, None), Literal::Integer(b, None)) => {
fold_integer_comparison(*a, op, *b).or_else(|| fold_integer_arithmetic(*a, op, *b))
}
_ => None, }
}
fn fold_integer_arithmetic(a: i64, op: BinaryOp, b: i64) -> Option<Literal> {
let result = match op {
BinaryOp::Add => a.checked_add(b)?,
BinaryOp::Subtract => a.checked_sub(b)?,
BinaryOp::Multiply => a.checked_mul(b)?,
BinaryOp::Divide if b != 0 => a.checked_div(b)?,
_ => return None,
};
Some(Literal::Integer(result, None))
}
fn fold_integer_comparison(a: i64, op: BinaryOp, b: i64) -> Option<Literal> {
let result = match op {
BinaryOp::Equal => a == b,
BinaryOp::NotEqual => a != b,
BinaryOp::Less => a < b,
BinaryOp::LessEqual => a <= b,
BinaryOp::Greater => a > b,
BinaryOp::GreaterEqual => a >= b,
_ => return None,
};
Some(Literal::Bool(result))
}
fn collect_used_functions(expr: &Expr) -> HashSet<String> {
let mut used = HashSet::new();
collect_used_functions_rec(expr, &mut used);
used
}
fn collect_used_functions_rec(expr: &Expr, used: &mut HashSet<String>) {
match &expr.kind {
ExprKind::Call { func, args } => {
if let ExprKind::Identifier(func_name) = &func.kind {
used.insert(func_name.clone());
}
collect_used_functions_rec(func, used);
for arg in args {
collect_used_functions_rec(arg, used);
}
}
ExprKind::Block(exprs) => {
for e in exprs {
collect_used_functions_rec(e, used);
}
}
ExprKind::Function { body, .. } => {
collect_used_functions_rec(body, used);
}
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
collect_used_functions_rec(condition, used);
collect_used_functions_rec(then_branch, used);
if let Some(else_expr) = else_branch {
collect_used_functions_rec(else_expr, used);
}
}
ExprKind::Binary { left, right, .. } => {
collect_used_functions_rec(left, used);
collect_used_functions_rec(right, used);
}
ExprKind::Let { value, body, .. } => {
collect_used_functions_rec(value, used);
collect_used_functions_rec(body, used);
}
ExprKind::Await { expr } => {
collect_used_functions_rec(expr, used);
}
ExprKind::AsyncBlock { body } => {
collect_used_functions_rec(body, used);
}
ExprKind::Spawn { actor } => {
collect_used_functions_rec(actor, used);
}
_ => {
}
}
}
fn collect_used_variables(expr: &Expr) -> HashSet<String> {
let mut used = HashSet::new();
collect_used_variables_rec(expr, &mut used, &HashSet::new());
used
}
fn collect_used_variables_rec(expr: &Expr, used: &mut HashSet<String>, bound: &HashSet<String>) {
match &expr.kind {
ExprKind::Identifier(name) => {
if bound.contains(name) {
used.insert(name.clone());
}
}
ExprKind::Let {
name,
value,
body,
else_block,
..
} => {
collect_used_variables_rec(value, used, bound);
let mut new_bound = bound.clone();
new_bound.insert(name.clone());
collect_used_variables_rec(body, used, &new_bound);
if let Some(else_expr) = else_block {
collect_used_variables_rec(else_expr, used, bound);
}
}
ExprKind::Block(exprs) => {
for e in exprs {
collect_used_variables_rec(e, used, bound);
}
}
ExprKind::Function { body, .. } => {
collect_used_variables_rec(body, used, &HashSet::new());
}
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
collect_used_variables_rec(condition, used, bound);
collect_used_variables_rec(then_branch, used, bound);
if let Some(else_expr) = else_branch {
collect_used_variables_rec(else_expr, used, bound);
}
}
ExprKind::While {
condition, body, ..
} => {
collect_used_variables_rec(condition, used, bound);
collect_used_variables_rec(body, used, bound);
}
ExprKind::Binary { left, right, .. } => {
collect_used_variables_rec(left, used, bound);
collect_used_variables_rec(right, used, bound);
}
ExprKind::Call { func, args } => {
collect_used_variables_rec(func, used, bound);
for arg in args {
collect_used_variables_rec(arg, used, bound);
}
}
ExprKind::Return { value } => {
if let Some(val) = value {
collect_used_variables_rec(val, used, bound);
}
}
_ => {
}
}
}
pub fn eliminate_dead_code(
expr: Expr,
inlined_functions: std::collections::HashSet<String>,
) -> Expr {
match expr.kind {
ExprKind::Block(exprs) => {
let used_functions = {
let temp_expr = Expr::new(ExprKind::Block(exprs.clone()), expr.span);
collect_used_functions(&temp_expr)
};
let used_variables = {
let temp_expr = Expr::new(ExprKind::Block(exprs.clone()), expr.span);
collect_used_variables(&temp_expr)
};
let cleaned = remove_dead_statements_and_unused_functions_and_variables(
exprs,
&used_functions,
&inlined_functions,
&used_variables,
);
Expr::new(ExprKind::Block(cleaned), expr.span)
}
ExprKind::Function {
name,
type_params,
params,
return_type,
body,
is_async,
is_pub,
} => {
let cleaned_body = Box::new(eliminate_dead_code(
(*body).clone(),
std::collections::HashSet::new(),
));
Expr::new(
ExprKind::Function {
name,
type_params,
params,
return_type,
body: cleaned_body,
is_async,
is_pub,
},
expr.span,
)
}
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
let cleaned_then = Box::new(eliminate_dead_code(
(*then_branch).clone(),
std::collections::HashSet::new(),
));
let cleaned_else = else_branch.map(|e| {
Box::new(eliminate_dead_code(
(*e).clone(),
std::collections::HashSet::new(),
))
});
Expr::new(
ExprKind::If {
condition,
then_branch: cleaned_then,
else_branch: cleaned_else,
},
expr.span,
)
}
ExprKind::While {
condition,
body,
label,
} => {
let cleaned_body = Box::new(eliminate_dead_code(
(*body).clone(),
std::collections::HashSet::new(),
));
Expr::new(
ExprKind::While {
condition,
body: cleaned_body,
label,
},
expr.span,
)
}
ExprKind::Call { func, args } => {
let cleaned_func = Box::new(eliminate_dead_code(
(*func).clone(),
inlined_functions.clone(),
));
let cleaned_args: Vec<Expr> = args
.into_iter()
.map(|arg| eliminate_dead_code(arg, inlined_functions.clone()))
.collect();
Expr::new(
ExprKind::Call {
func: cleaned_func,
args: cleaned_args,
},
expr.span,
)
}
_ => expr, }
}
fn remove_dead_statements(exprs: Vec<Expr>) -> Vec<Expr> {
let mut result = Vec::new();
for expr in exprs {
let cleaned = eliminate_dead_code(expr, std::collections::HashSet::new());
result.push(cleaned.clone());
if has_early_exit(&cleaned) {
break;
}
}
result
}
fn should_remove_function(
name: &str,
used_functions: &HashSet<String>,
inlined_functions: &std::collections::HashSet<String>,
) -> bool {
inlined_functions.contains(name) && !used_functions.contains(name) && name != "main"
}
fn process_let_elimination(
name: &str,
value: &Expr,
body: &Expr,
used_variables: &HashSet<String>,
) -> Option<Vec<Expr>> {
let body_is_unit = matches!(body.kind, ExprKind::Literal(Literal::Unit));
if used_variables.contains(name) || has_side_effects(value) || body_is_unit {
return None;
}
let cleaned_body = eliminate_dead_code(body.clone(), std::collections::HashSet::new());
Some(if let ExprKind::Block(inner_exprs) = cleaned_body.kind {
inner_exprs
} else {
vec![cleaned_body]
})
}
fn remove_dead_statements_and_unused_functions_and_variables(
exprs: Vec<Expr>,
used_functions: &HashSet<String>,
inlined_functions: &std::collections::HashSet<String>,
used_variables: &HashSet<String>,
) -> Vec<Expr> {
let mut result = Vec::new();
for expr in exprs {
if let ExprKind::Function { name, .. } = &expr.kind {
if should_remove_function(name, used_functions, inlined_functions) {
continue;
}
}
if let ExprKind::Let {
name, value, body, ..
} = &expr.kind
{
if let Some(replacement) = process_let_elimination(name, value, body, used_variables) {
result.extend(replacement);
continue;
}
}
let cleaned = eliminate_dead_code(expr, std::collections::HashSet::new());
result.push(cleaned.clone());
if has_early_exit(&cleaned) {
break;
}
}
result
}
fn has_side_effects(expr: &Expr) -> bool {
matches!(expr.kind, ExprKind::Call { .. } | ExprKind::Assign { .. })
}
fn has_early_exit(expr: &Expr) -> bool {
matches!(
expr.kind,
ExprKind::Return { .. } | ExprKind::Break { .. } | ExprKind::Continue { .. }
)
}
pub fn propagate_constants(expr: Expr) -> Expr {
let mut env = HashMap::new();
propagate_with_env(expr, &mut env)
}
fn propagate_with_env(expr: Expr, env: &mut HashMap<String, Literal>) -> Expr {
let expr = fold_constants(expr);
match expr.kind {
ExprKind::Let {
name,
type_annotation,
value,
body,
is_mutable,
else_block,
} => {
let folded_value = Box::new(propagate_with_env((*value).clone(), env));
if !is_mutable {
if let ExprKind::Literal(ref lit) = folded_value.kind {
if !matches!(lit, Literal::String(_)) {
env.insert(name.clone(), lit.clone());
}
}
}
let folded_body = Box::new(propagate_with_env((*body).clone(), env));
let folded_else = else_block.map(|e| Box::new(propagate_with_env((*e).clone(), env)));
Expr::new(
ExprKind::Let {
name,
type_annotation,
value: folded_value,
body: folded_body,
is_mutable,
else_block: folded_else,
},
expr.span,
)
}
ExprKind::Identifier(ref name) => {
if let Some(lit) = env.get(name) {
Expr::new(ExprKind::Literal(lit.clone()), expr.span)
} else {
expr
}
}
ExprKind::Binary { left, op, right } => {
let left_prop = propagate_with_env((*left).clone(), env);
let right_prop = propagate_with_env((*right).clone(), env);
let binary_expr = Expr::new(
ExprKind::Binary {
left: Box::new(left_prop),
op,
right: Box::new(right_prop),
},
expr.span,
);
fold_constants(binary_expr)
}
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
let cond_prop = Box::new(propagate_with_env((*condition).clone(), env));
let mut then_env = env.clone();
let then_prop = Box::new(propagate_with_env((*then_branch).clone(), &mut then_env));
let else_prop = else_branch.map(|e| {
let mut else_env = env.clone();
Box::new(propagate_with_env((*e).clone(), &mut else_env))
});
let if_expr = Expr::new(
ExprKind::If {
condition: cond_prop,
then_branch: then_prop,
else_branch: else_prop,
},
expr.span,
);
fold_constants(if_expr)
}
ExprKind::Block(exprs) => {
let mut block_env = env.clone();
let folded_exprs = exprs
.into_iter()
.map(|e| propagate_with_env(e, &mut block_env))
.collect();
Expr::new(ExprKind::Block(folded_exprs), expr.span)
}
ExprKind::Call { func, args } => {
let func_prop = Box::new(propagate_with_env((*func).clone(), env));
let args_prop = args
.into_iter()
.map(|a| propagate_with_env(a, env))
.collect();
Expr::new(
ExprKind::Call {
func: func_prop,
args: args_prop,
},
expr.span,
)
}
_ => expr,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn int_lit(n: i64) -> Expr {
Expr::new(
ExprKind::Literal(Literal::Integer(n, None)),
Span::new(0, 0),
)
}
fn binary(left: i64, op: BinaryOp, right: i64) -> Expr {
Expr::new(
ExprKind::Binary {
left: Box::new(int_lit(left)),
op,
right: Box::new(int_lit(right)),
},
Span::new(0, 0),
)
}
#[test]
fn test_fold_simple_add() {
let expr = binary(2, BinaryOp::Add, 3);
let folded = fold_constants(expr);
assert!(matches!(
folded.kind,
ExprKind::Literal(Literal::Integer(5, None))
));
}
#[test]
fn test_fold_comparison() {
let expr = binary(10, BinaryOp::Greater, 5);
let folded = fold_constants(expr);
assert!(matches!(
folded.kind,
ExprKind::Literal(Literal::Bool(true))
));
}
#[test]
fn test_fold_integer_arithmetic_subtract() {
let result = fold_integer_arithmetic(10, BinaryOp::Subtract, 3);
assert!(matches!(result, Some(Literal::Integer(7, None))));
}
#[test]
fn test_fold_integer_arithmetic_multiply() {
let result = fold_integer_arithmetic(4, BinaryOp::Multiply, 5);
assert!(matches!(result, Some(Literal::Integer(20, None))));
}
#[test]
fn test_fold_integer_arithmetic_divide() {
let result = fold_integer_arithmetic(20, BinaryOp::Divide, 4);
assert!(matches!(result, Some(Literal::Integer(5, None))));
}
#[test]
fn test_fold_integer_arithmetic_divide_by_zero() {
let result = fold_integer_arithmetic(20, BinaryOp::Divide, 0);
assert!(result.is_none());
}
#[test]
fn test_fold_integer_arithmetic_unsupported() {
let result = fold_integer_arithmetic(10, BinaryOp::Equal, 5);
assert!(result.is_none());
}
#[test]
fn test_fold_integer_comparison_equal() {
let result = fold_integer_comparison(5, BinaryOp::Equal, 5);
assert!(matches!(result, Some(Literal::Bool(true))));
}
#[test]
fn test_fold_integer_comparison_not_equal() {
let result = fold_integer_comparison(5, BinaryOp::NotEqual, 3);
assert!(matches!(result, Some(Literal::Bool(true))));
}
#[test]
fn test_fold_integer_comparison_less() {
let result = fold_integer_comparison(3, BinaryOp::Less, 5);
assert!(matches!(result, Some(Literal::Bool(true))));
}
#[test]
fn test_fold_integer_comparison_less_equal() {
let result = fold_integer_comparison(5, BinaryOp::LessEqual, 5);
assert!(matches!(result, Some(Literal::Bool(true))));
}
#[test]
fn test_fold_integer_comparison_greater_equal() {
let result = fold_integer_comparison(5, BinaryOp::GreaterEqual, 5);
assert!(matches!(result, Some(Literal::Bool(true))));
}
#[test]
fn test_fold_integer_comparison_unsupported() {
let result = fold_integer_comparison(5, BinaryOp::Add, 3);
assert!(result.is_none());
}
#[test]
fn test_has_side_effects_call() {
let expr = Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("foo".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
);
assert!(has_side_effects(&expr));
}
#[test]
fn test_has_side_effects_assign() {
let expr = Expr::new(
ExprKind::Assign {
target: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
value: Box::new(int_lit(5)),
},
Span::new(0, 0),
);
assert!(has_side_effects(&expr));
}
#[test]
fn test_has_side_effects_literal() {
let expr = int_lit(42);
assert!(!has_side_effects(&expr));
}
#[test]
fn test_has_early_exit_return() {
let expr = Expr::new(
ExprKind::Return {
value: Some(Box::new(int_lit(5))),
},
Span::new(0, 0),
);
assert!(has_early_exit(&expr));
}
#[test]
fn test_has_early_exit_break() {
let expr = Expr::new(
ExprKind::Break {
label: None,
value: None,
},
Span::new(0, 0),
);
assert!(has_early_exit(&expr));
}
#[test]
fn test_has_early_exit_continue() {
let expr = Expr::new(ExprKind::Continue { label: None }, Span::new(0, 0));
assert!(has_early_exit(&expr));
}
#[test]
fn test_has_early_exit_literal() {
let expr = int_lit(42);
assert!(!has_early_exit(&expr));
}
#[test]
fn test_collect_used_functions_basic() {
let expr = Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("foo".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
);
let used = collect_used_functions(&expr);
assert!(used.contains("foo"));
assert_eq!(used.len(), 1);
}
#[test]
fn test_collect_used_functions_nested_if() {
let expr = Expr::new(
ExprKind::If {
condition: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("check".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
)),
then_branch: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("action".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
)),
else_branch: None,
},
Span::new(0, 0),
);
let used = collect_used_functions(&expr);
assert!(used.contains("check"));
assert!(used.contains("action"));
assert_eq!(used.len(), 2);
}
#[test]
fn test_collect_used_functions_await() {
let mut used = HashSet::new();
let expr = Expr::new(
ExprKind::Await {
expr: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("async_fn".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
)),
},
Span::new(0, 0),
);
collect_used_functions_rec(&expr, &mut used);
assert!(used.contains("async_fn"));
}
#[test]
fn test_collect_used_functions_async_block() {
let mut used = HashSet::new();
let expr = Expr::new(
ExprKind::AsyncBlock {
body: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("work".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
)),
},
Span::new(0, 0),
);
collect_used_functions_rec(&expr, &mut used);
assert!(used.contains("work"));
}
#[test]
fn test_collect_used_functions_spawn() {
let mut used = HashSet::new();
let expr = Expr::new(
ExprKind::Spawn {
actor: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("actor_fn".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
)),
},
Span::new(0, 0),
);
collect_used_functions_rec(&expr, &mut used);
assert!(used.contains("actor_fn"));
}
#[test]
fn test_collect_used_variables_simple() {
let expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(5)),
body: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
is_mutable: false,
else_block: None,
},
Span::new(0, 0),
);
let used = collect_used_variables(&expr);
assert!(used.contains("x"));
assert_eq!(used.len(), 1);
}
#[test]
fn test_collect_used_variables_while_loop() {
let mut used = HashSet::new();
let mut bound = HashSet::new();
bound.insert("counter".to_string());
let expr = Expr::new(
ExprKind::While {
condition: Box::new(Expr::new(
ExprKind::Identifier("counter".to_string()),
Span::new(0, 0),
)),
body: Box::new(int_lit(1)),
label: None,
},
Span::new(0, 0),
);
collect_used_variables_rec(&expr, &mut used, &bound);
assert!(used.contains("counter"));
}
#[test]
fn test_collect_used_variables_return() {
let mut used = HashSet::new();
let mut bound = HashSet::new();
bound.insert("result".to_string());
let expr = Expr::new(
ExprKind::Return {
value: Some(Box::new(Expr::new(
ExprKind::Identifier("result".to_string()),
Span::new(0, 0),
))),
},
Span::new(0, 0),
);
collect_used_variables_rec(&expr, &mut used, &bound);
assert!(used.contains("result"));
}
#[test]
fn test_remove_dead_statements_after_return() {
let stmts = vec![
Expr::new(
ExprKind::Return {
value: Some(Box::new(int_lit(5))),
},
Span::new(0, 0),
),
int_lit(10), int_lit(20), ];
let result = remove_dead_statements(stmts);
assert_eq!(result.len(), 1); }
#[test]
fn test_remove_dead_statements_no_early_exit() {
let stmts = vec![int_lit(1), int_lit(2), int_lit(3)];
let result = remove_dead_statements(stmts);
assert_eq!(result.len(), 3); }
#[test]
fn test_propagate_constants_simple_let() {
let expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(5)),
body: Box::new(Expr::new(
ExprKind::Binary {
left: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
op: BinaryOp::Add,
right: Box::new(int_lit(1)),
},
Span::new(0, 0),
)),
is_mutable: false,
else_block: None,
},
Span::new(0, 0),
);
let result = propagate_constants(expr);
if let ExprKind::Let { body, .. } = result.kind {
assert!(matches!(
body.kind,
ExprKind::Literal(Literal::Integer(6, None))
));
} else {
panic!("Expected Let expression");
}
}
#[test]
fn test_propagate_constants_variable_substitution() {
let expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(10)),
body: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
is_mutable: false,
else_block: None,
},
Span::new(0, 0),
);
let result = propagate_constants(expr);
if let ExprKind::Let { body, .. } = result.kind {
assert!(matches!(
body.kind,
ExprKind::Literal(Literal::Integer(10, None))
));
} else {
panic!("Expected Let expression");
}
}
#[test]
fn test_propagate_with_env_mutable() {
let mut env = HashMap::new();
let expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(5)),
body: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
is_mutable: true, else_block: None,
},
Span::new(0, 0),
);
let result = propagate_with_env(expr, &mut env);
if let ExprKind::Let { body, .. } = result.kind {
assert!(matches!(body.kind, ExprKind::Identifier(_)));
} else {
panic!("Expected Let expression");
}
}
#[test]
fn test_propagate_with_env_call() {
let mut env = HashMap::new();
env.insert("x".to_string(), Literal::Integer(5, None));
let expr = Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("foo".to_string()),
Span::new(0, 0),
)),
args: vec![Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)],
},
Span::new(0, 0),
);
let result = propagate_with_env(expr, &mut env);
if let ExprKind::Call { args, .. } = result.kind {
assert_eq!(args.len(), 1);
assert!(matches!(
args[0].kind,
ExprKind::Literal(Literal::Integer(5, None))
));
} else {
panic!("Expected Call expression");
}
}
#[test]
fn test_fold_constants_if_const_true() {
let expr = Expr::new(
ExprKind::If {
condition: Box::new(Expr::new(
ExprKind::Literal(Literal::Bool(true)),
Span::new(0, 0),
)),
then_branch: Box::new(int_lit(42)),
else_branch: Some(Box::new(int_lit(99))),
},
Span::new(0, 0),
);
let folded = fold_constants(expr);
match &folded.kind {
ExprKind::Block(exprs) => {
assert_eq!(
exprs.len(),
1,
"Block should contain exactly one expression"
);
assert!(matches!(
exprs[0].kind,
ExprKind::Literal(Literal::Integer(42, None))
));
}
_ => panic!("Expected Block, got {:?}", folded.kind),
}
}
#[test]
fn test_fold_constants_if_const_false_with_else() {
let expr = Expr::new(
ExprKind::If {
condition: Box::new(Expr::new(
ExprKind::Literal(Literal::Bool(false)),
Span::new(0, 0),
)),
then_branch: Box::new(int_lit(42)),
else_branch: Some(Box::new(int_lit(99))),
},
Span::new(0, 0),
);
let folded = fold_constants(expr);
match &folded.kind {
ExprKind::Block(exprs) => {
assert_eq!(exprs.len(), 1);
assert!(matches!(
exprs[0].kind,
ExprKind::Literal(Literal::Integer(99, None))
));
}
_ => panic!("Expected Block, got {:?}", folded.kind),
}
}
#[test]
fn test_fold_constants_if_const_false_no_else() {
let expr = Expr::new(
ExprKind::If {
condition: Box::new(Expr::new(
ExprKind::Literal(Literal::Bool(false)),
Span::new(0, 0),
)),
then_branch: Box::new(int_lit(42)),
else_branch: None,
},
Span::new(0, 0),
);
let folded = fold_constants(expr);
match &folded.kind {
ExprKind::Block(exprs) => {
assert!(exprs.is_empty());
}
_ => panic!("Expected empty Block, got {:?}", folded.kind),
}
}
#[test]
fn test_fold_constants_nested_binary() {
let inner = Expr::new(
ExprKind::Binary {
left: Box::new(int_lit(2)),
op: BinaryOp::Add,
right: Box::new(int_lit(3)),
},
Span::new(0, 0),
);
let expr = Expr::new(
ExprKind::Binary {
left: Box::new(inner),
op: BinaryOp::Multiply,
right: Box::new(int_lit(4)),
},
Span::new(0, 0),
);
let folded = fold_constants(expr);
assert!(matches!(
folded.kind,
ExprKind::Literal(Literal::Integer(20, None))
));
}
#[test]
fn test_fold_constants_let_with_folded_value() {
let expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(binary(2, BinaryOp::Add, 3)),
body: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
is_mutable: false,
else_block: None,
},
Span::new(0, 0),
);
let folded = fold_constants(expr);
if let ExprKind::Let { value, .. } = folded.kind {
assert!(matches!(
value.kind,
ExprKind::Literal(Literal::Integer(5, None))
));
} else {
panic!("Expected Let expression");
}
}
#[test]
fn test_fold_constants_block() {
let expr = Expr::new(
ExprKind::Block(vec![
binary(1, BinaryOp::Add, 2),
binary(3, BinaryOp::Multiply, 4),
]),
Span::new(0, 0),
);
let folded = fold_constants(expr);
if let ExprKind::Block(exprs) = folded.kind {
assert_eq!(exprs.len(), 2);
assert!(matches!(
exprs[0].kind,
ExprKind::Literal(Literal::Integer(3, None))
));
assert!(matches!(
exprs[1].kind,
ExprKind::Literal(Literal::Integer(12, None))
));
} else {
panic!("Expected Block expression");
}
}
#[test]
fn test_fold_integer_comparison_false_cases() {
assert!(matches!(
fold_integer_comparison(5, BinaryOp::Equal, 3),
Some(Literal::Bool(false))
));
assert!(matches!(
fold_integer_comparison(5, BinaryOp::NotEqual, 5),
Some(Literal::Bool(false))
));
assert!(matches!(
fold_integer_comparison(5, BinaryOp::Less, 3),
Some(Literal::Bool(false))
));
assert!(matches!(
fold_integer_comparison(5, BinaryOp::Greater, 10),
Some(Literal::Bool(false))
));
}
#[test]
fn test_fold_integer_arithmetic_overflow() {
let result = fold_integer_arithmetic(i64::MAX, BinaryOp::Add, 1);
assert!(result.is_none());
}
#[test]
fn test_fold_integer_arithmetic_underflow() {
let result = fold_integer_arithmetic(i64::MIN, BinaryOp::Subtract, 1);
assert!(result.is_none());
}
#[test]
fn test_fold_integer_arithmetic_multiply_overflow() {
let result = fold_integer_arithmetic(i64::MAX, BinaryOp::Multiply, 2);
assert!(result.is_none());
}
#[test]
fn test_eliminate_dead_code_after_break() {
let stmts = vec![
Expr::new(
ExprKind::Break {
label: None,
value: None,
},
Span::new(0, 0),
),
int_lit(10), ];
let result = remove_dead_statements(stmts);
assert_eq!(result.len(), 1); }
#[test]
fn test_eliminate_dead_code_after_continue() {
let stmts = vec![
Expr::new(ExprKind::Continue { label: None }, Span::new(0, 0)),
int_lit(10), ];
let result = remove_dead_statements(stmts);
assert_eq!(result.len(), 1); }
#[test]
fn test_eliminate_dead_code_function_body() {
let expr = Expr::new(
ExprKind::Function {
name: "test".to_string(),
type_params: vec![],
params: vec![],
return_type: None,
body: Box::new(Expr::new(
ExprKind::Block(vec![
Expr::new(
ExprKind::Return {
value: Some(Box::new(int_lit(5))),
},
Span::new(0, 0),
),
int_lit(10), ]),
Span::new(0, 0),
)),
is_async: false,
is_pub: false,
},
Span::new(0, 0),
);
let result = eliminate_dead_code(expr, std::collections::HashSet::new());
if let ExprKind::Function { body, .. } = result.kind {
if let ExprKind::Block(exprs) = body.kind {
assert_eq!(exprs.len(), 1); }
}
}
#[test]
fn test_eliminate_dead_code_while_body() {
let expr = Expr::new(
ExprKind::While {
condition: Box::new(Expr::new(
ExprKind::Literal(Literal::Bool(true)),
Span::new(0, 0),
)),
body: Box::new(Expr::new(
ExprKind::Block(vec![
Expr::new(
ExprKind::Break {
label: None,
value: None,
},
Span::new(0, 0),
),
int_lit(10), ]),
Span::new(0, 0),
)),
label: None,
},
Span::new(0, 0),
);
let result = eliminate_dead_code(expr, std::collections::HashSet::new());
if let ExprKind::While { body, .. } = result.kind {
if let ExprKind::Block(exprs) = body.kind {
assert_eq!(exprs.len(), 1); }
}
}
#[test]
fn test_eliminate_dead_code_if_branches() {
let expr = Expr::new(
ExprKind::If {
condition: Box::new(Expr::new(
ExprKind::Identifier("cond".to_string()),
Span::new(0, 0),
)),
then_branch: Box::new(Expr::new(
ExprKind::Block(vec![
Expr::new(
ExprKind::Return {
value: Some(Box::new(int_lit(1))),
},
Span::new(0, 0),
),
int_lit(2), ]),
Span::new(0, 0),
)),
else_branch: Some(Box::new(Expr::new(
ExprKind::Block(vec![
Expr::new(
ExprKind::Return {
value: Some(Box::new(int_lit(3))),
},
Span::new(0, 0),
),
int_lit(4), ]),
Span::new(0, 0),
))),
},
Span::new(0, 0),
);
let result = eliminate_dead_code(expr, std::collections::HashSet::new());
if let ExprKind::If {
then_branch,
else_branch,
..
} = result.kind
{
if let ExprKind::Block(then_exprs) = then_branch.kind {
assert_eq!(then_exprs.len(), 1);
}
if let Some(else_box) = else_branch {
if let ExprKind::Block(else_exprs) = else_box.kind {
assert_eq!(else_exprs.len(), 1);
}
}
}
}
#[test]
fn test_eliminate_dead_code_call_args() {
let expr = Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("foo".to_string()),
Span::new(0, 0),
)),
args: vec![binary(1, BinaryOp::Add, 2)],
},
Span::new(0, 0),
);
let result = eliminate_dead_code(expr, std::collections::HashSet::new());
assert!(matches!(result.kind, ExprKind::Call { .. }));
}
#[test]
fn test_propagate_with_env_block_scope() {
let mut env = HashMap::new();
env.insert("outer".to_string(), Literal::Integer(10, None));
let expr = Expr::new(
ExprKind::Block(vec![
Expr::new(
ExprKind::Let {
name: "inner".to_string(),
type_annotation: None,
value: Box::new(int_lit(20)),
body: Box::new(Expr::new(
ExprKind::Literal(Literal::Unit),
Span::new(0, 0),
)),
is_mutable: false,
else_block: None,
},
Span::new(0, 0),
),
Expr::new(ExprKind::Identifier("outer".to_string()), Span::new(0, 0)),
]),
Span::new(0, 0),
);
let result = propagate_with_env(expr, &mut env);
if let ExprKind::Block(exprs) = result.kind {
assert_eq!(exprs.len(), 2);
}
}
#[test]
fn test_propagate_constants_in_if_condition() {
let expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(5)),
body: Box::new(Expr::new(
ExprKind::If {
condition: Box::new(Expr::new(
ExprKind::Binary {
left: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
op: BinaryOp::Greater,
right: Box::new(int_lit(3)),
},
Span::new(0, 0),
)),
then_branch: Box::new(int_lit(1)),
else_branch: Some(Box::new(int_lit(0))),
},
Span::new(0, 0),
)),
is_mutable: false,
else_block: None,
},
Span::new(0, 0),
);
let result = propagate_constants(expr);
if let ExprKind::Let { body, .. } = result.kind {
match body.kind {
ExprKind::Block(exprs) => {
assert_eq!(exprs.len(), 1);
}
_ => {} }
}
}
#[test]
fn test_collect_used_functions_in_block() {
let expr = Expr::new(
ExprKind::Block(vec![
Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("foo".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
),
Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("bar".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
),
]),
Span::new(0, 0),
);
let used = collect_used_functions(&expr);
assert!(used.contains("foo"));
assert!(used.contains("bar"));
assert_eq!(used.len(), 2);
}
#[test]
fn test_collect_used_functions_in_binary() {
let expr = Expr::new(
ExprKind::Binary {
left: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("left_fn".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
)),
op: BinaryOp::Add,
right: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("right_fn".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
)),
},
Span::new(0, 0),
);
let used = collect_used_functions(&expr);
assert!(used.contains("left_fn"));
assert!(used.contains("right_fn"));
}
#[test]
fn test_collect_used_functions_in_function_def() {
let expr = Expr::new(
ExprKind::Function {
name: "outer".to_string(),
type_params: vec![],
params: vec![],
return_type: None,
body: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("inner".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
)),
is_async: false,
is_pub: false,
},
Span::new(0, 0),
);
let used = collect_used_functions(&expr);
assert!(used.contains("inner"));
}
#[test]
fn test_collect_used_variables_in_block() {
let expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(1)),
body: Box::new(Expr::new(
ExprKind::Block(vec![Expr::new(
ExprKind::Let {
name: "y".to_string(),
type_annotation: None,
value: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
body: Box::new(Expr::new(
ExprKind::Identifier("y".to_string()),
Span::new(0, 0),
)),
is_mutable: false,
else_block: None,
},
Span::new(0, 0),
)]),
Span::new(0, 0),
)),
is_mutable: false,
else_block: None,
},
Span::new(0, 0),
);
let used = collect_used_variables(&expr);
assert!(used.contains("x"));
assert!(used.contains("y"));
}
#[test]
fn test_collect_used_variables_let_with_else() {
let expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(1)),
body: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::new(0, 0),
)),
is_mutable: false,
else_block: Some(Box::new(int_lit(0))),
},
Span::new(0, 0),
);
let used = collect_used_variables(&expr);
assert!(used.contains("x"));
}
#[test]
fn test_should_remove_function_main() {
let used = HashSet::new();
let mut inlined = std::collections::HashSet::new();
inlined.insert("main".to_string());
assert!(!should_remove_function("main", &used, &inlined));
}
#[test]
fn test_should_remove_function_inlined_and_unused() {
let used = HashSet::new();
let mut inlined = std::collections::HashSet::new();
inlined.insert("helper".to_string());
assert!(should_remove_function("helper", &used, &inlined));
}
#[test]
fn test_should_remove_function_inlined_but_used() {
let mut used = HashSet::new();
used.insert("helper".to_string());
let mut inlined = std::collections::HashSet::new();
inlined.insert("helper".to_string());
assert!(!should_remove_function("helper", &used, &inlined));
}
#[test]
fn test_process_let_elimination_used_variable() {
let used = {
let mut s = HashSet::new();
s.insert("x".to_string());
s
};
let value = int_lit(5);
let body = Expr::new(ExprKind::Identifier("x".to_string()), Span::new(0, 0));
let result = process_let_elimination("x", &value, &body, &used);
assert!(result.is_none()); }
#[test]
fn test_process_let_elimination_with_side_effects() {
let used = HashSet::new();
let value = Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("side_effect".to_string()),
Span::new(0, 0),
)),
args: vec![],
},
Span::new(0, 0),
);
let body = int_lit(1);
let result = process_let_elimination("unused", &value, &body, &used);
assert!(result.is_none()); }
}