use crate::flow::cfg::{BasicBlock, BlockId, CFG, Terminator};
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction {
Forward,
Backward,
}
pub trait Fact: Clone + Eq + Hash + Debug + Send + Sync {}
impl<T: Clone + Eq + Hash + Debug + Send + Sync> Fact for T {}
#[derive(Debug)]
pub struct DataflowResult<F: Fact> {
pub block_entry: HashMap<BlockId, HashSet<F>>,
pub block_exit: HashMap<BlockId, HashSet<F>>,
pub iterations: usize,
}
impl<F: Fact> Default for DataflowResult<F> {
fn default() -> Self {
Self {
block_entry: HashMap::new(),
block_exit: HashMap::new(),
iterations: 0,
}
}
}
impl<F: Fact> DataflowResult<F> {
pub fn at_entry(&self, block_id: BlockId) -> Option<&HashSet<F>> {
self.block_entry.get(&block_id)
}
pub fn at_exit(&self, block_id: BlockId) -> Option<&HashSet<F>> {
self.block_exit.get(&block_id)
}
pub fn contains_at_entry(&self, block_id: BlockId, fact: &F) -> bool {
self.block_entry
.get(&block_id)
.is_some_and(|set| set.contains(fact))
}
pub fn contains_at_exit(&self, block_id: BlockId, fact: &F) -> bool {
self.block_exit
.get(&block_id)
.is_some_and(|set| set.contains(fact))
}
pub fn contains_at_node(&self, node_id: usize, fact: &F, cfg: &CFG) -> bool {
cfg.node_to_block
.get(&node_id)
.is_some_and(|&block_id| self.contains_at_entry(block_id, fact))
}
pub fn facts_at_node(&self, node_id: usize, cfg: &CFG) -> HashSet<F> {
cfg.node_to_block
.get(&node_id)
.and_then(|&block_id| self.block_entry.get(&block_id))
.cloned()
.unwrap_or_default()
}
pub fn entry_facts(&self, block_id: BlockId) -> HashSet<F> {
self.block_entry.get(&block_id).cloned().unwrap_or_default()
}
pub fn exit_facts(&self, block_id: BlockId) -> HashSet<F> {
self.block_exit.get(&block_id).cloned().unwrap_or_default()
}
}
pub trait TransferFunction<F: Fact>: Send + Sync {
fn transfer(
&self,
block: &BasicBlock,
input: &HashSet<F>,
cfg: &CFG,
source: &[u8],
tree: &tree_sitter::Tree,
) -> HashSet<F>;
}
pub fn solve<F: Fact, T: TransferFunction<F>>(
cfg: &CFG,
direction: Direction,
transfer: &T,
source: &[u8],
tree: &tree_sitter::Tree,
) -> DataflowResult<F> {
let num_blocks = cfg.blocks.len();
if num_blocks == 0 {
return DataflowResult::default();
}
let mut block_entry: HashMap<BlockId, HashSet<F>> = HashMap::new();
let mut block_exit: HashMap<BlockId, HashSet<F>> = HashMap::new();
for block in &cfg.blocks {
block_entry.insert(block.id, HashSet::new());
block_exit.insert(block.id, HashSet::new());
}
let mut worklist: VecDeque<BlockId> = VecDeque::new();
let mut in_worklist: HashSet<BlockId> = HashSet::new();
match direction {
Direction::Forward => {
worklist.push_back(cfg.entry);
in_worklist.insert(cfg.entry);
}
Direction::Backward => {
for block in &cfg.blocks {
if block.reachable
&& matches!(
block.terminator,
Terminator::Return | Terminator::Unreachable
)
{
worklist.push_back(block.id);
in_worklist.insert(block.id);
}
}
if worklist.is_empty() {
for block in &cfg.blocks {
if block.reachable {
worklist.push_back(block.id);
in_worklist.insert(block.id);
}
}
}
}
}
let mut iterations = 0;
let max_iterations = num_blocks * 20;
while let Some(block_id) = worklist.pop_front() {
in_worklist.remove(&block_id);
iterations += 1;
if iterations > max_iterations {
tracing::warn!(
"Dataflow analysis did not converge after {} iterations",
max_iterations
);
break;
}
if block_id >= cfg.blocks.len() {
continue;
}
let block = &cfg.blocks[block_id];
if !block.reachable {
continue;
}
match direction {
Direction::Forward => {
let mut new_entry = HashSet::new();
for &pred in &block.predecessors {
if let Some(pred_exit) = block_exit.get(&pred) {
new_entry.extend(pred_exit.iter().cloned());
}
}
let new_exit = transfer.transfer(block, &new_entry, cfg, source, tree);
block_entry.insert(block_id, new_entry);
let old_exit = block_exit.get(&block_id);
let changed = old_exit.is_none_or(|old| *old != new_exit);
if changed {
block_exit.insert(block_id, new_exit);
for succ in cfg.successors(block_id) {
if !in_worklist.contains(&succ) {
worklist.push_back(succ);
in_worklist.insert(succ);
}
}
}
}
Direction::Backward => {
let mut new_exit = HashSet::new();
for succ in cfg.successors(block_id) {
if let Some(succ_entry) = block_entry.get(&succ) {
new_exit.extend(succ_entry.iter().cloned());
}
}
let new_entry = transfer.transfer(block, &new_exit, cfg, source, tree);
block_exit.insert(block_id, new_exit);
let old_entry = block_entry.get(&block_id);
let changed = old_entry.is_none_or(|old| *old != new_entry);
if changed {
block_entry.insert(block_id, new_entry);
for &pred in &block.predecessors {
if !in_worklist.contains(&pred) {
worklist.push_back(pred);
in_worklist.insert(pred);
}
}
}
}
}
}
DataflowResult {
block_entry,
block_exit,
iterations,
}
}
pub struct NodeIndex<'tree> {
nodes: HashMap<usize, tree_sitter::Node<'tree>>,
}
impl<'tree> NodeIndex<'tree> {
pub fn build(tree: &'tree tree_sitter::Tree) -> Self {
let mut nodes = HashMap::new();
fn walk<'a>(node: tree_sitter::Node<'a>, map: &mut HashMap<usize, tree_sitter::Node<'a>>) {
map.insert(node.id(), node);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk(child, map);
}
}
walk(tree.root_node(), &mut nodes);
Self { nodes }
}
pub fn get(&self, id: usize) -> Option<tree_sitter::Node<'tree>> {
self.nodes.get(&id).copied()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
}
pub fn find_node_by_id(
tree: &tree_sitter::Tree,
target_id: usize,
) -> Option<tree_sitter::Node<'_>> {
fn walk_find(node: tree_sitter::Node<'_>, target: usize) -> Option<tree_sitter::Node<'_>> {
if node.id() == target {
return Some(node);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = walk_find(child, target) {
return Some(found);
}
}
None
}
walk_find(tree.root_node(), target_id)
}
#[cfg(test)]
mod tests {
use super::*;
use rma_common::Language;
use rma_parser::ParserEngine;
use std::path::Path;
fn parse_js(code: &str) -> rma_parser::ParsedFile {
let config = rma_common::RmaConfig::default();
let parser = ParserEngine::new(config);
parser
.parse_file(Path::new("test.js"), code)
.expect("parse failed")
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TestFact(String);
struct IdentityTransfer;
impl TransferFunction<TestFact> for IdentityTransfer {
fn transfer(
&self,
_block: &BasicBlock,
input: &HashSet<TestFact>,
_cfg: &CFG,
_source: &[u8],
_tree: &tree_sitter::Tree,
) -> HashSet<TestFact> {
input.clone()
}
}
#[test]
fn test_empty_cfg() {
let parsed = parse_js("");
let cfg = CFG::build(&parsed, Language::JavaScript);
let transfer = IdentityTransfer;
let result = solve(&cfg, Direction::Forward, &transfer, b"", &parsed.tree);
assert!(result.iterations <= 1);
}
#[test]
fn test_simple_forward_propagation() {
let code = "const x = 1; const y = 2;";
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let transfer = IdentityTransfer;
let result = solve(
&cfg,
Direction::Forward,
&transfer,
code.as_bytes(),
&parsed.tree,
);
assert!(result.iterations < cfg.block_count() * 5);
}
#[test]
fn test_backward_direction() {
let code = "function f() { return 1; }";
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let transfer = IdentityTransfer;
let result = solve(
&cfg,
Direction::Backward,
&transfer,
code.as_bytes(),
&parsed.tree,
);
assert!(result.iterations < cfg.block_count() * 5);
}
#[test]
fn test_node_index() {
let code = "const x = 1;";
let parsed = parse_js(code);
let index = NodeIndex::build(&parsed.tree);
assert!(!index.is_empty());
let root_id = parsed.tree.root_node().id();
assert!(index.get(root_id).is_some());
}
#[test]
fn test_find_node_by_id_fallback() {
let code = "const x = 1;";
let parsed = parse_js(code);
let root_id = parsed.tree.root_node().id();
let found = find_node_by_id(&parsed.tree, root_id);
assert!(found.is_some());
let not_found = find_node_by_id(&parsed.tree, usize::MAX);
assert!(not_found.is_none());
}
#[test]
fn test_dataflow_result_queries() {
let mut result: DataflowResult<TestFact> = DataflowResult::default();
result
.block_entry
.insert(0, HashSet::from([TestFact("x".to_string())]));
result
.block_exit
.insert(0, HashSet::from([TestFact("y".to_string())]));
assert!(result.contains_at_entry(0, &TestFact("x".to_string())));
assert!(!result.contains_at_entry(0, &TestFact("y".to_string())));
assert!(result.contains_at_exit(0, &TestFact("y".to_string())));
assert!(!result.contains_at_exit(0, &TestFact("x".to_string())));
assert!(!result.contains_at_entry(999, &TestFact("x".to_string())));
}
}