lambda_platform/gfx/
shader.rs1use gfx_hal::{
5 device::Device,
6 pso::Specialization as ShaderSpecializations,
7};
8#[cfg(test)]
9use mockall::automock;
10
11use super::gpu;
12
13pub enum ShaderModuleType {
16 Vertex,
17 Fragment,
18 Compute,
19}
20
21pub struct ShaderModuleBuilder {
23 entry_name: String,
24 specializations: ShaderSpecializations<'static>,
25}
26
27#[cfg_attr(test, automock)]
28impl ShaderModuleBuilder {
29 pub fn new() -> Self {
30 return Self {
31 entry_name: "main".to_string(),
32 specializations: ShaderSpecializations::EMPTY,
33 };
34 }
35
36 pub fn with_entry_name(mut self, entry_name: &str) -> Self {
38 self.entry_name = entry_name.to_string();
39 return self;
40 }
41
42 pub fn with_specializations(
44 mut self,
45 specializations: ShaderSpecializations<'static>,
46 ) -> Self {
47 self.specializations = specializations;
48 return self;
49 }
50
51 pub fn build<RenderBackend: gfx_hal::Backend>(
55 self,
56 gpu: &mut gpu::Gpu<RenderBackend>,
57 shader_binary: &Vec<u32>,
58 shader_type: ShaderModuleType,
59 ) -> ShaderModule<RenderBackend> {
60 let shader_module = unsafe {
61 gpu
62 .internal_logical_device()
63 .create_shader_module(&shader_binary)
64 .expect("Failed to create a shader module.")
65 };
66
67 return ShaderModule {
68 entry_name: self.entry_name,
69 shader_module,
70 specializations: self.specializations,
71 shader_type,
72 };
73 }
74}
75
76pub struct ShaderModule<RenderBackend: gfx_hal::Backend> {
78 entry_name: String,
79 shader_module: RenderBackend::ShaderModule,
80 specializations: ShaderSpecializations<'static>,
81 shader_type: ShaderModuleType,
82}
83
84#[cfg_attr(test, automock)]
85impl<RenderBackend: gfx_hal::Backend> ShaderModule<RenderBackend> {
86 pub fn destroy(self, gpu: &mut gpu::Gpu<RenderBackend>) {
88 unsafe {
89 gpu
90 .internal_logical_device()
91 .destroy_shader_module(self.shader_module)
92 }
93 }
94
95 pub fn entry(&self) -> &str {
97 return self.entry_name.as_str();
98 }
99
100 pub fn specializations(&self) -> &ShaderSpecializations<'static> {
102 return &self.specializations;
103 }
104}
105
106#[cfg(test)]
107mod tests {
108
109 #[test]
112 fn shader_builder_initial_state() {
113 let shader_builder = super::ShaderModuleBuilder::new();
114 assert_eq!(shader_builder.entry_name, "main");
115 assert_eq!(shader_builder.specializations.data.len(), 0);
116 }
117
118 #[test]
121 fn shader_builder_with_properties() {
122 let shader_builder = super::ShaderModuleBuilder::new()
123 .with_entry_name("test")
124 .with_specializations(super::ShaderSpecializations::default());
125 assert_eq!(shader_builder.entry_name, "test");
126 assert_eq!(
127 shader_builder.specializations.data,
128 super::ShaderSpecializations::default().data
129 );
130 }
131
132 #[test]
133 fn shader_builder_builds_correctly() {
134 let shader_builder = super::ShaderModuleBuilder::new()
135 .with_entry_name("test")
136 .with_specializations(super::ShaderSpecializations::default());
137 }
138}
139
140pub mod internal {
143 use super::ShaderModule;
144
145 #[inline]
149 pub fn module_for<RenderBackend: gfx_hal::Backend>(
150 shader_module: &ShaderModule<RenderBackend>,
151 ) -> &RenderBackend::ShaderModule {
152 return &shader_module.shader_module;
153 }
154}