use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use crate::types::{CfgInfo, Language};
use crate::TldrResult;
use super::types::{SsaFunction, SsaNameId};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LiveVariables {
pub function: String,
pub blocks: HashMap<usize, LiveSets>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LiveSets {
pub live_in: HashSet<String>,
pub live_out: HashSet<String>,
}
pub fn compute_live_variables(
cfg: &CfgInfo,
refs: &[crate::types::VarRef],
) -> TldrResult<LiveVariables> {
use crate::types::RefType;
let line_to_block: HashMap<u32, usize> = cfg
.blocks
.iter()
.flat_map(|block| (block.lines.0..=block.lines.1).map(move |line| (line, block.id)))
.collect();
let mut successors: HashMap<usize, Vec<usize>> = HashMap::new();
for block in &cfg.blocks {
successors.entry(block.id).or_default();
}
for edge in &cfg.edges {
successors.entry(edge.from).or_default().push(edge.to);
}
let mut use_sets: HashMap<usize, HashSet<String>> = HashMap::new();
let mut def_sets: HashMap<usize, HashSet<String>> = HashMap::new();
for block in &cfg.blocks {
use_sets.insert(block.id, HashSet::new());
def_sets.insert(block.id, HashSet::new());
}
let mut block_refs: HashMap<usize, Vec<&crate::types::VarRef>> = HashMap::new();
for var_ref in refs {
if let Some(&block_id) = line_to_block.get(&var_ref.line) {
block_refs.entry(block_id).or_default().push(var_ref);
}
}
for (&block_id, refs_in_block) in &block_refs {
let mut sorted_refs: Vec<_> = refs_in_block.iter().collect();
sorted_refs.sort_by_key(|r| r.line);
let use_set = use_sets.get_mut(&block_id).unwrap();
let def_set = def_sets.get_mut(&block_id).unwrap();
for var_ref in sorted_refs {
match var_ref.ref_type {
RefType::Use => {
if !def_set.contains(&var_ref.name) {
use_set.insert(var_ref.name.clone());
}
}
RefType::Definition => {
def_set.insert(var_ref.name.clone());
}
RefType::Update => {
if !def_set.contains(&var_ref.name) {
use_set.insert(var_ref.name.clone());
}
def_set.insert(var_ref.name.clone());
}
}
}
}
let mut live_in: HashMap<usize, HashSet<String>> = HashMap::new();
let mut live_out: HashMap<usize, HashSet<String>> = HashMap::new();
for block in &cfg.blocks {
live_in.insert(block.id, HashSet::new());
live_out.insert(block.id, HashSet::new());
}
let block_ids: Vec<usize> = cfg.blocks.iter().map(|b| b.id).collect();
let max_iterations = block_ids.len() * 2 + 10;
let mut changed = true;
let mut iterations = 0;
while changed && iterations < max_iterations {
changed = false;
iterations += 1;
for &block_id in block_ids.iter().rev() {
let mut new_out: HashSet<String> = HashSet::new();
if let Some(succs) = successors.get(&block_id) {
for &succ_id in succs {
if let Some(succ_in) = live_in.get(&succ_id) {
new_out.extend(succ_in.iter().cloned());
}
}
}
let use_b = use_sets.get(&block_id).cloned().unwrap_or_default();
let def_b = def_sets.get(&block_id).cloned().unwrap_or_default();
let out_minus_def: HashSet<String> = new_out
.difference(&def_b)
.cloned()
.collect();
let mut new_in = use_b;
new_in.extend(out_minus_def);
if &new_in != live_in.get(&block_id).unwrap() {
changed = true;
live_in.insert(block_id, new_in);
}
if &new_out != live_out.get(&block_id).unwrap() {
changed = true;
live_out.insert(block_id, new_out);
}
}
}
let mut blocks_result = HashMap::new();
for block in &cfg.blocks {
blocks_result.insert(
block.id,
LiveSets {
live_in: live_in.get(&block.id).cloned().unwrap_or_default(),
live_out: live_out.get(&block.id).cloned().unwrap_or_default(),
},
);
}
Ok(LiveVariables {
function: cfg.function.clone(),
blocks: blocks_result,
})
}
impl LiveVariables {
pub fn is_live_in(&self, block: usize, var: &str) -> bool {
self.blocks
.get(&block)
.map(|sets| sets.live_in.contains(var))
.unwrap_or(false)
}
pub fn is_live_out(&self, block: usize, var: &str) -> bool {
self.blocks
.get(&block)
.map(|sets| sets.live_out.contains(var))
.unwrap_or(false)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValueNumbering {
pub function: String,
pub value_numbers: HashMap<SsaNameId, u32>,
pub equivalences: HashMap<u32, Vec<SsaNameId>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum CanonicalExpr {
Leaf(String),
BinaryOp {
op: String,
operands: (u32, u32),
},
UnaryOp {
op: String,
operand: u32,
},
Phi(Vec<u32>),
Call(u32),
}
pub fn compute_value_numbers(ssa: &SsaFunction) -> TldrResult<ValueNumbering> {
use super::types::SsaInstructionKind;
let mut value_numbers: HashMap<SsaNameId, u32> = HashMap::new();
let mut expr_to_number: HashMap<CanonicalExpr, u32> = HashMap::new();
let mut next_number: u32 = 0;
let mut call_counter: u32 = 0;
let get_or_create_number = |expr: CanonicalExpr,
expr_to_num: &mut HashMap<CanonicalExpr, u32>,
next_num: &mut u32|
-> u32 {
if let Some(&num) = expr_to_num.get(&expr) {
num
} else {
let num = *next_num;
*next_num += 1;
expr_to_num.insert(expr, num);
num
}
};
for block in &ssa.blocks {
for phi in &block.phi_functions {
let source_numbers: Vec<u32> = phi
.sources
.iter()
.filter_map(|s| value_numbers.get(&s.name).copied())
.collect();
let expr = CanonicalExpr::Phi(source_numbers);
let vn = get_or_create_number(expr, &mut expr_to_number, &mut next_number);
value_numbers.insert(phi.target, vn);
}
for instr in &block.instructions {
if let Some(target) = instr.target {
let use_numbers: Vec<u32> = instr
.uses
.iter()
.filter_map(|u| value_numbers.get(u).copied())
.collect();
let expr = match &instr.kind {
SsaInstructionKind::Param | SsaInstructionKind::Assign => {
if use_numbers.is_empty() {
let key = instr
.source_text
.clone()
.unwrap_or_else(|| format!("const_{}", target.0));
CanonicalExpr::Leaf(key)
} else if use_numbers.len() == 1 {
let source_vn = use_numbers[0];
value_numbers.insert(target, source_vn);
continue;
} else {
CanonicalExpr::Leaf(format!("assign_{}", target.0))
}
}
SsaInstructionKind::BinaryOp => {
if use_numbers.len() >= 2 {
let (left, right) = (use_numbers[0], use_numbers[1]);
let op_name = instr
.source_text
.as_ref()
.and_then(|s| {
if s.contains('+') {
Some("+")
} else if s.contains('*') {
Some("*")
} else if s.contains('-') {
Some("-")
} else if s.contains('/') {
Some("/")
} else {
None
}
})
.unwrap_or("binop");
let is_commutative = op_name == "+" || op_name == "*";
let operands = if is_commutative && left > right {
(right, left)
} else {
(left, right)
};
CanonicalExpr::BinaryOp {
op: op_name.to_string(),
operands,
}
} else {
CanonicalExpr::Leaf(format!("binop_{}", target.0))
}
}
SsaInstructionKind::UnaryOp => {
if !use_numbers.is_empty() {
CanonicalExpr::UnaryOp {
op: "unary".to_string(),
operand: use_numbers[0],
}
} else {
CanonicalExpr::Leaf(format!("unary_{}", target.0))
}
}
SsaInstructionKind::Call => {
call_counter += 1;
CanonicalExpr::Call(call_counter)
}
SsaInstructionKind::Return | SsaInstructionKind::Branch => {
continue;
}
};
let vn = get_or_create_number(expr, &mut expr_to_number, &mut next_number);
value_numbers.insert(target, vn);
}
}
}
let mut equivalences: HashMap<u32, Vec<SsaNameId>> = HashMap::new();
for (&name, &vn) in &value_numbers {
equivalences.entry(vn).or_default().push(name);
}
Ok(ValueNumbering {
function: ssa.function.clone(),
value_numbers,
equivalences,
})
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum LatticeValue {
Top,
Constant(ConstantValue),
Bottom,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConstantValue {
Int(i64),
Float(String),
String(String),
Bool(bool),
None,
}
impl std::fmt::Display for ConstantValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstantValue::Int(i) => write!(f, "{}", i),
ConstantValue::Float(fl) => write!(f, "{}", fl),
ConstantValue::String(s) => write!(f, "\"{}\"", s),
ConstantValue::Bool(b) => write!(f, "{}", b),
ConstantValue::None => write!(f, "None"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SccpResult {
pub function: String,
pub constants: HashMap<SsaNameId, ConstantValue>,
pub unreachable_blocks: HashSet<usize>,
pub dead_names: HashSet<SsaNameId>,
}
fn lattice_meet(a: &LatticeValue, b: &LatticeValue) -> LatticeValue {
match (a, b) {
(LatticeValue::Top, x) | (x, LatticeValue::Top) => x.clone(),
(LatticeValue::Constant(c1), LatticeValue::Constant(c2)) if c1 == c2 => {
LatticeValue::Constant(c1.clone())
}
_ => LatticeValue::Bottom,
}
}
fn parse_constant(source_text: &str) -> Option<ConstantValue> {
let trimmed = source_text.trim();
let value_part = if let Some(idx) = trimmed.find('=') {
trimmed[idx + 1..].trim()
} else {
trimmed
};
if value_part == "True" || value_part == "true" {
return Some(ConstantValue::Bool(true));
}
if value_part == "False" || value_part == "false" {
return Some(ConstantValue::Bool(false));
}
if value_part == "None" || value_part == "null" || value_part == "nil" {
return Some(ConstantValue::None);
}
if let Ok(i) = value_part.parse::<i64>() {
return Some(ConstantValue::Int(i));
}
if value_part.contains('.') && value_part.parse::<f64>().is_ok() {
return Some(ConstantValue::Float(value_part.to_string()));
}
if (value_part.starts_with('"') && value_part.ends_with('"'))
|| (value_part.starts_with('\'') && value_part.ends_with('\''))
{
let inner = &value_part[1..value_part.len() - 1];
return Some(ConstantValue::String(inner.to_string()));
}
None
}
fn eval_binary_op(op: &str, left: &ConstantValue, right: &ConstantValue) -> Option<ConstantValue> {
match (left, right) {
(ConstantValue::Int(l), ConstantValue::Int(r)) => {
let result = match op {
"+" => l.checked_add(*r)?,
"-" => l.checked_sub(*r)?,
"*" => l.checked_mul(*r)?,
"/" => {
if *r == 0 {
return None;
}
l.checked_div(*r)?
}
"%" => {
if *r == 0 {
return None;
}
l.checked_rem(*r)?
}
"<" => return Some(ConstantValue::Bool(l < r)),
">" => return Some(ConstantValue::Bool(l > r)),
"<=" => return Some(ConstantValue::Bool(l <= r)),
">=" => return Some(ConstantValue::Bool(l >= r)),
"==" => return Some(ConstantValue::Bool(l == r)),
"!=" => return Some(ConstantValue::Bool(l != r)),
_ => return None,
};
Some(ConstantValue::Int(result))
}
(ConstantValue::Bool(l), ConstantValue::Bool(r)) => {
let result = match op {
"and" | "&&" => *l && *r,
"or" | "||" => *l || *r,
"==" => l == r,
"!=" => l != r,
_ => return None,
};
Some(ConstantValue::Bool(result))
}
_ => None,
}
}
pub fn run_sccp(ssa: &SsaFunction) -> TldrResult<SccpResult> {
use super::types::SsaInstructionKind;
use std::collections::VecDeque;
let mut lattice_values: HashMap<SsaNameId, LatticeValue> = HashMap::new();
for name in &ssa.ssa_names {
lattice_values.insert(name.id, LatticeValue::Top);
}
let mut executable_blocks: HashSet<usize> = HashSet::new();
let mut executable_edges: HashSet<(usize, usize)> = HashSet::new();
let mut cfg_worklist: VecDeque<usize> = VecDeque::new();
let mut ssa_worklist: VecDeque<SsaNameId> = VecDeque::new();
let entry_block = ssa.blocks.first().map(|b| b.id).unwrap_or(0);
cfg_worklist.push_back(entry_block);
let block_index: HashMap<usize, usize> = ssa
.blocks
.iter()
.enumerate()
.map(|(i, b)| (b.id, i))
.collect();
let mut use_map: HashMap<SsaNameId, Vec<(usize, usize)>> = HashMap::new(); for block in &ssa.blocks {
for (instr_idx, instr) in block.instructions.iter().enumerate() {
for &use_name in &instr.uses {
use_map
.entry(use_name)
.or_default()
.push((block.id, instr_idx));
}
}
}
let max_iterations = ssa.blocks.len() * ssa.ssa_names.len() + 100;
let mut iterations = 0;
while (!cfg_worklist.is_empty() || !ssa_worklist.is_empty()) && iterations < max_iterations {
iterations += 1;
while let Some(block_id) = cfg_worklist.pop_front() {
if !executable_blocks.insert(block_id) {
continue; }
let Some(&block_idx) = block_index.get(&block_id) else {
continue;
};
let block = &ssa.blocks[block_idx];
for phi in &block.phi_functions {
let mut result = LatticeValue::Top;
for source in &phi.sources {
if executable_edges.contains(&(source.block, block_id))
|| executable_blocks.contains(&source.block)
{
if let Some(source_val) = lattice_values.get(&source.name) {
result = lattice_meet(&result, source_val);
}
}
}
if let Some(current) = lattice_values.get(&phi.target) {
if *current != result {
lattice_values.insert(phi.target, result);
ssa_worklist.push_back(phi.target);
}
}
}
let mut has_branch = false;
for instr in &block.instructions {
if let Some(target) = instr.target {
let new_value = evaluate_instruction(instr, &lattice_values);
if let Some(current) = lattice_values.get(&target) {
if *current != new_value {
lattice_values.insert(target, new_value);
ssa_worklist.push_back(target);
}
} else {
lattice_values.insert(target, new_value);
ssa_worklist.push_back(target);
}
}
if instr.kind == SsaInstructionKind::Branch {
has_branch = true;
if !instr.uses.is_empty() {
let cond_name = instr.uses[0];
let cond_value = lattice_values.get(&cond_name);
match cond_value {
Some(LatticeValue::Constant(ConstantValue::Bool(true))) => {
if let Some(&succ) = block.successors.first() {
executable_edges.insert((block_id, succ));
cfg_worklist.push_back(succ);
}
}
Some(LatticeValue::Constant(ConstantValue::Bool(false))) => {
if let Some(&succ) = block.successors.get(1) {
executable_edges.insert((block_id, succ));
cfg_worklist.push_back(succ);
}
}
_ => {
for &succ in &block.successors {
executable_edges.insert((block_id, succ));
cfg_worklist.push_back(succ);
}
}
}
} else {
for &succ in &block.successors {
executable_edges.insert((block_id, succ));
cfg_worklist.push_back(succ);
}
}
}
}
if !has_branch {
for &succ in &block.successors {
if !executable_blocks.contains(&succ) {
executable_edges.insert((block_id, succ));
cfg_worklist.push_back(succ);
}
}
}
}
while let Some(name) = ssa_worklist.pop_front() {
if let Some(uses) = use_map.get(&name) {
for &(block_id, _instr_idx) in uses {
if executable_blocks.contains(&block_id) {
cfg_worklist.push_back(block_id);
}
}
}
}
}
let mut constants: HashMap<SsaNameId, ConstantValue> = HashMap::new();
for (name, value) in &lattice_values {
if let LatticeValue::Constant(c) = value {
constants.insert(*name, c.clone());
}
}
let all_blocks: HashSet<usize> = ssa.blocks.iter().map(|b| b.id).collect();
let unreachable_blocks: HashSet<usize> = all_blocks
.difference(&executable_blocks)
.copied()
.collect();
let mut used_names: HashSet<SsaNameId> = HashSet::new();
for block in &ssa.blocks {
if executable_blocks.contains(&block.id) {
for instr in &block.instructions {
used_names.extend(instr.uses.iter().copied());
}
for phi in &block.phi_functions {
for source in &phi.sources {
used_names.insert(source.name);
}
}
}
}
let all_names: HashSet<SsaNameId> = ssa.ssa_names.iter().map(|n| n.id).collect();
let dead_names: HashSet<SsaNameId> = all_names.difference(&used_names).copied().collect();
Ok(SccpResult {
function: ssa.function.clone(),
constants,
unreachable_blocks,
dead_names,
})
}
fn evaluate_instruction(
instr: &super::types::SsaInstruction,
lattice_values: &HashMap<SsaNameId, LatticeValue>,
) -> LatticeValue {
use super::types::SsaInstructionKind;
match &instr.kind {
SsaInstructionKind::Param => {
LatticeValue::Bottom
}
SsaInstructionKind::Assign => {
if instr.uses.is_empty() {
if let Some(ref src) = instr.source_text {
if let Some(c) = parse_constant(src) {
return LatticeValue::Constant(c);
}
}
LatticeValue::Bottom
} else if instr.uses.len() == 1 {
lattice_values
.get(&instr.uses[0])
.cloned()
.unwrap_or(LatticeValue::Top)
} else {
LatticeValue::Bottom
}
}
SsaInstructionKind::BinaryOp => {
if instr.uses.len() >= 2 {
let left = lattice_values
.get(&instr.uses[0])
.cloned()
.unwrap_or(LatticeValue::Top);
let right = lattice_values
.get(&instr.uses[1])
.cloned()
.unwrap_or(LatticeValue::Top);
match (&left, &right) {
(LatticeValue::Bottom, _) | (_, LatticeValue::Bottom) => LatticeValue::Bottom,
(LatticeValue::Top, _) | (_, LatticeValue::Top) => LatticeValue::Top,
(LatticeValue::Constant(l), LatticeValue::Constant(r)) => {
let op = instr.source_text.as_ref().and_then(|s| {
if s.contains('+') {
Some("+")
} else if s.contains('-') && !s.starts_with('-') {
Some("-")
} else if s.contains('*') {
Some("*")
} else if s.contains('/') {
Some("/")
} else if s.contains('<') && s.contains('=') {
Some("<=")
} else if s.contains('>') && s.contains('=') {
Some(">=")
} else if s.contains('<') {
Some("<")
} else if s.contains('>') {
Some(">")
} else if s.contains("==") {
Some("==")
} else if s.contains("!=") {
Some("!=")
} else {
None
}
});
if let Some(op) = op {
if let Some(result) = eval_binary_op(op, l, r) {
return LatticeValue::Constant(result);
}
}
LatticeValue::Bottom
}
}
} else {
LatticeValue::Bottom
}
}
SsaInstructionKind::UnaryOp => {
if !instr.uses.is_empty() {
let operand = lattice_values
.get(&instr.uses[0])
.cloned()
.unwrap_or(LatticeValue::Top);
match operand {
LatticeValue::Bottom => LatticeValue::Bottom,
LatticeValue::Top => LatticeValue::Top,
LatticeValue::Constant(c) => {
if let Some(ref src) = instr.source_text {
if src.contains("not") || src.contains('!') {
if let ConstantValue::Bool(b) = c {
return LatticeValue::Constant(ConstantValue::Bool(!b));
}
}
if src.contains('-') {
if let ConstantValue::Int(i) = c {
if let Some(neg) = i.checked_neg() {
return LatticeValue::Constant(ConstantValue::Int(neg));
}
}
}
}
LatticeValue::Bottom
}
}
} else {
LatticeValue::Bottom
}
}
SsaInstructionKind::Call => {
LatticeValue::Bottom
}
SsaInstructionKind::Return | SsaInstructionKind::Branch => {
LatticeValue::Bottom
}
}
}
pub fn find_dead_code(ssa: &SsaFunction) -> TldrResult<Vec<SsaNameId>> {
let mut use_count: HashMap<SsaNameId, usize> = HashMap::new();
for name in &ssa.ssa_names {
use_count.insert(name.id, 0);
}
for block in &ssa.blocks {
for instr in &block.instructions {
for &use_name in &instr.uses {
*use_count.entry(use_name).or_insert(0) += 1;
}
}
for phi in &block.phi_functions {
for source in &phi.sources {
*use_count.entry(source.name).or_insert(0) += 1;
}
}
}
let mut def_info: HashMap<SsaNameId, (usize, Option<&super::types::SsaInstruction>)> =
HashMap::new();
for block in &ssa.blocks {
for instr in &block.instructions {
if let Some(target) = instr.target {
def_info.insert(target, (block.id, Some(instr)));
}
}
for phi in &block.phi_functions {
def_info.insert(phi.target, (block.id, None)); }
}
let mut dead: HashSet<SsaNameId> = HashSet::new();
let mut changed = true;
let max_iterations = ssa.ssa_names.len() + 10;
let mut iterations = 0;
while changed && iterations < max_iterations {
changed = false;
iterations += 1;
for name in &ssa.ssa_names {
if dead.contains(&name.id) {
continue;
}
let uses = use_count.get(&name.id).copied().unwrap_or(0);
if uses == 0 {
let has_effects = if let Some((_, Some(instr))) = def_info.get(&name.id) {
has_side_effects(&instr.kind)
} else {
false };
if !has_effects {
dead.insert(name.id);
changed = true;
if let Some((block_id, _)) = def_info.get(&name.id) {
for block in &ssa.blocks {
if block.id == *block_id {
for instr in &block.instructions {
if instr.target == Some(name.id) {
for &use_name in &instr.uses {
if let Some(count) = use_count.get_mut(&use_name) {
*count = count.saturating_sub(1);
}
}
}
}
for phi in &block.phi_functions {
if phi.target == name.id {
for source in &phi.sources {
if let Some(count) = use_count.get_mut(&source.name) {
*count = count.saturating_sub(1);
}
}
}
}
break;
}
}
}
}
}
}
}
let mut result: Vec<SsaNameId> = dead.into_iter().collect();
result.sort_by_key(|id| id.0);
Ok(result)
}
pub fn has_side_effects(kind: &super::types::SsaInstructionKind) -> bool {
matches!(
kind,
super::types::SsaInstructionKind::Call | super::types::SsaInstructionKind::Return
)
}
#[derive(Debug, Clone)]
pub struct SsaLiveness {
pub function: String,
#[allow(dead_code)]
dom_intervals: HashMap<usize, (u32, u32)>,
#[allow(dead_code)]
use_positions: HashMap<SsaNameId, Vec<(usize, u32)>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockPosition {
Entry,
After(usize),
Exit,
}
pub fn build_ssa_liveness(ssa: &SsaFunction) -> TldrResult<SsaLiveness> {
let mut block_order: HashMap<usize, u32> = HashMap::new();
for (i, block) in ssa.blocks.iter().enumerate() {
block_order.insert(block.id, i as u32);
}
let mut dom_intervals: HashMap<usize, (u32, u32)> = HashMap::new();
for block in &ssa.blocks {
let order = block_order.get(&block.id).copied().unwrap_or(0);
dom_intervals.insert(block.id, (order, ssa.blocks.len() as u32));
}
let mut use_positions: HashMap<SsaNameId, Vec<(usize, u32)>> = HashMap::new();
for block in &ssa.blocks {
for (instr_idx, instr) in block.instructions.iter().enumerate() {
for &use_name in &instr.uses {
use_positions
.entry(use_name)
.or_default()
.push((block.id, instr_idx as u32));
}
}
for phi in &block.phi_functions {
for source in &phi.sources {
use_positions
.entry(source.name)
.or_default()
.push((block.id, 0));
}
}
}
Ok(SsaLiveness {
function: ssa.function.clone(),
dom_intervals,
use_positions,
})
}
pub fn is_live_at(
liveness: &SsaLiveness,
name: SsaNameId,
block: usize,
position: BlockPosition,
) -> bool {
let Some(uses) = liveness.use_positions.get(&name) else {
return false; };
let query_pos = match position {
BlockPosition::Entry => 0,
BlockPosition::After(idx) => (idx + 1) as u32,
BlockPosition::Exit => u32::MAX,
};
let query_block_order = liveness
.dom_intervals
.get(&block)
.map(|(start, _)| *start)
.unwrap_or(0);
for &(use_block, use_pos) in uses {
let use_block_order = liveness
.dom_intervals
.get(&use_block)
.map(|(start, _)| *start)
.unwrap_or(0);
if use_block_order > query_block_order {
return true;
}
if use_block == block && use_pos > query_pos {
return true;
}
}
false
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AvailableExpressions {
pub function: String,
pub blocks: HashMap<usize, ExpressionSets>,
pub expressions: Vec<Expression>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpressionSets {
pub in_set: HashSet<usize>,
pub out_set: HashSet<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Expression {
pub text: String,
pub uses: Vec<String>,
pub first_line: u32,
}
impl AvailableExpressions {
pub fn is_available(&self, block_id: usize, expr_text: &str) -> bool {
let expr_idx = self.expressions.iter().position(|e| e.text == expr_text);
if let Some(idx) = expr_idx {
self.blocks
.get(&block_id)
.map(|sets| sets.in_set.contains(&idx))
.unwrap_or(false)
} else {
false
}
}
pub fn available_at(&self, block_id: usize) -> Vec<&Expression> {
self.blocks
.get(&block_id)
.map(|sets| {
sets.in_set
.iter()
.filter_map(|&idx| self.expressions.get(idx))
.collect()
})
.unwrap_or_default()
}
}
impl LiveVariables {
pub fn live_at_line(&self, cfg: &CfgInfo, line: u32) -> HashSet<String> {
let block = cfg
.blocks
.iter()
.find(|b| b.lines.0 <= line && line <= b.lines.1);
if let Some(block) = block {
self.blocks
.get(&block.id)
.map(|sets| sets.live_in.clone())
.unwrap_or_default()
} else {
HashSet::new()
}
}
pub fn live_range(&self, var: &str, cfg: &CfgInfo) -> Vec<(u32, u32)> {
let mut ranges = Vec::new();
for block in &cfg.blocks {
if self.is_live_in(block.id, var) || self.is_live_out(block.id, var) {
ranges.push(block.lines);
}
}
ranges
}
}
pub fn compute_available_expressions(
cfg: &CfgInfo,
source: &str,
language: Language,
) -> TldrResult<AvailableExpressions> {
let mut predecessors: HashMap<usize, Vec<usize>> = HashMap::new();
for block in &cfg.blocks {
predecessors.insert(block.id, Vec::new());
}
for edge in &cfg.edges {
predecessors.entry(edge.to).or_default().push(edge.from);
}
let line_to_block: HashMap<u32, usize> = cfg
.blocks
.iter()
.flat_map(|block| (block.lines.0..=block.lines.1).map(move |line| (line, block.id)))
.collect();
let extraction = extract_expressions_and_defs(source, language, &line_to_block, cfg);
let all_expressions = extraction.all_expressions;
let expr_to_index = extraction.expr_to_index;
let block_defs = extraction.block_defs;
let block_exprs = extraction.block_exprs;
let defs_by_line = extraction.defs_by_line;
if all_expressions.is_empty() {
let mut blocks_result = HashMap::new();
for block in &cfg.blocks {
blocks_result.insert(
block.id,
ExpressionSets {
in_set: HashSet::new(),
out_set: HashSet::new(),
},
);
}
return Ok(AvailableExpressions {
function: cfg.function.clone(),
blocks: blocks_result,
expressions: all_expressions,
});
}
let mut gen_sets: HashMap<usize, HashSet<usize>> = HashMap::new();
let mut kill_sets: HashMap<usize, HashSet<usize>> = HashMap::new();
for block in &cfg.blocks {
let mut gen = HashSet::new();
let mut kill = HashSet::new();
let defs = block_defs.get(&block.id).cloned().unwrap_or_default();
for (idx, expr) in all_expressions.iter().enumerate() {
if expr.uses.iter().any(|v| defs.contains(v)) {
kill.insert(idx);
}
}
if let Some(exprs) = block_exprs.get(&block.id) {
for (line_num, expr) in exprs {
if let Some(&idx) = expr_to_index.get(&expr.text) {
let killed_after = defs_by_line.iter().any(|(&def_line, def_var)| {
def_line > *line_num
&& def_line <= block.lines.1
&& expr.uses.contains(def_var)
});
if !killed_after {
gen.insert(idx);
}
}
}
}
gen_sets.insert(block.id, gen);
kill_sets.insert(block.id, kill);
}
let all_expr_indices: HashSet<usize> = (0..all_expressions.len()).collect();
let mut avail_in: HashMap<usize, HashSet<usize>> = HashMap::new();
let mut avail_out: HashMap<usize, HashSet<usize>> = HashMap::new();
for block in &cfg.blocks {
if block.id == cfg.entry_block {
avail_in.insert(block.id, HashSet::new());
} else {
avail_in.insert(block.id, all_expr_indices.clone());
}
avail_out.insert(
block.id,
gen_sets.get(&block.id).cloned().unwrap_or_default(),
);
}
let block_ids: Vec<usize> = cfg.blocks.iter().map(|b| b.id).collect();
let max_iterations = block_ids.len() * 2 + 10;
let mut changed = true;
let mut _iterations = 0;
while changed && _iterations < max_iterations {
changed = false;
_iterations += 1;
for &block_id in &block_ids {
let preds = predecessors.get(&block_id).cloned().unwrap_or_default();
let new_in = if preds.is_empty() {
HashSet::new()
} else {
let mut intersection = all_expr_indices.clone();
for pred in &preds {
if let Some(pred_out) = avail_out.get(pred) {
intersection = intersection.intersection(pred_out).copied().collect();
}
}
intersection
};
let gen = gen_sets.get(&block_id).cloned().unwrap_or_default();
let kill = kill_sets.get(&block_id).cloned().unwrap_or_default();
let in_minus_kill: HashSet<usize> = new_in.difference(&kill).copied().collect();
let new_out: HashSet<usize> = gen.union(&in_minus_kill).copied().collect();
if new_in != *avail_in.get(&block_id).unwrap_or(&HashSet::new()) {
changed = true;
avail_in.insert(block_id, new_in);
}
if new_out != *avail_out.get(&block_id).unwrap_or(&HashSet::new()) {
changed = true;
avail_out.insert(block_id, new_out);
}
}
}
let mut blocks_result = HashMap::new();
for block in &cfg.blocks {
blocks_result.insert(
block.id,
ExpressionSets {
in_set: avail_in.get(&block.id).cloned().unwrap_or_default(),
out_set: avail_out.get(&block.id).cloned().unwrap_or_default(),
},
);
}
Ok(AvailableExpressions {
function: cfg.function.clone(),
blocks: blocks_result,
expressions: all_expressions,
})
}
struct ExpressionExtraction {
all_expressions: Vec<Expression>,
expr_to_index: HashMap<String, usize>,
block_defs: HashMap<usize, HashSet<String>>,
block_exprs: HashMap<usize, Vec<(u32, Expression)>>,
defs_by_line: HashMap<u32, String>,
}
fn extract_expressions_and_defs(
source: &str,
language: Language,
line_to_block: &HashMap<u32, usize>,
cfg: &CfgInfo,
) -> ExpressionExtraction {
if let Some(result) = extract_expressions_ast(source, language, line_to_block, cfg) {
return result;
}
extract_expressions_regex(source, line_to_block, cfg)
}
fn extract_expressions_ast(
source: &str,
language: Language,
line_to_block: &HashMap<u32, usize>,
cfg: &CfgInfo,
) -> Option<ExpressionExtraction> {
use crate::ast::parser::ParserPool;
use crate::security::ast_utils::{
binary_expression_node_kinds, walk_descendants,
};
let pool = ParserPool::new();
let tree = pool.parse(source, language).ok()?;
let src_bytes = source.as_bytes();
let binop_kinds = binary_expression_node_kinds(language);
let mut all_expressions: Vec<Expression> = Vec::new();
let mut expr_to_index: HashMap<String, usize> = HashMap::new();
let mut defs_by_line: HashMap<u32, String> = HashMap::new();
let mut block_defs: HashMap<usize, HashSet<String>> = HashMap::new();
for block in &cfg.blocks {
block_defs.insert(block.id, HashSet::new());
}
let mut block_exprs: HashMap<usize, Vec<(u32, Expression)>> = HashMap::new();
for block in &cfg.blocks {
block_exprs.insert(block.id, Vec::new());
}
let descendants = walk_descendants(tree.root_node());
for node in &descendants {
if let Some(var_name) = extract_def_from_node(node, src_bytes, language) {
let line_num = node.start_position().row as u32 + 1;
if let Some(&block_id) = line_to_block.get(&line_num) {
block_defs.entry(block_id).or_default().insert(var_name.clone());
defs_by_line.insert(line_num, var_name);
}
}
}
for node in &descendants {
if !binop_kinds.contains(&node.kind()) {
continue;
}
if let Some((left, op, right)) = extract_binop_operands(node, src_bytes, language) {
if !is_arithmetic_op(&op) {
continue;
}
if !is_simple_identifier(&left) || !is_simple_identifier(&right) {
continue;
}
let line_num = node.start_position().row as u32 + 1;
let (canonical_left, canonical_right) = if op == "+" || op == "*" {
if left < right {
(left, right)
} else {
(right, left)
}
} else {
(left, right)
};
let expr_text = format!("{} {} {}", canonical_left, op, canonical_right);
let expr_idx = if let Some(&idx) = expr_to_index.get(&expr_text) {
idx
} else {
let idx = all_expressions.len();
let expr = Expression {
text: expr_text.clone(),
uses: vec![canonical_left.clone(), canonical_right.clone()],
first_line: line_num,
};
all_expressions.push(expr.clone());
expr_to_index.insert(expr_text, idx);
idx
};
if let Some(&block_id) = line_to_block.get(&line_num) {
let expr = all_expressions[expr_idx].clone();
block_exprs.entry(block_id).or_default().push((line_num, expr));
}
}
}
Some(ExpressionExtraction {
all_expressions,
expr_to_index,
block_defs,
block_exprs,
defs_by_line,
})
}
fn is_arithmetic_op(op: &str) -> bool {
matches!(op, "+" | "-" | "*" | "/" | "%" | "**" | "//" | "<<" | ">>" | "&" | "|" | "^")
}
fn is_simple_identifier(s: &str) -> bool {
if s.is_empty() {
return false;
}
let s = s.trim_start_matches('$'); if s.is_empty() {
return false;
}
let first = s.chars().next().unwrap();
if !first.is_alphabetic() && first != '_' {
return false;
}
s.chars().all(|c| c.is_alphanumeric() || c == '_')
}
fn extract_def_from_node(
node: &tree_sitter::Node,
source: &[u8],
language: Language,
) -> Option<String> {
use crate::security::ast_utils::{assignment_node_kinds, node_text};
let assign_kinds = assignment_node_kinds(language);
let kind = node.kind();
if assign_kinds.contains(&kind) {
return extract_lhs_from_assignment(node, source, language);
}
match language {
Language::TypeScript | Language::JavaScript => {
if kind == "lexical_declaration" {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "variable_declarator" {
if let Some(name) = child.child_by_field_name("name") {
return Some(node_text(&name, source).to_string());
}
}
}
}
}
}
Language::Elixir => {
if kind == "binary_operator" {
if let Some(op) = node.child_by_field_name("operator") {
if node_text(&op, source) == "=" {
if let Some(left) = node.child_by_field_name("left") {
let text = node_text(&left, source).to_string();
if is_simple_identifier(&text) {
return Some(text);
}
}
}
}
}
}
_ => {}
}
None
}
fn extract_lhs_from_assignment(
node: &tree_sitter::Node,
source: &[u8],
language: Language,
) -> Option<String> {
use crate::security::ast_utils::node_text;
if let Some(left) = node.child_by_field_name("left") {
let text = node_text(&left, source);
return text.split(',').next().map(|s| s.trim().to_string());
}
if let Some(pattern) = node.child_by_field_name("pattern") {
return Some(node_text(&pattern, source).to_string());
}
extract_lhs_language_fallback(node, source, language)
}
fn extract_lhs_language_fallback(
node: &tree_sitter::Node,
source: &[u8],
language: Language,
) -> Option<String> {
match language {
Language::TypeScript | Language::JavaScript => extract_name_from_variable_declarator(node, source),
Language::Java => {
if node.kind() == "local_variable_declaration" {
extract_name_from_variable_declarator(node, source)
} else {
None
}
}
Language::Kotlin => extract_lhs_kotlin(node, source),
Language::C | Language::Cpp => extract_lhs_c_family(node, source),
Language::Rust => extract_lhs_rust(node, source),
Language::CSharp => extract_lhs_csharp(node, source),
Language::Swift => extract_lhs_swift(node, source),
Language::Lua | Language::Luau => extract_lhs_lua(node, source, language),
Language::Scala => None,
Language::Elixir | Language::Ocaml => extract_first_child_text(node, source),
_ => None,
}
}
fn extract_name_from_variable_declarator(
node: &tree_sitter::Node,
source: &[u8],
) -> Option<String> {
use crate::security::ast_utils::node_text;
for i in 0..node.child_count() {
let child = node.child(i)?;
if child.kind() != "variable_declarator" {
continue;
}
if let Some(name) = child.child_by_field_name("name") {
return Some(node_text(&name, source).to_string());
}
}
None
}
fn extract_lhs_kotlin(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
use crate::security::ast_utils::node_text;
for i in 0..node.child_count() {
let child = node.child(i)?;
if child.kind() == "variable_declaration" || child.kind() == "simple_identifier" {
return Some(node_text(&child, source).to_string());
}
}
None
}
fn extract_lhs_c_family(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
use crate::security::ast_utils::node_text;
if node.kind() != "declaration" {
return None;
}
for i in 0..node.child_count() {
let child = node.child(i)?;
if child.kind() != "init_declarator" {
continue;
}
if let Some(decl) = child.child_by_field_name("declarator") {
let text = node_text(&decl, source);
return Some(text.trim_start_matches('*').to_string());
}
}
None
}
fn extract_lhs_rust(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
use crate::security::ast_utils::node_text;
if node.kind() != "let_declaration" {
return None;
}
let pattern = node.child_by_field_name("pattern")?;
Some(node_text(&pattern, source).to_string())
}
fn extract_lhs_csharp(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
use crate::security::ast_utils::node_text;
if node.kind() != "variable_declaration" {
return None;
}
for i in 0..node.child_count() {
let child = node.child(i)?;
if child.kind() != "variable_declarator" {
continue;
}
if let Some(name) = child.child_by_field_name("name") {
return Some(node_text(&name, source).to_string());
}
for j in 0..child.child_count() {
let grandchild = child.child(j)?;
if grandchild.kind() == "identifier" {
return Some(node_text(&grandchild, source).to_string());
}
}
}
None
}
fn extract_lhs_swift(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
use crate::security::ast_utils::node_text;
if node.kind() != "property_declaration" {
return None;
}
for i in 0..node.child_count() {
let child = node.child(i)?;
if child.kind() == "pattern" {
return Some(node_text(&child, source).to_string());
}
}
None
}
fn extract_lhs_lua(
node: &tree_sitter::Node,
source: &[u8],
language: Language,
) -> Option<String> {
use crate::security::ast_utils::node_text;
if node.kind() == "variable_declaration" {
for i in 0..node.child_count() {
let child = node.child(i)?;
if child.kind() == "assignment_statement" {
return extract_lhs_from_assignment(&child, source, language);
}
}
}
for i in 0..node.child_count() {
let child = node.child(i)?;
if child.kind() == "variable_list" || child.kind() == "assignment_variable_list" {
if let Some(first) = child.child(0) {
return Some(node_text(&first, source).to_string());
}
}
}
let first = node.child(0)?;
if first.is_named() && first.kind() != "local" {
return Some(node_text(&first, source).to_string());
}
None
}
fn extract_first_child_text(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
use crate::security::ast_utils::node_text;
Some(node_text(&node.child(0)?, source).to_string())
}
fn extract_binop_operands(
node: &tree_sitter::Node,
source: &[u8],
_language: Language,
) -> Option<(String, String, String)> {
use crate::security::ast_utils::node_text;
if let (Some(left), Some(right), Some(op)) = (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
node.child_by_field_name("operator"),
) {
return Some((
node_text(&left, source).trim().to_string(),
node_text(&op, source).trim().to_string(),
node_text(&right, source).trim().to_string(),
));
}
if node.child_count() >= 3 {
let left = node.child(0)?;
let op = node.child(1)?;
let right = node.child(2)?;
let op_text = node_text(&op, source).trim().to_string();
if !op_text.is_empty() {
return Some((
node_text(&left, source).trim().to_string(),
op_text,
node_text(&right, source).trim().to_string(),
));
}
}
if let (Some(left), Some(right)) = (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
) {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if !child.is_named() {
let text = node_text(&child, source).trim().to_string();
if is_arithmetic_op(&text) {
return Some((
node_text(&left, source).trim().to_string(),
text,
node_text(&right, source).trim().to_string(),
));
}
}
}
}
}
None
}
fn extract_expressions_regex(
source: &str,
line_to_block: &HashMap<u32, usize>,
cfg: &CfgInfo,
) -> ExpressionExtraction {
use regex::Regex;
let assign_re = Regex::new(r"^\s*(\w+)\s*=\s*(.+)$").unwrap();
let binop_re = Regex::new(r"(\w+)\s*([+\-*/])\s*(\w+)").unwrap();
let lines: Vec<&str> = source.lines().collect();
let mut all_expressions: Vec<Expression> = Vec::new();
let mut expr_to_index: HashMap<String, usize> = HashMap::new();
let mut defs_by_line: HashMap<u32, String> = HashMap::new();
let mut block_defs: HashMap<usize, HashSet<String>> = HashMap::new();
for block in &cfg.blocks {
block_defs.insert(block.id, HashSet::new());
}
let mut block_exprs: HashMap<usize, Vec<(u32, Expression)>> = HashMap::new();
for block in &cfg.blocks {
block_exprs.insert(block.id, Vec::new());
}
for (line_idx, line) in lines.iter().enumerate() {
let line_num = line_idx as u32 + 1;
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
if let Some(caps) = assign_re.captures(trimmed) {
let var_name = caps.get(1).unwrap().as_str().to_string();
let rhs = caps.get(2).unwrap().as_str();
if let Some(&block_id) = line_to_block.get(&line_num) {
block_defs.entry(block_id).or_default().insert(var_name.clone());
defs_by_line.insert(line_num, var_name);
}
if let Some(binop_caps) = binop_re.captures(rhs) {
let left = binop_caps.get(1).unwrap().as_str().to_string();
let op = binop_caps.get(2).unwrap().as_str().to_string();
let right = binop_caps.get(3).unwrap().as_str().to_string();
let (canonical_left, canonical_right) = if op == "+" || op == "*" {
if left < right {
(left, right)
} else {
(right, left)
}
} else {
(left, right)
};
let expr_text = format!("{} {} {}", canonical_left, op, canonical_right);
let expr_idx = if let Some(&idx) = expr_to_index.get(&expr_text) {
idx
} else {
let idx = all_expressions.len();
let expr = Expression {
text: expr_text.clone(),
uses: vec![canonical_left.clone(), canonical_right.clone()],
first_line: line_num,
};
all_expressions.push(expr.clone());
expr_to_index.insert(expr_text, idx);
idx
};
if let Some(&block_id) = line_to_block.get(&line_num) {
let expr = all_expressions[expr_idx].clone();
block_exprs.entry(block_id).or_default().push((line_num, expr));
}
}
}
}
ExpressionExtraction {
all_expressions,
expr_to_index,
block_defs,
block_exprs,
defs_by_line,
}
}
pub fn compute_available_expressions_from_refs(
cfg: &CfgInfo,
refs: &[crate::types::VarRef],
) -> TldrResult<AvailableExpressions> {
use crate::types::RefType;
let mut predecessors: HashMap<usize, Vec<usize>> = HashMap::new();
for block in &cfg.blocks {
predecessors.insert(block.id, Vec::new());
}
for edge in &cfg.edges {
predecessors.entry(edge.to).or_default().push(edge.from);
}
let line_to_block: HashMap<u32, usize> = cfg
.blocks
.iter()
.flat_map(|block| (block.lines.0..=block.lines.1).map(move |line| (line, block.id)))
.collect();
let mut all_expressions: Vec<Expression> = Vec::new();
let mut expr_to_index: HashMap<String, usize> = HashMap::new();
let mut block_defs: HashMap<usize, HashSet<String>> = HashMap::new();
for block in &cfg.blocks {
block_defs.insert(block.id, HashSet::new());
}
for var_ref in refs {
if let Some(&block_id) = line_to_block.get(&var_ref.line) {
if matches!(var_ref.ref_type, RefType::Definition | RefType::Update) {
block_defs.entry(block_id).or_default().insert(var_ref.name.clone());
}
}
}
let mut block_uses: HashMap<usize, Vec<&crate::types::VarRef>> = HashMap::new();
for var_ref in refs {
if matches!(var_ref.ref_type, RefType::Use) {
if let Some(&block_id) = line_to_block.get(&var_ref.line) {
block_uses.entry(block_id).or_default().push(var_ref);
}
}
}
for (&_block_id, uses) in &block_uses {
let mut uses_by_line: HashMap<u32, Vec<&crate::types::VarRef>> = HashMap::new();
for &var_ref in uses {
uses_by_line.entry(var_ref.line).or_default().push(var_ref);
}
for (line, line_uses) in uses_by_line {
if line_uses.len() >= 2 {
let mut operands: Vec<String> = line_uses.iter().map(|r| r.name.clone()).collect();
operands.sort(); let expr_text = operands.join(" op ");
if !expr_to_index.contains_key(&expr_text) {
let idx = all_expressions.len();
expr_to_index.insert(expr_text.clone(), idx);
all_expressions.push(Expression {
text: expr_text,
uses: operands,
first_line: line,
});
}
}
}
for &var_ref in uses {
let expr_text = format!("use_{}", var_ref.name);
if !expr_to_index.contains_key(&expr_text) {
let idx = all_expressions.len();
expr_to_index.insert(expr_text.clone(), idx);
all_expressions.push(Expression {
text: expr_text,
uses: vec![var_ref.name.clone()],
first_line: var_ref.line,
});
}
}
}
if all_expressions.is_empty() {
let mut blocks_result = HashMap::new();
for block in &cfg.blocks {
blocks_result.insert(
block.id,
ExpressionSets {
in_set: HashSet::new(),
out_set: HashSet::new(),
},
);
}
return Ok(AvailableExpressions {
function: cfg.function.clone(),
blocks: blocks_result,
expressions: all_expressions,
});
}
let mut gen_sets: HashMap<usize, HashSet<usize>> = HashMap::new();
let mut kill_sets: HashMap<usize, HashSet<usize>> = HashMap::new();
for block in &cfg.blocks {
let mut gen = HashSet::new();
let mut kill = HashSet::new();
let defs = block_defs.get(&block.id).cloned().unwrap_or_default();
for (idx, expr) in all_expressions.iter().enumerate() {
if expr.uses.iter().any(|v| defs.contains(v)) {
kill.insert(idx);
}
}
for (idx, expr) in all_expressions.iter().enumerate() {
if expr.first_line >= block.lines.0 && expr.first_line <= block.lines.1 {
let mut killed_after = false;
for var_ref in refs {
if var_ref.line > expr.first_line
&& var_ref.line <= block.lines.1
&& matches!(var_ref.ref_type, RefType::Definition | RefType::Update)
&& expr.uses.contains(&var_ref.name)
{
killed_after = true;
break;
}
}
if !killed_after {
gen.insert(idx);
}
}
}
gen_sets.insert(block.id, gen);
kill_sets.insert(block.id, kill);
}
let all_expr_indices: HashSet<usize> = (0..all_expressions.len()).collect();
let mut avail_in: HashMap<usize, HashSet<usize>> = HashMap::new();
let mut avail_out: HashMap<usize, HashSet<usize>> = HashMap::new();
for block in &cfg.blocks {
if block.id == cfg.entry_block {
avail_in.insert(block.id, HashSet::new());
} else {
avail_in.insert(block.id, all_expr_indices.clone());
}
avail_out.insert(
block.id,
gen_sets.get(&block.id).cloned().unwrap_or_default(),
);
}
let block_ids: Vec<usize> = cfg.blocks.iter().map(|b| b.id).collect();
let max_iterations = block_ids.len() * 2 + 10;
let mut changed = true;
let mut iterations = 0;
while changed && iterations < max_iterations {
changed = false;
iterations += 1;
for &block_id in &block_ids {
let preds = predecessors.get(&block_id).cloned().unwrap_or_default();
let new_in = if preds.is_empty() {
HashSet::new()
} else {
let mut intersection = all_expr_indices.clone();
for pred in &preds {
if let Some(pred_out) = avail_out.get(pred) {
intersection = intersection.intersection(pred_out).copied().collect();
}
}
intersection
};
let gen = gen_sets.get(&block_id).cloned().unwrap_or_default();
let kill = kill_sets.get(&block_id).cloned().unwrap_or_default();
let in_minus_kill: HashSet<usize> = new_in.difference(&kill).copied().collect();
let new_out: HashSet<usize> = gen.union(&in_minus_kill).copied().collect();
if new_in != *avail_in.get(&block_id).unwrap_or(&HashSet::new()) {
changed = true;
avail_in.insert(block_id, new_in);
}
if new_out != *avail_out.get(&block_id).unwrap_or(&HashSet::new()) {
changed = true;
avail_out.insert(block_id, new_out);
}
}
}
let mut blocks_result = HashMap::new();
for block in &cfg.blocks {
blocks_result.insert(
block.id,
ExpressionSets {
in_set: avail_in.get(&block.id).cloned().unwrap_or_default(),
out_set: avail_out.get(&block.id).cloned().unwrap_or_default(),
},
);
}
Ok(AvailableExpressions {
function: cfg.function.clone(),
blocks: blocks_result,
expressions: all_expressions,
})
}