use std::collections::{HashMap, HashSet, VecDeque};
use crate::analysis::ssa::{FieldRef, SsaCfg, SsaFunction, SsaOp, SsaVarId};
use crate::utils::graph::{
algorithms::{compute_dominance_frontiers, compute_dominators},
GraphBase, NodeId, RootedGraph, Successors,
};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum MemoryLocation {
InstanceField(SsaVarId, FieldRef),
StaticField(FieldRef),
ArrayElement(SsaVarId, ArrayIndex),
Indirect(SsaVarId),
Unknown,
}
impl MemoryLocation {
#[must_use]
pub fn base_object(&self) -> Option<SsaVarId> {
match self {
Self::InstanceField(obj, _) => Some(*obj),
Self::ArrayElement(arr, _) => Some(*arr),
Self::Indirect(ptr) => Some(*ptr),
Self::StaticField(_) | Self::Unknown => None,
}
}
#[must_use]
pub fn may_alias(&self, other: &Self) -> bool {
match (self, other) {
(Self::Unknown, _)
| (_, Self::Unknown)
| (
Self::Indirect(_),
Self::InstanceField(..) | Self::ArrayElement(..) | Self::StaticField(_),
)
| (
Self::InstanceField(..) | Self::ArrayElement(..) | Self::StaticField(_),
Self::Indirect(_),
) => true,
(Self::StaticField(f1), Self::StaticField(f2)) => f1 == f2,
(Self::StaticField(_), Self::InstanceField(..) | Self::ArrayElement(..))
| (Self::InstanceField(..) | Self::ArrayElement(..), Self::StaticField(_))
| (Self::InstanceField(..), Self::ArrayElement(..))
| (Self::ArrayElement(..), Self::InstanceField(..)) => false,
(Self::InstanceField(obj1, f1), Self::InstanceField(obj2, f2)) => {
obj1 == obj2 && f1 == f2
}
(Self::ArrayElement(arr1, idx1), Self::ArrayElement(arr2, idx2)) => {
arr1 == arr2 && idx1.may_overlap(idx2)
}
(Self::Indirect(p1), Self::Indirect(p2)) => p1 == p2,
}
}
#[must_use]
pub fn must_alias(&self, other: &Self) -> bool {
match (self, other) {
(Self::StaticField(f1), Self::StaticField(f2)) => f1 == f2,
(Self::InstanceField(obj1, f1), Self::InstanceField(obj2, f2)) => {
obj1 == obj2 && f1 == f2
}
(Self::ArrayElement(arr1, idx1), Self::ArrayElement(arr2, idx2)) => {
arr1 == arr2 && idx1.must_equal(idx2)
}
(Self::Indirect(p1), Self::Indirect(p2)) => p1 == p2,
_ => false,
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum ArrayIndex {
Constant(i64),
Variable(SsaVarId),
Unknown,
}
impl ArrayIndex {
#[must_use]
pub fn may_overlap(&self, other: &Self) -> bool {
match (self, other) {
(Self::Unknown | Self::Variable(_), _) | (_, Self::Unknown | Self::Variable(_)) => true,
(Self::Constant(i1), Self::Constant(i2)) => i1 == i2,
}
}
#[must_use]
pub fn must_equal(&self, other: &Self) -> bool {
match (self, other) {
(Self::Constant(i1), Self::Constant(i2)) => i1 == i2,
(Self::Variable(v1), Self::Variable(v2)) => v1 == v2,
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub enum MemoryOp {
Load {
location: MemoryLocation,
dest: SsaVarId,
block: usize,
instr: usize,
},
Store {
location: MemoryLocation,
value: SsaVarId,
block: usize,
instr: usize,
},
}
impl MemoryOp {
#[must_use]
pub fn location(&self) -> &MemoryLocation {
match self {
Self::Load { location, .. } | Self::Store { location, .. } => location,
}
}
#[must_use]
pub fn block(&self) -> usize {
match self {
Self::Load { block, .. } | Self::Store { block, .. } => *block,
}
}
#[must_use]
pub fn instr(&self) -> usize {
match self {
Self::Load { instr, .. } | Self::Store { instr, .. } => *instr,
}
}
#[must_use]
pub fn is_store(&self) -> bool {
matches!(self, Self::Store { .. })
}
#[must_use]
pub fn is_load(&self) -> bool {
matches!(self, Self::Load { .. })
}
}
#[derive(Debug, Clone)]
pub struct MemoryPhi {
pub location: MemoryLocation,
pub result_version: u32,
pub operands: Vec<MemoryPhiOperand>,
}
impl MemoryPhi {
#[must_use]
pub fn new(location: MemoryLocation, result_version: u32) -> Self {
Self {
location,
result_version,
operands: Vec::new(),
}
}
pub fn add_operand(&mut self, predecessor: usize, version: u32) {
self.operands.push(MemoryPhiOperand {
predecessor,
version,
});
}
#[must_use]
pub fn operand_from(&self, predecessor: usize) -> Option<&MemoryPhiOperand> {
self.operands
.iter()
.find(|op| op.predecessor == predecessor)
}
}
#[derive(Debug, Clone)]
pub struct MemoryPhiOperand {
pub predecessor: usize,
pub version: u32,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct MemoryVersion {
pub location: MemoryLocation,
pub version: u32,
}
impl MemoryVersion {
#[must_use]
pub fn new(location: MemoryLocation, version: u32) -> Self {
Self { location, version }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryDefSite {
Entry,
Store {
block: usize,
instr: usize,
},
Phi {
block: usize,
},
}
#[derive(Debug)]
pub struct MemorySsa {
next_version: HashMap<MemoryLocation, u32>,
memory_phis: HashMap<usize, Vec<MemoryPhi>>,
definitions: HashMap<MemoryVersion, MemoryDefSite>,
entry_versions: HashMap<(MemoryLocation, usize), u32>,
exit_versions: HashMap<(MemoryLocation, usize), u32>,
operations: Vec<MemoryOp>,
locations: HashSet<MemoryLocation>,
}
impl MemorySsa {
#[must_use]
pub fn new() -> Self {
Self {
next_version: HashMap::new(),
memory_phis: HashMap::new(),
definitions: HashMap::new(),
entry_versions: HashMap::new(),
exit_versions: HashMap::new(),
operations: Vec::new(),
locations: HashSet::new(),
}
}
#[must_use]
pub fn build(ssa: &SsaFunction, cfg: &SsaCfg<'_>) -> Self {
let mut mem_ssa = Self::new();
mem_ssa.identify_memory_operations(ssa);
mem_ssa.place_memory_phis(cfg);
mem_ssa.rename_memory_versions(ssa, cfg);
mem_ssa
}
#[must_use]
pub fn memory_phis(&self, block: usize) -> &[MemoryPhi] {
self.memory_phis.get(&block).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn operations(&self) -> &[MemoryOp] {
&self.operations
}
#[must_use]
pub fn locations(&self) -> &HashSet<MemoryLocation> {
&self.locations
}
#[must_use]
pub fn version_at_entry(&self, location: &MemoryLocation, block: usize) -> Option<u32> {
self.entry_versions.get(&(location.clone(), block)).copied()
}
#[must_use]
pub fn version_at_exit(&self, location: &MemoryLocation, block: usize) -> Option<u32> {
self.exit_versions.get(&(location.clone(), block)).copied()
}
#[must_use]
pub fn definition(&self, version: &MemoryVersion) -> Option<MemoryDefSite> {
self.definitions.get(version).copied()
}
fn allocate_version(&mut self, location: &MemoryLocation) -> u32 {
let version = self.next_version.entry(location.clone()).or_insert(0);
let result = *version;
*version += 1;
result
}
fn current_version(&self, location: &MemoryLocation) -> u32 {
self.next_version
.get(location)
.copied()
.unwrap_or(0)
.saturating_sub(1)
}
fn identify_memory_operations(&mut self, ssa: &SsaFunction) {
for (block_idx, instr_idx, instr) in ssa.iter_instructions() {
if let Some(mem_op) = Self::classify_memory_operation(instr.op(), block_idx, instr_idx)
{
self.locations.insert(mem_op.location().clone());
self.operations.push(mem_op);
}
}
}
fn classify_memory_operation(op: &SsaOp, block: usize, instr: usize) -> Option<MemoryOp> {
match op {
SsaOp::LoadField {
dest,
object,
field,
} => {
let location = MemoryLocation::InstanceField(*object, *field);
Some(MemoryOp::Load {
location,
dest: *dest,
block,
instr,
})
}
SsaOp::StoreField {
object,
field,
value,
} => {
let location = MemoryLocation::InstanceField(*object, *field);
Some(MemoryOp::Store {
location,
value: *value,
block,
instr,
})
}
SsaOp::LoadStaticField { dest, field } => {
let location = MemoryLocation::StaticField(*field);
Some(MemoryOp::Load {
location,
dest: *dest,
block,
instr,
})
}
SsaOp::StoreStaticField { field, value } => {
let location = MemoryLocation::StaticField(*field);
Some(MemoryOp::Store {
location,
value: *value,
block,
instr,
})
}
SsaOp::LoadElement {
dest, array, index, ..
} => {
let idx = Self::resolve_array_index(*index);
let location = MemoryLocation::ArrayElement(*array, idx);
Some(MemoryOp::Load {
location,
dest: *dest,
block,
instr,
})
}
SsaOp::StoreElement {
array,
index,
value,
..
} => {
let idx = Self::resolve_array_index(*index);
let location = MemoryLocation::ArrayElement(*array, idx);
Some(MemoryOp::Store {
location,
value: *value,
block,
instr,
})
}
SsaOp::LoadIndirect { dest, addr, .. } => {
let location = MemoryLocation::Indirect(*addr);
Some(MemoryOp::Load {
location,
dest: *dest,
block,
instr,
})
}
SsaOp::StoreIndirect { addr, value, .. } => {
let location = MemoryLocation::Indirect(*addr);
Some(MemoryOp::Store {
location,
value: *value,
block,
instr,
})
}
_ => None,
}
}
fn resolve_array_index(index_var: SsaVarId) -> ArrayIndex {
ArrayIndex::Variable(index_var)
}
fn place_memory_phis(&mut self, cfg: &SsaCfg<'_>) {
let block_count = cfg.node_count();
if block_count == 0 {
return;
}
let dom_tree = compute_dominators(cfg, cfg.entry());
let frontiers = compute_dominance_frontiers(cfg, &dom_tree);
let mut def_blocks: HashMap<MemoryLocation, HashSet<usize>> = HashMap::new();
for op in &self.operations {
if op.is_store() {
def_blocks
.entry(op.location().clone())
.or_default()
.insert(op.block());
}
}
for (location, defs) in def_blocks {
let mut phi_blocks: HashSet<usize> = HashSet::new();
let mut worklist: VecDeque<usize> = defs.iter().copied().collect();
let mut processed: HashSet<usize> = HashSet::new();
while let Some(block) = worklist.pop_front() {
if !processed.insert(block) {
continue;
}
let node_id = NodeId::new(block);
if node_id.index() >= frontiers.len() {
continue;
}
for &frontier_node in &frontiers[node_id.index()] {
let frontier_block = frontier_node.index();
if phi_blocks.insert(frontier_block) {
let version = self.allocate_version(&location);
let phi = MemoryPhi::new(location.clone(), version);
self.memory_phis
.entry(frontier_block)
.or_default()
.push(phi);
self.definitions.insert(
MemoryVersion::new(location.clone(), version),
MemoryDefSite::Phi {
block: frontier_block,
},
);
worklist.push_back(frontier_block);
}
}
}
}
}
fn rename_memory_versions(&mut self, ssa: &SsaFunction, cfg: &SsaCfg<'_>) {
let block_count = cfg.node_count();
if block_count == 0 {
return;
}
let dom_tree = compute_dominators(cfg, cfg.entry());
let mut version_stacks: HashMap<MemoryLocation, Vec<u32>> = HashMap::new();
let locations: Vec<_> = self.locations.iter().cloned().collect();
for location in locations {
let entry_version = self.allocate_version(&location);
version_stacks
.entry(location.clone())
.or_default()
.push(entry_version);
self.definitions.insert(
MemoryVersion::new(location, entry_version),
MemoryDefSite::Entry,
);
}
let mut visited = vec![false; block_count];
let mut worklist = vec![cfg.entry().index()];
while let Some(block_idx) = worklist.pop() {
if visited[block_idx] {
continue;
}
visited[block_idx] = true;
self.rename_block(block_idx, ssa, cfg, &mut version_stacks);
for child in dom_tree.children(NodeId::new(block_idx)) {
if !visited[child.index()] {
worklist.push(child.index());
}
}
}
}
fn rename_block(
&mut self,
block_idx: usize,
ssa: &SsaFunction,
cfg: &SsaCfg<'_>,
version_stacks: &mut HashMap<MemoryLocation, Vec<u32>>,
) {
for location in self.locations.clone() {
if let Some(&version) = version_stacks.get(&location).and_then(|s| s.last()) {
self.entry_versions
.insert((location.clone(), block_idx), version);
}
}
if let Some(phis) = self.memory_phis.get(&block_idx).cloned() {
for phi in phis {
version_stacks
.entry(phi.location.clone())
.or_default()
.push(phi.result_version);
}
}
let Some(block) = ssa.block(block_idx) else {
return;
};
for (instr_idx, instr) in block.instructions().iter().enumerate() {
if let Some(mem_op) = Self::classify_memory_operation(instr.op(), block_idx, instr_idx)
{
if mem_op.is_store() {
let location = mem_op.location().clone();
let new_version = self.allocate_version(&location);
version_stacks
.entry(location.clone())
.or_default()
.push(new_version);
self.definitions.insert(
MemoryVersion::new(location, new_version),
MemoryDefSite::Store {
block: block_idx,
instr: instr_idx,
},
);
}
}
}
for location in self.locations.clone() {
if let Some(&version) = version_stacks.get(&location).and_then(|s| s.last()) {
self.exit_versions
.insert((location.clone(), block_idx), version);
}
}
for succ_id in cfg.successors(NodeId::new(block_idx)) {
let succ_idx = succ_id.index();
if let Some(phis) = self.memory_phis.get_mut(&succ_idx) {
for phi in phis {
if let Some(&version) = version_stacks.get(&phi.location).and_then(|s| s.last())
{
phi.add_operand(block_idx, version);
}
}
}
}
}
#[must_use]
pub fn stats(&self) -> MemorySsaStats {
let total_phis = self.memory_phis.values().map(Vec::len).sum();
let store_count = self.operations.iter().filter(|op| op.is_store()).count();
let load_count = self.operations.iter().filter(|op| op.is_load()).count();
MemorySsaStats {
location_count: self.locations.len(),
memory_phi_count: total_phis,
store_count,
load_count,
version_count: self.definitions.len(),
}
}
}
impl Default for MemorySsa {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct MemorySsaStats {
pub location_count: usize,
pub memory_phi_count: usize,
pub store_count: usize,
pub load_count: usize,
pub version_count: usize,
}
#[derive(Debug, Clone)]
pub struct MemoryState {
values: HashMap<MemoryLocation, (u32, SsaVarId)>,
mem_ssa: Option<std::sync::Arc<MemorySsa>>,
}
impl MemoryState {
#[must_use]
pub fn new() -> Self {
Self {
values: HashMap::new(),
mem_ssa: None,
}
}
#[must_use]
pub fn with_mem_ssa(mem_ssa: std::sync::Arc<MemorySsa>) -> Self {
Self {
values: HashMap::new(),
mem_ssa: Some(mem_ssa),
}
}
pub fn store(&mut self, location: MemoryLocation, value: SsaVarId, version: u32) {
self.values.insert(location, (version, value));
}
#[must_use]
pub fn load(&self, location: &MemoryLocation) -> Option<SsaVarId> {
if let Some((_, value)) = self.values.get(location) {
return Some(*value);
}
for (loc, (_, value)) in &self.values {
if location.must_alias(loc) {
return Some(*value);
}
}
None
}
#[must_use]
pub fn version(&self, location: &MemoryLocation) -> Option<u32> {
self.values.get(location).map(|(v, _)| *v)
}
#[must_use]
pub fn has_may_alias(&self, location: &MemoryLocation) -> bool {
self.values.keys().any(|loc| loc.may_alias(location))
}
pub fn clear(&mut self) {
self.values.clear();
}
#[must_use]
pub fn len(&self) -> usize {
self.values.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
}
impl Default for MemoryState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AliasResult {
NoAlias,
MayAlias,
MustAlias,
}
#[must_use]
pub fn analyze_alias(loc1: &MemoryLocation, loc2: &MemoryLocation) -> AliasResult {
if loc1.must_alias(loc2) {
AliasResult::MustAlias
} else if loc1.may_alias(loc2) {
AliasResult::MayAlias
} else {
AliasResult::NoAlias
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_location_static_field_alias() {
let field1 = FieldRef::new(crate::metadata::token::Token::new(0x04000001));
let field2 = FieldRef::new(crate::metadata::token::Token::new(0x04000002));
let loc1 = MemoryLocation::StaticField(field1);
let loc2 = MemoryLocation::StaticField(field1);
let loc3 = MemoryLocation::StaticField(field2);
assert!(loc1.must_alias(&loc2));
assert!(loc1.may_alias(&loc2));
assert!(!loc1.may_alias(&loc3));
}
#[test]
fn test_memory_location_instance_field_alias() {
let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001));
let obj1 = SsaVarId::new();
let obj2 = SsaVarId::new();
let loc1 = MemoryLocation::InstanceField(obj1, field);
let loc2 = MemoryLocation::InstanceField(obj1, field);
let loc3 = MemoryLocation::InstanceField(obj2, field);
assert!(loc1.must_alias(&loc2));
assert!(loc1.may_alias(&loc2));
assert!(!loc1.may_alias(&loc3)); }
#[test]
fn test_array_index_overlap() {
let idx1 = ArrayIndex::Constant(5);
let idx2 = ArrayIndex::Constant(5);
let idx3 = ArrayIndex::Constant(10);
let idx4 = ArrayIndex::Unknown;
assert!(idx1.may_overlap(&idx2));
assert!(idx1.must_equal(&idx2));
assert!(!idx1.may_overlap(&idx3));
assert!(idx1.may_overlap(&idx4)); }
#[test]
fn test_memory_location_array_element_alias() {
let arr = SsaVarId::new();
let idx1 = ArrayIndex::Constant(5);
let idx2 = ArrayIndex::Constant(5);
let idx3 = ArrayIndex::Constant(10);
let loc1 = MemoryLocation::ArrayElement(arr, idx1);
let loc2 = MemoryLocation::ArrayElement(arr, idx2);
let loc3 = MemoryLocation::ArrayElement(arr, idx3);
assert!(loc1.must_alias(&loc2));
assert!(!loc1.may_alias(&loc3));
}
#[test]
fn test_memory_location_unknown_alias() {
let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001));
let loc1 = MemoryLocation::Unknown;
let loc2 = MemoryLocation::StaticField(field);
assert!(loc1.may_alias(&loc2)); assert!(!loc1.must_alias(&loc2)); }
#[test]
fn test_alias_result() {
let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001));
let loc1 = MemoryLocation::StaticField(field);
let loc2 = MemoryLocation::StaticField(field);
assert_eq!(analyze_alias(&loc1, &loc2), AliasResult::MustAlias);
let arr1 = SsaVarId::new();
let arr2 = SsaVarId::new();
let loc3 = MemoryLocation::ArrayElement(arr1, ArrayIndex::Constant(0));
let loc4 = MemoryLocation::ArrayElement(arr2, ArrayIndex::Constant(0));
assert_eq!(analyze_alias(&loc3, &loc4), AliasResult::NoAlias);
}
#[test]
fn test_memory_state() {
let mut state = MemoryState::new();
let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001));
let loc = MemoryLocation::StaticField(field);
let value = SsaVarId::new();
state.store(loc.clone(), value, 1);
assert_eq!(state.load(&loc), Some(value));
assert_eq!(state.version(&loc), Some(1));
assert_eq!(state.len(), 1);
state.clear();
assert!(state.is_empty());
}
#[test]
fn test_memory_phi() {
let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001));
let loc = MemoryLocation::StaticField(field);
let mut phi = MemoryPhi::new(loc.clone(), 2);
phi.add_operand(0, 0);
phi.add_operand(1, 1);
assert_eq!(phi.result_version, 2);
assert_eq!(phi.operands.len(), 2);
assert_eq!(phi.operand_from(0).unwrap().version, 0);
assert_eq!(phi.operand_from(1).unwrap().version, 1);
assert!(phi.operand_from(2).is_none());
}
#[test]
fn test_memory_op() {
let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001));
let loc = MemoryLocation::StaticField(field);
let dest = SsaVarId::new();
let value = SsaVarId::new();
let load = MemoryOp::Load {
location: loc.clone(),
dest,
block: 0,
instr: 5,
};
assert!(load.is_load());
assert!(!load.is_store());
assert_eq!(load.block(), 0);
assert_eq!(load.instr(), 5);
let store = MemoryOp::Store {
location: loc,
value,
block: 1,
instr: 3,
};
assert!(!store.is_load());
assert!(store.is_store());
}
}