vulk-ext 0.1.18

Vulkan bindings for Rust, except only the bleeding edge features
Documentation
use super::*;

#[derive(Clone, Copy, Debug)]
pub enum ShaderType {
    Task,
    Mesh,
    Fragment,
    Compute,
    Raygen,
    Miss,
    ClosestHit,
}

impl ShaderType {
    #[must_use]
    pub fn shader_stage(self) -> vk::ShaderStageFlagBits {
        match self {
            ShaderType::Task => vk::ShaderStageFlagBits::TaskEXT,
            ShaderType::Mesh => vk::ShaderStageFlagBits::MeshEXT,
            ShaderType::Fragment => vk::ShaderStageFlagBits::Fragment,
            ShaderType::Compute => vk::ShaderStageFlagBits::Compute,
            ShaderType::Raygen => vk::ShaderStageFlagBits::RaygenKHR,
            ShaderType::Miss => vk::ShaderStageFlagBits::MissKHR,
            ShaderType::ClosestHit => vk::ShaderStageFlagBits::ClosestHitKHR,
        }
    }

    #[must_use]
    pub fn next_shader_stage(self) -> Option<vk::ShaderStageFlagBits> {
        match self {
            ShaderType::Task => Some(vk::ShaderStageFlagBits::MeshEXT),
            ShaderType::Mesh => Some(vk::ShaderStageFlagBits::Fragment),
            ShaderType::Fragment
            | ShaderType::Compute
            | ShaderType::Raygen
            | ShaderType::Miss
            | ShaderType::ClosestHit => None,
        }
    }
}

#[derive(Clone, Debug)]
pub struct ShaderBinary {
    ty: ShaderType,
    code: Vec<u8>,
    entry_point: CString,
}

impl ShaderBinary {
    #[must_use]
    pub fn shader_type(&self) -> ShaderType {
        self.ty
    }

    #[must_use]
    pub fn code_size(&self) -> usize {
        self.code.len()
    }

    #[must_use]
    pub fn p_code(&self) -> *const std::ffi::c_void {
        self.code.as_ptr().cast()
    }

    #[must_use]
    pub fn entry_point(&self) -> *const std::ffi::c_char {
        self.entry_point.as_ptr()
    }

    #[must_use]
    pub fn shader_module_create_info(&self) -> vk::ShaderModuleCreateInfo {
        vk::ShaderModuleCreateInfo {
            s_type: vk::StructureType::ShaderModuleCreateInfo,
            p_next: null(),
            flags: vk::ShaderModuleCreateFlags::empty(),
            code_size: self.code_size(),
            p_code: self.p_code().cast(),
        }
    }
}

pub struct ShaderCompiler {
    compiler: shaderc::Compiler,
    includes: HashMap<String, String>,
}

impl ShaderCompiler {
    pub fn new() -> Result<Self> {
        use shaderc::Compiler;
        let compiler = Compiler::new().context("Creating shader compiler")?;
        Ok(Self {
            compiler,
            includes: HashMap::new(),
        })
    }

    pub fn include(&mut self, source_name: impl AsRef<str>, content: impl AsRef<str>) {
        self.includes.insert(
            source_name.as_ref().to_string(),
            content.as_ref().to_string(),
        );
    }

