use std::collections::{BTreeMap, BTreeSet};
use std::num::NonZeroU8;
use droidsaw_common::analysis::{TaintFinding, TaintSink};
use droidsaw_common::cross_layer_taint::{
AmbiguousCause, BridgeEdge, CrossLayerTaintFinding, JsBridgeKey,
NativeModuleMethodName, NativeModuleName, StitchOutcome,
};
use droidsaw_common::finding::{Finding, Layer, Severity};
use droidsaw_dex::ids::MethodIdx;
use crate::analysis::bridge::BridgeResolver;
pub const BRIDGE_RESOLUTION_AMBIGUOUS: &str = "BRIDGE_RESOLUTION_AMBIGUOUS";
pub const CROSS_LAYER_TAINT_FLOW: &str = "CROSS_LAYER_TAINT_FLOW";
#[derive(Debug, Clone)]
pub struct HbcStitchPayload {
pub tf: TaintFinding,
}
#[derive(Debug, Clone)]
pub struct DexBridgeStitchPayload {
pub tf: TaintFinding,
pub js_method: String,
pub dex_idx: usize,
pub method_idx: MethodIdx,
pub sink_reachable_seed_positions: BTreeSet<usize>,
}
#[derive(Debug, Clone)]
pub struct HbcBackwalkFailurePayload {
pub func_id: u32,
pub hop_count: u8,
}
pub fn stitch_cross_layer_taint(
hbc_payloads: Vec<HbcStitchPayload>,
dex_bridge_payloads: Vec<DexBridgeStitchPayload>,
backwalk_failures: Vec<HbcBackwalkFailurePayload>,
bridge: &BridgeResolver,
) -> StitchOutcome<MethodIdx> {
let mut outcome: StitchOutcome<MethodIdx> = StitchOutcome {
composites: Vec::new(),
unjoined_hbc: Vec::new(),
unjoined_dex: Vec::new(),
ambiguous: Vec::new(),
};
for fail in backwalk_failures {
outcome.ambiguous.push(make_ambiguous_finding(
None,
None,
&AmbiguousCause::ChainExtractionFailed { hop_count: fail.hop_count },
Some(fail.func_id),
));
}
let mut dex_index: BTreeMap<(usize, MethodIdx), Vec<usize>> = BTreeMap::new();
for (i, p) in dex_bridge_payloads.iter().enumerate() {
dex_index.entry((p.dex_idx, p.method_idx)).or_default().push(i);
}
let dex_payloads = dex_bridge_payloads;
let mut dex_claimed: BTreeSet<usize> = BTreeSet::new();
for HbcStitchPayload { tf } in hbc_payloads {
let (module, method, arg_positions) = match &tf.sink {
TaintSink::NativeModuleArg { module, method, arg_positions } => {
(module.clone(), method.clone(), arg_positions.clone())
}
_ => {
outcome.unjoined_hbc.push(tf);
continue;
}
};
let key = JsBridgeKey::new(module.clone(), method.clone());
let mappings_match = bridge.mappings.get(&key);
let by_method_match = bridge.by_method.get(method.as_str());
match mappings_match {
None => {
let cause = if by_method_match.is_some() {
AmbiguousCause::LegacyNoReactModule
} else {
AmbiguousCause::ResolverZeroMatch
};
outcome.ambiguous.push(make_ambiguous_finding(
Some(&module),
Some(&method),
&cause,
Some(tf.func_id),
));
}
Some(targets) if targets.is_empty() => {
outcome.ambiguous.push(make_ambiguous_finding(
Some(&module),
Some(&method),
&AmbiguousCause::ResolverZeroMatch,
Some(tf.func_id),
));
}
Some(targets) if targets.len() == 1 => {
let Some(&(dex_idx, method_idx)) = targets.first() else {
continue;
};
let mut emitted_any = false;
if let Some(indices) = dex_index.get(&(dex_idx, method_idx)) {
for &dex_p_idx in indices {
if dex_claimed.contains(&dex_p_idx) {
continue;
}
let Some(dex_p) = dex_payloads.get(dex_p_idx) else {
continue;
};
let overlap = arg_positions
.intersection(&dex_p.sink_reachable_seed_positions)
.next()
.is_some();
if !overlap {
continue;
}
dex_claimed.insert(dex_p_idx);
outcome.composites.push(CrossLayerTaintFinding {
js_source: tf.source.clone(),
js_func_id: tf.func_id,
bridge: BridgeEdge {
js_module: module.clone(),
js_method: method.clone(),
dex_idx,
method_idx,
},
native_sink: dex_p.tf.sink.clone(),
native_func_id: dex_p.tf.func_id,
native_class_descriptor: dex_p.tf.class_descriptor.clone(),
native_method_signature: dex_p.tf.method_signature.clone(),
severity: severity_for_sink(&dex_p.tf.sink),
cwe: cwe_for_sink(&dex_p.tf.sink),
});
emitted_any = true;
}
}
if !emitted_any {
outcome.unjoined_hbc.push(tf);
}
}
Some(targets) => {
let saturated = u8::try_from(targets.len()).unwrap_or(u8::MAX);
let candidates = NonZeroU8::new(saturated).unwrap_or(NonZeroU8::MIN);
outcome.ambiguous.push(make_ambiguous_finding(
Some(&module),
Some(&method),
&AmbiguousCause::ResolverMultiMatch { candidates },
Some(tf.func_id),
));
}
}
}
for (i, p) in dex_payloads.into_iter().enumerate() {
if !dex_claimed.contains(&i) {
outcome.unjoined_dex.push(p.tf);
}
}
outcome
}
pub fn make_ambiguous_finding(
module: Option<&NativeModuleName>,
method: Option<&NativeModuleMethodName>,
cause: &AmbiguousCause,
func_id: Option<u32>,
) -> Finding {
let cause_tag = match cause {
AmbiguousCause::ChainExtractionFailed { hop_count } => {
format!("chain-extraction-failed (hop_count={hop_count})")
}
AmbiguousCause::ResolverZeroMatch => "resolver-zero-match".into(),
AmbiguousCause::ResolverMultiMatch { candidates } => {
format!("resolver-multi-match (candidates={candidates})")
}
AmbiguousCause::LegacyNoReactModule => "legacy-no-react-module".into(),
};
let label = match (module, method) {
(Some(m), Some(meth)) => format!("{}::{}", m.as_str(), meth.as_str()),
_ => "<unresolved>".into(),
};
let detail = format!("bridge resolution ambiguous for {label} — {cause_tag}");
let extra = serde_json::to_string(cause).unwrap_or_else(|_| "{}".into());
let mut f = Finding::new(
BRIDGE_RESOLUTION_AMBIGUOUS,
Layer::Hbc,
Severity::Info,
detail,
)
.with_extra(extra)
.with_cwe(1023);
if let Some(id) = func_id {
f = f.with_func(id);
}
f
}
pub fn severity_for_sink(sink: &TaintSink) -> Severity {
match sink {
TaintSink::RuntimeExec | TaintSink::Eval => Severity::Critical,
TaintSink::SqlExecute => Severity::Critical,
TaintSink::WebViewLoadUrl => Severity::High,
TaintSink::ReflectionInvoke { .. } => Severity::High,
TaintSink::FilePathTraversal { .. } => Severity::High,
TaintSink::FileWrite { .. } => Severity::High,
TaintSink::LogOutput => Severity::Medium,
TaintSink::ContentProviderInsert { .. } => Severity::Medium,
TaintSink::NativeModuleArg { .. } => Severity::Medium,
TaintSink::CryptoInput { .. } => Severity::Medium,
TaintSink::NetworkFetch | TaintSink::HttpRequest { .. } => Severity::Low,
_ => Severity::Medium,
}
}
pub fn cwe_for_sink(sink: &TaintSink) -> Option<u16> {
match sink {
TaintSink::RuntimeExec | TaintSink::Eval => Some(78),
TaintSink::SqlExecute => Some(89),
TaintSink::WebViewLoadUrl => Some(79),
TaintSink::ReflectionInvoke { .. } => Some(470),
TaintSink::FilePathTraversal { .. } => Some(22),
TaintSink::FileWrite { .. } => Some(22),
TaintSink::LogOutput => Some(532),
TaintSink::ContentProviderInsert { .. } => Some(862),
TaintSink::NativeModuleArg { .. } => Some(20),
TaintSink::CryptoInput { .. } => Some(327),
TaintSink::NetworkFetch | TaintSink::HttpRequest { .. } => Some(918),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use droidsaw_common::analysis::TaintSource;
fn nm(s: &str) -> NativeModuleName {
NativeModuleName::try_new(s.into()).expect("non-empty test fixture")
}
fn nmm(s: &str) -> NativeModuleMethodName {
NativeModuleMethodName::try_new(s.into()).expect("non-empty test fixture")
}
fn hbc(module: &str, method: &str, positions: &[usize], func_id: u32) -> HbcStitchPayload {
HbcStitchPayload {
tf: TaintFinding {
source: TaintSource::UserInput,
sink: TaintSink::NativeModuleArg {
module: nm(module),
method: nmm(method),
arg_positions: positions.iter().copied().collect(),
},
layer: Layer::Hbc,
func_id,
class_descriptor: None,
method_signature: None,
source_offset: None,
sink_offset: None,
},
}
}
fn dex(
dex_idx: usize,
m_idx_raw: u32,
positions: &[usize],
js_method: &str,
sink: TaintSink,
) -> DexBridgeStitchPayload {
DexBridgeStitchPayload {
tf: TaintFinding {
source: TaintSource::ReactBridgeParam { method: js_method.into() },
sink,
layer: Layer::Dex,
func_id: 0xdef,
class_descriptor: Some("Lcom/example/Mod;".into()),
method_signature: Some(format!("{js_method}(Ljava/lang/String;)V")),
source_offset: None,
sink_offset: None,
},
js_method: js_method.into(),
dex_idx,
method_idx: MethodIdx(m_idx_raw),
sink_reachable_seed_positions: positions.iter().copied().collect(),
}
}
fn empty_bridge() -> BridgeResolver {
BridgeResolver {
mappings: BTreeMap::new(),
by_method: BTreeMap::new(),
}
}
fn with_mapping(b: &mut BridgeResolver, module: &str, method: &str, dex_idx: usize, m_idx_raw: u32) {
b.mappings
.entry(JsBridgeKey::new(nm(module), nmm(method)))
.or_default()
.push((dex_idx, MethodIdx(m_idx_raw)));
b.by_method
.entry(method.into())
.or_default()
.push((dex_idx, MethodIdx(m_idx_raw)));
}
#[test]
fn single_composite_emits_once() {
let mut b = empty_bridge();
with_mapping(&mut b, "Crypto", "encrypt", 0, 12345);
let h = vec![hbc("Crypto", "encrypt", &[0], 1)];
let d = vec![dex(0, 12345, &[0], "encrypt", TaintSink::RuntimeExec)];
let out = stitch_cross_layer_taint(h, d, vec![], &b);
assert_eq!(out.composites.len(), 1);
assert!(out.unjoined_hbc.is_empty());
assert!(out.unjoined_dex.is_empty());
assert!(out.ambiguous.is_empty());
let c = &out.composites[0];
assert_eq!(c.bridge.js_module.as_str(), "Crypto");
assert_eq!(c.bridge.js_method.as_str(), "encrypt");
assert_eq!(c.bridge.dex_idx, 0);
assert_eq!(c.bridge.method_idx, MethodIdx(12345));
assert_eq!(c.severity, Severity::Critical);
assert_eq!(c.cwe, Some(78));
}
#[test]
fn zero_match_emits_resolver_zero_match() {
let b = empty_bridge();
let h = vec![hbc("Crypto", "encrypt", &[0], 1)];
let out = stitch_cross_layer_taint(h, vec![], vec![], &b);
assert_eq!(out.composites.len(), 0);
assert_eq!(out.ambiguous.len(), 1);
let f = &out.ambiguous[0];
assert_eq!(f.id, BRIDGE_RESOLUTION_AMBIGUOUS);
let cause: AmbiguousCause = serde_json::from_str(
f.extra.as_deref().expect("extra populated"),
)
.expect("AmbiguousCause roundtrips");
assert!(matches!(cause, AmbiguousCause::ResolverZeroMatch));
assert_eq!(f.cwe, Some(1023));
assert_eq!(f.func_id, Some(1));
}
#[test]
fn legacy_no_react_module_emits_legacy_cause() {
let mut b = empty_bridge();
b.by_method.entry("getName".into()).or_default().push((0, MethodIdx(99)));
let h = vec![hbc("Crypto", "getName", &[0], 1)];
let out = stitch_cross_layer_taint(h, vec![], vec![], &b);
assert_eq!(out.ambiguous.len(), 1);
let cause: AmbiguousCause = serde_json::from_str(
out.ambiguous[0].extra.as_deref().expect("extra populated"),
)
.expect("AmbiguousCause roundtrips");
assert!(matches!(cause, AmbiguousCause::LegacyNoReactModule));
}
#[test]
fn multi_match_emits_ambiguous_with_candidates() {
let mut b = empty_bridge();
with_mapping(&mut b, "M", "x", 0, 1);
with_mapping(&mut b, "M", "x", 0, 2);
with_mapping(&mut b, "M", "x", 0, 3);
let h = vec![hbc("M", "x", &[0], 1)];
let out = stitch_cross_layer_taint(h, vec![], vec![], &b);
assert_eq!(out.composites.len(), 0);
assert_eq!(out.ambiguous.len(), 1);
let cause: AmbiguousCause = serde_json::from_str(
out.ambiguous[0].extra.as_deref().expect("extra populated"),
)
.expect("AmbiguousCause roundtrips");
match cause {
AmbiguousCause::ResolverMultiMatch { candidates } => {
assert_eq!(candidates.get(), 3);
}
other => panic!("expected ResolverMultiMatch, got {other:?}"),
}
}
#[test]
fn chain_extraction_failed_emits_ambiguous() {
let b = empty_bridge();
let out = stitch_cross_layer_taint(
vec![],
vec![],
vec![HbcBackwalkFailurePayload { func_id: 42, hop_count: 1 }],
&b,
);
assert_eq!(out.ambiguous.len(), 1);
let cause: AmbiguousCause = serde_json::from_str(
out.ambiguous[0].extra.as_deref().expect("extra populated"),
)
.expect("AmbiguousCause roundtrips");
match cause {
AmbiguousCause::ChainExtractionFailed { hop_count } => {
assert_eq!(hop_count, 1);
}
other => panic!("expected ChainExtractionFailed, got {other:?}"),
}
}
#[test]
fn two_function_no_cross_function_contamination() {
let mut b = empty_bridge();
with_mapping(&mut b, "Crypto", "encrypt", 0, 100);
with_mapping(&mut b, "Storage", "write", 0, 200);
let h = vec![
hbc("Crypto", "encrypt", &[0], 1),
hbc("Storage", "write", &[0], 2),
];
let d = vec![dex(0, 100, &[0], "encrypt", TaintSink::RuntimeExec)];
let out = stitch_cross_layer_taint(h, d, vec![], &b);
assert_eq!(out.composites.len(), 1, "exactly one composite (on Crypto.encrypt)");
assert_eq!(out.composites[0].bridge.js_module.as_str(), "Crypto");
assert_eq!(out.unjoined_hbc.len(), 1);
}
#[test]
fn dex_claimed_prevents_n_to_1_fanin() {
let mut b = empty_bridge();
with_mapping(&mut b, "M", "x", 0, 1);
let h = vec![
hbc("M", "x", &[0], 1),
hbc("M", "x", &[0], 2),
];
let d = vec![dex(0, 1, &[0], "x", TaintSink::Eval)];
let out = stitch_cross_layer_taint(h, d, vec![], &b);
assert_eq!(out.composites.len(), 1);
assert_eq!(out.unjoined_hbc.len(), 1);
assert_eq!(out.unjoined_dex.len(), 0);
}
#[test]
fn unjoined_dex_when_no_hbc() {
let mut b = empty_bridge();
with_mapping(&mut b, "M", "x", 0, 1);
let d = vec![dex(0, 1, &[0], "x", TaintSink::RuntimeExec)];
let out = stitch_cross_layer_taint(vec![], d, vec![], &b);
assert_eq!(out.composites.len(), 0);
assert_eq!(out.unjoined_dex.len(), 1);
}
#[test]
fn empty_overlap_routes_hbc_to_unjoined() {
let mut b = empty_bridge();
with_mapping(&mut b, "M", "x", 0, 1);
let h = vec![hbc("M", "x", &[1], 7)];
let d = vec![dex(0, 1, &[0], "x", TaintSink::RuntimeExec)];
let out = stitch_cross_layer_taint(h, d, vec![], &b);
assert_eq!(out.composites.len(), 0);
assert_eq!(out.unjoined_hbc.len(), 1);
assert_eq!(out.unjoined_dex.len(), 1);
}
#[test]
fn collision_two_modules_share_method_name_resolves_each() {
let mut b = empty_bridge();
with_mapping(&mut b, "ModuleA", "exec", 0, 10);
with_mapping(&mut b, "ModuleB", "exec", 0, 20);
let h = vec![
hbc("ModuleA", "exec", &[0], 1),
hbc("ModuleB", "exec", &[0], 2),
];
let d = vec![
dex(0, 10, &[0], "exec", TaintSink::RuntimeExec),
dex(0, 20, &[0], "exec", TaintSink::Eval),
];
let out = stitch_cross_layer_taint(h, d, vec![], &b);
assert_eq!(out.composites.len(), 2);
let mut by_module: BTreeMap<String, &CrossLayerTaintFinding<MethodIdx>> =
out.composites.iter().map(|c| (c.bridge.js_module.as_str().to_string(), c)).collect();
let a = by_module.remove("ModuleA").expect("ModuleA composite present");
let b_comp = by_module.remove("ModuleB").expect("ModuleB composite present");
assert_eq!(a.bridge.method_idx, MethodIdx(10));
assert!(matches!(a.native_sink, TaintSink::RuntimeExec));
assert_eq!(b_comp.bridge.method_idx, MethodIdx(20));
assert!(matches!(b_comp.native_sink, TaintSink::Eval));
}
#[test]
fn ambiguous_finding_has_func_id_when_provided() {
let f = make_ambiguous_finding(
None,
None,
&AmbiguousCause::ChainExtractionFailed { hop_count: 2 },
Some(0xcafe),
);
assert_eq!(f.id, BRIDGE_RESOLUTION_AMBIGUOUS);
assert_eq!(f.func_id, Some(0xcafe));
assert_eq!(f.cwe, Some(1023));
assert_eq!(f.severity, Severity::Info);
assert_eq!(f.layer, Layer::Hbc);
}
}