librashader_reflect/back/
wgsl.rs

1use crate::back::targets::WGSL;
2use crate::back::{CompileReflectShader, CompilerBackend, FromCompilation};
3use crate::error::ShaderReflectError;
4use crate::front::SpirvCompilation;
5use crate::reflect::naga::{Naga, NagaLoweringOptions, NagaReflect};
6use naga::Module;
7
8/// The context for a WGSL compilation via Naga
9pub struct NagaWgslContext {
10    pub fragment: Module,
11    pub vertex: Module,
12}
13
14#[cfg(not(feature = "stable"))]
15impl FromCompilation<SpirvCompilation, Naga> for WGSL {
16    type Target = WGSL;
17    type Options = NagaLoweringOptions;
18    type Context = NagaWgslContext;
19    type Output = impl CompileReflectShader<Self::Target, SpirvCompilation, Naga>;
20
21    fn from_compilation(
22        compile: SpirvCompilation,
23    ) -> Result<CompilerBackend<Self::Output>, ShaderReflectError> {
24        Ok(CompilerBackend {
25            backend: NagaReflect::try_from(&compile)?,
26        })
27    }
28}
29
30#[cfg(feature = "stable")]
31impl FromCompilation<SpirvCompilation, Naga> for WGSL {
32    type Target = WGSL;
33    type Options = NagaLoweringOptions;
34    type Context = NagaWgslContext;
35    type Output = Box<dyn CompileReflectShader<Self::Target, SpirvCompilation, Naga> + Send>;
36
37    fn from_compilation(
38        compile: SpirvCompilation,
39    ) -> Result<CompilerBackend<Self::Output>, ShaderReflectError> {
40        Ok(CompilerBackend {
41            backend: Box::new(NagaReflect::try_from(&compile)?),
42        })
43    }
44}
45
46#[cfg(test)]
47mod test {
48    use crate::back::targets::WGSL;
49    use crate::back::{CompileShader, FromCompilation};
50    use crate::reflect::naga::NagaLoweringOptions;
51    use crate::reflect::semantics::{Semantic, ShaderSemantics, UniformSemantic, UniqueSemantics};
52    use crate::reflect::ReflectShader;
53    use bitflags::Flags;
54    use librashader_common::map::{FastHashMap, ShortString};
55    use librashader_preprocess::ShaderSource;
56
57    #[test]
58    pub fn test_into() {
59        let result =
60            ShaderSource::load("../test/shaders_slang/crt/shaders/slotmask.slang").unwrap();
61
62        // let result = ShaderSource::load("../test/shaders_slang/crt/shaders/crt-royale/src/crt-royale-scanlines-horizontal-apply-mask.slang").unwrap();
63        // let result = ShaderSource::load("../test/shaders_slang/crt/shaders/crt-royale/src/crt-royale-scanlines-horizontal-apply-mask.slang").unwrap();
64        // let result = ShaderSource::load("../test/basic.slang").unwrap();
65
66        let mut uniform_semantics: FastHashMap<ShortString, UniformSemantic> = Default::default();
67
68        for (_index, param) in result.parameters.iter().enumerate() {
69            uniform_semantics.insert(
70                param.1.id.clone(),
71                UniformSemantic::Unique(Semantic {
72                    semantics: UniqueSemantics::FloatParameter,
73                    index: (),
74                }),
75            );
76        }
77
78        let compilation = crate::front::SpirvCompilation::try_from(&result).unwrap();
79
80        let mut wgsl = WGSL::from_compilation(compilation).unwrap();
81
82        wgsl.reflect(
83            0,
84            &ShaderSemantics {
85                uniform_semantics,
86                texture_semantics: Default::default(),
87            },
88        )
89        .expect("");
90
91        let compiled = wgsl
92            .compile(NagaLoweringOptions {
93                write_pcb_as_ubo: true,
94                sampler_bind_group: 1,
95            })
96            .unwrap();
97
98        println!("{}", compiled.fragment);
99
100        // println!("{}", compiled.fragment);
101        // let mut loader = rspirv::dr::Loader::new();
102        // rspirv::binary::parse_words(compilation.vertex.as_binary(), &mut loader).unwrap();
103        // let module = loader.module();
104        //
105        // let outputs: Vec<&Instruction> = module
106        //     .types_global_values
107        //     .iter()
108        //     .filter(|i| i.class.opcode == Op::Variable)
109        //     .collect();
110        //
111        // println!("{outputs:#?}");
112    }
113}