#![allow(unsafe_code)]
use crate::ir_inner::model::expr::Expr;
use crate::ir_inner::model::program::BufferDecl;
use bumpalo::Bump;
use rustc_hash::FxHashMap;
use std::cell::{Cell, UnsafeCell};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ExprRef {
index: usize,
}
impl ExprRef {
#[must_use]
#[inline]
pub fn index(self) -> usize {
self.index
}
}
#[derive(Default)]
pub struct ExprArena {
bump: Bump,
exprs: UnsafeCell<Vec<*const Expr>>,
len: Cell<usize>,
}
impl ExprArena {
#[must_use]
#[inline]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn alloc(&self, expr: Expr) -> ExprRef {
let index = self.len.get();
let ptr = self.bump.alloc(expr) as *const Expr;
unsafe {
(*self.exprs.get()).push(ptr);
}
self.len.set(index + 1);
ExprRef { index }
}
#[must_use]
pub fn get(&self, expr_ref: ExprRef) -> Option<&Expr> {
unsafe {
let vec: &Vec<*const Expr> = &*self.exprs.get();
vec.get(expr_ref.index).and_then(|ptr| ptr.as_ref())
}
}
pub fn reset(&mut self) {
self.exprs.get_mut().clear();
self.len.set(0);
self.bump.reset();
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.len.get()
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct ArenaProgram<'a> {
arena: &'a ExprArena,
buffers: Vec<BufferDecl>,
buffer_index: FxHashMap<Arc<str>, usize>,
workgroup_size: [u32; 3],
entry: Vec<ExprRef>,
}
impl<'a> ArenaProgram<'a> {
pub(crate) fn new(
arena: &'a ExprArena,
buffers: Vec<BufferDecl>,
workgroup_size: [u32; 3],
) -> Self {
let mut buffer_index = FxHashMap::default();
buffer_index.reserve(buffers.len());
for (index, buffer) in buffers.iter().enumerate() {
buffer_index
.entry(Arc::clone(&buffer.name))
.or_insert(index);
}
Self {
arena,
buffers,
buffer_index,
workgroup_size,
entry: Vec::new(),
}
}
#[must_use]
pub fn push_expr(&mut self, expr: Expr) -> ExprRef {
let expr_ref = self.arena.alloc(expr);
self.entry.push(expr_ref);
expr_ref
}
#[must_use]
pub fn expr(&self, expr_ref: ExprRef) -> Option<&Expr> {
self.arena.get(expr_ref)
}
#[must_use]
pub fn buffers(&self) -> &[BufferDecl] {
&self.buffers
}
#[must_use]
pub fn buffer(&self, name: &str) -> Option<&BufferDecl> {
self.buffer_index
.get(name)
.and_then(|&index| self.buffers.get(index))
}
#[must_use]
pub fn workgroup_size(&self) -> [u32; 3] {
self.workgroup_size
}
#[must_use]
pub fn entry(&self) -> &[ExprRef] {
&self.entry
}
}
#[cfg(test)]
mod tests {
use super::{ArenaProgram, ExprArena};
use crate::ir_inner::model::expr::Expr;
use crate::ir_inner::model::program::BufferDecl;
use crate::ir_inner::model::types::DataType;
#[test]
fn arena_allocates_stable_expression_refs() {
let arena = ExprArena::new();
let first = arena.alloc(Expr::u32(7));
let second = arena.alloc(Expr::var("x"));
assert_eq!(first.index(), 0);
assert_eq!(second.index(), 1);
assert_eq!(arena.get(first), Some(&Expr::u32(7)));
assert_eq!(arena.get(second), Some(&Expr::var("x")));
}
#[test]
fn arena_program_keeps_buffers_and_expression_handles() {
let arena = ExprArena::new();
let mut program = ArenaProgram::new(
&arena,
vec![BufferDecl::read("input", 0, DataType::U32)],
[64, 1, 1],
);
let expr_ref = program.push_expr(Expr::load("input", Expr::u32(0)));
assert_eq!(program.entry(), &[expr_ref]);
assert_eq!(program.buffer("input").map(BufferDecl::binding), Some(0));
assert_eq!(
program.expr(expr_ref),
Some(&Expr::load("input", Expr::u32(0)))
);
}
}