bort_vk/
shader_module.rs

1use crate::{Device, DeviceOwned, ALLOCATION_CALLBACK_NONE};
2use ash::{
3    util::read_spv,
4    vk::{self, Handle},
5};
6use std::{
7    error,
8    ffi::CString,
9    fmt, fs,
10    io::{self, Cursor},
11    sync::Arc,
12};
13
14pub struct ShaderModule {
15    handle: vk::ShaderModule,
16
17    // dependencies
18    device: Arc<Device>,
19}
20
21impl ShaderModule {
22    pub fn new_from_file(device: Arc<Device>, file_path: &str) -> Result<Self, ShaderError> {
23        let bytes = fs::read(file_path).map_err(|e| ShaderError::FileRead {
24            e,
25            path: file_path.to_string(),
26        })?;
27        let mut cursor = Cursor::new(bytes);
28
29        Self::new_from_spirv(device, &mut cursor)
30    }
31
32    pub fn new_from_spirv<R: io::Read + io::Seek>(
33        device: Arc<Device>,
34        spirv: &mut R,
35    ) -> Result<Self, ShaderError> {
36        let code = read_spv(spirv).map_err(|e| ShaderError::SpirVDecode(e))?;
37        let create_info = vk::ShaderModuleCreateInfo::builder().code(&code);
38
39        unsafe { Self::new_from_create_info(device, create_info) }
40    }
41
42    pub unsafe fn new_from_create_info(
43        device: Arc<Device>,
44        create_info_builder: vk::ShaderModuleCreateInfoBuilder,
45    ) -> Result<Self, ShaderError> {
46        let handle = unsafe {
47            device
48                .inner()
49                .create_shader_module(&create_info_builder, ALLOCATION_CALLBACK_NONE)
50        }
51        .map_err(|e| ShaderError::Creation(e))?;
52
53        Ok(Self { handle, device })
54    }
55
56    // Getters
57
58    #[inline]
59    pub fn handle(&self) -> vk::ShaderModule {
60        self.handle
61    }
62}
63
64impl DeviceOwned for ShaderModule {
65    #[inline]
66    fn device(&self) -> &Arc<Device> {
67        &self.device
68    }
69
70    #[inline]
71    fn handle_raw(&self) -> u64 {
72        self.handle.as_raw()
73    }
74}
75
76impl Drop for ShaderModule {
77    fn drop(&mut self) {
78        unsafe {
79            self.device
80                .inner()
81                .destroy_shader_module(self.handle, ALLOCATION_CALLBACK_NONE);
82        }
83    }
84}
85
86// Shader Stage
87
88// Note: this isn't a member of `GraphicsPipelineProperties` because we only need to ensure
89// the `ShaderModule` lifetime lasts during pipeline creation. Not needed after that.
90#[derive(Clone)]
91pub struct ShaderStage {
92    pub flags: vk::PipelineShaderStageCreateFlags,
93    pub stage: vk::ShaderStageFlags,
94    pub module: Arc<ShaderModule>,
95    pub entry_point: CString,
96    pub write_specialization_info: bool,
97    pub specialization_info: vk::SpecializationInfo,
98}
99
100impl ShaderStage {
101    pub fn new(
102        stage: vk::ShaderStageFlags,
103        module: Arc<ShaderModule>,
104        entry_point: CString,
105        specialization_info: Option<vk::SpecializationInfo>,
106    ) -> Self {
107        Self {
108            flags: vk::PipelineShaderStageCreateFlags::empty(),
109            stage,
110            module,
111            entry_point,
112            write_specialization_info: specialization_info.is_some(),
113            specialization_info: specialization_info.unwrap_or_default(),
114        }
115    }
116
117    pub fn write_create_info_builder<'a>(
118        &'a self,
119        builder: vk::PipelineShaderStageCreateInfoBuilder<'a>,
120    ) -> vk::PipelineShaderStageCreateInfoBuilder {
121        let builder = builder
122            .flags(self.flags)
123            .module(self.module.handle())
124            .stage(self.stage)
125            .name(self.entry_point.as_c_str());
126        if self.write_specialization_info {
127            builder.specialization_info(&self.specialization_info)
128        } else {
129            builder
130        }
131    }
132
133    pub fn create_info_builder(&self) -> vk::PipelineShaderStageCreateInfoBuilder {
134        self.write_create_info_builder(vk::PipelineShaderStageCreateInfo::builder())
135    }
136}
137
138// Errors
139
140#[derive(Debug)]
141pub enum ShaderError {
142    FileRead { e: io::Error, path: String },
143    SpirVDecode(io::Error),
144    Creation(vk::Result),
145}
146
147impl fmt::Display for ShaderError {
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        match self {
150            Self::FileRead { e, path } => {
151                write!(f, "failed to read file {} due to: {}", path, e)
152            }
153            Self::SpirVDecode(e) => write!(f, "failed to decode spirv: {}", e),
154            Self::Creation(e) => write!(f, "shader module creation failed: {}", e),
155        }
156    }
157}
158
159impl error::Error for ShaderError {
160    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
161        match self {
162            Self::FileRead { e, .. } => Some(e),
163            Self::SpirVDecode(e) => Some(e),
164            Self::Creation(e) => Some(e),
165        }
166    }
167}