use std::{collections::HashMap, mem::take};
use cubecl_ir::{
Id, IndexAssignOperator, IndexOperator, Instruction, Operation, Operator, Type, Variable,
VariableKind, VectorInitOperator,
};
use stable_vec::StableVec;
use crate::{AtomicCounter, Optimizer};
use super::OptimizerPass;
pub struct CompositeMerge;
impl OptimizerPass for CompositeMerge {
fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) {
let blocks = opt.node_ids();
for block in blocks {
let mut assigns = HashMap::<Id, Vec<(usize, u32, Variable)>>::new();
let ops = opt.program[block].ops.clone();
let indices = { ops.borrow().indices().collect::<Vec<_>>() };
for idx in indices {
{
let op = &mut ops.borrow_mut()[idx];
opt.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
if let Some(id) = opt.local_variable_id(var) {
assigns.remove(&id);
}
});
}
let op = { ops.borrow()[idx].clone() };
if let (
Operation::Operator(Operator::IndexAssign(IndexAssignOperator {
index,
value,
..
})),
Some(VariableKind::LocalMut { id }),
) = (op.operation, op.out.map(|it| it.kind))
&& value.is_immutable()
{
let item = op.out.unwrap().ty;
if let Some(index) = index.as_const() {
let index = index.as_u32();
let vector_size = item.vector_size();
if vector_size > 1 {
let assigns = assigns.entry(id).or_default();
assigns.push((idx, index, value));
if assigns.len() == vector_size {
merge_assigns(
&mut opt.program[block].ops.borrow_mut(),
take(assigns),
id,
item,
);
opt.program.variables.insert(id, item);
changes.inc();
}
} else {
assert_eq!(index, 0, "Can't index into scalar {}", op.out.unwrap());
opt.program[block].ops.borrow_mut()[idx] = Instruction::new(
Operation::Copy(value),
Variable::new(VariableKind::LocalMut { id }, item),
);
opt.program.variables.insert(id, item);
changes.inc();
}
}
}
}
}
}
}
fn merge_assigns(
ops: &mut StableVec<Instruction>,
mut assigns: Vec<(usize, u32, Variable)>,
id: Id,
item: Type,
) {
for assignment in assigns.iter() {
ops.remove(assignment.0);
}
let last = assigns.iter().map(|it| it.0).max().unwrap();
assigns.sort_by_key(|it| it.1);
let inputs = assigns.iter().map(|it| it.2).collect::<Vec<_>>();
let out = Variable::new(VariableKind::LocalMut { id }, item);
ops.insert(
last,
Instruction::new(Operator::InitVector(VectorInitOperator { inputs }), out),
);
}
pub struct RemoveIndexScalar;
impl OptimizerPass for RemoveIndexScalar {
fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) {
let blocks = opt.node_ids();
for block in blocks {
let ops = opt.program[block].ops.clone();
for op in ops.borrow_mut().values_mut() {
if let Operation::Operator(Operator::Index(IndexOperator { list, index, .. })) =
&mut op.operation
&& !list.is_array()
&& let Some(index) = index.as_const()
{
let index = index.as_u32();
let vector_size = list.ty.vector_size();
if vector_size == 1 {
assert_eq!(index, 0, "Can't index into scalar");
op.operation = Operation::Copy(*list);
changes.inc();
}
}
}
}
}
}