use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
use crate::{
analysis::cfg::SsaCfg,
graph::{
algorithms::{compute_dominance_frontiers, compute_dominators},
GraphBase, NodeId, RootedGraph, Successors,
},
ir::{function::SsaFunction, ops::SsaOp, variable::SsaVarId},
target::Target,
};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum MemoryLocation<T: Target> {
InstanceField(SsaVarId, T::FieldRef),
StaticField(T::FieldRef),
ArrayElement(SsaVarId, ArrayIndex),
Indirect(SsaVarId),
Unknown,
}
impl<T: Target> MemoryLocation<T> {
#[must_use]
pub fn base_object(&self) -> Option<SsaVarId> {
match self {
Self::InstanceField(obj, _) => Some(*obj),
Self::ArrayElement(arr, _) => Some(*arr),
Self::Indirect(ptr) => Some(*ptr),
Self::StaticField(_) | Self::Unknown => None,
}
}
#[must_use]
pub fn may_alias(&self, other: &Self) -> bool {
match (self, other) {
(Self::Unknown, _)
| (_, Self::Unknown)
| (
Self::Indirect(_),
Self::InstanceField(..) | Self::ArrayElement(..) | Self::StaticField(_),
)
| (
Self::InstanceField(..) | Self::ArrayElement(..) | Self::StaticField(_),
Self::Indirect(_),
) => true,
(Self::StaticField(f1), Self::StaticField(f2)) => f1 == f2,
(Self::StaticField(_), Self::InstanceField(..) | Self::ArrayElement(..))
| (Self::InstanceField(..) | Self::ArrayElement(..), Self::StaticField(_))
| (Self::InstanceField(..), Self::ArrayElement(..))
| (Self::ArrayElement(..), Self::InstanceField(..)) => false,
(Self::InstanceField(obj1, f1), Self::InstanceField(obj2, f2)) => {
obj1 == obj2 && f1 == f2
}
(Self::ArrayElement(arr1, idx1), Self::ArrayElement(arr2, idx2)) => {
arr1 == arr2 && idx1.may_overlap(idx2)
}
(Self::Indirect(p1), Self::Indirect(p2)) => p1 == p2,
}
}
#[must_use]
pub fn must_alias(&self, other: &Self) -> bool {
match (self, other) {
(Self::StaticField(f1), Self::StaticField(f2)) => f1 == f2,
(Self::InstanceField(obj1, f1), Self::InstanceField(obj2, f2)) => {
obj1 == obj2 && f1 == f2
}
(Self::ArrayElement(arr1, idx1), Self::ArrayElement(arr2, idx2)) => {
arr1 == arr2 && idx1.must_equal(idx2)
}
(Self::Indirect(p1), Self::Indirect(p2)) => p1 == p2,
_ => false,
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum ArrayIndex {
Constant(i64),
Variable(SsaVarId),
Unknown,
}
impl ArrayIndex {
#[must_use]
pub fn may_overlap(&self, other: &Self) -> bool {
match (self, other) {
(Self::Unknown | Self::Variable(_), _) | (_, Self::Unknown | Self::Variable(_)) => true,
(Self::Constant(i1), Self::Constant(i2)) => i1 == i2,
}
}
#[must_use]
pub fn must_equal(&self, other: &Self) -> bool {
match (self, other) {
(Self::Constant(i1), Self::Constant(i2)) => i1 == i2,
(Self::Variable(v1), Self::Variable(v2)) => v1 == v2,
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub enum MemoryOp<T: Target> {
Load {
location: MemoryLocation<T>,
dest: SsaVarId,
block: usize,
instr: usize,
},
Store {
location: MemoryLocation<T>,
value: SsaVarId,
block: usize,
instr: usize,
},
}
impl<T: Target> MemoryOp<T> {
#[must_use]
pub fn location(&self) -> &MemoryLocation<T> {
match self {
Self::Load { location, .. } | Self::Store { location, .. } => location,
}
}
#[must_use]
pub fn block(&self) -> usize {
match self {
Self::Load { block, .. } | Self::Store { block, .. } => *block,
}
}
#[must_use]
pub fn instr(&self) -> usize {
match self {
Self::Load { instr, .. } | Self::Store { instr, .. } => *instr,
}
}
#[must_use]
pub fn is_store(&self) -> bool {
matches!(self, Self::Store { .. })
}
#[must_use]
pub fn is_load(&self) -> bool {
matches!(self, Self::Load { .. })
}
}
#[derive(Debug, Clone)]
pub struct MemoryPhi<T: Target> {
pub location: MemoryLocation<T>,
pub result_version: u32,
pub operands: Vec<MemoryPhiOperand>,
}
impl<T: Target> MemoryPhi<T> {
#[must_use]
pub fn new(location: MemoryLocation<T>, result_version: u32) -> Self {
Self {
location,
result_version,
operands: Vec::new(),
}
}
pub fn add_operand(&mut self, predecessor: usize, version: u32) {
self.operands.push(MemoryPhiOperand {
predecessor,
version,
});
}
#[must_use]
pub fn operand_from(&self, predecessor: usize) -> Option<&MemoryPhiOperand> {
self.operands
.iter()
.find(|op| op.predecessor == predecessor)
}
}
#[derive(Debug, Clone)]
pub struct MemoryPhiOperand {
pub predecessor: usize,
pub version: u32,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct MemoryVersion<T: Target> {
pub location: MemoryLocation<T>,
pub version: u32,
}
impl<T: Target> MemoryVersion<T> {
#[must_use]
pub fn new(location: MemoryLocation<T>, version: u32) -> Self {
Self { location, version }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryDefSite {
Entry,
Store {
block: usize,
instr: usize,
},
Phi {
block: usize,
},
}
#[derive(Debug)]
pub struct MemorySsa<T: Target> {
next_version: HashMap<MemoryLocation<T>, u32>,
memory_phis: BTreeMap<usize, Vec<MemoryPhi<T>>>,
definitions: HashMap<MemoryVersion<T>, MemoryDefSite>,
entry_versions: HashMap<(MemoryLocation<T>, usize), u32>,
exit_versions: HashMap<(MemoryLocation<T>, usize), u32>,
operations: Vec<MemoryOp<T>>,
locations: HashSet<MemoryLocation<T>>,
}
impl<T: Target> MemorySsa<T> {
#[must_use]
pub fn new() -> Self {
Self {
next_version: HashMap::new(),
memory_phis: BTreeMap::new(),
definitions: HashMap::new(),
entry_versions: HashMap::new(),
exit_versions: HashMap::new(),
operations: Vec::new(),
locations: HashSet::new(),
}
}
#[must_use]
pub fn build(ssa: &SsaFunction<T>, cfg: &SsaCfg<'_, T>) -> Self {
let mut mem_ssa = Self::new();
mem_ssa.identify_memory_operations(ssa);
mem_ssa.place_memory_phis(cfg);
mem_ssa.rename_memory_versions(ssa, cfg);
mem_ssa
}
#[must_use]
pub fn memory_phis(&self, block: usize) -> &[MemoryPhi<T>] {
self.memory_phis.get(&block).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn operations(&self) -> &[MemoryOp<T>] {
&self.operations
}
#[must_use]
pub fn locations(&self) -> &HashSet<MemoryLocation<T>> {
&self.locations
}
#[must_use]
pub fn version_at_entry(&self, location: &MemoryLocation<T>, block: usize) -> Option<u32> {
self.entry_versions.get(&(location.clone(), block)).copied()
}
#[must_use]
pub fn version_at_exit(&self, location: &MemoryLocation<T>, block: usize) -> Option<u32> {
self.exit_versions.get(&(location.clone(), block)).copied()
}
#[must_use]
pub fn definition(&self, version: &MemoryVersion<T>) -> Option<MemoryDefSite> {
self.definitions.get(version).copied()
}
fn allocate_version(&mut self, location: &MemoryLocation<T>) -> u32 {
let version = self.next_version.entry(location.clone()).or_insert(0);
let result = *version;
*version = version.saturating_add(1);
result
}
fn identify_memory_operations(&mut self, ssa: &SsaFunction<T>) {
for (block_idx, instr_idx, instr) in ssa.iter_instructions() {
if let Some(mem_op) = Self::classify_memory_operation(instr.op(), block_idx, instr_idx)
{
self.locations.insert(mem_op.location().clone());
self.operations.push(mem_op);
}
}
}
fn classify_memory_operation(op: &SsaOp<T>, block: usize, instr: usize) -> Option<MemoryOp<T>> {
match op {
SsaOp::LoadField {
dest,
object,
field,
} => {
let location = MemoryLocation::InstanceField(*object, field.clone());
Some(MemoryOp::Load {
location,
dest: *dest,
block,
instr,
})
}
SsaOp::StoreField {
object,
field,
value,
} => {
let location = MemoryLocation::InstanceField(*object, field.clone());
Some(MemoryOp::Store {
location,
value: *value,
block,
instr,
})
}
SsaOp::LoadStaticField { dest, field } => {
let location = MemoryLocation::StaticField(field.clone());
Some(MemoryOp::Load {
location,
dest: *dest,
block,
instr,
})
}
SsaOp::StoreStaticField { field, value } => {
let location = MemoryLocation::StaticField(field.clone());
Some(MemoryOp::Store {
location,
value: *value,
block,
instr,
})
}
SsaOp::LoadElement {
dest, array, index, ..
} => {
let idx = Self::resolve_array_index(*index);
let location = MemoryLocation::ArrayElement(*array, idx);
Some(MemoryOp::Load {
location,
dest: *dest,
block,
instr,
})
}
SsaOp::StoreElement {
array,
index,
value,
..
} => {
let idx = Self::resolve_array_index(*index);
let location = MemoryLocation::ArrayElement(*array, idx);
Some(MemoryOp::Store {
location,
value: *value,
block,
instr,
})
}
SsaOp::LoadIndirect { dest, addr, .. } => {
let location = MemoryLocation::Indirect(*addr);
Some(MemoryOp::Load {
location,
dest: *dest,
block,
instr,
})
}
SsaOp::StoreIndirect { addr, value, .. } => {
let location = MemoryLocation::Indirect(*addr);
Some(MemoryOp::Store {
location,
value: *value,
block,
instr,
})
}
_ => None,
}
}
fn resolve_array_index(index_var: SsaVarId) -> ArrayIndex {
ArrayIndex::Variable(index_var)
}
fn place_memory_phis(&mut self, cfg: &SsaCfg<'_, T>) {
let block_count = cfg.node_count();
if block_count == 0 {
return;
}
let dom_tree = compute_dominators(cfg, cfg.entry());
let frontiers = compute_dominance_frontiers(cfg, &dom_tree);
let mut def_blocks: HashMap<MemoryLocation<T>, BTreeSet<usize>> = HashMap::new();
for op in &self.operations {
if op.is_store() {
def_blocks
.entry(op.location().clone())
.or_default()
.insert(op.block());
}
}
for (location, defs) in def_blocks {
let mut phi_blocks: BTreeSet<usize> = BTreeSet::new();
let mut worklist: VecDeque<usize> = defs.iter().copied().collect();
let mut processed: BTreeSet<usize> = BTreeSet::new();
while let Some(block) = worklist.pop_front() {
if !processed.insert(block) {
continue;
}
let node_id = NodeId::new(block);
if node_id.index() >= frontiers.len() {
continue;
}
let Some(frontier_set) = frontiers.get(node_id.index()) else {
continue;
};
for frontier_block in frontier_set.iter() {
if phi_blocks.insert(frontier_block) {
let version = self.allocate_version(&location);
let phi = MemoryPhi::new(location.clone(), version);
self.memory_phis
.entry(frontier_block)
.or_default()
.push(phi);
self.definitions.insert(
MemoryVersion::new(location.clone(), version),
MemoryDefSite::Phi {
block: frontier_block,
},
);
worklist.push_back(frontier_block);
}
}
}
}
}
fn rename_memory_versions(&mut self, ssa: &SsaFunction<T>, cfg: &SsaCfg<'_, T>) {
let block_count = cfg.node_count();
if block_count == 0 {
return;
}
let dom_tree = compute_dominators(cfg, cfg.entry());
let mut version_stacks: HashMap<MemoryLocation<T>, Vec<u32>> = HashMap::new();
let locations: Vec<_> = self.locations.iter().cloned().collect();
for location in locations {
let entry_version = self.allocate_version(&location);
version_stacks
.entry(location.clone())
.or_default()
.push(entry_version);
self.definitions.insert(
MemoryVersion::new(location, entry_version),
MemoryDefSite::Entry,
);
}
let mut visited = vec![false; block_count];
let mut worklist = vec![cfg.entry().index()];
while let Some(block_idx) = worklist.pop() {
match visited.get(block_idx) {
Some(true) => continue,
None => continue,
Some(false) => {}
}
if let Some(slot) = visited.get_mut(block_idx) {
*slot = true;
}
self.rename_block(block_idx, ssa, cfg, &mut version_stacks);
for child in dom_tree.children(NodeId::new(block_idx)) {
if visited.get(child.index()).copied() == Some(false) {
worklist.push(child.index());
}
}
}
}
fn rename_block(
&mut self,
block_idx: usize,
ssa: &SsaFunction<T>,
cfg: &SsaCfg<'_, T>,
version_stacks: &mut HashMap<MemoryLocation<T>, Vec<u32>>,
) {
for location in self.locations.clone() {
if let Some(&version) = version_stacks.get(&location).and_then(|s| s.last()) {
self.entry_versions
.insert((location.clone(), block_idx), version);
}
}
if let Some(phis) = self.memory_phis.get(&block_idx).cloned() {
for phi in phis {
version_stacks
.entry(phi.location.clone())
.or_default()
.push(phi.result_version);
}
}
let Some(block) = ssa.block(block_idx) else {
return;
};
for (instr_idx, instr) in block.instructions().iter().enumerate() {
if let Some(mem_op) = Self::classify_memory_operation(instr.op(), block_idx, instr_idx)
{
if mem_op.is_store() {
let location = mem_op.location().clone();
let new_version = self.allocate_version(&location);
version_stacks
.entry(location.clone())
.or_default()
.push(new_version);
self.definitions.insert(
MemoryVersion::new(location, new_version),
MemoryDefSite::Store {
block: block_idx,
instr: instr_idx,
},
);
}
}
}
for location in self.locations.clone() {
if let Some(&version) = version_stacks.get(&location).and_then(|s| s.last()) {
self.exit_versions
.insert((location.clone(), block_idx), version);
}
}
for succ_id in cfg.successors(NodeId::new(block_idx)) {
let succ_idx = succ_id.index();
if let Some(phis) = self.memory_phis.get_mut(&succ_idx) {
for phi in phis {
if let Some(&version) = version_stacks.get(&phi.location).and_then(|s| s.last())
{
phi.add_operand(block_idx, version);
}
}
}
}
}
#[must_use]
pub fn stats(&self) -> MemorySsaStats {
let total_phis = self.memory_phis.values().map(Vec::len).sum();
let store_count = self.operations.iter().filter(|op| op.is_store()).count();
let load_count = self.operations.iter().filter(|op| op.is_load()).count();
MemorySsaStats {
location_count: self.locations.len(),
memory_phi_count: total_phis,
store_count,
load_count,
version_count: self.definitions.len(),
}
}
}
impl<T: Target> Default for MemorySsa<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct MemorySsaStats {
pub location_count: usize,
pub memory_phi_count: usize,
pub store_count: usize,
pub load_count: usize,
pub version_count: usize,
}
#[derive(Debug, Clone)]
pub struct MemoryState<T: Target> {
values: HashMap<MemoryLocation<T>, (u32, SsaVarId)>,
}
impl<T: Target> MemoryState<T> {
#[must_use]
pub fn new() -> Self {
Self {
values: HashMap::new(),
}
}
pub fn store(&mut self, location: MemoryLocation<T>, value: SsaVarId, version: u32) {
self.values.insert(location, (version, value));
}
#[must_use]
pub fn load(&self, location: &MemoryLocation<T>) -> Option<SsaVarId> {
if let Some((_, value)) = self.values.get(location) {
return Some(*value);
}
for (loc, (_, value)) in &self.values {
if location.must_alias(loc) {
return Some(*value);
}
}
None
}
#[must_use]
pub fn version(&self, location: &MemoryLocation<T>) -> Option<u32> {
self.values.get(location).map(|(v, _)| *v)
}
#[must_use]
pub fn has_may_alias(&self, location: &MemoryLocation<T>) -> bool {
self.values.keys().any(|loc| loc.may_alias(location))
}
pub fn clear(&mut self) {
self.values.clear();
}
#[must_use]
pub fn len(&self) -> usize {
self.values.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
}
impl<T: Target> Default for MemoryState<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(clippy::enum_variant_names)]
pub enum AliasResult {
NoAlias,
MayAlias,
MustAlias,
}
#[must_use]
pub fn analyze_alias<T: Target>(loc1: &MemoryLocation<T>, loc2: &MemoryLocation<T>) -> AliasResult {
if loc1.must_alias(loc2) {
AliasResult::MustAlias
} else if loc1.may_alias(loc2) {
AliasResult::MayAlias
} else {
AliasResult::NoAlias
}
}