1use std::{any, collections::HashMap, fmt, fs, path::Path, str, sync::Arc};
2
3const FAILURE_DUMP_NAME: &str = "_failure.wgsl";
4
5#[derive(blade_macros::Flat)]
6pub struct CookedShader<'a> {
7 data: &'a [u8],
8}
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash)]
11pub struct Meta;
12impl fmt::Display for Meta {
13 fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
14 Ok(())
15 }
16}
17
18pub struct Shader {
19 pub raw: Result<blade_graphics::Shader, &'static str>,
20}
21
22pub enum Expansion {
23 Values(HashMap<String, u32>),
24 Bool(bool),
25}
26impl Expansion {
27 pub fn from_enum<E: strum::IntoEnumIterator + fmt::Debug + Into<u32>>() -> Self {
28 Self::Values(
29 E::iter()
30 .map(|variant| (format!("{variant:?}"), variant.into()))
31 .collect(),
32 )
33 }
34 pub fn from_bitflags<F: bitflags::Flags<Bits = u32>>() -> Self {
35 Self::Values(
36 F::FLAGS
37 .iter()
38 .map(|flag| (flag.name().to_string(), flag.value().bits()))
39 .collect(),
40 )
41 }
42}
43
44pub struct Baker {
45 gpu_context: Arc<blade_graphics::Context>,
46 expansions: HashMap<String, Expansion>,
47}
48
49impl Baker {
50 pub fn new(gpu_context: &Arc<blade_graphics::Context>) -> Self {
51 Self {
52 gpu_context: Arc::clone(gpu_context),
53 expansions: HashMap::default(),
54 }
55 }
56
57 fn register<T>(&mut self, expansion: Expansion) {
58 let full_name = any::type_name::<T>();
59 let short_name = full_name.split("::").last().unwrap().to_string();
60 self.expansions.insert(short_name, expansion);
61 }
62
63 pub fn register_enum<E: strum::IntoEnumIterator + fmt::Debug + Into<u32>>(&mut self) {
64 self.register::<E>(Expansion::from_enum::<E>());
65 }
66
67 pub fn register_bitflags<F: bitflags::Flags<Bits = u32>>(&mut self) {
68 self.register::<F>(Expansion::from_bitflags::<F>());
69 }
70
71 pub fn register_bool(&mut self, name: &str, value: bool) {
72 self.expansions
73 .insert(name.to_string(), Expansion::Bool(value));
74 }
75}
76
77fn parse_impl(
78 text_raw: &[u8],
79 base_path: &Path,
80 text_out: &mut String,
81 cooker: &blade_asset::Cooker<Baker>,
82 expansions: &HashMap<String, Expansion>,
83) {
84 use std::fmt::Write as _;
85
86 let text_in = str::from_utf8(text_raw).unwrap();
87 for line in text_in.lines() {
88 if line.starts_with("#include") {
89 let include_path = match line.split('"').nth(1) {
90 Some(include) => base_path.join(include),
91 None => panic!("Unable to extract the include path from: {line}"),
92 };
93 let include = cooker.add_dependency(&include_path);
94 writeln!(text_out, "//{}", line).unwrap();
95 parse_impl(
96 &include,
97 include_path.parent().unwrap(),
98 text_out,
99 cooker,
100 expansions,
101 );
102 } else if line.starts_with("#use") {
103 let type_name = line.split_whitespace().last().unwrap();
104 match expansions[type_name] {
105 Expansion::Values(ref map) => {
106 for (key, value) in map.iter() {
107 writeln!(text_out, "const {}_{}: u32 = {}u;", type_name, key, value)
108 .unwrap();
109 }
110 }
111 Expansion::Bool(value) => {
112 writeln!(text_out, "const {}: bool = {};", type_name, value).unwrap();
113 }
114 }
115 } else {
116 *text_out += line;
117 }
118 *text_out += "\n";
119 }
120}
121
122pub fn parse_shader(
123 text_raw: &[u8],
124 cooker: &blade_asset::Cooker<Baker>,
125 expansions: &HashMap<String, Expansion>,
126) -> String {
127 let mut text_out = String::new();
128 parse_impl(text_raw, ".".as_ref(), &mut text_out, cooker, expansions);
129 text_out
130}
131
132impl blade_asset::Baker for Baker {
133 type Meta = Meta;
134 type Data<'a> = CookedShader<'a>;
135 type Output = Shader;
136 fn cook(
137 &self,
138 source: &[u8],
139 extension: &str,
140 _meta: Meta,
141 cooker: Arc<blade_asset::Cooker<Self>>,
142 _exe_context: &choir::ExecutionContext,
143 ) {
144 assert_eq!(extension, "wgsl");
145 let text_out = parse_shader(source, &cooker, &self.expansions);
146 cooker.finish(CookedShader {
147 data: text_out.as_bytes(),
148 });
149 }
150 fn serve(&self, cooked: CookedShader, _exe_context: &choir::ExecutionContext) -> Shader {
151 let source = str::from_utf8(cooked.data).unwrap();
152 let raw = self
153 .gpu_context
154 .try_create_shader(blade_graphics::ShaderDesc { source });
155 if let Err(e) = raw {
156 let _ = fs::write(FAILURE_DUMP_NAME, source);
157 log::warn!("Shader compilation failed: {e:?}, source dumped as '{FAILURE_DUMP_NAME}'.")
158 }
159 Shader { raw }
160 }
161 fn delete(&self, _output: Shader) {}
162}