#![cfg_attr(docsrs, feature(doc_cfg))]
#![doc = include_str!("../README.md")]
#[cfg(feature = "eval")]
pub mod eval;
#[cfg(feature = "generics")]
mod generics;
#[cfg(feature = "package")]
mod package;
mod condcomp;
mod error;
mod idents;
mod import;
mod lower;
mod mangle;
mod resolve;
mod sourcemap;
mod strip;
mod syntax_util;
mod validate;
mod visit;
#[cfg(feature = "eval")]
pub use eval::{Eval, EvalError, Exec, Inputs, exec_entrypoint};
#[cfg(feature = "generics")]
pub use generics::GenericsError;
#[cfg(feature = "package")]
pub use package::{Module, Pkg, PkgBuilder};
pub use condcomp::{CondCompError, Feature, Features};
pub use error::{Diagnostic, Error};
pub use import::ImportError;
pub use lower::lower;
pub use mangle::{CacheMangler, EscapeMangler, HashMangler, Mangler, NoMangler, UnicodeMangler};
pub use resolve::{
CodegenModule, CodegenPkg, FileResolver, NoResolver, PkgResolver, Preprocessor, ResolveError,
Resolver, Router, StandardResolver, VirtualResolver, emit_rerun_if_changed,
};
pub use sourcemap::{BasicSourceMap, NoSourceMap, SourceMap, SourceMapper};
pub use syntax_util::SyntaxUtil;
pub use validate::{ValidateError, validate_wesl, validate_wgsl};
pub use wesl_macros::*;
pub use wgsl_parse::syntax;
pub use wgsl_parse::syntax::ModulePath;
#[cfg(feature = "eval")]
use std::collections::HashMap;
use std::{collections::HashSet, fmt::Display, path::Path};
use strip::strip_except;
use wgsl_parse::{
SyntaxNode,
syntax::{Ident, TranslationUnit},
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CompileOptions {
pub imports: bool,
pub condcomp: bool,
pub generics: bool,
pub strip: bool,
pub lower: bool,
pub validate: bool,
pub lazy: bool,
pub mangle_root: bool,
pub keep: Option<Vec<String>>,
pub keep_root: bool,
pub features: Features,
}
impl Default for CompileOptions {
fn default() -> Self {
Self {
imports: true,
condcomp: true,
generics: false,
strip: true,
lower: false,
validate: true,
lazy: true,
mangle_root: false,
keep: Default::default(),
keep_root: false,
features: Default::default(),
}
}
}
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
pub enum ManglerKind {
#[default]
Escape,
Hash,
Unicode,
None,
}
#[macro_export]
macro_rules! include_wesl {
($root:literal) => {
include_str!(concat!(env!("OUT_DIR"), "/", $root, ".wgsl"))
};
}
#[macro_export]
macro_rules! wesl_pkg {
($(#[$attr:meta])* $vis:vis $pkg_name:ident) => {
$crate::wesl_pkg!($(#[$attr])* $vis $pkg_name, concat!(stringify!($pkg_name), ".rs"));
};
($(#[$attr:meta])* $vis:vis $pkg_name:ident, $source:expr) => {
$(#[$attr])* $vis mod $pkg_name {
#![allow(non_snake_case)]
use $crate::{CodegenModule, CodegenPkg};
include!(concat!(env!("OUT_DIR"), "/", $source));
}
};
}
pub struct Wesl<R: Resolver> {
options: CompileOptions,
use_sourcemap: bool,
resolver: R,
mangler: Box<dyn Mangler + Send + Sync + 'static>,
}
impl Wesl<StandardResolver> {
pub fn new(base: impl AsRef<Path>) -> Self {
Self {
options: CompileOptions::default(),
use_sourcemap: true,
resolver: StandardResolver::new(base),
mangler: Box::new(EscapeMangler),
}
}
pub fn new_experimental(base: impl AsRef<Path>) -> Self {
Self {
options: CompileOptions {
generics: true,
lower: true,
..Default::default()
},
use_sourcemap: true,
resolver: StandardResolver::new(base),
mangler: Box::new(EscapeMangler),
}
}
pub fn add_package(&mut self, pkg: &'static CodegenPkg) -> &mut Self {
self.resolver.add_package(pkg);
self
}
pub fn add_packages(
&mut self,
pkgs: impl IntoIterator<Item = &'static CodegenPkg>,
) -> &mut Self {
for pkg in pkgs {
self.resolver.add_package(pkg);
}
self
}
pub fn add_constant(&mut self, name: impl ToString, value: f64) -> &mut Self {
self.resolver.add_constant(name, value);
self
}
pub fn add_constants(
&mut self,
constants: impl IntoIterator<Item = (impl ToString, f64)>,
) -> &mut Self {
for (name, value) in constants {
self.resolver.add_constant(name, value);
}
self
}
}
impl Wesl<NoResolver> {
pub fn new_barebones() -> Self {
Self {
options: CompileOptions {
imports: false,
condcomp: false,
generics: false,
strip: false,
lower: false,
validate: false,
lazy: false,
mangle_root: false,
keep: None,
keep_root: false,
features: Default::default(),
},
use_sourcemap: false,
resolver: NoResolver,
mangler: Box::new(NoMangler),
}
}
}
impl<R: Resolver> Wesl<R> {
pub fn set_options(&mut self, options: CompileOptions) -> &mut Self {
self.options = options;
self
}
pub fn set_mangler(&mut self, kind: ManglerKind) -> &mut Self {
self.mangler = match kind {
ManglerKind::Escape => Box::new(EscapeMangler),
ManglerKind::Hash => Box::new(HashMangler),
ManglerKind::Unicode => Box::new(UnicodeMangler),
ManglerKind::None => Box::new(NoMangler),
};
self
}
pub fn set_custom_mangler(
&mut self,
mangler: impl Mangler + Send + Sync + 'static,
) -> &mut Self {
self.mangler = Box::new(mangler);
self
}
pub fn set_custom_resolver<CustomResolver: Resolver>(
self,
resolver: CustomResolver,
) -> Wesl<CustomResolver> {
Wesl {
options: self.options,
use_sourcemap: self.use_sourcemap,
mangler: self.mangler,
resolver,
}
}
pub fn use_sourcemap(&mut self, val: bool) -> &mut Self {
self.use_sourcemap = val;
self
}
pub fn use_imports(&mut self, val: bool) -> &mut Self {
self.options.imports = val;
self
}
pub fn use_condcomp(&mut self, val: bool) -> &mut Self {
self.options.condcomp = val;
self
}
#[cfg(feature = "generics")]
pub fn use_generics(&mut self, val: bool) -> &mut Self {
self.options.generics = val;
self
}
pub fn set_feature(&mut self, feat: &str, val: impl Into<Feature>) -> &mut Self {
self.options
.features
.flags
.insert(feat.to_string(), val.into());
self
}
pub fn set_features(
&mut self,
feats: impl IntoIterator<Item = (impl ToString, impl Into<Feature>)>,
) -> &mut Self {
self.options
.features
.flags
.extend(feats.into_iter().map(|(k, v)| (k.to_string(), v.into())));
self
}
pub fn unset_feature(&mut self, feat: &str) -> &mut Self {
self.options.features.flags.remove(feat);
self
}
pub fn set_missing_feature_behavior(&mut self, val: impl Into<Feature>) -> &mut Self {
self.options.features.default = val.into();
self
}
pub fn use_stripping(&mut self, val: bool) -> &mut Self {
self.options.strip = val;
self
}
pub fn use_lower(&mut self, val: bool) -> &mut Self {
self.options.lower = val;
self
}
pub fn keep_declarations(&mut self, keep: Vec<String>) -> &mut Self {
self.options.keep = Some(keep);
self
}
pub fn keep_all_entrypoints(&mut self) -> &mut Self {
self.options.keep = None;
self
}
pub fn resolver(&self) -> &R {
&self.resolver
}
}
#[derive(Clone, Default)]
pub struct CompileResult {
pub syntax: TranslationUnit,
pub sourcemap: Option<BasicSourceMap>,
pub modules: Vec<ModulePath>,
}
impl CompileResult {
pub fn write_to_file(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
std::fs::write(path, self.to_string())
}
pub fn write_artifact(&self, artifact_name: &str) {
let dirname = std::env::var("OUT_DIR").unwrap();
let out_name = Path::new(artifact_name);
if out_name.iter().count() != 1 || out_name.extension().is_some() {
eprintln!("`out_name` cannot contain path separators or file extension");
panic!()
}
let mut output = Path::new(&dirname).join(out_name);
output.set_extension("wgsl");
self.write_to_file(output)
.expect("failed to write output shader");
}
}
impl Display for CompileResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.syntax.fmt(f)
}
}
#[cfg(feature = "eval")]
pub struct ExecResult<'a> {
pub inst: Option<eval::Instance>,
pub ctx: eval::Context<'a>,
}
#[cfg(feature = "eval")]
impl ExecResult<'_> {
pub fn return_value(&self) -> Option<&eval::Instance> {
self.inst.as_ref()
}
pub fn resource(&self, group: u32, binding: u32) -> Option<&eval::RefInstance> {
self.ctx.resource(group, binding)
}
}
#[cfg(feature = "eval")]
pub struct EvalResult<'a> {
pub inst: eval::Instance,
pub ctx: eval::Context<'a>,
}
#[cfg(feature = "eval")]
impl EvalResult<'_> {
pub fn to_buffer(&mut self) -> Option<Vec<u8>> {
self.inst.to_buffer()
}
}
#[cfg(feature = "eval")]
impl Display for EvalResult<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.inst.fmt(f)
}
}
#[cfg(feature = "eval")]
impl CompileResult {
pub fn eval<'a>(&'a self, source: &str) -> Result<EvalResult<'a>, Error> {
let expr = source
.parse::<syntax::Expression>()
.map_err(|e| Error::Error(Diagnostic::from(e).with_source(source.to_string())))?;
let (inst, ctx) = eval(&expr, &self.syntax);
let inst = inst.map_err(|e| {
Diagnostic::from(e)
.with_source(source.to_string())
.with_ctx(&ctx)
});
let inst = if let Some(sourcemap) = &self.sourcemap {
inst.map_err(|e| Error::Error(e.with_sourcemap(sourcemap)))
} else {
inst.map_err(Error::Error)
}?;
let res = EvalResult { inst, ctx };
Ok(res)
}
pub fn exec<'a>(
&'a self,
entrypoint: &str,
inputs: Inputs,
bindings: HashMap<(u32, u32), eval::RefInstance>,
overrides: HashMap<String, eval::Instance>,
) -> Result<ExecResult<'a>, Error> {
let mut ctx = eval::Context::new(&self.syntax);
ctx.add_bindings(bindings);
ctx.add_overrides(overrides);
ctx.set_stage(eval::ShaderStage::Exec);
let entry_fn = eval::SyntaxUtil::decl_function(ctx.source, entrypoint)
.ok_or_else(|| EvalError::UnknownFunction(entrypoint.to_string()))?;
let _ = self.syntax.exec(&mut ctx)?;
let inst = exec_entrypoint(entry_fn, inputs, &mut ctx).map_err(|e| {
if let Some(sourcemap) = &self.sourcemap {
Diagnostic::from(e).with_ctx(&ctx).with_sourcemap(sourcemap)
} else {
Diagnostic::from(e).with_ctx(&ctx)
}
})?;
Ok(ExecResult { inst, ctx })
}
}
impl<R: Resolver> Wesl<R> {
pub fn compile(&self, root: &ModulePath) -> Result<CompileResult, Error> {
if self.use_sourcemap {
compile_sourcemap(root, &self.resolver, &self.mangler, &self.options)
} else {
compile(root, &self.resolver, &self.mangler, &self.options)
}
}
pub fn build_artifact(&self, root: &ModulePath, artifact_name: &str) {
let compiled = self
.compile(root)
.inspect_err(|e| {
eprintln!("failed to build WESL shader `{root}`.\n{e}");
panic!();
})
.unwrap();
emit_rerun_if_changed(&compiled.modules, &self.resolver);
compiled.write_artifact(artifact_name);
}
}
fn keep_idents(
wesl: &TranslationUnit,
keep: &Option<Vec<String>>,
keep_root: bool,
strip: bool,
) -> HashSet<Ident> {
if strip && !keep_root {
if let Some(keep) = keep {
wesl.global_declarations
.iter()
.filter_map(|decl| {
let ident = decl.ident()?;
keep.iter()
.any(|name| name == &*ident.name())
.then_some(ident)
})
.collect()
} else {
wesl.entry_points().cloned().collect()
}
} else {
wesl.global_declarations
.iter()
.filter_map(|decl| decl.ident())
.collect()
}
}
fn compile_pre_assembly(
root: &ModulePath,
resolver: &impl Resolver,
opts: &CompileOptions,
) -> Result<(import::Resolutions, HashSet<Ident>), Error> {
let resolver: Box<dyn Resolver> = if opts.condcomp {
Box::new(Preprocessor::new(resolver, |wesl| {
condcomp::run(wesl, &opts.features)?;
Ok(())
}))
} else {
Box::new(resolver)
};
let wesl = resolver.resolve_module(root)?;
let keep = keep_idents(&wesl, &opts.keep, opts.keep_root, opts.strip);
let resolutions = if opts.imports {
if opts.strip && opts.lazy {
import::resolve_lazy(&keep, wesl, root, &resolver)?
} else {
import::resolve_eager(wesl, root, &resolver)?
}
} else {
import::Resolutions::new(wesl, root.clone())
};
if opts.validate {
for module in resolutions.modules() {
let module = module.borrow();
validate_wesl(&module.source).map_err(|d| {
d.with_module_path(module.path.clone(), resolver.display_name(&module.path))
})?;
}
}
Ok((resolutions, keep))
}
fn compile_post_assembly(
wesl: &mut TranslationUnit,
options: &CompileOptions,
keep: &HashSet<Ident>,
) -> Result<(), Error> {
#[cfg(feature = "generics")]
if options.generics {
generics::generate_variants(wesl)?;
generics::replace_calls(wesl)?;
};
if options.validate {
validate_wgsl(wesl)?;
}
if options.lower {
lower(wesl)?;
}
if options.strip {
strip_except(wesl, keep);
}
Ok(())
}
pub fn compile(
root: &ModulePath,
resolver: &impl Resolver,
mangler: &impl Mangler,
options: &CompileOptions,
) -> Result<CompileResult, Error> {
let (mut resolutions, keep) = compile_pre_assembly(root, resolver, options)?;
resolutions.mangle(mangler, options.mangle_root);
let mut assembly = resolutions.assemble(options.strip && options.lazy);
let modules = resolutions.into_module_order();
compile_post_assembly(&mut assembly, options, &keep)?;
Ok(CompileResult {
syntax: assembly,
sourcemap: None,
modules,
})
}
pub fn compile_sourcemap(
root: &ModulePath,
resolver: &impl Resolver,
mangler: &impl Mangler,
options: &CompileOptions,
) -> Result<CompileResult, Error> {
let sourcemapper = SourceMapper::new(root, resolver, mangler);
match compile_pre_assembly(root, &sourcemapper, options) {
Ok((mut resolutions, keep)) => {
resolutions.mangle(&sourcemapper, options.mangle_root);
let sourcemap = sourcemapper.finish();
let mut assembly = resolutions.assemble(options.strip && options.lazy);
let modules = resolutions.into_module_order();
compile_post_assembly(&mut assembly, options, &keep)
.map_err(|e| {
Diagnostic::from(e)
.with_output(assembly.to_string())
.with_sourcemap(&sourcemap)
.unmangle(Some(&sourcemap), Some(&mangler))
.into()
})
.map(|()| CompileResult {
syntax: assembly,
sourcemap: Some(sourcemap),
modules,
})
}
Err(e) => {
let sourcemap = sourcemapper.finish();
Err(Diagnostic::from(e)
.with_sourcemap(&sourcemap)
.unmangle(Some(&sourcemap), Some(&mangler))
.into())
}
}
}
#[cfg(feature = "eval")]
pub fn eval_str(expr: &str) -> Result<eval::Instance, Error> {
let expr = expr
.parse::<syntax::Expression>()
.map_err(|e| Error::Error(Diagnostic::from(e).with_source(expr.to_string())))?;
let wgsl = TranslationUnit::default();
let (inst, ctx) = eval(&expr, &wgsl);
inst.map_err(|e| {
Error::Error(
Diagnostic::from(e)
.with_source(expr.to_string())
.with_ctx(&ctx),
)
})
}
#[cfg(feature = "eval")]
pub fn eval<'s>(
expr: &syntax::Expression,
wgsl: &'s TranslationUnit,
) -> (Result<eval::Instance, EvalError>, eval::Context<'s>) {
let mut ctx = eval::Context::new(wgsl);
let res = wgsl.exec(&mut ctx).and_then(|_| expr.eval(&mut ctx));
(res, ctx)
}
#[cfg(feature = "eval")]
pub fn exec<'s>(
expr: &impl Eval,
wgsl: &'s TranslationUnit,
bindings: HashMap<(u32, u32), eval::RefInstance>,
overrides: HashMap<String, eval::Instance>,
) -> (Result<Option<eval::Instance>, EvalError>, eval::Context<'s>) {
let mut ctx = eval::Context::new(wgsl);
ctx.add_bindings(bindings);
ctx.add_overrides(overrides);
ctx.set_stage(eval::ShaderStage::Exec);
let res = wgsl.exec(&mut ctx).and_then(|_| match expr.eval(&mut ctx) {
Ok(ret) => Ok(Some(ret)),
Err(eval::EvalError::Void(_)) => Ok(None),
Err(e) => Err(e),
});
(res, ctx)
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Wesl<StandardResolver>>();
}