    pub fn compile(
        &self,
        shader_type: ShaderType,
        input_file_name: impl AsRef<str>,
        entry_point_name: impl AsRef<str>,
        code: impl AsRef<str>,
    ) -> Result<ShaderBinary> {
        use shaderc::CompileOptions;
        use shaderc::OptimizationLevel;
        use shaderc::ResolvedInclude;
        use shaderc::SourceLanguage;
        use shaderc::SpirvVersion;
        use shaderc::TargetEnv;

        let shader_kind = match shader_type {
            ShaderType::Task => shaderc::ShaderKind::Task,
            ShaderType::Mesh => shaderc::ShaderKind::Mesh,
            ShaderType::Fragment => shaderc::ShaderKind::Fragment,
            ShaderType::Compute => shaderc::ShaderKind::Compute,
            ShaderType::Raygen => shaderc::ShaderKind::RayGeneration,
            ShaderType::Miss => shaderc::ShaderKind::Miss,
            ShaderType::ClosestHit => shaderc::ShaderKind::ClosestHit,
        };
        let mut options = CompileOptions::new().context("Creating shader compiler options")?;
        options.set_target_env(TargetEnv::Vulkan, vulk::REQUIRED_VULKAN_VERSION);
        options.set_optimization_level(OptimizationLevel::Performance);
        options.set_target_spirv(SpirvVersion::V1_6);
        options.set_source_language(SourceLanguage::GLSL);
        options.set_warnings_as_errors();
        options.set_include_callback(|source_name, _, _, _| -> shaderc::IncludeCallbackResult {
            let content = self.includes.get(source_name);
            if let Some(content) = content {
                Ok(ResolvedInclude {
                    resolved_name: source_name.to_owned(),
                    content: (*content).clone(),
                })
            } else {
                Err(format!("Unable to include source file {source_name}."))
            }
        });
        let shader = self.compiler.compile_into_spirv(
            code.as_ref(),
            shader_kind,
            input_file_name.as_ref(),
            entry_point_name.as_ref(),
            Some(&options),
        )?;
        if shader.get_num_warnings() > 0 {
            error!("{}", shader.get_warning_messages());
        }
        Ok(ShaderBinary {
            ty: shader_type,
            code: shader.as_binary_u8().to_owned(),
            entry_point: CString::new(entry_point_name.as_ref())?,
        })
    }
}

#[derive(Debug)]
pub struct ShaderCreateInfo<'a> {
    pub shader_binaries: &'a [ShaderBinary],
    pub set_layouts: &'a [vk::DescriptorSetLayout],
    pub push_constant_ranges: &'a [vk::PushConstantRange],
    pub specialization_info: Option<&'a vk::SpecializationInfo>,
}

#[derive(Debug)]
pub struct Shader {
    pub(super) stages: Vec<vk::ShaderStageFlagBits>,
    pub(super) shaders: Vec<vk::ShaderEXT>,
}

impl Shader {
    pub unsafe fn create(device: &Device, create_info: &ShaderCreateInfo<'_>) -> Result<Self> {
        let create_infos = create_info
            .shader_binaries
            .iter()
            .map(|binary| vk::ShaderCreateInfoEXT {
                s_type: vk::StructureType::ShaderCreateInfoEXT,
                p_next: null(),
                flags: vk::ShaderCreateFlagBitsEXT::LinkStageEXT.into(),
                stage: binary.ty.shader_stage(),
                next_stage: if let Some(next_stage) = binary.ty.next_shader_stage() {
                    next_stage.into()
                } else {
                    zeroed()
                },
                code_type: vk::ShaderCodeTypeEXT::SpirvEXT,
                code_size: binary.code_size(),
                p_code: binary.p_code(),
                p_name: binary.entry_point(),
                set_layout_count: create_info.set_layouts.len() as _,
                p_set_layouts: create_info.set_layouts.as_ptr(),
                push_constant_range_count: create_info.push_constant_ranges.len() as _,
                p_push_constant_ranges: create_info.push_constant_ranges.as_ptr(),
                p_specialization_info: if let Some(specialization_info) =
                    create_info.specialization_info
                {
                    std::ptr::from_ref(specialization_info)
                } else {
                    null()
                },
            })
            .collect::<Vec<_>>();
        let mut shaders = Vec::with_capacity(create_info.shader_binaries.len());
        device.create_shaders_ext(
            create_infos.len() as _,
            create_infos.as_ptr(),
            shaders.as_mut_ptr(),
        )?;
        shaders.set_len(create_info.shader_binaries.len());

        let stages = create_info
            .shader_binaries
            .iter()
            .map(|spirv| spirv.ty.shader_stage())
            .collect();

        Ok(Self { stages, shaders })
    }

    pub unsafe fn destroy(self, device: &Device) {
        for &shader in &self.shaders {
            device.destroy_shader_ext(shader);
        }
    }
}

impl Shader {
    pub unsafe fn bind(&self, device: &Device, cmd: vk::CommandBuffer) {
        device.cmd_bind_shaders_ext(
            cmd,
            self.stages.len() as _,
            self.stages.as_ptr(),
            self.shaders.as_ptr(),
        );
    }
}