use crate::{
DType, Map,
kernel::{BOp, Kernel, Op, OpId},
shape::Dim,
};
impl Kernel {
pub fn algebraic_simplification(&mut self) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("algebraic_simplification");
self.unfuse_mad();
self.simplify_shl_shr_roundtrips();
self.simplify_bitwise_identities();
let bounds = self.compute_bounds();
let mut op_id = self.head;
while !op_id.is_null() {
let next = self.next_op(op_id);
if let &Op::Binary { x, y, bop } = self.at(op_id) {
if matches!(bop, BOp::Div | BOp::Mod) {
if let Op::Const(divisor) = self.at(y) {
let dtype = divisor.dtype();
if let Some(divisor) = divisor.as_dim() {
match bop {
BOp::Mod => self.simplify_mod(op_id, x, y, dtype, &bounds),
BOp::Div => self.simplify_div(op_id, x, divisor, dtype, &bounds),
_ => {}
}
}
}
}
}
op_id = next;
}
self.dead_code_elimination();
self.verify();
}
fn simplify_shl_shr_roundtrips(&mut self) {
let mut op_id = self.head;
while !op_id.is_null() {
let next = self.next_op(op_id);
if let Some(y) = self.match_shl_shr_roundtrip(op_id) {
self.remap(op_id, y);
}
op_id = next;
}
self.dead_code_elimination();
}
fn match_shl_shr_roundtrip(&self, op_id: OpId) -> Option<OpId> {
let Op::Binary { x: add_op, y: shift_amount, bop: BOp::BitShiftRight } = self.at(op_id) else {
return None;
};
let Op::Const(cst) = self.at(*shift_amount) else { return None };
let n = cst.as_dim()?;
if n >= 64 {
return None;
}
let Op::Binary { x: add_x, y: add_y, bop: BOp::Add } = self.at(*add_op) else {
return None;
};
for candidate in [add_x, add_y] {
if let Op::Binary { x: y, y: s, bop: BOp::BitShiftLeft } = self.at(*candidate) {
if let Op::Const(c) = self.at(*s) {
if c.as_dim() == Some(n) {
return Some(*y);
}
}
}
}
None
}
fn simplify_bitwise_identities(&mut self) {
let mut op_id = self.head;
while !op_id.is_null() {
let next = self.next_op(op_id);
if let Some(replacement) = self.match_bitwise_identity(op_id) {
self.remap(op_id, replacement);
}
op_id = next;
}
self.dead_code_elimination();
}
fn match_bitwise_identity(&self, op_id: OpId) -> Option<OpId> {
if let Op::Binary { x, y, bop: BOp::BitAnd } = self.at(op_id) {
for candidate in [(*x, *y), (*y, *x)] {
if let Op::Const(c) = self.at(candidate.0) {
if c.is_max() {
return Some(candidate.1);
}
}
}
}
if let Op::Binary { x, y, bop: BOp::BitOr } = self.at(op_id) {
for candidate in [(*x, *y), (*y, *x)] {
if let Op::Const(c) = self.at(candidate.0) {
if c.as_dim() == Some(0) {
return Some(candidate.1);
}
}
}
}
None
}
#[allow(unused)]
fn const_dim(&self, op_id: OpId) -> Option<Dim> {
let Op::Const(c) = self.ops[op_id].op else { return None };
c.as_dim()
}
#[allow(unused)]
fn get_add_sub_chain(&self, op_id: OpId) -> Vec<OpId> {
todo!()
}
fn simplify_div(&mut self, op_id: OpId, x: OpId, divisor: Dim, dtype: DType, bounds: &Map<OpId, (Dim, Dim)>) {
if let Some((a, c, _)) = mul_add(self, x) {
if c == divisor {
self.remap(op_id, a);
return;
}
}
if let Some((a, c, _)) = mad(self, x) {
if c == divisor {
self.remap(op_id, a);
return;
}
}
let Some(&(_, xu)) = bounds.get(&x) else { return };
if xu < divisor {
self.ops[op_id].op = Op::Const(dtype.zero_constant());
}
}
fn simplify_mod(&mut self, op_id: OpId, x: OpId, divisor_const: OpId, _dtype: DType, bounds: &Map<OpId, (Dim, Dim)>) {
let Op::Const(divisor) = self.ops[divisor_const].op else { return };
let Some(divisor) = divisor.as_dim() else { return };
if let Some(&(_, max_x)) = bounds.get(&x) {
if max_x < divisor {
self.remap(op_id, x);
return;
}
}
if let Some((a, c, b)) = mul_add(self, x) {
if c == divisor {
self.ops[op_id].op = Op::Binary { x: b, y: divisor_const, bop: BOp::Mod };
if let Some(&(_, max_b)) = bounds.get(&b) {
if max_b < divisor {
self.remap(op_id, b);
}
}
return;
}
if c % divisor == 1 {
let a_plus_b = self.insert_before(op_id, Op::Binary { x: a, y: b, bop: BOp::Add });
self.ops[op_id].op = Op::Binary { x: a_plus_b, y: divisor_const, bop: BOp::Mod };
if let Some(&(_, max_a)) = bounds.get(&a)
&& let Some(&(_, max_b)) = bounds.get(&b)
{
if max_a.saturating_add(max_b) < divisor {
self.remap(op_id, a_plus_b);
}
}
return;
}
if let Some(&(_min_a, max_a)) = bounds.get(&a) {
let max_a_c = max_a.saturating_mul(c);
if let Some(&(min_b, max_b)) = bounds.get(&b) {
if min_b == 0 && max_a_c.saturating_add(max_b) < divisor {
self.ops[op_id].op = Op::Binary { x: b, y: divisor_const, bop: BOp::Mod };
if max_b < divisor {
self.remap(op_id, b);
}
return;
}
}
}
if divisor > c && divisor.is_multiple_of(c) {
if let Some(&(_min_a, max_a)) = bounds.get(&a)
&& let Some(&(min_b, max_b)) = bounds.get(&b)
{
let max_ac = max_a.saturating_mul(c);
if min_b == 0 && max_ac.saturating_add(max_b) < divisor {
self.remap(op_id, b);
return;
}
}
}
}
if let Op::Binary { x: a, y: b, bop: BOp::Add } = self.ops[x].op {
if let Some(&(min_a, max_a)) = bounds.get(&a) {
if let Some(&(min_b, max_b)) = bounds.get(&b) {
if min_a > 0 && min_b > 0 {
let sum = max_a.saturating_add(max_b);
if sum < divisor && sum > 0 {
self.remap(op_id, x);
return;
}
}
}
}
}
if let Op::Binary { x: a, y: c, bop: BOp::Mul } = self.ops[x].op {
if let Op::Const(y) = self.ops[c].op {
if let Some(c) = y.as_dim() {
let c_reduced = c % divisor;
if c_reduced != c && c_reduced > 0 {
if let Some(&(min_a, max_a)) = bounds.get(&a) {
if min_a > 0 {
let prod = max_a.saturating_mul(c_reduced);
if prod < divisor && prod > 0 {
self.remap(op_id, x);
return;
}
}
}
}
}
}
}
if let Op::Binary { x: a, y: b, bop: BOp::Add } = self.ops[x].op {
if let Op::Const(y) = self.ops[b].op {
if let Some(y) = y.as_dim() {
if let Some(&(_, max_a)) = bounds.get(&a) {
if max_a + y < divisor {
self.remap(op_id, x);
return;
}
}
}
}
}
}
}
fn mul_add(k: &Kernel, x: OpId) -> Option<(OpId, u64, OpId)> {
if let Some(x) = mad(k, x) {
return Some(x);
}
let Op::Binary { x: mul, y: add, bop: BOp::Add } = k.at(x) else {
return None;
};
if let Some((a, cval)) = match_mul_or_shl(k, *mul) {
return Some((a, cval, *add));
}
let Op::Binary { x: b, y: mul, bop: BOp::Add } = k.at(x) else {
return None;
};
if let Some((a, cval)) = match_mul_or_shl(k, *mul) {
return Some((a, cval, *b));
}
None
}
fn match_mul_or_shl(k: &Kernel, op: OpId) -> Option<(OpId, u64)> {
if let Op::Binary { x: a, y: c, bop: BOp::Mul } = k.at(op) {
if let Op::Const(cst) = k.at(*c) {
if let Some(cval) = cst.as_dim() {
return Some((*a, cval));
}
}
}
if let Op::Binary { x: a, y: c, bop: BOp::BitShiftLeft } = k.at(op) {
if let Op::Const(cst) = k.at(*c) {
if let Some(cval) = cst.as_dim() {
if cval < 64 {
return Some((*a, 1u64 << cval));
}
}
}
}
None
}
fn mad(k: &Kernel, x: OpId) -> Option<(OpId, u64, OpId)> {
let Op::Mad { x: a, y: c, z: b } = k.at(x) else { return None };
let Op::Const(cst) = k.at(*c) else { return None };
let cval = cst.as_dim()?;
Some((*a, cval, *b))
}