use std::collections::HashSet;
use crate::{
ir::{
function::SsaFunction,
ops::SsaOp,
variable::{SsaVarId, VariableOrigin},
},
target::Target,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PhiTaintMode {
TaintAllOperands,
TaintIfAnyOperand,
TaintFromPredecessors(HashSet<usize>),
SelectivePhi {
predecessors: HashSet<usize>,
origin_filter: Option<VariableOrigin>,
},
NoPropagation,
}
#[derive(Debug, Clone)]
pub struct TaintConfig {
pub forward: bool,
pub backward: bool,
pub phi_mode: PhiTaintMode,
pub max_iterations: usize,
}
impl Default for TaintConfig {
fn default() -> Self {
Self {
forward: true,
backward: false,
phi_mode: PhiTaintMode::TaintIfAnyOperand,
max_iterations: 100,
}
}
}
impl TaintConfig {
#[must_use]
pub fn forward_only() -> Self {
Self {
forward: true,
backward: false,
phi_mode: PhiTaintMode::TaintIfAnyOperand,
max_iterations: 100,
}
}
#[must_use]
pub fn bidirectional() -> Self {
Self {
forward: true,
backward: true,
phi_mode: PhiTaintMode::TaintAllOperands,
max_iterations: 100,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TaintStats {
pub iterations: usize,
pub tainted_vars: usize,
pub tainted_instrs: usize,
pub tainted_phis: usize,
}
#[derive(Debug, Clone)]
pub struct TaintAnalysis {
tainted_vars: HashSet<SsaVarId>,
tainted_instrs: HashSet<(usize, usize)>,
tainted_phis: HashSet<(usize, usize)>,
config: TaintConfig,
stats: TaintStats,
}
impl TaintAnalysis {
#[must_use]
pub fn new(config: TaintConfig) -> Self {
Self {
tainted_vars: HashSet::new(),
tainted_instrs: HashSet::new(),
tainted_phis: HashSet::new(),
config,
stats: TaintStats::default(),
}
}
#[must_use]
pub fn forward_only() -> Self {
Self::new(TaintConfig::forward_only())
}
#[must_use]
pub fn bidirectional() -> Self {
Self::new(TaintConfig::bidirectional())
}
pub fn add_tainted_var(&mut self, var: SsaVarId) {
self.tainted_vars.insert(var);
}
pub fn add_tainted_vars(&mut self, vars: impl IntoIterator<Item = SsaVarId>) {
self.tainted_vars.extend(vars);
}
pub fn add_tainted_instr<T: Target>(
&mut self,
block: usize,
instr: usize,
ssa: &SsaFunction<T>,
) {
self.tainted_instrs.insert((block, instr));
if let Some(block_data) = ssa.block(block) {
if let Some(instruction) = block_data.instructions().get(instr) {
if let Some(def) = instruction.def() {
self.tainted_vars.insert(def);
}
if self.config.backward {
for use_var in instruction.uses() {
self.tainted_vars.insert(use_var);
}
}
}
}
}
pub fn add_tainted_phi<T: Target>(
&mut self,
block: usize,
phi_idx: usize,
ssa: &SsaFunction<T>,
) {
self.tainted_phis.insert((block, phi_idx));
if let Some(block_data) = ssa.block(block) {
if let Some(phi) = block_data.phi_nodes().get(phi_idx) {
self.tainted_vars.insert(phi.result());
}
}
}
pub fn propagate<T: Target>(&mut self, ssa: &SsaFunction<T>) {
let mut iterations: usize = 0;
loop {
if iterations >= self.config.max_iterations {
break;
}
iterations = iterations.saturating_add(1);
let mut changed = false;
changed |= self.propagate_phis(ssa);
changed |= self.propagate_instructions(ssa);
if !changed {
break;
}
}
self.stats = TaintStats {
iterations,
tainted_vars: self.tainted_vars.len(),
tainted_instrs: self.tainted_instrs.len(),
tainted_phis: self.tainted_phis.len(),
};
}
fn propagate_phis<T: Target>(&mut self, ssa: &SsaFunction<T>) -> bool {
let mut changed = false;
for (block_idx, block) in ssa.blocks().iter().enumerate() {
for (phi_idx, phi) in block.phi_nodes().iter().enumerate() {
let result = phi.result();
let result_tainted = self.tainted_vars.contains(&result);
match &self.config.phi_mode {
PhiTaintMode::TaintAllOperands => {
if result_tainted {
for operand in phi.operands() {
if self.tainted_vars.insert(operand.value()) {
changed = true;
}
}
if self.tainted_phis.insert((block_idx, phi_idx)) {
changed = true;
}
}
}
PhiTaintMode::TaintIfAnyOperand => {
let any_operand_tainted = phi
.operands()
.iter()
.any(|op| self.tainted_vars.contains(&op.value()));
if any_operand_tainted {
if self.tainted_vars.insert(result) {
changed = true;
}
if self.tainted_phis.insert((block_idx, phi_idx)) {
changed = true;
}
}
}
PhiTaintMode::TaintFromPredecessors(preds) => {
if result_tainted {
for operand in phi.operands() {
if preds.contains(&operand.predecessor())
&& self.tainted_vars.insert(operand.value())
{
changed = true;
}
}
if self.tainted_phis.insert((block_idx, phi_idx)) {
changed = true;
}
}
}
PhiTaintMode::SelectivePhi {
predecessors,
origin_filter,
} => {
if result_tainted {
let should_follow = origin_filter
.as_ref()
.is_none_or(|filter| phi.origin() == *filter);
if should_follow {
for operand in phi.operands() {
if predecessors.contains(&operand.predecessor())
&& self.tainted_vars.insert(operand.value())
{
changed = true;
}
}
if self.tainted_phis.insert((block_idx, phi_idx)) {
changed = true;
}
}
}
}
PhiTaintMode::NoPropagation => {
}
}
}
}
changed
}
fn propagate_instructions<T: Target>(&mut self, ssa: &SsaFunction<T>) -> bool {
let mut changed = false;
for (block_idx, instr_idx, instr) in ssa.iter_instructions() {
let def = instr.def();
let uses = instr.uses();
if self.config.forward {
if let Some(def_var) = def {
let uses_tainted = uses.iter().any(|u| self.tainted_vars.contains(u));
if uses_tainted {
if self.tainted_vars.insert(def_var) {
changed = true;
}
if self.tainted_instrs.insert((block_idx, instr_idx)) {
changed = true;
}
}
}
}
if self.config.backward {
let def_tainted = def.is_some_and(|d| self.tainted_vars.contains(&d));
if def_tainted {
for use_var in &uses {
if self.tainted_vars.insert(*use_var) {
changed = true;
}
}
if self.tainted_instrs.insert((block_idx, instr_idx)) {
changed = true;
}
}
}
if self.config.backward {
if let SsaOp::StoreElement { array, .. } = instr.op() {
if self.tainted_vars.contains(array)
&& self.tainted_instrs.insert((block_idx, instr_idx))
{
changed = true;
for use_var in &uses {
if self.tainted_vars.insert(*use_var) {
changed = true;
}
}
}
}
}
let uses_tainted = uses.iter().any(|u| self.tainted_vars.contains(u));
if uses_tainted && self.tainted_instrs.insert((block_idx, instr_idx)) {
changed = true;
}
}
changed
}
#[must_use]
pub fn is_var_tainted(&self, var: SsaVarId) -> bool {
self.tainted_vars.contains(&var)
}
#[must_use]
pub fn is_instr_tainted(&self, block: usize, instr: usize) -> bool {
self.tainted_instrs.contains(&(block, instr))
}
#[must_use]
pub fn is_phi_tainted(&self, block: usize, phi_idx: usize) -> bool {
self.tainted_phis.contains(&(block, phi_idx))
}
#[must_use]
pub fn tainted_variables(&self) -> &HashSet<SsaVarId> {
&self.tainted_vars
}
#[must_use]
pub fn tainted_instructions(&self) -> &HashSet<(usize, usize)> {
&self.tainted_instrs
}
#[must_use]
pub fn tainted_phis(&self) -> &HashSet<(usize, usize)> {
&self.tainted_phis
}
#[must_use]
pub fn stats(&self) -> &TaintStats {
&self.stats
}
#[must_use]
pub fn tainted_var_count(&self) -> usize {
self.tainted_vars.len()
}
#[must_use]
pub fn tainted_instr_count(&self) -> usize {
self.tainted_instrs.len()
}
pub fn clear(&mut self) {
self.tainted_vars.clear();
self.tainted_instrs.clear();
self.tainted_phis.clear();
self.stats = TaintStats::default();
}
}
#[must_use]
pub fn find_blocks_jumping_to<T: Target>(ssa: &SsaFunction<T>, target: usize) -> HashSet<usize> {
let mut jumpers = HashSet::new();
for block in ssa.blocks() {
if let Some(terminator) = block.instructions().last() {
let jumps_to_target = match terminator.op() {
SsaOp::Jump { target: t } | SsaOp::Leave { target: t } => *t == target,
SsaOp::Branch {
true_target,
false_target,
..
}
| SsaOp::BranchCmp {
true_target,
false_target,
..
} => *true_target == target || *false_target == target,
SsaOp::Switch {
targets, default, ..
} => *default == target || targets.contains(&target),
_ => false,
};
if jumps_to_target {
jumpers.insert(block.id());
}
}
}
jumpers
}
#[must_use]
pub fn cff_taint_config<T: Target>(
ssa: &SsaFunction<T>,
dispatcher_block: usize,
state_origin: Option<VariableOrigin>,
) -> TaintConfig {
let predecessors = find_blocks_jumping_to(ssa, dispatcher_block);
TaintConfig {
forward: true,
backward: false,
phi_mode: PhiTaintMode::SelectivePhi {
predecessors,
origin_filter: state_origin,
},
max_iterations: 100,
}
}