use crate::{BlockId, MirBody, MirExpr, MirStmt};
#[derive(Debug, Clone)]
pub struct MirCfg {
pub basic_blocks: Vec<BasicBlock>,
pub entry: BlockId,
}
impl MirCfg {
pub fn entry_block(&self) -> &BasicBlock {
&self.basic_blocks[self.entry.0 as usize]
}
pub fn block(&self, id: BlockId) -> &BasicBlock {
&self.basic_blocks[id.0 as usize]
}
pub fn successors(&self, id: BlockId) -> Vec<BlockId> {
self.block(id).terminator.successors()
}
pub fn predecessors(&self) -> Vec<Vec<BlockId>> {
let n = self.basic_blocks.len();
let mut preds: Vec<Vec<BlockId>> = vec![Vec::new(); n];
for bb in &self.basic_blocks {
for succ in bb.terminator.successors() {
preds[succ.0 as usize].push(bb.id);
}
}
preds
}
pub fn is_loop_header(&self, id: BlockId) -> bool {
let preds = self.predecessors();
preds[id.0 as usize].iter().any(|p| p.0 >= id.0)
}
}
#[derive(Debug, Clone)]
pub struct BasicBlock {
pub id: BlockId,
pub statements: Vec<CfgStmt>,
pub terminator: Terminator,
}
#[derive(Debug, Clone)]
pub enum CfgStmt {
Let {
name: String,
mutable: bool,
init: MirExpr,
},
Expr(MirExpr),
}
#[derive(Debug, Clone)]
pub enum Terminator {
Goto(BlockId),
Branch {
cond: MirExpr,
then_block: BlockId,
else_block: BlockId,
},
Return(Option<MirExpr>),
Unreachable,
}
impl Terminator {
pub fn successors(&self) -> Vec<BlockId> {
match self {
Terminator::Goto(id) => vec![*id],
Terminator::Branch { then_block, else_block, .. } => {
vec![*then_block, *else_block]
}
Terminator::Return(_) | Terminator::Unreachable => vec![],
}
}
}
pub struct CfgBuilder {
blocks: Vec<BasicBlock>,
next_block: u32,
}
impl CfgBuilder {
pub fn build(body: &MirBody) -> MirCfg {
let mut builder = CfgBuilder {
blocks: Vec::new(),
next_block: 0,
};
let entry = builder.new_block();
let (current, stmts, result_expr) = builder.lower_body(body, entry);
builder.blocks[current.0 as usize].statements = stmts;
let terminator = if let Some(expr) = result_expr {
Terminator::Return(Some(expr))
} else {
Terminator::Return(None)
};
builder.blocks[current.0 as usize].terminator = terminator;
MirCfg {
basic_blocks: builder.blocks,
entry,
}
}
fn new_block(&mut self) -> BlockId {
let id = BlockId(self.next_block);
self.next_block += 1;
self.blocks.push(BasicBlock {
id,
statements: Vec::new(),
terminator: Terminator::Unreachable, });
id
}
fn lower_body(
&mut self,
body: &MirBody,
start: BlockId,
) -> (BlockId, Vec<CfgStmt>, Option<MirExpr>) {
let mut current = start;
let mut stmts: Vec<CfgStmt> = Vec::new();
for stmt in &body.stmts {
match stmt {
MirStmt::Let { name, mutable, init, .. } => {
stmts.push(CfgStmt::Let {
name: name.clone(),
mutable: *mutable,
init: init.clone(),
});
}
MirStmt::Expr(expr) => {
stmts.push(CfgStmt::Expr(expr.clone()));
}
MirStmt::Return(opt_expr) => {
self.blocks[current.0 as usize].statements =
std::mem::take(&mut stmts);
self.blocks[current.0 as usize].terminator =
Terminator::Return(opt_expr.clone());
current = self.new_block();
}
MirStmt::If { cond, then_body, else_body } => {
let then_id = self.new_block();
let merge_id = self.new_block();
let else_id = if else_body.is_some() {
self.new_block()
} else {
merge_id
};
self.blocks[current.0 as usize].statements =
std::mem::take(&mut stmts);
self.blocks[current.0 as usize].terminator =
Terminator::Branch {
cond: cond.clone(),
then_block: then_id,
else_block: else_id,
};
let (then_end, then_stmts, then_result) =
self.lower_body(then_body, then_id);
self.blocks[then_end.0 as usize].statements = then_stmts;
if matches!(
self.blocks[then_end.0 as usize].terminator,
Terminator::Unreachable
) {
if let Some(expr) = then_result {
self.blocks[then_end.0 as usize]
.statements
.push(CfgStmt::Expr(expr));
}
self.blocks[then_end.0 as usize].terminator =
Terminator::Goto(merge_id);
}
if let Some(else_b) = else_body {
let (else_end, else_stmts, else_result) =
self.lower_body(else_b, else_id);
self.blocks[else_end.0 as usize].statements = else_stmts;
if matches!(
self.blocks[else_end.0 as usize].terminator,
Terminator::Unreachable
) {
if let Some(expr) = else_result {
self.blocks[else_end.0 as usize]
.statements
.push(CfgStmt::Expr(expr));
}
self.blocks[else_end.0 as usize].terminator =
Terminator::Goto(merge_id);
}
}
current = merge_id;
}
MirStmt::While { cond, body: loop_body } => {
let header_id = self.new_block();
let body_id = self.new_block();
let exit_id = self.new_block();
self.blocks[current.0 as usize].statements =
std::mem::take(&mut stmts);
self.blocks[current.0 as usize].terminator =
Terminator::Goto(header_id);
self.blocks[header_id.0 as usize].terminator =
Terminator::Branch {
cond: cond.clone(),
then_block: body_id,
else_block: exit_id,
};
let (body_end, body_stmts, _body_result) =
self.lower_body(loop_body, body_id);
self.blocks[body_end.0 as usize].statements = body_stmts;
if matches!(
self.blocks[body_end.0 as usize].terminator,
Terminator::Unreachable
) {
self.blocks[body_end.0 as usize].terminator =
Terminator::Goto(header_id); }
current = exit_id;
}
MirStmt::Break | MirStmt::Continue => {}
MirStmt::NoGcBlock(inner_body) => {
let (next_current, inner_stmts, inner_result) =
self.lower_body(inner_body, current);
if next_current != current {
let mut combined = stmts;
combined.extend(inner_stmts);
stmts = combined;
current = next_current;
} else {
stmts.extend(inner_stmts);
}
if let Some(expr) = inner_result {
stmts.push(CfgStmt::Expr(expr));
}
}
}
}
let tail = body.result.as_ref().map(|e| *e.clone());
(current, stmts, tail)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{MirExpr, MirExprKind, MirStmt};
fn int_expr(v: i64) -> MirExpr {
MirExpr {
kind: MirExprKind::IntLit(v),
}
}
fn bool_expr(b: bool) -> MirExpr {
MirExpr {
kind: MirExprKind::BoolLit(b),
}
}
fn empty_body() -> MirBody {
MirBody {
stmts: vec![],
result: None,
}
}
#[test]
fn test_cfg_straight_line_entry_block() {
let body = MirBody {
stmts: vec![
MirStmt::Let {
name: "x".into(),
mutable: false,
init: int_expr(42),
alloc_hint: None,
},
],
result: Some(Box::new(int_expr(42))),
};
let cfg = CfgBuilder::build(&body);
assert!(cfg.basic_blocks.len() >= 1, "should have at least entry block");
assert_eq!(cfg.entry, BlockId(0));
match &cfg.entry_block().terminator {
Terminator::Return(_) => {}
other => panic!("expected Return, got {:?}", other),
}
}
#[test]
fn test_cfg_if_creates_branch_terminator() {
let body = MirBody {
stmts: vec![
MirStmt::If {
cond: bool_expr(true),
then_body: MirBody {
stmts: vec![],
result: Some(Box::new(int_expr(1))),
},
else_body: Some(MirBody {
stmts: vec![],
result: Some(Box::new(int_expr(2))),
}),
},
],
result: None,
};
let cfg = CfgBuilder::build(&body);
assert!(cfg.basic_blocks.len() >= 3, "if/else should produce >= 3 blocks");
match &cfg.entry_block().terminator {
Terminator::Branch { then_block, else_block, .. } => {
assert_ne!(then_block, else_block, "then and else blocks must be distinct");
}
other => panic!("entry block should have Branch terminator, got {:?}", other),
}
}
#[test]
fn test_cfg_while_creates_back_edge() {
let body = MirBody {
stmts: vec![
MirStmt::While {
cond: bool_expr(false),
body: empty_body(),
},
],
result: None,
};
let cfg = CfgBuilder::build(&body);
assert!(cfg.basic_blocks.len() >= 3, "while should produce >= 3 blocks");
let header = BlockId(1);
assert!(
cfg.is_loop_header(header),
"block {:?} should be detected as a loop header",
header
);
}
#[test]
fn test_cfg_return_terminates_block() {
let body = MirBody {
stmts: vec![
MirStmt::Return(Some(int_expr(99))),
],
result: None,
};
let cfg = CfgBuilder::build(&body);
match &cfg.entry_block().terminator {
Terminator::Return(Some(expr)) => {
assert!(matches!(expr.kind, MirExprKind::IntLit(99)));
}
other => panic!("expected Return(99), got {:?}", other),
}
}
#[test]
fn test_cfg_predecessors_entry_has_no_preds() {
let body = MirBody {
stmts: vec![],
result: Some(Box::new(int_expr(0))),
};
let cfg = CfgBuilder::build(&body);
let preds = cfg.predecessors();
assert_eq!(
preds[cfg.entry.0 as usize].len(),
0,
"entry block should have no predecessors"
);
}
#[test]
fn test_cfg_goto_terminator_successors() {
let term = Terminator::Goto(BlockId(5));
let succs = term.successors();
assert_eq!(succs, vec![BlockId(5)]);
}
#[test]
fn test_cfg_return_has_no_successors() {
let term = Terminator::Return(None);
assert!(term.successors().is_empty());
}
#[test]
fn test_cfg_branch_has_two_successors() {
let term = Terminator::Branch {
cond: bool_expr(true),
then_block: BlockId(2),
else_block: BlockId(3),
};
let succs = term.successors();
assert_eq!(succs.len(), 2);
assert!(succs.contains(&BlockId(2)));
assert!(succs.contains(&BlockId(3)));
}
}