use crate::flow::symbol_table::{SymbolInfo, SymbolTable, ValueOrigin};
use crate::semantics::LanguageSemantics;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct LocationId(pub usize);
impl LocationId {
pub fn new(id: usize) -> Self {
Self(id)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Location {
Alloc(AllocationSite),
Parameter { func_name: String, index: usize },
ReturnValue { func_name: String },
Unknown,
Field { base: LocationId, field: String },
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AllocationSite {
pub node_id: usize,
pub line: usize,
pub kind: AllocKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AllocKind {
ObjectLiteral,
ArrayLiteral,
Constructor,
FunctionCall,
Import,
Unknown,
}
#[derive(Debug, Clone, Default)]
pub struct AliasSet {
variables: HashSet<String>,
locations: HashSet<LocationId>,
representative: Option<String>,
}
impl AliasSet {
pub fn new() -> Self {
Self::default()
}
pub fn singleton(var: impl Into<String>) -> Self {
let var = var.into();
let mut set = Self::new();
set.variables.insert(var.clone());
set.representative = Some(var);
set
}
pub fn add_variable(&mut self, var: impl Into<String>) {
let var = var.into();
if self.representative.is_none() {
self.representative = Some(var.clone());
}
self.variables.insert(var);
}
pub fn add_location(&mut self, loc: LocationId) {
self.locations.insert(loc);
}
pub fn contains(&self, var: &str) -> bool {
self.variables.contains(var)
}
pub fn variables(&self) -> &HashSet<String> {
&self.variables
}
pub fn locations(&self) -> &HashSet<LocationId> {
&self.locations
}
pub fn len(&self) -> usize {
self.variables.len()
}
pub fn is_empty(&self) -> bool {
self.variables.is_empty()
}
pub fn merge(&mut self, other: &AliasSet) {
self.variables.extend(other.variables.iter().cloned());
self.locations.extend(other.locations.iter().copied());
}
pub fn iter(&self) -> impl Iterator<Item = &String> {
self.variables.iter()
}
}
#[derive(Debug, Clone, Default)]
pub struct PointsToGraph {
points_to: HashMap<String, HashSet<LocationId>>,
locations: HashMap<LocationId, Location>,
reverse_points_to: HashMap<LocationId, HashSet<String>>,
next_location_id: usize,
direct_aliases: HashMap<String, HashSet<String>>,
param_aliases: HashMap<String, HashSet<(String, usize)>>, }
impl PointsToGraph {
pub fn new() -> Self {
Self::default()
}
pub fn create_location(&mut self, loc: Location) -> LocationId {
let id = LocationId::new(self.next_location_id);
self.next_location_id += 1;
self.locations.insert(id, loc);
id
}
pub fn add_points_to(&mut self, var: impl Into<String>, loc: LocationId) {
let var = var.into();
self.points_to.entry(var.clone()).or_default().insert(loc);
self.reverse_points_to.entry(loc).or_default().insert(var);
}
pub fn add_direct_alias(&mut self, alias: impl Into<String>, original: impl Into<String>) {
let alias = alias.into();
let original = original.into();
self.direct_aliases
.entry(alias.clone())
.or_default()
.insert(original.clone());
if let Some(locs) = self.points_to.get(&original).cloned() {
for loc in locs {
self.add_points_to(alias.clone(), loc);
}
}
}
pub fn add_param_alias(
&mut self,
param: impl Into<String>,
call_site_var: impl Into<String>,
arg_index: usize,
) {
let param = param.into();
let call_site_var = call_site_var.into();
self.param_aliases
.entry(param)
.or_default()
.insert((call_site_var, arg_index));
}
pub fn points_to_set(&self, var: &str) -> HashSet<LocationId> {
self.points_to.get(var).cloned().unwrap_or_default()
}
pub fn variables_pointing_to(&self, loc: LocationId) -> HashSet<String> {
self.reverse_points_to
.get(&loc)
.cloned()
.unwrap_or_default()
}
pub fn get_location(&self, id: LocationId) -> Option<&Location> {
self.locations.get(&id)
}
pub fn may_alias(&self, var1: &str, var2: &str) -> bool {
if var1 == var2 {
return true;
}
if self.are_directly_aliased(var1, var2) {
return true;
}
let pts1 = self.points_to_set(var1);
let pts2 = self.points_to_set(var2);
if pts1.is_empty() || pts2.is_empty() {
return self.transitive_alias_check(var1, var2);
}
pts1.intersection(&pts2).next().is_some()
}
fn are_directly_aliased(&self, var1: &str, var2: &str) -> bool {
if let Some(aliases) = self.direct_aliases.get(var1)
&& aliases.contains(var2)
{
return true;
}
if let Some(aliases) = self.direct_aliases.get(var2)
&& aliases.contains(var1)
{
return true;
}
false
}
fn transitive_alias_check(&self, var1: &str, var2: &str) -> bool {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(var1.to_string());
visited.insert(var1.to_string());
while let Some(current) = queue.pop_front() {
if current == var2 {
return true;
}
if let Some(aliases) = self.direct_aliases.get(¤t) {
for alias in aliases {
if !visited.contains(alias) {
visited.insert(alias.clone());
queue.push_back(alias.clone());
}
}
}
for (aliased_var, aliases) in &self.direct_aliases {
if aliases.contains(¤t) && !visited.contains(aliased_var) {
visited.insert(aliased_var.clone());
queue.push_back(aliased_var.clone());
}
}
}
false
}
pub fn aliases_of(&self, var: &str) -> HashSet<String> {
let mut aliases = HashSet::new();
let mut queue = VecDeque::new();
let mut visited = HashSet::new();
queue.push_back(var.to_string());
visited.insert(var.to_string());
while let Some(current) = queue.pop_front() {
if let Some(direct) = self.direct_aliases.get(¤t) {
for alias in direct {
if visited.insert(alias.clone()) {
aliases.insert(alias.clone());
queue.push_back(alias.clone());
}
}
}
for (other_var, other_aliases) in &self.direct_aliases {
if other_aliases.contains(¤t) && visited.insert(other_var.clone()) {
aliases.insert(other_var.clone());
queue.push_back(other_var.clone());
}
}
}
for loc in self.points_to_set(var) {
if let Some(vars) = self.reverse_points_to.get(&loc) {
for v in vars {
if v != var {
aliases.insert(v.clone());
}
}
}
}
aliases
}
pub fn compute_alias_sets(&self) -> Vec<AliasSet> {
let mut visited = HashSet::new();
let mut sets = Vec::new();
let all_vars: HashSet<_> = self
.points_to
.keys()
.chain(self.direct_aliases.keys())
.chain(self.direct_aliases.values().flat_map(|s| s.iter()))
.cloned()
.collect();
for var in all_vars {
if visited.contains(&var) {
continue;
}
let mut set = AliasSet::new();
let mut queue = VecDeque::new();
queue.push_back(var.clone());
visited.insert(var.clone());
set.add_variable(var.clone());
while let Some(current) = queue.pop_front() {
for loc in self.points_to_set(¤t) {
set.add_location(loc);
if let Some(vars) = self.reverse_points_to.get(&loc) {
for v in vars {
if visited.insert(v.clone()) {
set.add_variable(v.clone());
queue.push_back(v.clone());
}
}
}
}
if let Some(aliases) = self.direct_aliases.get(¤t) {
for alias in aliases {
if visited.insert(alias.clone()) {
set.add_variable(alias.clone());
queue.push_back(alias.clone());
}
}
}
for (aliased_var, aliases) in &self.direct_aliases {
if aliases.contains(¤t) && visited.insert(aliased_var.clone()) {
set.add_variable(aliased_var.clone());
queue.push_back(aliased_var.clone());
}
}
}
if !set.is_empty() {
sets.push(set);
}
}
sets
}
pub fn variable_count(&self) -> usize {
let mut vars: HashSet<_> = self.points_to.keys().cloned().collect();
vars.extend(self.direct_aliases.keys().cloned());
vars.extend(self.direct_aliases.values().flat_map(|s| s.iter()).cloned());
vars.len()
}
pub fn location_count(&self) -> usize {
self.locations.len()
}
}
#[derive(Debug, Clone)]
pub struct AliasResult {
pub graph: PointsToGraph,
pub alias_sets: Vec<AliasSet>,
pub var_to_set: HashMap<String, usize>,
pub iterations: usize,
}
impl AliasResult {
pub fn may_alias(&self, var1: &str, var2: &str) -> bool {
if var1 == var2 {
return true;
}
if let (Some(&set1), Some(&set2)) = (self.var_to_set.get(var1), self.var_to_set.get(var2))
&& set1 == set2
{
return true;
}
self.graph.may_alias(var1, var2)
}
pub fn aliases_of(&self, var: &str) -> HashSet<String> {
if let Some(&set_idx) = self.var_to_set.get(var)
&& let Some(set) = self.alias_sets.get(set_idx)
{
return set.variables().clone();
}
self.graph.aliases_of(var)
}
pub fn get_alias_set(&self, var: &str) -> Option<&AliasSet> {
self.var_to_set
.get(var)
.and_then(|&idx| self.alias_sets.get(idx))
}
pub fn all_alias_sets(&self) -> &[AliasSet] {
&self.alias_sets
}
}
pub struct AliasAnalyzer<'a> {
symbols: &'a SymbolTable,
semantics: &'static LanguageSemantics,
source: &'a [u8],
tree: &'a tree_sitter::Tree,
}
impl<'a> AliasAnalyzer<'a> {
pub fn new(
symbols: &'a SymbolTable,
semantics: &'static LanguageSemantics,
source: &'a [u8],
tree: &'a tree_sitter::Tree,
) -> Self {
Self {
symbols,
semantics,
source,
tree,
}
}
pub fn analyze(&self) -> AliasResult {
let mut graph = PointsToGraph::new();
let mut iterations = 0;
self.process_symbols(&mut graph);
self.process_calls(&mut graph);
let max_iterations = 100;
loop {
iterations += 1;
if iterations > max_iterations {
break;
}
let changed = self.propagate_aliases(&mut graph);
if !changed {
break;
}
}
let alias_sets = graph.compute_alias_sets();
let mut var_to_set = HashMap::new();
for (idx, set) in alias_sets.iter().enumerate() {
for var in set.variables() {
var_to_set.insert(var.clone(), idx);
}
}
AliasResult {
graph,
alias_sets,
var_to_set,
iterations,
}
}
fn process_symbols(&self, graph: &mut PointsToGraph) {
for (name, info) in self.symbols.iter() {
self.process_symbol(name, info, graph);
}
}
fn process_symbol(&self, name: &str, info: &SymbolInfo, graph: &mut PointsToGraph) {
match &info.initializer {
ValueOrigin::Variable(source_var) => {
graph.add_direct_alias(name, source_var);
}
ValueOrigin::Parameter(idx) => {
let loc = graph.create_location(Location::Parameter {
func_name: String::new(), index: *idx,
});
graph.add_points_to(name, loc);
}
ValueOrigin::FunctionCall(func_name) => {
let loc = graph.create_location(Location::ReturnValue {
func_name: func_name.clone(),
});
graph.add_points_to(name, loc);
}
ValueOrigin::MemberAccess(_path) => {
let loc = graph.create_location(Location::Alloc(AllocationSite {
node_id: info.declaration_node_id,
line: info.line,
kind: AllocKind::Unknown,
}));
graph.add_points_to(name, loc);
}
ValueOrigin::StringConcat(_vars) | ValueOrigin::TemplateLiteral(_vars) => {
let loc = graph.create_location(Location::Alloc(AllocationSite {
node_id: info.declaration_node_id,
line: info.line,
kind: AllocKind::Unknown,
}));
graph.add_points_to(name, loc);
}
ValueOrigin::MethodCall {
method,
receiver,
arguments: _,
} => {
if Self::returns_receiver(method) {
if let Some(recv) = receiver {
graph.add_direct_alias(name, recv);
}
} else {
let loc = graph.create_location(Location::ReturnValue {
func_name: method.clone(),
});
graph.add_points_to(name, loc);
}
}
ValueOrigin::Literal(_) => {
let loc = graph.create_location(Location::Alloc(AllocationSite {
node_id: info.declaration_node_id,
line: info.line,
kind: AllocKind::ObjectLiteral,
}));
graph.add_points_to(name, loc);
}
ValueOrigin::BinaryExpression => {
let loc = graph.create_location(Location::Alloc(AllocationSite {
node_id: info.declaration_node_id,
line: info.line,
kind: AllocKind::Unknown,
}));
graph.add_points_to(name, loc);
}
ValueOrigin::Unknown => {
let loc = graph.create_location(Location::Unknown);
graph.add_points_to(name, loc);
}
}
for reassign in &info.reassignments {
self.process_reassignment(name, reassign, graph);
}
}
fn process_reassignment(&self, name: &str, origin: &ValueOrigin, graph: &mut PointsToGraph) {
match origin {
ValueOrigin::Variable(source_var) => {
graph.add_direct_alias(name, source_var);
}
_ => {}
}
}
fn returns_receiver(method: &str) -> bool {
matches!(
method.to_lowercase().as_str(),
"concat"
| "slice"
| "map"
| "filter"
| "reduce"
| "trim"
| "tolowercase"
| "touppercase"
| "replace"
| "split"
| "join"
| "push"
| "pop"
| "shift"
| "unshift"
| "sort"
| "reverse"
| "fill"
| "copywithin"
)
}
fn process_calls(&self, graph: &mut PointsToGraph) {
let root = self.tree.root_node();
self.walk_for_calls(root, graph);
}
fn walk_for_calls(&self, node: tree_sitter::Node, graph: &mut PointsToGraph) {
if self.semantics.is_call(node.kind()) {
self.process_call_site(node, graph);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.walk_for_calls(child, graph);
}
}
fn process_call_site(&self, node: tree_sitter::Node, graph: &mut PointsToGraph) {
if let Some(args) = node.child_by_field_name("arguments") {
let mut cursor = args.walk();
for (idx, arg) in args.named_children(&mut cursor).enumerate() {
if (self.semantics.is_identifier(arg.kind()) || arg.kind() == "identifier")
&& let Ok(var_name) = arg.utf8_text(self.source)
{
let func_name = self.extract_callee_name(node).unwrap_or_default();
let param_name = format!("{}$param{}", func_name, idx);
graph.add_param_alias(¶m_name, var_name, idx);
}
}
}
}
fn extract_callee_name(&self, call_node: tree_sitter::Node) -> Option<String> {
let func = call_node
.child_by_field_name("function")
.or_else(|| call_node.child(0))?;
func.utf8_text(self.source).ok().map(String::from)
}
fn propagate_aliases(&self, graph: &mut PointsToGraph) -> bool {
let mut changed = false;
let aliases: Vec<_> = graph
.direct_aliases
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
for (alias, sources) in aliases {
for source in sources {
let source_pts = graph.points_to_set(&source);
for loc in source_pts {
if !graph
.points_to
.get(&alias)
.is_some_and(|s| s.contains(&loc))
{
graph.add_points_to(alias.clone(), loc);
changed = true;
}
}
}
}
changed
}
}
pub fn analyze_aliases(
symbols: &SymbolTable,
tree: &tree_sitter::Tree,
source: &[u8],
semantics: &'static LanguageSemantics,
) -> AliasResult {
let analyzer = AliasAnalyzer::new(symbols, semantics, source, tree);
analyzer.analyze()
}
pub fn propagate_taint_through_aliases(
tainted_vars: &HashSet<String>,
alias_result: &AliasResult,
) -> HashSet<String> {
let mut result = tainted_vars.clone();
for var in tainted_vars {
let aliases = alias_result.aliases_of(var);
result.extend(aliases);
}
result
}
pub fn any_tainted_with_aliases(
vars: &[&str],
tainted_vars: &HashSet<String>,
alias_result: &AliasResult,
) -> bool {
for var in vars {
if tainted_vars.contains(*var) {
return true;
}
for alias in alias_result.aliases_of(var) {
if tainted_vars.contains(&alias) {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use rma_common::Language;
use rma_parser::ParserEngine;
use std::path::Path;
fn parse_js(code: &str) -> rma_parser::ParsedFile {
let config = rma_common::RmaConfig::default();
let parser = ParserEngine::new(config);
parser
.parse_file(Path::new("test.js"), code)
.expect("parse failed")
}
#[test]
fn test_direct_assignment_alias() {
let code = r#"
const x = getValue();
const y = x;
"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_aliases(&symbols, &parsed.tree, code.as_bytes(), semantics);
assert!(result.may_alias("x", "y"), "y = x should create alias");
assert!(result.may_alias("y", "x"), "alias should be symmetric");
}
#[test]
fn test_no_alias_different_values() {
let code = r#"
const x = getValue1();
const y = getValue2();
"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_aliases(&symbols, &parsed.tree, code.as_bytes(), semantics);
assert!(
!result.may_alias("x", "y"),
"different values should not alias"
);
}
#[test]
fn test_transitive_alias() {
let code = r#"
const x = getValue();
const y = x;
const z = y;
"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_aliases(&symbols, &parsed.tree, code.as_bytes(), semantics);
assert!(result.may_alias("x", "y"));
assert!(result.may_alias("y", "z"));
assert!(result.may_alias("x", "z"), "aliasing should be transitive");
}
#[test]
fn test_shared_origin_alias() {
let code = r#"
const obj = getObject();
const a = obj;
const b = obj;
"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_aliases(&symbols, &parsed.tree, code.as_bytes(), semantics);
assert!(result.may_alias("a", "obj"));
assert!(result.may_alias("b", "obj"));
assert!(
result.may_alias("a", "b"),
"variables from same origin should alias"
);
}
#[test]
fn test_alias_set_computation() {
let code = r#"
const x = getValue();
const y = x;
const a = getOther();
const b = a;
"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_aliases(&symbols, &parsed.tree, code.as_bytes(), semantics);
let sets = result.all_alias_sets();
let x_set = sets.iter().find(|s| s.contains("x"));
assert!(x_set.is_some());
let x_set = x_set.unwrap();
assert!(x_set.contains("y"));
assert!(!x_set.contains("a"));
assert!(!x_set.contains("b"));
}
#[test]
fn test_taint_propagation_through_aliases() {
let mut tainted = HashSet::new();
tainted.insert("x".to_string());
let code = r#"
const x = userInput;
const y = x;
"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let alias_result = analyze_aliases(&symbols, &parsed.tree, code.as_bytes(), semantics);
let expanded_taint = propagate_taint_through_aliases(&tainted, &alias_result);
assert!(expanded_taint.contains("x"));
assert!(
expanded_taint.contains("y"),
"alias should be tainted when original is tainted"
);
}
#[test]
fn test_literal_no_alias() {
let code = r#"
const a = "hello";
const b = "hello";
"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_aliases(&symbols, &parsed.tree, code.as_bytes(), semantics);
assert!(!result.may_alias("a", "b"));
}
#[test]
fn test_aliases_of_query() {
let code = r#"
const x = getValue();
const y = x;
const z = y;
"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_aliases(&symbols, &parsed.tree, code.as_bytes(), semantics);
let x_aliases = result.aliases_of("x");
assert!(x_aliases.contains("y"));
assert!(x_aliases.contains("z"));
}
#[test]
fn test_points_to_graph_basics() {
let mut graph = PointsToGraph::new();
let loc1 = graph.create_location(Location::Unknown);
let loc2 = graph.create_location(Location::Unknown);
graph.add_points_to("x", loc1);
graph.add_points_to("y", loc1);
graph.add_points_to("z", loc2);
assert!(graph.may_alias("x", "y"));
assert!(!graph.may_alias("x", "z"));
assert!(!graph.may_alias("y", "z"));
}
#[test]
fn test_self_alias() {
let mut graph = PointsToGraph::new();
let loc = graph.create_location(Location::Unknown);
graph.add_points_to("x", loc);
assert!(graph.may_alias("x", "x"));
}
}