use std::collections::{HashMap, HashSet, VecDeque};
use crate::{
analysis::{
dataflow::lattice::MeetSemiLattice, ConstValue, PhiNode, SsaBlock, SsaFunction, SsaOp,
SsaVarId,
},
metadata::typesystem::PointerSize,
utils::graph::{NodeId, RootedGraph, Successors},
};
pub struct ConstantPropagation {
values: HashMap<SsaVarId, ScalarValue>,
executable_edges: HashSet<(usize, usize)>,
executable_blocks: HashSet<usize>,
ssa_worklist: VecDeque<SsaVarId>,
cfg_worklist: VecDeque<(usize, usize)>,
back_edges: HashSet<(usize, usize)>,
pointer_size: PointerSize,
}
impl ConstantPropagation {
#[must_use]
pub fn new(ptr_size: PointerSize) -> Self {
Self {
values: HashMap::new(),
executable_edges: HashSet::new(),
executable_blocks: HashSet::new(),
ssa_worklist: VecDeque::new(),
cfg_worklist: VecDeque::new(),
back_edges: HashSet::new(),
pointer_size: ptr_size,
}
}
pub fn analyze<G>(&mut self, ssa: &SsaFunction, cfg: &G) -> SccpResult
where
G: RootedGraph + Successors,
{
self.initialize(ssa, cfg);
self.propagate(ssa, cfg);
SccpResult {
values: self.values.clone(),
executable_blocks: self.executable_blocks.clone(),
}
}
fn initialize<G>(&mut self, ssa: &SsaFunction, cfg: &G)
where
G: RootedGraph + Successors,
{
self.values.clear();
self.executable_edges.clear();
self.executable_blocks.clear();
self.ssa_worklist.clear();
self.cfg_worklist.clear();
self.back_edges.clear();
for var in ssa.variables() {
let initial_value = if var.origin().is_argument()
&& var.version() == 0
&& var.def_site().instruction.is_none()
{
ScalarValue::Bottom
} else {
ScalarValue::Top
};
self.values.insert(var.id(), initial_value);
}
let entry = cfg.entry().index();
self.executable_blocks.insert(entry);
for succ in cfg.successors(cfg.entry()) {
self.cfg_worklist.push_back((entry, succ.index()));
}
if let Some(block) = ssa.block(entry) {
self.process_block_definitions(block);
}
}
fn propagate<G>(&mut self, ssa: &SsaFunction, cfg: &G)
where
G: RootedGraph + Successors,
{
loop {
while let Some((from, to)) = self.cfg_worklist.pop_front() {
if self.executable_edges.insert((from, to)) {
if self.executable_blocks.contains(&to) {
self.back_edges.insert((from, to));
}
self.process_edge(from, to, ssa, cfg);
}
}
if let Some(var) = self.ssa_worklist.pop_front() {
self.process_variable_uses(var, ssa, cfg);
} else {
break;
}
}
}
fn process_edge<G>(&mut self, from: usize, to: usize, ssa: &SsaFunction, cfg: &G)
where
G: RootedGraph + Successors,
{
let first_visit = !self.executable_blocks.contains(&to);
if first_visit {
self.executable_blocks.insert(to);
if let Some(block) = ssa.block(to) {
self.process_block_definitions(block);
}
}
if let Some(block) = ssa.block(to) {
for phi in block.phi_nodes() {
if phi.operand_from(from).is_some() {
let new_value = self.evaluate_phi(phi, to);
self.update_value(phi.result(), &new_value);
}
}
}
if first_visit {
if let Some(block) = ssa.block(to) {
self.propagate_outgoing_edges(to, block, cfg);
}
}
}
fn process_block_definitions(&mut self, block: &SsaBlock) {
for instr in block.instructions() {
if let Some(def) = instr.def() {
let value = self.evaluate_instruction(instr.op());
self.update_value(def, &value);
}
}
}
fn process_variable_uses<G>(&mut self, var: SsaVarId, ssa: &SsaFunction, cfg: &G)
where
G: RootedGraph + Successors,
{
if let Some(ssa_var) = ssa.variable(var) {
for use_site in ssa_var.uses() {
let block_id = use_site.block;
if !self.executable_blocks.contains(&block_id) {
continue;
}
if use_site.is_phi_operand {
if let Some(block) = ssa.block(block_id) {
if let Some(phi) = block.phi(use_site.instruction) {
let new_value = self.evaluate_phi(phi, block_id);
self.update_value(phi.result(), &new_value);
}
}
} else {
if let Some(block) = ssa.block(block_id) {
if let Some(instr) = block.instruction(use_site.instruction) {
if let Some(def) = instr.def() {
let value = self.evaluate_instruction(instr.op());
self.update_value(def, &value);
}
if instr.is_terminator() {
self.propagate_outgoing_edges(block_id, block, cfg);
}
}
}
}
}
}
}
fn propagate_outgoing_edges<G>(&mut self, block_id: usize, block: &SsaBlock, cfg: &G)
where
G: RootedGraph + Successors,
{
match block.terminator_op() {
Some(SsaOp::Branch {
condition,
true_target,
false_target,
}) => {
match self.get_value(*condition) {
ScalarValue::Constant(c) => {
let target = if c.as_bool() == Some(true) {
*true_target
} else {
*false_target
};
self.add_cfg_edge(block_id, target);
}
ScalarValue::Top => {
}
ScalarValue::Bottom => {
self.add_cfg_edge(block_id, *true_target);
self.add_cfg_edge(block_id, *false_target);
}
}
}
Some(SsaOp::Switch {
value,
targets,
default,
}) => {
match self.get_value(*value) {
ScalarValue::Constant(c) => {
if let Some(idx) = c.as_i32().and_then(|i| usize::try_from(i).ok()) {
if idx < targets.len() {
self.add_cfg_edge(block_id, targets[idx]);
} else {
self.add_cfg_edge(block_id, *default);
}
} else {
self.add_cfg_edge(block_id, *default);
}
}
ScalarValue::Top | ScalarValue::Bottom => {
for &target in targets {
self.add_cfg_edge(block_id, target);
}
self.add_cfg_edge(block_id, *default);
}
}
}
Some(SsaOp::Jump { target }) => {
self.add_cfg_edge(block_id, *target);
}
Some(SsaOp::Return { .. } | SsaOp::Throw { .. } | SsaOp::Rethrow) => {
}
_ => {
let node = NodeId::new(block_id);
for succ in cfg.successors(node) {
self.add_cfg_edge(block_id, succ.index());
}
}
}
}
fn add_cfg_edge(&mut self, from: usize, to: usize) {
if !self.executable_edges.contains(&(from, to)) {
self.cfg_worklist.push_back((from, to));
}
}
fn evaluate_phi(&self, phi: &PhiNode, block_id: usize) -> ScalarValue {
let mut result = ScalarValue::Top;
let mut has_executable_operand = false;
for operand in phi.operands() {
let pred = operand.predecessor();
if !self.executable_edges.contains(&(pred, block_id)) {
continue;
}
has_executable_operand = true;
let op_value = if self.back_edges.contains(&(pred, block_id)) {
ScalarValue::Bottom
} else {
self.get_value(operand.value())
};
result = result.meet(&op_value);
if result.is_bottom() {
break;
}
}
if !has_executable_operand {
return ScalarValue::Top;
}
result
}
fn evaluate_instruction(&self, op: &SsaOp) -> ScalarValue {
match op {
SsaOp::Const { value, .. } => ScalarValue::Constant(value.clone()),
SsaOp::Copy { src, .. } => self.get_value(*src),
SsaOp::Add { left, right, .. } => {
self.evaluate_binary(*left, *right, |a, b| a.add(b, self.pointer_size))
}
SsaOp::Sub { left, right, .. } => {
self.evaluate_binary(*left, *right, |a, b| a.sub(b, self.pointer_size))
}
SsaOp::Mul { left, right, .. } => {
self.evaluate_binary(*left, *right, |a, b| a.mul(b, self.pointer_size))
}
SsaOp::Div { left, right, .. } => {
self.evaluate_binary(*left, *right, |a, b| a.div(b, self.pointer_size))
}
SsaOp::Rem { left, right, .. } => {
self.evaluate_binary(*left, *right, |a, b| a.rem(b, self.pointer_size))
}
SsaOp::And { left, right, .. } => {
self.evaluate_binary(*left, *right, |a, b| a.bitwise_and(b, self.pointer_size))
}
SsaOp::Or { left, right, .. } => {
self.evaluate_binary(*left, *right, |a, b| a.bitwise_or(b, self.pointer_size))
}
SsaOp::Xor { left, right, .. } => {
self.evaluate_binary(*left, *right, |a, b| a.bitwise_xor(b, self.pointer_size))
}
SsaOp::Shl { value, amount, .. } => {
self.evaluate_binary(*value, *amount, |a, b| a.shl(b, self.pointer_size))
}
SsaOp::Shr {
value,
amount,
unsigned,
..
} => {
let unsigned = *unsigned;
self.evaluate_binary(*value, *amount, |l, r| {
l.shr(r, unsigned, self.pointer_size)
})
}
SsaOp::Ceq { left, right, .. } => self.evaluate_binary(*left, *right, ConstValue::ceq),
SsaOp::Clt { left, right, .. } => self.evaluate_binary(*left, *right, ConstValue::clt),
SsaOp::Cgt { left, right, .. } => self.evaluate_binary(*left, *right, ConstValue::cgt),
SsaOp::Neg { operand, .. } => match self.get_value(*operand) {
ScalarValue::Top => ScalarValue::Top,
ScalarValue::Constant(c) => c
.negate(self.pointer_size)
.map_or(ScalarValue::Bottom, ScalarValue::Constant),
ScalarValue::Bottom => ScalarValue::Bottom,
},
SsaOp::Not { operand, .. } => match self.get_value(*operand) {
ScalarValue::Top => ScalarValue::Top,
ScalarValue::Constant(c) => c
.bitwise_not(self.pointer_size)
.map_or(ScalarValue::Bottom, ScalarValue::Constant),
ScalarValue::Bottom => ScalarValue::Bottom,
},
_ => ScalarValue::Bottom,
}
}
fn evaluate_binary<F>(&self, left: SsaVarId, right: SsaVarId, f: F) -> ScalarValue
where
F: FnOnce(&ConstValue, &ConstValue) -> Option<ConstValue>,
{
let left_val = self.get_value(left);
let right_val = self.get_value(right);
match (&left_val, &right_val) {
(ScalarValue::Top, _) | (_, ScalarValue::Top) => ScalarValue::Top,
(ScalarValue::Constant(l), ScalarValue::Constant(r)) => {
f(l, r).map_or(ScalarValue::Bottom, ScalarValue::Constant)
}
_ => ScalarValue::Bottom,
}
}
fn get_value(&self, var: SsaVarId) -> ScalarValue {
self.values.get(&var).cloned().unwrap_or_default()
}
fn update_value(&mut self, var: SsaVarId, new_value: &ScalarValue) {
let old_value = self.values.get(&var).cloned().unwrap_or_default();
let final_value = old_value.meet(new_value);
if final_value != old_value {
self.values.insert(var, final_value);
self.ssa_worklist.push_back(var);
}
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum ScalarValue {
#[default]
Top,
Constant(ConstValue),
Bottom,
}
impl ScalarValue {
#[must_use]
pub const fn is_top(&self) -> bool {
matches!(self, Self::Top)
}
#[must_use]
pub const fn is_bottom(&self) -> bool {
matches!(self, Self::Bottom)
}
#[must_use]
pub const fn is_constant(&self) -> bool {
matches!(self, Self::Constant(_))
}
#[must_use]
pub const fn as_constant(&self) -> Option<&ConstValue> {
match self {
Self::Constant(c) => Some(c),
_ => None,
}
}
}
impl MeetSemiLattice for ScalarValue {
fn meet(&self, other: &Self) -> Self {
match (self, other) {
(Self::Top, x) | (x, Self::Top) => x.clone(),
(Self::Constant(a), Self::Constant(b)) if a == b => Self::Constant(a.clone()),
_ => Self::Bottom,
}
}
fn is_bottom(&self) -> bool {
matches!(self, Self::Bottom)
}
}
#[derive(Debug, Clone)]
pub struct SccpResult {
values: HashMap<SsaVarId, ScalarValue>,
executable_blocks: HashSet<usize>,
}
impl SccpResult {
#[must_use]
pub fn empty() -> Self {
Self {
values: HashMap::new(),
executable_blocks: HashSet::new(),
}
}
#[must_use]
pub fn get_value(&self, var: SsaVarId) -> Option<&ScalarValue> {
self.values.get(&var)
}
#[must_use]
pub fn is_constant(&self, var: SsaVarId) -> bool {
self.values
.get(&var)
.is_some_and(|v| matches!(v, ScalarValue::Constant(_)))
}
#[must_use]
pub fn constant_value(&self, var: SsaVarId) -> Option<&ConstValue> {
self.values.get(&var).and_then(|v| match v {
ScalarValue::Constant(c) => Some(c),
_ => None,
})
}
#[must_use]
pub fn is_block_executable(&self, block: usize) -> bool {
self.executable_blocks.contains(&block)
}
pub fn constants(&self) -> impl Iterator<Item = (SsaVarId, &ConstValue)> {
self.values.iter().filter_map(|(var, val)| match val {
ScalarValue::Constant(c) => Some((*var, c)),
_ => None,
})
}
pub fn executable_blocks(&self) -> impl Iterator<Item = usize> + '_ {
self.executable_blocks.iter().copied()
}
#[must_use]
pub fn constant_count(&self) -> usize {
self.values
.values()
.filter(|v| matches!(v, ScalarValue::Constant(_)))
.count()
}
#[must_use]
pub fn executable_block_count(&self) -> usize {
self.executable_blocks.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_value_meet() {
assert_eq!(
ScalarValue::Top.meet(&ScalarValue::Constant(ConstValue::I32(5))),
ScalarValue::Constant(ConstValue::I32(5))
);
assert_eq!(
ScalarValue::Constant(ConstValue::I32(5))
.meet(&ScalarValue::Constant(ConstValue::I32(5))),
ScalarValue::Constant(ConstValue::I32(5))
);
assert_eq!(
ScalarValue::Constant(ConstValue::I32(5))
.meet(&ScalarValue::Constant(ConstValue::I32(10))),
ScalarValue::Bottom
);
assert_eq!(
ScalarValue::Bottom.meet(&ScalarValue::Constant(ConstValue::I32(5))),
ScalarValue::Bottom
);
}
#[test]
fn test_scalar_value_accessors() {
let top = ScalarValue::Top;
let const_val = ScalarValue::Constant(ConstValue::I32(42));
let bottom = ScalarValue::Bottom;
assert!(top.is_top());
assert!(!top.is_constant());
assert!(!top.is_bottom());
assert!(!const_val.is_top());
assert!(const_val.is_constant());
assert!(!const_val.is_bottom());
assert_eq!(const_val.as_constant(), Some(&ConstValue::I32(42)));
assert!(!bottom.is_top());
assert!(!bottom.is_constant());
assert!(bottom.is_bottom());
}
#[test]
fn test_sccp_result() {
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
let v2 = SsaVarId::new();
let mut values = HashMap::new();
values.insert(v0, ScalarValue::Constant(ConstValue::I32(42)));
values.insert(v1, ScalarValue::Bottom);
values.insert(v2, ScalarValue::Top);
let mut executable_blocks = HashSet::new();
executable_blocks.insert(0);
executable_blocks.insert(1);
let result = SccpResult {
values,
executable_blocks,
};
assert!(result.is_constant(v0));
assert!(!result.is_constant(v1));
assert!(!result.is_constant(v2));
assert_eq!(result.constant_value(v0), Some(&ConstValue::I32(42)));
assert_eq!(result.constant_value(v1), None);
assert!(result.is_block_executable(0));
assert!(result.is_block_executable(1));
assert!(!result.is_block_executable(2));
assert_eq!(result.constant_count(), 1);
assert_eq!(result.executable_block_count(), 2);
}
}