1mod ctypes;
20mod error;
21mod logger;
22mod object;
23pub mod runtime;
24mod specialization;
25
26pub use crate::error::SpirvToDxilError;
27pub use ctypes::*;
28pub use object::*;
29pub use specialization::*;
30pub use spirv_to_dxil_sys::DXIL_SPIRV_MAX_VIEWPORT;
31
32use crate::logger::Logger;
33use spirv_to_dxil_sys::dxil_spirv_object;
34use std::mem::MaybeUninit;
35
36fn spirv_to_dxil_inner(
37 spirv_words: &[u32],
38 specializations: Option<&[Specialization]>,
39 entry_point: &[u8],
40 stage: ShaderStage,
41 validator_version_max: ValidatorVersion,
42 runtime_conf: &RuntimeConfig,
43 dump_nir: bool,
44 logger: &spirv_to_dxil_sys::dxil_spirv_logger,
45 out: &mut MaybeUninit<dxil_spirv_object>,
46) -> Result<bool, SpirvToDxilError> {
47 if runtime_conf.push_constant_cbv.register_space > 31
48 || runtime_conf.runtime_data_cbv.register_space > 31
49 {
50 return Err(SpirvToDxilError::RegisterSpaceOverflow(std::cmp::max(
51 runtime_conf.push_constant_cbv.register_space,
52 runtime_conf.runtime_data_cbv.register_space,
53 )));
54 }
55 let num_specializations = specializations.map(|o| o.len()).unwrap_or(0) as u32;
56 let mut specializations: Option<Vec<spirv_to_dxil_sys::dxil_spirv_specialization>> =
57 specializations.map(|o| o.into_iter().map(|o| (*o).into()).collect());
58
59 let debug = spirv_to_dxil_sys::dxil_spirv_debug_options { dump_nir };
60
61 unsafe {
62 Ok(spirv_to_dxil_sys::spirv_to_dxil(
63 spirv_words.as_ptr(),
64 spirv_words.len(),
65 specializations
66 .as_mut()
67 .map_or(std::ptr::null_mut(), |x| x.as_mut_ptr()),
68 num_specializations,
69 stage,
70 entry_point.as_ptr().cast(),
71 validator_version_max,
72 &debug,
73 runtime_conf,
74 logger,
75 out.as_mut_ptr(),
76 ))
77 }
78}
79
80pub fn dump_nir(
82 spirv_words: &[u32],
83 specializations: Option<&[Specialization]>,
84 entry_point: impl AsRef<str>,
85 stage: ShaderStage,
86 validator_version_max: ValidatorVersion,
87 runtime_conf: &RuntimeConfig,
88) -> Result<bool, SpirvToDxilError> {
89 let entry_point = entry_point.as_ref();
90 let mut entry_point = String::from(entry_point).into_bytes();
91 entry_point.push(0);
92
93 let mut out = MaybeUninit::uninit();
94 spirv_to_dxil_inner(
95 spirv_words,
96 specializations,
97 &entry_point,
98 stage,
99 validator_version_max,
100 runtime_conf,
101 true,
102 &logger::DEBUG_LOGGER,
103 &mut out,
104 )
105}
106
107pub fn spirv_to_dxil(
115 spirv_words: &[u32],
116 specializations: Option<&[Specialization]>,
117 entry_point: impl AsRef<str>,
118 stage: ShaderStage,
119 validator_version_max: ValidatorVersion,
120 runtime_conf: &RuntimeConfig,
121) -> Result<DxilObject, SpirvToDxilError> {
122 let entry_point = entry_point.as_ref();
123 let mut entry_point = String::from(entry_point).into_bytes();
124 entry_point.push(0);
125
126 let logger = Logger::new();
127 let logger = logger.into_logger();
128 let mut out = MaybeUninit::uninit();
129
130 let result = spirv_to_dxil_inner(
131 spirv_words,
132 specializations,
133 &entry_point,
134 stage,
135 validator_version_max,
136 runtime_conf,
137 false,
138 &logger,
139 &mut out,
140 )?;
141
142 let logger = unsafe { Logger::finalize(logger) };
143
144 if result {
145 let out = unsafe { out.assume_init() };
146
147 if validator_version_max == ValidatorVersion::None {
148 let size = out.binary.size;
149 let blob = unsafe { ::core::slice::from_raw_parts_mut(out.binary.buffer as *mut u8, size) };
150 mach_siegbert_vogt_dxcsa::sign_in_place(blob);
151 }
152
153 Ok(DxilObject::new(out))
154 } else {
155 Err(SpirvToDxilError::CompilerError(logger))
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn dump_nir() {
165 let fragment: &[u8] = include_bytes!("../test/fragment.spv");
166 let fragment = Vec::from(fragment);
167 let fragment = bytemuck::cast_slice(&fragment);
168
169 super::dump_nir(
170 &fragment,
171 None,
172 "main",
173 ShaderStage::Fragment,
174 ValidatorVersion::None,
175 &RuntimeConfig::default(),
176 )
177 .expect("failed to compile");
178 }
179
180 #[test]
181 fn test_compile() {
182 let fragment: &[u8] = include_bytes!("../test/fragment.spv");
183 let fragment = Vec::from(fragment);
184 let fragment = bytemuck::cast_slice(&fragment);
185
186 super::spirv_to_dxil(
187 &fragment,
188 None,
189 "main",
190 ShaderStage::Fragment,
191 ValidatorVersion::None,
192 &RuntimeConfig {
193 runtime_data_cbv: RuntimeDataBufferConfig {
194 register_space: 0,
195 base_shader_register: 0,
196 },
197 push_constant_cbv: PushConstantBufferConfig {
198 register_space: 31,
199 base_shader_register: 1,
200 },
201 shader_model_max: ShaderModel::ShaderModel6_1,
202 ..RuntimeConfig::default()
203 },
204 )
205 .expect("failed to compile");
206 }
207
208 #[test]
209 fn test_validate() {
210 let fragment: &[u8] = include_bytes!("../test/vertex.spv");
211 let fragment = Vec::from(fragment);
212 let fragment = bytemuck::cast_slice(&fragment);
213
214 super::spirv_to_dxil(
215 &fragment,
216 None,
217 "main",
218 ShaderStage::Vertex,
219 ValidatorVersion::None,
220 &RuntimeConfig::default(),
221 )
222 .expect("failed to compile");
223 }
224}