1#![allow(
2 clippy::too_many_arguments,
3 clippy::new_without_default,
4 clippy::type_complexity
5)]
6
7use crate::ffi::*;
8use crate::os::{HRESULT, LPCWSTR, LPWSTR, WCHAR};
9use crate::utils::{from_wide, to_wide, HassleError, Result};
10use com::{class, interfaces::IUnknown, production::Class, production::ClassAllocation, Interface};
11use libloading::{Library, Symbol};
12use std::cell::RefCell;
13use std::ops::Deref;
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16
17pub struct DxcBlob {
18 inner: IDxcBlob,
19}
20
21impl DxcBlob {
22 fn new(inner: IDxcBlob) -> Self {
23 Self { inner }
24 }
25
26 pub fn as_slice<T>(&self) -> &[T] {
27 unsafe {
28 std::slice::from_raw_parts(
29 self.inner.get_buffer_pointer().cast(),
30 self.inner.get_buffer_size() / std::mem::size_of::<T>(),
31 )
32 }
33 }
34
35 pub fn as_mut_slice<T>(&mut self) -> &mut [T] {
36 unsafe {
37 std::slice::from_raw_parts_mut(
38 self.inner.get_buffer_pointer().cast(),
39 self.inner.get_buffer_size() / std::mem::size_of::<T>(),
40 )
41 }
42 }
43
44 pub fn to_vec<T>(&self) -> Vec<T>
45 where
46 T: Clone,
47 {
48 self.as_slice().to_vec()
49 }
50}
51
52impl AsRef<[u8]> for DxcBlob {
53 fn as_ref(&self) -> &[u8] {
54 self.as_slice()
55 }
56}
57
58impl AsMut<[u8]> for DxcBlob {
59 fn as_mut(&mut self) -> &mut [u8] {
60 self.as_mut_slice()
61 }
62}
63
64pub struct DxcBlobEncoding {
65 inner: IDxcBlobEncoding,
66}
67
68impl DxcBlobEncoding {
69 fn new(inner: IDxcBlobEncoding) -> Self {
70 Self { inner }
71 }
72}
73
74impl From<DxcBlobEncoding> for DxcBlob {
75 fn from(encoded_blob: DxcBlobEncoding) -> Self {
76 DxcBlob::new(encoded_blob.inner.query_interface::<IDxcBlob>().unwrap())
77 }
78}
79
80pub struct DxcOperationResult {
81 inner: IDxcOperationResult,
82}
83
84impl DxcOperationResult {
85 fn new(inner: IDxcOperationResult) -> Self {
86 Self { inner }
87 }
88
89 pub fn get_status(&self) -> Result<u32> {
90 let mut status: u32 = 0;
91 unsafe { self.inner.get_status(&mut status) }.result_with_success(status)
92 }
93
94 pub fn get_result(&self) -> Result<DxcBlob> {
95 let mut blob = None;
96 unsafe { self.inner.get_result(&mut blob) }.result()?;
97 Ok(DxcBlob::new(blob.unwrap()))
98 }
99
100 pub fn get_error_buffer(&self) -> Result<DxcBlobEncoding> {
101 let mut blob = None;
102
103 unsafe { self.inner.get_error_buffer(&mut blob) }.result()?;
104 Ok(DxcBlobEncoding::new(blob.unwrap()))
105 }
106}
107
108pub trait DxcIncludeHandler {
109 fn load_source(&mut self, filename: String) -> Option<String>;
110}
111
112class! {
113 #[no_class_factory]
114 class DxcIncludeHandlerWrapper: IDxcIncludeHandler {
115 library: &'static DxcLibrary,
124 handler: RefCell<&'static mut dyn DxcIncludeHandler>,
125
126 pinned: RefCell<Vec<Pin<String>>>,
127 }
128
129 impl IDxcIncludeHandler for DxcIncludeHandlerWrapper {
130 fn load_source(&self, filename: LPCWSTR, include_source: *mut Option<IDxcBlob>) -> HRESULT {
131 let filename = crate::utils::from_wide(filename);
132
133 let mut handler = self.handler.borrow_mut();
134 let source = handler.load_source(filename);
135
136 if let Some(source) = source {
137 let source = Pin::new(source);
138 let blob = self.library
139 .create_blob_with_encoding_from_str(&source)
140 .unwrap();
141
142 unsafe { *include_source = Some(DxcBlob::from(blob).inner) };
143 self.pinned.borrow_mut().push(source);
144
145 0
147 } else {
148 -2_147_024_894 }
150 .into()
151 }
152 }
153}
154
155struct LocalClassAllocation<T: Class>(core::pin::Pin<Box<T>>);
165
166impl<T: Class> LocalClassAllocation<T> {
167 fn new(allocation: ClassAllocation<T>) -> Self {
168 let inner: core::mem::ManuallyDrop<core::pin::Pin<Box<T>>> =
173 unsafe { std::mem::transmute(allocation) };
174
175 Self(core::mem::ManuallyDrop::into_inner(inner))
176 }
177
178 }
184
185impl<T: Class> Deref for LocalClassAllocation<T> {
186 type Target = core::pin::Pin<Box<T>>;
187
188 fn deref(&self) -> &Self::Target {
189 &self.0
190 }
191}
192
193impl<T: Class> Drop for LocalClassAllocation<T> {
194 fn drop(&mut self) {
195 assert_eq!(
197 unsafe { self.0.dec_ref_count() },
198 0,
199 "COM object is still referenced"
200 );
201 }
204}
205
206impl DxcIncludeHandlerWrapper {
207 unsafe fn create_include_handler(
210 library: &'_ DxcLibrary,
211 include_handler: &'_ mut dyn DxcIncludeHandler,
212 ) -> LocalClassAllocation<DxcIncludeHandlerWrapper> {
213 LocalClassAllocation::new(Self::allocate(
214 std::mem::transmute(library),
215 RefCell::new(std::mem::transmute(include_handler)),
216 RefCell::new(vec![]),
217 ))
218 }
219}
220
221pub struct DxcCompiler {
222 inner: IDxcCompiler2,
223 library: DxcLibrary,
224}
225
226impl DxcCompiler {
227 fn new(inner: IDxcCompiler2, library: DxcLibrary) -> Self {
228 Self { inner, library }
229 }
230
231 fn prep_defines(
232 defines: &[(&str, Option<&str>)],
233 wide_defines: &mut Vec<(Vec<WCHAR>, Vec<WCHAR>)>,
234 dxc_defines: &mut Vec<DxcDefine>,
235 ) {
236 for (name, value) in defines {
237 if value.is_none() {
238 wide_defines.push((to_wide(name), to_wide("1")));
239 } else {
240 wide_defines.push((to_wide(name), to_wide(value.unwrap())));
241 }
242 }
243
244 for (ref name, ref value) in wide_defines {
245 dxc_defines.push(DxcDefine {
246 name: name.as_ptr(),
247 value: value.as_ptr(),
248 });
249 }
250 }
251
252 fn prep_args(args: &[&str], wide_args: &mut Vec<Vec<WCHAR>>, dxc_args: &mut Vec<LPCWSTR>) {
253 for a in args {
254 wide_args.push(to_wide(a));
255 }
256
257 for a in wide_args {
258 dxc_args.push(a.as_ptr());
259 }
260 }
261
262 pub fn compile(
263 &self,
264 blob: &DxcBlobEncoding,
265 source_name: &str,
266 entry_point: &str,
267 target_profile: &str,
268 args: &[&str],
269 include_handler: Option<&mut dyn DxcIncludeHandler>,
270 defines: &[(&str, Option<&str>)],
271 ) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
272 let mut wide_args = vec![];
273 let mut dxc_args = vec![];
274 Self::prep_args(args, &mut wide_args, &mut dxc_args);
275
276 let mut wide_defines = vec![];
277 let mut dxc_defines = vec![];
278 Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
279
280 let include_handler = include_handler.map(|include_handler| unsafe {
282 DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
283 });
284 let include_handler = include_handler
286 .as_ref()
287 .map(|i| i.query_interface().unwrap());
288
289 let mut result = None;
290 let result_hr = unsafe {
291 self.inner.compile(
292 &blob.inner,
293 to_wide(source_name).as_ptr(),
294 to_wide(entry_point).as_ptr(),
295 to_wide(target_profile).as_ptr(),
296 dxc_args.as_ptr(),
297 dxc_args.len() as u32,
298 dxc_defines.as_ptr(),
299 dxc_defines.len() as u32,
300 &include_handler,
301 &mut result,
302 )
303 };
304
305 let result = result.unwrap();
306
307 let mut compile_error = 0u32;
308 let status_hr = unsafe { result.get_status(&mut compile_error) };
309
310 if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
311 Ok(DxcOperationResult::new(result))
312 } else {
313 Err((DxcOperationResult::new(result), result_hr))
314 }
315 }
316
317 pub fn compile_with_debug(
318 &self,
319 blob: &DxcBlobEncoding,
320 source_name: &str,
321 entry_point: &str,
322 target_profile: &str,
323 args: &[&str],
324 include_handler: Option<&mut dyn DxcIncludeHandler>,
325 defines: &[(&str, Option<&str>)],
326 ) -> Result<(DxcOperationResult, String, DxcBlob), (DxcOperationResult, HRESULT)> {
327 let mut wide_args = vec![];
328 let mut dxc_args = vec![];
329 Self::prep_args(args, &mut wide_args, &mut dxc_args);
330
331 let mut wide_defines = vec![];
332 let mut dxc_defines = vec![];
333 Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
334
335 let include_handler = include_handler.map(|include_handler| unsafe {
337 DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
338 });
339 let include_handler = include_handler
340 .as_ref()
341 .map(|i| i.query_interface().unwrap());
342
343 let mut result = None;
344 let mut debug_blob = None;
345 let mut debug_filename: LPWSTR = std::ptr::null_mut();
346
347 let result_hr = unsafe {
348 self.inner.compile_with_debug(
349 &blob.inner,
350 to_wide(source_name).as_ptr(),
351 to_wide(entry_point).as_ptr(),
352 to_wide(target_profile).as_ptr(),
353 dxc_args.as_ptr(),
354 dxc_args.len() as u32,
355 dxc_defines.as_ptr(),
356 dxc_defines.len() as u32,
357 include_handler,
358 &mut result,
359 &mut debug_filename,
360 &mut debug_blob,
361 )
362 };
363 let result = result.unwrap();
364 let debug_blob = debug_blob.unwrap();
365
366 let mut compile_error = 0u32;
367 let status_hr = unsafe { result.get_status(&mut compile_error) };
368
369 if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
370 Ok((
371 DxcOperationResult::new(result),
372 from_wide(debug_filename),
373 DxcBlob::new(debug_blob),
374 ))
375 } else {
376 Err((DxcOperationResult::new(result), result_hr))
377 }
378 }
379
380 pub fn preprocess(
381 &self,
382 blob: &DxcBlobEncoding,
383 source_name: &str,
384 args: &[&str],
385 include_handler: Option<&mut dyn DxcIncludeHandler>,
386 defines: &[(&str, Option<&str>)],
387 ) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
388 let mut wide_args = vec![];
389 let mut dxc_args = vec![];
390 Self::prep_args(args, &mut wide_args, &mut dxc_args);
391
392 let mut wide_defines = vec![];
393 let mut dxc_defines = vec![];
394 Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
395
396 let include_handler = include_handler.map(|include_handler| unsafe {
398 DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
399 });
400 let include_handler = include_handler
401 .as_ref()
402 .map(|i| i.query_interface().unwrap());
403
404 let mut result = None;
405 let result_hr = unsafe {
406 self.inner.preprocess(
407 &blob.inner,
408 to_wide(source_name).as_ptr(),
409 dxc_args.as_ptr(),
410 dxc_args.len() as u32,
411 dxc_defines.as_ptr(),
412 dxc_defines.len() as u32,
413 include_handler,
414 &mut result,
415 )
416 };
417
418 let result = result.unwrap();
419
420 let mut compile_error = 0u32;
421 let status_hr = unsafe { result.get_status(&mut compile_error) };
422
423 if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
424 Ok(DxcOperationResult::new(result))
425 } else {
426 Err((DxcOperationResult::new(result), result_hr))
427 }
428 }
429
430 pub fn disassemble(&self, blob: &DxcBlob) -> Result<DxcBlobEncoding> {
431 let mut result_blob = None;
432 unsafe { self.inner.disassemble(&blob.inner, &mut result_blob) }.result()?;
433 Ok(DxcBlobEncoding::new(result_blob.unwrap()))
434 }
435}
436
437#[derive(Clone)]
438pub struct DxcLibrary {
439 inner: IDxcLibrary,
440}
441
442impl DxcLibrary {
443 fn new(inner: IDxcLibrary) -> Self {
444 Self { inner }
445 }
446
447 pub fn create_blob_with_encoding(&self, data: &[u8]) -> Result<DxcBlobEncoding> {
448 let mut blob = None;
449
450 unsafe {
451 self.inner.create_blob_with_encoding_from_pinned(
452 data.as_ptr().cast(),
453 data.len() as u32,
454 0, &mut blob,
456 )
457 }
458 .result()?;
459 Ok(DxcBlobEncoding::new(blob.unwrap()))
460 }
461
462 pub fn create_blob_with_encoding_from_str(&self, text: &str) -> Result<DxcBlobEncoding> {
463 let mut blob = None;
464 const CP_UTF8: u32 = 65001; unsafe {
467 self.inner.create_blob_with_encoding_from_pinned(
468 text.as_ptr().cast(),
469 text.len() as u32,
470 CP_UTF8,
471 &mut blob,
472 )
473 }
474 .result()?;
475 Ok(DxcBlobEncoding::new(blob.unwrap()))
476 }
477
478 pub fn get_blob_as_string(&self, blob: &DxcBlob) -> Result<String> {
479 let mut blob_utf8 = None;
480
481 unsafe { self.inner.get_blob_as_utf8(&blob.inner, &mut blob_utf8) }.result()?;
482
483 let blob_utf8 = blob_utf8.unwrap();
484
485 Ok(String::from_utf8(DxcBlob::new(blob_utf8.query_interface().unwrap()).to_vec()).unwrap())
486 }
487}
488
489#[derive(Debug)]
490pub struct Dxc {
491 dxc_lib: Library,
492}
493
494#[cfg(target_os = "windows")]
495fn dxcompiler_lib_name() -> &'static Path {
496 Path::new("dxcompiler.dll")
497}
498
499#[cfg(any(target_os = "linux", target_os = "android"))]
500fn dxcompiler_lib_name() -> &'static Path {
501 Path::new("./libdxcompiler.so")
502}
503
504#[cfg(target_os = "macos")]
505fn dxcompiler_lib_name() -> &'static Path {
506 Path::new("./libdxcompiler.dylib")
507}
508
509impl Dxc {
510 pub fn new(lib_path: Option<PathBuf>) -> Result<Self> {
513 let lib_path = if let Some(lib_path) = lib_path {
514 if lib_path.is_file() {
515 lib_path
516 } else {
517 lib_path.join(dxcompiler_lib_name())
518 }
519 } else {
520 dxcompiler_lib_name().to_owned()
521 };
522 let dxc_lib =
523 unsafe { Library::new(&lib_path) }.map_err(|e| HassleError::LoadLibraryError {
524 filename: lib_path,
525 inner: e,
526 })?;
527
528 Ok(Self { dxc_lib })
529 }
530
531 pub(crate) fn get_dxc_create_instance<T>(&self) -> Result<Symbol<DxcCreateInstanceProc<T>>> {
532 Ok(unsafe { self.dxc_lib.get(b"DxcCreateInstance\0")? })
533 }
534
535 pub fn create_compiler(&self) -> Result<DxcCompiler> {
536 let mut compiler = None;
537
538 self.get_dxc_create_instance()?(&CLSID_DxcCompiler, &IDxcCompiler2::IID, &mut compiler)
539 .result()?;
540 Ok(DxcCompiler::new(
541 compiler.unwrap(),
542 self.create_library().unwrap(),
543 ))
544 }
545
546 pub fn create_library(&self) -> Result<DxcLibrary> {
547 let mut library = None;
548 self.get_dxc_create_instance()?(&CLSID_DxcLibrary, &IDxcLibrary::IID, &mut library)
549 .result()?;
550 Ok(DxcLibrary::new(library.unwrap()))
551 }
552
553 pub fn create_reflector(&self) -> Result<DxcReflector> {
554 let mut reflector = None;
555
556 self.get_dxc_create_instance()?(
557 &CLSID_DxcContainerReflection,
558 &IDxcContainerReflection::IID,
559 &mut reflector,
560 )
561 .result()?;
562 Ok(DxcReflector::new(reflector.unwrap()))
563 }
564}
565
566pub struct DxcValidator {
567 inner: IDxcValidator,
568}
569
570pub type DxcValidatorVersion = (u32, u32);
571
572impl DxcValidator {
573 fn new(inner: IDxcValidator) -> Self {
574 Self { inner }
575 }
576
577 pub fn version(&self) -> Result<DxcValidatorVersion> {
578 let version = self
579 .inner
580 .query_interface::<IDxcVersionInfo>()
581 .ok_or(HassleError::Win32Error(HRESULT(com::sys::E_NOINTERFACE)))?;
582
583 let mut major = 0;
584 let mut minor = 0;
585
586 unsafe { version.get_version(&mut major, &mut minor) }.result_with_success((major, minor))
587 }
588
589 pub fn validate(&self, blob: DxcBlob) -> Result<DxcBlob, (DxcOperationResult, HassleError)> {
590 let mut result = None;
591 let result_hr = unsafe {
592 self.inner
593 .validate(&blob.inner, DXC_VALIDATOR_FLAGS_IN_PLACE_EDIT, &mut result)
594 };
595
596 let result = result.unwrap();
597
598 let mut validate_status = 0u32;
599 let status_hr = unsafe { result.get_status(&mut validate_status) };
600
601 if !result_hr.is_err() && !status_hr.is_err() && validate_status == 0 {
602 Ok(blob)
603 } else {
604 Err((
605 DxcOperationResult::new(result),
606 HassleError::Win32Error(result_hr),
607 ))
608 }
609 }
610}
611
612pub struct Reflection {
613 inner: ID3D12ShaderReflection,
614}
615impl Reflection {
616 fn new(inner: ID3D12ShaderReflection) -> Self {
617 Self { inner }
618 }
619
620 pub fn thread_group_size(&self) -> [u32; 3] {
621 let (mut size_x, mut size_y, mut size_z) = (0u32, 0u32, 0u32);
622 unsafe {
623 self.inner
624 .get_thread_group_size(&mut size_x, &mut size_y, &mut size_z)
625 };
626 [size_x, size_y, size_z]
627 }
628}
629
630pub struct DxcReflector {
631 inner: IDxcContainerReflection,
632}
633impl DxcReflector {
634 fn new(inner: IDxcContainerReflection) -> Self {
635 Self { inner }
636 }
637
638 pub fn reflect(&self, blob: DxcBlob) -> Result<Reflection> {
639 let result_hr = unsafe { self.inner.load(blob.inner) };
640 if result_hr.is_err() {
641 return Err(HassleError::Win32Error(result_hr));
642 }
643
644 let mut shader_idx = 0;
645 let result_hr = unsafe { self.inner.find_first_part_kind(DFCC_DXIL, &mut shader_idx) };
646 if result_hr.is_err() {
647 return Err(HassleError::Win32Error(result_hr));
648 }
649
650 let mut reflection = None::<IUnknown>;
651 let result_hr = unsafe {
652 self.inner.get_part_reflection(
653 shader_idx,
654 &ID3D12ShaderReflection::IID,
655 &mut reflection,
656 )
657 };
658 if result_hr.is_err() {
659 return Err(HassleError::Win32Error(result_hr));
660 }
661
662 Ok(Reflection::new(
663 reflection.unwrap().query_interface().unwrap(),
664 ))
665 }
666}
667
668#[derive(Debug)]
669pub struct Dxil {
670 dxil_lib: Library,
671}
672
673impl Dxil {
674 #[cfg(not(windows))]
675 pub fn new(_: Option<PathBuf>) -> Result<Self> {
676 Err(HassleError::WindowsOnly(
677 "DXIL Signing is only supported on Windows".to_string(),
678 ))
679 }
680
681 #[cfg(windows)]
684 pub fn new(lib_path: Option<PathBuf>) -> Result<Self> {
685 let lib_path = if let Some(lib_path) = lib_path {
686 if lib_path.is_file() {
687 lib_path
688 } else {
689 lib_path.join("dxil.dll")
690 }
691 } else {
692 PathBuf::from("dxil.dll")
693 };
694
695 let dxil_lib =
696 unsafe { Library::new(&lib_path) }.map_err(|e| HassleError::LoadLibraryError {
697 filename: lib_path.to_owned(),
698 inner: e,
699 })?;
700
701 Ok(Self { dxil_lib })
702 }
703
704 fn get_dxc_create_instance<T>(&self) -> Result<Symbol<DxcCreateInstanceProc<T>>> {
705 Ok(unsafe { self.dxil_lib.get(b"DxcCreateInstance\0")? })
706 }
707
708 pub fn create_validator(&self) -> Result<DxcValidator> {
709 let mut validator = None;
710 self.get_dxc_create_instance()?(&CLSID_DxcValidator, &IDxcValidator::IID, &mut validator)
711 .result()?;
712 Ok(DxcValidator::new(validator.unwrap()))
713 }
714}