hassle_rs/
wrapper.rs

1#![allow(
2    clippy::too_many_arguments,
3    clippy::new_without_default,
4    clippy::type_complexity
5)]
6
7use crate::ffi::*;
8use crate::os::{HRESULT, LPCWSTR, LPWSTR, WCHAR};
9use crate::utils::{from_wide, to_wide, HassleError, Result};
10use com::{class, interfaces::IUnknown, production::Class, production::ClassAllocation, Interface};
11use libloading::{Library, Symbol};
12use std::cell::RefCell;
13use std::ops::Deref;
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16
17pub struct DxcBlob {
18    inner: IDxcBlob,
19}
20
21impl DxcBlob {
22    fn new(inner: IDxcBlob) -> Self {
23        Self { inner }
24    }
25
26    pub fn as_slice<T>(&self) -> &[T] {
27        unsafe {
28            std::slice::from_raw_parts(
29                self.inner.get_buffer_pointer().cast(),
30                self.inner.get_buffer_size() / std::mem::size_of::<T>(),
31            )
32        }
33    }
34
35    pub fn as_mut_slice<T>(&mut self) -> &mut [T] {
36        unsafe {
37            std::slice::from_raw_parts_mut(
38                self.inner.get_buffer_pointer().cast(),
39                self.inner.get_buffer_size() / std::mem::size_of::<T>(),
40            )
41        }
42    }
43
44    pub fn to_vec<T>(&self) -> Vec<T>
45    where
46        T: Clone,
47    {
48        self.as_slice().to_vec()
49    }
50}
51
52impl AsRef<[u8]> for DxcBlob {
53    fn as_ref(&self) -> &[u8] {
54        self.as_slice()
55    }
56}
57
58impl AsMut<[u8]> for DxcBlob {
59    fn as_mut(&mut self) -> &mut [u8] {
60        self.as_mut_slice()
61    }
62}
63
64pub struct DxcBlobEncoding {
65    inner: IDxcBlobEncoding,
66}
67
68impl DxcBlobEncoding {
69    fn new(inner: IDxcBlobEncoding) -> Self {
70        Self { inner }
71    }
72}
73
74impl From<DxcBlobEncoding> for DxcBlob {
75    fn from(encoded_blob: DxcBlobEncoding) -> Self {
76        DxcBlob::new(encoded_blob.inner.query_interface::<IDxcBlob>().unwrap())
77    }
78}
79
80pub struct DxcOperationResult {
81    inner: IDxcOperationResult,
82}
83
84impl DxcOperationResult {
85    fn new(inner: IDxcOperationResult) -> Self {
86        Self { inner }
87    }
88
89    pub fn get_status(&self) -> Result<u32> {
90        let mut status: u32 = 0;
91        unsafe { self.inner.get_status(&mut status) }.result_with_success(status)
92    }
93
94    pub fn get_result(&self) -> Result<DxcBlob> {
95        let mut blob = None;
96        unsafe { self.inner.get_result(&mut blob) }.result()?;
97        Ok(DxcBlob::new(blob.unwrap()))
98    }
99
100    pub fn get_error_buffer(&self) -> Result<DxcBlobEncoding> {
101        let mut blob = None;
102
103        unsafe { self.inner.get_error_buffer(&mut blob) }.result()?;
104        Ok(DxcBlobEncoding::new(blob.unwrap()))
105    }
106}
107
108pub trait DxcIncludeHandler {
109    fn load_source(&mut self, filename: String) -> Option<String>;
110}
111
112class! {
113    #[no_class_factory]
114    class DxcIncludeHandlerWrapper: IDxcIncludeHandler {
115        // Com-rs intentionally does not support lifetimes in its class structs
116        // since they live on the heap and their lifetime can be prolonged for
117        // as long as someone keeps a reference through `add_ref()`.
118        // The only way for us to access the library and handler implementation,
119        // which are now intentionally behind a borrow to signify our promise
120        // regarding lifetime, is by transmuting them away and "ensuring" the
121        // class object is discarded at the end of our function call.
122
123        library: &'static DxcLibrary,
124        handler: RefCell<&'static mut dyn DxcIncludeHandler>,
125
126        pinned: RefCell<Vec<Pin<String>>>,
127    }
128
129    impl IDxcIncludeHandler for DxcIncludeHandlerWrapper {
130        fn load_source(&self, filename: LPCWSTR, include_source: *mut Option<IDxcBlob>) -> HRESULT {
131            let filename = crate::utils::from_wide(filename);
132
133            let mut handler = self.handler.borrow_mut();
134            let source = handler.load_source(filename);
135
136            if let Some(source) = source {
137                let source = Pin::new(source);
138                let blob = self.library
139                    .create_blob_with_encoding_from_str(&source)
140                    .unwrap();
141
142                unsafe { *include_source = Some(DxcBlob::from(blob).inner) };
143                self.pinned.borrow_mut().push(source);
144
145                // NOERROR
146                0
147            } else {
148                -2_147_024_894 // ERROR_FILE_NOT_FOUND / 0x80070002
149            }
150            .into()
151        }
152    }
153}
154
155/// Represents a reference to a COM object that should only live as long as itself
156///
157/// In other words, on [`drop()`] we assert that the refcount is decremented to zero,
158/// rather than allowing it to be referenced externally (i.e. [`Class::dec_ref_count()`]
159/// returning `> 0`).
160/// This object functions a lot like [`ClassAllocation`]: see its similar [`drop()`]
161/// implementation for details.
162///
163/// Note that COM objects live on the heap by design, because of this refcount system.
164struct LocalClassAllocation<T: Class>(core::pin::Pin<Box<T>>);
165
166impl<T: Class> LocalClassAllocation<T> {
167    fn new(allocation: ClassAllocation<T>) -> Self {
168        // TODO: There is no way to take the internal, owned box out of com-rs's
169        // allocation wrapper.
170        // https://github.com/microsoft/com-rs/issues/236 covers this issue as a whole,
171        // including lifetime support and this `LocalClassAllocation` upstream.
172        let inner: core::mem::ManuallyDrop<core::pin::Pin<Box<T>>> =
173            unsafe { std::mem::transmute(allocation) };
174
175        Self(core::mem::ManuallyDrop::into_inner(inner))
176    }
177
178    // TODO: Return a borrow of this interface?
179    // query_interface() is not behind one of the traits
180    // fn query_interface<T>(&self) -> Option<T> {
181    //     self.0.query_interface::<T>().unwrap()
182    // }
183}
184
185impl<T: Class> Deref for LocalClassAllocation<T> {
186    type Target = core::pin::Pin<Box<T>>;
187
188    fn deref(&self) -> &Self::Target {
189        &self.0
190    }
191}
192
193impl<T: Class> Drop for LocalClassAllocation<T> {
194    fn drop(&mut self) {
195        // Check if we are the only remaining reference to this object
196        assert_eq!(
197            unsafe { self.0.dec_ref_count() },
198            0,
199            "COM object is still referenced"
200        );
201        // Now that we're the last one to give up our refcount, it is safe
202        // for the internal object to get dropped.
203    }
204}
205
206impl DxcIncludeHandlerWrapper {
207    /// SAFETY: Make sure the returned object does _not_ outlive the lifetime
208    /// of either `library` nor `include_handler`
209    unsafe fn create_include_handler(
210        library: &'_ DxcLibrary,
211        include_handler: &'_ mut dyn DxcIncludeHandler,
212    ) -> LocalClassAllocation<DxcIncludeHandlerWrapper> {
213        LocalClassAllocation::new(Self::allocate(
214            std::mem::transmute(library),
215            RefCell::new(std::mem::transmute(include_handler)),
216            RefCell::new(vec![]),
217        ))
218    }
219}
220
221pub struct DxcCompiler {
222    inner: IDxcCompiler2,
223    library: DxcLibrary,
224}
225
226impl DxcCompiler {
227    fn new(inner: IDxcCompiler2, library: DxcLibrary) -> Self {
228        Self { inner, library }
229    }
230
231    fn prep_defines(
232        defines: &[(&str, Option<&str>)],
233        wide_defines: &mut Vec<(Vec<WCHAR>, Vec<WCHAR>)>,
234        dxc_defines: &mut Vec<DxcDefine>,
235    ) {
236        for (name, value) in defines {
237            if value.is_none() {
238                wide_defines.push((to_wide(name), to_wide("1")));
239            } else {
240                wide_defines.push((to_wide(name), to_wide(value.unwrap())));
241            }
242        }
243
244        for (ref name, ref value) in wide_defines {
245            dxc_defines.push(DxcDefine {
246                name: name.as_ptr(),
247                value: value.as_ptr(),
248            });
249        }
250    }
251
252    fn prep_args(args: &[&str], wide_args: &mut Vec<Vec<WCHAR>>, dxc_args: &mut Vec<LPCWSTR>) {
253        for a in args {
254            wide_args.push(to_wide(a));
255        }
256
257        for a in wide_args {
258            dxc_args.push(a.as_ptr());
259        }
260    }
261
262    pub fn compile(
263        &self,
264        blob: &DxcBlobEncoding,
265        source_name: &str,
266        entry_point: &str,
267        target_profile: &str,
268        args: &[&str],
269        include_handler: Option<&mut dyn DxcIncludeHandler>,
270        defines: &[(&str, Option<&str>)],
271    ) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
272        let mut wide_args = vec![];
273        let mut dxc_args = vec![];
274        Self::prep_args(args, &mut wide_args, &mut dxc_args);
275
276        let mut wide_defines = vec![];
277        let mut dxc_defines = vec![];
278        Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
279
280        // Keep alive on the stack
281        let include_handler = include_handler.map(|include_handler| unsafe {
282            DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
283        });
284        // TODO: query_interface() should have a borrow on LocalClassAllocation to prevent things going kaboom
285        let include_handler = include_handler
286            .as_ref()
287            .map(|i| i.query_interface().unwrap());
288
289        let mut result = None;
290        let result_hr = unsafe {
291            self.inner.compile(
292                &blob.inner,
293                to_wide(source_name).as_ptr(),
294                to_wide(entry_point).as_ptr(),
295                to_wide(target_profile).as_ptr(),
296                dxc_args.as_ptr(),
297                dxc_args.len() as u32,
298                dxc_defines.as_ptr(),
299                dxc_defines.len() as u32,
300                &include_handler,
301                &mut result,
302            )
303        };
304
305        let result = result.unwrap();
306
307        let mut compile_error = 0u32;
308        let status_hr = unsafe { result.get_status(&mut compile_error) };
309
310        if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
311            Ok(DxcOperationResult::new(result))
312        } else {
313            Err((DxcOperationResult::new(result), result_hr))
314        }
315    }
316
317    pub fn compile_with_debug(
318        &self,
319        blob: &DxcBlobEncoding,
320        source_name: &str,
321        entry_point: &str,
322        target_profile: &str,
323        args: &[&str],
324        include_handler: Option<&mut dyn DxcIncludeHandler>,
325        defines: &[(&str, Option<&str>)],
326    ) -> Result<(DxcOperationResult, String, DxcBlob), (DxcOperationResult, HRESULT)> {
327        let mut wide_args = vec![];
328        let mut dxc_args = vec![];
329        Self::prep_args(args, &mut wide_args, &mut dxc_args);
330
331        let mut wide_defines = vec![];
332        let mut dxc_defines = vec![];
333        Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
334
335        // Keep alive on the stack
336        let include_handler = include_handler.map(|include_handler| unsafe {
337            DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
338        });
339        let include_handler = include_handler
340            .as_ref()
341            .map(|i| i.query_interface().unwrap());
342
343        let mut result = None;
344        let mut debug_blob = None;
345        let mut debug_filename: LPWSTR = std::ptr::null_mut();
346
347        let result_hr = unsafe {
348            self.inner.compile_with_debug(
349                &blob.inner,
350                to_wide(source_name).as_ptr(),
351                to_wide(entry_point).as_ptr(),
352                to_wide(target_profile).as_ptr(),
353                dxc_args.as_ptr(),
354                dxc_args.len() as u32,
355                dxc_defines.as_ptr(),
356                dxc_defines.len() as u32,
357                include_handler,
358                &mut result,
359                &mut debug_filename,
360                &mut debug_blob,
361            )
362        };
363        let result = result.unwrap();
364        let debug_blob = debug_blob.unwrap();
365
366        let mut compile_error = 0u32;
367        let status_hr = unsafe { result.get_status(&mut compile_error) };
368
369        if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
370            Ok((
371                DxcOperationResult::new(result),
372                from_wide(debug_filename),
373                DxcBlob::new(debug_blob),
374            ))
375        } else {
376            Err((DxcOperationResult::new(result), result_hr))
377        }
378    }
379
380    pub fn preprocess(
381        &self,
382        blob: &DxcBlobEncoding,
383        source_name: &str,
384        args: &[&str],
385        include_handler: Option<&mut dyn DxcIncludeHandler>,
386        defines: &[(&str, Option<&str>)],
387    ) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
388        let mut wide_args = vec![];
389        let mut dxc_args = vec![];
390        Self::prep_args(args, &mut wide_args, &mut dxc_args);
391
392        let mut wide_defines = vec![];
393        let mut dxc_defines = vec![];
394        Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
395
396        // Keep alive on the stack
397        let include_handler = include_handler.map(|include_handler| unsafe {
398            DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
399        });
400        let include_handler = include_handler
401            .as_ref()
402            .map(|i| i.query_interface().unwrap());
403
404        let mut result = None;
405        let result_hr = unsafe {
406            self.inner.preprocess(
407                &blob.inner,
408                to_wide(source_name).as_ptr(),
409                dxc_args.as_ptr(),
410                dxc_args.len() as u32,
411                dxc_defines.as_ptr(),
412                dxc_defines.len() as u32,
413                include_handler,
414                &mut result,
415            )
416        };
417
418        let result = result.unwrap();
419
420        let mut compile_error = 0u32;
421        let status_hr = unsafe { result.get_status(&mut compile_error) };
422
423        if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
424            Ok(DxcOperationResult::new(result))
425        } else {
426            Err((DxcOperationResult::new(result), result_hr))
427        }
428    }
429
430    pub fn disassemble(&self, blob: &DxcBlob) -> Result<DxcBlobEncoding> {
431        let mut result_blob = None;
432        unsafe { self.inner.disassemble(&blob.inner, &mut result_blob) }.result()?;
433        Ok(DxcBlobEncoding::new(result_blob.unwrap()))
434    }
435}
436
437#[derive(Clone)]
438pub struct DxcLibrary {
439    inner: IDxcLibrary,
440}
441
442impl DxcLibrary {
443    fn new(inner: IDxcLibrary) -> Self {
444        Self { inner }
445    }
446
447    pub fn create_blob_with_encoding(&self, data: &[u8]) -> Result<DxcBlobEncoding> {
448        let mut blob = None;
449
450        unsafe {
451            self.inner.create_blob_with_encoding_from_pinned(
452                data.as_ptr().cast(),
453                data.len() as u32,
454                0, // Binary; no code page
455                &mut blob,
456            )
457        }
458        .result()?;
459        Ok(DxcBlobEncoding::new(blob.unwrap()))
460    }
461
462    pub fn create_blob_with_encoding_from_str(&self, text: &str) -> Result<DxcBlobEncoding> {
463        let mut blob = None;
464        const CP_UTF8: u32 = 65001; // UTF-8 translation
465
466        unsafe {
467            self.inner.create_blob_with_encoding_from_pinned(
468                text.as_ptr().cast(),
469                text.len() as u32,
470                CP_UTF8,
471                &mut blob,
472            )
473        }
474        .result()?;
475        Ok(DxcBlobEncoding::new(blob.unwrap()))
476    }
477
478    pub fn get_blob_as_string(&self, blob: &DxcBlob) -> Result<String> {
479        let mut blob_utf8 = None;
480
481        unsafe { self.inner.get_blob_as_utf8(&blob.inner, &mut blob_utf8) }.result()?;
482
483        let blob_utf8 = blob_utf8.unwrap();
484
485        Ok(String::from_utf8(DxcBlob::new(blob_utf8.query_interface().unwrap()).to_vec()).unwrap())
486    }
487}
488
489#[derive(Debug)]
490pub struct Dxc {
491    dxc_lib: Library,
492}
493
494#[cfg(target_os = "windows")]
495fn dxcompiler_lib_name() -> &'static Path {
496    Path::new("dxcompiler.dll")
497}
498
499#[cfg(any(target_os = "linux", target_os = "android"))]
500fn dxcompiler_lib_name() -> &'static Path {
501    Path::new("./libdxcompiler.so")
502}
503
504#[cfg(target_os = "macos")]
505fn dxcompiler_lib_name() -> &'static Path {
506    Path::new("./libdxcompiler.dylib")
507}
508
509impl Dxc {
510    /// `dxc_path` can point to a library directly or the directory containing the library,
511    /// in which case the appended filename depends on the platform.
512    pub fn new(lib_path: Option<PathBuf>) -> Result<Self> {
513        let lib_path = if let Some(lib_path) = lib_path {
514            if lib_path.is_file() {
515                lib_path
516            } else {
517                lib_path.join(dxcompiler_lib_name())
518            }
519        } else {
520            dxcompiler_lib_name().to_owned()
521        };
522        let dxc_lib =
523            unsafe { Library::new(&lib_path) }.map_err(|e| HassleError::LoadLibraryError {
524                filename: lib_path,
525                inner: e,
526            })?;
527
528        Ok(Self { dxc_lib })
529    }
530
531    pub(crate) fn get_dxc_create_instance<T>(&self) -> Result<Symbol<DxcCreateInstanceProc<T>>> {
532        Ok(unsafe { self.dxc_lib.get(b"DxcCreateInstance\0")? })
533    }
534
535    pub fn create_compiler(&self) -> Result<DxcCompiler> {
536        let mut compiler = None;
537
538        self.get_dxc_create_instance()?(&CLSID_DxcCompiler, &IDxcCompiler2::IID, &mut compiler)
539            .result()?;
540        Ok(DxcCompiler::new(
541            compiler.unwrap(),
542            self.create_library().unwrap(),
543        ))
544    }
545
546    pub fn create_library(&self) -> Result<DxcLibrary> {
547        let mut library = None;
548        self.get_dxc_create_instance()?(&CLSID_DxcLibrary, &IDxcLibrary::IID, &mut library)
549            .result()?;
550        Ok(DxcLibrary::new(library.unwrap()))
551    }
552
553    pub fn create_reflector(&self) -> Result<DxcReflector> {
554        let mut reflector = None;
555
556        self.get_dxc_create_instance()?(
557            &CLSID_DxcContainerReflection,
558            &IDxcContainerReflection::IID,
559            &mut reflector,
560        )
561        .result()?;
562        Ok(DxcReflector::new(reflector.unwrap()))
563    }
564}
565
566pub struct DxcValidator {
567    inner: IDxcValidator,
568}
569
570pub type DxcValidatorVersion = (u32, u32);
571
572impl DxcValidator {
573    fn new(inner: IDxcValidator) -> Self {
574        Self { inner }
575    }
576
577    pub fn version(&self) -> Result<DxcValidatorVersion> {
578        let version = self
579            .inner
580            .query_interface::<IDxcVersionInfo>()
581            .ok_or(HassleError::Win32Error(HRESULT(com::sys::E_NOINTERFACE)))?;
582
583        let mut major = 0;
584        let mut minor = 0;
585
586        unsafe { version.get_version(&mut major, &mut minor) }.result_with_success((major, minor))
587    }
588
589    pub fn validate(&self, blob: DxcBlob) -> Result<DxcBlob, (DxcOperationResult, HassleError)> {
590        let mut result = None;
591        let result_hr = unsafe {
592            self.inner
593                .validate(&blob.inner, DXC_VALIDATOR_FLAGS_IN_PLACE_EDIT, &mut result)
594        };
595
596        let result = result.unwrap();
597
598        let mut validate_status = 0u32;
599        let status_hr = unsafe { result.get_status(&mut validate_status) };
600
601        if !result_hr.is_err() && !status_hr.is_err() && validate_status == 0 {
602            Ok(blob)
603        } else {
604            Err((
605                DxcOperationResult::new(result),
606                HassleError::Win32Error(result_hr),
607            ))
608        }
609    }
610}
611
612pub struct Reflection {
613    inner: ID3D12ShaderReflection,
614}
615impl Reflection {
616    fn new(inner: ID3D12ShaderReflection) -> Self {
617        Self { inner }
618    }
619
620    pub fn thread_group_size(&self) -> [u32; 3] {
621        let (mut size_x, mut size_y, mut size_z) = (0u32, 0u32, 0u32);
622        unsafe {
623            self.inner
624                .get_thread_group_size(&mut size_x, &mut size_y, &mut size_z)
625        };
626        [size_x, size_y, size_z]
627    }
628}
629
630pub struct DxcReflector {
631    inner: IDxcContainerReflection,
632}
633impl DxcReflector {
634    fn new(inner: IDxcContainerReflection) -> Self {
635        Self { inner }
636    }
637
638    pub fn reflect(&self, blob: DxcBlob) -> Result<Reflection> {
639        let result_hr = unsafe { self.inner.load(blob.inner) };
640        if result_hr.is_err() {
641            return Err(HassleError::Win32Error(result_hr));
642        }
643
644        let mut shader_idx = 0;
645        let result_hr = unsafe { self.inner.find_first_part_kind(DFCC_DXIL, &mut shader_idx) };
646        if result_hr.is_err() {
647            return Err(HassleError::Win32Error(result_hr));
648        }
649
650        let mut reflection = None::<IUnknown>;
651        let result_hr = unsafe {
652            self.inner.get_part_reflection(
653                shader_idx,
654                &ID3D12ShaderReflection::IID,
655                &mut reflection,
656            )
657        };
658        if result_hr.is_err() {
659            return Err(HassleError::Win32Error(result_hr));
660        }
661
662        Ok(Reflection::new(
663            reflection.unwrap().query_interface().unwrap(),
664        ))
665    }
666}
667
668#[derive(Debug)]
669pub struct Dxil {
670    dxil_lib: Library,
671}
672
673impl Dxil {
674    #[cfg(not(windows))]
675    pub fn new(_: Option<PathBuf>) -> Result<Self> {
676        Err(HassleError::WindowsOnly(
677            "DXIL Signing is only supported on Windows".to_string(),
678        ))
679    }
680
681    /// `dxil_path` can point to a library directly or the directory containing the library,
682    /// in which case `dxil.dll` is appended.
683    #[cfg(windows)]
684    pub fn new(lib_path: Option<PathBuf>) -> Result<Self> {
685        let lib_path = if let Some(lib_path) = lib_path {
686            if lib_path.is_file() {
687                lib_path
688            } else {
689                lib_path.join("dxil.dll")
690            }
691        } else {
692            PathBuf::from("dxil.dll")
693        };
694
695        let dxil_lib =
696            unsafe { Library::new(&lib_path) }.map_err(|e| HassleError::LoadLibraryError {
697                filename: lib_path.to_owned(),
698                inner: e,
699            })?;
700
701        Ok(Self { dxil_lib })
702    }
703
704    fn get_dxc_create_instance<T>(&self) -> Result<Symbol<DxcCreateInstanceProc<T>>> {
705        Ok(unsafe { self.dxil_lib.get(b"DxcCreateInstance\0")? })
706    }
707
708    pub fn create_validator(&self) -> Result<DxcValidator> {
709        let mut validator = None;
710        self.get_dxc_create_instance()?(&CLSID_DxcValidator, &IDxcValidator::IID, &mut validator)
711            .result()?;
712        Ok(DxcValidator::new(validator.unwrap()))
713    }
714}