use std::collections::{HashMap, HashSet};
use crate::{
analysis::ssa::{
evaluator::SsaEvaluator, symbolic::SymbolicExpr, ConstValue, SsaFunction, SsaOp, SsaVarId,
},
metadata::typesystem::PointerSize,
};
#[derive(Debug)]
pub struct PatternDetector<'a> {
ssa: &'a SsaFunction,
pointer_size: PointerSize,
}
impl<'a> PatternDetector<'a> {
#[must_use]
pub fn new(ssa: &'a SsaFunction, pointer_size: PointerSize) -> Self {
Self { ssa, pointer_size }
}
#[must_use]
pub fn ssa(&self) -> &SsaFunction {
self.ssa
}
#[must_use]
pub fn find_dispatchers(&self) -> Vec<DispatcherPattern> {
let mut dispatchers: Vec<_> = (0..self.ssa.block_count())
.filter_map(|block_idx| self.analyze_potential_dispatcher(block_idx))
.collect();
dispatchers.sort_by_key(|d| d.block);
dispatchers
}
fn analyze_potential_dispatcher(&self, block_idx: usize) -> Option<DispatcherPattern> {
let block = self.ssa.block(block_idx)?;
let terminator = block.terminator()?;
let (switch_var, targets, default) = match terminator.op() {
SsaOp::Switch {
value,
targets,
default,
} => (*value, targets.clone(), *default),
_ => return None,
};
if targets.len() < 2 {
return None;
}
let has_loopback = targets
.iter()
.any(|&target| self.reaches_block(target, block_idx))
|| self.reaches_block(default, block_idx);
if !has_loopback {
return None;
}
let dispatch_expr = self.build_dispatch_expression(block_idx, switch_var);
let state_vars = dispatch_expr
.as_ref()
.map(|e| e.variables().into_iter().collect())
.unwrap_or_default();
Some(DispatcherPattern {
block: block_idx,
switch_var,
targets,
default,
dispatch_expr,
state_vars,
})
}
fn reaches_block(&self, from_block: usize, target_block: usize) -> bool {
let mut visited = HashSet::new();
let mut queue = vec![from_block];
let max_depth = 50; let mut depth = 0;
while !queue.is_empty() && depth < max_depth {
let mut next_queue = Vec::new();
for block_idx in queue {
if block_idx == target_block {
return true;
}
if !visited.insert(block_idx) {
continue;
}
if let Some(successors) = self.block_successors(block_idx) {
next_queue.extend(successors);
}
}
queue = next_queue;
depth += 1;
}
false
}
fn block_successors(&self, block_idx: usize) -> Option<Vec<usize>> {
let block = self.ssa.block(block_idx)?;
let terminator = block.terminator()?;
let successors = match terminator.op() {
SsaOp::Jump { target } => vec![*target],
SsaOp::Branch {
true_target,
false_target,
..
} => vec![*true_target, *false_target],
SsaOp::Switch {
targets, default, ..
} => {
let mut succs: Vec<_> = targets.clone();
succs.push(*default);
succs
}
SsaOp::Return { .. } | SsaOp::Throw { .. } => vec![],
_ => return None,
};
Some(successors)
}
fn build_dispatch_expression(
&self,
block_idx: usize,
switch_var: SsaVarId,
) -> Option<SymbolicExpr> {
let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size);
if let Some(block) = self.ssa.block(block_idx) {
for phi in block.phi_nodes() {
let name = format!("phi_{}", phi.result().index());
eval.set_symbolic(phi.result(), name);
}
}
eval.evaluate_block(block_idx);
eval.get(switch_var).cloned()
}
#[must_use]
pub fn find_sources(&self, dispatcher: &DispatcherPattern) -> Vec<SourceBlock> {
let reaching_blocks = self.find_reaching_blocks(dispatcher.block);
reaching_blocks
.into_iter()
.filter(|&block_idx| block_idx != dispatcher.block)
.filter_map(|block_idx| self.analyze_source_block(block_idx, dispatcher))
.collect()
}
fn find_reaching_blocks(&self, dispatcher_block: usize) -> HashSet<usize> {
let mut reaching = HashSet::new();
let mut predecessors: HashMap<usize, Vec<usize>> = HashMap::new();
for block_idx in 0..self.ssa.block_count() {
if let Some(succs) = self.block_successors(block_idx) {
for succ in succs {
predecessors.entry(succ).or_default().push(block_idx);
}
}
}
let mut queue = vec![dispatcher_block];
while let Some(block_idx) = queue.pop() {
if !reaching.insert(block_idx) {
continue;
}
if let Some(preds) = predecessors.get(&block_idx) {
queue.extend(preds.iter().copied());
}
}
reaching
}
fn analyze_source_block(
&self,
block_idx: usize,
dispatcher: &DispatcherPattern,
) -> Option<SourceBlock> {
let block = self.ssa.block(block_idx)?;
let terminator = block.terminator()?;
let (leads_to_dispatcher, is_conditional) = match terminator.op() {
SsaOp::Jump { target } => (*target == dispatcher.block, false),
SsaOp::Branch {
true_target,
false_target,
..
} => {
let leads = *true_target == dispatcher.block || *false_target == dispatcher.block;
(leads, true)
}
_ => return None,
};
if !leads_to_dispatcher {
return None;
}
let state_value = self.compute_state_value(block_idx, dispatcher);
let target_case = self.compute_target_case(state_value.as_ref(), dispatcher);
Some(SourceBlock {
block: block_idx,
state_value,
target_case,
is_conditional,
})
}
fn compute_state_value(
&self,
block_idx: usize,
dispatcher: &DispatcherPattern,
) -> Option<SymbolicExpr> {
let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size);
if let Some(block) = self.ssa.block(block_idx) {
for phi in block.phi_nodes() {
let name = format!("phi_{}", phi.result().index());
eval.set_symbolic(phi.result(), name);
}
}
eval.evaluate_block(block_idx);
if let Some(state_var) = dispatcher.state_vars.first() {
if let Some(disp_block) = self.ssa.block(dispatcher.block) {
for phi in disp_block.phi_nodes() {
if phi.result() == *state_var {
for operand in phi.operands() {
if operand.predecessor() == block_idx {
return eval.get(operand.value()).cloned();
}
}
}
}
}
}
None
}
fn compute_target_case(
&self,
state_value: Option<&SymbolicExpr>,
dispatcher: &DispatcherPattern,
) -> Option<usize> {
let concrete_state = state_value.and_then(SymbolicExpr::as_constant)?;
let dispatch_expr = dispatcher.dispatch_expr.as_ref()?;
let state_var_names: Vec<String> = dispatcher
.state_vars
.iter()
.map(|v| format!("phi_{}", v.index()))
.collect();
let mut bindings: HashMap<&str, ConstValue> = HashMap::new();
for name in &state_var_names {
bindings.insert(name.as_str(), concrete_state.clone());
}
bindings.insert("state", concrete_state.clone());
let case_idx = dispatch_expr.evaluate_named(&bindings, self.pointer_size)?;
let idx = case_idx.as_i64().and_then(|v| usize::try_from(v).ok())?;
if idx < dispatcher.targets.len() {
Some(idx)
} else {
None }
}
#[must_use]
pub fn find_opaque_predicates(&self) -> Vec<OpaquePredicatePattern> {
(0..self.ssa.block_count())
.filter_map(|block_idx| self.analyze_opaque_predicate(block_idx))
.collect()
}
fn analyze_opaque_predicate(&self, block_idx: usize) -> Option<OpaquePredicatePattern> {
let block = self.ssa.block(block_idx)?;
let terminator = block.terminator()?;
let (condition_var, true_target, false_target) = match terminator.op() {
SsaOp::Branch {
condition,
true_target,
false_target,
} => (*condition, *true_target, *false_target),
_ => return None,
};
let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size);
for phi in block.phi_nodes() {
let name = format!("phi_{}", phi.result().index());
eval.set_symbolic(phi.result(), name);
}
eval.evaluate_block(block_idx);
let condition_value = eval.get(condition_var);
let resolution = match condition_value {
Some(expr) if expr.is_constant() => {
if expr.as_constant().is_some_and(ConstValue::is_zero) {
PredicateResolution::AlwaysFalse
} else {
PredicateResolution::AlwaysTrue
}
}
Some(expr) => PredicateResolution::Symbolic(expr.clone()),
None => PredicateResolution::Unknown,
};
if matches!(
resolution,
PredicateResolution::AlwaysTrue | PredicateResolution::AlwaysFalse
) {
Some(OpaquePredicatePattern {
block: block_idx,
condition_var,
true_target,
false_target,
resolution,
})
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct DispatcherPattern {
pub block: usize,
pub switch_var: SsaVarId,
pub targets: Vec<usize>,
pub default: usize,
pub dispatch_expr: Option<SymbolicExpr>,
pub state_vars: Vec<SsaVarId>,
}
impl DispatcherPattern {
#[must_use]
pub fn case_count(&self) -> usize {
self.targets.len()
}
#[must_use]
pub fn target_for_case(&self, case_idx: usize) -> usize {
if case_idx < self.targets.len() {
self.targets[case_idx]
} else {
self.default
}
}
}
#[derive(Debug, Clone)]
pub struct SourceBlock {
pub block: usize,
pub state_value: Option<SymbolicExpr>,
pub target_case: Option<usize>,
pub is_conditional: bool,
}
#[derive(Debug, Clone)]
pub struct OpaquePredicatePattern {
pub block: usize,
pub condition_var: SsaVarId,
pub true_target: usize,
pub false_target: usize,
pub resolution: PredicateResolution,
}
impl OpaquePredicatePattern {
#[must_use]
pub fn actual_target(&self) -> Option<usize> {
match self.resolution {
PredicateResolution::AlwaysTrue => Some(self.true_target),
PredicateResolution::AlwaysFalse => Some(self.false_target),
_ => None,
}
}
#[must_use]
pub fn dead_target(&self) -> Option<usize> {
match self.resolution {
PredicateResolution::AlwaysTrue => Some(self.false_target),
PredicateResolution::AlwaysFalse => Some(self.true_target),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub enum PredicateResolution {
AlwaysTrue,
AlwaysFalse,
Symbolic(SymbolicExpr),
Unknown,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::SsaFunctionBuilder;
#[test]
fn test_pattern_detector_creation() {
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.ret());
});
let detector = PatternDetector::new(&ssa, PointerSize::Bit32);
assert_eq!(detector.ssa().block_count(), 1);
}
#[test]
fn test_find_no_dispatchers_in_simple_function() {
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v = b.const_i32(42);
b.ret_val(v);
});
});
let detector = PatternDetector::new(&ssa, PointerSize::Bit32);
let dispatchers = detector.find_dispatchers();
assert!(dispatchers.is_empty());
}
#[test]
fn test_dispatcher_pattern_methods() {
let pattern = DispatcherPattern {
block: 0,
switch_var: SsaVarId::new(),
targets: vec![1, 2, 3],
default: 4,
dispatch_expr: None,
state_vars: vec![],
};
assert_eq!(pattern.case_count(), 3);
assert_eq!(pattern.target_for_case(0), 1);
assert_eq!(pattern.target_for_case(1), 2);
assert_eq!(pattern.target_for_case(2), 3);
assert_eq!(pattern.target_for_case(10), 4); }
#[test]
fn test_opaque_predicate_methods() {
let cond = SsaVarId::new();
let pattern = OpaquePredicatePattern {
block: 0,
condition_var: cond,
true_target: 1,
false_target: 2,
resolution: PredicateResolution::AlwaysTrue,
};
assert_eq!(pattern.actual_target(), Some(1));
assert_eq!(pattern.dead_target(), Some(2));
let pattern2 = OpaquePredicatePattern {
block: 0,
condition_var: cond,
true_target: 1,
false_target: 2,
resolution: PredicateResolution::AlwaysFalse,
};
assert_eq!(pattern2.actual_target(), Some(2));
assert_eq!(pattern2.dead_target(), Some(1));
}
}