use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use crate::types::CfgInfo;
use crate::TldrResult;
use super::dominators::{build_dominator_tree, compute_dominance_frontier, DominanceFrontier};
use super::types::{SsaFunction, SsaInstructionKind};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MemoryVersion(pub u32);
impl std::fmt::Display for MemoryVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "mem_{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemorySsa {
pub function: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub file: Option<String>,
pub memory_phis: Vec<MemoryPhi>,
pub memory_defs: Vec<MemoryDef>,
pub memory_uses: Vec<MemoryUse>,
pub def_use: HashMap<MemoryVersion, Vec<MemoryVersion>>,
pub stats: MemorySsaStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MemorySsaStats {
pub defs: usize,
pub uses: usize,
pub phis: usize,
pub max_version: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryPhi {
pub result: MemoryVersion,
pub block: usize,
pub sources: Vec<MemoryPhiSource>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryPhiSource {
pub block: usize,
pub version: MemoryVersion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryDef {
pub version: MemoryVersion,
pub clobbers: MemoryVersion,
pub block: usize,
pub line: u32,
pub access: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub kind: Option<MemoryDefKind>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MemoryDefKind {
Store,
Call,
Alloc,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryUse {
pub version: MemoryVersion,
pub block: usize,
pub line: u32,
pub access: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub kind: Option<MemoryUseKind>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MemoryUseKind {
Load,
Call,
}
struct MemorySsaBuilder {
next_version: u32,
block_out_version: HashMap<usize, MemoryVersion>,
version_stack: Vec<MemoryVersion>,
memory_defs: Vec<MemoryDef>,
memory_uses: Vec<MemoryUse>,
memory_phis: Vec<MemoryPhi>,
def_blocks: HashSet<usize>,
}
impl MemorySsaBuilder {
fn new() -> Self {
MemorySsaBuilder {
next_version: 1, block_out_version: HashMap::new(),
version_stack: vec![MemoryVersion(0)], memory_defs: Vec::new(),
memory_uses: Vec::new(),
memory_phis: Vec::new(),
def_blocks: HashSet::new(),
}
}
fn new_version(&mut self) -> MemoryVersion {
let version = MemoryVersion(self.next_version);
self.next_version += 1;
version
}
fn current_version(&self) -> MemoryVersion {
*self.version_stack.last().unwrap_or(&MemoryVersion(0))
}
fn push_version(&mut self, version: MemoryVersion) {
self.version_stack.push(version);
}
fn pop_version(&mut self) {
if self.version_stack.len() > 1 {
self.version_stack.pop();
}
}
fn add_def(&mut self, block: usize, line: u32, access: String, kind: MemoryDefKind) {
let clobbers = self.current_version();
let version = self.new_version();
self.memory_defs.push(MemoryDef {
version,
clobbers,
block,
line,
access,
kind: Some(kind),
});
self.push_version(version);
self.def_blocks.insert(block);
}
fn add_use(&mut self, block: usize, line: u32, access: String, kind: MemoryUseKind) {
let version = self.current_version();
self.memory_uses.push(MemoryUse {
version,
block,
line,
access,
kind: Some(kind),
});
}
fn add_phi(&mut self, block: usize) -> MemoryVersion {
let version = self.new_version();
self.memory_phis.push(MemoryPhi {
result: version,
block,
sources: Vec::new(), });
version
}
}
pub fn build_memory_ssa(cfg: &CfgInfo, ssa: &SsaFunction) -> TldrResult<MemorySsa> {
let mut builder = MemorySsaBuilder::new();
let memory_ops = extract_memory_operations(ssa);
if memory_ops.is_empty() {
return Ok(MemorySsa {
function: ssa.function.clone(),
file: Some(ssa.file.to_string_lossy().to_string()),
memory_phis: Vec::new(),
memory_defs: Vec::new(),
memory_uses: Vec::new(),
def_use: HashMap::new(),
stats: MemorySsaStats::default(),
});
}
let dom_tree = build_dominator_tree(cfg)?;
let dom_frontier = compute_dominance_frontier(cfg, &dom_tree)?;
let def_blocks: HashSet<usize> = memory_ops
.iter()
.filter(|op| op.is_def)
.map(|op| op.block)
.collect();
let phi_blocks = place_memory_phis(&def_blocks, &dom_frontier);
let mut phi_versions: HashMap<usize, MemoryVersion> = HashMap::new();
for &block in &phi_blocks {
let version = builder.add_phi(block);
phi_versions.insert(block, version);
}
rename_memory_versions(
cfg.entry_block,
cfg,
&memory_ops,
&phi_versions,
&dom_tree,
&mut builder,
);
fill_memory_phi_sources(cfg, &mut builder);
let def_use = build_memory_def_use_chains(&builder);
let stats = MemorySsaStats {
defs: builder.memory_defs.len(),
uses: builder.memory_uses.len(),
phis: builder.memory_phis.len(),
max_version: builder.next_version - 1,
};
Ok(MemorySsa {
function: ssa.function.clone(),
file: Some(ssa.file.to_string_lossy().to_string()),
memory_phis: builder.memory_phis,
memory_defs: builder.memory_defs,
memory_uses: builder.memory_uses,
def_use,
stats,
})
}
struct MemoryOp {
block: usize,
line: u32,
access: String,
is_def: bool,
kind: MemoryOpKind,
}
enum MemoryOpKind {
Store,
Load,
Call,
Alloc,
}
fn extract_memory_operations(ssa: &SsaFunction) -> Vec<MemoryOp> {
let mut ops = Vec::new();
for block in &ssa.blocks {
for instr in &block.instructions {
match instr.kind {
SsaInstructionKind::Call => {
let access = instr
.source_text
.as_ref()
.map(|s| extract_call_name(s))
.unwrap_or_else(|| "call".to_string());
ops.push(MemoryOp {
block: block.id,
line: instr.line,
access,
is_def: true,
kind: MemoryOpKind::Call,
});
}
SsaInstructionKind::Assign => {
if let Some(source) = &instr.source_text {
if is_attribute_access(source) {
let (access, is_store) = parse_attribute_assignment(source);
if is_store {
ops.push(MemoryOp {
block: block.id,
line: instr.line,
access,
is_def: true,
kind: MemoryOpKind::Store,
});
} else {
ops.push(MemoryOp {
block: block.id,
line: instr.line,
access,
is_def: false,
kind: MemoryOpKind::Load,
});
}
} else if is_allocation(source) {
let access = extract_allocation(source);
ops.push(MemoryOp {
block: block.id,
line: instr.line,
access,
is_def: true,
kind: MemoryOpKind::Alloc,
});
}
}
}
_ => {
}
}
}
}
ops
}
fn is_attribute_access(source: &str) -> bool {
source.contains('.') || source.contains('[')
}
fn is_allocation(source: &str) -> bool {
source.contains("new ")
|| (source.contains('(')
&& source.contains(')')
&& !source.starts_with("def ")
&& !source.starts_with("fn ")
&& source
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false))
}
fn parse_attribute_assignment(source: &str) -> (String, bool) {
if let Some(eq_pos) = source.find('=') {
let lhs = source[..eq_pos].trim();
let rhs = source[eq_pos + 1..].trim();
if lhs.contains('.') || lhs.contains('[') {
return (lhs.to_string(), true);
}
if rhs.contains('.') || rhs.contains('[') {
return (extract_access(rhs), false);
}
}
(source.to_string(), false)
}
fn extract_access(expr: &str) -> String {
let trimmed = expr.trim();
if let Some(dot_pos) = trimmed.find('.') {
let start = trimmed[..dot_pos]
.rfind(|c: char| !c.is_alphanumeric() && c != '_')
.map(|i| i + 1)
.unwrap_or(0);
let after_dot = dot_pos + 1;
let end = trimmed[after_dot..]
.find(|c: char| !c.is_alphanumeric() && c != '_')
.map(|i| after_dot + i)
.unwrap_or(trimmed.len());
return trimmed[start..end].to_string();
}
if let Some(bracket_pos) = trimmed.find('[') {
let start = trimmed[..bracket_pos]
.rfind(|c: char| !c.is_alphanumeric() && c != '_')
.map(|i| i + 1)
.unwrap_or(0);
let end = trimmed.find(']').map(|i| i + 1).unwrap_or(trimmed.len());
return trimmed[start..end].to_string();
}
trimmed.to_string()
}
fn extract_call_name(source: &str) -> String {
if let Some(paren_pos) = source.find('(') {
let before_paren = source[..paren_pos].trim();
if let Some(dot_pos) = before_paren.rfind('.') {
return before_paren[dot_pos + 1..].to_string();
}
if let Some(eq_pos) = before_paren.rfind('=') {
return before_paren[eq_pos + 1..].trim().to_string();
}
return before_paren.to_string();
}
"call".to_string()
}
fn extract_allocation(source: &str) -> String {
if let Some(new_pos) = source.find("new ") {
let after_new = &source[new_pos + 4..];
if let Some(paren_pos) = after_new.find('(') {
return format!("new {}", &after_new[..paren_pos].trim());
}
}
if let Some(paren_pos) = source.find('(') {
let before_paren = source[..paren_pos].trim();
if let Some(eq_pos) = before_paren.rfind('=') {
let class_name = before_paren[eq_pos + 1..].trim();
return format!("new {}", class_name);
}
}
"alloc".to_string()
}
fn place_memory_phis(
def_blocks: &HashSet<usize>,
dom_frontier: &DominanceFrontier,
) -> HashSet<usize> {
dom_frontier.iterated(def_blocks)
}
#[allow(clippy::only_used_in_recursion)]
fn rename_memory_versions(
block_id: usize,
cfg: &CfgInfo,
memory_ops: &[MemoryOp],
phi_versions: &HashMap<usize, MemoryVersion>,
dom_tree: &super::dominators::DominatorTree,
builder: &mut MemorySsaBuilder,
) {
let stack_depth = builder.version_stack.len();
if let Some(&phi_version) = phi_versions.get(&block_id) {
builder.push_version(phi_version);
}
let block_ops: Vec<_> = memory_ops
.iter()
.filter(|op| op.block == block_id)
.collect();
for op in block_ops {
match op.kind {
MemoryOpKind::Store => {
builder.add_def(block_id, op.line, op.access.clone(), MemoryDefKind::Store);
}
MemoryOpKind::Load => {
builder.add_use(block_id, op.line, op.access.clone(), MemoryUseKind::Load);
}
MemoryOpKind::Call => {
builder.add_use(block_id, op.line, op.access.clone(), MemoryUseKind::Call);
builder.add_def(block_id, op.line, op.access.clone(), MemoryDefKind::Call);
}
MemoryOpKind::Alloc => {
builder.add_def(block_id, op.line, op.access.clone(), MemoryDefKind::Alloc);
}
}
}
builder
.block_out_version
.insert(block_id, builder.current_version());
if let Some(node) = dom_tree.nodes.get(&block_id) {
for &child in &node.children {
rename_memory_versions(child, cfg, memory_ops, phi_versions, dom_tree, builder);
}
}
while builder.version_stack.len() > stack_depth {
builder.pop_version();
}
}
fn fill_memory_phi_sources(cfg: &CfgInfo, builder: &mut MemorySsaBuilder) {
let mut predecessors: HashMap<usize, Vec<usize>> = HashMap::new();
for block in &cfg.blocks {
predecessors.entry(block.id).or_default();
}
for edge in &cfg.edges {
predecessors.entry(edge.to).or_default().push(edge.from);
}
for phi in &mut builder.memory_phis {
if let Some(preds) = predecessors.get(&phi.block) {
for &pred_block in preds {
let version = builder
.block_out_version
.get(&pred_block)
.copied()
.unwrap_or(MemoryVersion(0));
phi.sources.push(MemoryPhiSource {
block: pred_block,
version,
});
}
}
}
}
fn build_memory_def_use_chains(
builder: &MemorySsaBuilder,
) -> HashMap<MemoryVersion, Vec<MemoryVersion>> {
let mut chains: HashMap<MemoryVersion, Vec<MemoryVersion>> = HashMap::new();
for def in &builder.memory_defs {
chains.entry(def.version).or_default();
}
for phi in &builder.memory_phis {
chains.entry(phi.result).or_default();
}
for use_ in &builder.memory_uses {
if let Some(uses) = chains.get_mut(&use_.version) {
uses.push(use_.version);
}
}
for phi in &builder.memory_phis {
for source in &phi.sources {
if let Some(uses) = chains.get_mut(&source.version) {
uses.push(phi.result);
}
}
}
chains
}
pub fn get_reaching_memory_version(
memory_ssa: &MemorySsa,
block: usize,
line: u32,
) -> Option<MemoryVersion> {
let mut latest_version = None;
let mut latest_line = 0u32;
for def in &memory_ssa.memory_defs {
if def.block == block && def.line < line && def.line >= latest_line {
latest_version = Some(def.version);
latest_line = def.line;
}
}
for phi in &memory_ssa.memory_phis {
if phi.block == block && latest_version.is_none() {
latest_version = Some(phi.result);
}
}
latest_version
}
pub fn may_alias(_store: &MemoryDef, _load: &MemoryUse) -> bool {
true
}
pub fn get_reaching_defs_for_use<'a>(
memory_ssa: &'a MemorySsa,
use_: &MemoryUse,
) -> Vec<&'a MemoryDef> {
memory_ssa
.memory_defs
.iter()
.filter(|def| def.version == use_.version)
.collect()
}
pub fn get_uses_for_def<'a>(memory_ssa: &'a MemorySsa, def: &MemoryDef) -> Vec<&'a MemoryUse> {
memory_ssa
.memory_uses
.iter()
.filter(|use_| use_.version == def.version)
.collect()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryDefUseChain {
pub def: MemoryVersion,
pub def_line: u32,
pub def_block: usize,
pub uses: Vec<MemoryUseLocation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryUseLocation {
pub line: u32,
pub block: usize,
}
pub fn build_explicit_def_use_chains(memory_ssa: &MemorySsa) -> Vec<MemoryDefUseChain> {
let mut chains = Vec::new();
for def in &memory_ssa.memory_defs {
let uses: Vec<MemoryUseLocation> = memory_ssa
.memory_uses
.iter()
.filter(|u| u.version == def.version)
.map(|u| MemoryUseLocation {
line: u.line,
block: u.block,
})
.collect();
chains.push(MemoryDefUseChain {
def: def.version,
def_line: def.line,
def_block: def.block,
uses,
});
}
chains
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_version_display() {
let v = MemoryVersion(42);
assert_eq!(format!("{}", v), "mem_42");
}
#[test]
fn test_memory_version_default() {
let v = MemoryVersion::default();
assert_eq!(v.0, 0);
}
#[test]
fn test_is_attribute_access() {
assert!(is_attribute_access("obj.field = 1"));
assert!(is_attribute_access("x = obj.field"));
assert!(is_attribute_access("arr[0] = 1"));
assert!(is_attribute_access("x = arr[i]"));
assert!(!is_attribute_access("x = 1"));
}
#[test]
fn test_parse_attribute_assignment_store() {
let (access, is_store) = parse_attribute_assignment("obj.field = 1");
assert!(is_store);
assert_eq!(access, "obj.field");
}
#[test]
fn test_parse_attribute_assignment_load() {
let (access, is_store) = parse_attribute_assignment("x = obj.field");
assert!(!is_store);
assert!(access.contains("obj.field") || access.contains("obj"));
}
#[test]
fn test_extract_call_name() {
assert_eq!(extract_call_name("x = foo()"), "foo");
assert_eq!(extract_call_name("obj.method()"), "method");
assert_eq!(extract_call_name("result = bar(1, 2)"), "bar");
}
}