use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
ops::{Deref, DerefMut},
rc::Rc,
sync::atomic::{AtomicUsize, Ordering},
};
use cubecl_core::{
ir::{self as core, Branch, Operator, Variable},
CubeDim,
};
use cubecl_core::{
ir::{Item, Operation, Scope},
ExecutionMode,
};
use gvn::GvnPass;
use passes::{
CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
EliminateConstBranches, EliminateDeadBlocks, EliminateUnusedVariables, EmptyBranchToSelect,
FindConstSliceLen, InBoundsToUnchecked, InlineAssignments, IntegerRangeAnalysis, MergeBlocks,
MergeSameExpressions, OptimizerPass, ReduceStrength, RemoveIndexScalar,
};
use petgraph::{prelude::StableDiGraph, visit::EdgeRef, Direction};
mod block;
mod control_flow;
mod debug;
mod gvn;
mod instructions;
mod passes;
mod phi_frontiers;
mod version;
pub use block::*;
pub use control_flow::*;
pub use petgraph::graph::{EdgeIndex, NodeIndex};
pub use version::PhiInstruction;
#[derive(Clone, Debug, Default)]
pub struct AtomicCounter {
inner: Rc<AtomicUsize>,
}
impl AtomicCounter {
pub fn new(val: usize) -> Self {
Self {
inner: Rc::new(AtomicUsize::new(val)),
}
}
pub fn inc(&self) -> usize {
self.inner.fetch_add(1, Ordering::AcqRel)
}
pub fn get(&self) -> usize {
self.inner.load(Ordering::Acquire)
}
}
#[derive(Debug, Clone)]
pub(crate) struct Slice {
pub(crate) start: Variable,
pub(crate) end: Variable,
pub(crate) end_op: Option<Operation>,
pub(crate) const_len: Option<u32>,
}
#[derive(Default, Debug, Clone)]
struct Program {
pub variables: HashMap<(u16, u8), Item>,
pub(crate) slices: HashMap<(u16, u8), Slice>,
pub graph: StableDiGraph<BasicBlock, ()>,
root: NodeIndex,
int_ranges: HashMap<VarId, Range>,
temp_id: AtomicCounter,
}
impl Deref for Program {
type Target = StableDiGraph<BasicBlock, ()>;
fn deref(&self) -> &Self::Target {
&self.graph
}
}
impl DerefMut for Program {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.graph
}
}
type VarId = (u16, u8, u16);
#[derive(Default, Clone, Copy, PartialEq, Eq, Debug)]
struct Range {
lower_bound: Option<i64>,
upper_bound: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct Optimizer {
program: Program,
post_order: Vec<NodeIndex>,
current_block: Option<NodeIndex>,
loop_break: VecDeque<NodeIndex>,
pub ret: NodeIndex,
root_scope: Scope,
pub(crate) cube_dim: CubeDim,
pub(crate) mode: ExecutionMode,
pub(crate) gvn: Rc<RefCell<GvnPass>>,
}
impl Default for Optimizer {
fn default() -> Self {
Self {
program: Default::default(),
current_block: Default::default(),
loop_break: Default::default(),
ret: Default::default(),
root_scope: Scope::root(),
cube_dim: Default::default(),
mode: Default::default(),
post_order: Default::default(),
gvn: Default::default(),
}
}
}
impl Optimizer {
pub fn new(expand: Scope, cube_dim: CubeDim, mode: ExecutionMode) -> Self {
let mut opt = Self {
root_scope: expand.clone(),
cube_dim,
mode,
..Default::default()
};
opt.run_opt(expand);
opt
}
fn run_opt(&mut self, expand: Scope) {
self.parse_graph(expand);
self.split_critical_edges();
self.determine_postorder(self.entry(), &mut HashSet::new());
self.analyze_liveness();
self.apply_pre_ssa_passes();
self.exempt_index_assign_locals();
self.ssa_transform();
self.apply_post_ssa_passes();
let arrays_prop = AtomicCounter::new(0);
CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
if arrays_prop.get() > 0 {
self.analyze_liveness();
self.ssa_transform();
self.apply_post_ssa_passes();
}
let gvn_count = AtomicCounter::new(0);
let gvn = self.gvn.clone();
gvn.borrow_mut().apply_post_ssa(self, gvn_count.clone());
ReduceStrength.apply_post_ssa(self, gvn_count.clone());
CopyTransform.apply_post_ssa(self, gvn_count.clone());
if gvn_count.get() > 0 {
self.apply_post_ssa_passes();
}
MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
}
pub fn entry(&self) -> NodeIndex {
self.program.root
}
fn parse_graph(&mut self, scope: Scope) {
let entry = self.program.add_node(BasicBlock::default());
self.program.root = entry;
self.current_block = Some(entry);
self.ret = self.program.add_node(BasicBlock::default());
*self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
self.parse_scope(scope);
if let Some(current_block) = self.current_block {
self.program.add_edge(current_block, self.ret, ());
}
}
fn determine_postorder(&mut self, block: NodeIndex, visited: &mut HashSet<NodeIndex>) {
for successor in self.successors(block) {
if !visited.contains(&successor) {
visited.insert(successor);
self.determine_postorder(successor, visited);
}
}
self.post_order.push(block);
}
pub fn post_order(&self) -> Vec<NodeIndex> {
self.post_order.clone()
}
pub fn reverse_post_order(&self) -> Vec<NodeIndex> {
self.post_order.iter().rev().copied().collect()
}
fn apply_pre_ssa_passes(&mut self) {
let mut passes = vec![CompositeMerge];
loop {
let counter = AtomicCounter::default();
for pass in &mut passes {
pass.apply_pre_ssa(self, counter.clone());
}
if counter.get() == 0 {
break;
}
}
}
fn apply_post_ssa_passes(&mut self) {
let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
Box::new(InlineAssignments),
Box::new(EliminateUnusedVariables),
Box::new(ConstOperandSimplify),
Box::new(MergeSameExpressions),
Box::new(ConstEval),
Box::new(RemoveIndexScalar),
Box::new(EliminateConstBranches),
Box::new(EmptyBranchToSelect),
Box::new(EliminateDeadBlocks),
Box::new(MergeBlocks),
];
let checked_passes: Vec<Box<dyn OptimizerPass>> = vec![
Box::new(IntegerRangeAnalysis),
Box::new(FindConstSliceLen),
Box::new(InBoundsToUnchecked),
];
if matches!(self.mode, ExecutionMode::Checked) {
passes.extend(checked_passes);
}
loop {
let counter = AtomicCounter::default();
for pass in &mut passes {
pass.apply_post_ssa(self, counter.clone());
}
if counter.get() == 0 {
break;
}
}
}
fn exempt_index_assign_locals(&mut self) {
for node in self.node_ids() {
let ops = self.program[node].ops.clone();
for op in ops.borrow().values() {
if let Operation::Operator(Operator::IndexAssign(binop)) = op {
if let Variable::Local { id, depth, .. } = &binop.out {
self.program.variables.remove(&(*id, *depth));
}
}
}
}
}
fn node_ids(&self) -> Vec<NodeIndex> {
self.program.node_indices().collect()
}
fn ssa_transform(&mut self) {
self.program.fill_dom_frontiers();
self.program.place_phi_nodes();
self.version_program();
self.program.variables.clear();
for block in self.node_ids() {
self.program[block].writes.clear();
self.program[block].dom_frontiers.clear();
}
}
pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
&mut self.program[self.current_block.unwrap()]
}
pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
self.program
.edges_directed(block, Direction::Incoming)
.map(|it| it.source())
.collect()
}
pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
self.program
.edges_directed(block, Direction::Outgoing)
.map(|it| it.target())
.collect()
}
#[track_caller]
pub fn block(&self, block: NodeIndex) -> &BasicBlock {
&self.program[block]
}
#[track_caller]
pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
&mut self.program[block]
}
pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
let processed = scope.process();
for var in processed.variables {
if let Variable::Local { id, item, depth } = var {
self.program.variables.insert((id, depth), item);
}
}
let is_break = processed
.operations
.contains(&Operation::Branch(Branch::Break));
for instruction in processed.operations {
match instruction {
Operation::Branch(branch) => self.parse_control_flow(branch),
Operation::Operator(Operator::Slice(slice_op)) => {
let out_id = match &slice_op.out {
Variable::Slice { id, depth, .. } => (*id, *depth),
_ => unreachable!(),
};
let const_len = slice_op.start.as_const().zip(slice_op.end.as_const());
let const_len = const_len.map(|(start, end)| end.as_u32() - start.as_u32());
self.program.slices.insert(
out_id,
Slice {
start: slice_op.start,
end: slice_op.end,
end_op: None,
const_len,
},
);
let mut op = Operation::Operator(Operator::Slice(slice_op));
self.visit_operation(&mut op, |_, _| {}, |opt, var| opt.write_var(var));
self.current_block_mut().ops.borrow_mut().push(op);
}
mut other => {
self.visit_operation(&mut other, |_, _| {}, |opt, var| opt.write_var(var));
self.current_block_mut().ops.borrow_mut().push(other);
}
}
}
is_break
}
pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<(u16, u8)> {
match variable {
core::Variable::Local { id, depth, item } if !item.elem.is_atomic() => {
Some((*id, *depth))
}
_ => None,
}
}
pub fn create_temporary(&self, item: Item) -> Variable {
let next_id = self.program.temp_id.inc() as u16;
Variable::LocalBinding {
id: u16::MAX - next_id,
item,
depth: u8::MAX,
}
}
pub(crate) fn ret(&mut self) -> NodeIndex {
if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
let new_ret = self.program.add_node(BasicBlock::default());
self.program.add_edge(new_ret, self.ret, ());
self.ret = new_ret;
new_ret
} else {
self.ret
}
}
}
pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
#[cfg(test)]
mod test {
use cubecl_core::{
self as cubecl,
ir::{Elem, HybridAllocator, Item, Variable},
prelude::{Array, CubeContext, ExpandElement},
};
use cubecl_core::{cube, CubeDim, ExecutionMode};
use crate::Optimizer;
#[allow(unused)]
#[cube(launch)]
fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
let mut y = 0;
let mut z = 0;
if cond == 0 {
y = x + 4;
}
z = x + 4;
out[0] = y;
out[1] = z;
}
#[test]
#[ignore = "no good way to assert opt is applied"]
fn test_pre() {
let mut ctx = CubeContext::root(HybridAllocator::default());
let x = ExpandElement::Plain(Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
});
let cond = ExpandElement::Plain(Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
});
let arr = ExpandElement::Plain(Variable::GlobalOutputArray {
id: 0,
item: Item::new(Elem::UInt),
});
pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
let scope = ctx.into_scope();
let opt = Optimizer::new(scope, CubeDim::default(), ExecutionMode::Checked);
println!("{opt}")
}
}