blade_render/
shader.rs

1use std::{any, collections::HashMap, fmt, fs, path::Path, str, sync::Arc};
2
3const FAILURE_DUMP_NAME: &str = "_failure.wgsl";
4
5#[derive(blade_macros::Flat)]
6pub struct CookedShader<'a> {
7    data: &'a [u8],
8}
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash)]
11pub struct Meta;
12impl fmt::Display for Meta {
13    fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
14        Ok(())
15    }
16}
17
18pub struct Shader {
19    pub raw: Result<blade_graphics::Shader, &'static str>,
20}
21
22pub enum Expansion {
23    Values(HashMap<String, u32>),
24    Bool(bool),
25}
26impl Expansion {
27    pub fn from_enum<E: strum::IntoEnumIterator + fmt::Debug + Into<u32>>() -> Self {
28        Self::Values(
29            E::iter()
30                .map(|variant| (format!("{variant:?}"), variant.into()))
31                .collect(),
32        )
33    }
34    pub fn from_bitflags<F: bitflags::Flags<Bits = u32>>() -> Self {
35        Self::Values(
36            F::FLAGS
37                .iter()
38                .map(|flag| (flag.name().to_string(), flag.value().bits()))
39                .collect(),
40        )
41    }
42}
43
44pub struct Baker {
45    gpu_context: Arc<blade_graphics::Context>,
46    expansions: HashMap<String, Expansion>,
47}
48
49impl Baker {
50    pub fn new(gpu_context: &Arc<blade_graphics::Context>) -> Self {
51        Self {
52            gpu_context: Arc::clone(gpu_context),
53            expansions: HashMap::default(),
54        }
55    }
56
57    fn register<T>(&mut self, expansion: Expansion) {
58        let full_name = any::type_name::<T>();
59        let short_name = full_name.split("::").last().unwrap().to_string();
60        self.expansions.insert(short_name, expansion);
61    }
62
63    pub fn register_enum<E: strum::IntoEnumIterator + fmt::Debug + Into<u32>>(&mut self) {
64        self.register::<E>(Expansion::from_enum::<E>());
65    }
66
67    pub fn register_bitflags<F: bitflags::Flags<Bits = u32>>(&mut self) {
68        self.register::<F>(Expansion::from_bitflags::<F>());
69    }
70
71    pub fn register_bool(&mut self, name: &str, value: bool) {
72        self.expansions
73            .insert(name.to_string(), Expansion::Bool(value));
74    }
75}
76
77fn parse_impl(
78    text_raw: &[u8],
79    base_path: &Path,
80    text_out: &mut String,
81    cooker: &blade_asset::Cooker<Baker>,
82    expansions: &HashMap<String, Expansion>,
83) {
84    use std::fmt::Write as _;
85
86    let text_in = str::from_utf8(text_raw).unwrap();
87    for line in text_in.lines() {
88        if line.starts_with("#include") {
89            let include_path = match line.split('"').nth(1) {
90                Some(include) => base_path.join(include),
91                None => panic!("Unable to extract the include path from: {line}"),
92            };
93            let include = cooker.add_dependency(&include_path);
94            writeln!(text_out, "//{}", line).unwrap();
95            parse_impl(
96                &include,
97                include_path.parent().unwrap(),
98                text_out,
99                cooker,
100                expansions,
101            );
102        } else if line.starts_with("#use") {
103            let type_name = line.split_whitespace().last().unwrap();
104            match expansions[type_name] {
105                Expansion::Values(ref map) => {
106                    for (key, value) in map.iter() {
107                        writeln!(text_out, "const {}_{}: u32 = {}u;", type_name, key, value)
108                            .unwrap();
109                    }
110                }
111                Expansion::Bool(value) => {
112                    writeln!(text_out, "const {}: bool = {};", type_name, value).unwrap();
113                }
114            }
115        } else {
116            *text_out += line;
117        }
118        *text_out += "\n";
119    }
120}
121
122pub fn parse_shader(
123    text_raw: &[u8],
124    cooker: &blade_asset::Cooker<Baker>,
125    expansions: &HashMap<String, Expansion>,
126) -> String {
127    let mut text_out = String::new();
128    parse_impl(text_raw, ".".as_ref(), &mut text_out, cooker, expansions);
129    text_out
130}
131
132impl blade_asset::Baker for Baker {
133    type Meta = Meta;
134    type Data<'a> = CookedShader<'a>;
135    type Output = Shader;
136    fn cook(
137        &self,
138        source: &[u8],
139        extension: &str,
140        _meta: Meta,
141        cooker: Arc<blade_asset::Cooker<Self>>,
142        _exe_context: &choir::ExecutionContext,
143    ) {
144        assert_eq!(extension, "wgsl");
145        let text_out = parse_shader(source, &cooker, &self.expansions);
146        cooker.finish(CookedShader {
147            data: text_out.as_bytes(),
148        });
149    }
150    fn serve(&self, cooked: CookedShader, _exe_context: &choir::ExecutionContext) -> Shader {
151        let source = str::from_utf8(cooked.data).unwrap();
152        let raw = self
153            .gpu_context
154            .try_create_shader(blade_graphics::ShaderDesc { source });
155        if let Err(e) = raw {
156            let _ = fs::write(FAILURE_DUMP_NAME, source);
157            log::warn!("Shader compilation failed: {e:?}, source dumped as '{FAILURE_DUMP_NAME}'.")
158        }
159        Shader { raw }
160    }
161    fn delete(&self, _output: Shader) {}
162}