use naga::StructMember;
use proc_macro2::TokenStream;
use quote::quote;
use crate::quote_gen::RustSourceItemPath;
pub fn shader_stages(module: &naga::Module) -> wgpu::ShaderStages {
module
.entry_points
.iter()
.map(|entry| match entry.stage {
naga::ShaderStage::Vertex => wgpu::ShaderStages::VERTEX,
naga::ShaderStage::Fragment => wgpu::ShaderStages::FRAGMENT,
naga::ShaderStage::Compute => wgpu::ShaderStages::COMPUTE,
naga::ShaderStage::Task => wgpu::ShaderStages::TASK,
naga::ShaderStage::Mesh => wgpu::ShaderStages::MESH,
naga::ShaderStage::RayGeneration => wgpu::ShaderStages::RAY_GENERATION,
naga::ShaderStage::Miss => wgpu::ShaderStages::MISS,
naga::ShaderStage::AnyHit => wgpu::ShaderStages::ANY_HIT,
naga::ShaderStage::ClosestHit => wgpu::ShaderStages::CLOSEST_HIT,
})
.collect()
}
pub fn buffer_binding_type(storage: naga::AddressSpace) -> TokenStream {
match storage {
naga::AddressSpace::Uniform => quote!(wgpu::BufferBindingType::Uniform),
naga::AddressSpace::Storage { access } => {
let _is_read = access.contains(naga::StorageAccess::LOAD);
let is_write = access.contains(naga::StorageAccess::STORE);
if is_write {
quote!(wgpu::BufferBindingType::Storage { read_only: false })
} else {
quote!(wgpu::BufferBindingType::Storage { read_only: true })
}
}
_ => quote!(wgpu::BufferBindingType::Uniform),
}
}
pub fn vertex_format(ty: &naga::Type) -> wgpu::VertexFormat {
match &ty.inner {
naga::TypeInner::Scalar(scalar) => match (scalar.kind, scalar.width) {
(naga::ScalarKind::Sint, 4) => wgpu::VertexFormat::Sint32,
(naga::ScalarKind::Uint, 4) => wgpu::VertexFormat::Uint32,
(naga::ScalarKind::Float, 4) => wgpu::VertexFormat::Float32,
(naga::ScalarKind::Float, 8) => wgpu::VertexFormat::Float64,
_ => todo!(),
},
naga::TypeInner::Vector { size, scalar } => match size {
naga::VectorSize::Bi => match (scalar.kind, scalar.width) {
(naga::ScalarKind::Sint, 1) => wgpu::VertexFormat::Sint8x2,
(naga::ScalarKind::Uint, 1) => wgpu::VertexFormat::Uint8x2,
(naga::ScalarKind::Sint, 2) => wgpu::VertexFormat::Sint16x2,
(naga::ScalarKind::Uint, 2) => wgpu::VertexFormat::Uint16x2,
(naga::ScalarKind::Uint, 4) => wgpu::VertexFormat::Uint32x2,
(naga::ScalarKind::Sint, 4) => wgpu::VertexFormat::Sint32x2,
(naga::ScalarKind::Float, 4) => wgpu::VertexFormat::Float32x2,
(naga::ScalarKind::Float, 8) => wgpu::VertexFormat::Float64x2,
_ => todo!(),
},
naga::VectorSize::Tri => match (scalar.kind, scalar.width) {
(naga::ScalarKind::Uint, 4) => wgpu::VertexFormat::Uint32x3,
(naga::ScalarKind::Sint, 4) => wgpu::VertexFormat::Sint32x3,
(naga::ScalarKind::Float, 4) => wgpu::VertexFormat::Float32x3,
(naga::ScalarKind::Float, 8) => wgpu::VertexFormat::Float64x3,
_ => todo!(),
},
naga::VectorSize::Quad => match (scalar.kind, scalar.width) {
(naga::ScalarKind::Sint, 1) => wgpu::VertexFormat::Sint8x4,
(naga::ScalarKind::Uint, 1) => wgpu::VertexFormat::Uint8x4,
(naga::ScalarKind::Sint, 2) => wgpu::VertexFormat::Sint16x4,
(naga::ScalarKind::Uint, 2) => wgpu::VertexFormat::Uint16x4,
(naga::ScalarKind::Uint, 4) => wgpu::VertexFormat::Uint32x4,
(naga::ScalarKind::Sint, 4) => wgpu::VertexFormat::Sint32x4,
(naga::ScalarKind::Float, 4) => wgpu::VertexFormat::Float32x4,
(naga::ScalarKind::Float, 8) => wgpu::VertexFormat::Float64x4,
_ => todo!(),
},
},
_ => todo!(), }
}
pub struct VertexInput<'a> {
pub item_path: RustSourceItemPath,
pub fields: Vec<(u32, StructMember)>,
pub naga_module: &'a naga::Module,
}
pub fn get_vertex_input_structs<'a>(
invoking_entry_module: &str,
module: &'a naga::Module,
) -> Vec<VertexInput<'a>> {
module
.entry_points
.iter()
.find(|e| e.stage == naga::ShaderStage::Vertex)
.map(|vertex_entry| {
vertex_entry
.function
.arguments
.iter()
.filter(|a| a.binding.is_none())
.filter_map(|argument| {
let arg_type = &module.types[argument.ty];
match &arg_type.inner {
naga::TypeInner::Struct { members, span: _ } => {
let item_path = RustSourceItemPath::from_mangled(
arg_type.name.as_ref().unwrap(),
invoking_entry_module,
);
let input = VertexInput {
item_path,
fields: members
.iter()
.filter_map(|member| {
let location = match member.binding.as_ref().unwrap() {
naga::Binding::BuiltIn(_) => None,
naga::Binding::Location { location, .. } => Some(*location),
}?;
Some((location, member.clone()))
})
.collect(),
naga_module: module,
};
Some(input)
}
_ => None,
}
})
.collect()
})
.unwrap_or_default()
}
pub struct VertexPrimitiveInput<'a> {
pub name: String,
pub ty: &'a naga::Type,
pub location: u32,
}
pub fn get_vertex_input_primitives_for_entry_point<'a>(
module: &'a naga::Module,
entry_point: &naga::EntryPoint,
) -> Vec<VertexPrimitiveInput<'a>> {
entry_point
.function
.arguments
.iter()
.filter_map(|argument| {
if let Some(naga::Binding::Location { location, .. }) = argument.binding {
let arg_type = &module.types[argument.ty];
Some(VertexPrimitiveInput {
name: argument
.name
.clone()
.unwrap_or_else(|| format!("loc{}", location)),
ty: arg_type,
location,
})
} else {
None
}
})
.collect()
}
pub fn get_vertex_input_structs_for_entry_point<'a>(
invoking_entry_module: &str,
module: &'a naga::Module,
entry_point: &naga::EntryPoint,
) -> Vec<VertexInput<'a>> {
entry_point
.function
.arguments
.iter()
.filter(|a| a.binding.is_none())
.filter_map(|argument| {
let arg_type = &module.types[argument.ty];
match &arg_type.inner {
naga::TypeInner::Struct { members, span: _ } => {
let item_path = RustSourceItemPath::from_mangled(
arg_type.name.as_ref().unwrap(),
invoking_entry_module,
);
let input = VertexInput {
item_path,
fields: members
.iter()
.filter_map(|member| {
let location = match member.binding.as_ref().unwrap() {
naga::Binding::BuiltIn(_) => None,
naga::Binding::Location { location, .. } => Some(*location),
}?;
Some((location, member.clone()))
})
.collect(),
naga_module: module,
};
Some(input)
}
_ => None,
}
})
.collect()
}
pub fn get_all_vertex_input_structs<'a>(
invoking_entry_module: &str,
module: &'a naga::Module,
) -> Vec<VertexInput<'a>> {
let mut all_inputs = Vec::new();
let mut seen_types = std::collections::HashSet::new();
for entry_point in module.entry_points.iter() {
if entry_point.stage == naga::ShaderStage::Vertex {
let inputs = get_vertex_input_structs_for_entry_point(
invoking_entry_module,
module,
entry_point,
);
for input in inputs {
let type_name = input.item_path.name.clone();
if seen_types.insert(type_name) {
all_inputs.push(input);
}
}
}
}
all_inputs
}
#[cfg(test)]
mod tests {
use indoc::indoc;
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn shader_stages_none() {
let source = indoc! {r#"
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
assert_eq!(wgpu::ShaderStages::NONE, shader_stages(&module));
}
#[test]
fn shader_stages_vertex() {
let source = indoc! {r#"
@vertex
fn main() {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
assert_eq!(wgpu::ShaderStages::VERTEX, shader_stages(&module));
}
#[test]
fn shader_stages_fragment() {
let source = indoc! {r#"
@fragment
fn main() {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
assert_eq!(wgpu::ShaderStages::FRAGMENT, shader_stages(&module));
}
#[test]
fn shader_stages_vertex_fragment() {
let source = indoc! {r#"
@vertex
fn vs_main() {}
@fragment
fn fs_main() {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
assert_eq!(wgpu::ShaderStages::VERTEX_FRAGMENT, shader_stages(&module));
}
#[test]
fn shader_stages_compute() {
let source = indoc! {r#"
@compute
@workgroup_size(64)
fn main() {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
assert_eq!(wgpu::ShaderStages::COMPUTE, shader_stages(&module));
}
#[test]
fn shader_stages_all() {
let source = indoc! {r#"
@vertex
fn vs_main() {}
@fragment
fn fs_main() {}
@compute
@workgroup_size(64)
fn cs_main() {}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
assert_eq!(
wgpu::ShaderStages::COMPUTE | wgpu::ShaderStages::VERTEX_FRAGMENT,
shader_stages(&module)
);
}
#[test]
fn vertex_input_structs_two_structs() {
let source = indoc! {r#"
struct VertexInput0 {
@location(0) in0: vec4<f32>,
@location(1) in1: vec4<f32>,
@location(2) in2: vec4<f32>,
};
struct VertexInput1 {
@location(3) in3: vec4<f32>,
@location(4) in4: vec4<f32>,
@builtin(vertex_index) index: u32,
@location(5) in5: vec4<f32>,
@location(6) in6: vec4<u32>,
};
@vertex
fn main(
in0: VertexInput0,
in1: VertexInput1,
@builtin(front_facing) in2: bool,
@location(7) in3: vec4<f32>,
) -> @builtin(position) vec4<f32> {
return vec4<f32>(0.0);
}
"#};
let module = naga::front::wgsl::parse_str(source).unwrap();
let vertex_inputs = get_vertex_input_structs("", &module);
assert_eq!(2, vertex_inputs.len());
assert_eq!("VertexInput0", vertex_inputs[0].item_path.name);
assert_eq!(3, vertex_inputs[0].fields.len());
assert_eq!("in1", vertex_inputs[0].fields[1].1.name.as_ref().unwrap());
assert_eq!(1, vertex_inputs[0].fields[1].0);
assert_eq!("VertexInput1", vertex_inputs[1].item_path.name);
assert_eq!(4, vertex_inputs[1].fields.len());
assert_eq!("in5", vertex_inputs[1].fields[2].1.name.as_ref().unwrap());
assert_eq!(5, vertex_inputs[1].fields[2].0);
}
}