use std::collections::{BTreeMap, BTreeSet};
use crate::cfg::{CfgStmt, MirCfg};
use crate::dominators::DominatorTree;
use crate::BlockId;
use crate::MirExprKind;
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct SsaVar {
pub name: String,
pub version: u32,
}
impl std::fmt::Display for SsaVar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}_{}", self.name, self.version)
}
}
#[derive(Debug, Clone)]
pub struct PhiNode {
pub target: SsaVar,
pub sources: Vec<(BlockId, SsaVar)>,
}
#[derive(Debug, Clone)]
pub struct SsaForm {
pub phis: Vec<Vec<PhiNode>>,
pub num_blocks: usize,
pub entry: BlockId,
pub version_counts: BTreeMap<String, u32>,
pub def_versions: BTreeMap<(u32, String, usize), u32>,
pub params: Vec<String>,
}
fn collect_defined_vars(cfg: &MirCfg) -> BTreeSet<String> {
let mut vars = BTreeSet::new();
for block in &cfg.basic_blocks {
for stmt in &block.statements {
match stmt {
CfgStmt::Let { name, .. } => {
vars.insert(name.clone());
}
CfgStmt::Expr(expr) => {
if let MirExprKind::Assign { target, .. } = &expr.kind {
if let MirExprKind::Var(name) = &target.kind {
vars.insert(name.clone());
}
}
}
}
}
}
vars
}
fn collect_def_blocks(
cfg: &MirCfg,
variables: &BTreeSet<String>,
) -> BTreeMap<String, BTreeSet<BlockId>> {
let mut def_blocks: BTreeMap<String, BTreeSet<BlockId>> = BTreeMap::new();
for var in variables {
def_blocks.insert(var.clone(), BTreeSet::new());
}
for block in &cfg.basic_blocks {
for stmt in &block.statements {
match stmt {
CfgStmt::Let { name, .. } => {
if let Some(set) = def_blocks.get_mut(name) {
set.insert(block.id);
}
}
CfgStmt::Expr(expr) => {
if let MirExprKind::Assign { target, .. } = &expr.kind {
if let MirExprKind::Var(name) = &target.kind {
if let Some(set) = def_blocks.get_mut(name) {
set.insert(block.id);
}
}
}
}
}
}
}
def_blocks
}
fn fresh_version(
name: &str,
counters: &mut BTreeMap<String, u32>,
stacks: &mut BTreeMap<String, Vec<u32>>,
) -> u32 {
let c = counters.get_mut(name).unwrap();
let ver = *c;
*c += 1;
stacks.get_mut(name).unwrap().push(ver);
ver
}
fn current_version(name: &str, stacks: &BTreeMap<String, Vec<u32>>) -> Option<u32> {
stacks.get(name).and_then(|s| s.last().copied())
}
fn rename_dfs(
block_id: BlockId,
cfg: &MirCfg,
domtree: &DominatorTree,
preds: &[Vec<BlockId>],
counters: &mut BTreeMap<String, u32>,
stacks: &mut BTreeMap<String, Vec<u32>>,
result_phis: &mut [Vec<PhiNode>],
def_versions: &mut BTreeMap<(u32, String, usize), u32>,
) {
let b = block_id.0 as usize;
let mut push_counts: Vec<(String, usize)> = Vec::new();
for phi in &mut result_phis[b] {
let ver = fresh_version(&phi.target.name, counters, stacks);
phi.target.version = ver;
if let Some(entry) = push_counts.iter_mut().find(|(n, _)| *n == phi.target.name) {
entry.1 += 1;
} else {
push_counts.push((phi.target.name.clone(), 1));
}
}
let block = &cfg.basic_blocks[b];
for (idx, stmt) in block.statements.iter().enumerate() {
match stmt {
CfgStmt::Let { name, .. } => {
let ver = fresh_version(name, counters, stacks);
def_versions.insert((block_id.0, name.clone(), idx), ver);
if let Some(entry) = push_counts.iter_mut().find(|(n, _)| n == name) {
entry.1 += 1;
} else {
push_counts.push((name.clone(), 1));
}
}
CfgStmt::Expr(expr) => {
if let MirExprKind::Assign { target, .. } = &expr.kind {
if let MirExprKind::Var(name) = &target.kind {
let ver = fresh_version(name, counters, stacks);
def_versions.insert((block_id.0, name.clone(), idx), ver);
if let Some(entry) = push_counts.iter_mut().find(|(n, _)| n == name) {
entry.1 += 1;
} else {
push_counts.push((name.clone(), 1));
}
}
}
}
}
}
for succ_id in cfg.successors(block_id) {
let s = succ_id.0 as usize;
let pred_idx = preds[s].iter().position(|p| *p == block_id);
if let Some(j) = pred_idx {
for phi in &mut result_phis[s] {
let var_name = &phi.target.name;
if let Some(ver) = current_version(var_name, stacks) {
if j < phi.sources.len() {
phi.sources[j] = (
block_id,
SsaVar {
name: var_name.clone(),
version: ver,
},
);
}
}
}
}
}
let children = domtree.children(block_id);
for child in children {
rename_dfs(
child,
cfg,
domtree,
preds,
counters,
stacks,
result_phis,
def_versions,
);
}
for (name, count) in &push_counts {
let stack = stacks.get_mut(name).unwrap();
for _ in 0..*count {
stack.pop();
}
}
}
impl SsaForm {
pub fn construct(cfg: &MirCfg, params: &[String]) -> Self {
let n = cfg.basic_blocks.len();
if n == 0 {
return SsaForm {
phis: vec![],
num_blocks: 0,
entry: cfg.entry,
version_counts: BTreeMap::new(),
def_versions: BTreeMap::new(),
params: params.to_vec(),
};
}
let domtree = DominatorTree::compute(cfg);
let df = domtree.dominance_frontiers(cfg);
let preds = cfg.predecessors();
let mut variables = collect_defined_vars(cfg);
for p in params {
variables.insert(p.clone());
}
let def_blocks = collect_def_blocks(cfg, &variables);
let mut phi_vars: Vec<BTreeSet<String>> = vec![BTreeSet::new(); n];
for (var, defs) in &def_blocks {
let mut worklist: Vec<BlockId> = defs.iter().copied().collect();
if params.contains(var) {
worklist.push(cfg.entry);
}
let mut has_phi: BTreeSet<u32> = BTreeSet::new();
let mut ever_on_worklist: BTreeSet<u32> =
worklist.iter().map(|b| b.0).collect();
while let Some(block) = worklist.pop() {
for &frontier_block in &df[block.0 as usize] {
if has_phi.insert(frontier_block.0) {
phi_vars[frontier_block.0 as usize].insert(var.clone());
if ever_on_worklist.insert(frontier_block.0) {
worklist.push(frontier_block);
}
}
}
}
}
let mut result_phis: Vec<Vec<PhiNode>> = Vec::with_capacity(n);
for b in 0..n {
let mut block_phis = Vec::new();
let sorted_vars: Vec<String> = phi_vars[b].iter().cloned().collect();
for var in sorted_vars {
let pred_count = preds[b].len();
block_phis.push(PhiNode {
target: SsaVar {
name: var.clone(),
version: 0,
},
sources: vec![
(
BlockId(u32::MAX),
SsaVar {
name: var,
version: 0,
}
);
pred_count
],
});
}
result_phis.push(block_phis);
}
let mut counters: BTreeMap<String, u32> = BTreeMap::new();
let mut stacks: BTreeMap<String, Vec<u32>> = BTreeMap::new();
for var in &variables {
counters.insert(var.clone(), 0);
stacks.insert(var.clone(), Vec::new());
}
for p in params {
let _ = fresh_version(p, &mut counters, &mut stacks);
}
let mut def_versions: BTreeMap<(u32, String, usize), u32> = BTreeMap::new();
rename_dfs(
cfg.entry,
cfg,
&domtree,
&preds,
&mut counters,
&mut stacks,
&mut result_phis,
&mut def_versions,
);
for block_phis in &mut result_phis {
for phi in block_phis.iter_mut() {
phi.sources.sort_by_key(|(bid, _)| bid.0);
}
}
SsaForm {
phis: result_phis,
num_blocks: n,
entry: cfg.entry,
version_counts: counters,
def_versions,
params: params.to_vec(),
}
}
pub fn phi_count(&self) -> usize {
self.phis.iter().map(|p| p.len()).sum()
}
pub fn block_phis(&self, block: BlockId) -> &[PhiNode] {
&self.phis[block.0 as usize]
}
pub fn def_version(&self, block: BlockId, var: &str, stmt_idx: usize) -> Option<u32> {
self.def_versions.get(&(block.0, var.to_string(), stmt_idx)).copied()
}
pub fn total_versions(&self) -> u32 {
self.version_counts.values().sum()
}
}
#[derive(Debug, Clone)]
pub enum SsaError {
DuplicateDefinition {
var: SsaVar,
block1: BlockId,
block2: BlockId,
},
PhiSourceCount {
block: BlockId,
var: SsaVar,
expected: usize,
got: usize,
},
PhiInvalidPredecessor {
block: BlockId,
var: SsaVar,
pred: BlockId,
},
EntryHasPhi {
var: SsaVar,
},
}
impl std::fmt::Display for SsaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SsaError::DuplicateDefinition { var, block1, block2 } => write!(
f,
"SSA: {} defined in both block {} and block {}",
var, block1.0, block2.0
),
SsaError::PhiSourceCount {
block,
var,
expected,
got,
} => write!(
f,
"SSA: phi for {} in block {} has {} sources, expected {}",
var, block.0, got, expected
),
SsaError::PhiInvalidPredecessor { block, var, pred } => write!(
f,
"SSA: phi for {} in block {} references non-predecessor block {}",
var, block.0, pred.0
),
SsaError::EntryHasPhi { var } => {
write!(f, "SSA: entry block has phi for {}", var)
}
}
}
}
pub fn verify_ssa(ssa: &SsaForm, cfg: &MirCfg) -> Result<(), Vec<SsaError>> {
let mut errors = Vec::new();
let preds = cfg.predecessors();
for phi in &ssa.phis[ssa.entry.0 as usize] {
errors.push(SsaError::EntryHasPhi {
var: phi.target.clone(),
});
}
for (b, block_phis) in ssa.phis.iter().enumerate() {
let pred_count = preds[b].len();
for phi in block_phis {
if phi.sources.len() != pred_count {
errors.push(SsaError::PhiSourceCount {
block: BlockId(b as u32),
var: phi.target.clone(),
expected: pred_count,
got: phi.sources.len(),
});
}
for (src_block, _) in &phi.sources {
if !preds[b].contains(src_block) {
errors.push(SsaError::PhiInvalidPredecessor {
block: BlockId(b as u32),
var: phi.target.clone(),
pred: *src_block,
});
}
}
}
}
let mut def_locations: BTreeMap<SsaVar, BlockId> = BTreeMap::new();
for (b, block_phis) in ssa.phis.iter().enumerate() {
let block_id = BlockId(b as u32);
for phi in block_phis {
if let Some(&prev_block) = def_locations.get(&phi.target) {
errors.push(SsaError::DuplicateDefinition {
var: phi.target.clone(),
block1: prev_block,
block2: block_id,
});
} else {
def_locations.insert(phi.target.clone(), block_id);
}
}
}
for ((block_num, name, _idx), ver) in &ssa.def_versions {
let var = SsaVar {
name: name.clone(),
version: *ver,
};
let block_id = BlockId(*block_num);
if let Some(&prev_block) = def_locations.get(&var) {
errors.push(SsaError::DuplicateDefinition {
var,
block1: prev_block,
block2: block_id,
});
} else {
def_locations.insert(var, block_id);
}
}
for p in &ssa.params {
let var = SsaVar {
name: p.clone(),
version: 0,
};
def_locations.entry(var).or_insert(ssa.entry);
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cfg::{BasicBlock, CfgBuilder, Terminator};
use crate::{MirBody, MirExpr, MirExprKind, MirStmt};
fn int_expr(v: i64) -> MirExpr {
MirExpr {
kind: MirExprKind::IntLit(v),
}
}
fn bool_expr(b: bool) -> MirExpr {
MirExpr {
kind: MirExprKind::BoolLit(b),
}
}
fn var_expr(name: &str) -> MirExpr {
MirExpr {
kind: MirExprKind::Var(name.to_string()),
}
}
fn assign_expr(name: &str, value: MirExpr) -> MirExpr {
MirExpr {
kind: MirExprKind::Assign {
target: Box::new(var_expr(name)),
value: Box::new(value),
},
}
}
#[test]
fn test_ssa_single_block_no_phis() {
let cfg = MirCfg {
basic_blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![CfgStmt::Let {
name: "x".into(),
mutable: false,
init: int_expr(42),
}],
terminator: Terminator::Return(None),
}],
entry: BlockId(0),
};
let ssa = SsaForm::construct(&cfg, &[]);
assert_eq!(ssa.phi_count(), 0, "single block should have no phis");
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_linear_chain_no_phis() {
let cfg = MirCfg {
basic_blocks: vec![
BasicBlock {
id: BlockId(0),
statements: vec![CfgStmt::Let {
name: "x".into(),
mutable: true,
init: int_expr(1),
}],
terminator: Terminator::Goto(BlockId(1)),
},
BasicBlock {
id: BlockId(1),
statements: vec![CfgStmt::Expr(assign_expr("x", int_expr(2)))],
terminator: Terminator::Return(None),
},
],
entry: BlockId(0),
};
let ssa = SsaForm::construct(&cfg, &[]);
assert_eq!(ssa.phi_count(), 0);
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_diamond_has_phi_at_merge() {
let cfg = MirCfg {
basic_blocks: vec![
BasicBlock {
id: BlockId(0),
statements: vec![CfgStmt::Let {
name: "x".into(),
mutable: true,
init: int_expr(0),
}],
terminator: Terminator::Branch {
cond: bool_expr(true),
then_block: BlockId(1),
else_block: BlockId(2),
},
},
BasicBlock {
id: BlockId(1),
statements: vec![CfgStmt::Expr(assign_expr("x", int_expr(10)))],
terminator: Terminator::Goto(BlockId(3)),
},
BasicBlock {
id: BlockId(2),
statements: vec![CfgStmt::Expr(assign_expr("x", int_expr(20)))],
terminator: Terminator::Goto(BlockId(3)),
},
BasicBlock {
id: BlockId(3),
statements: vec![],
terminator: Terminator::Return(None),
},
],
entry: BlockId(0),
};
let ssa = SsaForm::construct(&cfg, &[]);
let merge_phis = ssa.block_phis(BlockId(3));
assert_eq!(merge_phis.len(), 1, "merge block should have 1 phi");
assert_eq!(merge_phis[0].target.name, "x");
assert_eq!(merge_phis[0].sources.len(), 2, "phi should have 2 sources");
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_while_loop_phi_at_header() {
let cfg = MirCfg {
basic_blocks: vec![
BasicBlock {
id: BlockId(0),
statements: vec![CfgStmt::Let {
name: "i".into(),
mutable: true,
init: int_expr(0),
}],
terminator: Terminator::Goto(BlockId(1)),
},
BasicBlock {
id: BlockId(1),
statements: vec![],
terminator: Terminator::Branch {
cond: bool_expr(true), then_block: BlockId(2),
else_block: BlockId(3),
},
},
BasicBlock {
id: BlockId(2),
statements: vec![CfgStmt::Expr(assign_expr("i", int_expr(1)))],
terminator: Terminator::Goto(BlockId(1)), },
BasicBlock {
id: BlockId(3),
statements: vec![],
terminator: Terminator::Return(None),
},
],
entry: BlockId(0),
};
let ssa = SsaForm::construct(&cfg, &[]);
let header_phis = ssa.block_phis(BlockId(1));
assert_eq!(header_phis.len(), 1, "loop header should have 1 phi");
assert_eq!(header_phis[0].target.name, "i");
assert_eq!(header_phis[0].sources.len(), 2);
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_params_version_zero() {
let cfg = MirCfg {
basic_blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![],
terminator: Terminator::Return(None),
}],
entry: BlockId(0),
};
let params = vec!["a".to_string(), "b".to_string()];
let ssa = SsaForm::construct(&cfg, ¶ms);
assert!(ssa.version_counts["a"] >= 1);
assert!(ssa.version_counts["b"] >= 1);
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_multiple_vars_diamond() {
let cfg = MirCfg {
basic_blocks: vec![
BasicBlock {
id: BlockId(0),
statements: vec![
CfgStmt::Let {
name: "x".into(),
mutable: true,
init: int_expr(0),
},
CfgStmt::Let {
name: "y".into(),
mutable: true,
init: int_expr(0),
},
],
terminator: Terminator::Branch {
cond: bool_expr(true),
then_block: BlockId(1),
else_block: BlockId(2),
},
},
BasicBlock {
id: BlockId(1),
statements: vec![
CfgStmt::Expr(assign_expr("x", int_expr(10))),
CfgStmt::Expr(assign_expr("y", int_expr(100))),
],
terminator: Terminator::Goto(BlockId(3)),
},
BasicBlock {
id: BlockId(2),
statements: vec![
CfgStmt::Expr(assign_expr("x", int_expr(20))),
CfgStmt::Expr(assign_expr("y", int_expr(200))),
],
terminator: Terminator::Goto(BlockId(3)),
},
BasicBlock {
id: BlockId(3),
statements: vec![],
terminator: Terminator::Return(None),
},
],
entry: BlockId(0),
};
let ssa = SsaForm::construct(&cfg, &[]);
let merge_phis = ssa.block_phis(BlockId(3));
assert_eq!(merge_phis.len(), 2, "merge should have 2 phis");
let phi_names: Vec<&str> = merge_phis.iter().map(|p| p.target.name.as_str()).collect();
assert!(phi_names.contains(&"x"));
assert!(phi_names.contains(&"y"));
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_from_mir_body_if_else() {
let body = MirBody {
stmts: vec![
MirStmt::Let {
name: "x".into(),
mutable: true,
init: int_expr(0),
alloc_hint: None,
},
MirStmt::If {
cond: bool_expr(true),
then_body: MirBody {
stmts: vec![MirStmt::Expr(assign_expr("x", int_expr(1)))],
result: None,
},
else_body: Some(MirBody {
stmts: vec![MirStmt::Expr(assign_expr("x", int_expr(2)))],
result: None,
}),
},
],
result: Some(Box::new(var_expr("x"))),
};
let cfg = CfgBuilder::build(&body);
let ssa = SsaForm::construct(&cfg, &[]);
assert!(ssa.phi_count() >= 1, "should have at least 1 phi");
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_from_mir_body_while() {
let body = MirBody {
stmts: vec![
MirStmt::Let {
name: "i".into(),
mutable: true,
init: int_expr(0),
alloc_hint: None,
},
MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(assign_expr("i", int_expr(1)))],
result: None,
},
},
],
result: None,
};
let cfg = CfgBuilder::build(&body);
let ssa = SsaForm::construct(&cfg, &[]);
assert!(ssa.phi_count() >= 1, "while loop should produce at least 1 phi");
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_empty_cfg() {
let cfg = MirCfg {
basic_blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![],
terminator: Terminator::Return(None),
}],
entry: BlockId(0),
};
let ssa = SsaForm::construct(&cfg, &[]);
assert_eq!(ssa.phi_count(), 0);
assert_eq!(ssa.total_versions(), 0);
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_version_numbering() {
let cfg = MirCfg {
basic_blocks: vec![
BasicBlock {
id: BlockId(0),
statements: vec![CfgStmt::Let {
name: "x".into(),
mutable: true,
init: int_expr(0),
}],
terminator: Terminator::Goto(BlockId(1)),
},
BasicBlock {
id: BlockId(1),
statements: vec![CfgStmt::Expr(assign_expr("x", int_expr(1)))],
terminator: Terminator::Return(None),
},
],
entry: BlockId(0),
};
let ssa = SsaForm::construct(&cfg, &[]);
assert_eq!(ssa.version_counts["x"], 2);
assert_eq!(ssa.def_version(BlockId(0), "x", 0), Some(0));
assert_eq!(ssa.def_version(BlockId(1), "x", 0), Some(1));
assert!(verify_ssa(&ssa, &cfg).is_ok());
}
#[test]
fn test_ssa_phi_sources_match_predecessors() {
let cfg = MirCfg {
basic_blocks: vec![
BasicBlock {
id: BlockId(0),
statements: vec![CfgStmt::Let {
name: "x".into(),
mutable: true,
init: int_expr(0),
}],
terminator: Terminator::Branch {
cond: bool_expr(true),
then_block: BlockId(1),
else_block: BlockId(2),
},
},
BasicBlock {
id: BlockId(1),
statements: vec![CfgStmt::Expr(assign_expr("x", int_expr(10)))],
terminator: Terminator::Goto(BlockId(3)),
},
BasicBlock {
id: BlockId(2),
statements: vec![CfgStmt::Expr(assign_expr("x", int_expr(20)))],
terminator: Terminator::Goto(BlockId(3)),
},
BasicBlock {
id: BlockId(3),
statements: vec![],
terminator: Terminator::Return(None),
},
],
entry: BlockId(0),
};
let ssa = SsaForm::construct(&cfg, &[]);
let phi = &ssa.block_phis(BlockId(3))[0];
let src_blocks: Vec<u32> = phi.sources.iter().map(|(bid, _)| bid.0).collect();
assert!(src_blocks.contains(&1));
assert!(src_blocks.contains(&2));
let src_versions: Vec<u32> = phi.sources.iter().map(|(_, v)| v.version).collect();
assert_ne!(src_versions[0], src_versions[1]);
}
#[test]
fn test_ssa_deterministic() {
let body = MirBody {
stmts: vec![
MirStmt::Let {
name: "a".into(),
mutable: true,
init: int_expr(0),
alloc_hint: None,
},
MirStmt::If {
cond: bool_expr(true),
then_body: MirBody {
stmts: vec![MirStmt::Expr(assign_expr("a", int_expr(1)))],
result: None,
},
else_body: Some(MirBody {
stmts: vec![MirStmt::Expr(assign_expr("a", int_expr(2)))],
result: None,
}),
},
],
result: None,
};
let cfg = CfgBuilder::build(&body);
let ssa1 = SsaForm::construct(&cfg, &[]);
let ssa2 = SsaForm::construct(&cfg, &[]);
assert_eq!(ssa1.phi_count(), ssa2.phi_count());
assert_eq!(ssa1.version_counts, ssa2.version_counts);
assert_eq!(ssa1.def_versions, ssa2.def_versions);
for (b, (p1, p2)) in ssa1.phis.iter().zip(ssa2.phis.iter()).enumerate() {
assert_eq!(p1.len(), p2.len(), "phi count mismatch at block {}", b);
for (phi1, phi2) in p1.iter().zip(p2.iter()) {
assert_eq!(phi1.target, phi2.target);
assert_eq!(phi1.sources.len(), phi2.sources.len());
}
}
}
}