1use crate::{Device, DeviceOwned, ALLOCATION_CALLBACK_NONE};
2use ash::{
3 util::read_spv,
4 vk::{self, Handle},
5};
6use std::{
7 error,
8 ffi::CString,
9 fmt, fs,
10 io::{self, Cursor},
11 sync::Arc,
12};
13
14pub struct ShaderModule {
15 handle: vk::ShaderModule,
16
17 device: Arc<Device>,
19}
20
21impl ShaderModule {
22 pub fn new_from_file(device: Arc<Device>, file_path: &str) -> Result<Self, ShaderError> {
23 let bytes = fs::read(file_path).map_err(|e| ShaderError::FileRead {
24 e,
25 path: file_path.to_string(),
26 })?;
27 let mut cursor = Cursor::new(bytes);
28
29 Self::new_from_spirv(device, &mut cursor)
30 }
31
32 pub fn new_from_spirv<R: io::Read + io::Seek>(
33 device: Arc<Device>,
34 spirv: &mut R,
35 ) -> Result<Self, ShaderError> {
36 let code = read_spv(spirv).map_err(|e| ShaderError::SpirVDecode(e))?;
37 let create_info = vk::ShaderModuleCreateInfo::builder().code(&code);
38
39 unsafe { Self::new_from_create_info(device, create_info) }
40 }
41
42 pub unsafe fn new_from_create_info(
43 device: Arc<Device>,
44 create_info_builder: vk::ShaderModuleCreateInfoBuilder,
45 ) -> Result<Self, ShaderError> {
46 let handle = unsafe {
47 device
48 .inner()
49 .create_shader_module(&create_info_builder, ALLOCATION_CALLBACK_NONE)
50 }
51 .map_err(|e| ShaderError::Creation(e))?;
52
53 Ok(Self { handle, device })
54 }
55
56 #[inline]
59 pub fn handle(&self) -> vk::ShaderModule {
60 self.handle
61 }
62}
63
64impl DeviceOwned for ShaderModule {
65 #[inline]
66 fn device(&self) -> &Arc<Device> {
67 &self.device
68 }
69
70 #[inline]
71 fn handle_raw(&self) -> u64 {
72 self.handle.as_raw()
73 }
74}
75
76impl Drop for ShaderModule {
77 fn drop(&mut self) {
78 unsafe {
79 self.device
80 .inner()
81 .destroy_shader_module(self.handle, ALLOCATION_CALLBACK_NONE);
82 }
83 }
84}
85
86#[derive(Clone)]
91pub struct ShaderStage {
92 pub flags: vk::PipelineShaderStageCreateFlags,
93 pub stage: vk::ShaderStageFlags,
94 pub module: Arc<ShaderModule>,
95 pub entry_point: CString,
96 pub write_specialization_info: bool,
97 pub specialization_info: vk::SpecializationInfo,
98}
99
100impl ShaderStage {
101 pub fn new(
102 stage: vk::ShaderStageFlags,
103 module: Arc<ShaderModule>,
104 entry_point: CString,
105 specialization_info: Option<vk::SpecializationInfo>,
106 ) -> Self {
107 Self {
108 flags: vk::PipelineShaderStageCreateFlags::empty(),
109 stage,
110 module,
111 entry_point,
112 write_specialization_info: specialization_info.is_some(),
113 specialization_info: specialization_info.unwrap_or_default(),
114 }
115 }
116
117 pub fn write_create_info_builder<'a>(
118 &'a self,
119 builder: vk::PipelineShaderStageCreateInfoBuilder<'a>,
120 ) -> vk::PipelineShaderStageCreateInfoBuilder {
121 let builder = builder
122 .flags(self.flags)
123 .module(self.module.handle())
124 .stage(self.stage)
125 .name(self.entry_point.as_c_str());
126 if self.write_specialization_info {
127 builder.specialization_info(&self.specialization_info)
128 } else {
129 builder
130 }
131 }
132
133 pub fn create_info_builder(&self) -> vk::PipelineShaderStageCreateInfoBuilder {
134 self.write_create_info_builder(vk::PipelineShaderStageCreateInfo::builder())
135 }
136}
137
138#[derive(Debug)]
141pub enum ShaderError {
142 FileRead { e: io::Error, path: String },
143 SpirVDecode(io::Error),
144 Creation(vk::Result),
145}
146
147impl fmt::Display for ShaderError {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 match self {
150 Self::FileRead { e, path } => {
151 write!(f, "failed to read file {} due to: {}", path, e)
152 }
153 Self::SpirVDecode(e) => write!(f, "failed to decode spirv: {}", e),
154 Self::Creation(e) => write!(f, "shader module creation failed: {}", e),
155 }
156 }
157}
158
159impl error::Error for ShaderError {
160 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
161 match self {
162 Self::FileRead { e, .. } => Some(e),
163 Self::SpirVDecode(e) => Some(e),
164 Self::Creation(e) => Some(e),
165 }
166 }
167}