use crate::{
DType, Map, Set,
dtype::Constant,
kernel::{BOp, IDX_T, Kernel, Op, OpId, Scope, UOp},
};
use std::hash::BuildHasherDefault;
impl Kernel {
#[allow(clippy::match_same_arms)]
pub fn constant_folding(&mut self) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("constant folding");
let mut op_id = self.head;
while !op_id.is_null() {
let next = self.next_op(op_id);
match *self.at(op_id) {
Op::Move { .. } | Op::ConstView { .. } | Op::LoadView { .. } | Op::StoreView { .. } | Op::Reduce { .. } => todo!(),
Op::Wmma { .. }
| Op::Barrier { .. }
| Op::If { .. }
| Op::EndIf => {}
| Op::Vectorize { .. } | Op::Devectorize { .. } | Op::Const(_)
| Op::Define { .. }
| Op::Load { .. }
| Op::Index { .. }
| Op::Loop { .. }
| Op::EndLoop => {}
Op::Store { dst, x, .. } => {
if let Op::Load { src, .. } = *self.at(x) {
if src == dst {
self.remove_op(op_id);
}
}
}
Op::Cast { x, dtype } => {
if let Op::Const(cx) = self.at(x) {
self.ops[op_id].op = Op::Const(cx.cast(dtype));
}
if let Op::Cast { x: inner_x, .. } = *self.at(x) {
self.ops[op_id].op = Op::Cast { x: inner_x, dtype };
}
if let Op::Binary { x: sub_x, y: sub_y, bop: BOp::Sub } = *self.at(x) {
let add_x = if let Op::Cast { x: inner_cast_x, .. } = *self.at(sub_x) {
inner_cast_x
} else {
sub_x
};
if let Op::Binary { x: inner_add_x, y: add_y, bop: BOp::Add } = *self.at(add_x) {
if self.constants_equal(add_y, sub_y) {
self.ops[op_id].op = Op::Cast { x: inner_add_x, dtype };
}
}
}
}
Op::Unary { x, uop } => {
if let Op::Const(cx) = self.at(x) {
self.ops[op_id].op = Op::Const(cx.unary(uop));
}
}
Op::Binary { x, y, bop } => match (self.at(x).clone(), self.at(y).clone()) {
(Op::Const(cx), Op::Const(cy)) => {
self.ops[op_id].op = Op::Const(Constant::binary(cx, cy, bop));
}
(Op::Const(cx), _) => match bop {
BOp::And if cx.dtype() == DType::Bool && cx.is_zero() => self.remap(op_id, x),
BOp::And if cx.dtype() == DType::Bool && cx.is_one() => self.remap(op_id, y),
BOp::Add if cx.is_zero() => self.remap(op_id, y),
BOp::Sub if cx.is_zero() => self.ops[op_id].op = Op::Unary { x: y, uop: UOp::Neg },
BOp::Mul | BOp::Div if cx.is_zero() => self.ops[op_id].op = Op::Const(cx),
BOp::Mul if cx.is_one() => self.remap(op_id, y),
BOp::Mul if cx.is_two() => self.ops[op_id].op = Op::Binary { x: y, y, bop: BOp::Add },
BOp::Mul if cx.is_power_of_two() && cx.dtype() == IDX_T => {
let c = self.insert_before(op_id, Op::Const(cx.unary(UOp::Log2)));
self.ops[op_id].op = Op::Binary { x: y, y: c, bop: BOp::BitShiftLeft };
}
BOp::Div if cx.is_zero() => self.remap(op_id, x),
BOp::Div if cx.is_one() => self.ops[op_id].op = Op::Unary { x: y, uop: UOp::Reciprocal },
BOp::Pow if cx.is_one() => self.ops[op_id].op = Op::Const(cx),
BOp::Max if cx.is_minimum() => self.remap(op_id, y),
BOp::BitShiftLeft | BOp::BitShiftRight if cx.is_zero() => self.remap(op_id, y),
_ => {}
},
(_, Op::Const(cy)) => match bop {
BOp::And if cy.dtype() == DType::Bool && cy.is_zero() => self.remap(op_id, y),
BOp::And if cy.dtype() == DType::Bool && cy.is_one() => self.remap(op_id, x),
BOp::Add | BOp::Sub if cy.is_zero() => self.remap(op_id, x),
BOp::Mul if cy.is_zero() => self.ops[op_id].op = Op::Const(cy),
BOp::Mul if cy.is_one() => self.remap(op_id, x),
BOp::Mul if cy.is_two() => self.ops[op_id].op = Op::Binary { x, y: x, bop: BOp::Add },
BOp::Mul if cy.is_power_of_two() && cy.dtype() == IDX_T => {
let c = self.insert_before(op_id, Op::Const(cy.unary(UOp::Log2)));
self.ops[op_id].op = Op::Binary { x, y: c, bop: BOp::BitShiftLeft };
}
BOp::Div if cy.is_zero() => panic!("Division by constant zero"),
BOp::Div if cy.is_one() => self.remap(op_id, x),
BOp::Div if cy.is_power_of_two() && cy.dtype() == IDX_T => {
let y = self.insert_before(op_id, Op::Const(cy.unary(UOp::Log2)));
self.ops[op_id].op = Op::Binary { x, y, bop: BOp::BitShiftRight };
}
BOp::Mod if cy.is_zero() => panic!("Modulo by constant zero"),
BOp::Mod if cy.is_zero() && cy.dtype() == IDX_T => {
let shift = Constant::binary(cy, Constant::idx(1), BOp::Sub);
let y = self.insert_before(op_id, Op::Const(shift));
self.ops[op_id].op = Op::Binary { x, y, bop: BOp::BitAnd };
}
BOp::Mod if cy.dtype() == IDX_T => {
if let Op::Binary { bop, x: xi, y: yi } = self.ops[x].op {
if bop == BOp::Mod
&& let Op::Const(ciy) = self.ops[yi].op
{
if ciy > cy {
self.ops[op_id].op = Op::Binary { x: xi, y, bop: BOp::Mod };
} else {
self.ops[op_id].op = Op::Binary { x: xi, y: yi, bop: BOp::Mod };
}
}
}
}
BOp::Pow if cy.is_zero() => self.ops[op_id].op = Op::Const(cy.dtype().one_constant()),
BOp::Pow if cy.is_one() => self.remap(op_id, x),
BOp::Pow if cy.is_two() => self.ops[op_id].op = Op::Binary { x, y: x, bop: BOp::Mul },
BOp::BitShiftLeft if cy.is_zero() => self.remap(op_id, x),
BOp::BitShiftRight if cy.is_zero() => self.remap(op_id, x),
_ => {}
},
(x_op, y_op) if x_op == y_op => {
match bop {
BOp::Div => todo!(), BOp::Sub => todo!(), _ => {}
}
}
_ => {}
},
Op::Mad { x, y, z } => {
match (self.at(x).clone(), self.at(y).clone(), self.at(z).clone()) {
(Op::Const(cx), Op::Const(cy), Op::Const(cz)) => {
let mul = Constant::binary(cx, cy, BOp::Mul);
self.ops[op_id].op = Op::Const(Constant::binary(mul, cz, BOp::Add));
}
(Op::Const(cx), Op::Const(cy), _) => {
let mul = Constant::binary(cx, cy, BOp::Mul);
let x = self.insert_before(op_id, Op::Const(mul));
self.ops[op_id].op = Op::Binary { x, y: z, bop: BOp::Add };
}
(Op::Const(cx), _, _) if cx.is_zero() => {
self.remap(op_id, z);
}
(Op::Const(cx), _, _) if cx.is_one() => {
self.ops[op_id].op = Op::Binary { x: y, y: z, bop: BOp::Add };
}
(_, Op::Const(cy), _) if cy.is_zero() => {
self.remap(op_id, z);
}
(_, Op::Const(cy), _) if cy.is_one() => {
self.ops[op_id].op = Op::Binary { x, y: z, bop: BOp::Add };
}
(_, _, Op::Const(cz)) if cz.is_zero() => {
self.ops[op_id].op = Op::Binary { x, y, bop: BOp::Mul };
}
_ => {}
}
}
}
op_id = next;
}
self.verify();
}
pub fn fold_accs(&mut self) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("fold_accs");
self.constant_folding();
let mut defines = Map::default();
let mut loop_level = 0u32;
let mut op_id = self.head;
while !op_id.is_null() {
match *self.at(op_id) {
Op::Define { scope: Scope::Register, .. } => {
defines.insert(op_id, loop_level);
}
Op::Store { dst, .. } => {
if let Some(level) = defines.get(&dst) {
if loop_level > *level {
defines.remove(&dst);
}
}
}
Op::Loop { .. } => {
loop_level += 1;
}
Op::EndLoop => {
loop_level -= 1;
}
_ => {}
}
op_id = self.next_op(op_id);
}
for (define, _) in defines {
self.fold_acc(define);
}
}
pub fn fold_acc(&mut self, define_id: OpId) {
let Op::Define { len, .. } = self.ops[define_id].op else { unreachable!() };
self.remove_op(define_id);
let mut latest_stores = vec![OpId::NULL; len as usize];
let mut remaps = Map::default();
let mut op_id = self.head;
while !op_id.is_null() {
let next = self.next_op(op_id);
match *self.at(op_id) {
Op::Store { dst, x, index, vlen } => {
if vlen > 1 {
todo!()
}
if dst == define_id {
self.remove_op(op_id);
if self.ops.contains_key(x) {
let Op::Const(index) = self.ops[index].op else { unreachable!() };
let Constant::U32(index) = index else { unreachable!() };
latest_stores[index as usize] = x;
}
op_id = next;
continue;
}
}
Op::Load { src, index, .. } if src == define_id => {
self.remove_op(op_id);
let Op::Const(index) = self.ops[index].op else { unreachable!() };
let Constant::U32(index) = index else { unreachable!() };
remaps.insert(op_id, latest_stores[index as usize]);
op_id = next;
continue;
}
_ => {}
}
self.ops[op_id].op.remap_params(&remaps);
op_id = next;
}
self.verify();
}
pub fn delete_empty_loops(&mut self) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("delete_empty_loops");
let mut dead = Set::default();
let mut defines_stack: Vec<Set<OpId>> = Vec::new();
defines_stack.push(Set::default());
let mut ops_stack: Vec<Set<OpId>> = Vec::new();
ops_stack.push(Set::default());
let mut delete_stack: Vec<bool> = Vec::new();
delete_stack.push(false);
let mut op_id = self.head;
while !op_id.is_null() {
match self.at(op_id) {
Op::Loop { .. } | Op::If { .. } => {
ops_stack.push(Set::default());
defines_stack.push(Set::default());
delete_stack.push(true);
for slice in &mut ops_stack {
slice.insert(op_id);
}
}
Op::Define { ro, .. } => {
if !ro {
defines_stack.last_mut().unwrap().insert(op_id);
}
}
Op::Store { dst, .. } => {
for (i, defines_set) in defines_stack.iter().enumerate().take(defines_stack.len() - 1) {
if defines_set.contains(dst) {
for delete_flag in delete_stack.iter_mut().skip(i + 1) {
*delete_flag = false;
}
break;
}
}
for slice in &mut ops_stack {
slice.insert(op_id);
}
}
Op::EndLoop | Op::EndIf => {
for slice in &mut ops_stack {
slice.insert(op_id);
}
defines_stack.pop();
if let Some(delete_slice) = delete_stack.pop() {
if delete_slice {
dead.extend(ops_stack.pop().unwrap());
} else {
ops_stack.pop();
}
} else {
ops_stack.pop();
}
}
_ => {
for slice in &mut ops_stack {
slice.insert(op_id);
}
}
}
op_id = self.next_op(op_id);
}
for op_id in dead {
self.remove_op(op_id);
}
self.verify();
}
pub fn dead_code_elimination(&mut self) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("dead_code_elimination");
let mut params = Vec::new();
let mut visited = Set::default();
for (op_id, op) in self.iter_unordered() {
if matches!(
op,
Op::Store { .. }
| Op::Define { .. }
| Op::Wmma { .. }
| Op::Barrier { .. }
| Op::If { .. }
| Op::EndIf
| Op::Loop { .. }
| Op::EndLoop
| Op::StoreView { .. }
) {
params.push(op_id);
}
}
while let Some(op_id) = params.pop() {
if visited.insert(op_id) {
params.extend(self.at(op_id).parameters());
}
}
#[allow(clippy::needless_collect)] for op_id in self.ops.ids().collect::<Vec<_>>() {
if !visited.contains(&op_id) {
self.remove_op(op_id);
}
}
self.verify();
}
pub fn common_subexpression_elimination(&mut self) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("common_subexpression_elimination");
let mut stack: Vec<Map<Op, OpId>> = Vec::with_capacity(10);
stack.push(Map::with_capacity_and_hasher(20, BuildHasherDefault::default()));
let mut stored_stack: Vec<Set<OpId>> = Vec::with_capacity(10);
stored_stack.push(Set::with_capacity_and_hasher(10, BuildHasherDefault::default()));
let mut remaps = Map::with_capacity_and_hasher(10, BuildHasherDefault::default());
let mut op_id = self.head;
while !op_id.is_null() {
match &mut self.ops[op_id].op {
Op::Barrier { .. } | Op::Define { .. } => {} Op::If { .. } | Op::Loop { .. } => {
stack.push(Map::with_capacity_and_hasher(20, BuildHasherDefault::default()));
stored_stack.push(Set::with_capacity_and_hasher(10, BuildHasherDefault::default()));
}
Op::EndIf | Op::EndLoop => {
stack.pop();
stored_stack.pop();
}
&mut Op::Store { dst, .. } => {
stored_stack.last_mut().unwrap().insert(dst);
}
op => {
let mut remove_op = false;
let can_cse = if let Op::Load { src, .. } = op {
if stored_stack.iter().rev().any(|x| x.contains(src)) {
for x in stored_stack.iter_mut() {
x.remove(src);
}
false
} else {
true
}
} else {
true
};
if can_cse {
for loop_level in &stack {
if let Some(&old_op_id) = loop_level.get(op) {
remaps.insert(op_id, old_op_id);
remove_op = true;
break;
}
}
}
if !remove_op {
for param in op.parameters_mut() {
if let Some(&new_id) = remaps.get(param) {
*param = new_id;
}
}
stack.last_mut().unwrap().insert(op.clone(), op_id);
}
}
}
op_id = self.next_op(op_id);
}
self.verify();
}
pub fn move_constants_to_beginning(&mut self) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("move_constants_to_beginning");
let mut start = self.head;
while let Op::Define { .. } = self.at(start) {
start = self.next_op(start);
}
let mut op_id = start;
let mut start = self.prev_op(start);
while !op_id.is_null() {
let next = self.next_op(op_id);
if let Op::Const(_) = self.at(op_id) {
self.move_op_after(op_id, start);
start = op_id;
}
op_id = next;
}
let mut start = self.head;
while let Op::Define { .. } | Op::Const(_) = self.at(start) {
start = self.next_op(start);
}
let mut op_id = start;
let mut start = self.prev_op(start);
while !op_id.is_null() {
let next = self.next_op(op_id);
if let Op::Index { .. } = self.at(op_id) {
self.move_op_after(op_id, start);
start = op_id;
}
op_id = next;
}
#[cfg(debug_assertions)]
self.verify();
}
fn constants_equal(&self, a: OpId, b: OpId) -> bool {
let a = self.at(a);
let b = self.at(b);
match (a, b) {
(Op::Const(ca), Op::Const(cb)) => {
let a_val: Option<i64> = match ca {
Constant::U32(x) => Some(*x as i64),
Constant::I32(x) => Some(*x as i64),
Constant::U64(x) | Constant::I64(x) => Some(i64::from_le_bytes(*x)),
_ => None,
};
let b_val: Option<i64> = match cb {
Constant::U32(x) => Some(*x as i64),
Constant::I32(x) => Some(*x as i64),
Constant::U64(x) | Constant::I64(x) => Some(i64::from_le_bytes(*x)),
_ => None,
};
match (a_val, b_val) {
(Some(av), Some(bv)) => av == bv,
_ => false,
}
}
_ => false,
}
}
pub fn unfold_pows(&mut self) {
let mut op_id = self.head;
while !op_id.is_null() {
if let &Op::Binary { x, y, bop } = self.at(op_id) {
if bop == BOp::Pow {
let x = self.insert_before(op_id, Op::Unary { x, uop: UOp::Log2 });
let x = self.insert_before(op_id, Op::Binary { x, y, bop: BOp::Mul });
self.ops[op_id].op = Op::Unary { x, uop: UOp::Exp2 };
}
}
op_id = self.next_op(op_id);
}
self.verify();
}
}