use crate::ir::model::expr::Expr;
use crate::ir::model::program::BufferDecl;
use rustc_hash::FxHashMap;
use std::cell::RefCell;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ExprRef {
index: usize,
}
impl ExprRef {
#[must_use]
#[inline]
pub fn index(self) -> usize {
self.index
}
}
#[derive(Default)]
pub struct ExprArena {
bump: bumpalo::Bump,
exprs: RefCell<Vec<Expr>>,
}
impl ExprArena {
#[must_use]
#[inline]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn alloc(&self, expr: Expr) -> ExprRef {
let index = self.exprs.borrow().len();
self.exprs.borrow_mut().push(expr);
ExprRef { index }
}
#[must_use]
pub fn get(&self, expr_ref: ExprRef) -> Option<Expr> {
self.exprs.borrow().get(expr_ref.index).cloned()
}
pub fn reset(&mut self) {
self.exprs.get_mut().clear();
self.bump.reset();
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.exprs.borrow().len()
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct ArenaProgram<'a> {
arena: &'a ExprArena,
buffers: Vec<BufferDecl>,
buffer_index: FxHashMap<String, usize>,
workgroup_size: [u32; 3],
entry: Vec<ExprRef>,
}
impl<'a> ArenaProgram<'a> {
pub(crate) fn new(
arena: &'a ExprArena,
buffers: Vec<BufferDecl>,
workgroup_size: [u32; 3],
) -> Self {
let mut buffer_index = FxHashMap::default();
buffer_index.reserve(buffers.len());
for (index, buffer) in buffers.iter().enumerate() {
buffer_index.entry(buffer.name.clone()).or_insert(index);
}
Self {
arena,
buffers,
buffer_index,
workgroup_size,
entry: Vec::new(),
}
}
pub fn push_expr(&mut self, expr: Expr) -> ExprRef {
let expr_ref = self.arena.alloc(expr);
self.entry.push(expr_ref);
expr_ref
}
#[must_use]
pub fn expr(&self, expr_ref: ExprRef) -> Option<Expr> {
self.arena.get(expr_ref)
}
#[must_use]
pub fn buffers(&self) -> &[BufferDecl] {
&self.buffers
}
#[must_use]
pub fn buffer(&self, name: &str) -> Option<&BufferDecl> {
self.buffer_index
.get(name)
.and_then(|&index| self.buffers.get(index))
}
#[must_use]
pub fn workgroup_size(&self) -> [u32; 3] {
self.workgroup_size
}
#[must_use]
pub fn entry(&self) -> &[ExprRef] {
&self.entry
}
}
#[cfg(test)]
mod tests {
use super::{ArenaProgram, ExprArena};
use crate::ir::model::expr::Expr;
use crate::ir::model::program::BufferDecl;
use crate::ir::model::types::DataType;
#[test]
fn arena_allocates_stable_expression_refs() {
let arena = ExprArena::new();
let first = arena.alloc(Expr::u32(7));
let second = arena.alloc(Expr::var("x"));
assert_eq!(first.index(), 0);
assert_eq!(second.index(), 1);
assert_eq!(arena.get(first), Some(Expr::u32(7)));
assert_eq!(arena.get(second), Some(Expr::var("x")));
}
#[test]
fn arena_program_keeps_buffers_and_expression_handles() {
let arena = ExprArena::new();
let mut program = ArenaProgram::new(
&arena,
vec![BufferDecl::read("input", 0, DataType::U32)],
[64, 1, 1],
);
let expr_ref = program.push_expr(Expr::load("input", Expr::u32(0)));
assert_eq!(program.entry(), &[expr_ref]);
assert_eq!(program.buffer("input").map(BufferDecl::binding), Some(0));
assert_eq!(
program.expr(expr_ref),
Some(Expr::load("input", Expr::u32(0)))
);
}
}