#![allow(dead_code, unused)]
extern crate wgpu_types as wgpu;
use crate::quote_gen::{custom_vector_matrix_assertions, MOD_STRUCT_ASSERTIONS};
use bevy_util::SourceWithFullDependenciesResult;
use case::CaseExt;
use derive_more::IsVariant;
use generate::bind_group::RawShadersBindGroups;
use generate::entry::{self, entry_point_constants, vertex_struct_impls};
use generate::{bind_group, consts, pipeline, shader_module, shader_registry};
use heck::ToPascalCase;
use proc_macro2::{Span, TokenStream};
use qs::{format_ident, quote, Ident, Index};
use quote_gen::RustModBuilder;
use smallvec::SmallVec;
use thiserror::Error;
pub mod bevy_util;
mod bindgen;
mod generate;
mod naga_util;
mod quote_gen;
mod structs;
pub mod test_helper;
mod types;
mod wgsl;
mod wgsl_type;
pub mod qs {
pub use proc_macro2::TokenStream;
pub use quote::{format_ident, quote};
pub use syn::{Ident, Index};
}
pub use bindgen::*;
pub use naga::FastIndexMap;
pub use regex::Regex;
pub use types::*;
pub use wgsl_type::*;
pub use naga_oil::compose::ShaderDefValue;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, IsVariant)]
pub enum WgslTypeSerializeStrategy {
#[default]
Encase,
Bytemuck,
}
#[derive(Debug, Error)]
pub enum CreateModuleError {
#[error("bind groups are non-consecutive or do not start from 0")]
NonConsecutiveBindGroups,
#[error("duplicate binding found with index `{binding}`")]
DuplicateBinding { binding: u32 },
#[error("duplicate content found `{0}`")]
RustModuleBuilderError(#[from] quote_gen::RustModuleBuilderError),
}
#[derive(Debug)]
pub(crate) struct WgslEntryResult<'a> {
mod_name: String,
naga_module: naga::Module,
source_including_deps: SourceWithFullDependenciesResult<'a>,
}
impl<'a> WgslEntryResult<'a> {
pub fn get_shader_variant(&self) -> TokenStream {
let mod_name = sanitize_and_pascal_case(&self.mod_name);
let enum_variant = format_ident!("{}", mod_name);
quote! { #enum_variant }
}
pub fn get_mod_path(&self) -> TokenStream {
let mod_path_parts = self
.mod_name
.split("::")
.map(|part| format_ident!("{}", part))
.collect::<SmallVec<[_; 4]>>();
quote! {
#(#mod_path_parts)::*
}
}
}
fn create_rust_bindings(
entries: Vec<WgslEntryResult<'_>>,
options: &WgslBindgenOption,
) -> Result<String, CreateModuleError> {
let mut mod_builder = RustModBuilder::new(true, true);
if let Some(custom_wgsl_type_asserts) = custom_vector_matrix_assertions(options) {
mod_builder.add(MOD_STRUCT_ASSERTIONS, custom_wgsl_type_asserts);
}
let mut all_shader_bind_groups = RawShadersBindGroups::new(options);
let mut all_shader_vertex_inputs =
generate::vertex_input_collector::RawShadersVertexInputs::new();
for entry in entries.iter() {
let WgslEntryResult {
mod_name,
naga_module,
..
} = entry;
mod_builder.add_items(structs::structs_items(mod_name, naga_module, options))?;
mod_builder.add_items(consts::consts_items(mod_name, naga_module))?;
mod_builder
.add(mod_name, consts::pipeline_overridable_constants(naga_module, options));
let shader_vertex_inputs =
generate::vertex_input_collector::RawShadersVertexInputs::from_module(
mod_name,
naga_module,
);
all_shader_vertex_inputs.add(shader_vertex_inputs);
mod_builder.add(
mod_name,
shader_module::compute_module(naga_module, options.shader_source_type),
);
mod_builder.add(mod_name, entry_point_constants(naga_module));
mod_builder.add(mod_name, entry::vertex_states(mod_name, naga_module));
mod_builder.add(mod_name, entry::fragment_states(naga_module));
let shader_stages = wgsl::shader_stages(naga_module);
let shader_bind_groups = bind_group::get_bind_group_data_for_entry(
naga_module,
shader_stages,
options,
mod_name,
)?;
all_shader_bind_groups.add(shader_bind_groups);
}
let reusable_bind_groups = all_shader_bind_groups.create_reusable_shader_bind_groups();
let bind_groups = reusable_bind_groups.generate_bind_groups(options);
mod_builder.add_items(bind_groups)?;
let vertex_input_impls = all_shader_vertex_inputs.generate_vertex_input_impls();
mod_builder.add_items(vertex_input_impls)?;
for entry in entries.iter() {
let WgslEntryResult {
mod_name,
naga_module,
..
} = entry;
let entry_name = sanitize_and_pascal_case(mod_name);
if let Some(shader_bind_groups) = reusable_bind_groups
.entrypoint_bindgroups
.get(mod_name.as_str())
{
let create_pipeline_layout = pipeline::create_pipeline_layout_fn(
&entry_name,
naga_module,
shader_bind_groups,
options,
);
mod_builder.add(mod_name, create_pipeline_layout);
}
mod_builder.add(mod_name, shader_module::shader_module(entry, options));
}
let mod_token_stream = mod_builder.generate();
let shader_registry = shader_registry::build_shader_registry(&entries, options);
let output = quote! {
#![allow(unused, non_snake_case, non_camel_case_types, non_upper_case_globals)]
#shader_registry
#mod_token_stream
};
Ok(pretty_print(&output))
}
fn indexed_name_ident(name: &str, index: u32) -> Ident {
format_ident!("{name}{index}")
}
fn sanitize_and_pascal_case(v: &str) -> String {
let normalized = v.replace("::", "_");
normalized
.chars()
.filter(|ch| ch.is_alphanumeric() || *ch == '_')
.collect::<String>()
.to_pascal_case()
}
fn sanitized_upper_snake_case(v: &str) -> String {
v.replace("::", "_")
.chars()
.filter(|ch| ch.is_alphanumeric() || *ch == '_')
.collect::<String>()
.to_snake()
.to_uppercase()
.split('_')
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join("_")
}
pub fn pretty_print(tokens: &TokenStream) -> String {
let code = tokens.to_string();
match format_with_rustfmt(&code) {
Ok(formatted) => formatted,
Err(error) => {
eprintln!(
"Warning: rustfmt formatting failed ({error}), falling back to prettyplease",
);
let file = syn::parse_file(&code).unwrap();
prettyplease::unparse(&file)
}
}
}
fn format_with_rustfmt(code: &str) -> Result<String, Box<dyn std::error::Error>> {
use std::io::Write;
use std::process::{Command, Stdio};
let mut child = Command::new("rustfmt")
.arg("--emit")
.arg("stdout")
.arg("--quiet")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
if let Some(stdin) = child.stdin.as_mut() {
stdin.write_all(code.as_bytes())?;
} else {
return Err("Failed to open stdin".into());
}
let output = child.wait_with_output()?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?)
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(format!("rustfmt failed: {stderr}").into())
}
}
#[cfg(test)]
mod test {
use indoc::indoc;
use self::bevy_util::source_file::SourceFile;
use super::*;
#[test]
fn test_sanitize_and_pascal_case() {
assert_eq!(sanitize_and_pascal_case("segment"), "Segment");
assert_eq!(sanitize_and_pascal_case("lines::segment"), "LinesSegment");
assert_eq!(
sanitize_and_pascal_case("compute_demo::particle_physics"),
"ComputeDemoParticlePhysics"
);
assert_eq!(
sanitize_and_pascal_case("bevy_pbr::mesh_view_bindings"),
"BevyPbrMeshViewBindings"
);
assert_eq!(sanitize_and_pascal_case("a::b::c::d"), "ABCD");
assert_eq!(sanitize_and_pascal_case("simple_name"), "SimpleName");
}
fn create_shader_module(
source: &str,
options: WgslBindgenOption,
) -> Result<String, CreateModuleError> {
let naga_module = naga::front::wgsl::parse_str(source).unwrap();
let dummy_source = SourceFile::create(SourceFilePath::new(""), None, "".into());
let entry = WgslEntryResult {
mod_name: "test".into(),
naga_module,
source_including_deps: SourceWithFullDependenciesResult {
full_dependencies: Default::default(),
source_file: &dummy_source,
},
};
create_rust_bindings(vec![entry], &options)
}
#[test]
fn create_shader_module_embed_source() {
let source = indoc! {r#"
var<immediate> consts: vec4<f32>;
@fragment
fn fs_main() {}
"#};
let actual = create_shader_module(
source,
WgslBindgenOption {
shader_source_type: [WgslShaderSourceType::EmbedSource].into_iter().collect(),
..Default::default()
},
)
.unwrap();
insta::assert_snapshot!(actual);
}
#[test]
fn create_shader_module_consecutive_bind_groups() {
let source = indoc! {r#"
struct A {
f: vec4<f32>
};
@group(0) @binding(0) var<uniform> a: A;
@group(1) @binding(0) var<uniform> b: A;
@vertex
fn vs_main() -> @builtin(position) vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 1.0);
}
@fragment
fn fs_main() {}
"#};
create_shader_module(source, WgslBindgenOption::default()).unwrap();
}
#[test]
fn create_shader_module_non_consecutive_bind_groups() {
let source = indoc! {r#"
@group(0) @binding(0) var<uniform> a: vec4<f32>;
@group(1) @binding(0) var<uniform> b: vec4<f32>;
@group(3) @binding(0) var<uniform> c: vec4<f32>;
@fragment
fn main() {}
"#};
let result = create_shader_module(source, WgslBindgenOption::default());
assert!(matches!(result, Err(CreateModuleError::NonConsecutiveBindGroups)));
}
#[test]
fn create_shader_module_repeated_bindings() {
let source = indoc! {r#"
struct A {
f: vec4<f32>
};
@group(0) @binding(2) var<uniform> a: A;
@group(0) @binding(2) var<uniform> b: A;
@fragment
fn main() {}
"#};
let result = create_shader_module(source, WgslBindgenOption::default());
assert!(matches!(result, Err(CreateModuleError::DuplicateBinding { binding: 2 })));
}
}