librashader_reflect/back/
hlsl.rs

1use crate::back::targets::HLSL;
2use crate::back::{CompileReflectShader, CompilerBackend, FromCompilation};
3use crate::error::ShaderReflectError;
4use crate::front::SpirvCompilation;
5use crate::reflect::cross::hlsl::HlslReflect;
6use crate::reflect::cross::{CompiledProgram, SpirvCross};
7
8/// The HLSL shader model version to target.
9pub use spirv_cross2::compile::hlsl::HlslShaderModel;
10
11/// Buffer assignment information
12#[derive(Debug, Clone)]
13pub struct HlslBufferAssignment {
14    /// The name of the buffer
15    pub name: String,
16    /// The id of the buffer
17    pub id: u32,
18}
19
20/// Buffer assignment information
21#[derive(Debug, Clone, Default)]
22pub struct HlslBufferAssignments {
23    /// Buffer assignment information for UBO
24    pub ubo: Option<HlslBufferAssignment>,
25    /// Buffer assignment information for Push
26    pub push: Option<HlslBufferAssignment>,
27}
28
29impl HlslBufferAssignments {
30    fn find_mangled_id(mangled_name: &str) -> Option<u32> {
31        if !mangled_name.starts_with("_") {
32            return None;
33        }
34
35        let Some(next_underscore) = mangled_name[1..].find("_") else {
36            return None;
37        };
38
39        mangled_name[1..next_underscore + 1].parse().ok()
40    }
41
42    fn find_mangled_name(buffer_name: &str, uniform_name: &str, mangled_name: &str) -> bool {
43        // name prependded
44        if mangled_name[buffer_name.len()..].starts_with("_")
45            && &mangled_name[buffer_name.len() + 1..] == uniform_name
46        {
47            return true;
48        }
49        false
50    }
51
52    // Check if the mangled name matches.
53    pub fn contains_uniform(&self, uniform_name: &str, mangled_name: &str) -> bool {
54        let is_likely_id_mangled = mangled_name.starts_with("_");
55        if !mangled_name.ends_with(uniform_name) {
56            return false;
57        }
58
59        if let Some(ubo) = &self.ubo {
60            if is_likely_id_mangled {
61                if let Some(id) = Self::find_mangled_id(mangled_name) {
62                    if id == ubo.id {
63                        return true;
64                    }
65                }
66            }
67
68            // name prependded
69            if Self::find_mangled_name(&ubo.name, uniform_name, mangled_name) {
70                return true;
71            }
72        }
73
74        if let Some(push) = &self.push {
75            if is_likely_id_mangled {
76                if let Some(id) = Self::find_mangled_id(mangled_name) {
77                    if id == push.id {
78                        return true;
79                    }
80                }
81            }
82
83            // name prependded
84            if Self::find_mangled_name(&push.name, uniform_name, mangled_name) {
85                return true;
86            }
87        }
88
89        // Sometimes SPIRV-cross will assign variables to "global"
90        if Self::find_mangled_name("global", uniform_name, mangled_name) {
91            return true;
92        }
93
94        false
95    }
96}
97
98/// The context for a HLSL compilation via spirv-cross.
99pub struct CrossHlslContext {
100    /// The compiled HLSL program.
101    pub artifact: CompiledProgram<spirv_cross2::targets::Hlsl>,
102    pub vertex_buffers: HlslBufferAssignments,
103    pub fragment_buffers: HlslBufferAssignments,
104}
105
106#[cfg(not(feature = "stable"))]
107impl FromCompilation<SpirvCompilation, SpirvCross> for HLSL {
108    type Target = HLSL;
109    type Options = Option<HlslShaderModel>;
110    type Context = CrossHlslContext;
111    type Output = impl CompileReflectShader<Self::Target, SpirvCompilation, SpirvCross>;
112
113    fn from_compilation(
114        compile: SpirvCompilation,
115    ) -> Result<CompilerBackend<Self::Output>, ShaderReflectError> {
116        Ok(CompilerBackend {
117            backend: HlslReflect::try_from(&compile)?,
118        })
119    }
120}
121
122#[cfg(feature = "stable")]
123impl FromCompilation<SpirvCompilation, SpirvCross> for HLSL {
124    type Target = HLSL;
125    type Options = Option<HlslShaderModel>;
126    type Context = CrossHlslContext;
127    type Output = Box<dyn CompileReflectShader<Self::Target, SpirvCompilation, SpirvCross> + Send>;
128
129    fn from_compilation(
130        compile: SpirvCompilation,
131    ) -> Result<CompilerBackend<Self::Output>, ShaderReflectError> {
132        Ok(CompilerBackend {
133            backend: Box::new(HlslReflect::try_from(&compile)?),
134        })
135    }
136}
137
138#[cfg(test)]
139mod test {
140    use crate::back::hlsl::HlslBufferAssignments;
141
142    #[test]
143    pub fn mangled_id_test() {
144        assert_eq!(HlslBufferAssignments::find_mangled_id("_19_MVP"), Some(19));
145        assert_eq!(HlslBufferAssignments::find_mangled_id("_19"), None);
146        assert_eq!(HlslBufferAssignments::find_mangled_id("_19_"), Some(19));
147        assert_eq!(HlslBufferAssignments::find_mangled_id("19_"), None);
148        assert_eq!(HlslBufferAssignments::find_mangled_id("_19MVP"), None);
149        assert_eq!(
150            HlslBufferAssignments::find_mangled_id("_19_29_MVP"),
151            Some(19)
152        );
153    }
154
155    #[test]
156    pub fn mangled_name_test() {
157        assert!(HlslBufferAssignments::find_mangled_name(
158            "params",
159            "MVP",
160            "params_MVP"
161        ));
162    }
163}