use super::super::*;
use crate::ir::Expr;
#[test]
fn eq_self_is_true() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Eq,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::bool(true))
);
}
#[test]
fn ne_self_is_false() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Ne,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::bool(false))
);
}
#[test]
fn lt_self_is_false() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Lt,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::bool(false))
);
}
#[test]
fn gt_self_is_false() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Gt,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::bool(false))
);
}
#[test]
fn le_self_is_true() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Le,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::bool(true))
);
}
#[test]
fn ge_self_is_true() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Ge,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::bool(true))
);
}
#[test]
fn mod_one_is_zero() {
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mod,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(1))
}),
Some(Expr::u32(0))
);
}
#[test]
fn mod_self_is_zero() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mod,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::u32(0))
);
}
#[test]
fn min_self_is_self() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Min,
left: Box::new(x.clone()),
right: Box::new(x.clone())
}),
Some(x)
);
}
#[test]
fn max_self_is_self() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Max,
left: Box::new(x.clone()),
right: Box::new(x.clone())
}),
Some(x)
);
}
#[test]
fn div_self_is_one() {
let x = Expr::var("x");
assert_eq!(fold_expr(&Expr::div(x.clone(), x)), Some(Expr::u32(1)));
}
#[test]
fn wrapping_add_zero() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::WrappingAdd,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(0))
}),
Some(x)
);
}
#[test]
fn wrapping_sub_self() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::WrappingSub,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::u32(0))
);
}
#[test]
fn saturating_add_zero() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::SaturatingAdd,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(0))
}),
Some(x)
);
}
#[test]
fn saturating_sub_self() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::SaturatingSub,
left: Box::new(x.clone()),
right: Box::new(x)
}),
Some(Expr::u32(0))
);
}
#[test]
fn saturating_mul_one() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::SaturatingMul,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(1))
}),
Some(x)
);
}
#[test]
fn saturating_mul_zero() {
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::SaturatingMul,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(0))
}),
Some(Expr::u32(0))
);
}
#[test]
fn and_true_id() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::And,
left: Box::new(Expr::bool(true)),
right: Box::new(x.clone())
}),
Some(x)
);
}
#[test]
fn and_false_ann() {
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::And,
left: Box::new(Expr::bool(false)),
right: Box::new(Expr::var("x"))
}),
Some(Expr::bool(false))
);
}
#[test]
fn or_true_ann() {
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Or,
left: Box::new(Expr::bool(true)),
right: Box::new(Expr::var("x"))
}),
Some(Expr::bool(true))
);
}
#[test]
fn or_false_id() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Or,
left: Box::new(Expr::bool(false)),
right: Box::new(x.clone())
}),
Some(x)
);
}
#[test]
fn bitand_all_ones() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::bitand(x.clone(), Expr::u32(u32::MAX))),
Some(x)
);
}
#[test]
fn bitor_all_ones() {
assert_eq!(
fold_expr(&Expr::bitor(Expr::var("x"), Expr::u32(u32::MAX))),
Some(Expr::u32(u32::MAX))
);
}
#[test]
fn and_x_not_x_is_false_contradiction() {
let x = Expr::var("c");
let not_x = Expr::UnOp {
op: crate::ir::UnOp::LogicalNot,
operand: Box::new(x.clone()),
};
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::And,
left: Box::new(x),
right: Box::new(not_x)
}),
Some(Expr::bool(false))
);
}
#[test]
fn and_not_x_x_is_false_contradiction_left_not() {
let x = Expr::var("c");
let not_x = Expr::UnOp {
op: crate::ir::UnOp::LogicalNot,
operand: Box::new(x.clone()),
};
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::And,
left: Box::new(not_x),
right: Box::new(x)
}),
Some(Expr::bool(false))
);
}
#[test]
fn or_x_not_x_is_true_tautology() {
let x = Expr::var("c");
let not_x = Expr::UnOp {
op: crate::ir::UnOp::LogicalNot,
operand: Box::new(x.clone()),
};
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Or,
left: Box::new(x),
right: Box::new(not_x)
}),
Some(Expr::bool(true))
);
}
#[test]
fn or_not_x_x_is_true_tautology_left_not() {
let x = Expr::var("c");
let not_x = Expr::UnOp {
op: crate::ir::UnOp::LogicalNot,
operand: Box::new(x.clone()),
};
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Or,
left: Box::new(not_x),
right: Box::new(x)
}),
Some(Expr::bool(true))
);
}
#[test]
fn absorption_and_over_or() {
let x = Expr::var("x");
let y = Expr::var("y");
let or_xy = Expr::BinOp {
op: crate::ir::BinOp::Or,
left: Box::new(x.clone()),
right: Box::new(y),
};
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::And,
left: Box::new(x.clone()),
right: Box::new(or_xy)
}),
Some(x)
);
}
#[test]
fn absorption_or_over_and() {
let x = Expr::var("x");
let y = Expr::var("y");
let and_xy = Expr::BinOp {
op: crate::ir::BinOp::And,
left: Box::new(x.clone()),
right: Box::new(y),
};
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Or,
left: Box::new(x.clone()),
right: Box::new(and_xy)
}),
Some(x)
);
}
#[test]
fn reflexive_eq_on_load_does_not_fold() {
let load = Expr::load("buf", Expr::u32(0));
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Eq,
left: Box::new(load.clone()),
right: Box::new(load)
}),
None,
"Eq(Load, Load) must not fold"
);
}
#[test]
fn min_with_u32_max_is_identity() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Min,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(u32::MAX))
}),
Some(x.clone())
);
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Min,
left: Box::new(Expr::u32(u32::MAX)),
right: Box::new(x.clone())
}),
Some(x)
);
}
#[test]
fn max_with_zero_is_identity() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Max,
left: Box::new(x.clone()),
right: Box::new(Expr::u32(0))
}),
Some(x.clone())
);
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Max,
left: Box::new(Expr::u32(0)),
right: Box::new(x.clone())
}),
Some(x)
);
}
#[test]
fn min_with_zero_is_zero() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Min,
left: Box::new(x),
right: Box::new(Expr::u32(0))
}),
Some(Expr::u32(0))
);
}
#[test]
fn max_with_u32_max_is_u32_max() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Max,
left: Box::new(x),
right: Box::new(Expr::u32(u32::MAX))
}),
Some(Expr::u32(u32::MAX))
);
}
#[test]
fn lt_zero_for_u32_is_false() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Lt,
left: Box::new(x),
right: Box::new(Expr::u32(0))
}),
Some(Expr::bool(false))
);
}
#[test]
fn ge_zero_for_u32_is_true() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Ge,
left: Box::new(x),
right: Box::new(Expr::u32(0))
}),
Some(Expr::bool(true))
);
}
#[test]
fn le_u32_max_is_true() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Le,
left: Box::new(x),
right: Box::new(Expr::u32(u32::MAX))
}),
Some(Expr::bool(true))
);
}
#[test]
fn gt_u32_max_is_false() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Gt,
left: Box::new(x),
right: Box::new(Expr::u32(u32::MAX))
}),
Some(Expr::bool(false))
);
}
#[test]
fn distributes_mul_lit_over_add_when_one_arm_is_literal() {
let folded = fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mul,
left: Box::new(Expr::u32(3)),
right: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(7)),
}),
});
assert_eq!(
folded,
Some(Expr::add(
Expr::mul(Expr::u32(3), Expr::var("x")),
Expr::mul(Expr::u32(3), Expr::u32(7)),
))
);
}
#[test]
fn distributes_add_lit_times_mul_lit_when_one_arm_is_literal() {
let folded = fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mul,
left: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::u32(5)),
right: Box::new(Expr::var("y")),
}),
right: Box::new(Expr::u32(4)),
});
assert_eq!(
folded,
Some(Expr::add(
Expr::mul(Expr::u32(5), Expr::u32(4)),
Expr::mul(Expr::var("y"), Expr::u32(4)),
))
);
}
#[test]
fn distributes_mul_lit_i32_over_add_when_one_arm_is_literal() {
let folded = fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mul,
left: Box::new(Expr::i32(3)),
right: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::i32(7)),
}),
});
assert_eq!(
folded,
Some(Expr::add(
Expr::mul(Expr::i32(3), Expr::var("x")),
Expr::mul(Expr::i32(3), Expr::i32(7)),
))
);
}
#[test]
fn does_not_distribute_when_neither_addend_is_literal() {
let folded = fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mul,
left: Box::new(Expr::u32(3)),
right: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::var("y")),
}),
});
assert_eq!(folded, None);
}
#[test]
fn does_not_distribute_when_scalar_is_not_literal() {
let folded = fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mul,
left: Box::new(Expr::var("c")),
right: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(7)),
}),
});
assert_eq!(folded, None);
}
#[test]
fn does_not_distribute_for_float_operands() {
let folded = fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mul,
left: Box::new(Expr::f32(3.0)),
right: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::f32(7.0)),
}),
});
assert_eq!(folded, None);
}
#[test]
fn distributes_mul_lit_over_sub_when_one_arm_is_literal() {
let folded = fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mul,
left: Box::new(Expr::u32(3)),
right: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Sub,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(7)),
}),
});
let expected = Expr::BinOp {
op: crate::ir::BinOp::Sub,
left: Box::new(Expr::mul(Expr::u32(3), Expr::var("x"))),
right: Box::new(Expr::mul(Expr::u32(3), Expr::u32(7))),
};
assert_eq!(folded, Some(expected));
}
#[test]
fn distributes_sub_lit_times_mul_lit_when_one_arm_is_literal() {
let folded = fold_expr(&Expr::BinOp {
op: crate::ir::BinOp::Mul,
left: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Sub,
left: Box::new(Expr::u32(7)),
right: Box::new(Expr::var("x")),
}),
right: Box::new(Expr::u32(3)),
});
let expected = Expr::BinOp {
op: crate::ir::BinOp::Sub,
left: Box::new(Expr::mul(Expr::u32(7), Expr::u32(3))),
right: Box::new(Expr::mul(Expr::var("x"), Expr::u32(3))),
};
assert_eq!(folded, Some(expected));
}
fn test_mod_program(c: u32, n: u32) -> crate::optimizer::PassResult {
use crate::ir::{BufferDecl, DataType, Node, Program};
use crate::optimizer::passes::algebraic::const_fold::ConstFold;
let entry = vec![
Node::let_bind("x", Expr::u32(c)),
Node::let_bind(
"y",
Expr::BinOp {
op: crate::ir::BinOp::Mod,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(n)),
},
),
Node::store("out", Expr::u32(0), Expr::var("y")),
];
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
entry,
);
ConstFold::transform(program)
}
fn extract_let_y_value(nodes: &[crate::ir::Node]) -> Option<Expr> {
for n in nodes {
match n {
crate::ir::Node::Let { name, value } if name.as_str() == "y" => {
return Some(value.clone())
}
crate::ir::Node::Region { body, .. } => {
if let Some(v) = extract_let_y_value(body) {
return Some(v);
}
}
_ => {}
}
}
None
}
#[test]
fn stronger_range_fold_mod_positive() {
let result = test_mod_program(5, 10);
assert!(result.changed);
assert_eq!(
extract_let_y_value(result.program.entry()),
Some(Expr::var("x"))
);
}
#[test]
fn stronger_range_fold_mod_negative_c_ge_n() {
let result = test_mod_program(15, 10);
let y = extract_let_y_value(result.program.entry());
assert_ne!(y, Some(Expr::var("x")));
}
#[test]
fn stronger_range_fold_mod_negative_not_literal() {
use crate::ir::{BufferDecl, DataType, Node, Program};
use crate::optimizer::passes::algebraic::const_fold::ConstFold;
let entry = vec![
Node::let_bind("x", Expr::add(Expr::var("z"), Expr::u32(1))),
Node::let_bind(
"y",
Expr::BinOp {
op: crate::ir::BinOp::Mod,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(10)),
},
),
];
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
entry,
);
let result = ConstFold::transform(program);
assert!(!result.changed);
}