use crate::semantics::LanguageSemantics;
use rma_common::Language;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CallbackSite {
pub file: PathBuf,
pub line: usize,
pub column: usize,
pub hof_name: String,
pub receiver: Option<String>,
pub kind: CallbackKind,
pub callback_params: Vec<String>,
pub containing_function: Option<String>,
pub node_id: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CallbackKind {
ArrayIterator,
ArrayReducer,
PromiseChain,
EventHandler,
TimerCallback,
HigherOrderFunction,
}
impl CallbackKind {
pub fn tainted_param_index(&self) -> usize {
match self {
CallbackKind::ArrayIterator => 0,
CallbackKind::ArrayReducer => 1,
CallbackKind::PromiseChain => 0,
CallbackKind::EventHandler => 0,
CallbackKind::TimerCallback => usize::MAX,
CallbackKind::HigherOrderFunction => 0,
}
}
}
#[derive(Debug, Clone)]
pub struct CallbackTaintFlow {
pub callback_site: CallbackSite,
pub taint_source: TaintSource,
pub target_param: String,
pub target_param_index: usize,
pub confidence: TaintConfidence,
}
#[derive(Debug, Clone)]
pub enum TaintSource {
ArrayElements(String),
PromiseResolution(String),
EventData { event_name: String, emitter: String },
Accumulator(String),
Variable(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaintConfidence {
Definite,
Possible,
Speculative,
}
#[derive(Debug, Default)]
pub struct CallbackRegistry {
callback_sites: Vec<CallbackSite>,
by_receiver: HashMap<String, Vec<usize>>,
by_hof_name: HashMap<String, Vec<usize>>,
taint_flows: Vec<CallbackTaintFlow>,
tainted_vars: HashSet<String>,
}
impl CallbackRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_tainted_vars(tainted_vars: HashSet<String>) -> Self {
Self {
tainted_vars,
..Default::default()
}
}
pub fn register_callback(&mut self, site: CallbackSite) {
let index = self.callback_sites.len();
if let Some(ref receiver) = site.receiver {
self.by_receiver
.entry(receiver.clone())
.or_default()
.push(index);
}
self.by_hof_name
.entry(site.hof_name.clone())
.or_default()
.push(index);
self.callback_sites.push(site);
}
pub fn add_tainted_var(&mut self, var: String) {
self.tainted_vars.insert(var);
}
pub fn is_tainted(&self, var: &str) -> bool {
self.tainted_vars.contains(var)
}
pub fn compute_taint_flows(&mut self) {
self.taint_flows.clear();
for site in &self.callback_sites {
if let Some(flow) = self.compute_flow_for_site(site) {
self.taint_flows.push(flow);
}
}
}
fn compute_flow_for_site(&self, site: &CallbackSite) -> Option<CallbackTaintFlow> {
let tainted_param_idx = site.kind.tainted_param_index();
if tainted_param_idx == usize::MAX {
return None; }
let target_param = site.callback_params.get(tainted_param_idx)?;
let (taint_source, confidence) = match site.kind {
CallbackKind::ArrayIterator | CallbackKind::ArrayReducer => {
if let Some(ref receiver) = site.receiver {
if self.tainted_vars.contains(receiver) {
(
TaintSource::ArrayElements(receiver.clone()),
TaintConfidence::Definite,
)
} else {
let possibly_tainted = self.tainted_vars.iter().any(|t| {
receiver.starts_with(t) || receiver.contains(&format!(".{}", t))
});
if possibly_tainted {
(
TaintSource::ArrayElements(receiver.clone()),
TaintConfidence::Possible,
)
} else {
return None;
}
}
} else {
return None;
}
}
CallbackKind::PromiseChain => {
if let Some(ref receiver) = site.receiver {
if self.tainted_vars.contains(receiver) {
(
TaintSource::PromiseResolution(receiver.clone()),
TaintConfidence::Definite,
)
} else {
(
TaintSource::PromiseResolution(receiver.clone()),
TaintConfidence::Speculative,
)
}
} else {
return None;
}
}
CallbackKind::EventHandler => {
if let Some(ref receiver) = site.receiver {
(
TaintSource::EventData {
event_name: site.hof_name.clone(),
emitter: receiver.clone(),
},
TaintConfidence::Possible,
)
} else {
return None;
}
}
CallbackKind::TimerCallback => {
return None;
}
CallbackKind::HigherOrderFunction => {
if let Some(ref receiver) = site.receiver {
if self.tainted_vars.contains(receiver) {
(
TaintSource::Variable(receiver.clone()),
TaintConfidence::Possible,
)
} else {
return None;
}
} else {
return None;
}
}
};
Some(CallbackTaintFlow {
callback_site: site.clone(),
taint_source,
target_param: target_param.clone(),
target_param_index: tainted_param_idx,
confidence,
})
}
pub fn taint_flows(&self) -> &[CallbackTaintFlow] {
&self.taint_flows
}
pub fn callbacks_for_receiver(&self, receiver: &str) -> Vec<&CallbackSite> {
self.by_receiver
.get(receiver)
.map(|indices| indices.iter().map(|&i| &self.callback_sites[i]).collect())
.unwrap_or_default()
}
pub fn callbacks_for_hof(&self, hof_name: &str) -> Vec<&CallbackSite> {
self.by_hof_name
.get(hof_name)
.map(|indices| indices.iter().map(|&i| &self.callback_sites[i]).collect())
.unwrap_or_default()
}
pub fn all_callbacks(&self) -> &[CallbackSite] {
&self.callback_sites
}
pub fn tainted_callback_params(&self) -> HashSet<String> {
self.taint_flows
.iter()
.filter(|f| f.confidence != TaintConfidence::Speculative)
.map(|f| f.target_param.clone())
.collect()
}
}
pub struct CallbackPatterns {
pub array_iterators: &'static [&'static str],
pub array_reducers: &'static [&'static str],
pub promise_methods: &'static [&'static str],
pub event_handlers: &'static [&'static str],
pub timer_functions: &'static [&'static str],
}
impl CallbackPatterns {
pub fn for_language(language: Language) -> Self {
match language {
Language::JavaScript | Language::TypeScript => Self {
array_iterators: &[
"map",
"filter",
"forEach",
"find",
"findIndex",
"some",
"every",
"flatMap",
],
array_reducers: &["reduce", "reduceRight"],
promise_methods: &["then", "catch", "finally"],
event_handlers: &["on", "once", "addEventListener", "addListener", "subscribe"],
timer_functions: &[
"setTimeout",
"setInterval",
"setImmediate",
"requestAnimationFrame",
"queueMicrotask",
],
},
Language::Python => Self {
array_iterators: &["map", "filter"],
array_reducers: &["reduce"],
promise_methods: &[], event_handlers: &["connect", "on"],
timer_functions: &["call_later", "call_at", "call_soon"],
},
Language::Java => Self {
array_iterators: &[
"map",
"filter",
"forEach",
"findFirst",
"findAny",
"anyMatch",
"allMatch",
"noneMatch",
],
array_reducers: &["reduce", "collect"],
promise_methods: &["thenApply", "thenAccept", "thenCompose", "exceptionally"],
event_handlers: &["addListener", "subscribe", "on"],
timer_functions: &["schedule", "scheduleAtFixedRate"],
},
Language::Go => Self {
array_iterators: &[],
array_reducers: &[],
promise_methods: &[],
event_handlers: &[],
timer_functions: &["AfterFunc"],
},
Language::Rust => Self {
array_iterators: &[
"map", "filter", "for_each", "find", "any", "all", "flat_map",
],
array_reducers: &["fold", "reduce"],
promise_methods: &["and_then", "map", "map_err", "or_else"],
event_handlers: &[],
timer_functions: &[],
},
Language::Php => Self {
array_iterators: &[
"array_map",
"array_filter",
"array_walk",
"array_walk_recursive",
],
array_reducers: &["array_reduce"],
promise_methods: &[],
event_handlers: &[],
timer_functions: &[],
},
Language::CSharp => Self {
array_iterators: &["Select", "Where", "ForEach", "Any", "All", "SelectMany"],
array_reducers: &["Aggregate"],
promise_methods: &["ContinueWith", "Then"],
event_handlers: &["Subscribe", "AddHandler"],
timer_functions: &[],
},
Language::Kotlin => Self {
array_iterators: &["map", "filter", "forEach", "find", "any", "all", "flatMap"],
array_reducers: &["fold", "reduce"],
promise_methods: &[],
event_handlers: &["collect", "onEach"],
timer_functions: &[],
},
Language::Scala => Self {
array_iterators: &[
"map", "filter", "foreach", "find", "exists", "forall", "flatMap",
],
array_reducers: &["fold", "foldLeft", "foldRight", "reduce", "collect"],
promise_methods: &["map", "flatMap", "onComplete", "recover"],
event_handlers: &[],
timer_functions: &[],
},
Language::Swift => Self {
array_iterators: &[
"map",
"filter",
"forEach",
"first",
"contains",
"compactMap",
"flatMap",
],
array_reducers: &["reduce"],
promise_methods: &[],
event_handlers: &["sink", "receive"],
timer_functions: &[],
},
Language::Elixir => Self {
array_iterators: &[
"Enum.map",
"Enum.filter",
"Enum.each",
"Enum.find",
"Enum.any?",
"Enum.all?",
"Enum.flat_map",
],
array_reducers: &["Enum.reduce", "Enum.reduce_while"],
promise_methods: &[],
event_handlers: &[],
timer_functions: &["Task.async", "Task.start"],
},
_ => Self {
array_iterators: &[],
array_reducers: &[],
promise_methods: &[],
event_handlers: &[],
timer_functions: &[],
},
}
}
pub fn classify(&self, name: &str) -> Option<CallbackKind> {
if self.array_iterators.contains(&name) {
Some(CallbackKind::ArrayIterator)
} else if self.array_reducers.contains(&name) {
Some(CallbackKind::ArrayReducer)
} else if self.promise_methods.contains(&name) {
Some(CallbackKind::PromiseChain)
} else if self.event_handlers.contains(&name) {
Some(CallbackKind::EventHandler)
} else if self.timer_functions.contains(&name) {
Some(CallbackKind::TimerCallback)
} else {
None
}
}
pub fn is_callback_pattern(&self, name: &str) -> bool {
self.classify(name).is_some()
}
}
pub struct CallbackAnalyzer<'a> {
semantics: &'static LanguageSemantics,
patterns: CallbackPatterns,
source: &'a [u8],
tainted_vars: HashSet<String>,
file_path: PathBuf,
}
impl<'a> CallbackAnalyzer<'a> {
pub fn new(
semantics: &'static LanguageSemantics,
source: &'a [u8],
file_path: PathBuf,
) -> Self {
let language = semantics.language_enum();
Self {
semantics,
patterns: CallbackPatterns::for_language(language),
source,
tainted_vars: HashSet::new(),
file_path,
}
}
pub fn with_tainted_vars(
semantics: &'static LanguageSemantics,
source: &'a [u8],
file_path: PathBuf,
tainted_vars: HashSet<String>,
) -> Self {
let language = semantics.language_enum();
Self {
semantics,
patterns: CallbackPatterns::for_language(language),
source,
tainted_vars,
file_path,
}
}
pub fn analyze(&self, tree: &tree_sitter::Tree) -> CallbackRegistry {
let mut registry = CallbackRegistry::with_tainted_vars(self.tainted_vars.clone());
let root = tree.root_node();
self.walk_for_callbacks(root, &mut registry, None);
registry.compute_taint_flows();
registry
}
fn walk_for_callbacks(
&self,
node: tree_sitter::Node,
registry: &mut CallbackRegistry,
current_function: Option<String>,
) {
let func_context = if self.semantics.is_function_def(node.kind()) {
self.extract_function_name(node)
.or(current_function.clone())
} else {
current_function.clone()
};
if self.semantics.is_call(node.kind())
&& let Some(callback_site) = self.extract_callback_site(node, &func_context)
{
registry.register_callback(callback_site);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.walk_for_callbacks(child, registry, func_context.clone());
}
}
fn extract_callback_site(
&self,
node: tree_sitter::Node,
containing_function: &Option<String>,
) -> Option<CallbackSite> {
let func_node = node
.child_by_field_name(self.semantics.function_field)
.or_else(|| node.child(0))?;
let (hof_name, receiver) = self.extract_hof_and_receiver(func_node)?;
let kind = self.patterns.classify(&hof_name)?;
let callback_params = self.extract_callback_params(node)?;
Some(CallbackSite {
file: self.file_path.clone(),
line: node.start_position().row + 1,
column: node.start_position().column,
hof_name,
receiver,
kind,
callback_params,
containing_function: containing_function.clone(),
node_id: node.id(),
})
}
fn extract_hof_and_receiver(
&self,
func_node: tree_sitter::Node,
) -> Option<(String, Option<String>)> {
match func_node.kind() {
"member_expression" | "field_expression" | "attribute" | "selector_expression" => {
let method_name = func_node
.child_by_field_name(self.semantics.property_field)
.or_else(|| {
let count = func_node.named_child_count();
if count > 0 {
func_node.named_child(count - 1)
} else {
None
}
})?
.utf8_text(self.source)
.ok()?
.to_string();
let receiver = func_node
.child_by_field_name(self.semantics.object_field)
.or_else(|| func_node.named_child(0))
.and_then(|n| {
if n.kind() == "identifier" {
n.utf8_text(self.source).ok().map(String::from)
} else {
n.utf8_text(self.source).ok().map(String::from)
}
});
Some((method_name, receiver))
}
"identifier" => {
let name = func_node.utf8_text(self.source).ok()?.to_string();
Some((name, None))
}
_ => None,
}
}
fn extract_callback_params(&self, call_node: tree_sitter::Node) -> Option<Vec<String>> {
let args_node = call_node.child_by_field_name(self.semantics.arguments_field)?;
let mut cursor = args_node.walk();
for arg in args_node.named_children(&mut cursor) {
match arg.kind() {
"arrow_function" | "function_expression" | "function" | "lambda" => {
return self.extract_function_params(arg);
}
_ => continue,
}
}
None
}
fn extract_function_params(&self, func_node: tree_sitter::Node) -> Option<Vec<String>> {
let params_node = func_node.child_by_field_name(self.semantics.parameters_field)?;
let mut params = Vec::new();
let mut cursor = params_node.walk();
for param in params_node.named_children(&mut cursor) {
match param.kind() {
"identifier" => {
if let Ok(name) = param.utf8_text(self.source) {
params.push(name.to_string());
}
}
"formal_parameter" | "required_parameter" | "parameter" => {
if let Some(name_node) = param.child_by_field_name(self.semantics.name_field) {
if let Ok(name) = name_node.utf8_text(self.source) {
params.push(name.to_string());
}
} else if let Ok(name) = param.utf8_text(self.source) {
let name = name.split(':').next().unwrap_or(name).trim();
params.push(name.to_string());
}
}
"assignment_pattern" | "default_parameter" => {
if let Some(left) = param.child_by_field_name(self.semantics.left_field)
&& let Ok(name) = left.utf8_text(self.source)
{
params.push(name.to_string());
}
}
_ => continue,
}
}
if params.is_empty() {
None
} else {
Some(params)
}
}
fn extract_function_name(&self, node: tree_sitter::Node) -> Option<String> {
node.child_by_field_name(self.semantics.name_field)
.and_then(|n| n.utf8_text(self.source).ok())
.map(String::from)
}
}
pub fn propagate_callback_taint(
tainted_vars: &HashSet<String>,
callback_sites: &[CallbackSite],
) -> HashSet<String> {
let mut registry = CallbackRegistry::with_tainted_vars(tainted_vars.clone());
for site in callback_sites {
registry.register_callback(site.clone());
}
registry.compute_taint_flows();
registry.tainted_callback_params()
}
pub fn analyze_callback_taint(
tree: &tree_sitter::Tree,
source: &[u8],
file_path: PathBuf,
tainted_vars: HashSet<String>,
semantics: &'static LanguageSemantics,
) -> CallbackRegistry {
let analyzer = CallbackAnalyzer::with_tainted_vars(semantics, source, file_path, tainted_vars);
analyzer.analyze(tree)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_callback_patterns_js() {
let patterns = CallbackPatterns::for_language(Language::JavaScript);
assert_eq!(patterns.classify("map"), Some(CallbackKind::ArrayIterator));
assert_eq!(
patterns.classify("filter"),
Some(CallbackKind::ArrayIterator)
);
assert_eq!(
patterns.classify("forEach"),
Some(CallbackKind::ArrayIterator)
);
assert_eq!(
patterns.classify("reduce"),
Some(CallbackKind::ArrayReducer)
);
assert_eq!(patterns.classify("then"), Some(CallbackKind::PromiseChain));
assert_eq!(patterns.classify("catch"), Some(CallbackKind::PromiseChain));
assert_eq!(patterns.classify("on"), Some(CallbackKind::EventHandler));
assert_eq!(
patterns.classify("setTimeout"),
Some(CallbackKind::TimerCallback)
);
assert_eq!(patterns.classify("unknownMethod"), None);
}
#[test]
fn test_callback_kind_tainted_param() {
assert_eq!(CallbackKind::ArrayIterator.tainted_param_index(), 0);
assert_eq!(CallbackKind::ArrayReducer.tainted_param_index(), 1);
assert_eq!(CallbackKind::PromiseChain.tainted_param_index(), 0);
assert_eq!(CallbackKind::EventHandler.tainted_param_index(), 0);
assert_eq!(
CallbackKind::TimerCallback.tainted_param_index(),
usize::MAX
);
}
#[test]
fn test_callback_registry_basic() {
let mut registry = CallbackRegistry::new();
registry.add_tainted_var("userInputs".to_string());
let site = CallbackSite {
file: PathBuf::from("test.js"),
line: 10,
column: 0,
hof_name: "map".to_string(),
receiver: Some("userInputs".to_string()),
kind: CallbackKind::ArrayIterator,
callback_params: vec!["item".to_string()],
containing_function: Some("handler".to_string()),
node_id: 100,
};
registry.register_callback(site);
registry.compute_taint_flows();
let flows = registry.taint_flows();
assert_eq!(flows.len(), 1);
assert_eq!(flows[0].target_param, "item");
assert_eq!(flows[0].confidence, TaintConfidence::Definite);
}
#[test]
fn test_callback_registry_promise_chain() {
let mut registry = CallbackRegistry::new();
registry.add_tainted_var("fetchResult".to_string());
let site = CallbackSite {
file: PathBuf::from("test.js"),
line: 15,
column: 0,
hof_name: "then".to_string(),
receiver: Some("fetchResult".to_string()),
kind: CallbackKind::PromiseChain,
callback_params: vec!["response".to_string()],
containing_function: None,
node_id: 200,
};
registry.register_callback(site);
registry.compute_taint_flows();
let flows = registry.taint_flows();
assert_eq!(flows.len(), 1);
assert_eq!(flows[0].target_param, "response");
match &flows[0].taint_source {
TaintSource::PromiseResolution(receiver) => {
assert_eq!(receiver, "fetchResult");
}
_ => panic!("Expected PromiseResolution taint source"),
}
}
#[test]
fn test_callback_registry_no_taint() {
let mut registry = CallbackRegistry::new();
let site = CallbackSite {
file: PathBuf::from("test.js"),
line: 10,
column: 0,
hof_name: "map".to_string(),
receiver: Some("safeArray".to_string()),
kind: CallbackKind::ArrayIterator,
callback_params: vec!["item".to_string()],
containing_function: None,
node_id: 100,
};
registry.register_callback(site);
registry.compute_taint_flows();
let flows = registry.taint_flows();
assert_eq!(flows.len(), 0); }
#[test]
fn test_tainted_callback_params() {
let mut registry = CallbackRegistry::new();
registry.add_tainted_var("taintedArray".to_string());
registry.add_tainted_var("taintedPromise".to_string());
registry.register_callback(CallbackSite {
file: PathBuf::from("test.js"),
line: 10,
column: 0,
hof_name: "forEach".to_string(),
receiver: Some("taintedArray".to_string()),
kind: CallbackKind::ArrayIterator,
callback_params: vec!["item".to_string(), "index".to_string()],
containing_function: None,
node_id: 100,
});
registry.register_callback(CallbackSite {
file: PathBuf::from("test.js"),
line: 20,
column: 0,
hof_name: "then".to_string(),
receiver: Some("taintedPromise".to_string()),
kind: CallbackKind::PromiseChain,
callback_params: vec!["result".to_string()],
containing_function: None,
node_id: 200,
});
registry.compute_taint_flows();
let tainted_params = registry.tainted_callback_params();
assert!(tainted_params.contains("item"));
assert!(tainted_params.contains("result"));
assert!(!tainted_params.contains("index"));
}
#[test]
fn test_propagate_callback_taint() {
let mut tainted = HashSet::new();
tainted.insert("userInputs".to_string());
let callbacks = vec![CallbackSite {
file: PathBuf::from("test.js"),
line: 10,
column: 0,
hof_name: "map".to_string(),
receiver: Some("userInputs".to_string()),
kind: CallbackKind::ArrayIterator,
callback_params: vec!["x".to_string()],
containing_function: None,
node_id: 100,
}];
let tainted_params = propagate_callback_taint(&tainted, &callbacks);
assert!(tainted_params.contains("x"));
}
#[test]
fn test_callback_site_indexing() {
let mut registry = CallbackRegistry::new();
let site1 = CallbackSite {
file: PathBuf::from("test.js"),
line: 10,
column: 0,
hof_name: "map".to_string(),
receiver: Some("arr1".to_string()),
kind: CallbackKind::ArrayIterator,
callback_params: vec!["x".to_string()],
containing_function: None,
node_id: 100,
};
let site2 = CallbackSite {
file: PathBuf::from("test.js"),
line: 20,
column: 0,
hof_name: "filter".to_string(),
receiver: Some("arr1".to_string()),
kind: CallbackKind::ArrayIterator,
callback_params: vec!["y".to_string()],
containing_function: None,
node_id: 200,
};
let site3 = CallbackSite {
file: PathBuf::from("test.js"),
line: 30,
column: 0,
hof_name: "map".to_string(),
receiver: Some("arr2".to_string()),
kind: CallbackKind::ArrayIterator,
callback_params: vec!["z".to_string()],
containing_function: None,
node_id: 300,
};
registry.register_callback(site1);
registry.register_callback(site2);
registry.register_callback(site3);
let arr1_callbacks = registry.callbacks_for_receiver("arr1");
assert_eq!(arr1_callbacks.len(), 2);
let arr2_callbacks = registry.callbacks_for_receiver("arr2");
assert_eq!(arr2_callbacks.len(), 1);
let map_callbacks = registry.callbacks_for_hof("map");
assert_eq!(map_callbacks.len(), 2);
let filter_callbacks = registry.callbacks_for_hof("filter");
assert_eq!(filter_callbacks.len(), 1);
}
}