spirv_to_dxil/
lib.rs

1//! Rust bindings for [spirv-to-dxil](https://gitlab.freedesktop.org/mesa/mesa/-/tree/main/src/microsoft/spirv_to_dxil).
2//!
3//! For the lower-level C interface, see the [spirv-to-dxil-sys](https://docs.rs/spirv-to-dxil-sys/) crate.
4//!
5//! ## Push Constant Buffers
6//! SPIR-V shaders that use Push Constants must choose register assignments that correspond with the Root Descriptor that the compiled shader will
7//! use for a constant buffer to store push constants with [PushConstantBufferConfig](crate::PushConstantBufferConfig).
8//!
9//! ## Runtime Data
10//! For some vertex and compute shaders, a constant buffer provided at runtime is required to be bound.
11//!
12//! You can check if the compiled shader requires runtime data with
13//! [`DxilObject::requires_runtime_data`](crate::DxilObject::requires_runtime_data).
14//!
15//! If your shader requires runtime data, then register assignments must be chosen in
16//! [RuntimeDataBufferConfig](crate::RuntimeDataBufferConfig).
17//!
18//! See the [`runtime`](crate::runtime) module for how to construct the expected runtime data to be bound in a constant buffer.
19mod ctypes;
20mod error;
21mod logger;
22mod object;
23pub mod runtime;
24mod specialization;
25
26pub use crate::error::SpirvToDxilError;
27pub use ctypes::*;
28pub use object::*;
29pub use specialization::*;
30pub use spirv_to_dxil_sys::DXIL_SPIRV_MAX_VIEWPORT;
31
32use crate::logger::Logger;
33use spirv_to_dxil_sys::dxil_spirv_object;
34use std::mem::MaybeUninit;
35
36fn spirv_to_dxil_inner(
37    spirv_words: &[u32],
38    specializations: Option<&[Specialization]>,
39    entry_point: &[u8],
40    stage: ShaderStage,
41    validator_version_max: ValidatorVersion,
42    runtime_conf: &RuntimeConfig,
43    dump_nir: bool,
44    logger: &spirv_to_dxil_sys::dxil_spirv_logger,
45    out: &mut MaybeUninit<dxil_spirv_object>,
46) -> Result<bool, SpirvToDxilError> {
47    if runtime_conf.push_constant_cbv.register_space > 31
48        || runtime_conf.runtime_data_cbv.register_space > 31
49    {
50        return Err(SpirvToDxilError::RegisterSpaceOverflow(std::cmp::max(
51            runtime_conf.push_constant_cbv.register_space,
52            runtime_conf.runtime_data_cbv.register_space,
53        )));
54    }
55    let num_specializations = specializations.map(|o| o.len()).unwrap_or(0) as u32;
56    let mut specializations: Option<Vec<spirv_to_dxil_sys::dxil_spirv_specialization>> =
57        specializations.map(|o| o.into_iter().map(|o| (*o).into()).collect());
58
59    let debug = spirv_to_dxil_sys::dxil_spirv_debug_options { dump_nir };
60
61    unsafe {
62        Ok(spirv_to_dxil_sys::spirv_to_dxil(
63            spirv_words.as_ptr(),
64            spirv_words.len(),
65            specializations
66                .as_mut()
67                .map_or(std::ptr::null_mut(), |x| x.as_mut_ptr()),
68            num_specializations,
69            stage,
70            entry_point.as_ptr().cast(),
71            validator_version_max,
72            &debug,
73            runtime_conf,
74            logger,
75            out.as_mut_ptr(),
76        ))
77    }
78}
79
80/// Dump the parsed NIR output of the SPIR-V to stdout.
81pub fn dump_nir(
82    spirv_words: &[u32],
83    specializations: Option<&[Specialization]>,
84    entry_point: impl AsRef<str>,
85    stage: ShaderStage,
86    validator_version_max: ValidatorVersion,
87    runtime_conf: &RuntimeConfig,
88) -> Result<bool, SpirvToDxilError> {
89    let entry_point = entry_point.as_ref();
90    let mut entry_point = String::from(entry_point).into_bytes();
91    entry_point.push(0);
92
93    let mut out = MaybeUninit::uninit();
94    spirv_to_dxil_inner(
95        spirv_words,
96        specializations,
97        &entry_point,
98        stage,
99        validator_version_max,
100        runtime_conf,
101        true,
102        &logger::DEBUG_LOGGER,
103        &mut out,
104    )
105}
106
107/// Compile SPIR-V words to a DXIL blob.
108///
109/// If `validator_version` is not `None`, then `dxil.dll` must be in the load path to output
110/// a valid blob,
111///
112/// If `validator_version` is none, validation will be skipped and the resulting blobs will
113/// be fakesigned.
114pub fn spirv_to_dxil(
115    spirv_words: &[u32],
116    specializations: Option<&[Specialization]>,
117    entry_point: impl AsRef<str>,
118    stage: ShaderStage,
119    validator_version_max: ValidatorVersion,
120    runtime_conf: &RuntimeConfig,
121) -> Result<DxilObject, SpirvToDxilError> {
122    let entry_point = entry_point.as_ref();
123    let mut entry_point = String::from(entry_point).into_bytes();
124    entry_point.push(0);
125
126    let logger = Logger::new();
127    let logger = logger.into_logger();
128    let mut out = MaybeUninit::uninit();
129
130    let result = spirv_to_dxil_inner(
131        spirv_words,
132        specializations,
133        &entry_point,
134        stage,
135        validator_version_max,
136        runtime_conf,
137        false,
138        &logger,
139        &mut out,
140    )?;
141
142    let logger = unsafe { Logger::finalize(logger) };
143
144    if result {
145        let out = unsafe { out.assume_init() };
146
147        if validator_version_max == ValidatorVersion::None {
148            let size = out.binary.size;
149            let blob = unsafe { ::core::slice::from_raw_parts_mut(out.binary.buffer as *mut u8, size) };
150            mach_siegbert_vogt_dxcsa::sign_in_place(blob);
151        }
152
153        Ok(DxilObject::new(out))
154    } else {
155        Err(SpirvToDxilError::CompilerError(logger))
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn dump_nir() {
165        let fragment: &[u8] = include_bytes!("../test/fragment.spv");
166        let fragment = Vec::from(fragment);
167        let fragment = bytemuck::cast_slice(&fragment);
168
169        super::dump_nir(
170            &fragment,
171            None,
172            "main",
173            ShaderStage::Fragment,
174            ValidatorVersion::None,
175            &RuntimeConfig::default(),
176        )
177        .expect("failed to compile");
178    }
179
180    #[test]
181    fn test_compile() {
182        let fragment: &[u8] = include_bytes!("../test/fragment.spv");
183        let fragment = Vec::from(fragment);
184        let fragment = bytemuck::cast_slice(&fragment);
185
186        super::spirv_to_dxil(
187            &fragment,
188            None,
189            "main",
190            ShaderStage::Fragment,
191            ValidatorVersion::None,
192            &RuntimeConfig {
193                runtime_data_cbv: RuntimeDataBufferConfig {
194                    register_space: 0,
195                    base_shader_register: 0,
196                },
197                push_constant_cbv: PushConstantBufferConfig {
198                    register_space: 31,
199                    base_shader_register: 1,
200                },
201                shader_model_max: ShaderModel::ShaderModel6_1,
202                ..RuntimeConfig::default()
203            },
204        )
205        .expect("failed to compile");
206    }
207
208    #[test]
209    fn test_validate() {
210        let fragment: &[u8] = include_bytes!("../test/vertex.spv");
211        let fragment = Vec::from(fragment);
212        let fragment = bytemuck::cast_slice(&fragment);
213
214        super::spirv_to_dxil(
215            &fragment,
216            None,
217            "main",
218            ShaderStage::Vertex,
219            ValidatorVersion::None,
220            &RuntimeConfig::default(),
221        )
222        .expect("failed to compile");
223    }
224}