use crate::ir::model::arena::{ArenaProgram, ExprArena};
use crate::ir::model::node::Node;
use crate::ir::model::types::{BufferAccess, DataType};
use rustc_hash::FxHashMap;
use std::collections::HashMap;
use std::hash::Hash;
use std::rc::Rc;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct Program {
pub entry_op_id: Option<String>,
pub buffers: Arc<[BufferDecl]>,
#[doc(hidden)]
pub buffer_index: Arc<FxHashMap<String, usize>>,
pub workgroup_size: [u32; 3],
pub entry: Arc<Vec<Node>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BufferDecl {
pub name: String,
pub binding: u32,
pub access: BufferAccess,
pub element: DataType,
pub count: u32,
pub is_output: bool,
}
#[derive(Debug, Clone)]
pub struct Scope<K, V> {
bindings: Rc<HashMap<K, V>>,
}
impl<K, V> Default for Scope<K, V> {
fn default() -> Self {
Self {
bindings: Rc::new(HashMap::new()),
}
}
}
mod impl_scope;
impl BufferDecl {
#[must_use]
#[inline]
pub fn storage(name: &str, binding: u32, access: BufferAccess, element: DataType) -> Self {
Self {
name: name.to_string(),
binding,
access,
element,
count: 0,
is_output: false,
}
}
#[must_use]
#[inline]
pub fn read(name: &str, binding: u32, element: DataType) -> Self {
Self::storage(name, binding, BufferAccess::ReadOnly, element)
}
#[must_use]
#[inline]
pub fn read_write(name: &str, binding: u32, element: DataType) -> Self {
Self::storage(name, binding, BufferAccess::ReadWrite, element)
}
#[must_use]
#[inline]
pub fn output(name: &str, binding: u32, element: DataType) -> Self {
Self {
is_output: true,
..Self::read_write(name, binding, element)
}
}
#[must_use]
#[inline]
pub fn uniform(name: &str, binding: u32, element: DataType) -> Self {
Self::storage(name, binding, BufferAccess::Uniform, element)
}
#[must_use]
#[inline]
pub fn workgroup(name: &str, count: u32, element: DataType) -> Self {
Self {
name: name.to_string(),
binding: 0,
access: BufferAccess::Workgroup,
element,
count,
is_output: false,
}
}
#[must_use]
#[inline]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
#[inline]
pub fn binding(&self) -> u32 {
self.binding
}
#[must_use]
#[inline]
pub fn access(&self) -> BufferAccess {
self.access.clone()
}
#[must_use]
#[inline]
pub fn element(&self) -> DataType {
self.element.clone()
}
#[must_use]
#[inline]
pub fn count(&self) -> u32 {
self.count
}
#[must_use]
#[inline]
pub fn is_output(&self) -> bool {
self.is_output
}
}
impl Program {
#[must_use]
#[inline]
pub fn new(buffers: Vec<BufferDecl>, workgroup_size: [u32; 3], entry: Vec<Node>) -> Self {
let buffer_index = Self::build_buffer_index(&buffers);
Self {
entry_op_id: None,
buffers: Arc::from(buffers),
buffer_index: Arc::new(buffer_index),
workgroup_size,
entry: Arc::new(entry),
}
}
#[must_use]
#[inline]
pub fn with_arena<'a>(
arena: &'a ExprArena,
buffers: Vec<BufferDecl>,
workgroup_size: [u32; 3],
) -> ArenaProgram<'a> {
ArenaProgram::new(arena, buffers, workgroup_size)
}
#[must_use]
#[inline]
pub fn empty() -> Self {
Self {
entry_op_id: None,
buffers: Arc::from([]),
buffer_index: Arc::new(FxHashMap::default()),
workgroup_size: [1, 1, 1],
entry: Arc::new(Vec::new()),
}
}
#[must_use]
#[inline]
pub fn with_entry_op_id(mut self, op_id: impl Into<String>) -> Self {
self.entry_op_id = Some(op_id.into());
self
}
#[must_use]
#[inline]
pub fn entry_op_id(&self) -> Option<&str> {
self.entry_op_id.as_deref()
}
#[must_use]
#[inline]
pub(crate) fn with_optional_entry_op_id(mut self, op_id: Option<String>) -> Self {
self.entry_op_id = op_id;
self
}
#[must_use]
#[inline]
pub fn buffer(&self, name: &str) -> Option<&BufferDecl> {
self.buffer_index
.get(name)
.and_then(|&index| self.buffers.get(index))
}
#[must_use]
#[inline]
pub fn buffers(&self) -> &[BufferDecl] {
self.buffers.as_ref()
}
#[must_use]
#[inline]
pub fn workgroup_size(&self) -> [u32; 3] {
self.workgroup_size
}
#[inline]
pub fn set_workgroup_size(&mut self, workgroup_size: [u32; 3]) {
self.workgroup_size = workgroup_size;
}
#[must_use]
#[inline]
pub fn entry(&self) -> &[Node] {
self.entry.as_ref().as_slice()
}
#[must_use]
#[inline]
pub fn entry_mut(&mut self) -> &mut Vec<Node> {
Arc::make_mut(&mut self.entry)
}
#[must_use]
#[inline]
pub fn has_buffer(&self, name: &str) -> bool {
self.buffer_index.contains_key(name)
}
#[must_use]
#[inline]
pub fn buffer_count(&self) -> usize {
self.buffers.len()
}
#[inline]
fn build_buffer_index(buffers: &[BufferDecl]) -> FxHashMap<String, usize> {
let mut index = FxHashMap::default();
index.reserve(buffers.len());
for (buffer_index, buffer) in buffers.iter().enumerate() {
index.entry(buffer.name.clone()).or_insert(buffer_index);
}
index
}
}