use std::{collections::BTreeSet, ops};
use cranelift_entity::SecondaryMap;
use crate::cfg::ControlFlowGraph;
use sonatina_ir::{
func_cursor::{CursorLocation, FuncCursor, InsnInserter},
insn::{BinaryOp, CastOp, InsnData, UnaryOp},
Block, Function, Immediate, Insn, Type, Value,
};
#[derive(Debug)]
pub struct SccpSolver {
lattice: SecondaryMap<Value, LatticeCell>,
reachable_edges: BTreeSet<FlowEdge>,
reachable_blocks: BTreeSet<Block>,
flow_work: Vec<FlowEdge>,
ssa_work: Vec<Value>,
}
impl SccpSolver {
pub fn new() -> Self {
Self {
lattice: SecondaryMap::default(),
reachable_edges: BTreeSet::default(),
reachable_blocks: BTreeSet::default(),
flow_work: Vec::default(),
ssa_work: Vec::default(),
}
}
pub fn run(&mut self, func: &mut Function, cfg: &mut ControlFlowGraph) {
self.clear();
let entry_block = match func.layout.entry_block() {
Some(block) => block,
_ => return,
};
for arg in &func.arg_values {
self.lattice[*arg] = LatticeCell::Top;
}
self.reachable_blocks.insert(entry_block);
self.eval_insns_in(func, entry_block);
let mut changed = true;
while changed {
changed = false;
while let Some(edge) = self.flow_work.pop() {
changed = true;
self.eval_edge(func, edge);
}
while let Some(value) = self.ssa_work.pop() {
changed = true;
for &user in func.dfg.users(value) {
let user_block = func.layout.insn_block(user);
if self.reachable_blocks.contains(&user_block) {
if func.dfg.is_phi(user) {
self.eval_phi(func, user);
} else {
self.eval_insn(func, user);
}
}
}
}
}
self.remove_unreachable_edges(func);
cfg.compute(func);
self.fold_insns(func, cfg);
}
pub fn clear(&mut self) {
self.lattice.clear();
self.reachable_edges.clear();
self.reachable_blocks.clear();
self.flow_work.clear();
self.ssa_work.clear();
}
fn eval_edge(&mut self, func: &mut Function, edge: FlowEdge) {
let dest = edge.to;
if self.reachable_edges.contains(&edge) {
return;
}
self.reachable_edges.insert(edge);
if self.reachable_blocks.contains(&dest) {
self.eval_phis_in(func, dest);
} else {
self.reachable_blocks.insert(dest);
self.eval_insns_in(func, dest);
}
if let Some(last_insn) = func.layout.last_insn_of(dest) {
let branch_info = func.dfg.analyze_branch(last_insn);
if branch_info.dests_num() == 1 {
self.flow_work.push(FlowEdge::new(
last_insn,
branch_info.iter_dests().next().unwrap(),
))
}
}
}
fn eval_phis_in(&mut self, func: &Function, block: Block) {
for insn in func.layout.iter_insn(block) {
if func.dfg.is_phi(insn) {
self.eval_phi(func, insn);
}
}
}
fn eval_phi(&mut self, func: &Function, insn: Insn) {
debug_assert!(func.dfg.is_phi(insn));
debug_assert!(self
.reachable_blocks
.contains(&func.layout.insn_block(insn)));
for &arg in func.dfg.insn_args(insn) {
if let Some(imm) = func.dfg.value_imm(arg) {
self.set_lattice_cell(arg, LatticeCell::Const(imm));
}
}
let block = func.layout.insn_block(insn);
let mut eval_result = LatticeCell::Bot;
for (i, from) in func.dfg.phi_blocks(insn).iter().enumerate() {
if self.is_reachable(func, *from, block) {
let phi_arg = func.dfg.insn_arg(insn, i);
let v_cell = self.lattice[phi_arg];
eval_result = eval_result.join(v_cell);
}
}
let phi_value = func.dfg.insn_result(insn).unwrap();
if eval_result != self.lattice[phi_value] {
self.ssa_work.push(phi_value);
self.lattice[phi_value] = eval_result;
}
}
fn eval_insns_in(&mut self, func: &Function, block: Block) {
for insn in func.layout.iter_insn(block) {
if func.dfg.is_phi(insn) {
self.eval_phi(func, insn);
} else {
self.eval_insn(func, insn);
}
}
}
fn eval_insn(&mut self, func: &Function, insn: Insn) {
debug_assert!(!func.dfg.is_phi(insn));
for &arg in func.dfg.insn_args(insn) {
if let Some(imm) = func.dfg.value_imm(arg) {
self.set_lattice_cell(arg, LatticeCell::Const(imm));
}
}
let cell = match func.dfg.insn_data(insn) {
InsnData::Unary { code, args } => {
let arg_cell = self.lattice[args[0]];
match *code {
UnaryOp::Not => arg_cell.not(),
UnaryOp::Neg => arg_cell.neg(),
}
}
InsnData::Binary { code, args } => {
let lhs = self.lattice[args[0]];
let rhs = self.lattice[args[1]];
match *code {
BinaryOp::Add => lhs.add(rhs),
BinaryOp::Sub => lhs.sub(rhs),
BinaryOp::Mul => lhs.mul(rhs),
BinaryOp::Udiv => lhs.udiv(rhs),
BinaryOp::Sdiv => lhs.sdiv(rhs),
BinaryOp::Lt => lhs.lt(rhs),
BinaryOp::Gt => lhs.gt(rhs),
BinaryOp::Slt => lhs.slt(rhs),
BinaryOp::Sgt => lhs.sgt(rhs),
BinaryOp::Le => lhs.le(rhs),
BinaryOp::Ge => lhs.ge(rhs),
BinaryOp::Sle => lhs.sle(rhs),
BinaryOp::Sge => lhs.sge(rhs),
BinaryOp::Eq => lhs.eq(rhs),
BinaryOp::Ne => lhs.ne(rhs),
BinaryOp::And => lhs.and(rhs),
BinaryOp::Or => lhs.or(rhs),
BinaryOp::Xor => lhs.xor(rhs),
}
}
InsnData::Cast { code, args, ty } => {
let arg_cell = self.lattice[args[0]];
match code {
CastOp::Sext => arg_cell.sext(*ty),
CastOp::Zext => arg_cell.zext(*ty),
CastOp::Trunc => arg_cell.trunc(*ty),
CastOp::BitCast => LatticeCell::Top,
}
}
InsnData::Load { .. } => LatticeCell::Top,
InsnData::Call { .. } => LatticeCell::Top,
InsnData::Jump { dests, .. } => {
self.flow_work.push(FlowEdge::new(insn, dests[0]));
return;
}
InsnData::Branch { args, dests } => {
let v_cell = self.lattice[args[0]];
if v_cell.is_top() {
self.flow_work.push(FlowEdge::new(insn, dests[0]));
self.flow_work.push(FlowEdge::new(insn, dests[1]));
} else if v_cell.is_bot() {
unreachable!();
} else if v_cell.is_zero() {
self.flow_work.push(FlowEdge::new(insn, dests[1]));
} else {
self.flow_work.push(FlowEdge::new(insn, dests[0]));
}
return;
}
InsnData::BrTable {
args,
default,
table,
} => {
let mut add_all_dest = || {
if let Some(default) = default {
self.flow_work.push(FlowEdge::new(insn, *default));
}
for dest in table {
self.flow_work.push(FlowEdge::new(insn, *dest));
}
};
let v_cell = self.lattice[args[0]];
if v_cell.is_top() {
add_all_dest();
return;
}
if v_cell.is_bot() {
unreachable!()
}
let mut contains_top = false;
for (value, dest) in args[1..].iter().zip(table.iter()) {
if self.lattice[*value] == v_cell {
self.flow_work.push(FlowEdge::new(insn, *dest));
return;
} else if v_cell.is_top() {
contains_top = true;
}
}
if contains_top {
add_all_dest();
} else {
if let Some(default) = default {
self.flow_work.push(FlowEdge::new(insn, *default));
}
}
return;
}
InsnData::Alloca { .. } | InsnData::Gep { .. } => LatticeCell::Top,
InsnData::Store { .. } | InsnData::Return { .. } => {
return;
}
InsnData::Phi { .. } => unreachable!(),
};
let insn_result = func.dfg.insn_result(insn).unwrap();
self.set_lattice_cell(insn_result, cell);
}
fn remove_unreachable_edges(&self, func: &mut Function) {
let entry_block = func.layout.entry_block().unwrap();
let mut inserter = InsnInserter::new(func, CursorLocation::BlockTop(entry_block));
loop {
match inserter.loc() {
CursorLocation::BlockTop(block) => {
if !self.reachable_blocks.contains(&block) {
inserter.remove_block();
} else {
inserter.proceed();
}
}
CursorLocation::BlockBottom(..) => inserter.proceed(),
CursorLocation::At(insn) => {
if inserter.func().dfg.is_branch(insn) {
let branch_info = inserter.func().dfg.analyze_branch(insn);
for dest in branch_info.iter_dests().collect::<Vec<_>>() {
if !self.is_reachable_edge(insn, dest) {
inserter.func_mut().dfg.remove_branch_dest(insn, dest);
}
}
}
inserter.proceed();
}
CursorLocation::NoWhere => break,
}
}
}
fn is_reachable_edge(&self, insn: Insn, dest: Block) -> bool {
self.reachable_edges.contains(&FlowEdge::new(insn, dest))
}
fn fold_insns(&mut self, func: &mut Function, cfg: &ControlFlowGraph) {
let mut rpo: Vec<_> = cfg.post_order().collect();
rpo.reverse();
for block in rpo {
let mut next_insn = func.layout.first_insn_of(block);
while let Some(insn) = next_insn {
next_insn = func.layout.next_insn_of(insn);
self.fold(func, insn);
}
}
}
fn fold(&self, func: &mut Function, insn: Insn) {
let insn_result = match func.dfg.insn_result(insn) {
Some(result) => result,
None => return,
};
match self.lattice[insn_result].to_imm() {
Some(imm) => {
InsnInserter::new(func, CursorLocation::At(insn)).remove_insn();
let new_value = func.dfg.make_imm_value(imm);
func.dfg.change_to_alias(insn_result, new_value);
}
None => {
if func.dfg.is_phi(insn) {
self.try_fold_phi(func, insn)
}
}
}
}
fn try_fold_phi(&self, func: &mut Function, insn: Insn) {
debug_assert!(func.dfg.is_phi(insn));
let mut blocks = func.dfg.phi_blocks(insn).to_vec();
blocks.retain(|block| !self.reachable_blocks.contains(block));
for block in blocks {
func.dfg.remove_phi_arg(insn, block);
}
if func.dfg.insn_args_num(insn) == 1 {
let phi_value = func.dfg.insn_result(insn).unwrap();
func.dfg
.change_to_alias(phi_value, func.dfg.insn_arg(insn, 0));
InsnInserter::new(func, CursorLocation::At(insn)).remove_insn();
}
}
fn is_reachable(&self, func: &Function, from: Block, to: Block) -> bool {
let last_insn = if let Some(insn) = func.layout.last_insn_of(from) {
insn
} else {
return false;
};
for dest in func.dfg.analyze_branch(last_insn).iter_dests() {
if dest == to
&& self
.reachable_edges
.contains(&FlowEdge::new(last_insn, dest))
{
return true;
}
}
false
}
fn set_lattice_cell(&mut self, value: Value, cell: LatticeCell) {
let old_cell = &self.lattice[value];
if old_cell != &cell {
self.lattice[value] = cell;
self.ssa_work.push(value);
}
}
}
impl Default for SccpSolver {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct FlowEdge {
insn: Insn,
to: Block,
}
impl FlowEdge {
fn new(insn: Insn, to: Block) -> Self {
Self { insn, to }
}
}
#[derive(Debug, Clone, Copy)]
enum LatticeCell {
Top,
Const(Immediate),
Bot,
}
impl PartialEq for LatticeCell {
fn eq(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Top, Self::Top) | (Self::Bot, Self::Bot) => true,
(Self::Const(v1), Self::Const(v2)) => v1 == v2,
_ => false,
}
}
}
impl LatticeCell {
fn to_imm(self) -> Option<Immediate> {
match self {
Self::Top | Self::Bot => None,
Self::Const(imm) => Some(imm),
}
}
fn is_zero(self) -> bool {
match self {
Self::Top | Self::Bot => false,
Self::Const(c) => c.is_zero(),
}
}
fn is_top(self) -> bool {
matches!(self, Self::Top)
}
fn is_bot(self) -> bool {
matches!(self, Self::Bot)
}
fn join(self, rhs: Self) -> Self {
match (self, rhs) {
(Self::Top, _) | (_, Self::Top) => Self::Top,
(Self::Const(v1), Self::Const(v2)) => {
if v1 == v2 {
self
} else {
Self::Top
}
}
(Self::Bot, other) | (other, Self::Bot) => other,
}
}
fn apply_unop<F>(self, f: F) -> Self
where
F: FnOnce(Immediate) -> Immediate,
{
match self {
Self::Top => Self::Top,
Self::Const(lhs) => Self::Const(f(lhs)),
Self::Bot => Self::Bot,
}
}
fn apply_binop<F>(self, rhs: Self, f: F) -> Self
where
F: FnOnce(Immediate, Immediate) -> Immediate,
{
match (self, rhs) {
(Self::Top, _) | (_, Self::Top) => Self::Top,
(Self::Const(lhs), Self::Const(rhs)) => Self::Const(f(lhs, rhs)),
(Self::Bot, _) | (_, Self::Bot) => Self::Bot,
}
}
fn not(self) -> Self {
self.apply_unop(ops::Not::not)
}
fn neg(self) -> Self {
self.apply_unop(ops::Neg::neg)
}
fn add(self, rhs: Self) -> Self {
self.apply_binop(rhs, ops::Add::add)
}
fn sub(self, rhs: Self) -> Self {
self.apply_binop(rhs, ops::Sub::sub)
}
fn mul(self, rhs: Self) -> Self {
self.apply_binop(rhs, ops::Mul::mul)
}
fn udiv(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::udiv)
}
fn sdiv(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::sdiv)
}
fn lt(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::lt)
}
fn gt(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::gt)
}
fn slt(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::slt)
}
fn sgt(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::sgt)
}
fn le(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::le)
}
fn ge(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::ge)
}
fn sle(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::sle)
}
fn sge(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::sge)
}
fn eq(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::imm_eq)
}
fn ne(self, rhs: Self) -> Self {
self.apply_binop(rhs, Immediate::imm_ne)
}
fn and(self, rhs: Self) -> Self {
self.apply_binop(rhs, ops::BitAnd::bitand)
}
fn or(self, rhs: Self) -> Self {
self.apply_binop(rhs, ops::BitOr::bitor)
}
fn xor(self, rhs: Self) -> Self {
self.apply_binop(rhs, ops::BitXor::bitxor)
}
fn sext(self, ty: Type) -> Self {
self.apply_unop(|val| Immediate::sext(val, ty))
}
fn zext(self, ty: Type) -> Self {
self.apply_unop(|val| Immediate::zext(val, ty))
}
fn trunc(self, ty: Type) -> Self {
self.apply_unop(|val| Immediate::trunc(val, ty))
}
}
impl Default for LatticeCell {
fn default() -> Self {
Self::Bot
}
}