use std::collections::HashMap;
use crate::{
analysis::{evaluator::SsaEvaluator, symbolic::SymbolicExpr},
ir::{function::SsaFunction, ops::SsaOp, value::ConstValue, variable::SsaVarId},
target::Target,
BitSet, PointerSize,
};
#[derive(Debug)]
pub struct PatternDetector<'a, T: Target> {
ssa: &'a SsaFunction<T>,
pointer_size: PointerSize,
}
impl<'a, T: Target> PatternDetector<'a, T> {
#[must_use]
pub fn new(ssa: &'a SsaFunction<T>, pointer_size: PointerSize) -> Self {
Self { ssa, pointer_size }
}
#[must_use]
pub fn ssa(&self) -> &SsaFunction<T> {
self.ssa
}
#[must_use]
pub fn find_dispatchers(&self) -> Vec<DispatcherPattern<T>> {
let mut dispatchers: Vec<_> = (0..self.ssa.block_count())
.filter_map(|block_idx| self.analyze_potential_dispatcher(block_idx))
.collect();
dispatchers.sort_by_key(|d| d.block);
dispatchers
}
fn analyze_potential_dispatcher(&self, block_idx: usize) -> Option<DispatcherPattern<T>> {
let block = self.ssa.block(block_idx)?;
let terminator = block.terminator()?;
let (switch_var, targets, default) = match terminator.op() {
SsaOp::Switch {
value,
targets,
default,
} => (*value, targets.clone(), *default),
_ => return None,
};
if targets.len() < 2 {
return None;
}
let has_loopback = targets
.iter()
.any(|&target| self.reaches_block(target, block_idx))
|| self.reaches_block(default, block_idx);
if !has_loopback {
return None;
}
let dispatch_expr = self.build_dispatch_expression(block_idx, switch_var);
let state_vars = dispatch_expr
.as_ref()
.map(|e| e.variables().into_iter().collect())
.unwrap_or_default();
Some(DispatcherPattern {
block: block_idx,
switch_var,
targets,
default,
dispatch_expr,
state_vars,
})
}
fn reaches_block(&self, from_block: usize, target_block: usize) -> bool {
let block_count = self.ssa.block_count().max(1);
let mut visited = BitSet::new(block_count);
let mut queue = vec![from_block];
let max_depth: u32 = 50; let mut depth: u32 = 0;
while !queue.is_empty() && depth < max_depth {
let mut next_queue = Vec::new();
for block_idx in queue {
if block_idx == target_block {
return true;
}
if block_idx >= block_count || !visited.insert(block_idx) {
continue;
}
if let Some(successors) = self.block_successors(block_idx) {
next_queue.extend(successors);
}
}
queue = next_queue;
depth = depth.saturating_add(1);
}
false
}
fn block_successors(&self, block_idx: usize) -> Option<Vec<usize>> {
let block = self.ssa.block(block_idx)?;
block.terminator()?;
Some(block.successors())
}
fn build_dispatch_expression(
&self,
block_idx: usize,
switch_var: SsaVarId,
) -> Option<SymbolicExpr<T>> {
let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size);
if let Some(block) = self.ssa.block(block_idx) {
for phi in block.phi_nodes() {
let name = format!("phi_{}", phi.result().index());
eval.set_symbolic(phi.result(), name);
}
}
eval.evaluate_block(block_idx);
eval.get(switch_var).cloned()
}
#[must_use]
pub fn find_sources(&self, dispatcher: &DispatcherPattern<T>) -> Vec<SourceBlock<T>> {
let reaching_blocks = self.find_reaching_blocks(dispatcher.block);
reaching_blocks
.iter()
.filter(|&block_idx| block_idx != dispatcher.block)
.filter_map(|block_idx| self.analyze_source_block(block_idx, dispatcher))
.collect()
}
fn find_reaching_blocks(&self, dispatcher_block: usize) -> BitSet {
let block_count = self.ssa.block_count().max(1);
let mut reaching = BitSet::new(block_count);
let mut predecessors: HashMap<usize, Vec<usize>> = HashMap::new();
for block_idx in 0..self.ssa.block_count() {
if let Some(succs) = self.block_successors(block_idx) {
for succ in succs {
predecessors.entry(succ).or_default().push(block_idx);
}
}
}
let mut queue = vec![dispatcher_block];
while let Some(block_idx) = queue.pop() {
if block_idx >= block_count || !reaching.insert(block_idx) {
continue;
}
if let Some(preds) = predecessors.get(&block_idx) {
queue.extend(preds.iter().copied());
}
}
reaching
}
fn analyze_source_block(
&self,
block_idx: usize,
dispatcher: &DispatcherPattern<T>,
) -> Option<SourceBlock<T>> {
let block = self.ssa.block(block_idx)?;
let terminator = block.terminator()?;
let (leads_to_dispatcher, is_conditional) = match terminator.op() {
SsaOp::Jump { target } => (*target == dispatcher.block, false),
SsaOp::Branch {
true_target,
false_target,
..
} => {
let leads = *true_target == dispatcher.block || *false_target == dispatcher.block;
(leads, true)
}
_ => return None,
};
if !leads_to_dispatcher {
return None;
}
let state_value = self.compute_state_value(block_idx, dispatcher);
let target_case = self.compute_target_case(state_value.as_ref(), dispatcher);
Some(SourceBlock {
block: block_idx,
state_value,
target_case,
is_conditional,
})
}
fn compute_state_value(
&self,
block_idx: usize,
dispatcher: &DispatcherPattern<T>,
) -> Option<SymbolicExpr<T>> {
let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size);
if let Some(block) = self.ssa.block(block_idx) {
for phi in block.phi_nodes() {
let name = format!("phi_{}", phi.result().index());
eval.set_symbolic(phi.result(), name);
}
}
eval.evaluate_block(block_idx);
if let Some(state_var) = dispatcher.state_vars.first() {
if let Some(disp_block) = self.ssa.block(dispatcher.block) {
for phi in disp_block.phi_nodes() {
if phi.result() == *state_var {
for operand in phi.operands() {
if operand.predecessor() == block_idx {
return eval.get(operand.value()).cloned();
}
}
}
}
}
}
None
}
fn compute_target_case(
&self,
state_value: Option<&SymbolicExpr<T>>,
dispatcher: &DispatcherPattern<T>,
) -> Option<usize> {
let concrete_state = state_value.and_then(SymbolicExpr::as_constant)?;
let dispatch_expr = dispatcher.dispatch_expr.as_ref()?;
let state_var_names: Vec<String> = dispatcher
.state_vars
.iter()
.map(|v| format!("phi_{}", v.index()))
.collect();
let mut bindings: HashMap<&str, ConstValue<T>> = HashMap::new();
for name in &state_var_names {
bindings.insert(name.as_str(), concrete_state.clone());
}
bindings.insert("state", concrete_state.clone());
let case_idx = dispatch_expr.evaluate_named(&bindings, self.pointer_size)?;
let idx = case_idx.as_i64().and_then(|v| usize::try_from(v).ok())?;
if idx < dispatcher.targets.len() {
Some(idx)
} else {
None }
}
#[must_use]
pub fn find_opaque_predicates(&self) -> Vec<OpaquePredicatePattern<T>> {
(0..self.ssa.block_count())
.filter_map(|block_idx| self.analyze_opaque_predicate(block_idx))
.collect()
}
fn analyze_opaque_predicate(&self, block_idx: usize) -> Option<OpaquePredicatePattern<T>> {
let block = self.ssa.block(block_idx)?;
let terminator = block.terminator()?;
let (condition_var, true_target, false_target) = match terminator.op() {
SsaOp::Branch {
condition,
true_target,
false_target,
} => (*condition, *true_target, *false_target),
_ => return None,
};
let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size);
for phi in block.phi_nodes() {
let name = format!("phi_{}", phi.result().index());
eval.set_symbolic(phi.result(), name);
}
eval.evaluate_block(block_idx);
let condition_value = eval.get(condition_var);
let resolution = match condition_value {
Some(expr) if expr.is_constant() => {
if expr.as_constant().is_some_and(ConstValue::is_zero) {
PredicateResolution::AlwaysFalse
} else {
PredicateResolution::AlwaysTrue
}
}
Some(expr) => PredicateResolution::Symbolic(expr.clone()),
None => PredicateResolution::Unknown,
};
if matches!(
resolution,
PredicateResolution::AlwaysTrue | PredicateResolution::AlwaysFalse
) {
Some(OpaquePredicatePattern {
block: block_idx,
condition_var,
true_target,
false_target,
resolution,
})
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct DispatcherPattern<T: Target> {
pub block: usize,
pub switch_var: SsaVarId,
pub targets: Vec<usize>,
pub default: usize,
pub dispatch_expr: Option<SymbolicExpr<T>>,
pub state_vars: Vec<SsaVarId>,
}
impl<T: Target> DispatcherPattern<T> {
#[must_use]
pub fn case_count(&self) -> usize {
self.targets.len()
}
#[must_use]
pub fn target_for_case(&self, case_idx: usize) -> usize {
self.targets.get(case_idx).copied().unwrap_or(self.default)
}
}
#[derive(Debug, Clone)]
pub struct SourceBlock<T: Target> {
pub block: usize,
pub state_value: Option<SymbolicExpr<T>>,
pub target_case: Option<usize>,
pub is_conditional: bool,
}
#[derive(Debug, Clone)]
pub struct OpaquePredicatePattern<T: Target> {
pub block: usize,
pub condition_var: SsaVarId,
pub true_target: usize,
pub false_target: usize,
pub resolution: PredicateResolution<T>,
}
impl<T: Target> OpaquePredicatePattern<T> {
#[must_use]
pub fn actual_target(&self) -> Option<usize> {
match self.resolution {
PredicateResolution::AlwaysTrue => Some(self.true_target),
PredicateResolution::AlwaysFalse => Some(self.false_target),
_ => None,
}
}
#[must_use]
pub fn dead_target(&self) -> Option<usize> {
match self.resolution {
PredicateResolution::AlwaysTrue => Some(self.false_target),
PredicateResolution::AlwaysFalse => Some(self.true_target),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub enum PredicateResolution<T: Target> {
AlwaysTrue,
AlwaysFalse,
Symbolic(SymbolicExpr<T>),
Unknown,
}