1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//! Low level shader implementations used by the lambda-platform crate to load
//! SPIR-V compiled shaders into the GPU.

use gfx_hal::{
  device::Device,
  pso::Specialization as ShaderSpecializations,
};
#[cfg(test)]
use mockall::automock;

use super::gpu;

/// The type of shader that a shader module represents. Different shader types
/// are used for different operations in the rendering pipeline.
pub enum ShaderModuleType {
  Vertex,
  Fragment,
  Compute,
}

/// Builder class for creating a shader module.
pub struct ShaderModuleBuilder {
  entry_name: String,
  specializations: ShaderSpecializations<'static>,
}

#[cfg_attr(test, automock)]
impl ShaderModuleBuilder {
  pub fn new() -> Self {
    return Self {
      entry_name: "main".to_string(),
      specializations: ShaderSpecializations::EMPTY,
    };
  }

  /// Define the shader entry point (Defaults to main)
  pub fn with_entry_name(mut self, entry_name: &str) -> Self {
    self.entry_name = entry_name.to_string();
    return self;
  }

  /// Attach specializations to the shader.
  pub fn with_specializations(
    mut self,
    specializations: ShaderSpecializations<'static>,
  ) -> Self {
    self.specializations = specializations;
    return self;
  }

  /// Builds the shader binary into a shader module located on the GPU.
  /// ShaderModules are specific to gfx-hal and can be used for building
  /// RenderPipelines
  pub fn build<RenderBackend: gfx_hal::Backend>(
    self,
    gpu: &mut gpu::Gpu<RenderBackend>,
    shader_binary: &Vec<u32>,
    shader_type: ShaderModuleType,
  ) -> ShaderModule<RenderBackend> {
    let shader_module = unsafe {
      gpu
        .internal_logical_device()
        .create_shader_module(&shader_binary)
        .expect("Failed to create a shader module.")
    };

    return ShaderModule {
      entry_name: self.entry_name,
      shader_module,
      specializations: self.specializations,
      shader_type,
    };
  }
}

/// Shader modules are used for uploading shaders into the render pipeline.
pub struct ShaderModule<RenderBackend: gfx_hal::Backend> {
  entry_name: String,
  shader_module: RenderBackend::ShaderModule,
  specializations: ShaderSpecializations<'static>,
  shader_type: ShaderModuleType,
}

#[cfg_attr(test, automock)]
impl<RenderBackend: gfx_hal::Backend> ShaderModule<RenderBackend> {
  /// Destroy the shader module and free the memory on the GPU.
  pub fn destroy(self, gpu: &mut gpu::Gpu<RenderBackend>) {
    unsafe {
      gpu
        .internal_logical_device()
        .destroy_shader_module(self.shader_module)
    }
  }

  /// Get the entry point that this shader module is using.
  pub fn entry(&self) -> &str {
    return self.entry_name.as_str();
  }

  /// Get the specializations being applied to the current shader module.
  pub fn specializations(&self) -> &ShaderSpecializations<'static> {
    return &self.specializations;
  }
}

#[cfg(test)]
mod tests {

  /// Test that we can create a shader module builder and it has the correct
  /// defaults.
  #[test]
  fn shader_builder_initial_state() {
    let shader_builder = super::ShaderModuleBuilder::new();
    assert_eq!(shader_builder.entry_name, "main");
    assert_eq!(shader_builder.specializations.data.len(), 0);
  }

  /// Test that we can create a shader module builder with a custom entry point
  /// & default specializations.
  #[test]
  fn shader_builder_with_properties() {
    let shader_builder = super::ShaderModuleBuilder::new()
      .with_entry_name("test")
      .with_specializations(super::ShaderSpecializations::default());
    assert_eq!(shader_builder.entry_name, "test");
    assert_eq!(
      shader_builder.specializations.data,
      super::ShaderSpecializations::default().data
    );
  }

  #[test]
  fn shader_builder_builds_correctly() {
    let shader_builder = super::ShaderModuleBuilder::new()
      .with_entry_name("test")
      .with_specializations(super::ShaderSpecializations::default());
  }
}

/// Internal functions for the shader module. User applications most likely
/// should not use these functions directly nor should they need to.
pub mod internal {
  use super::ShaderModule;

  /// Retrieve the underlying gfx-hal shader module given the lambda-platform
  /// implemented shader module. Useful for creating gfx-hal entry points and
  /// attaching the shader to rendering pipelines.
  #[inline]
  pub fn module_for<RenderBackend: gfx_hal::Backend>(
    shader_module: &ShaderModule<RenderBackend>,
  ) -> &RenderBackend::ShaderModule {
    return &shader_module.shader_module;
  }
}