use crate::core::Device;
use ash::vk;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ShaderError {
#[error("Shader compilation failed: {message}")]
CompilationFailed { message: String, line: Option<u32> },
#[error("Invalid SPIR-V bytecode")]
InvalidSpirv,
#[error("Shader module creation failed: {0}")]
ModuleCreationFailed(vk::Result),
#[error("Invalid entry point")]
InvalidEntryPoint,
#[error("Reflection failed: {0}")]
ReflectionFailed(String),
#[error("Unsupported shader stage")]
UnsupportedStage,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShaderStage {
Vertex,
Fragment,
Compute,
Geometry,
TessellationControl,
TessellationEvaluation,
}
impl ShaderStage {
pub fn to_vk_flags(self) -> vk::ShaderStageFlags {
match self {
ShaderStage::Vertex => vk::ShaderStageFlags::VERTEX,
ShaderStage::Fragment => vk::ShaderStageFlags::FRAGMENT,
ShaderStage::Compute => vk::ShaderStageFlags::COMPUTE,
ShaderStage::Geometry => vk::ShaderStageFlags::GEOMETRY,
ShaderStage::TessellationControl => vk::ShaderStageFlags::TESSELLATION_CONTROL,
ShaderStage::TessellationEvaluation => vk::ShaderStageFlags::TESSELLATION_EVALUATION,
}
}
fn to_naga_stage(self) -> naga::ShaderStage {
match self {
ShaderStage::Vertex => naga::ShaderStage::Vertex,
ShaderStage::Fragment => naga::ShaderStage::Fragment,
ShaderStage::Compute => naga::ShaderStage::Compute,
_ => naga::ShaderStage::Vertex, }
}
}
impl From<ShaderStage> for naga::ShaderStage {
fn from(stage: ShaderStage) -> Self {
stage.to_naga_stage()
}
}
pub struct Shader {
module: vk::ShaderModule,
stage: ShaderStage,
entry_point: String,
}
impl Shader {
pub fn from_spirv(
device: &Device,
spirv: &[u32],
stage: ShaderStage,
entry_point: &str,
) -> Result<Self, ShaderError> {
let create_info = vk::ShaderModuleCreateInfo {
code_size: spirv.len() * 4,
p_code: spirv.as_ptr(),
..Default::default()
};
let module = unsafe {
device
.handle()
.create_shader_module(&create_info, None)
.map_err(ShaderError::ModuleCreationFailed)?
};
Ok(Self {
module,
stage,
entry_point: entry_point.to_string(),
})
}
pub fn from_glsl(
device: &Device,
source: &str,
stage: ShaderStage,
entry_point: &str,
) -> Result<Self, ShaderError> {
let spirv = ShaderCompiler::compile_glsl(source, stage)?;
Self::from_spirv(device, &spirv, stage, entry_point)
}
#[inline]
pub fn handle(&self) -> vk::ShaderModule {
self.module
}
#[inline]
pub fn stage(&self) -> ShaderStage {
self.stage
}
#[inline]
pub fn entry_point(&self) -> &str {
&self.entry_point
}
pub fn destroy(&self, device: &Device) {
unsafe {
device.handle().destroy_shader_module(self.module, None);
}
}
}
impl Drop for Shader {
fn drop(&mut self) {
if self.module != vk::ShaderModule::null() {
eprintln!("WARNING: Shader dropped without calling .destroy() - potential memory leak");
}
}
}
pub struct ShaderCompiler;
impl ShaderCompiler {
pub fn compile_glsl(source: &str, stage: ShaderStage) -> Result<Vec<u32>, ShaderError> {
use naga::back::spv;
use naga::front::glsl;
use naga::valid::{ValidationFlags, Validator};
let mut parser = glsl::Frontend::default();
let options = glsl::Options {
stage: stage.to_naga_stage(),
defines: Default::default(),
};
let module = parser.parse(&options, source).map_err(|errors| {
let message = format!("{:?}", errors);
ShaderError::CompilationFailed {
message,
line: None,
}
})?;
let mut validator = Validator::new(ValidationFlags::all(), Default::default());
let info = validator
.validate(&module)
.map_err(|e| ShaderError::CompilationFailed {
message: format!("Validation failed: {:?}", e),
line: None,
})?;
let options = spv::Options {
lang_version: (1, 0),
flags: spv::WriterFlags::empty(),
capabilities: None,
bounds_check_policies: Default::default(),
zero_initialize_workgroup_memory:
naga::back::spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
..Default::default()
};
let mut words = Vec::new();
let mut writer =
spv::Writer::new(&options).map_err(|e| ShaderError::CompilationFailed {
message: format!("SPIR-V writer creation failed: {:?}", e),
line: None,
})?;
let pipeline_options = naga::back::spv::PipelineOptions {
shader_stage: stage.into(),
entry_point: "main".to_string(),
};
writer
.write(&module, &info, Some(&pipeline_options), &None, &mut words)
.map_err(|e| ShaderError::CompilationFailed {
message: format!("SPIR-V generation failed: {:?}", e),
line: None,
})?;
Ok(words)
}
}
pub struct ShaderReflection {
descriptor_sets: Vec<DescriptorSetReflection>,
push_constants: Vec<PushConstantReflection>,
}
#[derive(Debug, Clone)]
pub struct DescriptorSetReflection {
pub set: u32,
pub bindings: Vec<DescriptorBindingReflection>,
}
#[derive(Debug, Clone)]
pub struct DescriptorBindingReflection {
pub binding: u32,
pub descriptor_type: vk::DescriptorType,
pub descriptor_count: u32,
pub stage_flags: vk::ShaderStageFlags,
}
#[derive(Debug, Clone)]
pub struct PushConstantReflection {
pub offset: u32,
pub size: u32,
pub stage_flags: vk::ShaderStageFlags,
}
impl ShaderReflection {
pub fn from_spirv(spirv: &[u32]) -> Result<Self, ShaderError> {
use spirv_reflect::types::ReflectDescriptorType;
let module = spirv_reflect::ShaderModule::load_u32_data(spirv)
.map_err(|e| ShaderError::ReflectionFailed(format!("{:?}", e)))?;
let descriptor_sets = module
.enumerate_descriptor_sets(None)
.map_err(|e| ShaderError::ReflectionFailed(format!("{:?}", e)))?
.into_iter()
.map(|set| {
let bindings = set
.bindings
.iter()
.map(|binding| {
let descriptor_type = match binding.descriptor_type {
ReflectDescriptorType::Sampler => vk::DescriptorType::SAMPLER,
ReflectDescriptorType::CombinedImageSampler => {
vk::DescriptorType::COMBINED_IMAGE_SAMPLER
}
ReflectDescriptorType::SampledImage => {
vk::DescriptorType::SAMPLED_IMAGE
}
ReflectDescriptorType::StorageImage => {
vk::DescriptorType::STORAGE_IMAGE
}
ReflectDescriptorType::UniformBuffer => {
vk::DescriptorType::UNIFORM_BUFFER
}
ReflectDescriptorType::StorageBuffer => {
vk::DescriptorType::STORAGE_BUFFER
}
ReflectDescriptorType::UniformBufferDynamic => {
vk::DescriptorType::UNIFORM_BUFFER_DYNAMIC
}
ReflectDescriptorType::StorageBufferDynamic => {
vk::DescriptorType::STORAGE_BUFFER_DYNAMIC
}
ReflectDescriptorType::InputAttachment => {
vk::DescriptorType::INPUT_ATTACHMENT
}
_ => vk::DescriptorType::UNIFORM_BUFFER,
};
DescriptorBindingReflection {
binding: binding.binding,
descriptor_type,
descriptor_count: binding.count,
stage_flags: vk::ShaderStageFlags::ALL,
}
})
.collect();
DescriptorSetReflection {
set: set.set,
bindings,
}
})
.collect();
let push_constants = module
.enumerate_push_constant_blocks(None)
.map_err(|e| ShaderError::ReflectionFailed(format!("{:?}", e)))?
.into_iter()
.map(|block| PushConstantReflection {
offset: block.offset,
size: block.size,
stage_flags: vk::ShaderStageFlags::ALL,
})
.collect();
Ok(Self {
descriptor_sets,
push_constants,
})
}
pub fn descriptor_sets(&self) -> &[DescriptorSetReflection] {
&self.descriptor_sets
}
pub fn push_constants(&self) -> &[PushConstantReflection] {
&self.push_constants
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Device, DeviceCreateInfo, Instance, InstanceCreateInfo, QueueCreateInfo};
fn create_test_device() -> (Instance, Device) {
let instance = Instance::new(InstanceCreateInfo {
enable_validation: false,
..Default::default()
})
.unwrap();
let physical_devices = instance.enumerate_physical_devices().unwrap();
let physical_device = physical_devices[0];
let graphics_family = unsafe {
instance
.get_physical_device_queue_family_properties(physical_device)
.iter()
.enumerate()
.find(|(_, qf)| qf.queue_flags.contains(vk::QueueFlags::GRAPHICS))
.map(|(i, _)| i as u32)
.unwrap()
};
let device = Device::new(
&instance,
physical_device,
DeviceCreateInfo {
queue_create_infos: vec![QueueCreateInfo {
queue_family_index: graphics_family,
queue_count: 1,
queue_priorities: vec![1.0],
}],
..Default::default()
},
)
.unwrap();
(instance, device)
}
#[test]
fn test_shader_stage_to_vk_flags() {
assert_eq!(
ShaderStage::Vertex.to_vk_flags(),
vk::ShaderStageFlags::VERTEX
);
assert_eq!(
ShaderStage::Fragment.to_vk_flags(),
vk::ShaderStageFlags::FRAGMENT
);
assert_eq!(
ShaderStage::Compute.to_vk_flags(),
vk::ShaderStageFlags::COMPUTE
);
}
#[test]
fn test_shader_from_spirv() {
let (_instance, device) = create_test_device();
let spirv = vec![
0x07230203, 0x00010000, 0x00000000, 0x00000001, 0x00000000, ];
let shader = Shader::from_spirv(&device, &spirv, ShaderStage::Vertex, "main").unwrap();
assert_ne!(shader.handle(), vk::ShaderModule::null());
assert_eq!(shader.stage(), ShaderStage::Vertex);
assert_eq!(shader.entry_point(), "main");
shader.destroy(&device);
}
#[test]
fn test_glsl_compilation() {
let glsl = r#"
#version 450
void main() {
gl_Position = vec4(0.0, 0.0, 0.0, 1.0);
}
"#;
let result = ShaderCompiler::compile_glsl(glsl, ShaderStage::Vertex);
match result {
Ok(spirv) => {
assert!(!spirv.is_empty());
assert_eq!(spirv[0], 0x07230203); }
Err(e) => {
println!("GLSL compilation failed (expected with naga): {}", e);
}
}
}
}