1#![allow(
2 clippy::too_many_arguments,
3 clippy::new_without_default,
4 clippy::type_complexity
5)]
6
7use crate::ffi::*;
8use crate::os::{HRESULT, LPCWSTR, LPWSTR, WCHAR};
9use crate::utils::{from_wide, to_wide, HassleError, Result};
10use com::{class, interfaces::IUnknown, production::Class, production::ClassAllocation, Interface};
11use libloading::{library_filename, Library, Symbol};
12use std::cell::RefCell;
13use std::ops::Deref;
14use std::path::PathBuf;
15use std::pin::Pin;
16
17pub struct DxcBlob {
18 inner: IDxcBlob,
19}
20
21impl DxcBlob {
22 fn new(inner: IDxcBlob) -> Self {
23 Self { inner }
24 }
25
26 pub fn as_slice<T>(&self) -> &[T] {
27 unsafe {
28 std::slice::from_raw_parts(
29 self.inner.get_buffer_pointer().cast(),
30 self.inner.get_buffer_size() / std::mem::size_of::<T>(),
31 )
32 }
33 }
34
35 pub fn as_mut_slice<T>(&mut self) -> &mut [T] {
36 unsafe {
37 std::slice::from_raw_parts_mut(
38 self.inner.get_buffer_pointer().cast(),
39 self.inner.get_buffer_size() / std::mem::size_of::<T>(),
40 )
41 }
42 }
43
44 pub fn to_vec<T>(&self) -> Vec<T>
45 where
46 T: Clone,
47 {
48 self.as_slice().to_vec()
49 }
50}
51
52impl AsRef<[u8]> for DxcBlob {
53 fn as_ref(&self) -> &[u8] {
54 self.as_slice()
55 }
56}
57
58impl AsMut<[u8]> for DxcBlob {
59 fn as_mut(&mut self) -> &mut [u8] {
60 self.as_mut_slice()
61 }
62}
63
64pub struct DxcBlobEncoding {
65 inner: IDxcBlobEncoding,
66}
67
68impl DxcBlobEncoding {
69 fn new(inner: IDxcBlobEncoding) -> Self {
70 Self { inner }
71 }
72}
73
74impl From<DxcBlobEncoding> for DxcBlob {
75 fn from(encoded_blob: DxcBlobEncoding) -> Self {
76 DxcBlob::new(encoded_blob.inner.query_interface::<IDxcBlob>().unwrap())
77 }
78}
79
80pub struct DxcOperationResult {
81 inner: IDxcOperationResult,
82}
83
84impl DxcOperationResult {
85 fn new(inner: IDxcOperationResult) -> Self {
86 Self { inner }
87 }
88
89 pub fn get_status(&self) -> Result<u32> {
90 let mut status: u32 = 0;
91 unsafe { self.inner.get_status(&mut status) }.result_with_success(status)
92 }
93
94 pub fn get_result(&self) -> Result<DxcBlob> {
95 let mut blob = None;
96 unsafe { self.inner.get_result(&mut blob) }.result()?;
97 Ok(DxcBlob::new(blob.unwrap()))
98 }
99
100 pub fn get_error_buffer(&self) -> Result<DxcBlobEncoding> {
101 let mut blob = None;
102
103 unsafe { self.inner.get_error_buffer(&mut blob) }.result()?;
104 Ok(DxcBlobEncoding::new(blob.unwrap()))
105 }
106}
107
108pub trait DxcIncludeHandler {
109 fn load_source(&mut self, filename: String) -> Option<String>;
110}
111
112class! {
113 #[no_class_factory]
114 class DxcIncludeHandlerWrapper: IDxcIncludeHandler {
115 library: &'static DxcLibrary,
124 handler: RefCell<&'static mut dyn DxcIncludeHandler>,
125
126 pinned: RefCell<Vec<Pin<String>>>,
127 }
128
129 impl IDxcIncludeHandler for DxcIncludeHandlerWrapper {
130 fn load_source(&self, filename: LPCWSTR, include_source: *mut Option<IDxcBlob>) -> HRESULT {
131 let filename = crate::utils::from_wide(filename);
132
133 let mut handler = self.handler.borrow_mut();
134 let source = handler.load_source(filename);
135
136 if let Some(source) = source {
137 let source = Pin::new(source);
138 let blob = self.library
139 .create_blob_with_encoding_from_str(&source)
140 .unwrap();
141
142 unsafe { *include_source = Some(DxcBlob::from(blob).inner) };
143 self.pinned.borrow_mut().push(source);
144
145 0
147 } else {
148 -2_147_024_894 }
150 .into()
151 }
152 }
153}
154
155struct LocalClassAllocation<T: Class>(core::pin::Pin<Box<T>>);
165
166impl<T: Class> LocalClassAllocation<T> {
167 fn new(allocation: ClassAllocation<T>) -> Self {
168 let inner: core::mem::ManuallyDrop<core::pin::Pin<Box<T>>> =
173 unsafe { std::mem::transmute(allocation) };
174
175 Self(core::mem::ManuallyDrop::into_inner(inner))
176 }
177
178 }
184
185impl<T: Class> Deref for LocalClassAllocation<T> {
186 type Target = core::pin::Pin<Box<T>>;
187
188 fn deref(&self) -> &Self::Target {
189 &self.0
190 }
191}
192
193impl<T: Class> Drop for LocalClassAllocation<T> {
194 fn drop(&mut self) {
195 assert_eq!(
197 unsafe { self.0.dec_ref_count() },
198 0,
199 "COM object is still referenced"
200 );
201 }
204}
205
206impl DxcIncludeHandlerWrapper {
207 unsafe fn create_include_handler(
210 library: &'_ DxcLibrary,
211 include_handler: &'_ mut dyn DxcIncludeHandler,
212 ) -> LocalClassAllocation<DxcIncludeHandlerWrapper> {
213 #[allow(clippy::missing_transmute_annotations)]
214 LocalClassAllocation::new(Self::allocate(
215 std::mem::transmute(library),
216 RefCell::new(std::mem::transmute(include_handler)),
217 RefCell::new(vec![]),
218 ))
219 }
220}
221
222pub struct DxcCompiler {
223 inner: IDxcCompiler2,
224 library: DxcLibrary,
225}
226
227impl DxcCompiler {
228 fn new(inner: IDxcCompiler2, library: DxcLibrary) -> Self {
229 Self { inner, library }
230 }
231
232 fn prep_defines(
233 defines: &[(&str, Option<&str>)],
234 wide_defines: &mut Vec<(Vec<WCHAR>, Vec<WCHAR>)>,
235 dxc_defines: &mut Vec<DxcDefine>,
236 ) {
237 for (name, value) in defines {
238 if value.is_none() {
239 wide_defines.push((to_wide(name), to_wide("1")));
240 } else {
241 wide_defines.push((to_wide(name), to_wide(value.unwrap())));
242 }
243 }
244
245 for (ref name, ref value) in wide_defines {
246 dxc_defines.push(DxcDefine {
247 name: name.as_ptr(),
248 value: value.as_ptr(),
249 });
250 }
251 }
252
253 fn prep_args(args: &[&str], wide_args: &mut Vec<Vec<WCHAR>>, dxc_args: &mut Vec<LPCWSTR>) {
254 for a in args {
255 wide_args.push(to_wide(a));
256 }
257
258 for a in wide_args {
259 dxc_args.push(a.as_ptr());
260 }
261 }
262
263 pub fn compile(
264 &self,
265 blob: &DxcBlobEncoding,
266 source_name: &str,
267 entry_point: &str,
268 target_profile: &str,
269 args: &[&str],
270 include_handler: Option<&mut dyn DxcIncludeHandler>,
271 defines: &[(&str, Option<&str>)],
272 ) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
273 let mut wide_args = vec![];
274 let mut dxc_args = vec![];
275 Self::prep_args(args, &mut wide_args, &mut dxc_args);
276
277 let mut wide_defines = vec![];
278 let mut dxc_defines = vec![];
279 Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
280
281 let include_handler = include_handler.map(|include_handler| unsafe {
283 DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
284 });
285 let include_handler = include_handler
287 .as_ref()
288 .map(|i| i.query_interface().unwrap());
289
290 let mut result = None;
291 let result_hr = unsafe {
292 self.inner.compile(
293 &blob.inner,
294 to_wide(source_name).as_ptr(),
295 to_wide(entry_point).as_ptr(),
296 to_wide(target_profile).as_ptr(),
297 dxc_args.as_ptr(),
298 dxc_args.len() as u32,
299 dxc_defines.as_ptr(),
300 dxc_defines.len() as u32,
301 &include_handler,
302 &mut result,
303 )
304 };
305
306 let result = result.unwrap();
307
308 let mut compile_error = 0u32;
309 let status_hr = unsafe { result.get_status(&mut compile_error) };
310
311 if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
312 Ok(DxcOperationResult::new(result))
313 } else {
314 Err((DxcOperationResult::new(result), result_hr))
315 }
316 }
317
318 pub fn compile_with_debug(
319 &self,
320 blob: &DxcBlobEncoding,
321 source_name: &str,
322 entry_point: &str,
323 target_profile: &str,
324 args: &[&str],
325 include_handler: Option<&mut dyn DxcIncludeHandler>,
326 defines: &[(&str, Option<&str>)],
327 ) -> Result<(DxcOperationResult, String, DxcBlob), (DxcOperationResult, HRESULT)> {
328 let mut wide_args = vec![];
329 let mut dxc_args = vec![];
330 Self::prep_args(args, &mut wide_args, &mut dxc_args);
331
332 let mut wide_defines = vec![];
333 let mut dxc_defines = vec![];
334 Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
335
336 let include_handler = include_handler.map(|include_handler| unsafe {
338 DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
339 });
340 let include_handler = include_handler
341 .as_ref()
342 .map(|i| i.query_interface().unwrap());
343
344 let mut result = None;
345 let mut debug_blob = None;
346 let mut debug_filename: LPWSTR = std::ptr::null_mut();
347
348 let result_hr = unsafe {
349 self.inner.compile_with_debug(
350 &blob.inner,
351 to_wide(source_name).as_ptr(),
352 to_wide(entry_point).as_ptr(),
353 to_wide(target_profile).as_ptr(),
354 dxc_args.as_ptr(),
355 dxc_args.len() as u32,
356 dxc_defines.as_ptr(),
357 dxc_defines.len() as u32,
358 include_handler,
359 &mut result,
360 &mut debug_filename,
361 &mut debug_blob,
362 )
363 };
364 let result = result.unwrap();
365 let debug_blob = debug_blob.unwrap();
366
367 let mut compile_error = 0u32;
368 let status_hr = unsafe { result.get_status(&mut compile_error) };
369
370 if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
371 Ok((
372 DxcOperationResult::new(result),
373 from_wide(debug_filename),
374 DxcBlob::new(debug_blob),
375 ))
376 } else {
377 Err((DxcOperationResult::new(result), result_hr))
378 }
379 }
380
381 pub fn preprocess(
382 &self,
383 blob: &DxcBlobEncoding,
384 source_name: &str,
385 args: &[&str],
386 include_handler: Option<&mut dyn DxcIncludeHandler>,
387 defines: &[(&str, Option<&str>)],
388 ) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
389 let mut wide_args = vec![];
390 let mut dxc_args = vec![];
391 Self::prep_args(args, &mut wide_args, &mut dxc_args);
392
393 let mut wide_defines = vec![];
394 let mut dxc_defines = vec![];
395 Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
396
397 let include_handler = include_handler.map(|include_handler| unsafe {
399 DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
400 });
401 let include_handler = include_handler
402 .as_ref()
403 .map(|i| i.query_interface().unwrap());
404
405 let mut result = None;
406 let result_hr = unsafe {
407 self.inner.preprocess(
408 &blob.inner,
409 to_wide(source_name).as_ptr(),
410 dxc_args.as_ptr(),
411 dxc_args.len() as u32,
412 dxc_defines.as_ptr(),
413 dxc_defines.len() as u32,
414 include_handler,
415 &mut result,
416 )
417 };
418
419 let result = result.unwrap();
420
421 let mut compile_error = 0u32;
422 let status_hr = unsafe { result.get_status(&mut compile_error) };
423
424 if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
425 Ok(DxcOperationResult::new(result))
426 } else {
427 Err((DxcOperationResult::new(result), result_hr))
428 }
429 }
430
431 pub fn disassemble(&self, blob: &DxcBlob) -> Result<DxcBlobEncoding> {
432 let mut result_blob = None;
433 unsafe { self.inner.disassemble(&blob.inner, &mut result_blob) }.result()?;
434 Ok(DxcBlobEncoding::new(result_blob.unwrap()))
435 }
436}
437
438#[derive(Clone)]
439pub struct DxcLibrary {
440 inner: IDxcLibrary,
441}
442
443impl DxcLibrary {
444 fn new(inner: IDxcLibrary) -> Self {
445 Self { inner }
446 }
447
448 pub fn create_blob_with_encoding(&self, data: &[u8]) -> Result<DxcBlobEncoding> {
449 let mut blob = None;
450
451 unsafe {
452 self.inner.create_blob_with_encoding_from_pinned(
453 data.as_ptr().cast(),
454 data.len() as u32,
455 0, &mut blob,
457 )
458 }
459 .result()?;
460 Ok(DxcBlobEncoding::new(blob.unwrap()))
461 }
462
463 pub fn create_blob_with_encoding_from_str(&self, text: &str) -> Result<DxcBlobEncoding> {
464 let mut blob = None;
465 const CP_UTF8: u32 = 65001; unsafe {
468 self.inner.create_blob_with_encoding_from_pinned(
469 text.as_ptr().cast(),
470 text.len() as u32,
471 CP_UTF8,
472 &mut blob,
473 )
474 }
475 .result()?;
476 Ok(DxcBlobEncoding::new(blob.unwrap()))
477 }
478
479 pub fn get_blob_as_string(&self, blob: &DxcBlob) -> Result<String> {
480 let mut blob_utf8 = None;
481
482 unsafe { self.inner.get_blob_as_utf8(&blob.inner, &mut blob_utf8) }.result()?;
483
484 let blob_utf8 = blob_utf8.unwrap();
485
486 Ok(String::from_utf8(DxcBlob::new(blob_utf8.query_interface().unwrap()).to_vec()).unwrap())
487 }
488}
489
490#[derive(Debug)]
491pub struct Dxc {
492 dxc_lib: Library,
493}
494
495impl Dxc {
496 pub fn new(lib_path: Option<PathBuf>) -> Result<Self> {
499 let lib_path = lib_path.unwrap_or_else(|| PathBuf::from(library_filename("dxcompiler")));
500
501 let dxc_lib =
502 unsafe { Library::new(&lib_path) }.map_err(|e| HassleError::LoadLibraryError {
503 filename: lib_path,
504 inner: e,
505 })?;
506
507 Ok(Self { dxc_lib })
508 }
509
510 pub(crate) fn get_dxc_create_instance<T>(&self) -> Result<Symbol<DxcCreateInstanceProc<T>>> {
511 Ok(unsafe { self.dxc_lib.get(b"DxcCreateInstance\0")? })
512 }
513
514 pub fn create_compiler(&self) -> Result<DxcCompiler> {
515 let mut compiler = None;
516
517 self.get_dxc_create_instance()?(&CLSID_DxcCompiler, &IDxcCompiler2::IID, &mut compiler)
518 .result()?;
519 Ok(DxcCompiler::new(
520 compiler.unwrap(),
521 self.create_library().unwrap(),
522 ))
523 }
524
525 pub fn create_library(&self) -> Result<DxcLibrary> {
526 let mut library = None;
527 self.get_dxc_create_instance()?(&CLSID_DxcLibrary, &IDxcLibrary::IID, &mut library)
528 .result()?;
529 Ok(DxcLibrary::new(library.unwrap()))
530 }
531
532 pub fn create_reflector(&self) -> Result<DxcReflector> {
533 let mut reflector = None;
534
535 self.get_dxc_create_instance()?(
536 &CLSID_DxcContainerReflection,
537 &IDxcContainerReflection::IID,
538 &mut reflector,
539 )
540 .result()?;
541 Ok(DxcReflector::new(reflector.unwrap()))
542 }
543}
544
545pub struct DxcValidator {
546 inner: IDxcValidator,
547}
548
549pub type DxcValidatorVersion = (u32, u32);
550
551impl DxcValidator {
552 fn new(inner: IDxcValidator) -> Self {
553 Self { inner }
554 }
555
556 pub fn version(&self) -> Result<DxcValidatorVersion> {
557 let version = self
558 .inner
559 .query_interface::<IDxcVersionInfo>()
560 .ok_or(HassleError::Win32Error(HRESULT(com::sys::E_NOINTERFACE)))?;
561
562 let mut major = 0;
563 let mut minor = 0;
564
565 unsafe { version.get_version(&mut major, &mut minor) }.result_with_success((major, minor))
566 }
567
568 pub fn validate(&self, blob: DxcBlob) -> Result<DxcBlob, (DxcOperationResult, HassleError)> {
569 let mut result = None;
570 let result_hr = unsafe {
571 self.inner
572 .validate(&blob.inner, DXC_VALIDATOR_FLAGS_IN_PLACE_EDIT, &mut result)
573 };
574
575 let result = result.unwrap();
576
577 let mut validate_status = 0u32;
578 let status_hr = unsafe { result.get_status(&mut validate_status) };
579
580 if !result_hr.is_err() && !status_hr.is_err() && validate_status == 0 {
581 Ok(blob)
582 } else {
583 Err((
584 DxcOperationResult::new(result),
585 HassleError::Win32Error(result_hr),
586 ))
587 }
588 }
589}
590
591pub struct Reflection {
592 inner: ID3D12ShaderReflection,
593}
594impl Reflection {
595 fn new(inner: ID3D12ShaderReflection) -> Self {
596 Self { inner }
597 }
598
599 pub fn thread_group_size(&self) -> [u32; 3] {
600 let (mut size_x, mut size_y, mut size_z) = (0u32, 0u32, 0u32);
601 unsafe {
602 self.inner
603 .get_thread_group_size(&mut size_x, &mut size_y, &mut size_z)
604 };
605 [size_x, size_y, size_z]
606 }
607}
608
609pub struct DxcReflector {
610 inner: IDxcContainerReflection,
611}
612impl DxcReflector {
613 fn new(inner: IDxcContainerReflection) -> Self {
614 Self { inner }
615 }
616
617 pub fn reflect(&self, blob: DxcBlob) -> Result<Reflection> {
618 let result_hr = unsafe { self.inner.load(blob.inner) };
619 if result_hr.is_err() {
620 return Err(HassleError::Win32Error(result_hr));
621 }
622
623 let mut shader_idx = 0;
624 let result_hr = unsafe { self.inner.find_first_part_kind(DFCC_DXIL, &mut shader_idx) };
625 if result_hr.is_err() {
626 return Err(HassleError::Win32Error(result_hr));
627 }
628
629 let mut reflection = None::<IUnknown>;
630 let result_hr = unsafe {
631 self.inner.get_part_reflection(
632 shader_idx,
633 &ID3D12ShaderReflection::IID,
634 &mut reflection,
635 )
636 };
637 if result_hr.is_err() {
638 return Err(HassleError::Win32Error(result_hr));
639 }
640
641 Ok(Reflection::new(
642 reflection.unwrap().query_interface().unwrap(),
643 ))
644 }
645}
646
647#[derive(Debug)]
648pub struct Dxil {
649 dxil_lib: Library,
650}
651
652impl Dxil {
653 pub fn new(lib_path: Option<PathBuf>) -> Result<Self> {
656 let lib_path = lib_path.unwrap_or_else(|| PathBuf::from(library_filename("dxil")));
657
658 let dxil_lib =
659 unsafe { Library::new(&lib_path) }.map_err(|e| HassleError::LoadLibraryError {
660 filename: lib_path.to_owned(),
661 inner: e,
662 })?;
663
664 Ok(Self { dxil_lib })
665 }
666
667 fn get_dxc_create_instance<T>(&self) -> Result<Symbol<DxcCreateInstanceProc<T>>> {
668 Ok(unsafe { self.dxil_lib.get(b"DxcCreateInstance\0")? })
669 }
670
671 pub fn create_validator(&self) -> Result<DxcValidator> {
672 let mut validator = None;
673 self.get_dxc_create_instance()?(&CLSID_DxcValidator, &IDxcValidator::IID, &mut validator)
674 .result()?;
675 Ok(DxcValidator::new(validator.unwrap()))
676 }
677}