use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node};
use crate::optimizer::passes::const_fold::ConstFold;
use crate::optimizer::{PassScheduler, ProgramPassKind};
#[test]
fn optimizer_strength_reduce_multiplies_by_two() {
let program = crate::optimizer::passes::cleanup::region_inline_engine::run(Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::mul(Expr::var("x"), Expr::u32(2)),
)],
));
let optimized = PassScheduler::with_passes(vec![
ProgramPassKind::new(ConstFold),
ProgramPassKind::new(StrengthReduce),
])
.run(program)
.expect("Fix: strength reduce should converge");
let body = crate::test_util::region_body(&optimized);
assert!(matches!(
&body[0],
Node::Store {
value: Expr::BinOp {
op: BinOp::Shl,
right,
..
},
..
} if matches!(right.as_ref(), Expr::LitU32(1))
));
}
#[test]
fn optimizer_strength_reduce_decomposes_mul_by_three() {
let program = Program::wrapped(
Vec::new(),
[1, 1, 1],
vec![Node::let_bind(
"x",
Expr::mul(Expr::var("input"), Expr::u32(3)),
)],
);
let optimized = PassScheduler::with_passes(vec![
ProgramPassKind::new(ConstFold),
ProgramPassKind::new(StrengthReduce),
])
.run(program)
.expect("Fix: strength reduce should converge");
let body = crate::test_util::region_body(&optimized);
assert!(
matches!(
&body[0],
Node::Let {
value: Expr::BinOp {
op: BinOp::Add | BinOp::Sub,
..
},
..
}
),
"x * 3 must decompose to a shift/add/sub chain: {body:?}"
);
}
#[test]
fn float_mul_by_two_becomes_add() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::f32(2.0)));
assert!(result.is_some());
let reduced = result.unwrap();
assert!(matches!(&reduced, Expr::BinOp { op: BinOp::Add, .. }));
}
#[test]
fn float_mul_by_one_becomes_identity() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::f32(1.0)));
assert_eq!(result, Some(Expr::var("x")));
}
#[test]
fn float_mul_by_zero_does_not_hide_runtime_nan() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::f32(0.0)));
assert_eq!(result, None);
}
#[test]
fn float_div_by_two_becomes_mul_half() {
let result = reduce_expr(&Expr::div(Expr::var("x"), Expr::f32(2.0)));
assert!(result.is_some());
let reduced = result.unwrap();
assert!(matches!(&reduced, Expr::BinOp { op: BinOp::Mul, .. }));
}
#[test]
fn float_add_zero_becomes_identity() {
let result = reduce_expr(&Expr::add(Expr::var("x"), Expr::f32(0.0)));
assert_eq!(result, Some(Expr::var("x")));
}
#[test]
fn float_sub_zero_becomes_identity() {
let result = reduce_expr(&Expr::sub(Expr::var("x"), Expr::f32(0.0)));
assert_eq!(result, Some(Expr::var("x")));
}
#[test]
fn int_div_by_power_of_two_becomes_shr() {
let result = reduce_expr(&Expr::div(Expr::var("x"), Expr::u32(8)));
assert!(result.is_some());
let reduced = result.unwrap();
assert!(matches!(&reduced, Expr::BinOp { op: BinOp::Shr, .. }));
}
#[test]
fn int_div_by_constant_becomes_mulhi() {
let result = reduce_expr(&Expr::div(Expr::var("x"), Expr::u32(3)));
assert!(result.is_some(), "x/3 must be strength-reduced");
let reduced = result.unwrap();
match &reduced {
Expr::BinOp {
op: BinOp::Shr,
left,
..
} => {
assert!(
matches!(
left.as_ref(),
Expr::BinOp {
op: BinOp::MulHigh,
..
}
),
"inner must be MulHigh: {left:?}"
);
}
other => panic!("x/3 must reduce to Shr(MulHigh(...)), got {other:?}"),
}
}
#[test]
fn int_div_by_seven_uses_fixup() {
let result = reduce_expr(&Expr::div(Expr::var("x"), Expr::u32(7)));
assert!(result.is_some(), "x/7 must be strength-reduced");
let reduced = result.unwrap();
match &reduced {
Expr::BinOp {
op: BinOp::Shr,
left,
..
} => {
assert!(
matches!(left.as_ref(), Expr::BinOp { op: BinOp::Add, .. }),
"fixup must produce Add at top: {left:?}"
);
}
other => panic!("x/7 must reduce to Shr(Add(...)), got {other:?}"),
}
}
#[test]
fn int_mod_by_power_of_two_becomes_bitand() {
let result = reduce_expr(&Expr::BinOp {
op: BinOp::Mod,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(16)),
});
assert!(result.is_some());
let reduced = result.unwrap();
assert!(matches!(
&reduced,
Expr::BinOp {
op: BinOp::BitAnd,
..
}
));
}
#[test]
fn float_div_by_constant_becomes_reciprocal_mul() {
let result = reduce_expr(&Expr::div(Expr::var("x"), Expr::f32(3.0)));
assert!(result.is_some());
let reduced = result.unwrap();
match &reduced {
Expr::BinOp {
op: BinOp::Mul,
right,
..
} => match right.as_ref() {
Expr::LitF32(v) => {
assert!((v - 1.0 / 3.0).abs() < 1e-7, "reciprocal should be ~0.333");
}
other => panic!("expected LitF32 reciprocal, got {other:?}"),
},
other => panic!("expected Mul, got {other:?}"),
}
}
#[test]
fn float_one_div_variable_becomes_reciprocal_unop() {
let result = reduce_expr(&Expr::div(Expr::f32(1.0), Expr::var("x")));
assert_eq!(result, Some(Expr::reciprocal(Expr::var("x"))));
}
#[test]
fn float_div_by_nan_does_not_reduce() {
let result = reduce_expr(&Expr::div(Expr::var("x"), Expr::f32(f32::NAN)));
assert!(result.is_none(), "NaN divisor must not fold");
}
#[test]
fn float_div_by_zero_does_not_reduce() {
let result = reduce_expr(&Expr::div(Expr::var("x"), Expr::f32(0.0)));
assert!(result.is_none(), "zero divisor must not fold");
}
#[test]
fn shift_add_mul_by_3() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::u32(3)));
assert!(result.is_some(), "x*3 must decompose");
let r = result.unwrap();
assert!(
matches!(&r, Expr::BinOp { op: BinOp::Sub, .. }),
"must be sub: {r:?}"
);
}
#[test]
fn shift_add_mul_by_5() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::u32(5)));
assert!(result.is_some(), "x*5 must decompose");
let r = result.unwrap();
assert!(
matches!(&r, Expr::BinOp { op: BinOp::Add, .. }),
"must be add: {r:?}"
);
}
#[test]
fn shift_add_mul_by_7() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::u32(7)));
assert!(result.is_some(), "x*7 must decompose");
let r = result.unwrap();
assert!(
matches!(&r, Expr::BinOp { op: BinOp::Sub, .. }),
"must be sub: {r:?}"
);
}
#[test]
fn shift_add_mul_by_9() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::u32(9)));
assert!(result.is_some(), "x*9 must decompose");
}
#[test]
fn shift_add_mul_by_15() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::u32(15)));
assert!(result.is_some(), "x*15 must decompose");
let r = result.unwrap();
assert!(
matches!(&r, Expr::BinOp { op: BinOp::Sub, .. }),
"must be sub: {r:?}"
);
}
#[test]
fn shift_add_decomposes_prime_11_with_naf() {
let result = reduce_expr(&Expr::mul(Expr::var("x"), Expr::u32(11)));
assert!(result.is_some(), "x*11 must use the bounded NAF chain");
let r = result.unwrap();
assert!(
matches!(&r, Expr::BinOp { op: BinOp::Sub, .. }),
"must be a subtractive chain: {r:?}"
);
}
#[test]
fn shift_add_skips_expensive_operands_to_avoid_duplication() {
let expensive = Expr::add(Expr::load("input", Expr::gid_x()), Expr::u32(1));
let result = reduce_expr(&Expr::mul(expensive, Expr::u32(11)));
assert!(
result.is_none(),
"bounded shift/add chains must not duplicate non-trivial operands"
);
}
#[test]
fn integer_mul_zero_and_one_fold() {
assert_eq!(
reduce_expr(&Expr::mul(Expr::var("x"), Expr::u32(0))),
Some(Expr::u32(0))
);
assert_eq!(
reduce_expr(&Expr::mul(Expr::u32(1), Expr::var("x"))),
Some(Expr::var("x"))
);
}
#[test]
fn shift_add_does_not_fire_for_floats() {
let result = shift_add_decompose(&Expr::var("x"), &Expr::f32(3.0));
assert!(result.is_none());
}
#[test]
fn horner_rewrites_expanded_u32_quadratic() {
let x = Expr::var("x");
let quadratic = Expr::mul(Expr::mul(Expr::u32(3), x.clone()), x.clone());
let linear = Expr::mul(Expr::u32(5), x.clone());
let expanded = Expr::add(Expr::add(quadratic, linear), Expr::u32(7));
let result = reduce_expr(&expanded).expect("Fix: u32 quadratic must rewrite to Horner form");
let expected = Expr::add(
Expr::mul(
Expr::add(Expr::mul(Expr::u32(3), Expr::var("x")), Expr::u32(5)),
Expr::var("x"),
),
Expr::u32(7),
);
assert_eq!(result, expected);
}
#[test]
fn horner_accepts_commuted_terms_and_implicit_coefficients() {
let expanded = Expr::add(
Expr::u32(9),
Expr::add(Expr::var("x"), Expr::mul(Expr::var("x"), Expr::var("x"))),
);
let result = reduce_expr(&expanded).expect("Fix: x*x + x + c must rewrite");
let expected = Expr::add(
Expr::mul(
Expr::add(Expr::mul(Expr::u32(1), Expr::var("x")), Expr::u32(1)),
Expr::var("x"),
),
Expr::u32(9),
);
assert_eq!(result, expected);
}
#[test]
fn horner_rejects_float_quadratic_to_preserve_rounding_contract() {
let x = Expr::var("x");
let quadratic = Expr::mul(Expr::mul(Expr::f32(3.0), x.clone()), x.clone());
let linear = Expr::mul(Expr::f32(5.0), x);
let expanded = Expr::add(Expr::add(quadratic, linear), Expr::f32(7.0));
assert!(
horner_quadratic_u32(&expanded).is_none(),
"float polynomial reassociation changes rounding and must stay untouched"
);
}
#[test]
fn shift_by_zero_is_identity() {
let result = reduce_expr(&Expr::shl(Expr::var("x"), Expr::u32(0)));
assert_eq!(result, Some(Expr::var("x")));
}
#[test]
fn shr_by_zero_is_identity() {
let result = reduce_expr(&Expr::shr(Expr::var("x"), Expr::u32(0)));
assert_eq!(result, Some(Expr::var("x")));
}
#[test]
fn chained_shl_fuses() {
let inner = Expr::shl(Expr::var("x"), Expr::u32(3));
let result = reduce_expr(&Expr::shl(inner, Expr::u32(4)));
assert!(result.is_some());
let r = result.unwrap();
assert!(
matches!(
&r,
Expr::BinOp {
op: BinOp::Shl,
right,
..
} if matches!(right.as_ref(), Expr::LitU32(7))
),
"must fuse to x<<7: {r:?}"
);
}
#[test]
fn chained_shr_fuses() {
let inner = Expr::shr(Expr::var("x"), Expr::u32(2));
let result = reduce_expr(&Expr::shr(inner, Expr::u32(5)));
assert!(result.is_some());
let r = result.unwrap();
assert!(
matches!(
&r,
Expr::BinOp {
op: BinOp::Shr,
right,
..
} if matches!(right.as_ref(), Expr::LitU32(7))
),
"must fuse to x>>7: {r:?}"
);
}
#[test]
fn mixed_shift_does_not_fuse() {
let inner = Expr::shl(Expr::var("x"), Expr::u32(3));
let result = reduce_expr(&Expr::shr(inner, Expr::u32(4)));
assert!(result.is_none(), "mixed-direction shifts must not fuse");
}
#[test]
fn add_neg_becomes_sub() {
let result = reduce_expr(&Expr::add(Expr::var("x"), Expr::negate(Expr::var("y"))));
let expected = Expr::sub(Expr::var("x"), Expr::var("y"));
assert_eq!(result, Some(expected));
}
#[test]
fn neg_add_becomes_sub() {
let result = reduce_expr(&Expr::add(Expr::negate(Expr::var("x")), Expr::var("y")));
let expected = Expr::sub(Expr::var("y"), Expr::var("x"));
assert_eq!(result, Some(expected));
}
#[test]
fn sub_neg_becomes_add() {
let result = reduce_expr(&Expr::sub(Expr::var("x"), Expr::negate(Expr::var("y"))));
let expected = Expr::add(Expr::var("x"), Expr::var("y"));
assert_eq!(result, Some(expected));
}
#[test]
fn reverse_float_mul_sub_becomes_fma() {
let result = reduce_expr(&Expr::sub(
Expr::f32(1.0),
Expr::mul(Expr::var("a"), Expr::var("b")),
));
let reduced = result.expect("Fix: reverse multiply-subtract must synthesize FMA");
assert!(
matches!(
&reduced,
Expr::Fma {
a,
b,
c
} if matches!(a.as_ref(), Expr::UnOp { op: UnOp::Negate, .. })
&& matches!(b.as_ref(), Expr::Var(name) if name == "b")
&& matches!(c.as_ref(), Expr::LitF32(v) if *v == 1.0)
),
"expected fma(-a, b, c), got {reduced:?}"
);
}
#[test]
fn bitand_complement_is_zero() {
let x = Expr::var("x");
let expr = Expr::bitand(
x.clone(),
Expr::UnOp {
op: UnOp::BitNot,
operand: Box::new(x),
},
);
assert_eq!(reduce_expr(&expr), Some(Expr::u32(0)));
}
#[test]
fn bitand_complement_reversed() {
let x = Expr::var("x");
let expr = Expr::bitand(
Expr::UnOp {
op: UnOp::BitNot,
operand: Box::new(x.clone()),
},
x,
);
assert_eq!(reduce_expr(&expr), Some(Expr::u32(0)));
}
#[test]
fn bitor_complement_is_all_ones() {
let x = Expr::var("x");
let expr = Expr::bitor(
x.clone(),
Expr::UnOp {
op: UnOp::BitNot,
operand: Box::new(x),
},
);
assert_eq!(reduce_expr(&expr), Some(Expr::u32(u32::MAX)));
}
#[test]
fn bitxor_complement_is_all_ones() {
let x = Expr::var("x");
let expr = Expr::bitxor(
x.clone(),
Expr::UnOp {
op: UnOp::BitNot,
operand: Box::new(x),
},
);
assert_eq!(reduce_expr(&expr), Some(Expr::u32(u32::MAX)));
}
#[test]
fn rotate_left_zero_is_identity() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::RotateLeft,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(0)),
};
assert_eq!(reduce_expr(&expr), Some(x));
}
#[test]
fn rotate_right_zero_is_identity() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::RotateRight,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(0)),
};
assert_eq!(reduce_expr(&expr), Some(x));
}
#[test]
fn rotate_left_32_is_identity() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::RotateLeft,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(32)),
};
assert_eq!(reduce_expr(&expr), Some(x));
}
#[test]
fn absdiff_self_is_zero() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::AbsDiff,
left: Box::new(x.clone()),
right: Box::new(x),
};
assert_eq!(reduce_expr(&expr), Some(Expr::u32(0)));
}
#[test]
fn min_zero_unsigned_is_zero() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::Min,
left: Box::new(x),
right: Box::new(Expr::u32(0)),
};
assert_eq!(reduce_expr(&expr), Some(Expr::u32(0)));
}
#[test]
fn max_zero_unsigned_is_x() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::Max,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(0)),
};
assert_eq!(reduce_expr(&expr), Some(x));
}
#[test]
fn min_max_unsigned_is_x() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::Min,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(u32::MAX)),
};
assert_eq!(reduce_expr(&expr), Some(x));
}
#[test]
fn max_max_unsigned_is_max() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::Max,
left: Box::new(x),
right: Box::new(Expr::u32(u32::MAX)),
};
assert_eq!(reduce_expr(&expr), Some(Expr::u32(u32::MAX)));
}
#[test]
fn lt_zero_unsigned_is_false() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::Lt,
left: Box::new(x),
right: Box::new(Expr::u32(0)),
};
assert_eq!(reduce_expr(&expr), Some(Expr::bool(false)));
}
#[test]
fn ge_zero_unsigned_is_true() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::Ge,
left: Box::new(x),
right: Box::new(Expr::u32(0)),
};
assert_eq!(reduce_expr(&expr), Some(Expr::bool(true)));
}
#[test]
fn zero_gt_x_unsigned_is_false() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::Gt,
left: Box::new(Expr::u32(0)),
right: Box::new(x),
};
assert_eq!(reduce_expr(&expr), Some(Expr::bool(false)));
}
#[test]
fn zero_le_x_unsigned_is_true() {
let x = Expr::var("x");
let expr = Expr::BinOp {
op: BinOp::Le,
left: Box::new(Expr::u32(0)),
right: Box::new(x),
};
assert_eq!(reduce_expr(&expr), Some(Expr::bool(true)));
}
#[test]
fn negate_negate_cancels_to_identity() {
let x = Expr::var("x");
let expr = Expr::UnOp {
op: crate::ir::UnOp::Negate,
operand: Box::new(Expr::UnOp {
op: crate::ir::UnOp::Negate,
operand: Box::new(x.clone()),
}),
};
assert_eq!(
reduce_expr_extra(&expr),
Some(x),
"Negate(Negate(x)) must reduce to x; without this every double-negation chain pays \
two extra ops at runtime"
);
}
#[test]
fn bitnot_bitnot_cancels_to_identity() {
let x = Expr::var("x");
let expr = Expr::UnOp {
op: crate::ir::UnOp::BitNot,
operand: Box::new(Expr::UnOp {
op: crate::ir::UnOp::BitNot,
operand: Box::new(x.clone()),
}),
};
assert_eq!(
reduce_expr_extra(&expr),
Some(x),
"BitNot(BitNot(x)) must reduce to x"
);
}
#[test]
fn reverse_bits_self_inverse() {
let x = Expr::var("x");
let expr = Expr::UnOp {
op: crate::ir::UnOp::ReverseBits,
operand: Box::new(Expr::UnOp {
op: crate::ir::UnOp::ReverseBits,
operand: Box::new(x.clone()),
}),
};
assert_eq!(
reduce_expr_extra(&expr),
Some(x),
"ReverseBits(ReverseBits(x)) must reduce to x"
);
}
#[test]
fn select_with_identical_arms_collapses_to_arm() {
let x = Expr::var("x");
let cond = Expr::var("c");
let expr = Expr::Select {
cond: Box::new(cond),
true_val: Box::new(x.clone()),
false_val: Box::new(x.clone()),
};
assert_eq!(
reduce_expr_extra(&expr),
Some(x),
"select(c, x, x) must collapse to x — the condition is dead. Without this, \
post-CSE merges that collapse both arms still pay the branch."
);
}
#[test]
fn select_with_constant_true_collapses_to_true_arm() {
let true_arm = Expr::u32(42);
let false_arm = Expr::u32(99);
let expr = Expr::Select {
cond: Box::new(Expr::bool(true)),
true_val: Box::new(true_arm.clone()),
false_val: Box::new(false_arm),
};
assert_eq!(
reduce_expr_extra(&expr),
Some(true_arm),
"select(true, a, b) must collapse to a"
);
}
#[test]
fn select_with_constant_false_collapses_to_false_arm() {
let true_arm = Expr::u32(42);
let false_arm = Expr::u32(99);
let expr = Expr::Select {
cond: Box::new(Expr::bool(false)),
true_val: Box::new(true_arm),
false_val: Box::new(false_arm.clone()),
};
assert_eq!(
reduce_expr_extra(&expr),
Some(false_arm),
"select(false, a, b) must collapse to b"
);
}
#[test]
fn negate_single_does_not_collapse() {
let x = Expr::var("x");
let expr = Expr::UnOp {
op: crate::ir::UnOp::Negate,
operand: Box::new(x),
};
assert_eq!(
reduce_expr_extra(&expr),
None,
"Negate(x) on its own must not be rewritten — the peephole only fires on \
Negate(Negate(x))"
);
}
#[test]
fn select_with_distinct_arms_does_not_collapse() {
let x = Expr::var("x");
let y = Expr::var("y");
let cond = Expr::var("c");
let expr = Expr::Select {
cond: Box::new(cond),
true_val: Box::new(x),
false_val: Box::new(y),
};
assert_eq!(
reduce_expr_extra(&expr),
None,
"select(c, x, y) with distinct arms must NOT collapse — without this contract a \
legitimate branch would be silently rewritten away"
);
}
#[test]
fn div_one_by_constant_folds_to_reciprocal_literal() {
let result = reduce_expr(&Expr::div(Expr::f32(1.0), Expr::f32(4.0)));
assert_eq!(
result,
Some(Expr::f32(0.25)),
"Div(1.0, 4.0) must fold to LitF32(0.25)"
);
}
#[test]
fn div_one_by_zero_does_not_fold() {
let result = reduce_expr(&Expr::div(Expr::f32(1.0), Expr::f32(0.0)));
assert!(
result.is_none(),
"Div(1.0, 0.0) must NOT fold — div-by-zero is a trap"
);
}