hassle_rs/
wrapper.rs

1#![allow(
2    clippy::too_many_arguments,
3    clippy::new_without_default,
4    clippy::type_complexity
5)]
6
7use crate::ffi::*;
8use crate::os::{HRESULT, LPCWSTR, LPWSTR, WCHAR};
9use crate::utils::{from_wide, to_wide, HassleError, Result};
10use com::{class, interfaces::IUnknown, production::Class, production::ClassAllocation, Interface};
11use libloading::{library_filename, Library, Symbol};
12use std::cell::RefCell;
13use std::ops::Deref;
14use std::path::PathBuf;
15use std::pin::Pin;
16
17pub struct DxcBlob {
18    inner: IDxcBlob,
19}
20
21impl DxcBlob {
22    fn new(inner: IDxcBlob) -> Self {
23        Self { inner }
24    }
25
26    pub fn as_slice<T>(&self) -> &[T] {
27        unsafe {
28            std::slice::from_raw_parts(
29                self.inner.get_buffer_pointer().cast(),
30                self.inner.get_buffer_size() / std::mem::size_of::<T>(),
31            )
32        }
33    }
34
35    pub fn as_mut_slice<T>(&mut self) -> &mut [T] {
36        unsafe {
37            std::slice::from_raw_parts_mut(
38                self.inner.get_buffer_pointer().cast(),
39                self.inner.get_buffer_size() / std::mem::size_of::<T>(),
40            )
41        }
42    }
43
44    pub fn to_vec<T>(&self) -> Vec<T>
45    where
46        T: Clone,
47    {
48        self.as_slice().to_vec()
49    }
50}
51
52impl AsRef<[u8]> for DxcBlob {
53    fn as_ref(&self) -> &[u8] {
54        self.as_slice()
55    }
56}
57
58impl AsMut<[u8]> for DxcBlob {
59    fn as_mut(&mut self) -> &mut [u8] {
60        self.as_mut_slice()
61    }
62}
63
64pub struct DxcBlobEncoding {
65    inner: IDxcBlobEncoding,
66}
67
68impl DxcBlobEncoding {
69    fn new(inner: IDxcBlobEncoding) -> Self {
70        Self { inner }
71    }
72}
73
74impl From<DxcBlobEncoding> for DxcBlob {
75    fn from(encoded_blob: DxcBlobEncoding) -> Self {
76        DxcBlob::new(encoded_blob.inner.query_interface::<IDxcBlob>().unwrap())
77    }
78}
79
80pub struct DxcOperationResult {
81    inner: IDxcOperationResult,
82}
83
84impl DxcOperationResult {
85    fn new(inner: IDxcOperationResult) -> Self {
86        Self { inner }
87    }
88
89    pub fn get_status(&self) -> Result<u32> {
90        let mut status: u32 = 0;
91        unsafe { self.inner.get_status(&mut status) }.result_with_success(status)
92    }
93
94    pub fn get_result(&self) -> Result<DxcBlob> {
95        let mut blob = None;
96        unsafe { self.inner.get_result(&mut blob) }.result()?;
97        Ok(DxcBlob::new(blob.unwrap()))
98    }
99
100    pub fn get_error_buffer(&self) -> Result<DxcBlobEncoding> {
101        let mut blob = None;
102
103        unsafe { self.inner.get_error_buffer(&mut blob) }.result()?;
104        Ok(DxcBlobEncoding::new(blob.unwrap()))
105    }
106}
107
108pub trait DxcIncludeHandler {
109    fn load_source(&mut self, filename: String) -> Option<String>;
110}
111
112class! {
113    #[no_class_factory]
114    class DxcIncludeHandlerWrapper: IDxcIncludeHandler {
115        // Com-rs intentionally does not support lifetimes in its class structs
116        // since they live on the heap and their lifetime can be prolonged for
117        // as long as someone keeps a reference through `add_ref()`.
118        // The only way for us to access the library and handler implementation,
119        // which are now intentionally behind a borrow to signify our promise
120        // regarding lifetime, is by transmuting them away and "ensuring" the
121        // class object is discarded at the end of our function call.
122
123        library: &'static DxcLibrary,
124        handler: RefCell<&'static mut dyn DxcIncludeHandler>,
125
126        pinned: RefCell<Vec<Pin<String>>>,
127    }
128
129    impl IDxcIncludeHandler for DxcIncludeHandlerWrapper {
130        fn load_source(&self, filename: LPCWSTR, include_source: *mut Option<IDxcBlob>) -> HRESULT {
131            let filename = crate::utils::from_wide(filename);
132
133            let mut handler = self.handler.borrow_mut();
134            let source = handler.load_source(filename);
135
136            if let Some(source) = source {
137                let source = Pin::new(source);
138                let blob = self.library
139                    .create_blob_with_encoding_from_str(&source)
140                    .unwrap();
141
142                unsafe { *include_source = Some(DxcBlob::from(blob).inner) };
143                self.pinned.borrow_mut().push(source);
144
145                // NOERROR
146                0
147            } else {
148                -2_147_024_894 // ERROR_FILE_NOT_FOUND / 0x80070002
149            }
150            .into()
151        }
152    }
153}
154
155/// Represents a reference to a COM object that should only live as long as itself
156///
157/// In other words, on [`drop()`] we assert that the refcount is decremented to zero,
158/// rather than allowing it to be referenced externally (i.e. [`Class::dec_ref_count()`]
159/// returning `> 0`).
160/// This object functions a lot like [`ClassAllocation`]: see its similar [`drop()`]
161/// implementation for details.
162///
163/// Note that COM objects live on the heap by design, because of this refcount system.
164struct LocalClassAllocation<T: Class>(core::pin::Pin<Box<T>>);
165
166impl<T: Class> LocalClassAllocation<T> {
167    fn new(allocation: ClassAllocation<T>) -> Self {
168        // TODO: There is no way to take the internal, owned box out of com-rs's
169        // allocation wrapper.
170        // https://github.com/microsoft/com-rs/issues/236 covers this issue as a whole,
171        // including lifetime support and this `LocalClassAllocation` upstream.
172        let inner: core::mem::ManuallyDrop<core::pin::Pin<Box<T>>> =
173            unsafe { std::mem::transmute(allocation) };
174
175        Self(core::mem::ManuallyDrop::into_inner(inner))
176    }
177
178    // TODO: Return a borrow of this interface?
179    // query_interface() is not behind one of the traits
180    // fn query_interface<T>(&self) -> Option<T> {
181    //     self.0.query_interface::<T>().unwrap()
182    // }
183}
184
185impl<T: Class> Deref for LocalClassAllocation<T> {
186    type Target = core::pin::Pin<Box<T>>;
187
188    fn deref(&self) -> &Self::Target {
189        &self.0
190    }
191}
192
193impl<T: Class> Drop for LocalClassAllocation<T> {
194    fn drop(&mut self) {
195        // Check if we are the only remaining reference to this object
196        assert_eq!(
197            unsafe { self.0.dec_ref_count() },
198            0,
199            "COM object is still referenced"
200        );
201        // Now that we're the last one to give up our refcount, it is safe
202        // for the internal object to get dropped.
203    }
204}
205
206impl DxcIncludeHandlerWrapper {
207    /// SAFETY: Make sure the returned object does _not_ outlive the lifetime
208    /// of either `library` nor `include_handler`
209    unsafe fn create_include_handler(
210        library: &'_ DxcLibrary,
211        include_handler: &'_ mut dyn DxcIncludeHandler,
212    ) -> LocalClassAllocation<DxcIncludeHandlerWrapper> {
213        #[allow(clippy::missing_transmute_annotations)]
214        LocalClassAllocation::new(Self::allocate(
215            std::mem::transmute(library),
216            RefCell::new(std::mem::transmute(include_handler)),
217            RefCell::new(vec![]),
218        ))
219    }
220}
221
222pub struct DxcCompiler {
223    inner: IDxcCompiler2,
224    library: DxcLibrary,
225}
226
227impl DxcCompiler {
228    fn new(inner: IDxcCompiler2, library: DxcLibrary) -> Self {
229        Self { inner, library }
230    }
231
232    fn prep_defines(
233        defines: &[(&str, Option<&str>)],
234        wide_defines: &mut Vec<(Vec<WCHAR>, Vec<WCHAR>)>,
235        dxc_defines: &mut Vec<DxcDefine>,
236    ) {
237        for (name, value) in defines {
238            if value.is_none() {
239                wide_defines.push((to_wide(name), to_wide("1")));
240            } else {
241                wide_defines.push((to_wide(name), to_wide(value.unwrap())));
242            }
243        }
244
245        for (ref name, ref value) in wide_defines {
246            dxc_defines.push(DxcDefine {
247                name: name.as_ptr(),
248                value: value.as_ptr(),
249            });
250        }
251    }
252
253    fn prep_args(args: &[&str], wide_args: &mut Vec<Vec<WCHAR>>, dxc_args: &mut Vec<LPCWSTR>) {
254        for a in args {
255            wide_args.push(to_wide(a));
256        }
257
258        for a in wide_args {
259            dxc_args.push(a.as_ptr());
260        }
261    }
262
263    pub fn compile(
264        &self,
265        blob: &DxcBlobEncoding,
266        source_name: &str,
267        entry_point: &str,
268        target_profile: &str,
269        args: &[&str],
270        include_handler: Option<&mut dyn DxcIncludeHandler>,
271        defines: &[(&str, Option<&str>)],
272    ) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
273        let mut wide_args = vec![];
274        let mut dxc_args = vec![];
275        Self::prep_args(args, &mut wide_args, &mut dxc_args);
276
277        let mut wide_defines = vec![];
278        let mut dxc_defines = vec![];
279        Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
280
281        // Keep alive on the stack
282        let include_handler = include_handler.map(|include_handler| unsafe {
283            DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
284        });
285        // TODO: query_interface() should have a borrow on LocalClassAllocation to prevent things going kaboom
286        let include_handler = include_handler
287            .as_ref()
288            .map(|i| i.query_interface().unwrap());
289
290        let mut result = None;
291        let result_hr = unsafe {
292            self.inner.compile(
293                &blob.inner,
294                to_wide(source_name).as_ptr(),
295                to_wide(entry_point).as_ptr(),
296                to_wide(target_profile).as_ptr(),
297                dxc_args.as_ptr(),
298                dxc_args.len() as u32,
299                dxc_defines.as_ptr(),
300                dxc_defines.len() as u32,
301                &include_handler,
302                &mut result,
303            )
304        };
305
306        let result = result.unwrap();
307
308        let mut compile_error = 0u32;
309        let status_hr = unsafe { result.get_status(&mut compile_error) };
310
311        if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
312            Ok(DxcOperationResult::new(result))
313        } else {
314            Err((DxcOperationResult::new(result), result_hr))
315        }
316    }
317
318    pub fn compile_with_debug(
319        &self,
320        blob: &DxcBlobEncoding,
321        source_name: &str,
322        entry_point: &str,
323        target_profile: &str,
324        args: &[&str],
325        include_handler: Option<&mut dyn DxcIncludeHandler>,
326        defines: &[(&str, Option<&str>)],
327    ) -> Result<(DxcOperationResult, String, DxcBlob), (DxcOperationResult, HRESULT)> {
328        let mut wide_args = vec![];
329        let mut dxc_args = vec![];
330        Self::prep_args(args, &mut wide_args, &mut dxc_args);
331
332        let mut wide_defines = vec![];
333        let mut dxc_defines = vec![];
334        Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
335
336        // Keep alive on the stack
337        let include_handler = include_handler.map(|include_handler| unsafe {
338            DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
339        });
340        let include_handler = include_handler
341            .as_ref()
342            .map(|i| i.query_interface().unwrap());
343
344        let mut result = None;
345        let mut debug_blob = None;
346        let mut debug_filename: LPWSTR = std::ptr::null_mut();
347
348        let result_hr = unsafe {
349            self.inner.compile_with_debug(
350                &blob.inner,
351                to_wide(source_name).as_ptr(),
352                to_wide(entry_point).as_ptr(),
353                to_wide(target_profile).as_ptr(),
354                dxc_args.as_ptr(),
355                dxc_args.len() as u32,
356                dxc_defines.as_ptr(),
357                dxc_defines.len() as u32,
358                include_handler,
359                &mut result,
360                &mut debug_filename,
361                &mut debug_blob,
362            )
363        };
364        let result = result.unwrap();
365        let debug_blob = debug_blob.unwrap();
366
367        let mut compile_error = 0u32;
368        let status_hr = unsafe { result.get_status(&mut compile_error) };
369
370        if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
371            Ok((
372                DxcOperationResult::new(result),
373                from_wide(debug_filename),
374                DxcBlob::new(debug_blob),
375            ))
376        } else {
377            Err((DxcOperationResult::new(result), result_hr))
378        }
379    }
380
381    pub fn preprocess(
382        &self,
383        blob: &DxcBlobEncoding,
384        source_name: &str,
385        args: &[&str],
386        include_handler: Option<&mut dyn DxcIncludeHandler>,
387        defines: &[(&str, Option<&str>)],
388    ) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
389        let mut wide_args = vec![];
390        let mut dxc_args = vec![];
391        Self::prep_args(args, &mut wide_args, &mut dxc_args);
392
393        let mut wide_defines = vec![];
394        let mut dxc_defines = vec![];
395        Self::prep_defines(defines, &mut wide_defines, &mut dxc_defines);
396
397        // Keep alive on the stack
398        let include_handler = include_handler.map(|include_handler| unsafe {
399            DxcIncludeHandlerWrapper::create_include_handler(&self.library, include_handler)
400        });
401        let include_handler = include_handler
402            .as_ref()
403            .map(|i| i.query_interface().unwrap());
404
405        let mut result = None;
406        let result_hr = unsafe {
407            self.inner.preprocess(
408                &blob.inner,
409                to_wide(source_name).as_ptr(),
410                dxc_args.as_ptr(),
411                dxc_args.len() as u32,
412                dxc_defines.as_ptr(),
413                dxc_defines.len() as u32,
414                include_handler,
415                &mut result,
416            )
417        };
418
419        let result = result.unwrap();
420
421        let mut compile_error = 0u32;
422        let status_hr = unsafe { result.get_status(&mut compile_error) };
423
424        if !result_hr.is_err() && !status_hr.is_err() && compile_error == 0 {
425            Ok(DxcOperationResult::new(result))
426        } else {
427            Err((DxcOperationResult::new(result), result_hr))
428        }
429    }
430
431    pub fn disassemble(&self, blob: &DxcBlob) -> Result<DxcBlobEncoding> {
432        let mut result_blob = None;
433        unsafe { self.inner.disassemble(&blob.inner, &mut result_blob) }.result()?;
434        Ok(DxcBlobEncoding::new(result_blob.unwrap()))
435    }
436}
437
438#[derive(Clone)]
439pub struct DxcLibrary {
440    inner: IDxcLibrary,
441}
442
443impl DxcLibrary {
444    fn new(inner: IDxcLibrary) -> Self {
445        Self { inner }
446    }
447
448    pub fn create_blob_with_encoding(&self, data: &[u8]) -> Result<DxcBlobEncoding> {
449        let mut blob = None;
450
451        unsafe {
452            self.inner.create_blob_with_encoding_from_pinned(
453                data.as_ptr().cast(),
454                data.len() as u32,
455                0, // Binary; no code page
456                &mut blob,
457            )
458        }
459        .result()?;
460        Ok(DxcBlobEncoding::new(blob.unwrap()))
461    }
462
463    pub fn create_blob_with_encoding_from_str(&self, text: &str) -> Result<DxcBlobEncoding> {
464        let mut blob = None;
465        const CP_UTF8: u32 = 65001; // UTF-8 translation
466
467        unsafe {
468            self.inner.create_blob_with_encoding_from_pinned(
469                text.as_ptr().cast(),
470                text.len() as u32,
471                CP_UTF8,
472                &mut blob,
473            )
474        }
475        .result()?;
476        Ok(DxcBlobEncoding::new(blob.unwrap()))
477    }
478
479    pub fn get_blob_as_string(&self, blob: &DxcBlob) -> Result<String> {
480        let mut blob_utf8 = None;
481
482        unsafe { self.inner.get_blob_as_utf8(&blob.inner, &mut blob_utf8) }.result()?;
483
484        let blob_utf8 = blob_utf8.unwrap();
485
486        Ok(String::from_utf8(DxcBlob::new(blob_utf8.query_interface().unwrap()).to_vec()).unwrap())
487    }
488}
489
490#[derive(Debug)]
491pub struct Dxc {
492    dxc_lib: Library,
493}
494
495impl Dxc {
496    /// `lib_path` is an optional path to the library.  Otherwise
497    /// [`libloading::library_filename("dxcompiler")`] is used.
498    pub fn new(lib_path: Option<PathBuf>) -> Result<Self> {
499        let lib_path = lib_path.unwrap_or_else(|| PathBuf::from(library_filename("dxcompiler")));
500
501        let dxc_lib =
502            unsafe { Library::new(&lib_path) }.map_err(|e| HassleError::LoadLibraryError {
503                filename: lib_path,
504                inner: e,
505            })?;
506
507        Ok(Self { dxc_lib })
508    }
509
510    pub(crate) fn get_dxc_create_instance<T>(&self) -> Result<Symbol<DxcCreateInstanceProc<T>>> {
511        Ok(unsafe { self.dxc_lib.get(b"DxcCreateInstance\0")? })
512    }
513
514    pub fn create_compiler(&self) -> Result<DxcCompiler> {
515        let mut compiler = None;
516
517        self.get_dxc_create_instance()?(&CLSID_DxcCompiler, &IDxcCompiler2::IID, &mut compiler)
518            .result()?;
519        Ok(DxcCompiler::new(
520            compiler.unwrap(),
521            self.create_library().unwrap(),
522        ))
523    }
524
525    pub fn create_library(&self) -> Result<DxcLibrary> {
526        let mut library = None;
527        self.get_dxc_create_instance()?(&CLSID_DxcLibrary, &IDxcLibrary::IID, &mut library)
528            .result()?;
529        Ok(DxcLibrary::new(library.unwrap()))
530    }
531
532    pub fn create_reflector(&self) -> Result<DxcReflector> {
533        let mut reflector = None;
534
535        self.get_dxc_create_instance()?(
536            &CLSID_DxcContainerReflection,
537            &IDxcContainerReflection::IID,
538            &mut reflector,
539        )
540        .result()?;
541        Ok(DxcReflector::new(reflector.unwrap()))
542    }
543}
544
545pub struct DxcValidator {
546    inner: IDxcValidator,
547}
548
549pub type DxcValidatorVersion = (u32, u32);
550
551impl DxcValidator {
552    fn new(inner: IDxcValidator) -> Self {
553        Self { inner }
554    }
555
556    pub fn version(&self) -> Result<DxcValidatorVersion> {
557        let version = self
558            .inner
559            .query_interface::<IDxcVersionInfo>()
560            .ok_or(HassleError::Win32Error(HRESULT(com::sys::E_NOINTERFACE)))?;
561
562        let mut major = 0;
563        let mut minor = 0;
564
565        unsafe { version.get_version(&mut major, &mut minor) }.result_with_success((major, minor))
566    }
567
568    pub fn validate(&self, blob: DxcBlob) -> Result<DxcBlob, (DxcOperationResult, HassleError)> {
569        let mut result = None;
570        let result_hr = unsafe {
571            self.inner
572                .validate(&blob.inner, DXC_VALIDATOR_FLAGS_IN_PLACE_EDIT, &mut result)
573        };
574
575        let result = result.unwrap();
576
577        let mut validate_status = 0u32;
578        let status_hr = unsafe { result.get_status(&mut validate_status) };
579
580        if !result_hr.is_err() && !status_hr.is_err() && validate_status == 0 {
581            Ok(blob)
582        } else {
583            Err((
584                DxcOperationResult::new(result),
585                HassleError::Win32Error(result_hr),
586            ))
587        }
588    }
589}
590
591pub struct Reflection {
592    inner: ID3D12ShaderReflection,
593}
594impl Reflection {
595    fn new(inner: ID3D12ShaderReflection) -> Self {
596        Self { inner }
597    }
598
599    pub fn thread_group_size(&self) -> [u32; 3] {
600        let (mut size_x, mut size_y, mut size_z) = (0u32, 0u32, 0u32);
601        unsafe {
602            self.inner
603                .get_thread_group_size(&mut size_x, &mut size_y, &mut size_z)
604        };
605        [size_x, size_y, size_z]
606    }
607}
608
609pub struct DxcReflector {
610    inner: IDxcContainerReflection,
611}
612impl DxcReflector {
613    fn new(inner: IDxcContainerReflection) -> Self {
614        Self { inner }
615    }
616
617    pub fn reflect(&self, blob: DxcBlob) -> Result<Reflection> {
618        let result_hr = unsafe { self.inner.load(blob.inner) };
619        if result_hr.is_err() {
620            return Err(HassleError::Win32Error(result_hr));
621        }
622
623        let mut shader_idx = 0;
624        let result_hr = unsafe { self.inner.find_first_part_kind(DFCC_DXIL, &mut shader_idx) };
625        if result_hr.is_err() {
626            return Err(HassleError::Win32Error(result_hr));
627        }
628
629        let mut reflection = None::<IUnknown>;
630        let result_hr = unsafe {
631            self.inner.get_part_reflection(
632                shader_idx,
633                &ID3D12ShaderReflection::IID,
634                &mut reflection,
635            )
636        };
637        if result_hr.is_err() {
638            return Err(HassleError::Win32Error(result_hr));
639        }
640
641        Ok(Reflection::new(
642            reflection.unwrap().query_interface().unwrap(),
643        ))
644    }
645}
646
647#[derive(Debug)]
648pub struct Dxil {
649    dxil_lib: Library,
650}
651
652impl Dxil {
653    /// `lib_path` is an optional path to the library.  Otherwise
654    /// [`libloading::library_filename("dxil")`] is used.
655    pub fn new(lib_path: Option<PathBuf>) -> Result<Self> {
656        let lib_path = lib_path.unwrap_or_else(|| PathBuf::from(library_filename("dxil")));
657
658        let dxil_lib =
659            unsafe { Library::new(&lib_path) }.map_err(|e| HassleError::LoadLibraryError {
660                filename: lib_path.to_owned(),
661                inner: e,
662            })?;
663
664        Ok(Self { dxil_lib })
665    }
666
667    fn get_dxc_create_instance<T>(&self) -> Result<Symbol<DxcCreateInstanceProc<T>>> {
668        Ok(unsafe { self.dxil_lib.get(b"DxcCreateInstance\0")? })
669    }
670
671    pub fn create_validator(&self) -> Result<DxcValidator> {
672        let mut validator = None;
673        self.get_dxc_create_instance()?(&CLSID_DxcValidator, &IDxcValidator::IID, &mut validator)
674            .result()?;
675        Ok(DxcValidator::new(validator.unwrap()))
676    }
677}