use crate::error::{GpuError, Result};
pub struct CompiledShader {
pub spirv: Vec<u32>,
pub module: naga::Module,
pub _entry_point: String,
}
pub fn compile_wgsl(source: &str, entry_point: &str) -> Result<CompiledShader> {
let module = naga::front::wgsl::parse_str(source)
.map_err(|e| GpuError::ShaderCompilation(format!("{e}")))?;
let info = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
)
.validate(&module)
.map_err(|e| GpuError::ShaderCompilation(format!("{e}")))?;
let ep_exists = module.entry_points.iter().any(|ep| ep.name == entry_point);
if !ep_exists {
return Err(GpuError::MissingEntryPoint {
name: entry_point.to_string(),
});
}
let spirv = naga::back::spv::write_vec(
&module,
&info,
&naga::back::spv::Options {
lang_version: (1, 3),
..Default::default()
},
None,
)
.map_err(|e| GpuError::ShaderCompilation(format!("{e}")))?;
Ok(CompiledShader {
spirv,
module,
_entry_point: entry_point.to_string(),
})
}
pub fn push_constant_size(module: &naga::Module) -> u32 {
for (_, var) in module.global_variables.iter() {
if var.space == naga::AddressSpace::PushConstant {
let ty = &module.types[var.ty];
if let naga::TypeInner::Struct { span, .. } = ty.inner {
return span;
}
}
}
0
}
pub fn binding_count(module: &naga::Module) -> usize {
module
.global_variables
.iter()
.filter(|(_, var)| var.binding.is_some())
.filter(|(_, var)| {
matches!(
var.space,
naga::AddressSpace::Storage { .. } | naga::AddressSpace::Uniform
)
})
.count()
}