use std::collections::VecDeque;
use crate::{
analysis::{
dataflow::{
framework::{AnalysisResults, DataFlowAnalysis, DataFlowCfg, Direction},
lattice::MeetSemiLattice,
},
SsaFunction,
},
utils::graph::NodeId,
};
pub struct DataFlowSolver<A: DataFlowAnalysis> {
analysis: A,
in_states: Vec<A::Lattice>,
out_states: Vec<A::Lattice>,
worklist: VecDeque<usize>,
in_worklist: Vec<bool>,
iterations: usize,
}
impl<A: DataFlowAnalysis> DataFlowSolver<A> {
#[must_use]
pub fn new(analysis: A) -> Self {
Self {
analysis,
in_states: Vec::new(),
out_states: Vec::new(),
worklist: VecDeque::new(),
in_worklist: Vec::new(),
iterations: 0,
}
}
pub fn solve<C: DataFlowCfg>(
mut self,
ssa: &SsaFunction,
cfg: &C,
) -> AnalysisResults<A::Lattice>
where
A::Lattice: Clone,
{
let num_blocks = ssa.block_count();
if num_blocks == 0 {
return AnalysisResults::new(Vec::new(), Vec::new());
}
self.initialize(ssa, cfg);
self.iterate(ssa, cfg);
self.analysis
.finalize(&self.in_states, &self.out_states, ssa);
AnalysisResults::new(self.in_states, self.out_states)
}
#[must_use]
pub const fn iterations(&self) -> usize {
self.iterations
}
fn initialize<C: DataFlowCfg>(&mut self, ssa: &SsaFunction, cfg: &C)
where
A::Lattice: Clone,
{
let num_blocks = ssa.block_count();
let initial = self.analysis.initial(ssa);
let boundary = self.analysis.boundary(ssa);
self.in_states = vec![initial.clone(); num_blocks];
self.out_states = vec![initial; num_blocks];
self.in_worklist = vec![false; num_blocks];
match A::DIRECTION {
Direction::Forward => {
let entry = cfg.entry().index();
if entry < num_blocks {
self.in_states[entry] = boundary;
}
}
Direction::Backward => {
for exit in cfg.exits() {
let idx = exit.index();
if idx < num_blocks {
self.out_states[idx] = boundary.clone();
}
}
}
}
let order = match A::DIRECTION {
Direction::Forward => cfg.reverse_postorder(),
Direction::Backward => cfg.postorder(),
};
for node in order {
let idx = node.index();
if idx < num_blocks {
self.worklist.push_back(idx);
self.in_worklist[idx] = true;
}
}
}
fn iterate<C: DataFlowCfg>(&mut self, ssa: &SsaFunction, cfg: &C)
where
A::Lattice: Clone,
{
while let Some(block_idx) = self.worklist.pop_front() {
self.in_worklist[block_idx] = false;
self.iterations += 1;
let changed = match A::DIRECTION {
Direction::Forward => self.process_forward(block_idx, ssa, cfg),
Direction::Backward => self.process_backward(block_idx, ssa, cfg),
};
if changed {
self.add_affected_to_worklist(block_idx, cfg);
}
}
}
fn process_forward<C: DataFlowCfg>(
&mut self,
block_idx: usize,
ssa: &SsaFunction,
cfg: &C,
) -> bool
where
A::Lattice: Clone,
{
let node = NodeId::new(block_idx);
let mut input = if cfg.predecessors(node).next().is_none() {
self.in_states[block_idx].clone()
} else {
let mut result: Option<A::Lattice> = None;
for pred in cfg.predecessors(node) {
let pred_out = &self.out_states[pred.index()];
result = Some(match result {
None => pred_out.clone(),
Some(acc) => acc.meet(pred_out),
});
}
result.unwrap_or_else(|| self.in_states[block_idx].clone())
};
if node == cfg.entry() {
input = self.in_states[block_idx].clone();
}
self.in_states[block_idx] = input.clone();
let block = ssa.block(block_idx).expect("block should exist");
let output = self.analysis.transfer(block_idx, block, &input, ssa);
let changed = output != self.out_states[block_idx];
self.out_states[block_idx] = output;
changed
}
fn process_backward<C: DataFlowCfg>(
&mut self,
block_idx: usize,
ssa: &SsaFunction,
cfg: &C,
) -> bool
where
A::Lattice: Clone,
{
let node = NodeId::new(block_idx);
let mut output = if cfg.successors(node).next().is_none() {
self.out_states[block_idx].clone()
} else {
let mut result: Option<A::Lattice> = None;
for succ in cfg.successors(node) {
let succ_in = &self.in_states[succ.index()];
result = Some(match result {
None => succ_in.clone(),
Some(acc) => acc.meet(succ_in),
});
}
result.unwrap_or_else(|| self.out_states[block_idx].clone())
};
if cfg.exits().contains(&node) {
output = self.out_states[block_idx].clone();
}
self.out_states[block_idx] = output.clone();
let block = ssa.block(block_idx).expect("block should exist");
let input = self.analysis.transfer(block_idx, block, &output, ssa);
let changed = input != self.in_states[block_idx];
self.in_states[block_idx] = input;
changed
}
fn add_affected_to_worklist<C: DataFlowCfg>(&mut self, block_idx: usize, cfg: &C) {
let node = NodeId::new(block_idx);
match A::DIRECTION {
Direction::Forward => {
for succ in cfg.successors(node) {
let idx = succ.index();
if idx < self.in_worklist.len() && !self.in_worklist[idx] {
self.worklist.push_back(idx);
self.in_worklist[idx] = true;
}
}
}
Direction::Backward => {
for pred in cfg.predecessors(node) {
let idx = pred.index();
if idx < self.in_worklist.len() && !self.in_worklist[idx] {
self.worklist.push_back(idx);
self.in_worklist[idx] = true;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::SsaBlock;
#[derive(Debug, Clone, PartialEq)]
enum TestLattice {
Top,
Value(i32),
Bottom,
}
impl MeetSemiLattice for TestLattice {
fn meet(&self, other: &Self) -> Self {
match (self, other) {
(Self::Top, x) | (x, Self::Top) => x.clone(),
(Self::Value(a), Self::Value(b)) if a == b => Self::Value(*a),
_ => Self::Bottom,
}
}
fn is_bottom(&self) -> bool {
matches!(self, Self::Bottom)
}
}
struct TrivialAnalysis;
impl DataFlowAnalysis for TrivialAnalysis {
type Lattice = TestLattice;
const DIRECTION: Direction = Direction::Forward;
fn boundary(&self, _ssa: &SsaFunction) -> Self::Lattice {
TestLattice::Value(42)
}
fn initial(&self, _ssa: &SsaFunction) -> Self::Lattice {
TestLattice::Top
}
fn transfer(
&self,
_block_id: usize,
_block: &SsaBlock,
input: &Self::Lattice,
_ssa: &SsaFunction,
) -> Self::Lattice {
input.clone()
}
}
#[test]
fn test_solver_iterations() {
let solver = DataFlowSolver::new(TrivialAnalysis);
assert_eq!(solver.iterations(), 0);
}
}