use clap::{Args, Parser, Subcommand, ValueEnum};
use std::{
convert::Infallible,
error::Error,
fs::{self, File},
io::{IsTerminal, Read, Write},
path::PathBuf,
str::FromStr,
};
use wesl::{
CompileOptions, CompileResult, Diagnostic, Feature, Features, Inputs, ManglerKind, ModulePath,
PkgBuilder, Router, StandardResolver, SyntaxUtil, VirtualResolver, Wesl,
eval::{Eval, EvalAttrs, Instance, RefInstance, Ty, ty_eval_ty},
syntax::{self, AccessMode, AddressSpace, PathOrigin, TranslationUnit},
};
fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
where
T: FromStr,
T::Err: Error + Send + Sync + 'static,
U: FromStr + Default,
U::Err: Error + Send + Sync + 'static,
{
let pos = s.find('=');
if let Some(pos) = pos {
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
} else {
Ok((s.parse()?, U::default()))
}
}
#[derive(Parser)]
#[command(version, author, about)]
#[command(propagate_version = true)]
struct Cli {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand, Clone, Debug)]
enum Command {
Check(CheckArgs),
Compile(CompileArgs),
Eval(EvalArgs),
Exec(ExecArgs),
Package(PkgArgs),
}
#[derive(Default, Clone, Copy, Debug, ValueEnum)]
pub enum ClapManglerKind {
#[default]
Escape,
Hash,
Unicode,
None,
}
impl From<ClapManglerKind> for ManglerKind {
fn from(value: ClapManglerKind) -> Self {
match value {
ClapManglerKind::Escape => Self::Escape,
ClapManglerKind::Hash => Self::Hash,
ClapManglerKind::Unicode => Self::Unicode,
ClapManglerKind::None => Self::None,
}
}
}
#[derive(Default, Clone, Copy, Debug, ValueEnum)]
pub enum ClapFeature {
#[default]
#[value(alias("true"))]
Enable,
#[value(alias("false"))]
Disable,
Keep,
Error,
}
impl FromStr for ClapFeature {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"enable" | "true" => Ok(Self::Enable),
"disable" | "false" => Ok(Self::Disable),
"keep" => Ok(Self::Keep),
"error" => Ok(Self::Error),
_ => {
panic!("not a valid feature value, expected `enable`, `disable`, `keep` or `error`")
}
}
}
}
impl From<ClapFeature> for Feature {
fn from(value: ClapFeature) -> Self {
match value {
ClapFeature::Enable => Self::Enable,
ClapFeature::Disable => Self::Disable,
ClapFeature::Keep => Self::Keep,
ClapFeature::Error => Self::Error,
}
}
}
#[derive(Args, Clone, Debug)]
struct CompOptsArgs {
#[arg(long, default_value = "escape")]
mangler: ClapManglerKind,
#[arg(long)]
no_sourcemap: bool,
#[arg(long)]
no_imports: bool,
#[arg(long)]
no_cond_comp: bool,
#[arg(long)]
generics: bool,
#[arg(long)]
no_strip: bool,
#[arg(long)]
lower: bool,
#[arg(long)]
no_validate: bool,
#[arg(long)]
eager: bool,
#[arg(long)]
mangle_root: bool,
#[cfg(feature = "naga")]
#[arg(long)]
no_naga: bool,
#[arg(long)]
keep: Option<Vec<String>>,
#[arg(long)]
keep_root: bool,
#[arg(short='D', long, value_name="NAME | NAME=[enable, disable, keep, error]", value_parser = parse_key_val::<String, ClapFeature>)]
feature: Vec<(String, ClapFeature)>,
#[arg(long, default_value = "disable")]
feature_default: ClapFeature,
#[arg(long)]
base: Option<PathBuf>,
}
impl From<&CompOptsArgs> for CompileOptions {
fn from(opts: &CompOptsArgs) -> Self {
let flags = opts
.feature
.iter()
.map(|(k, v)| (k.clone(), (*v).into()))
.collect();
Self {
imports: !opts.no_imports,
condcomp: !opts.no_cond_comp,
generics: opts.generics,
strip: !opts.no_strip,
lower: opts.lower,
validate: !opts.no_validate,
lazy: !opts.eager,
mangle_root: opts.mangle_root,
keep: if opts.no_strip {
None
} else {
opts.keep.clone()
},
keep_root: opts.keep_root,
features: Features {
default: opts.feature_default.into(),
flags,
},
}
}
}
#[derive(Args, Clone, Debug)]
struct CompileArgs {
#[command(flatten)]
options: CompOptsArgs,
file: Option<PathBuf>,
}
#[derive(Args, Clone, Debug)]
struct CheckArgs {
#[arg(long, default_value = "wesl")]
kind: CheckKind,
#[cfg(feature = "naga")]
#[arg(long)]
naga: bool,
file: Option<PathBuf>,
}
#[derive(ValueEnum, Clone, Debug, Default)]
enum CheckKind {
Wgsl,
#[default]
Wesl,
}
#[derive(Clone, Copy, Debug)]
enum BindingType {
Uniform,
Storage,
ReadOnlyStorage,
Filtering,
NonFiltering,
Comparison,
Float,
UnfilterableFloat,
Sint,
Uint,
Depth,
WriteOnly,
ReadWrite,
ReadOnly,
}
impl FromStr for BindingType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"uniform" => Ok(Self::Uniform),
"storage" => Ok(Self::Storage),
"read-only-storage" => Ok(Self::ReadOnlyStorage),
"filtering" => Ok(Self::Filtering),
"non-filtering" => Ok(Self::NonFiltering),
"comparison" => Ok(Self::Comparison),
"float" => Ok(Self::Float),
"unfilterable-float" => Ok(Self::UnfilterableFloat),
"sint" => Ok(Self::Sint),
"uint" => Ok(Self::Uint),
"depth" => Ok(Self::Depth),
"write-only" => Ok(Self::WriteOnly),
"read-write" => Ok(Self::ReadWrite),
"read-only" => Ok(Self::ReadOnly),
_ => Err(()),
}
}
}
#[derive(Clone, Debug)]
struct Binding {
group: u32,
binding: u32,
kind: BindingType,
data: Box<[u8]>,
}
impl FromStr for Binding {
type Err = Box<dyn Error + Send + Sync + 'static>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut it = s.split(':');
let binding = (|| {
Ok(Binding {
group: it
.next()
.ok_or("missing @group number")?
.parse()
.map_err(|e| format!("failed to parse group: {e}"))?,
binding: it
.next()
.ok_or("missing @binding number")?
.parse()
.map_err(|e| format!("failed to parse binding: {e}"))?,
kind: it
.next()
.ok_or("missing resource binding type")?
.parse()
.map_err(|()| "invalid resource binding type".to_string())?,
data: {
let path = PathBuf::from(it.next().ok_or("missing data")?);
let mut file = File::open(&path).expect("failed to open binding file");
let mut buf = Vec::new();
file.read_to_end(&mut buf)
.expect("failed to read binding file");
buf.into_boxed_slice()
},
})
})();
binding.map_err(|e: String| format!("failed to parse binding: {e}").into())
}
}
#[derive(Args, Clone, Debug)]
struct EvalArgs {
#[command(flatten)]
options: CompOptsArgs,
#[arg(short, long)]
binary: bool,
#[arg(long)]
file: Option<PathBuf>,
expr: String,
}
#[derive(Args, Clone, Debug)]
struct ExecArgs {
#[command(flatten)]
options: CompOptsArgs,
#[arg(long = "resource", value_parser = Binding::from_str, verbatim_doc_comment)]
resources: Vec<Binding>,
#[arg(long = "override", value_name="NAME=EXPRESSION", value_parser = parse_key_val::<String, String>)]
overrides: Vec<(String, String)>,
#[arg(short, long = "out-binary")]
binary: bool,
#[arg(long, default_value = "main")]
entrypoint: String,
file: Option<PathBuf>,
}
#[derive(Args, Clone, Debug)]
struct PkgArgs {
name: String,
dir: PathBuf,
}
#[derive(Clone, Debug, thiserror::Error)]
enum CliError {
#[error("input file not found")]
FileNotFound,
#[error("resource `@group({0}) @binding({1})` not found")]
ResourceNotFound(u32, u32),
#[error(
"resource `@group({0}) @binding({1})` ({2} bytes) is incompatible with type `{3}` ({4} bytes)"
)]
ResourceIncompatible(u32, u32, u32, wesl::eval::Type, u32),
#[error("Could not convert instance to buffer (type `{0}` is not storable)")]
NotStorable(wesl::eval::Type),
#[error("{0}")]
WeslError(#[from] wesl::Error),
#[error("{0}")]
WeslDiagnostic(#[from] wesl::Diagnostic<wesl::Error>),
#[cfg(feature = "naga")]
#[error("naga parse error: {}", .0.emit_to_string(.1))]
NagaParse(naga::front::wgsl::ParseError, String),
#[cfg(feature = "naga")]
#[error("naga validation error: {}", .0.emit_to_string(.1))]
NagaValid(Box<naga::WithSpan<naga::valid::ValidationError>>, String),
}
enum FileOrSource {
File(PathBuf),
Source(String),
}
fn run_compile(
options: &CompOptsArgs,
file_or_source: FileOrSource,
) -> Result<CompileResult, CliError> {
let compile_options = CompileOptions::from(options);
let mut compiler = Wesl::new_barebones();
compiler
.set_options(compile_options)
.use_sourcemap(!options.no_sourcemap)
.set_mangler(options.mangler.into());
match file_or_source {
FileOrSource::File(path) => {
let base = options
.base
.as_deref()
.or(path.parent())
.ok_or(CliError::FileNotFound)?;
let name = path
.file_stem()
.ok_or(CliError::FileNotFound)?
.to_string_lossy()
.to_string(); let path = ModulePath::new(PathOrigin::Absolute, vec![name]);
let resolver = StandardResolver::new(base);
let res = compiler.set_custom_resolver(resolver).compile(&path)?;
Ok(res)
}
FileOrSource::Source(source) => {
let base = std::env::current_dir().unwrap();
let name = "stdin";
let mut router = Router::new();
let mut resolver = VirtualResolver::new();
let path = ModulePath::new(PathOrigin::Absolute, vec![name.to_string()]);
resolver.add_module(ModulePath::new_root(), source.into());
router.mount_resolver(path.clone(), resolver);
router.mount_fallback_resolver(StandardResolver::new(base));
let res = compiler.set_custom_resolver(router).compile(&path)?;
Ok(res)
}
}
}
fn parse_binding(
b: &Binding,
wgsl: &TranslationUnit,
) -> Result<((u32, u32), RefInstance), CliError> {
let mut ctx = wesl::eval::Context::new(wgsl);
let ty_expr = wgsl
.global_declarations
.iter()
.find_map(|d| match d.node() {
syntax::GlobalDeclaration::Declaration(d) => {
let (group, binding) = d.attr_group_binding(&mut ctx).ok()?;
if group == b.group && binding == b.binding {
d.ty.clone()
} else {
None
}
}
_ => None,
})
.ok_or(CliError::ResourceNotFound(b.group, b.binding))?;
let ty = ty_eval_ty(&ty_expr, &mut ctx).map_err(|e| {
Diagnostic::from(e)
.with_ctx(&ctx)
.with_source(ty_expr.to_string())
})?;
let (storage, access) = match b.kind {
BindingType::Uniform => (AddressSpace::Uniform, AccessMode::Read),
BindingType::Storage => (AddressSpace::Storage, AccessMode::ReadWrite),
BindingType::ReadOnlyStorage => (AddressSpace::Storage, AccessMode::Read),
BindingType::Filtering => todo!(),
BindingType::NonFiltering => todo!(),
BindingType::Comparison => todo!(),
BindingType::Float => todo!(),
BindingType::UnfilterableFloat => todo!(),
BindingType::Sint => todo!(),
BindingType::Uint => todo!(),
BindingType::Depth => todo!(),
BindingType::WriteOnly => todo!(),
BindingType::ReadWrite => todo!(),
BindingType::ReadOnly => todo!(),
};
let inst = Instance::from_buffer(&b.data, &ty).ok_or_else(|| {
CliError::ResourceIncompatible(
b.group,
b.binding,
b.data.len() as u32,
ty.clone(),
ty.size_of().unwrap_or_default(),
)
})?;
Ok((
(b.group, b.binding),
RefInstance::new(inst, storage, access),
))
}
fn parse_override(src: &str, wgsl: &TranslationUnit) -> Result<Instance, CliError> {
let mut ctx = wesl::eval::Context::new(wgsl);
let expr = src
.parse::<syntax::Expression>()
.map_err(|e| Diagnostic::from(e).with_source(src.to_string()))?;
let inst = expr.eval_value(&mut ctx).map_err(|e| {
Diagnostic::from(e)
.with_ctx(&ctx)
.with_source(src.to_string())
})?;
Ok(inst)
}
fn main() {
let cli = Cli::try_parse()
.inspect_err(|e| {
eprintln!("invalid arguments: {e}");
std::process::exit(1)
})
.unwrap();
run(cli).inspect_err(|e| eprintln!("{e}")).ok();
}
fn file_or_source(path: Option<PathBuf>) -> Option<FileOrSource> {
path.map(FileOrSource::File).or_else(|| {
if std::io::stdin().is_terminal() {
return None;
}
let mut buf = String::new();
std::io::stdin()
.read_to_string(&mut buf)
.ok()
.map(|_| FileOrSource::Source(buf))
})
}
#[cfg(feature = "naga")]
fn naga_validate(source: &str) -> Result<(), CliError> {
let module = naga::front::wgsl::parse_str(source)
.map_err(|e| CliError::NagaParse(e, source.to_string()))?;
let mut validator = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
);
validator
.validate(&module)
.map_err(|e| CliError::NagaValid(Box::new(e), source.to_string()))?;
Ok(())
}
fn run(cli: Cli) -> Result<(), CliError> {
match cli.command {
Command::Check(args) => {
let source = if let Some(file) = &args.file {
fs::read_to_string(file).map_err(|_| CliError::FileNotFound)?
} else {
let mut source = String::new();
std::io::stdin()
.read_to_string(&mut source)
.map_err(|_| CliError::FileNotFound)?;
source
};
match &args.kind {
CheckKind::Wgsl => {
wgsl_parse::recognize_str(&source)
.map_err(|e| Diagnostic::from(e).with_source(source.clone()))?;
let mut wgsl = wgsl_parse::parse_str(&source)
.map_err(|e| Diagnostic::from(e).with_source(source.clone()))?;
wgsl.retarget_idents();
wesl::validate_wgsl(&wgsl)?;
#[cfg(feature = "naga")]
if args.naga {
naga_validate(&source)?;
}
}
CheckKind::Wesl => {
let mut wesl = TranslationUnit::from_str(&source)
.map_err(|e| Diagnostic::from(e).with_source(source))?;
wesl.retarget_idents();
wesl::validate_wesl(&wesl)?;
}
}
println!("OK");
}
Command::Compile(args) => {
let comp = file_or_source(args.file)
.map(|input| run_compile(&args.options, input))
.unwrap_or_else(|| Ok(CompileResult::default()))?;
#[cfg(feature = "naga")]
if !args.options.no_naga {
naga_validate(&comp.to_string())?;
}
println!("{comp}");
}
Command::Eval(args) => {
let comp = file_or_source(args.file)
.map(|input| run_compile(&args.options, input))
.unwrap_or_else(|| Ok(CompileResult::default()))?;
let eval = comp.eval(&args.expr)?;
if args.binary {
let buf = eval
.inst
.to_buffer()
.ok_or_else(|| CliError::NotStorable(eval.inst.ty()))?;
std::io::stdout().write_all(buf.as_slice()).unwrap();
} else {
println!("{}", eval.inst)
}
}
Command::Exec(args) => {
let comp = file_or_source(args.file)
.map(|input| run_compile(&args.options, input))
.unwrap_or_else(|| Ok(CompileResult::default()))?;
let inputs = Inputs::new_zero_initialized();
let resources = args
.resources
.iter()
.map(|b| parse_binding(b, &comp.syntax))
.collect::<Result<_, _>>()?;
let overrides = args
.overrides
.iter()
.map(|(name, expr)| -> Result<(String, Instance), CliError> {
Ok((name.to_string(), parse_override(expr, &comp.syntax)?))
})
.collect::<Result<_, _>>()?;
let exec = comp.exec(&args.entrypoint, inputs, resources, overrides)?;
if let Some(inst) = &exec.inst {
if args.binary {
let buf = inst
.to_buffer()
.ok_or_else(|| CliError::NotStorable(inst.ty()))?;
std::io::stdout().write_all(buf.as_slice()).unwrap();
} else {
println!("return: {inst}")
}
} else if !args.binary {
println!("return: void")
}
let resources = args
.resources
.iter()
.filter_map(|r| {
let inst = exec.resource(r.group, r.binding)?.clone();
let inst = inst.read().ok()?.to_owned();
Some((r.group, r.binding, inst))
})
.collect::<Vec<_>>();
for (group, binding, inst) in resources {
if args.binary {
let buf = inst
.to_buffer()
.ok_or_else(|| CliError::NotStorable(inst.ty()))?;
std::io::stdout().write_all(buf.as_slice()).unwrap();
} else {
println!("resource: group={group} binding={binding} value={inst}")
}
}
}
Command::Package(args) => {
let code = PkgBuilder::new(&args.name)
.scan_root(args.dir)
.expect("failed to scan WESL files")
.validate()
.map_err(|e| {
eprintln!("{e}");
panic!()
})
.unwrap()
.codegen();
println!("{code}");
}
};
Ok(())
}