mod bindings;
mod types;
use std::collections::HashMap;
use std::path::PathBuf;
pub use bindings::*;
use derive_builder::Builder;
use derive_more::IsVariant;
use enumflags2::{bitflags, BitFlags};
pub use naga::valid::Capabilities as WgslShaderIrCapabilities;
use proc_macro2::TokenStream;
use regex::Regex;
pub use types::*;
use crate::{
FastIndexMap, WGSLBindgen, WgslBindgenError, WgslType, WgslTypeSerializeStrategy,
};
#[bitflags(default = EmbedSource)]
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, IsVariant)]
pub enum WgslShaderSourceType {
EmbedSource,
EmbedWithNagaOilComposer,
ComposerWithRelativePath,
}
#[derive(Debug, Clone, Default)]
pub struct AdditionalScanDirectory {
pub module_import_root: Option<String>,
pub directory: String,
}
impl From<(Option<&str>, &str)> for AdditionalScanDirectory {
fn from((module_import_root, directory): (Option<&str>, &str)) -> Self {
Self {
module_import_root: module_import_root.map(ToString::to_string),
directory: directory.to_string(),
}
}
}
pub trait WgslTypeMapBuild {
fn build(&self, strategy: WgslTypeSerializeStrategy) -> WgslTypeMap;
}
impl WgslTypeMapBuild for WgslTypeMap {
fn build(&self, _: WgslTypeSerializeStrategy) -> WgslTypeMap {
self.clone()
}
}
#[derive(Clone, Debug)]
pub struct OverrideStruct {
pub from: String,
pub to: TokenStream,
pub alignment: usize,
}
impl From<(&str, TokenStream, usize)> for OverrideStruct {
fn from((from, to, alignment): (&str, TokenStream, usize)) -> Self {
OverrideStruct {
from: from.to_owned(),
to,
alignment,
}
}
}
#[derive(Clone, Debug)]
pub struct OverrideStructFieldType {
pub struct_regex: Regex,
pub field_regex: Regex,
pub override_type: TokenStream,
}
impl From<(Regex, Regex, TokenStream)> for OverrideStructFieldType {
fn from(
(struct_regex, field_regex, override_type): (Regex, Regex, TokenStream),
) -> Self {
Self {
struct_regex,
field_regex,
override_type,
}
}
}
impl From<(&str, &str, TokenStream)> for OverrideStructFieldType {
fn from((struct_regex, field_regex, override_type): (&str, &str, TokenStream)) -> Self {
Self {
struct_regex: Regex::new(struct_regex).expect("Failed to create struct regex"),
field_regex: Regex::new(field_regex).expect("Failed to create field regex"),
override_type,
}
}
}
#[derive(Clone, Debug)]
pub struct OverrideStructAlignment {
pub struct_regex: Regex,
pub alignment: u16,
}
impl From<(Regex, u16)> for OverrideStructAlignment {
fn from((struct_regex, alignment): (Regex, u16)) -> Self {
Self {
struct_regex,
alignment,
}
}
}
impl From<(&str, u16)> for OverrideStructAlignment {
fn from((struct_regex, alignment): (&str, u16)) -> Self {
Self {
struct_regex: Regex::new(struct_regex).expect("Failed to create struct regex"),
alignment,
}
}
}
#[derive(Clone, Debug)]
pub struct OverrideBindGroupEntryModulePath {
pub bind_group_entry_regex: Regex,
pub target_path: String,
}
impl From<(Regex, &str)> for OverrideBindGroupEntryModulePath {
fn from((bind_group_entry_regex, target_path): (Regex, &str)) -> Self {
Self {
bind_group_entry_regex,
target_path: target_path.to_string(),
}
}
}
impl From<(&str, &str)> for OverrideBindGroupEntryModulePath {
fn from((bind_group_entry_regex, target_path): (&str, &str)) -> Self {
Self {
bind_group_entry_regex: Regex::new(bind_group_entry_regex)
.expect("Failed to create bind group entry regex"),
target_path: target_path.to_string(),
}
}
}
#[derive(Clone, Debug)]
pub struct OverrideTextureFilterability {
pub binding_regex: Regex,
pub filterable: bool,
}
impl From<(Regex, bool)> for OverrideTextureFilterability {
fn from((binding_regex, filterable): (Regex, bool)) -> Self {
Self {
binding_regex,
filterable,
}
}
}
impl From<(&str, bool)> for OverrideTextureFilterability {
fn from((binding_regex, filterable): (&str, bool)) -> Self {
Self {
binding_regex: Regex::new(binding_regex).expect("Failed to create binding regex"),
filterable,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SamplerType {
Filtering,
NonFiltering,
Comparison,
}
#[derive(Clone, Debug)]
pub struct OverrideSamplerType {
pub binding_regex: Regex,
pub sampler_type: SamplerType,
}
impl From<(Regex, SamplerType)> for OverrideSamplerType {
fn from((binding_regex, sampler_type): (Regex, SamplerType)) -> Self {
Self {
binding_regex,
sampler_type,
}
}
}
impl From<(&str, SamplerType)> for OverrideSamplerType {
fn from((binding_regex, sampler_type): (&str, SamplerType)) -> Self {
Self {
binding_regex: Regex::new(binding_regex).expect("Failed to create binding regex"),
sampler_type,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub enum WgslTypeVisibility {
#[default]
Public,
RestrictedCrate,
RestrictedSuper,
}
#[derive(Debug, Default, Builder)]
#[builder(
setter(into),
field(private),
build_fn(private, name = "fallible_build")
)]
pub struct WgslBindgenOption {
#[builder(setter(each(name = "add_entry_point", into)))]
pub entry_points: Vec<String>,
#[builder(default, setter(strip_option, into))]
pub module_import_root: Option<String>,
#[builder(setter(into))]
pub workspace_root: PathBuf,
#[builder(default = "true")]
pub emit_rerun_if_change: bool,
#[builder(default = "false")]
pub skip_header_comments: bool,
#[builder(default = "false")]
pub skip_hash_check: bool,
#[builder(default)]
pub serialization_strategy: WgslTypeSerializeStrategy,
#[builder(default = "false")]
pub derive_serde: bool,
#[builder(default)]
pub shader_source_type: BitFlags<WgslShaderSourceType>,
#[builder(default, setter(strip_option, into))]
pub output: Option<PathBuf>,
#[builder(default, setter(into, each(name = "additional_scan_dir", into)))]
pub additional_scan_dirs: Vec<AdditionalScanDirectory>,
#[builder(default, setter(strip_option))]
pub ir_capabilities: Option<WgslShaderIrCapabilities>,
#[builder(default, setter(strip_option, into))]
pub short_constructor: Option<i32>,
#[builder(default)]
pub type_visibility: WgslTypeVisibility,
#[builder(setter(custom))]
pub type_map: WgslTypeMap,
#[builder(default, setter(each(name = "add_override_struct_mapping", into)))]
pub override_struct: Vec<OverrideStruct>,
#[builder(default, setter(into))]
pub override_struct_field_type: Vec<OverrideStructFieldType>,
#[builder(default, setter(into))]
pub override_struct_alignment: Vec<OverrideStructAlignment>,
#[builder(default, setter(into))]
pub override_bind_group_entry_module_path: Vec<OverrideBindGroupEntryModulePath>,
#[builder(default, setter(each(name = "add_custom_padding_field_regexp", into)))]
pub custom_padding_field_regexps: Vec<Regex>,
#[builder(default = "false")]
pub always_generate_init_struct: bool,
#[builder(default, setter(custom))]
pub extra_binding_generator: Option<BindingGenerator>,
#[builder(default, setter(custom))]
pub wgpu_binding_generator: BindingGenerator,
#[builder(default, setter(into))]
pub override_texture_filterability: Vec<OverrideTextureFilterability>,
#[builder(default, setter(into))]
pub override_sampler_type: Vec<OverrideSamplerType>,
#[builder(default, setter(into))]
pub shader_defs: Vec<(String, naga_oil::compose::ShaderDefValue)>,
}
impl WgslBindgenOptionBuilder {
pub fn build(&mut self) -> Result<WGSLBindgen, WgslBindgenError> {
self.merge_struct_type_overrides();
let options = self.fallible_build()?;
WGSLBindgen::new(options)
}
pub fn type_map(&mut self, map_build: impl WgslTypeMapBuild) -> &mut Self {
let serialization_strategy = self
.serialization_strategy
.expect("Serialization strategy must be set before `wgs_type_map`");
let map = map_build.build(serialization_strategy);
match self.type_map.as_mut() {
Some(m) => m.extend(map),
None => self.type_map = Some(map),
}
self
}
pub fn add_shader_def(
&mut self,
name: impl Into<String>,
value: naga_oil::compose::ShaderDefValue,
) -> &mut Self {
if self.shader_defs.is_none() {
self.shader_defs = Some(Vec::new());
}
self
.shader_defs
.as_mut()
.unwrap()
.push((name.into(), value));
self
}
pub fn add_shader_defs(
&mut self,
defs: Vec<(String, naga_oil::compose::ShaderDefValue)>,
) -> &mut Self {
match self.shader_defs.as_mut() {
Some(existing) => existing.extend(defs),
None => self.shader_defs = Some(defs),
}
self
}
fn merge_struct_type_overrides(&mut self) {
let struct_mappings = self
.override_struct
.iter()
.flatten()
.map(
|OverrideStruct {
from,
to,
alignment,
}| {
let wgsl_type = WgslType::Struct {
fully_qualified_name: from.clone(),
};
(wgsl_type, WgslTypeInfo::new(to.clone(), *alignment))
},
)
.collect::<FastIndexMap<_, _>>();
self.type_map(struct_mappings);
}
pub fn extra_binding_generator(
&mut self,
config: impl GetBindingsGeneratorConfig,
) -> &mut Self {
let generator = Some(config.get_generator_config());
self.extra_binding_generator = Some(generator);
self
}
}