1use std::{
3 collections::HashMap,
4 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not},
5 path::PathBuf,
6 str::FromStr,
7};
8
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
13pub enum ShadingLanguage {
14 Wgsl,
15 Hlsl,
16 Glsl,
17}
18
19#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
20pub struct ShaderStageMask(u32);
21
22impl ShaderStageMask {
23 pub const VERTEX: Self = Self(1 << 0);
24 pub const FRAGMENT: Self = Self(1 << 1);
25 pub const COMPUTE: Self = Self(1 << 2);
26 pub const TESSELATION_CONTROL: Self = Self(1 << 3);
27 pub const TESSELATION_EVALUATION: Self = Self(1 << 4);
28 pub const MESH: Self = Self(1 << 5);
29 pub const TASK: Self = Self(1 << 6);
30 pub const GEOMETRY: Self = Self(1 << 7);
31 pub const RAY_GENERATION: Self = Self(1 << 8);
32 pub const CLOSEST_HIT: Self = Self(1 << 9);
33 pub const ANY_HIT: Self = Self(1 << 10);
34 pub const CALLABLE: Self = Self(1 << 11);
35 pub const MISS: Self = Self(1 << 12);
36 pub const INTERSECT: Self = Self(1 << 13);
37}
38
39impl Default for ShaderStageMask {
40 fn default() -> Self {
41 Self(0)
42 }
43}
44impl ShaderStageMask {
45 pub const fn from_u32(x: u32) -> Self {
46 Self(x)
47 }
48 pub const fn as_u32(self) -> u32 {
49 self.0
50 }
51 pub const fn is_empty(self) -> bool {
52 self.0 == 0
53 }
54 pub const fn contains(self, other: &ShaderStage) -> bool {
55 let mask = other.as_mask();
56 self.0 & mask.0 == mask.0
57 }
58}
59impl BitOr for ShaderStageMask {
60 type Output = Self;
61 #[inline]
62 fn bitor(self, rhs: Self) -> Self {
63 Self(self.0 | rhs.0)
64 }
65}
66impl BitOrAssign for ShaderStageMask {
67 #[inline]
68 fn bitor_assign(&mut self, rhs: Self) {
69 *self = *self | rhs
70 }
71}
72impl BitAnd for ShaderStageMask {
73 type Output = Self;
74 #[inline]
75 fn bitand(self, rhs: Self) -> Self {
76 Self(self.0 & rhs.0)
77 }
78}
79impl BitAndAssign for ShaderStageMask {
80 #[inline]
81 fn bitand_assign(&mut self, rhs: Self) {
82 *self = *self & rhs
83 }
84}
85impl BitXor for ShaderStageMask {
86 type Output = Self;
87 #[inline]
88 fn bitxor(self, rhs: Self) -> Self {
89 Self(self.0 ^ rhs.0)
90 }
91}
92impl BitXorAssign for ShaderStageMask {
93 #[inline]
94 fn bitxor_assign(&mut self, rhs: Self) {
95 *self = *self ^ rhs
96 }
97}
98impl Not for ShaderStageMask {
99 type Output = Self;
100 #[inline]
101 fn not(self) -> Self {
102 Self(!self.0)
103 }
104}
105
106#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
108#[serde(rename_all = "camelCase")]
109pub enum ShaderStage {
110 Vertex,
111 Fragment, Compute,
113 TesselationControl, TesselationEvaluation, Mesh,
116 Task, Geometry,
118 RayGeneration,
119 ClosestHit,
120 AnyHit,
121 Callable,
122 Miss,
123 Intersect,
124}
125
126impl ShaderStage {
127 pub fn from_file_name(file_name: &String) -> Option<ShaderStage> {
129 let paths = HashMap::from([
131 ("vert", ShaderStage::Vertex),
132 ("frag", ShaderStage::Fragment),
133 ("comp", ShaderStage::Compute),
134 ("task", ShaderStage::Task),
135 ("mesh", ShaderStage::Mesh),
136 ("tesc", ShaderStage::TesselationControl),
137 ("tese", ShaderStage::TesselationEvaluation),
138 ("geom", ShaderStage::Geometry),
139 ("rgen", ShaderStage::RayGeneration),
140 ("rchit", ShaderStage::ClosestHit),
141 ("rahit", ShaderStage::AnyHit),
142 ("rcall", ShaderStage::Callable),
143 ("rmiss", ShaderStage::Miss),
144 ("rint", ShaderStage::Intersect),
145 ]);
146 let extension_list = file_name.rsplit(".");
147 for extension in extension_list {
148 if let Some(stage) = paths.get(extension) {
149 return Some(stage.clone());
150 } else {
151 continue;
152 }
153 }
154 None
156 }
157 pub const fn as_mask(&self) -> ShaderStageMask {
158 match self {
159 ShaderStage::Vertex => ShaderStageMask::VERTEX,
160 ShaderStage::Fragment => ShaderStageMask::FRAGMENT,
161 ShaderStage::Compute => ShaderStageMask::COMPUTE,
162 ShaderStage::TesselationControl => ShaderStageMask::TESSELATION_CONTROL,
163 ShaderStage::TesselationEvaluation => ShaderStageMask::TESSELATION_EVALUATION,
164 ShaderStage::Mesh => ShaderStageMask::MESH,
165 ShaderStage::Task => ShaderStageMask::TASK,
166 ShaderStage::Geometry => ShaderStageMask::GEOMETRY,
167 ShaderStage::RayGeneration => ShaderStageMask::RAY_GENERATION,
168 ShaderStage::ClosestHit => ShaderStageMask::CLOSEST_HIT,
169 ShaderStage::AnyHit => ShaderStageMask::ANY_HIT,
170 ShaderStage::Callable => ShaderStageMask::CALLABLE,
171 ShaderStage::Miss => ShaderStageMask::MISS,
172 ShaderStage::Intersect => ShaderStageMask::INTERSECT,
173 }
174 }
175 pub fn graphics() -> ShaderStageMask {
177 ShaderStageMask::VERTEX
178 | ShaderStageMask::FRAGMENT
179 | ShaderStageMask::GEOMETRY
180 | ShaderStageMask::TESSELATION_CONTROL
181 | ShaderStageMask::TESSELATION_EVALUATION
182 | ShaderStageMask::TASK
183 | ShaderStageMask::MESH
184 }
185 pub fn compute() -> ShaderStageMask {
187 ShaderStageMask::COMPUTE
188 }
189 pub fn raytracing() -> ShaderStageMask {
191 ShaderStageMask::RAY_GENERATION
192 | ShaderStageMask::INTERSECT
193 | ShaderStageMask::CLOSEST_HIT
194 | ShaderStageMask::ANY_HIT
195 | ShaderStageMask::MISS
196 | ShaderStageMask::INTERSECT
197 }
198}
199
200impl FromStr for ShaderStage {
201 type Err = ();
202
203 fn from_str(input: &str) -> Result<ShaderStage, Self::Err> {
204 let lower_input = input.to_lowercase();
206 match lower_input.as_str() {
207 "vertex" => Ok(ShaderStage::Vertex),
208 "fragment" | "pixel" => Ok(ShaderStage::Fragment),
209 "compute" => Ok(ShaderStage::Compute),
210 "tesselationcontrol" | "hull" => Ok(ShaderStage::TesselationControl),
211 "tesselationevaluation" | "domain" => Ok(ShaderStage::TesselationEvaluation),
212 "mesh" => Ok(ShaderStage::Mesh),
213 "task" | "amplification" => Ok(ShaderStage::Task),
214 "geometry" => Ok(ShaderStage::Geometry),
215 "raygeneration" => Ok(ShaderStage::RayGeneration),
216 "closesthit" => Ok(ShaderStage::ClosestHit),
217 "anyhit" => Ok(ShaderStage::AnyHit),
218 "callable" => Ok(ShaderStage::Callable),
219 "miss" => Ok(ShaderStage::Miss),
220 "intersect" => Ok(ShaderStage::Intersect),
221 _ => Err(()),
222 }
223 }
224}
225impl ToString for ShaderStage {
226 fn to_string(&self) -> String {
227 match self {
228 ShaderStage::Vertex => "vertex".to_string(),
229 ShaderStage::Fragment => "fragment".to_string(),
230 ShaderStage::Compute => "compute".to_string(),
231 ShaderStage::TesselationControl => "tesselationcontrol".to_string(),
232 ShaderStage::TesselationEvaluation => "tesselationevaluation".to_string(),
233 ShaderStage::Mesh => "mesh".to_string(),
234 ShaderStage::Task => "task".to_string(),
235 ShaderStage::Geometry => "geometry".to_string(),
236 ShaderStage::RayGeneration => "raygeneration".to_string(),
237 ShaderStage::ClosestHit => "closesthit".to_string(),
238 ShaderStage::AnyHit => "anyhit".to_string(),
239 ShaderStage::Callable => "callable".to_string(),
240 ShaderStage::Miss => "miss".to_string(),
241 ShaderStage::Intersect => "intersect".to_string(),
242 }
243 }
244}
245
246impl FromStr for ShadingLanguage {
247 type Err = ();
248
249 fn from_str(input: &str) -> Result<ShadingLanguage, Self::Err> {
250 match input {
251 "wgsl" => Ok(ShadingLanguage::Wgsl),
252 "hlsl" => Ok(ShadingLanguage::Hlsl),
253 "glsl" => Ok(ShadingLanguage::Glsl),
254 _ => Err(()),
255 }
256 }
257}
258impl ToString for ShadingLanguage {
259 fn to_string(&self) -> String {
260 String::from(match &self {
261 ShadingLanguage::Wgsl => "wgsl",
262 ShadingLanguage::Hlsl => "hlsl",
263 ShadingLanguage::Glsl => "glsl",
264 })
265 }
266}
267
268pub trait ShadingLanguageTag {
270 fn get_language() -> ShadingLanguage;
272}
273
274pub struct HlslShadingLanguageTag {}
276impl ShadingLanguageTag for HlslShadingLanguageTag {
277 fn get_language() -> ShadingLanguage {
278 ShadingLanguage::Hlsl
279 }
280}
281pub struct GlslShadingLanguageTag {}
283impl ShadingLanguageTag for GlslShadingLanguageTag {
284 fn get_language() -> ShadingLanguage {
285 ShadingLanguage::Glsl
286 }
287}
288pub struct WgslShadingLanguageTag {}
290impl ShadingLanguageTag for WgslShadingLanguageTag {
291 fn get_language() -> ShadingLanguage {
292 ShadingLanguage::Wgsl
293 }
294}
295
296#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
301pub enum HlslShaderModel {
302 ShaderModel1,
303 ShaderModel1_1,
304 ShaderModel1_2,
305 ShaderModel1_3,
306 ShaderModel1_4,
307 ShaderModel2,
308 ShaderModel3,
309 ShaderModel4,
310 ShaderModel4_1,
311 ShaderModel5,
312 ShaderModel5_1,
313 ShaderModel6,
314 ShaderModel6_1,
315 ShaderModel6_2,
316 ShaderModel6_3,
317 ShaderModel6_4,
318 ShaderModel6_5,
319 ShaderModel6_6,
320 ShaderModel6_7,
321 #[default]
322 ShaderModel6_8,
323}
324
325impl HlslShaderModel {
326 pub fn earliest() -> HlslShaderModel {
328 HlslShaderModel::ShaderModel1
329 }
330 pub fn latest() -> HlslShaderModel {
332 HlslShaderModel::ShaderModel6_8
333 }
334}
335
336#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
338pub enum HlslVersion {
339 V2016,
340 V2017,
341 V2018,
342 #[default]
343 V2021,
344}
345
346#[derive(Default, Debug, Clone, PartialEq, Eq)]
348pub struct HlslCompilationParams {
349 pub shader_model: HlslShaderModel,
350 pub version: HlslVersion,
351 pub enable16bit_types: bool,
352 pub spirv: bool,
353}
354
355#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
357pub enum GlslTargetClient {
358 Vulkan1_0,
359 Vulkan1_1,
360 Vulkan1_2,
361 #[default]
362 Vulkan1_3,
363 OpenGL450,
364}
365
366impl GlslTargetClient {
367 pub fn is_opengl(&self) -> bool {
369 match *self {
370 GlslTargetClient::OpenGL450 => true,
371 _ => false,
372 }
373 }
374}
375
376#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
378pub enum GlslSpirvVersion {
379 SPIRV1_0,
380 SPIRV1_1,
381 SPIRV1_2,
382 SPIRV1_3,
383 SPIRV1_4,
384 SPIRV1_5,
385 #[default]
386 SPIRV1_6,
387}
388#[derive(Default, Debug, Clone, PartialEq, Eq)]
390pub struct GlslCompilationParams {
391 pub client: GlslTargetClient,
392 pub spirv: GlslSpirvVersion,
393}
394
395#[derive(Default, Debug, Clone, PartialEq, Eq)]
397pub struct WgslCompilationParams {}
398
399#[derive(Default, Debug, Clone)]
401pub struct ShaderContextParams {
402 pub defines: HashMap<String, String>,
403 pub includes: Vec<PathBuf>,
404 pub path_remapping: HashMap<PathBuf, PathBuf>,
405}
406
407#[derive(Default, Debug, Clone)]
409pub struct ShaderCompilationParams {
410 pub entry_point: Option<String>,
411 pub shader_stage: Option<ShaderStage>,
412 pub hlsl: HlslCompilationParams,
413 pub glsl: GlslCompilationParams,
414 pub wgsl: WgslCompilationParams,
415}
416
417#[derive(Default, Debug, Clone)]
419pub struct ShaderParams {
420 pub context: ShaderContextParams,
421 pub compilation: ShaderCompilationParams,
422}