librashader_reflect/back/
hlsl.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
use crate::back::targets::HLSL;
use crate::back::{CompileReflectShader, CompilerBackend, FromCompilation};
use crate::error::ShaderReflectError;
use crate::front::SpirvCompilation;
use crate::reflect::cross::hlsl::HlslReflect;
use crate::reflect::cross::{CompiledProgram, SpirvCross};

/// The HLSL shader model version to target.
pub use spirv_cross2::compile::hlsl::HlslShaderModel;

/// Buffer assignment information
#[derive(Debug, Clone)]
pub struct HlslBufferAssignment {
    /// The name of the buffer
    pub name: String,
    /// The id of the buffer
    pub id: u32,
}

/// Buffer assignment information
#[derive(Debug, Clone, Default)]
pub struct HlslBufferAssignments {
    /// Buffer assignment information for UBO
    pub ubo: Option<HlslBufferAssignment>,
    /// Buffer assignment information for Push
    pub push: Option<HlslBufferAssignment>,
}

impl HlslBufferAssignments {
    fn find_mangled_id(mangled_name: &str) -> Option<u32> {
        if !mangled_name.starts_with("_") {
            return None;
        }

        let Some(next_underscore) = mangled_name[1..].find("_") else {
            return None;
        };

        mangled_name[1..next_underscore + 1].parse().ok()
    }

    fn find_mangled_name(buffer_name: &str, uniform_name: &str, mangled_name: &str) -> bool {
        // name prependded
        if mangled_name[buffer_name.len()..].starts_with("_")
            && &mangled_name[buffer_name.len() + 1..] == uniform_name
        {
            return true;
        }
        false
    }

    // Check if the mangled name matches.
    pub fn contains_uniform(&self, uniform_name: &str, mangled_name: &str) -> bool {
        let is_likely_id_mangled = mangled_name.starts_with("_");
        if !mangled_name.ends_with(uniform_name) {
            return false;
        }

        if let Some(ubo) = &self.ubo {
            if is_likely_id_mangled {
                if let Some(id) = Self::find_mangled_id(mangled_name) {
                    if id == ubo.id {
                        return true;
                    }
                }
            }

            // name prependded
            if Self::find_mangled_name(&ubo.name, uniform_name, mangled_name) {
                return true;
            }
        }

        if let Some(push) = &self.push {
            if is_likely_id_mangled {
                if let Some(id) = Self::find_mangled_id(mangled_name) {
                    if id == push.id {
                        return true;
                    }
                }
            }

            // name prependded
            if Self::find_mangled_name(&push.name, uniform_name, mangled_name) {
                return true;
            }
        }

        // Sometimes SPIRV-cross will assign variables to "global"
        if Self::find_mangled_name("global", uniform_name, mangled_name) {
            return true;
        }

        false
    }
}

/// The context for a HLSL compilation via spirv-cross.
pub struct CrossHlslContext {
    /// The compiled HLSL program.
    pub artifact: CompiledProgram<spirv_cross2::targets::Hlsl>,
    pub vertex_buffers: HlslBufferAssignments,
    pub fragment_buffers: HlslBufferAssignments,
}

#[cfg(not(feature = "stable"))]
impl FromCompilation<SpirvCompilation, SpirvCross> for HLSL {
    type Target = HLSL;
    type Options = Option<HlslShaderModel>;
    type Context = CrossHlslContext;
    type Output = impl CompileReflectShader<Self::Target, SpirvCompilation, SpirvCross>;

    fn from_compilation(
        compile: SpirvCompilation,
    ) -> Result<CompilerBackend<Self::Output>, ShaderReflectError> {
        Ok(CompilerBackend {
            backend: HlslReflect::try_from(&compile)?,
        })
    }
}

#[cfg(feature = "stable")]
impl FromCompilation<SpirvCompilation, SpirvCross> for HLSL {
    type Target = HLSL;
    type Options = Option<HlslShaderModel>;
    type Context = CrossHlslContext;
    type Output = Box<dyn CompileReflectShader<Self::Target, SpirvCompilation, SpirvCross> + Send>;

    fn from_compilation(
        compile: SpirvCompilation,
    ) -> Result<CompilerBackend<Self::Output>, ShaderReflectError> {
        Ok(CompilerBackend {
            backend: Box::new(HlslReflect::try_from(&compile)?),
        })
    }
}

#[cfg(test)]
mod test {
    use crate::back::hlsl::HlslBufferAssignments;

    #[test]
    pub fn mangled_id_test() {
        assert_eq!(HlslBufferAssignments::find_mangled_id("_19_MVP"), Some(19));
        assert_eq!(HlslBufferAssignments::find_mangled_id("_19"), None);
        assert_eq!(HlslBufferAssignments::find_mangled_id("_19_"), Some(19));
        assert_eq!(HlslBufferAssignments::find_mangled_id("19_"), None);
        assert_eq!(HlslBufferAssignments::find_mangled_id("_19MVP"), None);
        assert_eq!(
            HlslBufferAssignments::find_mangled_id("_19_29_MVP"),
            Some(19)
        );
    }

    #[test]
    pub fn mangled_name_test() {
        assert!(HlslBufferAssignments::find_mangled_name(
            "params",
            "MVP",
            "params_MVP"
        ));
    }
}