use crate::transform::{InnerInPlaceTransform as _, Transformer};
use crate::{
AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeDef, ControlNodeKind,
ControlNodeOutputDecl, ControlRegion, ControlRegionDef, EntityOrientedDenseMap, FuncDefBody,
FxIndexMap, FxIndexSet, SelectionKind, Type, TypeKind, Value, spv,
};
use itertools::{Either, Itertools};
use smallvec::SmallVec;
use std::mem;
use std::rc::Rc;
#[derive(Clone, Default)]
pub struct ControlFlowGraph {
pub control_inst_on_exit_from: EntityOrientedDenseMap<ControlRegion, ControlInst>,
pub loop_merge_to_loop_header: FxIndexMap<ControlRegion, ControlRegion>,
}
#[derive(Clone)]
pub struct ControlInst {
pub attrs: AttrSet,
pub kind: ControlInstKind,
pub inputs: SmallVec<[Value; 2]>,
pub targets: SmallVec<[ControlRegion; 4]>,
pub target_inputs: FxIndexMap<ControlRegion, SmallVec<[Value; 2]>>,
}
#[derive(Clone)]
pub enum ControlInstKind {
Unreachable,
Return,
ExitInvocation(ExitInvocationKind),
Branch,
SelectBranch(SelectionKind),
}
#[derive(Clone)]
pub enum ExitInvocationKind {
SpvInst(spv::Inst),
}
impl ControlFlowGraph {
pub fn rev_post_order(
&self,
func_def_body: &FuncDefBody,
) -> impl DoubleEndedIterator<Item = ControlRegion> {
let mut post_order = SmallVec::<[_; 8]>::new();
self.traverse_whole_func(func_def_body, &mut TraversalState {
incoming_edge_counts: EntityOrientedDenseMap::new(),
pre_order_visit: |_| {},
post_order_visit: |region| post_order.push(region),
reverse_targets: true,
});
post_order.into_iter().rev()
}
}
mod sealed {
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub(super) struct IncomingEdgeCount(usize);
impl IncomingEdgeCount {
pub(super) const ONE: Self = Self(1);
}
impl std::ops::Add for IncomingEdgeCount {
type Output = Self;
fn add(self, other: Self) -> Self {
Self(self.0 + other.0)
}
}
impl std::ops::AddAssign for IncomingEdgeCount {
fn add_assign(&mut self, other: Self) {
*self = *self + other;
}
}
}
use sealed::IncomingEdgeCount;
struct TraversalState<PreVisit: FnMut(ControlRegion), PostVisit: FnMut(ControlRegion)> {
incoming_edge_counts: EntityOrientedDenseMap<ControlRegion, IncomingEdgeCount>,
pre_order_visit: PreVisit,
post_order_visit: PostVisit,
reverse_targets: bool,
}
impl ControlFlowGraph {
fn traverse_whole_func(
&self,
func_def_body: &FuncDefBody,
state: &mut TraversalState<impl FnMut(ControlRegion), impl FnMut(ControlRegion)>,
) {
let func_at_body = func_def_body.at_body();
assert!(std::ptr::eq(func_def_body.unstructured_cfg.as_ref().unwrap(), self));
assert!(func_at_body.def().outputs.is_empty());
self.traverse(func_def_body.body, state);
}
fn traverse(
&self,
region: ControlRegion,
state: &mut TraversalState<impl FnMut(ControlRegion), impl FnMut(ControlRegion)>,
) {
if let Some(existing_count) = state.incoming_edge_counts.get_mut(region) {
*existing_count += IncomingEdgeCount::ONE;
return;
}
state.incoming_edge_counts.insert(region, IncomingEdgeCount::ONE);
(state.pre_order_visit)(region);
let control_inst = self
.control_inst_on_exit_from
.get(region)
.expect("cfg: missing `ControlInst`, despite having left structured control-flow");
let targets = control_inst.targets.iter().copied();
let targets = if state.reverse_targets {
Either::Left(targets.rev())
} else {
Either::Right(targets)
};
for target in targets {
self.traverse(target, state);
}
(state.post_order_visit)(region);
}
}
struct LoopFinder<'a> {
cfg: &'a ControlFlowGraph,
loop_header_to_exit_targets: FxIndexMap<ControlRegion, FxIndexSet<ControlRegion>>,
scc_stack: Vec<ControlRegion>,
scc_state: EntityOrientedDenseMap<ControlRegion, SccState>,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SccStackIdx(u32);
#[derive(PartialEq, Eq)]
enum SccState {
Pending(SccStackIdx),
Complete(EventualCfgExits),
}
#[derive(Copy, Clone, Default, PartialEq, Eq)]
struct EventualCfgExits {
may_return_from_func: bool,
}
impl std::ops::BitOr for EventualCfgExits {
type Output = Self;
fn bitor(self, other: Self) -> Self {
Self { may_return_from_func: self.may_return_from_func | other.may_return_from_func }
}
}
impl std::ops::BitOrAssign for EventualCfgExits {
fn bitor_assign(&mut self, other: Self) {
*self = *self | other;
}
}
impl<'a> LoopFinder<'a> {
fn new(cfg: &'a ControlFlowGraph) -> Self {
Self {
cfg,
loop_header_to_exit_targets: FxIndexMap::default(),
scc_stack: vec![],
scc_state: EntityOrientedDenseMap::new(),
}
}
fn find_earliest_scc_root_of(
&mut self,
node: ControlRegion,
) -> (Option<SccStackIdx>, EventualCfgExits) {
let state_entry = self.scc_state.entry(node);
if let Some(state) = &state_entry {
return match *state {
SccState::Pending(scc_stack_idx) => {
(Some(scc_stack_idx), EventualCfgExits::default())
}
SccState::Complete(eventual_cfg_exits) => (None, eventual_cfg_exits),
};
}
let scc_stack_idx = SccStackIdx(self.scc_stack.len().try_into().unwrap());
self.scc_stack.push(node);
*state_entry = Some(SccState::Pending(scc_stack_idx));
let control_inst = self
.cfg
.control_inst_on_exit_from
.get(node)
.expect("cfg: missing `ControlInst`, despite having left structured control-flow");
let mut eventual_cfg_exits = EventualCfgExits::default();
if let ControlInstKind::Return = control_inst.kind {
eventual_cfg_exits.may_return_from_func = true;
}
let earliest_scc_root = control_inst
.targets
.iter()
.flat_map(|&target| {
let (earliest_scc_root_of_target, eventual_cfg_exits_of_target) =
self.find_earliest_scc_root_of(target);
eventual_cfg_exits |= eventual_cfg_exits_of_target;
let root_candidate_from_loop_merge =
self.cfg.loop_merge_to_loop_header.get(&target).and_then(|&loop_header| {
match self.scc_state.get(loop_header) {
Some(&SccState::Pending(scc_stack_idx)) => Some(scc_stack_idx),
_ => None,
}
});
earliest_scc_root_of_target.into_iter().chain(root_candidate_from_loop_merge)
})
.min();
if earliest_scc_root == Some(scc_stack_idx) {
let scc_start = scc_stack_idx.0 as usize;
let target_is_exit = |target| {
match self.scc_state[target] {
SccState::Pending(i) => {
assert!(i >= scc_stack_idx);
false
}
SccState::Complete(eventual_cfg_exits_of_target) => {
let EventualCfgExits { may_return_from_func: loop_may_reconverge } =
eventual_cfg_exits;
let EventualCfgExits { may_return_from_func: target_may_reconverge } =
eventual_cfg_exits_of_target;
target_may_reconverge == loop_may_reconverge
}
}
};
self.loop_header_to_exit_targets.insert(
node,
self.scc_stack[scc_start..]
.iter()
.flat_map(|&scc_node| {
self.cfg.control_inst_on_exit_from[scc_node].targets.iter().copied()
})
.filter(|&target| target_is_exit(target))
.collect(),
);
self.scc_state[node] = SccState::Complete(eventual_cfg_exits);
let loop_body_range = scc_start + 1..self.scc_stack.len();
for &scc_node in &self.scc_stack[loop_body_range.clone()] {
self.scc_state.remove(scc_node);
}
for i in loop_body_range.clone() {
self.find_earliest_scc_root_of(self.scc_stack[i]);
}
assert_eq!(self.scc_stack.len(), loop_body_range.end);
self.scc_stack.truncate(scc_start);
return (None, eventual_cfg_exits);
}
if earliest_scc_root.is_none() {
assert!(self.scc_stack.pop() == Some(node));
self.scc_state[node] = SccState::Complete(eventual_cfg_exits);
}
(earliest_scc_root, eventual_cfg_exits)
}
}
#[allow(rustdoc::private_intra_doc_links)]
pub struct Structurizer<'a> {
cx: &'a Context,
type_bool: Type,
const_true: Const,
const_false: Const,
func_def_body: &'a mut FuncDefBody,
loop_header_to_exit_targets: FxIndexMap<ControlRegion, FxIndexSet<ControlRegion>>,
incoming_edge_counts_including_loop_exits:
EntityOrientedDenseMap<ControlRegion, IncomingEdgeCount>,
structurize_region_state: FxIndexMap<ControlRegion, StructurizeRegionState>,
control_region_input_rewrites:
EntityOrientedDenseMap<ControlRegion, ControlRegionInputRewrites>,
}
enum ControlRegionInputRewrites {
ReplaceWith(SmallVec<[Value; 2]>),
RenumberOrReplaceWith(SmallVec<[Result<u32, Value>; 2]>),
}
impl ControlRegionInputRewrites {
fn rewrite_all(
rewrites: &EntityOrientedDenseMap<ControlRegion, Self>,
) -> impl crate::transform::Transformer + '_ {
use crate::transform::*;
struct ReplaceValueWith<F>(F);
impl<F: Fn(Value) -> Option<Value>> Transformer for ReplaceValueWith<F> {
fn transform_value_use(&mut self, v: &Value) -> Transformed<Value> {
self.0(*v).map_or(Transformed::Unchanged, Transformed::Changed)
}
}
ReplaceValueWith(move |v| {
let mut new_v = v;
while let Value::ControlRegionInput { region, input_idx } = new_v {
match rewrites.get(region) {
Some(ControlRegionInputRewrites::ReplaceWith(replacements)) => {
new_v = replacements[input_idx as usize];
}
Some(ControlRegionInputRewrites::RenumberOrReplaceWith(
renumbering_and_replacements,
)) => match renumbering_and_replacements[input_idx as usize] {
Ok(new_idx) => {
new_v = Value::ControlRegionInput { region, input_idx: new_idx };
break;
}
Err(replacement) => new_v = replacement,
},
None => break,
}
}
(v != new_v).then_some(new_v)
})
}
}
enum StructurizeRegionState {
InProgress,
Ready {
accumulated_backedge_count: IncomingEdgeCount,
region_deferred_edges: DeferredEdgeBundleSet,
},
Claimed,
}
struct IncomingEdgeBundle<T> {
target: T,
accumulated_count: IncomingEdgeCount,
target_inputs: SmallVec<[Value; 2]>,
}
impl<T> IncomingEdgeBundle<T> {
fn with_target<U>(self, target: U) -> IncomingEdgeBundle<U> {
let IncomingEdgeBundle { target: _, accumulated_count, target_inputs } = self;
IncomingEdgeBundle { target, accumulated_count, target_inputs }
}
}
struct DeferredEdgeBundle<T = DeferredTarget> {
condition: LazyCond,
edge_bundle: IncomingEdgeBundle<T>,
}
impl<T> DeferredEdgeBundle<T> {
fn with_target<U>(self, target: U) -> DeferredEdgeBundle<U> {
let DeferredEdgeBundle { condition, edge_bundle } = self;
DeferredEdgeBundle { condition, edge_bundle: edge_bundle.with_target(target) }
}
}
#[derive(Clone)]
enum LazyCond {
Undef,
False,
True,
Merge(Rc<LazyCondMerge>),
}
enum LazyCondMerge {
Select {
control_node: ControlNode,
per_case_conds: SmallVec<[LazyCond; 4]>,
},
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
enum DeferredTarget {
Region(ControlRegion),
Return,
}
enum DeferredEdgeBundleSet {
Unreachable,
Always {
target: DeferredTarget,
edge_bundle: IncomingEdgeBundle<()>,
},
Choice {
target_to_deferred: FxIndexMap<DeferredTarget, DeferredEdgeBundle<()>>,
},
}
impl FromIterator<DeferredEdgeBundle> for DeferredEdgeBundleSet {
fn from_iter<T: IntoIterator<Item = DeferredEdgeBundle>>(iter: T) -> Self {
let mut iter = iter.into_iter();
match iter.next() {
None => Self::Unreachable,
Some(first) => match iter.next() {
None => Self::Always {
target: first.edge_bundle.target,
edge_bundle: first.edge_bundle.with_target(()),
},
Some(second) => Self::Choice {
target_to_deferred: ([first, second].into_iter().chain(iter))
.map(|d| (d.edge_bundle.target, d.with_target(())))
.collect(),
},
},
}
}
}
impl From<FxIndexMap<DeferredTarget, DeferredEdgeBundle<()>>> for DeferredEdgeBundleSet {
fn from(target_to_deferred: FxIndexMap<DeferredTarget, DeferredEdgeBundle<()>>) -> Self {
if target_to_deferred.len() <= 1 {
target_to_deferred
.into_iter()
.map(|(target, deferred)| deferred.with_target(target))
.collect()
} else {
Self::Choice { target_to_deferred }
}
}
}
impl DeferredEdgeBundleSet {
fn get_edge_bundle_by_target(
&self,
search_target: DeferredTarget,
) -> Option<&IncomingEdgeBundle<()>> {
match self {
DeferredEdgeBundleSet::Unreachable => None,
DeferredEdgeBundleSet::Always { target, edge_bundle } => {
(*target == search_target).then_some(edge_bundle)
}
DeferredEdgeBundleSet::Choice { target_to_deferred } => {
Some(&target_to_deferred.get(&search_target)?.edge_bundle)
}
}
}
fn get_edge_bundle_mut_by_target(
&mut self,
search_target: DeferredTarget,
) -> Option<&mut IncomingEdgeBundle<()>> {
match self {
DeferredEdgeBundleSet::Unreachable => None,
DeferredEdgeBundleSet::Always { target, edge_bundle } => {
(*target == search_target).then_some(edge_bundle)
}
DeferredEdgeBundleSet::Choice { target_to_deferred } => {
Some(&mut target_to_deferred.get_mut(&search_target)?.edge_bundle)
}
}
}
fn iter_targets_with_edge_bundle(
&self,
) -> impl Iterator<Item = (DeferredTarget, &IncomingEdgeBundle<()>)> {
match self {
DeferredEdgeBundleSet::Unreachable => Either::Left(None.into_iter()),
DeferredEdgeBundleSet::Always { target, edge_bundle } => {
Either::Left(Some((*target, edge_bundle)).into_iter())
}
DeferredEdgeBundleSet::Choice { target_to_deferred } => Either::Right(
target_to_deferred
.iter()
.map(|(&target, deferred)| (target, &deferred.edge_bundle)),
),
}
}
fn iter_targets_with_edge_bundle_mut(
&mut self,
) -> impl Iterator<Item = (DeferredTarget, &mut IncomingEdgeBundle<()>)> {
match self {
DeferredEdgeBundleSet::Unreachable => Either::Left(None.into_iter()),
DeferredEdgeBundleSet::Always { target, edge_bundle } => {
Either::Left(Some((*target, edge_bundle)).into_iter())
}
DeferredEdgeBundleSet::Choice { target_to_deferred } => Either::Right(
target_to_deferred
.iter_mut()
.map(|(&target, deferred)| (target, &mut deferred.edge_bundle)),
),
}
}
fn steal_deferred_by_target_without_removal(
&mut self,
search_target: DeferredTarget,
) -> Option<DeferredEdgeBundle<()>> {
let steal_edge_bundle = |edge_bundle: &mut IncomingEdgeBundle<()>| IncomingEdgeBundle {
target: (),
accumulated_count: edge_bundle.accumulated_count,
target_inputs: mem::take(&mut edge_bundle.target_inputs),
};
match self {
DeferredEdgeBundleSet::Unreachable => None,
DeferredEdgeBundleSet::Always { target, edge_bundle } => (*target == search_target)
.then(|| DeferredEdgeBundle {
condition: LazyCond::True,
edge_bundle: steal_edge_bundle(edge_bundle),
}),
DeferredEdgeBundleSet::Choice { target_to_deferred } => {
let DeferredEdgeBundle { condition, edge_bundle } =
target_to_deferred.get_mut(&search_target)?;
Some(DeferredEdgeBundle {
condition: mem::replace(condition, LazyCond::False),
edge_bundle: steal_edge_bundle(edge_bundle),
})
}
}
}
fn split_out_target(self, split_target: DeferredTarget) -> (Option<DeferredEdgeBundle>, Self) {
match self {
DeferredEdgeBundleSet::Unreachable => (None, DeferredEdgeBundleSet::Unreachable),
DeferredEdgeBundleSet::Always { target, edge_bundle } => {
if target == split_target {
(
Some(DeferredEdgeBundle {
condition: LazyCond::True,
edge_bundle: edge_bundle.with_target(target),
}),
DeferredEdgeBundleSet::Unreachable,
)
} else {
(None, DeferredEdgeBundleSet::Always { target, edge_bundle })
}
}
DeferredEdgeBundleSet::Choice { mut target_to_deferred } => {
(
target_to_deferred
.swap_remove(&split_target)
.map(|d| d.with_target(split_target)),
Self::from(target_to_deferred),
)
}
}
}
fn split_out_matching<T>(
self,
mut matches: impl FnMut(DeferredEdgeBundle) -> Result<T, DeferredEdgeBundle>,
) -> (Option<T>, Self) {
match self {
DeferredEdgeBundleSet::Unreachable => (None, DeferredEdgeBundleSet::Unreachable),
DeferredEdgeBundleSet::Always { target, edge_bundle } => {
match matches(DeferredEdgeBundle {
condition: LazyCond::True,
edge_bundle: edge_bundle.with_target(target),
}) {
Ok(x) => (Some(x), DeferredEdgeBundleSet::Unreachable),
Err(new_deferred) => {
assert!(new_deferred.edge_bundle.target == target);
assert!(matches!(new_deferred.condition, LazyCond::True));
(None, DeferredEdgeBundleSet::Always {
target,
edge_bundle: new_deferred.edge_bundle.with_target(()),
})
}
}
}
DeferredEdgeBundleSet::Choice { mut target_to_deferred } => {
let mut result = None;
for (i, (&target, deferred)) in target_to_deferred.iter_mut().enumerate() {
let taken_deferred = mem::replace(deferred, DeferredEdgeBundle {
condition: LazyCond::False,
edge_bundle: IncomingEdgeBundle {
target: Default::default(),
accumulated_count: Default::default(),
target_inputs: Default::default(),
},
});
match matches(taken_deferred.with_target(target)) {
Ok(x) => {
result = Some(x);
target_to_deferred.shift_remove_index(i).unwrap();
break;
}
Err(new_deferred) => {
assert!(new_deferred.edge_bundle.target == target);
*deferred = new_deferred.with_target(());
}
}
}
(result, Self::from(target_to_deferred))
}
}
}
}
struct ClaimedRegion {
structured_body: ControlRegion,
structured_body_inputs: SmallVec<[Value; 2]>,
deferred_edges: DeferredEdgeBundleSet,
}
impl<'a> Structurizer<'a> {
pub fn new(cx: &'a Context, func_def_body: &'a mut FuncDefBody) -> Self {
let wk = &spv::spec::Spec::get().well_known;
let type_bool = cx.intern(TypeKind::SpvInst {
spv_inst: wk.OpTypeBool.into(),
type_and_const_inputs: [].into_iter().collect(),
});
let const_true = cx.intern(ConstDef {
attrs: AttrSet::default(),
ty: type_bool,
kind: ConstKind::SpvInst {
spv_inst_and_const_inputs: Rc::new((
wk.OpConstantTrue.into(),
[].into_iter().collect(),
)),
},
});
let const_false = cx.intern(ConstDef {
attrs: AttrSet::default(),
ty: type_bool,
kind: ConstKind::SpvInst {
spv_inst_and_const_inputs: Rc::new((
wk.OpConstantFalse.into(),
[].into_iter().collect(),
)),
},
});
let (loop_header_to_exit_targets, incoming_edge_counts_including_loop_exits) =
func_def_body
.unstructured_cfg
.as_ref()
.map(|cfg| {
let loop_header_to_exit_targets = {
let mut loop_finder = LoopFinder::new(cfg);
loop_finder.find_earliest_scc_root_of(func_def_body.body);
loop_finder.loop_header_to_exit_targets
};
let mut state = TraversalState {
incoming_edge_counts: EntityOrientedDenseMap::new(),
pre_order_visit: |_| {},
post_order_visit: |_| {},
reverse_targets: false,
};
cfg.traverse_whole_func(func_def_body, &mut state);
for loop_exit_targets in loop_header_to_exit_targets.values() {
for &exit_target in loop_exit_targets {
*state
.incoming_edge_counts
.entry(exit_target)
.get_or_insert(Default::default()) += IncomingEdgeCount::ONE;
}
}
(loop_header_to_exit_targets, state.incoming_edge_counts)
})
.unwrap_or_default();
Self {
cx,
type_bool,
const_true,
const_false,
func_def_body,
loop_header_to_exit_targets,
incoming_edge_counts_including_loop_exits,
structurize_region_state: FxIndexMap::default(),
control_region_input_rewrites: EntityOrientedDenseMap::new(),
}
}
pub fn structurize_func(mut self) {
if self.func_def_body.unstructured_cfg.is_none() {
return;
}
let func_body_deferred_edges = {
let func_entry_pseudo_edge = {
let target = self.func_def_body.body;
move || IncomingEdgeBundle {
target,
accumulated_count: IncomingEdgeCount::ONE,
target_inputs: [].into_iter().collect(),
}
};
if self.incoming_edge_counts_including_loop_exits[func_entry_pseudo_edge().target]
!= func_entry_pseudo_edge().accumulated_count
{
return;
}
let ClaimedRegion { structured_body, structured_body_inputs, deferred_edges } =
self.try_claim_edge_bundle(func_entry_pseudo_edge()).ok().unwrap();
assert!(structured_body == func_entry_pseudo_edge().target);
assert!(structured_body_inputs == func_entry_pseudo_edge().target_inputs);
deferred_edges
};
match func_body_deferred_edges {
DeferredEdgeBundleSet::Unreachable => {
let mut control_inst_on_exit_from = EntityOrientedDenseMap::new();
control_inst_on_exit_from.insert(self.func_def_body.body, ControlInst {
attrs: AttrSet::default(),
kind: ControlInstKind::Unreachable,
inputs: [].into_iter().collect(),
targets: [].into_iter().collect(),
target_inputs: FxIndexMap::default(),
});
self.func_def_body.unstructured_cfg = Some(ControlFlowGraph {
control_inst_on_exit_from,
loop_merge_to_loop_header: Default::default(),
});
}
DeferredEdgeBundleSet::Always { target: DeferredTarget::Return, edge_bundle } => {
let body_def = self.func_def_body.at_mut_body().def();
body_def.outputs = edge_bundle.target_inputs;
self.func_def_body.unstructured_cfg = None;
}
_ => {
let structurize_region_state = mem::take(&mut self.structurize_region_state)
.into_iter()
.chain([(self.func_def_body.body, StructurizeRegionState::Ready {
accumulated_backedge_count: IncomingEdgeCount::default(),
region_deferred_edges: func_body_deferred_edges,
})]);
for (target, state) in structurize_region_state {
if let StructurizeRegionState::Ready { region_deferred_edges, .. } = state {
self.rebuild_cfg_from_unclaimed_region_deferred_edges(
target,
region_deferred_edges,
);
}
}
}
}
self.func_def_body.inner_in_place_transform_with(
&mut ControlRegionInputRewrites::rewrite_all(&self.control_region_input_rewrites),
);
}
fn try_claim_edge_bundle(
&mut self,
edge_bundle: IncomingEdgeBundle<ControlRegion>,
) -> Result<ClaimedRegion, IncomingEdgeBundle<ControlRegion>> {
let target = edge_bundle.target;
if self.structurize_region_state.get(&target).is_none() {
self.structurize_region(target);
}
let backedge_count = match self.structurize_region_state[&target] {
StructurizeRegionState::InProgress => IncomingEdgeCount::default(),
StructurizeRegionState::Ready { accumulated_backedge_count, .. } => {
accumulated_backedge_count
}
StructurizeRegionState::Claimed => {
unreachable!("cfg::Structurizer::try_claim_edge_bundle: already claimed");
}
};
if self.incoming_edge_counts_including_loop_exits[target]
!= edge_bundle.accumulated_count + backedge_count
{
return Err(edge_bundle);
}
let state =
self.structurize_region_state.insert(target, StructurizeRegionState::Claimed).unwrap();
let mut deferred_edges = match state {
StructurizeRegionState::InProgress => unreachable!(
"cfg::Structurizer::try_claim_edge_bundle: cyclic calls \
should not get this far"
),
StructurizeRegionState::Ready { region_deferred_edges, .. } => region_deferred_edges,
StructurizeRegionState::Claimed => {
unreachable!()
}
};
let mut backedge = None;
if backedge_count != IncomingEdgeCount::default() {
(backedge, deferred_edges) =
deferred_edges.split_out_target(DeferredTarget::Region(target));
}
let structured_body = if let Some(backedge) = backedge {
let DeferredEdgeBundle { condition: repeat_condition, edge_bundle: backedge } =
backedge;
let body = target;
let wrapper_region =
self.func_def_body.control_regions.define(self.cx, ControlRegionDef::default());
let body_def = self.func_def_body.at_mut(body).def();
let original_input_decls = mem::take(&mut body_def.inputs);
assert!(body_def.outputs.is_empty());
let mut initial_inputs = SmallVec::<[_; 2]>::new();
let body_input_rewrites = ControlRegionInputRewrites::RenumberOrReplaceWith(
backedge
.target_inputs
.into_iter()
.enumerate()
.map(|(original_idx, mut backedge_value)| {
ControlRegionInputRewrites::rewrite_all(
&self.control_region_input_rewrites,
)
.transform_value_use(&backedge_value)
.apply_to(&mut backedge_value);
let original_idx = u32::try_from(original_idx).unwrap();
if backedge_value
== (Value::ControlRegionInput { region: body, input_idx: original_idx })
{
Err(Value::ControlRegionInput {
region: wrapper_region,
input_idx: original_idx,
})
} else {
let renumbered_idx = u32::try_from(body_def.inputs.len()).unwrap();
initial_inputs.push(Value::ControlRegionInput {
region: wrapper_region,
input_idx: original_idx,
});
body_def.inputs.push(original_input_decls[original_idx as usize]);
body_def.outputs.push(backedge_value);
Ok(renumbered_idx)
}
})
.collect(),
);
self.control_region_input_rewrites.insert(body, body_input_rewrites);
assert_eq!(initial_inputs.len(), body_def.inputs.len());
assert_eq!(body_def.outputs.len(), body_def.inputs.len());
let repeat_condition = self.materialize_lazy_cond(&repeat_condition);
let loop_node = self.func_def_body.control_nodes.define(
self.cx,
ControlNodeDef {
kind: ControlNodeKind::Loop { initial_inputs, body, repeat_condition },
outputs: [].into_iter().collect(),
}
.into(),
);
let wrapper_region_def = &mut self.func_def_body.control_regions[wrapper_region];
wrapper_region_def.inputs = original_input_decls;
wrapper_region_def
.children
.insert_last(loop_node, &mut self.func_def_body.control_nodes);
if let Some(exit_targets) = self.loop_header_to_exit_targets.get(&target) {
for &exit_target in exit_targets {
if let Some(exit_edge_bundle) = deferred_edges
.get_edge_bundle_mut_by_target(DeferredTarget::Region(exit_target))
{
exit_edge_bundle.accumulated_count += IncomingEdgeCount::ONE;
}
}
}
wrapper_region
} else {
target
};
Ok(ClaimedRegion {
structured_body,
structured_body_inputs: edge_bundle.target_inputs,
deferred_edges,
})
}
fn structurize_region(&mut self, region: ControlRegion) {
{
let old_state =
self.structurize_region_state.insert(region, StructurizeRegionState::InProgress);
if let Some(old_state) = old_state {
unreachable!(
"cfg::Structurizer::structurize_region: \
already {}, when attempting to start structurization",
match old_state {
StructurizeRegionState::InProgress => "in progress (cycle detected)",
StructurizeRegionState::Ready { .. } => "completed",
StructurizeRegionState::Claimed => "claimed",
}
);
}
}
let control_inst_on_exit = self
.func_def_body
.unstructured_cfg
.as_mut()
.unwrap()
.control_inst_on_exit_from
.remove(region)
.expect(
"cfg::Structurizer::structurize_region: missing \
`ControlInst` (CFG wasn't unstructured in the first place?)",
);
let mut deferred_edges = {
let ControlInst { attrs, kind, inputs, targets, target_inputs } = control_inst_on_exit;
let _ = attrs;
let target_regions: SmallVec<[_; 8]> = targets
.iter()
.map(|&target| {
self.try_claim_edge_bundle(IncomingEdgeBundle {
target,
accumulated_count: IncomingEdgeCount::ONE,
target_inputs: target_inputs.get(&target).cloned().unwrap_or_default(),
})
.map_err(|edge_bundle| {
let target_is_trivial_unreachable =
match self.structurize_region_state.get(&edge_bundle.target) {
Some(StructurizeRegionState::Ready {
region_deferred_edges: DeferredEdgeBundleSet::Unreachable,
..
}) => {
self.func_def_body
.at(edge_bundle.target)
.at_children()
.into_iter()
.next()
.is_none()
}
_ => false,
};
if target_is_trivial_unreachable {
DeferredEdgeBundleSet::Unreachable
} else {
DeferredEdgeBundleSet::Always {
target: DeferredTarget::Region(edge_bundle.target),
edge_bundle: edge_bundle.with_target(()),
}
}
})
})
.collect();
match kind {
ControlInstKind::Unreachable => {
assert_eq!((inputs.len(), target_regions.len()), (0, 0));
DeferredEdgeBundleSet::Unreachable
}
ControlInstKind::ExitInvocation(kind) => {
assert_eq!(target_regions.len(), 0);
let control_node = self.func_def_body.control_nodes.define(
self.cx,
ControlNodeDef {
kind: ControlNodeKind::ExitInvocation { kind, inputs },
outputs: [].into_iter().collect(),
}
.into(),
);
self.func_def_body.control_regions[region]
.children
.insert_last(control_node, &mut self.func_def_body.control_nodes);
DeferredEdgeBundleSet::Unreachable
}
ControlInstKind::Return => {
assert_eq!(target_regions.len(), 0);
DeferredEdgeBundleSet::Always {
target: DeferredTarget::Return,
edge_bundle: IncomingEdgeBundle {
accumulated_count: IncomingEdgeCount::default(),
target: (),
target_inputs: inputs,
},
}
}
ControlInstKind::Branch => {
assert_eq!((inputs.len(), target_regions.len()), (0, 1));
self.append_maybe_claimed_region(
region,
target_regions.into_iter().next().unwrap(),
)
}
ControlInstKind::SelectBranch(kind) => {
assert_eq!(inputs.len(), 1);
let scrutinee = inputs[0];
self.structurize_select_into(region, kind, Ok(scrutinee), target_regions)
}
}
};
loop {
let (claimed, else_deferred_edges) = deferred_edges.split_out_matching(|deferred| {
let deferred_target = deferred.edge_bundle.target;
let DeferredEdgeBundle { condition, edge_bundle } = match deferred_target {
DeferredTarget::Region(target) => deferred.with_target(target),
DeferredTarget::Return => return Err(deferred),
};
match self.try_claim_edge_bundle(edge_bundle) {
Ok(claimed_region) => Ok((condition, claimed_region)),
Err(new_edge_bundle) => {
let new_target = DeferredTarget::Region(new_edge_bundle.target);
Err(DeferredEdgeBundle {
condition,
edge_bundle: new_edge_bundle.with_target(new_target),
})
}
}
});
let Some((condition, then_region)) = claimed else {
deferred_edges = else_deferred_edges;
break;
};
deferred_edges = self.structurize_select_into(
region,
SelectionKind::BoolCond,
Err(&condition),
[Ok(then_region), Err(else_deferred_edges)].into_iter().collect(),
);
}
let accumulated_backedge_count = deferred_edges
.get_edge_bundle_by_target(DeferredTarget::Region(region))
.map(|backedge| backedge.accumulated_count)
.unwrap_or_default();
let old_state =
self.structurize_region_state.insert(region, StructurizeRegionState::Ready {
accumulated_backedge_count,
region_deferred_edges: deferred_edges,
});
if !matches!(old_state, Some(StructurizeRegionState::InProgress)) {
unreachable!(
"cfg::Structurizer::structurize_region: \
already {}, when attempting to store structurization result",
match old_state {
None => "reverted to missing (removed from the map?)",
Some(StructurizeRegionState::InProgress) => unreachable!(),
Some(StructurizeRegionState::Ready { .. }) => "completed",
Some(StructurizeRegionState::Claimed) => "claimed",
}
);
}
}
fn structurize_select_into(
&mut self,
parent_region: ControlRegion,
kind: SelectionKind,
scrutinee: Result<Value, &LazyCond>,
mut cases: SmallVec<[Result<ClaimedRegion, DeferredEdgeBundleSet>; 8]>,
) -> DeferredEdgeBundleSet {
let convergent_cases = cases.iter_mut().filter(|case| match case {
Ok(ClaimedRegion { deferred_edges, .. }) | Err(deferred_edges) => {
!matches!(deferred_edges, DeferredEdgeBundleSet::Unreachable)
}
});
if let Ok(convergent_case) = convergent_cases.exactly_one() {
let convergent_case =
mem::replace(convergent_case, Err(DeferredEdgeBundleSet::Unreachable));
let deferred_edges =
self.structurize_select_into(parent_region, kind, scrutinee, cases);
assert!(matches!(deferred_edges, DeferredEdgeBundleSet::Unreachable));
return self.append_maybe_claimed_region(parent_region, convergent_case);
}
let mut cached_select_node = None;
let mut non_move_kind = Some(kind);
let mut get_or_define_select_node = |this: &mut Self, cases: &[_]| {
*cached_select_node.get_or_insert_with(|| {
let kind = non_move_kind.take().unwrap();
let cases = cases
.iter()
.map(|case| {
let case_region = match case {
&Ok(ClaimedRegion { structured_body, .. }) => structured_body,
Err(_) => this
.func_def_body
.control_regions
.define(this.cx, ControlRegionDef::default()),
};
let case_region_def = this.func_def_body.at_mut(case_region).def();
case_region_def.outputs.clear();
case_region
})
.collect();
let scrutinee =
scrutinee.unwrap_or_else(|lazy_cond| this.materialize_lazy_cond(lazy_cond));
let select_node = this.func_def_body.control_nodes.define(
this.cx,
ControlNodeDef {
kind: ControlNodeKind::Select { kind, scrutinee, cases },
outputs: [].into_iter().collect(),
}
.into(),
);
this.func_def_body.control_regions[parent_region]
.children
.insert_last(select_node, &mut this.func_def_body.control_nodes);
select_node
})
};
let any_non_empty_case = cases.iter().any(|case| {
case.as_ref().is_ok_and(|&ClaimedRegion { structured_body, .. }| {
self.func_def_body.at(structured_body).at_children().into_iter().next().is_some()
})
});
if any_non_empty_case {
get_or_define_select_node(self, &cases);
}
struct DeferredTargetSummary {
input_count: usize,
total_edge_count: IncomingEdgeCount,
}
let mut deferred_targets = FxIndexMap::default();
for case in &cases {
let case_deferred_edges = match case {
Ok(ClaimedRegion { deferred_edges, .. }) | Err(deferred_edges) => deferred_edges,
};
for (target, edge_bundle) in case_deferred_edges.iter_targets_with_edge_bundle() {
let input_count = edge_bundle.target_inputs.len();
let summary = deferred_targets.entry(target).or_insert(DeferredTargetSummary {
input_count,
total_edge_count: IncomingEdgeCount::default(),
});
assert_eq!(summary.input_count, input_count);
summary.total_edge_count += edge_bundle.accumulated_count;
}
}
for case in &mut cases {
let (case_structured_body_inputs, case_deferred_edges) = match case {
Ok(ClaimedRegion { structured_body_inputs, deferred_edges, .. }) => {
(&mut structured_body_inputs[..], deferred_edges)
}
Err(deferred_edges) => (&mut [][..], deferred_edges),
};
let all_values = case_structured_body_inputs.iter_mut().chain(
case_deferred_edges
.iter_targets_with_edge_bundle_mut()
.flat_map(|(_, edge_bundle)| &mut edge_bundle.target_inputs),
);
for v in all_values {
ControlRegionInputRewrites::rewrite_all(&self.control_region_input_rewrites)
.transform_value_use(v)
.apply_to(v);
}
}
let deferred_edges = deferred_targets.into_iter().map(|(target, target_summary)| {
let DeferredTargetSummary { input_count, total_edge_count } = target_summary;
let per_case_deferred: SmallVec<[Result<DeferredEdgeBundle<()>, LazyCond>; 8]> = cases
.iter_mut()
.map(|case| match case {
Ok(ClaimedRegion { deferred_edges, .. }) | Err(deferred_edges) => {
if let DeferredEdgeBundleSet::Unreachable = deferred_edges {
Err(LazyCond::Undef)
} else {
deferred_edges
.steal_deferred_by_target_without_removal(target)
.ok_or(LazyCond::False)
}
}
})
.collect();
let target_inputs = (0..input_count)
.map(|target_input_idx| {
let per_case_target_input = per_case_deferred.iter().map(|per_case_deferred| {
per_case_deferred.as_ref().ok().map(
|DeferredEdgeBundle { edge_bundle, .. }| {
edge_bundle.target_inputs[target_input_idx]
},
)
});
let unique_target_input_value = per_case_target_input
.clone()
.zip_eq(&cases)
.filter_map(|(v, case)| Some((v?, case)))
.map(|(v, case)| {
match case {
Ok(ClaimedRegion {
structured_body,
structured_body_inputs,
..
}) => match v {
Value::Const(_) => Ok(v),
Value::ControlRegionInput { region, input_idx }
if region == *structured_body =>
{
Ok(structured_body_inputs[input_idx as usize])
}
_ => Err(()),
},
Err(_) => Ok(v),
}
})
.dedup()
.exactly_one();
if let Ok(Ok(v)) = unique_target_input_value {
return v;
}
let ty = match target {
DeferredTarget::Region(target) => {
self.func_def_body.at(target).def().inputs[target_input_idx].ty
}
DeferredTarget::Return => per_case_target_input
.clone()
.flatten()
.map(|v| self.func_def_body.at(v).type_of(self.cx))
.dedup()
.exactly_one()
.ok()
.expect("mismatched `return`ed value types"),
};
let select_node = get_or_define_select_node(self, &cases);
let output_decls = &mut self.func_def_body.at_mut(select_node).def().outputs;
let output_idx = output_decls.len();
output_decls.push(ControlNodeOutputDecl { attrs: AttrSet::default(), ty });
for (case_idx, v) in per_case_target_input.enumerate() {
let v = v.unwrap_or_else(|| Value::Const(self.const_undef(ty)));
let case_region = match &self.func_def_body.at(select_node).def().kind {
ControlNodeKind::Select { cases, .. } => cases[case_idx],
_ => unreachable!(),
};
let outputs = &mut self.func_def_body.at_mut(case_region).def().outputs;
assert_eq!(outputs.len(), output_idx);
outputs.push(v);
}
Value::ControlNodeOutput {
control_node: select_node,
output_idx: output_idx.try_into().unwrap(),
}
})
.collect();
let per_case_conds =
per_case_deferred.iter().map(|per_case_deferred| match per_case_deferred {
Ok(DeferredEdgeBundle { condition, .. }) => condition,
Err(undef_or_false) => undef_or_false,
});
let condition = if per_case_conds
.clone()
.all(|cond| matches!(cond, LazyCond::Undef | LazyCond::True))
{
LazyCond::True
} else {
LazyCond::Merge(Rc::new(LazyCondMerge::Select {
control_node: get_or_define_select_node(self, &cases),
per_case_conds: per_case_conds.cloned().collect(),
}))
};
DeferredEdgeBundle {
condition,
edge_bundle: IncomingEdgeBundle {
target,
accumulated_count: total_edge_count,
target_inputs,
},
}
});
let deferred_edges = deferred_edges.collect();
#[allow(clippy::manual_flatten)]
for case in cases {
if let Ok(ClaimedRegion { structured_body, structured_body_inputs, .. }) = case {
if !structured_body_inputs.is_empty() {
self.control_region_input_rewrites.insert(
structured_body,
ControlRegionInputRewrites::ReplaceWith(structured_body_inputs),
);
self.func_def_body.at_mut(structured_body).def().inputs.clear();
}
}
}
deferred_edges
}
fn materialize_lazy_cond(&mut self, cond: &LazyCond) -> Value {
match cond {
LazyCond::Undef => Value::Const(self.const_undef(self.type_bool)),
LazyCond::False => Value::Const(self.const_false),
LazyCond::True => Value::Const(self.const_true),
LazyCond::Merge(merge) => {
let LazyCondMerge::Select { control_node, ref per_case_conds } = **merge;
let per_case_conds: SmallVec<[_; 8]> = per_case_conds
.into_iter()
.map(|cond| self.materialize_lazy_cond(cond))
.collect();
let ControlNodeDef { kind, outputs: output_decls } =
&mut *self.func_def_body.control_nodes[control_node];
let cases = match kind {
ControlNodeKind::Select { kind, scrutinee, cases } => {
assert_eq!(cases.len(), per_case_conds.len());
if let SelectionKind::BoolCond = kind {
let [val_false, val_true] =
[self.const_false, self.const_true].map(Value::Const);
if per_case_conds[..] == [val_true, val_false] {
return *scrutinee;
} else if per_case_conds[..] == [val_false, val_true] {
let _not_cond = *scrutinee;
}
}
cases
}
_ => unreachable!(),
};
let output_idx = u32::try_from(output_decls.len()).unwrap();
output_decls
.push(ControlNodeOutputDecl { attrs: AttrSet::default(), ty: self.type_bool });
for (&case, cond) in cases.iter().zip_eq(per_case_conds) {
let ControlRegionDef { outputs, .. } =
&mut self.func_def_body.control_regions[case];
outputs.push(cond);
assert_eq!(outputs.len(), output_decls.len());
}
Value::ControlNodeOutput { control_node, output_idx }
}
}
}
fn append_maybe_claimed_region(
&mut self,
parent_region: ControlRegion,
maybe_claimed_region: Result<ClaimedRegion, DeferredEdgeBundleSet>,
) -> DeferredEdgeBundleSet {
match maybe_claimed_region {
Ok(ClaimedRegion { structured_body, structured_body_inputs, deferred_edges }) => {
if !structured_body_inputs.is_empty() {
self.control_region_input_rewrites.insert(
structured_body,
ControlRegionInputRewrites::ReplaceWith(structured_body_inputs),
);
}
let new_children =
mem::take(&mut self.func_def_body.at_mut(structured_body).def().children);
self.func_def_body.control_regions[parent_region]
.children
.append(new_children, &mut self.func_def_body.control_nodes);
deferred_edges
}
Err(deferred_edges) => deferred_edges,
}
}
fn rebuild_cfg_from_unclaimed_region_deferred_edges(
&mut self,
region: ControlRegion,
mut deferred_edges: DeferredEdgeBundleSet,
) {
assert!(
self.structurize_region_state.is_empty(),
"cfg::Structurizer::rebuild_cfg_from_unclaimed_region_deferred_edges:
must only be called from `structurize_func`, \
after it takes `structurize_region_state`"
);
let mut control_source = Some(region);
loop {
let taken_then;
(taken_then, deferred_edges) =
deferred_edges.split_out_matching(|deferred| match deferred.edge_bundle.target {
DeferredTarget::Region(target) => {
Ok((deferred.condition, (target, deferred.edge_bundle.target_inputs)))
}
DeferredTarget::Return => Err(deferred),
});
let Some((condition, then_target_and_inputs)) = taken_then else {
break;
};
let branch_source = control_source.take().unwrap();
let else_target_and_inputs = match deferred_edges {
DeferredEdgeBundleSet::Unreachable => None,
DeferredEdgeBundleSet::Always {
target: DeferredTarget::Region(else_target),
edge_bundle,
} => {
deferred_edges = DeferredEdgeBundleSet::Unreachable;
Some((else_target, edge_bundle.target_inputs))
}
_ => {
let new_empty_region = self
.func_def_body
.control_regions
.define(self.cx, ControlRegionDef::default());
control_source = Some(new_empty_region);
Some((new_empty_region, [].into_iter().collect()))
}
};
let condition = Some(condition)
.filter(|_| else_target_and_inputs.is_some())
.map(|cond| self.materialize_lazy_cond(&cond));
let branch_control_inst = ControlInst {
attrs: AttrSet::default(),
kind: if condition.is_some() {
ControlInstKind::SelectBranch(SelectionKind::BoolCond)
} else {
ControlInstKind::Branch
},
inputs: condition.into_iter().collect(),
targets: [&then_target_and_inputs]
.into_iter()
.chain(&else_target_and_inputs)
.map(|&(target, _)| target)
.collect(),
target_inputs: [then_target_and_inputs]
.into_iter()
.chain(else_target_and_inputs)
.filter(|(_, inputs)| !inputs.is_empty())
.collect(),
};
assert!(
self.func_def_body
.unstructured_cfg
.as_mut()
.unwrap()
.control_inst_on_exit_from
.insert(branch_source, branch_control_inst)
.is_none()
);
}
let deferred_return = match deferred_edges {
DeferredEdgeBundleSet::Unreachable => None,
DeferredEdgeBundleSet::Always { target: DeferredTarget::Return, edge_bundle } => {
Some(edge_bundle.target_inputs)
}
_ => unreachable!(),
};
let final_source = match control_source {
Some(region) => region,
None => {
assert!(deferred_return.is_none());
return;
}
};
let final_control_inst = {
let (kind, inputs) = match deferred_return {
Some(return_values) => (ControlInstKind::Return, return_values),
None => (ControlInstKind::Unreachable, [].into_iter().collect()),
};
ControlInst {
attrs: AttrSet::default(),
kind,
inputs,
targets: [].into_iter().collect(),
target_inputs: FxIndexMap::default(),
}
};
assert!(
self.func_def_body
.unstructured_cfg
.as_mut()
.unwrap()
.control_inst_on_exit_from
.insert(final_source, final_control_inst)
.is_none()
);
}
fn const_undef(&self, ty: Type) -> Const {
let wk = &spv::spec::Spec::get().well_known;
self.cx.intern(ConstDef {
attrs: AttrSet::default(),
ty,
kind: ConstKind::SpvInst {
spv_inst_and_const_inputs: Rc::new((wk.OpUndef.into(), [].into_iter().collect())),
},
})
}
}