use std::collections::hash_map::Entry;
use rustc_hash::FxHashMap;
use oxc_allocator::{Address, GetAddress, Vec as ArenaVec};
use oxc_ast::{AstBuilder, ast::*};
pub struct StatementInjectorStore<'a> {
insertions: FxHashMap<Address, Vec<AdjacentStatement<'a>>>,
}
#[derive(Debug)]
enum Direction {
Before,
After,
}
#[derive(Debug)]
struct AdjacentStatement<'a> {
stmt: Statement<'a>,
direction: Direction,
}
impl StatementInjectorStore<'_> {
pub fn new() -> Self {
Self { insertions: FxHashMap::default() }
}
}
impl<'a> StatementInjectorStore<'a> {
#[inline]
pub fn insert_before<A: GetAddress>(&mut self, target: &A, stmt: Statement<'a>) {
self.insert_before_address(target.address(), stmt);
}
fn insert_before_address(&mut self, target: Address, stmt: Statement<'a>) {
let adjacent_stmts = self.insertions.entry(target).or_default();
let index = adjacent_stmts
.iter()
.position(|s| matches!(s.direction, Direction::After))
.unwrap_or(adjacent_stmts.len());
adjacent_stmts.insert(index, AdjacentStatement { stmt, direction: Direction::Before });
}
#[inline]
pub fn insert_after<A: GetAddress>(&mut self, target: &A, stmt: Statement<'a>) {
self.insert_after_address(target.address(), stmt);
}
fn insert_after_address(&mut self, target: Address, stmt: Statement<'a>) {
let adjacent_stmts = self.insertions.entry(target).or_default();
adjacent_stmts.push(AdjacentStatement { stmt, direction: Direction::After });
}
#[inline]
pub fn insert_many_before<A, S>(&mut self, target: &A, stmts: S)
where
A: GetAddress,
S: IntoIterator<Item = Statement<'a>>,
{
self.insert_many_before_address(target.address(), stmts);
}
fn insert_many_before_address<S>(&mut self, target: Address, stmts: S)
where
S: IntoIterator<Item = Statement<'a>>,
{
let adjacent_stmts = self.insertions.entry(target).or_default();
adjacent_stmts.splice(
0..0,
stmts.into_iter().map(|stmt| AdjacentStatement { stmt, direction: Direction::Before }),
);
}
#[inline]
pub fn insert_many_after<A, S>(&mut self, target: &A, stmts: S)
where
A: GetAddress,
S: IntoIterator<Item = Statement<'a>>,
{
self.insert_many_after_address(target.address(), stmts);
}
fn insert_many_after_address<S>(&mut self, target: Address, stmts: S)
where
S: IntoIterator<Item = Statement<'a>>,
{
let adjacent_stmts = self.insertions.entry(target).or_default();
adjacent_stmts.extend(
stmts.into_iter().map(|stmt| AdjacentStatement { stmt, direction: Direction::After }),
);
}
#[inline]
pub fn move_insertions<A1: GetAddress, A2: GetAddress>(
&mut self,
old_target: &A1,
new_target: &A2,
) {
self.move_insertions_address(old_target.address(), new_target.address());
}
fn move_insertions_address(&mut self, old_address: Address, new_address: Address) {
let Some(mut adjacent_stmts) = self.insertions.remove(&old_address) else { return };
match self.insertions.entry(new_address) {
Entry::Occupied(entry) => {
entry.into_mut().append(&mut adjacent_stmts);
}
Entry::Vacant(entry) => {
entry.insert(adjacent_stmts);
}
}
}
}
impl<'a> StatementInjectorStore<'a> {
pub(crate) fn insert_into_statements(
&mut self,
statements: &mut ArenaVec<'a, Statement<'a>>,
ast: AstBuilder<'a>,
) {
if self.insertions.is_empty() {
return;
}
let new_statement_count = statements
.iter()
.filter_map(|s| self.insertions.get(&s.address()).map(Vec::len))
.sum::<usize>();
if new_statement_count == 0 {
return;
}
let mut new_statements = ast.vec_with_capacity(statements.len() + new_statement_count);
for stmt in statements.drain(..) {
match self.insertions.remove(&stmt.address()) {
Some(mut adjacent_stmts) => {
let first_after_stmt_index = adjacent_stmts
.iter()
.position(|s| matches!(s.direction, Direction::After))
.unwrap_or(adjacent_stmts.len());
if first_after_stmt_index != 0 {
let right = adjacent_stmts.split_off(first_after_stmt_index);
new_statements.extend(adjacent_stmts.into_iter().map(|s| s.stmt));
new_statements.push(stmt);
new_statements.extend(right.into_iter().map(|s| s.stmt));
} else {
new_statements.push(stmt);
new_statements.extend(adjacent_stmts.into_iter().map(|s| s.stmt));
}
}
_ => {
new_statements.push(stmt);
}
}
}
*statements = new_statements;
}
#[expect(clippy::inline_always)]
#[inline(always)]
pub(crate) fn assert_no_insertions_remaining(&self) {
debug_assert!(self.insertions.is_empty());
}
}