use super::super::*;
use super::helpers::{program_contains_literal, simple_program};
use crate::ir::{DataType, Expr, UnOp};
#[test]
fn select_not_cond_flips() {
let c = Expr::var("c");
let a = Expr::var("a");
let b = Expr::var("b");
let nc = Expr::UnOp {
op: UnOp::LogicalNot,
operand: Box::new(c.clone()),
};
assert_eq!(
fold_expr(&Expr::select(nc, a.clone(), b.clone())),
Some(Expr::select(c, b, a))
);
}
#[test]
fn select_one_zero_becomes_cast() {
let c = Expr::var("c");
assert_eq!(
fold_expr(&Expr::select(c.clone(), Expr::u32(1), Expr::u32(0))),
Some(Expr::Cast {
target: DataType::U32,
value: Box::new(c)
})
);
}
#[test]
fn fma_one_b_c() {
let b = Expr::var("b");
let c = Expr::var("c");
assert_eq!(
fold_expr(&Expr::fma(Expr::f32(1.0), b.clone(), c.clone())),
Some(Expr::add(b, c))
);
}
#[test]
fn fma_a_one_c() {
let a = Expr::var("a");
let c = Expr::var("c");
assert_eq!(
fold_expr(&Expr::fma(a.clone(), Expr::f32(1.0), c.clone())),
Some(Expr::add(a, c))
);
}
#[test]
fn fma_zero_b_c() {
let c = Expr::var("c");
assert_eq!(
fold_expr(&Expr::fma(Expr::f32(0.0), Expr::f32(2.0), c.clone())),
Some(c)
);
}
#[test]
fn fma_a_zero_c() {
let c = Expr::var("c");
assert_eq!(
fold_expr(&Expr::fma(Expr::f32(2.0), Expr::f32(0.0), c.clone())),
Some(c)
);
}
#[test]
fn fma_a_b_zero() {
let a = Expr::var("a");
let b = Expr::var("b");
assert_eq!(
fold_expr(&Expr::fma(a.clone(), b.clone(), Expr::f32(0.0))),
None
);
}
#[test]
fn mul_by_neg_one_folds_to_negate() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::mul(x.clone(), Expr::i32(-1))),
Some(Expr::negate(x.clone()))
);
assert_eq!(
fold_expr(&Expr::mul(Expr::f32(-1.0), x.clone())),
Some(Expr::negate(x))
);
}
#[test]
fn double_cast_elimination() {
let x = Expr::var("x");
let inner_cast = Expr::Cast {
target: crate::ir::DataType::I32,
value: Box::new(x.clone()),
};
let outer_cast = Expr::Cast {
target: crate::ir::DataType::U32,
value: Box::new(inner_cast),
};
assert_eq!(
fold_expr(&outer_cast),
Some(Expr::Cast {
target: crate::ir::DataType::U32,
value: Box::new(x)
})
);
}
#[test]
fn select_of_select_fusion() {
let cond = Expr::var("c");
let a = Expr::var("a");
let b = Expr::var("b");
let d = Expr::var("d");
let inner_true = Expr::select(cond.clone(), a.clone(), b.clone());
let outer = Expr::select(cond.clone(), inner_true, d.clone());
assert_eq!(fold_expr(&outer), Some(Expr::select(cond, a, d)));
}
#[test]
fn select_bool_canonicalization() {
let cond = Expr::var("c");
let outer = Expr::select(cond.clone(), Expr::u32(0), Expr::u32(1));
assert_eq!(
fold_expr(&outer),
Some(Expr::Cast {
target: crate::ir::DataType::U32,
value: Box::new(Expr::UnOp {
op: UnOp::LogicalNot,
operand: Box::new(cond),
})
})
);
}
#[test]
fn reciprocal_sqrt_fusion() {
let x = Expr::var("x");
let sqrt = Expr::UnOp {
op: UnOp::Sqrt,
operand: Box::new(x.clone()),
};
let div = Expr::div(Expr::f32(1.0), sqrt);
assert_eq!(
fold_expr(&div),
Some(Expr::UnOp {
op: UnOp::InverseSqrt,
operand: Box::new(x)
})
);
}
#[test]
fn trig_division_peephole() {
let x = Expr::var("x");
let sin = Expr::UnOp {
op: UnOp::Sin,
operand: Box::new(x.clone()),
};
let cos = Expr::UnOp {
op: UnOp::Cos,
operand: Box::new(x.clone()),
};
let div = Expr::div(sin, cos);
assert_eq!(
fold_expr(&div),
Some(Expr::UnOp {
op: UnOp::Tan,
operand: Box::new(x)
})
);
}
#[test]
fn div_self_identity() {
let x = Expr::var("x");
assert_eq!(
fold_expr(&Expr::div(x.clone(), x.clone())),
Some(Expr::u32(1))
);
let f = Expr::fma(Expr::var("y"), Expr::var("z"), Expr::var("w")); assert_eq!(fold_expr(&Expr::div(f.clone(), f)), None);
}
#[test]
fn algebraic_reassociation() {
let x = Expr::var("x");
let left = Expr::add(x.clone(), Expr::u32(3));
let add1 = Expr::add(left, Expr::u32(4));
assert_eq!(fold_expr(&add1), Some(Expr::add(x.clone(), Expr::u32(7))));
let left2 = Expr::add(Expr::u32(3), x.clone());
let add2 = Expr::add(left2, Expr::u32(4));
assert_eq!(fold_expr(&add2), Some(Expr::add(x, Expr::u32(7))));
}
#[test]
fn distributive_law_add() {
let x = Expr::var("x");
let l = Expr::mul(x.clone(), Expr::u32(3));
let r = Expr::mul(x.clone(), Expr::u32(4));
let add = Expr::add(l, r);
assert_eq!(fold_expr(&add), Some(Expr::mul(x, Expr::u32(7))));
}
#[test]
fn distributive_law_sub() {
let x = Expr::var("x");
let l = Expr::mul(x.clone(), Expr::u32(5));
let r = Expr::mul(x.clone(), Expr::u32(2));
let sub = Expr::sub(l, r);
assert_eq!(fold_expr(&sub), Some(Expr::mul(x, Expr::u32(3))));
}
#[test]
fn popcount_of_nonzero_literal_folds_to_count() {
let expr = Expr::UnOp {
op: UnOp::Popcount,
operand: Box::new(Expr::u32(0xF0F0_F0F0)),
};
let folded = ConstFold::transform(simple_program(expr));
assert!(
program_contains_literal(&folded.program, 16),
"popcount(0xF0F0_F0F0) must fold to literal 16; got program: {:?}",
folded.program
);
}
#[test]
fn clz_of_nonzero_literal_folds_to_count() {
let expr = Expr::UnOp {
op: UnOp::Clz,
operand: Box::new(Expr::u32(0x0000_0008)),
};
let folded = ConstFold::transform(simple_program(expr));
assert!(
program_contains_literal(&folded.program, 28),
"clz(0x00000008) must fold to literal 28"
);
}
#[test]
fn ctz_of_nonzero_literal_folds_to_count() {
let expr = Expr::UnOp {
op: UnOp::Ctz,
operand: Box::new(Expr::u32(0x0000_0008)),
};
let folded = ConstFold::transform(simple_program(expr));
assert!(
program_contains_literal(&folded.program, 3),
"ctz(0x00000008) must fold to literal 3"
);
}
#[test]
fn reverse_bits_of_literal_folds_to_reversed() {
let expr = Expr::UnOp {
op: UnOp::ReverseBits,
operand: Box::new(Expr::u32(1)),
};
let folded = ConstFold::transform(simple_program(expr));
assert!(
program_contains_literal(&folded.program, 0x8000_0000),
"ReverseBits(1) must fold to 0x80000000"
);
}
#[test]
fn bitnot_of_literal_folds_to_complement() {
let expr = Expr::UnOp {
op: UnOp::BitNot,
operand: Box::new(Expr::u32(0)),
};
let folded = ConstFold::transform(simple_program(expr));
assert!(
program_contains_literal(&folded.program, u32::MAX),
"BitNot(0) must fold to u32::MAX"
);
}