use std::{
collections::HashMap,
ffi::{c_void, CString},
fs::{read_dir, File},
io::Read,
path::{Path, PathBuf},
ptr,
};
use ash::{
vk::{
PipelineShaderStageCreateFlags, PipelineShaderStageCreateInfo, ShaderModule,
ShaderModuleCreateFlags, ShaderModuleCreateInfo, ShaderStageFlags, SpecializationInfo,
StructureType,
},
Device,
};
pub struct ShaderStage<'a> {
pub device: &'a Device,
pub dir_path: &'a Path,
pub shader_flags: ShaderModuleCreateFlags,
pub shader_p_next: *const c_void,
pub main_function_name: CString,
pub shader_stage_flags: PipelineShaderStageCreateFlags,
pub shader_stage_p_next: *const c_void,
pub spec_info: *const SpecializationInfo,
}
impl<'a> ShaderStage<'a> {
pub fn new(device: &'a ash::Device, dir_path: &'a Path) -> Self {
Self {
device,
dir_path,
shader_flags: ShaderModuleCreateFlags::empty(),
shader_p_next: ptr::null(),
main_function_name: CString::new("main").unwrap(),
shader_stage_flags: PipelineShaderStageCreateFlags::empty(),
shader_stage_p_next: ptr::null(),
spec_info: ptr::null(),
}
}
pub fn with_shader_flags(&mut self, shader_flags: ShaderModuleCreateFlags) {
self.shader_flags = shader_flags;
}
pub fn with_shader_p_next(&mut self, p_next: *const c_void) {
self.shader_p_next = p_next;
}
pub fn with_shader_stage_flags(&mut self, shader_stage_flags: PipelineShaderStageCreateFlags) {
self.shader_stage_flags = shader_stage_flags;
}
pub fn with_shader_stage_p_next(&mut self, p_next: *const c_void) {
self.shader_stage_p_next = p_next;
}
pub fn with_spec_info(&mut self, spec_info: *const SpecializationInfo) {
self.spec_info = spec_info;
}
pub fn build(self) -> Vec<PipelineShaderStageCreateInfo> {
let shader_modules = create_shader_modules(
self.device,
self.dir_path,
self.shader_flags,
self.shader_p_next,
);
let file_paths = read_dir(self.dir_path)
.unwrap()
.into_iter()
.filter(|file_name| {
file_name
.as_ref()
.unwrap()
.path()
.to_str()
.unwrap()
.contains(".spv")
})
.map(|path| path.unwrap().path());
let shader_path: HashMap<&ShaderModule, PathBuf> =
shader_modules.iter().zip(file_paths.into_iter()).collect();
shader_modules
.iter()
.map(|module| PipelineShaderStageCreateInfo {
s_type: StructureType::PIPELINE_SHADER_STAGE_CREATE_INFO,
p_next: ptr::null(),
flags: PipelineShaderStageCreateFlags::empty(),
stage: if shader_path
.get(&module)
.unwrap()
.to_str()
.unwrap()
.contains("vert.spv")
{
ShaderStageFlags::VERTEX
} else if shader_path
.get(&module)
.unwrap()
.to_str()
.unwrap()
.contains("frag.spv")
{
ShaderStageFlags::FRAGMENT
} else {
panic!("Failed to define shader type!")
},
module: *module,
p_name: self.main_function_name.as_ptr(),
p_specialization_info: ptr::null(),
})
.collect()
}
}
fn create_shader_modules(
device: &Device,
dir_path: &Path,
flags: ShaderModuleCreateFlags,
p_next: *const c_void,
) -> Vec<ShaderModule> {
let spv_files_dir =
read_dir(dir_path).unwrap_or_else(|_| panic!("Failed to find spv file at {:?}", dir_path));
let files_path_buf: Vec<PathBuf> = spv_files_dir
.into_iter()
.filter(|file_name| {
file_name
.as_ref()
.unwrap()
.path()
.to_str()
.unwrap()
.contains(".spv")
})
.map(|compiled_shader| compiled_shader.unwrap().path())
.collect();
let files = files_path_buf.iter().map(|path_buf| {
File::open(path_buf).unwrap_or_else(|_| panic!("Failed to find spv file at {:?}", path_buf))
});
let shader_code = files.map(|file| {
file.bytes()
.filter_map(|byte| byte.ok())
.collect::<Vec<u8>>()
});
shader_code
.map(|shader_code| {
let shader_module_create_info = ShaderModuleCreateInfo {
s_type: StructureType::SHADER_MODULE_CREATE_INFO,
p_next,
flags,
code_size: shader_code.len(),
p_code: shader_code.as_ptr() as *const u32,
};
unsafe {
device
.create_shader_module(&shader_module_create_info, None)
.expect("Failed to create shader module!")
}
})
.collect::<Vec<ShaderModule>>()
}