use super::autotune::Optimization;
use crate::{
backend::DeviceInfo,
dtype::Constant,
kernel::{BOp, Kernel, Op, OpId, Scope},
};
impl Kernel {
pub fn opt_tiled_reduce(&self, dev_info: &DeviceInfo) -> (Optimization, usize) {
#[cfg(feature = "time")]
let _timer = crate::Timer::new("opt_tiled_reduce");
if self.ops.values().any(|node| match node.op {
Op::Barrier { .. } => true,
_ => false,
}) {
return (Optimization::TiledReduce { factors: Vec::new() }, 0);
}
let n_loops = self.ops.values().filter(|node| matches!(node.op, Op::Loop { .. })).count();
if n_loops != 1 {
return (Optimization::TiledReduce { factors: Vec::new() }, 0);
}
let mut local_axis_sizes: crate::Map<u32, u64> = crate::Map::default();
for op in self.ops.values() {
if let Op::Index { scope: Scope::Local, axis, len } = op.op {
if let Some(&existing) = local_axis_sizes.get(&axis) {
debug_assert_eq!(existing, len);
} else {
local_axis_sizes.insert(axis, len);
}
}
}
let used_threads: u64 = local_axis_sizes.values().product::<u64>();
let remaining_threads = if local_axis_sizes.is_empty() {
dev_info.max_local_threads
} else {
dev_info.max_local_threads / used_threads
};
let candidates = vec![32, 16, 8, 64, 128];
let tree_branch_candidates = vec![2, 4];
let mut factors = Vec::new();
let mut op_id = self.head;
while !op_id.is_null() {
if let Op::Loop { len } = self.ops[op_id].op {
if len >= 256 {
for &factor in &candidates {
if len.is_multiple_of(factor) && len / factor >= 4 && remaining_threads >= factor {
for &tree_branch in &tree_branch_candidates {
factors.push((op_id, factor, tree_branch));
}
}
}
}
}
op_id = self.next_op(op_id);
}
let n = factors.len();
(Optimization::TiledReduce { factors }, n)
}
pub fn tiled_reduce(&mut self, loop_start: OpId, factor: u64, tree_branch: u64) {
let loop_len = if let Op::Loop { len } = self.at(loop_start) {
*len
} else {
return;
};
let laxis = self
.ops
.values()
.filter_map(|node| {
if let Op::Index { scope: Scope::Local, axis, .. } = node.op {
Some(axis + 1)
} else {
None
}
})
.max()
.unwrap_or(0);
if laxis > 2 {
return;
}
let mut op_id = loop_start;
let reg_acc;
let acc_dtype;
loop {
if let Op::Define { dtype, scope, ro, len } = self.ops[op_id].op {
if scope != Scope::Register || ro || len != 1 {
return;
}
reg_acc = op_id;
acc_dtype = dtype;
break;
}
op_id = self.prev_op(op_id);
if op_id == OpId::NULL {
return;
}
}
debug_assert!(!reg_acc.is_null());
let mut reduce_bop_id = OpId::NULL;
let acc_load_id;
let mut op_id = self.next_op(loop_start);
let mut depth = 1;
loop {
match self.ops[op_id].op {
Op::Store { dst, x, vlen, .. } => {
debug_assert_eq!(vlen, 1);
if dst == reg_acc {
reduce_bop_id = x;
}
}
Op::Load { src, vlen, .. } if depth == 0 && src == reg_acc => {
debug_assert_eq!(vlen, 1);
acc_load_id = op_id;
break;
}
Op::Loop { .. } => depth += 1,
Op::EndLoop => depth -= 1,
_ => {}
}
op_id = self.next_op(op_id);
if op_id.is_null() {
return;
}
}
debug_assert!(!reduce_bop_id.is_null());
let Op::Binary { bop, .. } = self.ops[reduce_bop_id].op else {
return;
};
let lidx = self.insert_before(reg_acc, Op::Index { len: factor, scope: Scope::Local, axis: laxis });
let factor_const = self.insert_before(loop_start, Op::Const(Constant::idx(factor as u64)));
let ridx = self.insert_before(loop_start, Op::Loop { len: loop_len / factor });
self.ops[loop_start].op = Op::Mad { x: ridx, y: factor_const, z: lidx };
let loc_acc = self.insert_before(
acc_load_id,
Op::Define { dtype: acc_dtype, scope: Scope::Local, ro: false, len: factor },
);
let const_zero = self.insert_before(acc_load_id, Op::Const(Constant::idx(0)));
let x = self.insert_before(acc_load_id, Op::Load { src: reg_acc, index: const_zero, vlen: 1 });
self.insert_before(acc_load_id, Op::Store { dst: loc_acc, x, index: lidx, vlen: 1 });
self.insert_before(acc_load_id, Op::Barrier { scope: Scope::Local });
let mut stride = factor;
while stride > 1 {
let use_tree_branch = stride >= tree_branch;
let active_threads = if use_tree_branch { stride / tree_branch } else { stride / 2 };
let limit_const = self.insert_before(acc_load_id, Op::Const(Constant::idx(active_threads as u64)));
let condition = self.insert_before(acc_load_id, Op::Binary { x: lidx, y: limit_const, bop: BOp::Cmplt });
self.insert_before(acc_load_id, Op::If { condition });
let branch = if use_tree_branch { tree_branch } else { 2 };
let mut sum_x = None;
for i in 1..branch {
let offset = i * active_threads;
let offset_const = self.insert_before(acc_load_id, Op::Const(Constant::idx(offset as u64)));
let offset_idx = self.insert_before(acc_load_id, Op::Binary { x: lidx, y: offset_const, bop: BOp::Add });
let local_load = self.insert_before(acc_load_id, Op::Load { src: loc_acc, index: offset_idx, vlen: 1 });
if let Some(prev_sum) = sum_x {
sum_x = Some(self.insert_before(acc_load_id, Op::Binary { x: prev_sum, y: local_load, bop }));
} else {
let current_val = self.insert_before(acc_load_id, Op::Load { src: loc_acc, index: lidx, vlen: 1 });
sum_x = Some(self.insert_before(acc_load_id, Op::Binary { x: current_val, y: local_load, bop }));
}
}
let bop_id = sum_x.unwrap();
self.insert_before(acc_load_id, Op::Store { dst: loc_acc, x: bop_id, index: lidx, vlen: 1 });
self.insert_before(acc_load_id, Op::EndIf);
self.insert_before(acc_load_id, Op::Barrier { scope: Scope::Local });
stride = active_threads;
}
let condition = self.insert_before(acc_load_id, Op::Binary { x: lidx, y: const_zero, bop: BOp::Eq });
self.insert_before(acc_load_id, Op::If { condition });
let final_val = self.insert_before(acc_load_id, Op::Load { src: loc_acc, index: const_zero, vlen: 1 });
self.insert_before(
acc_load_id,
Op::Store { dst: reg_acc, x: final_val, index: const_zero, vlen: 1 },
);
self.insert_after(self.tail, Op::EndIf);
self.verify();
}
}