use std::collections::HashMap;
use cubecl_ir::{
Id, Instruction, Operation, Operator, Type, UnaryOperator, Variable, VariableKind,
};
use crate::{AtomicCounter, Optimizer, analyses::writes::Writes};
use super::OptimizerPass;
pub struct DisaggregateArray;
impl OptimizerPass for DisaggregateArray {
fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) {
let arrays = find_const_arrays(opt);
for Array { id, length, item } in arrays {
let block = opt.entry();
let old_insts = opt.program[block].ops.take();
let arr_id = id;
let vars = (0..length)
.map(|_| *opt.root_scope.create_local_restricted(item))
.collect::<Vec<_>>();
for var in &vars {
let local_id = opt.local_variable_id(var).unwrap();
opt.program.variables.insert(local_id, var.ty);
let assign =
Instruction::new(Operator::Cast(UnaryOperator { input: 0u32.into() }), *var);
opt.program[block].ops.borrow_mut().push(assign);
}
opt.program[block]
.ops
.borrow_mut()
.extend(old_insts.into_iter().map(|it| it.1));
replace_const_arrays(opt, arr_id, &vars);
changes.inc();
}
}
}
#[derive(Clone, Copy, Debug)]
struct Array {
id: Id,
length: usize,
item: Type,
}
fn find_const_arrays(opt: &mut Optimizer) -> Vec<Array> {
let mut track_consts = HashMap::new();
let mut arrays = HashMap::new();
for block in opt.node_ids() {
let ops = opt.program[block].ops.clone();
for op in ops.borrow().values() {
match &op.operation {
Operation::Operator(Operator::Index(index) | Operator::UncheckedIndex(index)) => {
if let VariableKind::LocalArray {
id,
length,
unroll_factor,
} = index.list.kind
{
let item = index.list.ty;
arrays.insert(
id,
Array {
id,
length: length * unroll_factor,
item,
},
);
let is_const = index.index.as_const().is_some();
*track_consts.entry(id).or_insert(is_const) &= is_const;
}
}
Operation::Operator(
Operator::IndexAssign(assign) | Operator::UncheckedIndexAssign(assign),
) => {
if let VariableKind::LocalArray {
id,
length,
unroll_factor,
} = op.out().kind
{
let item = op.out().ty;
arrays.insert(
id,
Array {
id,
length: length * unroll_factor,
item,
},
);
let is_const = assign.index.as_const().is_some();
*track_consts.entry(id).or_insert(is_const) &= is_const;
}
}
_ => {}
}
}
}
track_consts
.iter()
.filter(|(_, is_const)| **is_const)
.map(|(id, _)| arrays[id])
.collect()
}
fn replace_const_arrays(opt: &mut Optimizer, arr_id: Id, vars: &[Variable]) {
for block in opt.node_ids() {
let ops = opt.program[block].ops.clone();
for op in ops.borrow_mut().values_mut() {
match &mut op.operation.clone() {
Operation::Operator(Operator::Index(index) | Operator::UncheckedIndex(index)) => {
if let VariableKind::LocalArray { id, .. } = index.list.kind
&& id == arr_id
{
let const_index = index.index.as_const().unwrap().as_i64() as usize;
op.operation = Operation::Copy(vars[const_index]);
}
}
Operation::Operator(
Operator::IndexAssign(assign) | Operator::UncheckedIndexAssign(assign),
) => {
if let VariableKind::LocalArray { id, .. } = op.out.unwrap().kind
&& id == arr_id
{
let const_index = assign.index.as_const().unwrap().as_i64() as usize;
let out = vars[const_index];
*op = Instruction::new(Operation::Copy(assign.value), out);
opt.invalidate_analysis::<Writes>();
}
}
_ => {}
}
}
}
}