use super::autotune::Optimization;
use crate::kernel::{Kernel, Op, OpId};
use crate::{Map, Set};
impl Kernel {
pub const fn opt_licm(_: &Kernel) -> (Optimization, usize) {
(Optimization::Licm, 1)
}
pub const fn opt_reassociate_commutative(_: &Kernel) -> (Optimization, usize) {
(Optimization::ReassociateCommutative, 1)
}
pub fn swap_commutative(&mut self) {
let mut loop_dep: Map<OpId, usize> = Map::default();
let mut loop_depth = 0;
let mut op_id = self.head;
while !op_id.is_null() {
let depth = match self.at(op_id) {
Op::Move { .. } | Op::ConstView { .. } | Op::LoadView { .. } | Op::StoreView { .. } | Op::Reduce { .. } => {
unreachable!()
}
Op::Devectorize { .. } | Op::Wmma { .. } | Op::Vectorize { .. } => loop_depth,
Op::If { .. } | Op::Loop { .. } => {
loop_depth += 1;
loop_depth
}
Op::EndIf | Op::EndLoop => {
loop_depth -= 1;
loop_depth
}
Op::Unary { x, .. } | Op::Cast { x, .. } => loop_dep[x],
&Op::Binary { x, y, bop } => {
if bop.is_commutative() && !self.ops[x].op.is_const() {
if loop_dep[&x] > loop_dep[&y] || self.ops[y].op.is_const() || self.ops[x].op.is_load() {
if let Op::Binary { x, y, .. } = &mut self.ops[op_id].op {
std::mem::swap(x, y);
}
}
}
loop_dep[&x].max(loop_dep[&y])
}
Op::Mad { x, y, z } => loop_dep[x].max(loop_dep[y]).max(loop_dep[z]),
Op::Barrier { .. } | Op::Index { .. } | Op::Load { .. } | Op::Store { .. } | Op::Const(_) | Op::Define { .. } => {
loop_depth
}
};
loop_dep.insert(op_id, depth);
op_id = self.next_op(op_id);
}
self.verify();
}
pub fn reassociate_commutative(&mut self) {
let mut loop_dep: Map<OpId, usize> = Map::default();
let mut loop_depth = 0;
let mut op_id = self.head;
while !op_id.is_null() {
let depth = match self.at(op_id) {
Op::Move { .. } | Op::ConstView { .. } | Op::LoadView { .. } | Op::StoreView { .. } | Op::Reduce { .. } => {
unreachable!()
}
Op::Vectorize { ops } => {
let mut max = 0;
for op in ops {
max = max.max(loop_dep[op]);
}
max
}
Op::Devectorize { .. } => todo!(),
Op::Mad { x, y, z } => loop_dep[x].max(loop_dep[y]).max(loop_dep[z]),
Op::Loop { .. } | Op::If { .. } => {
loop_depth += 1;
loop_depth
}
Op::EndLoop | Op::EndIf => {
loop_depth -= 1;
loop_depth
}
Op::Unary { x, .. } | Op::Cast { x, .. } => loop_dep[x],
Op::Binary { x, y, .. } => loop_dep[x].max(loop_dep[y]),
Op::Index { .. }
| Op::Barrier { .. }
| Op::Load { .. }
| Op::Store { .. }
| Op::Const(_)
| Op::Define { .. }
| Op::Wmma { .. } => loop_depth,
};
loop_dep.insert(op_id, depth);
op_id = self.next_op(op_id);
}
let mut op_id = self.head;
'a: while !op_id.is_null() {
let next = self.next_op(op_id);
if let &Op::Binary { bop, .. } = self.at(op_id) {
if !bop.is_commutative() || !bop.is_associative() {
op_id = next;
continue 'a;
}
let mut params = vec![op_id];
let mut chain = Vec::new();
while let Some(param) = params.pop() {
if let &Op::Binary { x, y, bop: t_bop } = self.at(param) {
if t_bop == bop {
params.push(x);
params.push(y);
continue;
}
}
chain.push(param);
if chain.len() > 20 {
op_id = next;
continue 'a;
}
}
if chain.len() < 2 {
op_id = next;
continue 'a;
}
chain.sort_by_key(|id| loop_dep[id]);
let mut prev_acc = chain[0];
let mut j = 1;
while j < chain.len() - 1 {
let op = Op::Binary { x: chain[j], y: prev_acc, bop };
let new_acc = self.insert_before(op_id, op);
prev_acc = new_acc;
j += 1;
}
self.ops[op_id].op = Op::Binary { x: chain[j], y: prev_acc, bop };
}
op_id = next;
}
self.verify();
}
pub fn loop_invariant_code_motion(&mut self) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("loop_invariant_code_motion");
let mut endloop_is = Vec::new();
let mut loop_id = self.tail;
while !loop_id.is_null() {
if *self.at(loop_id) == Op::EndLoop {
endloop_is.push(loop_id);
}
if let Op::Loop { .. } = self.at(loop_id) {
let mut op_ids_in_loop = Set::default();
op_ids_in_loop.insert(loop_id);
let mut op_id = loop_id;
let endloop_id = endloop_is.pop().unwrap();
while op_id != endloop_id {
let op = self.at(op_id);
let next_op_id = self.next_op(op_id);
if !matches!(
op,
Op::Store { .. } | Op::Load { .. } | Op::Loop { .. } | Op::EndLoop | Op::Define { .. }
) && op.parameters().all(|op_id| !op_ids_in_loop.contains(&op_id))
{
self.move_op_before(op_id, loop_id);
} else {
op_ids_in_loop.insert(op_id);
}
op_id = next_op_id;
}
}
loop_id = self.prev_op(loop_id);
}
self.verify();
}
}