use crate::{
analysis::{
dataflow::{
framework::{DataFlowAnalysis, Direction},
lattice::MeetSemiLattice,
},
SsaBlock, SsaFunction, SsaVarId,
},
utils::BitSet,
};
pub struct ReachingDefinitions {
num_vars: usize,
gen_sets: Vec<BitSet>,
}
impl ReachingDefinitions {
#[must_use]
pub fn new(ssa: &SsaFunction) -> Self {
let num_vars = ssa.variable_count();
let num_blocks = ssa.block_count();
let mut gen_sets = Vec::with_capacity(num_blocks);
for block in ssa.blocks() {
let mut gen = BitSet::new(num_vars);
for phi in block.phi_nodes() {
if let Some(idx) = ssa.var_index(phi.result()) {
gen.insert(idx);
}
}
for instr in block.instructions() {
if let Some(def) = instr.def() {
if let Some(idx) = ssa.var_index(def) {
gen.insert(idx);
}
}
}
gen_sets.push(gen);
}
Self { num_vars, gen_sets }
}
#[must_use]
pub const fn num_variables(&self) -> usize {
self.num_vars
}
}
impl DataFlowAnalysis for ReachingDefinitions {
type Lattice = ReachingDefsResult;
const DIRECTION: Direction = Direction::Forward;
fn boundary(&self, ssa: &SsaFunction) -> Self::Lattice {
let mut defs = BitSet::new(self.num_vars);
for (idx, var) in ssa.variables().iter().enumerate() {
if var.version() == 0 && (var.origin().is_argument() || var.origin().is_local()) {
defs.insert(idx);
}
}
ReachingDefsResult { defs }
}
fn initial(&self, _ssa: &SsaFunction) -> Self::Lattice {
ReachingDefsResult {
defs: BitSet::new(self.num_vars),
}
}
fn transfer(
&self,
block_id: usize,
_block: &SsaBlock,
input: &Self::Lattice,
_ssa: &SsaFunction,
) -> Self::Lattice {
let mut result = input.defs.clone();
result.union_with(&self.gen_sets[block_id]);
ReachingDefsResult { defs: result }
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ReachingDefsResult {
defs: BitSet,
}
impl ReachingDefsResult {
#[must_use]
pub fn new(num_vars: usize) -> Self {
Self {
defs: BitSet::new(num_vars),
}
}
#[must_use]
pub fn reaches(&self, var: SsaVarId) -> bool {
let idx = var.index();
idx < self.defs.len() && self.defs.contains(idx)
}
pub fn definitions(&self) -> impl Iterator<Item = SsaVarId> + '_ {
self.defs.iter().map(SsaVarId::from_index)
}
#[must_use]
pub fn count(&self) -> usize {
self.defs.count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.defs.is_empty()
}
pub fn add(&mut self, var: SsaVarId) {
let idx = var.index();
if idx < self.defs.len() {
self.defs.insert(idx);
}
}
pub fn remove(&mut self, var: SsaVarId) {
let idx = var.index();
if idx < self.defs.len() {
self.defs.remove(idx);
}
}
}
impl MeetSemiLattice for ReachingDefsResult {
fn meet(&self, other: &Self) -> Self {
let mut result = self.defs.clone();
result.union_with(&other.defs);
Self { defs: result }
}
fn is_bottom(&self) -> bool {
self.defs.count() == self.defs.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reaching_defs_result() {
let mut result = ReachingDefsResult::new(10);
assert!(result.is_empty());
result.add(SsaVarId::from_index(0));
result.add(SsaVarId::from_index(5));
assert!(!result.is_empty());
assert_eq!(result.count(), 2);
assert!(result.reaches(SsaVarId::from_index(0)));
assert!(result.reaches(SsaVarId::from_index(5)));
assert!(!result.reaches(SsaVarId::from_index(1)));
result.remove(SsaVarId::from_index(0));
assert!(!result.reaches(SsaVarId::from_index(0)));
assert_eq!(result.count(), 1);
}
#[test]
fn test_reaching_defs_meet() {
let mut a = ReachingDefsResult::new(10);
let mut b = ReachingDefsResult::new(10);
a.add(SsaVarId::from_index(0));
a.add(SsaVarId::from_index(1));
b.add(SsaVarId::from_index(1));
b.add(SsaVarId::from_index(2));
let result = a.meet(&b);
assert!(result.reaches(SsaVarId::from_index(0)));
assert!(result.reaches(SsaVarId::from_index(1)));
assert!(result.reaches(SsaVarId::from_index(2)));
assert_eq!(result.count(), 3);
}
#[test]
fn test_reaching_defs_iterator() {
let mut result = ReachingDefsResult::new(100);
result.add(SsaVarId::from_index(5));
result.add(SsaVarId::from_index(42));
result.add(SsaVarId::from_index(99));
let defs: Vec<_> = result.definitions().collect();
assert_eq!(defs.len(), 3);
assert!(defs.contains(&SsaVarId::from_index(5)));
assert!(defs.contains(&SsaVarId::from_index(42)));
assert!(defs.contains(&SsaVarId::from_index(99)));
}
}