use std::collections::{HashMap, HashSet};
use cubecl_ir::{CopyMemoryOperator, Id, Instruction, Operation, Operator, Variable, VariableKind};
use crate::{AtomicCounter, Optimizer};
use super::OptimizerPass;
pub struct CopyTransform;
impl OptimizerPass for CopyTransform {
fn apply_post_ssa(&mut self, opt: &mut Optimizer, changes: AtomicCounter) {
for block in opt.node_ids() {
let mut reads = HashMap::new();
let mut writes = HashMap::new();
let ops = opt.program[block].ops.clone();
let indices = ops.borrow().indices().collect::<Vec<_>>();
for idx in indices {
let inst = ops.borrow()[idx].clone();
match &inst.operation {
Operation::Operator(Operator::Index(op))
if op.list.is_memory()
&& op.list.ty == inst.ty()
&& !is_reused(opt, &inst.out) =>
{
if let Some(id) = as_versioned(&inst.out()) {
reads.insert(id, (idx, op.list, op.index));
}
}
Operation::Operator(Operator::IndexAssign(op))
if inst.out().is_memory() && inst.ty() == op.value.ty =>
{
if let Some(id) = as_versioned(&op.value) {
writes.insert(id, (idx, inst.out(), op.index));
}
}
_ => {}
}
}
let read_ids: HashSet<_> = reads.keys().collect();
let write_ids: HashSet<_> = writes.keys().collect();
let copy_ids = read_ids.intersection(&write_ids);
for id in copy_ids {
let (read_idx, input, in_index) = reads[*id];
let (write_idx, out, out_index) = writes[*id];
let valid = (read_idx..write_idx)
.filter_map(|idx| ops.borrow().get(idx).and_then(|it| it.out))
.all(|write| write != input && write != out);
if !valid {
continue;
}
ops.borrow_mut().remove(read_idx);
let copy = Operator::CopyMemory(CopyMemoryOperator {
out_index,
input,
in_index,
});
ops.borrow_mut()[write_idx] = Instruction::new(copy, out);
changes.inc();
}
}
}
}
fn as_versioned(var: &Variable) -> Option<(Id, u16)> {
match var.kind {
VariableKind::LocalConst { id } => Some((id, 0)),
VariableKind::Versioned { id, version } => Some((id, version)),
_ => None,
}
}
fn is_reused(opt: &mut Optimizer, var: &Option<Variable>) -> bool {
if let Some(var) = var.as_ref() {
let count = AtomicCounter::new(0);
opt.visit_all(
|_, other| {
if other == var {
count.inc();
}
},
|_, _| {},
);
count.get() > 1
} else {
false
}
}