mod attrs;
mod constant;
mod error;
#[allow(clippy::module_inception)]
mod eval;
mod exec;
mod lower;
mod prelude;
mod to_expr;
mod ty;
pub use attrs::*;
pub(crate) use constant::*;
pub use error::*;
pub use eval::*;
pub use exec::*;
pub use lower::*;
pub use prelude::*;
pub use to_expr::*;
pub use ty::*;
pub use wgsl_types::{ShaderStage, builtin::*, conv::*, inst::*, tplt::*, ty::*};
use derive_more::Display;
use std::{collections::HashMap, rc::Rc};
use wgsl_parse::{
span::{Span, Spanned},
syntax::*,
};
use crate::syntax_util::IteratorExt;
#[derive(Debug)]
struct ScopeInner<T> {
local: HashMap<String, T>,
parent: Option<Rc<ScopeInner<T>>>,
transparent: bool,
}
#[derive(Debug)]
pub struct Scope<T> {
inner: Rc<ScopeInner<T>>,
}
impl<T> Default for Scope<T> {
fn default() -> Self {
Self {
inner: Rc::new(ScopeInner {
local: Default::default(),
parent: Default::default(),
transparent: false,
}),
}
}
}
impl<T> ScopeInner<T> {
pub fn depth(&self) -> u32 {
if let Some(parent) = &self.parent {
parent.depth() + 1
} else {
0
}
}
pub fn get(&self, name: &str) -> Option<&T> {
self.local
.get(name)
.or_else(|| self.parent.as_ref().and_then(|parent| parent.get(name)))
}
pub fn contains(&self, name: &str) -> bool {
self.local.contains_key(name)
|| self
.parent
.as_ref()
.is_some_and(|parent| parent.contains(name))
}
pub fn iter_keys(&self) -> impl Iterator<Item = &str> {
self.local.keys().map(|k| k.as_str()).chain(
self.parent
.iter()
.flat_map(|parent| parent.iter_keys())
.boxed(),
)
}
}
impl<T> Scope<T> {
pub fn new() -> Self {
Self {
inner: Rc::new(ScopeInner {
local: Default::default(),
parent: None,
transparent: false,
}),
}
}
pub fn is_root(&self) -> bool {
self.inner.parent.is_none()
}
pub fn depth(&self) -> u32 {
self.inner.depth()
}
pub fn make_transparent(&mut self) {
Rc::get_mut(&mut self.inner)
.expect("cannot edit a parent scope")
.transparent = true;
}
pub fn push(&mut self) {
self.inner = Rc::new(ScopeInner {
local: Default::default(),
parent: Some(self.inner.clone()),
transparent: false,
});
}
pub fn push_transparent(&mut self) {
self.inner = Rc::new(ScopeInner {
local: Default::default(),
parent: Some(self.inner.clone()),
transparent: true,
});
}
pub fn pop(&mut self) {
self.inner = self.inner.parent.clone().expect("failed to pop scope");
}
pub fn add(&mut self, name: String, value: T) -> bool {
if self.local_contains(&name) {
false
} else {
Rc::get_mut(&mut self.inner)
.expect("cannot edit a parent scope")
.local
.insert(name, value);
true
}
}
pub fn local_get_mut(&mut self, name: &str) -> Option<&mut T> {
Rc::get_mut(&mut self.inner)
.expect("cannot edit a parent scope")
.local
.get_mut(name)
}
pub fn get(&self, name: &str) -> Option<&T> {
self.inner.get(name)
}
pub fn local_contains(&self, name: &str) -> bool {
self.inner.local.contains_key(name)
|| self.inner.transparent
&& self
.inner
.parent
.as_ref()
.expect("transparent scope must have a parent")
.local
.contains_key(name)
}
pub fn contains(&self, name: &str) -> bool {
self.inner.contains(name)
}
pub fn iter_keys(&self) -> impl Iterator<Item = &str> {
self.inner.iter_keys()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Display)]
pub enum ScopeKind {
#[display("module")]
Module,
#[display("function")]
Function,
}
#[derive(Clone, Copy, Debug)]
pub enum ResourceKind {
UniformBuffer,
StorageBuffer,
Texture,
Sampler,
}
pub struct Context<'s> {
pub(crate) source: &'s TranslationUnit,
pub(crate) scope: Scope<Instance>,
pub(crate) resources: HashMap<(u32, u32), RefInstance>,
pub(crate) overrides: HashMap<String, Instance>,
pub(crate) kind: ScopeKind,
pub(crate) stage: ShaderStage,
pub(crate) err_decl: Option<String>,
pub(crate) err_span: Option<Span>,
}
impl<'s> Context<'s> {
pub fn new(source: &'s TranslationUnit) -> Self {
Self {
source,
scope: Default::default(),
resources: Default::default(),
overrides: Default::default(),
kind: ScopeKind::Function,
stage: ShaderStage::Const,
err_span: None,
err_decl: None,
}
}
pub fn source(&self) -> &TranslationUnit {
self.source
}
fn set_err_decl_ctx(&mut self, decl: String) {
if self.err_decl.is_none() {
self.err_decl = Some(decl)
}
}
fn set_err_span_ctx(&mut self, expr: Span) {
if self.err_span.is_none() {
self.err_span = Some(expr)
}
}
pub fn err_ctx(&self) -> (Option<String>, Option<Span>) {
(self.err_decl.clone(), self.err_span)
}
pub fn set_stage(&mut self, stage: ShaderStage) {
self.stage = stage;
}
pub fn add_bindings(&mut self, resources: impl IntoIterator<Item = ((u32, u32), RefInstance)>) {
for ((group, binding), inst) in resources.into_iter() {
self.add_binding(group, binding, inst);
}
}
pub fn add_binding(&mut self, group: u32, binding: u32, inst: RefInstance) {
self.resources.insert((group, binding), inst);
}
pub fn resource(&self, group: u32, binding: u32) -> Option<&RefInstance> {
self.resources.get(&(group, binding))
}
pub fn add_overrides(&mut self, overrides: impl IntoIterator<Item = (String, Instance)>) {
self.overrides.extend(overrides);
}
pub fn add_overridable(&mut self, name: String, inst: Instance) {
self.overrides.insert(name, inst);
}
pub fn overridable(&self, name: &str) -> Option<&Instance> {
self.overrides.get(name)
}
}
pub trait SyntaxUtil {
fn user_decl(&self, name: &str) -> Option<&GlobalDeclaration>;
fn decl(&self, name: &str) -> Option<&GlobalDeclaration>;
fn decl_decl(&self, name: &str) -> Option<&Declaration>;
fn decl_alias(&self, name: &str) -> Option<&TypeAlias>;
fn decl_struct(&self, name: &str) -> Option<&Struct>;
fn decl_function(&self, name: &str) -> Option<&Function>;
fn resolve_alias(&self, name: &str) -> Option<&TypeExpression>;
fn resolve_ty<'a>(&'a self, ty: &'a TypeExpression) -> &'a TypeExpression;
}
impl SyntaxUtil for TranslationUnit {
fn user_decl(&self, name: &str) -> Option<&GlobalDeclaration> {
self.global_declarations
.iter()
.map(Spanned::node)
.find(|d| match d {
GlobalDeclaration::Declaration(d) => *d.ident.name() == name,
GlobalDeclaration::TypeAlias(d) => *d.ident.name() == name,
GlobalDeclaration::Struct(d) => *d.ident.name() == name,
GlobalDeclaration::Function(d) => *d.ident.name() == name,
_ => false,
})
}
fn decl(&self, name: &str) -> Option<&GlobalDeclaration> {
self.global_declarations
.iter()
.chain(PRELUDE.global_declarations.iter())
.map(Spanned::node)
.find(|d| match d {
GlobalDeclaration::Declaration(d) => *d.ident.name() == name,
GlobalDeclaration::TypeAlias(d) => *d.ident.name() == name,
GlobalDeclaration::Struct(d) => *d.ident.name() == name,
GlobalDeclaration::Function(d) => *d.ident.name() == name,
_ => false,
})
}
fn decl_decl(&self, name: &str) -> Option<&Declaration> {
match self.decl(name) {
Some(GlobalDeclaration::Declaration(s)) => Some(s),
_ => None,
}
}
fn decl_alias(&self, name: &str) -> Option<&TypeAlias> {
match self.decl(name) {
Some(GlobalDeclaration::TypeAlias(s)) => Some(s),
_ => None,
}
}
fn decl_struct(&self, name: &str) -> Option<&Struct> {
match self.decl(name) {
Some(GlobalDeclaration::Struct(s)) => Some(s),
_ => None,
}
}
fn decl_function(&self, name: &str) -> Option<&Function> {
match self.decl(name) {
Some(GlobalDeclaration::Function(f)) => Some(f),
_ => None,
}
}
fn resolve_alias(&self, name: &str) -> Option<&TypeExpression> {
if let Some(alias) = self.decl_alias(name) {
if alias.ty.template_args.is_none() {
self.resolve_alias(&alias.ty.ident.name())
.or(Some(&alias.ty))
} else {
Some(&alias.ty)
}
} else {
None
}
}
fn resolve_ty<'a>(&'a self, ty: &'a TypeExpression) -> &'a TypeExpression {
self.resolve_alias(&ty.ident.name()).unwrap_or(ty)
}
}