extern crate wgpu_types as wgpu;
use bindgroup::{bind_groups_module, get_bind_group_data};
use consts::pipeline_overridable_constants;
use entry::{entry_point_constants, fragment_states, vertex_states, vertex_struct_methods};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Ident, Index};
use thiserror::Error;
mod bindgroup;
mod consts;
mod entry;
mod structs;
mod wgsl;
#[derive(Debug, PartialEq, Eq, Error)]
pub enum CreateModuleError {
#[error("bind groups are non-consecutive or do not start from 0")]
NonConsecutiveBindGroups,
#[error("duplicate binding found with index `{binding}`")]
DuplicateBinding { binding: u32 },
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
pub struct WriteOptions {
pub derive_bytemuck_vertex: bool,
pub derive_bytemuck_host_shareable: bool,
pub derive_encase_host_shareable: bool,
pub derive_serde: bool,
pub matrix_vector_types: MatrixVectorTypes,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MatrixVectorTypes {
Rust,
Glam,
Nalgebra,
}
impl Default for MatrixVectorTypes {
fn default() -> Self {
Self::Rust
}
}
pub fn create_shader_module(
wgsl_source: &str,
wgsl_include_path: &str,
options: WriteOptions,
) -> Result<String, CreateModuleError> {
create_shader_module_inner(wgsl_source, Some(wgsl_include_path), options)
}
pub fn create_shader_module_embedded(
wgsl_source: &str,
options: WriteOptions,
) -> Result<String, CreateModuleError> {
create_shader_module_inner(wgsl_source, None, options)
}
fn create_shader_module_inner(
wgsl_source: &str,
wgsl_include_path: Option<&str>,
options: WriteOptions,
) -> Result<String, CreateModuleError> {
let module = naga::front::wgsl::parse_str(wgsl_source).unwrap();
let bind_group_data = get_bind_group_data(&module)?;
let shader_stages = wgsl::shader_stages(&module);
let structs = structs::structs(&module, options);
let consts = consts::consts(&module);
let bind_groups_module = bind_groups_module(&bind_group_data, shader_stages);
let vertex_module = vertex_struct_methods(&module);
let compute_module = compute_module(&module);
let entry_point_constants = entry_point_constants(&module);
let vertex_states = vertex_states(&module);
let fragment_states = fragment_states(&module);
let included_source = wgsl_include_path
.map(|p| quote!(include_str!(#p)))
.unwrap_or_else(|| quote!(#wgsl_source));
let create_shader_module = quote! {
pub fn create_shader_module(device: &wgpu::Device) -> wgpu::ShaderModule {
let source = std::borrow::Cow::Borrowed(#included_source);
device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(source)
})
}
};
let bind_group_layouts: Vec<_> = bind_group_data
.keys()
.map(|group_no| {
let group = indexed_name_to_ident("BindGroup", *group_no);
quote!(bind_groups::#group::get_bind_group_layout(device))
})
.collect();
let create_pipeline_layout = quote! {
pub fn create_pipeline_layout(device: &wgpu::Device) -> wgpu::PipelineLayout {
device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[
#(&#bind_group_layouts),*
],
push_constant_ranges: &[],
})
}
};
let override_constants = pipeline_overridable_constants(&module);
let output = quote! {
#(#structs)*
#(#consts)*
#override_constants
#bind_groups_module
#vertex_module
#compute_module
#entry_point_constants
#vertex_states
#fragment_states
#create_shader_module
#create_pipeline_layout
};
Ok(pretty_print(&output))
}
fn pretty_print(tokens: &TokenStream) -> String {
let file = syn::parse_file(&tokens.to_string()).unwrap();
prettyplease::unparse(&file)
}
fn indexed_name_to_ident(name: &str, index: u32) -> Ident {
Ident::new(&format!("{name}{index}"), Span::call_site())
}
fn compute_module(module: &naga::Module) -> TokenStream {
let entry_points: Vec<_> = module
.entry_points
.iter()
.filter_map(|e| {
if e.stage == naga::ShaderStage::Compute {
let workgroup_size_constant = workgroup_size(e);
let create_pipeline = create_compute_pipeline(e);
Some(quote! {
#workgroup_size_constant
#create_pipeline
})
} else {
None
}
})
.collect();
if entry_points.is_empty() {
quote!()
} else {
quote! {
pub mod compute {
#(#entry_points)*
}
}
}
}
fn create_compute_pipeline(e: &naga::EntryPoint) -> TokenStream {
let pipeline_name = Ident::new(&format!("create_{}_pipeline", e.name), Span::call_site());
let entry_point = &e.name;
let label = format!("Compute Pipeline {}", e.name);
quote! {
pub fn #pipeline_name(device: &wgpu::Device) -> wgpu::ComputePipeline {
let module = super::create_shader_module(device);
let layout = super::create_pipeline_layout(device);
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(#label),
layout: Some(&layout),
module: &module,
entry_point: #entry_point,
compilation_options: Default::default(),
})
}
}
}
fn workgroup_size(e: &naga::EntryPoint) -> TokenStream {
let name = Ident::new(
&format!("{}_WORKGROUP_SIZE", e.name.to_uppercase()),
Span::call_site(),
);
let [x, y, z] = e.workgroup_size.map(|s| Index::from(s as usize));
quote!(pub const #name: [u32; 3] = [#x, #y, #z];)
}
#[cfg(test)]
#[macro_export]
macro_rules! assert_tokens_eq {
($a:expr, $b:expr) => {
pretty_assertions::assert_eq!(crate::pretty_print(&$a), crate::pretty_print(&$b))
};
}
#[cfg(test)]
mod test {
use super::*;
use indoc::indoc;
#[test]
fn create_shader_module_include_source() {
let source = indoc! {r#"
@fragment
fn fs_main() {}
"#};
let actual = create_shader_module(source, "shader.wgsl", WriteOptions::default()).unwrap();
pretty_assertions::assert_eq!(
indoc! {r#"
pub const ENTRY_FS_MAIN: &str = "fs_main";
#[derive(Debug)]
pub struct FragmentEntry<const N: usize> {
pub entry_point: &'static str,
pub targets: [Option<wgpu::ColorTargetState>; N],
pub constants: std::collections::HashMap<String, f64>,
}
pub fn fragment_state<'a, const N: usize>(
module: &'a wgpu::ShaderModule,
entry: &'a FragmentEntry<N>,
) -> wgpu::FragmentState<'a> {
wgpu::FragmentState {
module,
entry_point: entry.entry_point,
targets: &entry.targets,
compilation_options: wgpu::PipelineCompilationOptions {
constants: &entry.constants,
..Default::default()
},
}
}
pub fn fs_main_entry(targets: [Option<wgpu::ColorTargetState>; 0]) -> FragmentEntry<0> {
FragmentEntry {
entry_point: ENTRY_FS_MAIN,
targets,
constants: Default::default(),
}
}
pub fn create_shader_module(device: &wgpu::Device) -> wgpu::ShaderModule {
let source = std::borrow::Cow::Borrowed(include_str!("shader.wgsl"));
device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(source),
})
}
pub fn create_pipeline_layout(device: &wgpu::Device) -> wgpu::PipelineLayout {
device
.create_pipeline_layout(
&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[],
push_constant_ranges: &[],
},
)
}
"#},
actual
);
}
#[test]
fn create_shader_module_embed_source() {
let source = indoc! {r#"
@fragment
fn fs_main() {}
"#};
let actual = create_shader_module_embedded(source, WriteOptions::default()).unwrap();
pretty_assertions::assert_eq!(
indoc! {r#"
pub const ENTRY_FS_MAIN: &str = "fs_main";
#[derive(Debug)]
pub struct FragmentEntry<const N: usize> {
pub entry_point: &'static str,
pub targets: [Option<wgpu::ColorTargetState>; N],
pub constants: std::collections::HashMap<String, f64>,
}
pub fn fragment_state<'a, const N: usize>(
module: &'a wgpu::ShaderModule,
entry: &'a FragmentEntry<N>,
) -> wgpu::FragmentState<'a> {
wgpu::FragmentState {
module,
entry_point: entry.entry_point,
targets: &entry.targets,
compilation_options: wgpu::PipelineCompilationOptions {
constants: &entry.constants,
..Default::default()
},
}
}
pub fn fs_main_entry(targets: [Option<wgpu::ColorTargetState>; 0]) -> FragmentEntry<0> {
FragmentEntry {
entry_point: ENTRY_FS_MAIN,
targets,
constants: Default::default(),
}
}
pub fn create_shader_module(device: &wgpu::Device) -> wgpu::ShaderModule {
let source = std::borrow::Cow::Borrowed("@fragment\nfn fs_main() {}\n");
device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(source),
})
}
pub fn create_pipeline_layout(device: &wgpu::Device) -> wgpu::PipelineLayout {
device
.create_pipeline_layout(
&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[],
push_constant_ranges: &[],
},
)
}
"#},
actual
);
}
#[test]
fn create_shader_module_consecutive_bind_groups() {
let source = indoc! {r#"
struct A {
f: vec4<f32>
};
@group(0) @binding(0) var<uniform> a: A;
@group(1) @binding(0) var<uniform> b: A;
@vertex
fn vs_main() {}
@fragment
fn fs_main() {}
"#};
create_shader_module(source, "shader.wgsl", WriteOptions::default()).unwrap();
}
#[test]
fn create_shader_module_non_consecutive_bind_groups() {
let source = indoc! {r#"
@group(0) @binding(0) var<uniform> a: vec4<f32>;
@group(1) @binding(0) var<uniform> b: vec4<f32>;
@group(3) @binding(0) var<uniform> c: vec4<f32>;
@fragment
fn main() {}
"#};
let result = create_shader_module(source, "shader.wgsl", WriteOptions::default());
assert!(matches!(
result,
Err(CreateModuleError::NonConsecutiveBindGroups)
));
}
#[test]
fn create_shader_module_repeated_bindings() {
let source = indoc! {r#"
struct A {
f: vec4<f32>
};
@group(0) @binding(2) var<uniform> a: A;
@group(0) @binding(2) var<uniform> b: A;
@fragment
fn main() {}
"#};
let result = create_shader_module(source, "shader.wgsl", WriteOptions::default());
assert!(matches!(
result,
Err(CreateModuleError::DuplicateBinding { binding: 2 })
));
}
#[test]
fn write_vertex_module_empty() {
let source = indoc! {r#"
@vertex
fn main() {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
let actual = vertex_struct_methods(&module);
assert_tokens_eq!(quote!(), actual);
}
#[test]
fn write_vertex_module_single_input_float32() {
let source = indoc! {r#"
struct VertexInput0 {
@location(0) a: f32,
@location(1) b: vec2<f32>,
@location(2) c: vec3<f32>,
@location(3) d: vec4<f32>,
};
@vertex
fn main(in0: VertexInput0) {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
let actual = vertex_struct_methods(&module);
assert_tokens_eq!(
quote! {
impl VertexInput0 {
pub const VERTEX_ATTRIBUTES: [wgpu::VertexAttribute; 4] = [
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32,
offset: std::mem::offset_of!(VertexInput0, a) as u64,
shader_location: 0,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x2,
offset: std::mem::offset_of!(VertexInput0, b) as u64,
shader_location: 1,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x3,
offset: std::mem::offset_of!(VertexInput0, c) as u64,
shader_location: 2,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x4,
offset: std::mem::offset_of!(VertexInput0, d) as u64,
shader_location: 3,
},
];
pub const fn vertex_buffer_layout(
step_mode: wgpu::VertexStepMode,
) -> wgpu::VertexBufferLayout<'static> {
wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<VertexInput0>() as u64,
step_mode,
attributes: &VertexInput0::VERTEX_ATTRIBUTES,
}
}
}
},
actual
);
}
#[test]
fn write_vertex_module_single_input_float64() {
let source = indoc! {r#"
struct VertexInput0 {
@location(0) a: f64,
@location(1) b: vec2<f64>,
@location(2) c: vec3<f64>,
@location(3) d: vec4<f64>,
};
@vertex
fn main(in0: VertexInput0) {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
let actual = vertex_struct_methods(&module);
assert_tokens_eq!(
quote! {
impl VertexInput0 {
pub const VERTEX_ATTRIBUTES: [wgpu::VertexAttribute; 4] = [
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float64,
offset: std::mem::offset_of!(VertexInput0, a) as u64,
shader_location: 0,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float64x2,
offset: std::mem::offset_of!(VertexInput0, b) as u64,
shader_location: 1,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float64x3,
offset: std::mem::offset_of!(VertexInput0, c) as u64,
shader_location: 2,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float64x4,
offset: std::mem::offset_of!(VertexInput0, d) as u64,
shader_location: 3,
},
];
pub const fn vertex_buffer_layout(
step_mode: wgpu::VertexStepMode,
) -> wgpu::VertexBufferLayout<'static> {
wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<VertexInput0>() as u64,
step_mode,
attributes: &VertexInput0::VERTEX_ATTRIBUTES,
}
}
}
},
actual
);
}
#[test]
fn write_vertex_module_single_input_sint32() {
let source = indoc! {r#"
struct VertexInput0 {
@location(0) a: i32,
@location(1) a: vec2<i32>,
@location(2) a: vec3<i32>,
@location(3) a: vec4<i32>,
};
@vertex
fn main(in0: VertexInput0) {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
let actual = vertex_struct_methods(&module);
assert_tokens_eq!(
quote! {
impl VertexInput0 {
pub const VERTEX_ATTRIBUTES: [wgpu::VertexAttribute; 4] = [
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Sint32,
offset: std::mem::offset_of!(VertexInput0, a) as u64,
shader_location: 0,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Sint32x2,
offset: std::mem::offset_of!(VertexInput0, a) as u64,
shader_location: 1,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Sint32x3,
offset: std::mem::offset_of!(VertexInput0, a) as u64,
shader_location: 2,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Sint32x4,
offset: std::mem::offset_of!(VertexInput0, a) as u64,
shader_location: 3,
},
];
pub const fn vertex_buffer_layout(
step_mode: wgpu::VertexStepMode,
) -> wgpu::VertexBufferLayout<'static> {
wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<VertexInput0>() as u64,
step_mode,
attributes: &VertexInput0::VERTEX_ATTRIBUTES,
}
}
}
},
actual
);
}
#[test]
fn write_vertex_module_single_input_uint32() {
let source = indoc! {r#"
struct VertexInput0 {
@location(0) a: u32,
@location(1) b: vec2<u32>,
@location(2) c: vec3<u32>,
@location(3) d: vec4<u32>,
};
@vertex
fn main(in0: VertexInput0) {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
let actual = vertex_struct_methods(&module);
assert_tokens_eq!(
quote! {
impl VertexInput0 {
pub const VERTEX_ATTRIBUTES: [wgpu::VertexAttribute; 4] = [
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Uint32,
offset: std::mem::offset_of!(VertexInput0, a) as u64,
shader_location: 0,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Uint32x2,
offset: std::mem::offset_of!(VertexInput0, b) as u64,
shader_location: 1,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Uint32x3,
offset: std::mem::offset_of!(VertexInput0, c) as u64,
shader_location: 2,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Uint32x4,
offset: std::mem::offset_of!(VertexInput0, d) as u64,
shader_location: 3,
},
];
pub const fn vertex_buffer_layout(
step_mode: wgpu::VertexStepMode,
) -> wgpu::VertexBufferLayout<'static> {
wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<VertexInput0>() as u64,
step_mode,
attributes: &VertexInput0::VERTEX_ATTRIBUTES,
}
}
}
},
actual
);
}
#[test]
fn write_compute_module_empty() {
let source = indoc! {r#"
@vertex
fn main() {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
let actual = compute_module(&module);
assert_tokens_eq!(quote!(), actual);
}
#[test]
fn write_compute_module_multiple_entries() {
let source = indoc! {r#"
@compute
@workgroup_size(1,2,3)
fn main1() {}
@compute
@workgroup_size(256)
fn main2() {}
"#
};
let module = naga::front::wgsl::parse_str(source).unwrap();
let actual = compute_module(&module);
assert_tokens_eq!(
quote! {
pub mod compute {
pub const MAIN1_WORKGROUP_SIZE: [u32; 3] = [1, 2, 3];
pub fn create_main1_pipeline(device: &wgpu::Device) -> wgpu::ComputePipeline {
let module = super::create_shader_module(device);
let layout = super::create_pipeline_layout(device);
device
.create_compute_pipeline(
&wgpu::ComputePipelineDescriptor {
label: Some("Compute Pipeline main1"),
layout: Some(&layout),
module: &module,
entry_point: "main1",
compilation_options: Default::default(),
},
)
}
pub const MAIN2_WORKGROUP_SIZE: [u32; 3] = [256, 1, 1];
pub fn create_main2_pipeline(device: &wgpu::Device) -> wgpu::ComputePipeline {
let module = super::create_shader_module(device);
let layout = super::create_pipeline_layout(device);
device
.create_compute_pipeline(
&wgpu::ComputePipelineDescriptor {
label: Some("Compute Pipeline main2"),
layout: Some(&layout),
module: &module,
entry_point: "main2",
compilation_options: Default::default(),
},
)
}
}
},
actual
);
}
}