use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::collections::btree_map::Entry;
use droidsaw_hermes::decompile::ssa::{Phase, Raw, Resolved, SsaFunction, SsaOperand, VarId};
use droidsaw_hermes::opcodes::OpCode;
use droidsaw_common::analysis::{TaintFinding, TaintSink, TaintSource};
use droidsaw_common::cross_layer_taint::{NativeModuleMethodName, NativeModuleName};
use droidsaw_common::finding::Layer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackwalkFailureReason {
CalleeNotVar,
ChainExtractionFailed,
TerminalHopIsGetById,
}
#[derive(Debug, Clone, Copy)]
pub struct BackwalkFailureSite {
pub func_id: u32,
pub op_index: usize,
pub reason: BackwalkFailureReason,
}
pub struct HbcTaintAnalysis {
pub findings: Vec<TaintFinding>,
pub backwalk_failures: Vec<BackwalkFailureSite>,
}
impl HbcTaintAnalysis {
pub fn run_eval_only(
ssa: &SsaFunction<Raw>,
func_id: u32,
layer: Layer,
seeds: BTreeMap<VarId, TaintSource>,
) -> Self {
let taints = propagate_to_fixed_point(ssa, seeds);
let findings = emit_eval_sinks(ssa, &taints, layer, func_id);
Self { findings, backwalk_failures: Vec::new() }
}
pub fn run_full(
ssa: &SsaFunction<Resolved>,
func_id: u32,
layer: Layer,
seeds: BTreeMap<VarId, TaintSource>,
bridge_func_ids: &BTreeSet<u32>,
) -> Self {
let taints = propagate_to_fixed_point(ssa, seeds);
let mut findings = emit_eval_sinks(ssa, &taints, layer, func_id);
let mut backwalk_failures = Vec::new();
if bridge_func_ids.contains(&func_id) {
let bridge_findings =
emit_bridge_sinks(ssa, &taints, layer, func_id, &mut backwalk_failures);
findings.extend(bridge_findings);
}
Self { findings, backwalk_failures }
}
}
fn propagate_to_fixed_point<P: Phase>(
ssa: &SsaFunction<P>,
seeds: BTreeMap<VarId, TaintSource>,
) -> BTreeMap<VarId, TaintSource> {
let mut taints = seeds;
let mut changed = true;
while changed {
changed = false;
for block in &ssa.blocks {
for phi in &block.phis {
for (_pred, arg) in &phi.args {
if let Some(src) = taints.get(arg).cloned()
&& let Entry::Vacant(e) = taints.entry(phi.dst)
{
e.insert(src);
changed = true;
}
}
}
for op in &block.ops {
if let Some(dst) = op.dst {
for operand in &op.operands {
if let SsaOperand::Var(v) = operand
&& let Some(src) = taints.get(v).cloned()
&& let Entry::Vacant(e) = taints.entry(dst)
{
e.insert(src);
changed = true;
}
}
}
}
}
}
taints
}
fn emit_eval_sinks<P: Phase>(
ssa: &SsaFunction<P>,
taints: &BTreeMap<VarId, TaintSource>,
layer: Layer,
func_id: u32,
) -> Vec<TaintFinding> {
let mut findings = Vec::new();
for block in &ssa.blocks {
for op in &block.ops {
if op.op == OpCode::DirectEval {
for operand in &op.operands {
if let SsaOperand::Var(v) = operand
&& let Some(src) = taints.get(v)
{
findings.push(TaintFinding {
source: src.clone(),
sink: TaintSink::Eval,
layer,
func_id,
class_descriptor: None,
method_signature: None,
source_offset: None,
sink_offset: None,
});
}
}
}
}
}
findings
}
fn emit_bridge_sinks(
ssa: &SsaFunction<Resolved>,
taints: &BTreeMap<VarId, TaintSource>,
layer: Layer,
func_id: u32,
failures: &mut Vec<BackwalkFailureSite>,
) -> Vec<TaintFinding> {
let def_map: BTreeMap<VarId, (usize, usize)> = ssa
.blocks
.iter()
.enumerate()
.flat_map(|(bi, b)| {
b.ops.iter().enumerate().filter_map(move |(oi, op)| {
op.dst.map(|d| (d, (bi, oi)))
})
})
.collect();
let mut findings = Vec::new();
let mut linear_op_index: usize = 0;
for block in &ssa.blocks {
for op in &block.ops {
let this_op_index = linear_op_index;
linear_op_index = linear_op_index.saturating_add(1);
if !is_call_op(op.op) {
continue;
}
let mut arg_positions: BTreeSet<usize> = BTreeSet::new();
let mut first_source: Option<TaintSource> = None;
for (arg_pos, operand) in op.operands.iter().skip(3).enumerate() {
if let SsaOperand::Var(v) = operand
&& let Some(src) = taints.get(v)
{
if first_source.is_none() {
first_source = Some(src.clone());
}
arg_positions.insert(arg_pos);
}
}
let Some(source) = first_source else { continue };
let callee_var = match op.operands.get(1) {
Some(SsaOperand::Var(v)) => *v,
_ => {
failures.push(BackwalkFailureSite {
func_id,
op_index: this_op_index,
reason: BackwalkFailureReason::CalleeNotVar,
});
continue;
}
};
match extract_native_modules_chain(callee_var, &def_map, ssa) {
ChainResult::Resolved(module, method) => {
findings.push(TaintFinding {
source,
sink: TaintSink::NativeModuleArg {
module,
method,
arg_positions,
},
layer,
func_id,
class_descriptor: None,
method_signature: None,
source_offset: None,
sink_offset: None,
});
}
ChainResult::Failed(reason) => {
failures.push(BackwalkFailureSite {
func_id,
op_index: this_op_index,
reason,
});
}
}
}
}
findings
}
enum ChainResult {
Resolved(NativeModuleName, NativeModuleMethodName),
Failed(BackwalkFailureReason),
}
fn read_get_by_id_chain_step(
var: VarId,
def_map: &BTreeMap<VarId, (usize, usize)>,
ssa: &SsaFunction<Resolved>,
) -> Option<(VarId, String)> {
let (bi, oi) = def_map.get(&var)?;
let op = ssa.blocks.get(*bi)?.ops.get(*oi)?;
if !is_get_by_id_op(op.op) {
return None;
}
let obj = match op.operands.get(1)? {
SsaOperand::Var(v) => *v,
_ => return None,
};
let name = match op.operands.last()? {
SsaOperand::ResolvedString(s) => s.clone(),
_ => return None,
};
Some((obj, name))
}
fn extract_native_modules_chain(
callee_var: VarId,
def_map: &BTreeMap<VarId, (usize, usize)>,
ssa: &SsaFunction<Resolved>,
) -> ChainResult {
let Some((module_var, method_str)) = read_get_by_id_chain_step(callee_var, def_map, ssa)
else {
return ChainResult::Failed(BackwalkFailureReason::ChainExtractionFailed);
};
let Some((outer_var, module_str)) = read_get_by_id_chain_step(module_var, def_map, ssa) else {
return ChainResult::Failed(BackwalkFailureReason::ChainExtractionFailed);
};
if let Some((bi, oi)) = def_map.get(&outer_var)
&& let Some(op) = ssa.blocks.get(*bi).and_then(|b| b.ops.get(*oi))
&& is_get_by_id_op(op.op)
{
let third_hop_str = match op.operands.last() {
Some(SsaOperand::ResolvedString(s)) => s.as_str(),
_ => return ChainResult::Failed(BackwalkFailureReason::TerminalHopIsGetById),
};
if third_hop_str != "NativeModules" {
return ChainResult::Failed(BackwalkFailureReason::TerminalHopIsGetById);
}
}
let Some(module) = NativeModuleName::try_new(module_str) else {
return ChainResult::Failed(BackwalkFailureReason::ChainExtractionFailed);
};
let Some(method) = NativeModuleMethodName::try_new(method_str) else {
return ChainResult::Failed(BackwalkFailureReason::ChainExtractionFailed);
};
ChainResult::Resolved(module, method)
}
fn is_get_by_id_op(op: OpCode) -> bool {
matches!(
op,
OpCode::GetById
| OpCode::GetByIdShort
| OpCode::GetByIdLong
| OpCode::TryGetById
| OpCode::TryGetByIdLong
)
}
fn is_call_op(op: OpCode) -> bool {
matches!(
op,
OpCode::Call
| OpCode::CallLong
| OpCode::Call1
| OpCode::Call2
| OpCode::Call3
| OpCode::Call4
| OpCode::CallDirect
| OpCode::CallDirectLongIndex
| OpCode::CallWithNewTarget
| OpCode::CallWithNewTargetLong
)
}