#![allow(unknown_lints, unnecessary_transmutes)]
use std::{
collections::{HashMap, VecDeque},
ops::{Deref, DerefMut},
rc::Rc,
sync::atomic::{AtomicUsize, Ordering},
};
use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
use cubecl_core::CubeDim;
use cubecl_ir::{
self as core, Allocator, Branch, Id, Operation, Operator, Processor, Scope, Type, Variable,
VariableKind,
};
use gvn::GvnPass;
use passes::{
CompositeMerge, ConstEval, ConstOperandSimplify, CopyTransform, DisaggregateArray,
EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
ReduceStrength, RemoveIndexScalar,
};
use petgraph::{
Direction,
dot::{Config, Dot},
prelude::StableDiGraph,
visit::EdgeRef,
};
mod analyses;
mod block;
mod control_flow;
mod debug;
mod gvn;
mod instructions;
mod passes;
mod phi_frontiers;
mod transformers;
mod version;
pub use analyses::uniformity::Uniformity;
pub use block::*;
pub use control_flow::*;
pub use petgraph::graph::{EdgeIndex, NodeIndex};
pub use transformers::*;
pub use version::PhiInstruction;
pub use crate::analyses::liveness::shared::{SharedLiveness, SharedMemory};
#[derive(Clone, Debug, Default)]
pub struct AtomicCounter {
inner: Rc<AtomicUsize>,
}
impl AtomicCounter {
pub fn new(val: usize) -> Self {
Self {
inner: Rc::new(AtomicUsize::new(val)),
}
}
pub fn inc(&self) -> usize {
self.inner.fetch_add(1, Ordering::AcqRel)
}
pub fn get(&self) -> usize {
self.inner.load(Ordering::Acquire)
}
}
#[derive(Debug, Clone)]
pub struct ConstArray {
pub id: Id,
pub length: usize,
pub item: Type,
pub values: Vec<core::Variable>,
}
#[derive(Default, Debug, Clone)]
pub struct Program {
pub const_arrays: Vec<ConstArray>,
pub variables: HashMap<Id, Type>,
pub graph: StableDiGraph<BasicBlock, u32>,
root: NodeIndex,
}
impl Deref for Program {
type Target = StableDiGraph<BasicBlock, u32>;
fn deref(&self) -> &Self::Target {
&self.graph
}
}
impl DerefMut for Program {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.graph
}
}
type VarId = (Id, u16);
#[derive(Debug, Clone)]
pub struct Optimizer {
pub program: Program,
pub allocator: Allocator,
analysis_cache: Rc<AnalysisCache>,
current_block: Option<NodeIndex>,
loop_break: VecDeque<NodeIndex>,
pub ret: NodeIndex,
pub root_scope: Scope,
pub(crate) cube_dim: CubeDim,
pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
pub(crate) processors: Rc<Vec<Box<dyn Processor>>>,
}
unsafe impl Send for Optimizer {}
unsafe impl Sync for Optimizer {}
impl Default for Optimizer {
fn default() -> Self {
Self {
program: Default::default(),
allocator: Default::default(),
current_block: Default::default(),
loop_break: Default::default(),
ret: Default::default(),
root_scope: Scope::root(false),
cube_dim: CubeDim::new_1d(1),
analysis_cache: Default::default(),
transformers: Default::default(),
processors: Default::default(),
}
}
}
impl Optimizer {
pub fn new(
expand: Scope,
cube_dim: CubeDim,
transformers: Vec<Rc<dyn IrTransformer>>,
processors: Vec<Box<dyn Processor>>,
) -> Self {
let mut opt = Self {
root_scope: expand.clone(),
cube_dim,
allocator: expand.allocator.clone(),
transformers,
processors: Rc::new(processors),
..Default::default()
};
opt.run_opt();
opt
}
pub fn shared_only(expand: Scope, cube_dim: CubeDim) -> Self {
let mut opt = Self {
root_scope: expand.clone(),
cube_dim,
allocator: expand.allocator.clone(),
transformers: Vec::new(),
processors: Rc::new(Vec::new()),
..Default::default()
};
opt.run_shared_only();
opt
}
fn run_opt(&mut self) {
self.parse_graph(self.root_scope.clone());
self.split_critical_edges();
self.transform_ssa_and_merge_composites();
self.apply_post_ssa_passes();
let arrays_prop = AtomicCounter::new(0);
log::debug!("Applying {}", DisaggregateArray.name());
DisaggregateArray.apply_post_ssa(self, arrays_prop.clone());
if arrays_prop.get() > 0 {
self.invalidate_analysis::<Liveness>();
self.ssa_transform();
self.apply_post_ssa_passes();
}
let gvn_count = AtomicCounter::new(0);
log::debug!("Applying {}", GvnPass.name());
GvnPass.apply_post_ssa(self, gvn_count.clone());
log::debug!("Applying {}", ReduceStrength.name());
ReduceStrength.apply_post_ssa(self, gvn_count.clone());
log::debug!("Applying {}", CopyTransform.name());
CopyTransform.apply_post_ssa(self, gvn_count.clone());
if gvn_count.get() > 0 {
self.apply_post_ssa_passes();
}
self.split_free();
self.analysis::<SharedLiveness>();
log::debug!("Applying {}", MergeBlocks.name());
MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
}
fn run_shared_only(&mut self) {
self.parse_graph(self.root_scope.clone());
self.split_critical_edges();
self.transform_ssa_and_merge_composites();
self.split_free();
self.analysis::<SharedLiveness>();
}
pub fn entry(&self) -> NodeIndex {
self.program.root
}
fn parse_graph(&mut self, scope: Scope) {
let entry = self.program.add_node(BasicBlock::default());
self.program.root = entry;
self.current_block = Some(entry);
self.ret = self.program.add_node(BasicBlock::default());
*self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
self.parse_scope(scope);
if let Some(current_block) = self.current_block {
self.program.add_edge(current_block, self.ret, 0);
}
self.invalidate_structure();
}
fn apply_post_ssa_passes(&mut self) {
let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
Box::new(InlineAssignments),
Box::new(EliminateUnusedVariables),
Box::new(ConstOperandSimplify),
Box::new(MergeSameExpressions),
Box::new(ConstEval),
Box::new(RemoveIndexScalar),
Box::new(EliminateConstBranches),
Box::new(EmptyBranchToSelect),
Box::new(EliminateDeadBlocks),
Box::new(EliminateDeadPhi),
];
log::debug!("Applying post-SSA passes");
loop {
let counter = AtomicCounter::default();
for pass in &mut passes {
log::debug!("Applying {}", pass.name());
pass.apply_post_ssa(self, counter.clone());
}
if counter.get() == 0 {
break;
}
}
}
fn exempt_index_assign_locals(&mut self) {
for node in self.node_ids() {
let ops = self.program[node].ops.clone();
for op in ops.borrow().values() {
if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation
&& let VariableKind::LocalMut { id } = &op.out().kind
{
self.program.variables.remove(id);
}
}
}
}
pub fn node_ids(&self) -> Vec<NodeIndex> {
self.program.node_indices().collect()
}
fn transform_ssa_and_merge_composites(&mut self) {
self.exempt_index_assign_locals();
self.ssa_transform();
let mut done = false;
while !done {
let changes = AtomicCounter::new(0);
CompositeMerge.apply_post_ssa(self, changes.clone());
if changes.get() > 0 {
self.exempt_index_assign_locals();
self.ssa_transform();
} else {
done = true;
}
}
}
fn ssa_transform(&mut self) {
self.place_phi_nodes();
self.version_program();
self.program.variables.clear();
self.invalidate_analysis::<Writes>();
self.invalidate_analysis::<DomFrontiers>();
}
pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
&mut self.program[self.current_block.unwrap()]
}
pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
self.program
.edges_directed(block, Direction::Incoming)
.map(|it| it.source())
.filter(|it| !self.is_unreachable(*it))
.collect()
}
pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
self.program
.edges_directed(block, Direction::Outgoing)
.map(|it| it.target())
.collect()
}
#[track_caller]
pub fn block(&self, block: NodeIndex) -> &BasicBlock {
&self.program[block]
}
#[track_caller]
pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
&mut self.program[block]
}
pub fn is_unreachable(&self, block: NodeIndex) -> bool {
let control_flow = self.program[block].control_flow.borrow();
matches!(*control_flow, ControlFlow::Unreachable)
}
pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
let processed = scope.process(self.processors.iter().map(|it| &**it));
for var in processed.variables {
if let VariableKind::LocalMut { id } = var.kind {
self.program.variables.insert(id, var.ty);
}
}
for (var, values) in scope.const_arrays.clone() {
let VariableKind::ConstantArray {
id,
length,
unroll_factor,
} = var.kind
else {
unreachable!()
};
self.program.const_arrays.push(ConstArray {
id,
length: length * unroll_factor,
item: var.ty,
values,
});
}
let is_break = processed.instructions.contains(&Branch::Break.into());
for mut instruction in processed.instructions {
let mut removed = false;
for transform in self.transformers.iter() {
match transform.maybe_transform(&mut scope, &instruction) {
TransformAction::Ignore => {}
TransformAction::Replace(replacement) => {
self.current_block_mut()
.ops
.borrow_mut()
.extend(replacement);
removed = true;
break;
}
TransformAction::Remove => {
removed = true;
break;
}
}
}
if removed {
continue;
}
match &mut instruction.operation {
Operation::Branch(branch) => match self.parse_control_flow(branch.clone()) {
ControlFlowAction::None => {}
ControlFlowAction::AbortBlock => {
break;
}
},
_ => {
self.current_block_mut().ops.borrow_mut().push(instruction);
}
}
}
is_break
}
pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
match variable.kind {
core::VariableKind::LocalMut { id } if !variable.ty.is_atomic() => Some(id),
_ => None,
}
}
pub(crate) fn ret(&mut self) -> NodeIndex {
if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
let new_ret = self.program.add_node(BasicBlock::default());
self.program.add_edge(new_ret, self.ret, 0);
self.ret = new_ret;
self.invalidate_structure();
new_ret
} else {
self.ret
}
}
pub fn const_arrays(&self) -> Vec<ConstArray> {
self.program.const_arrays.clone()
}
pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
Dot::with_config(&self.program, &[Config::EdgeNoLabel])
}
}
pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
#[cfg(test)]
mod test {
use cubecl_core as cubecl;
use cubecl_core::cube;
use cubecl_core::prelude::*;
use cubecl_ir::{ElemType, ManagedVariable, Type, UIntKind, Variable, VariableKind};
use crate::Optimizer;
#[allow(unused)]
#[cube(launch)]
fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
let mut y = 0;
let mut z = 0;
if cond == 0 {
y = x + 4;
}
z = x + 4;
out[0] = y;
out[1] = z;
}
#[test_log::test]
#[ignore = "no good way to assert opt is applied"]
fn test_pre() {
let mut ctx = Scope::root(false);
let x = ManagedVariable::Plain(Variable::new(
VariableKind::GlobalScalar(0),
Type::scalar(ElemType::UInt(UIntKind::U32)),
));
let cond = ManagedVariable::Plain(Variable::new(
VariableKind::GlobalScalar(1),
Type::scalar(ElemType::UInt(UIntKind::U32)),
));
let arr = ManagedVariable::Plain(Variable::new(
VariableKind::GlobalOutputArray(0),
Type::scalar(ElemType::UInt(UIntKind::U32)),
));
pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
let opt = Optimizer::new(ctx, CubeDim::new_1d(1), vec![], vec![]);
println!("{opt}")
}
}