lambda_platform/gfx/
shader.rs

1//! Low level shader implementations used by the lambda-platform crate to load
2//! SPIR-V compiled shaders into the GPU.
3
4use gfx_hal::{
5  device::Device,
6  pso::Specialization as ShaderSpecializations,
7};
8#[cfg(test)]
9use mockall::automock;
10
11use super::gpu;
12
13/// The type of shader that a shader module represents. Different shader types
14/// are used for different operations in the rendering pipeline.
15pub enum ShaderModuleType {
16  Vertex,
17  Fragment,
18  Compute,
19}
20
21/// Builder class for creating a shader module.
22pub struct ShaderModuleBuilder {
23  entry_name: String,
24  specializations: ShaderSpecializations<'static>,
25}
26
27#[cfg_attr(test, automock)]
28impl ShaderModuleBuilder {
29  pub fn new() -> Self {
30    return Self {
31      entry_name: "main".to_string(),
32      specializations: ShaderSpecializations::EMPTY,
33    };
34  }
35
36  /// Define the shader entry point (Defaults to main)
37  pub fn with_entry_name(mut self, entry_name: &str) -> Self {
38    self.entry_name = entry_name.to_string();
39    return self;
40  }
41
42  /// Attach specializations to the shader.
43  pub fn with_specializations(
44    mut self,
45    specializations: ShaderSpecializations<'static>,
46  ) -> Self {
47    self.specializations = specializations;
48    return self;
49  }
50
51  /// Builds the shader binary into a shader module located on the GPU.
52  /// ShaderModules are specific to gfx-hal and can be used for building
53  /// RenderPipelines
54  pub fn build<RenderBackend: gfx_hal::Backend>(
55    self,
56    gpu: &mut gpu::Gpu<RenderBackend>,
57    shader_binary: &Vec<u32>,
58    shader_type: ShaderModuleType,
59  ) -> ShaderModule<RenderBackend> {
60    let shader_module = unsafe {
61      gpu
62        .internal_logical_device()
63        .create_shader_module(&shader_binary)
64        .expect("Failed to create a shader module.")
65    };
66
67    return ShaderModule {
68      entry_name: self.entry_name,
69      shader_module,
70      specializations: self.specializations,
71      shader_type,
72    };
73  }
74}
75
76/// Shader modules are used for uploading shaders into the render pipeline.
77pub struct ShaderModule<RenderBackend: gfx_hal::Backend> {
78  entry_name: String,
79  shader_module: RenderBackend::ShaderModule,
80  specializations: ShaderSpecializations<'static>,
81  shader_type: ShaderModuleType,
82}
83
84#[cfg_attr(test, automock)]
85impl<RenderBackend: gfx_hal::Backend> ShaderModule<RenderBackend> {
86  /// Destroy the shader module and free the memory on the GPU.
87  pub fn destroy(self, gpu: &mut gpu::Gpu<RenderBackend>) {
88    unsafe {
89      gpu
90        .internal_logical_device()
91        .destroy_shader_module(self.shader_module)
92    }
93  }
94
95  /// Get the entry point that this shader module is using.
96  pub fn entry(&self) -> &str {
97    return self.entry_name.as_str();
98  }
99
100  /// Get the specializations being applied to the current shader module.
101  pub fn specializations(&self) -> &ShaderSpecializations<'static> {
102    return &self.specializations;
103  }
104}
105
106#[cfg(test)]
107mod tests {
108
109  /// Test that we can create a shader module builder and it has the correct
110  /// defaults.
111  #[test]
112  fn shader_builder_initial_state() {
113    let shader_builder = super::ShaderModuleBuilder::new();
114    assert_eq!(shader_builder.entry_name, "main");
115    assert_eq!(shader_builder.specializations.data.len(), 0);
116  }
117
118  /// Test that we can create a shader module builder with a custom entry point
119  /// & default specializations.
120  #[test]
121  fn shader_builder_with_properties() {
122    let shader_builder = super::ShaderModuleBuilder::new()
123      .with_entry_name("test")
124      .with_specializations(super::ShaderSpecializations::default());
125    assert_eq!(shader_builder.entry_name, "test");
126    assert_eq!(
127      shader_builder.specializations.data,
128      super::ShaderSpecializations::default().data
129    );
130  }
131
132  #[test]
133  fn shader_builder_builds_correctly() {
134    let shader_builder = super::ShaderModuleBuilder::new()
135      .with_entry_name("test")
136      .with_specializations(super::ShaderSpecializations::default());
137  }
138}
139
140/// Internal functions for the shader module. User applications most likely
141/// should not use these functions directly nor should they need to.
142pub mod internal {
143  use super::ShaderModule;
144
145  /// Retrieve the underlying gfx-hal shader module given the lambda-platform
146  /// implemented shader module. Useful for creating gfx-hal entry points and
147  /// attaching the shader to rendering pipelines.
148  #[inline]
149  pub fn module_for<RenderBackend: gfx_hal::Backend>(
150    shader_module: &ShaderModule<RenderBackend>,
151  ) -> &RenderBackend::ShaderModule {
152    return &shader_module.shader_module;
153  }
154}