hassle_rs/
utils.rs

1use std::ffi::CStr;
2use std::path::PathBuf;
3
4use crate::os::{SysFreeString, SysStringLen, BSTR, HRESULT, LPCSTR, LPCWSTR, WCHAR};
5use crate::wrapper::*;
6use thiserror::Error;
7
8pub(crate) fn to_wide(msg: &str) -> Vec<WCHAR> {
9    widestring::WideCString::from_str(msg)
10        .unwrap()
11        .into_vec_with_nul()
12}
13
14pub(crate) fn from_wide(wide: LPCWSTR) -> String {
15    unsafe { widestring::WideCStr::from_ptr_str(wide) }
16        .to_string()
17        .expect("widestring decode failed")
18}
19
20pub(crate) fn from_bstr(string: BSTR) -> String {
21    let len = unsafe { SysStringLen(string) } as usize;
22
23    let result = unsafe { widestring::WideStr::from_ptr(string, len) }
24        .to_string()
25        .expect("widestring decode failed");
26
27    unsafe { SysFreeString(string) };
28    result
29}
30
31pub(crate) fn from_lpstr(string: LPCSTR) -> String {
32    unsafe { CStr::from_ptr(string) }
33        .to_str()
34        .unwrap()
35        .to_owned()
36}
37
38struct DefaultIncludeHandler {}
39
40impl DxcIncludeHandler for DefaultIncludeHandler {
41    fn load_source(&mut self, filename: String) -> Option<String> {
42        use std::io::Read;
43        match std::fs::File::open(filename) {
44            Ok(mut f) => {
45                let mut content = String::new();
46                f.read_to_string(&mut content).ok()?;
47                Some(content)
48            }
49            Err(_) => None,
50        }
51    }
52}
53
54#[derive(Error, Debug)]
55pub enum HassleError {
56    #[error("Win32 error: {0:x}")]
57    Win32Error(HRESULT),
58    #[error("{0}")]
59    CompileError(String),
60    #[error("Validation error: {0}")]
61    ValidationError(String),
62    #[error("Failed to load library {filename:?}: {inner:?}")]
63    LoadLibraryError {
64        filename: PathBuf,
65        #[source]
66        inner: libloading::Error,
67    },
68    #[error("LibLoading error: {0:?}")]
69    LibLoadingError(#[from] libloading::Error),
70}
71
72pub type Result<T, E = HassleError> = std::result::Result<T, E>;
73
74impl HRESULT {
75    /// Turns an [`HRESULT`] from the COM [`crate::ffi`] API declaration
76    /// into a [`Result`] containing [`HassleError`].
77    pub fn result(self) -> Result<()> {
78        self.result_with_success(())
79    }
80
81    /// Turns an [`HRESULT`] from the COM [`crate::ffi`] API declaration
82    /// into a [`Result`] containing [`HassleError`], with the desired value.
83    ///
84    /// Note that `v` is passed by value and is not a closure that is executed
85    /// lazily.  Use the short-circuiting `?` operator for such cases:
86    /// ```no_run
87    /// let mut blob: ComPtr<IDxcBlob> = ComPtr::new();
88    /// unsafe { self.inner.get_result(blob.as_mut_ptr()) }.result()?;
89    /// Ok(DxcBlob::new(blob))
90    /// ```
91    pub fn result_with_success<T>(self, v: T) -> Result<T> {
92        if self.is_err() {
93            Err(HassleError::Win32Error(self))
94        } else {
95            Ok(v)
96        }
97    }
98}
99
100/// Helper function to directly compile a HLSL shader to an intermediate language,
101/// this function expects `dxcompiler.dll` to be available in the current
102/// executable environment.
103///
104/// Specify -spirv as one of the `args` to compile to SPIR-V
105/// `dxc_path` can point to a library directly or the directory containing the library,
106/// in which case the appended filename depends on the platform.
107pub fn compile_hlsl(
108    source_name: &str,
109    shader_text: &str,
110    entry_point: &str,
111    target_profile: &str,
112    args: &[&str],
113    defines: &[(&str, Option<&str>)],
114) -> Result<Vec<u8>> {
115    let dxc = Dxc::new(None)?;
116
117    let compiler = dxc.create_compiler()?;
118    let library = dxc.create_library()?;
119
120    let blob = library.create_blob_with_encoding_from_str(shader_text)?;
121
122    let result = compiler.compile(
123        &blob,
124        source_name,
125        entry_point,
126        target_profile,
127        args,
128        Some(&mut DefaultIncludeHandler {}),
129        defines,
130    );
131
132    match result {
133        Err(result) => {
134            let error_blob = result.0.get_error_buffer()?;
135            Err(HassleError::CompileError(
136                library.get_blob_as_string(&error_blob.into())?,
137            ))
138        }
139        Ok(result) => {
140            let result_blob = result.get_result()?;
141
142            Ok(result_blob.to_vec())
143        }
144    }
145}
146
147/// Helper function to validate a DXIL binary independent from the compilation process,
148/// this function expects `dxcompiler.dll` and `dxil.dll` to be available in the current
149/// execution environment.
150///
151/// `dxil.dll` is only available on Windows.
152pub fn validate_dxil(data: &[u8]) -> Result<Vec<u8>, HassleError> {
153    let dxc = Dxc::new(None)?;
154    let dxil = Dxil::new(None)?;
155
156    let validator = dxil.create_validator()?;
157    let library = dxc.create_library()?;
158
159    let blob_encoding = library.create_blob_with_encoding(data)?;
160
161    match validator.validate(blob_encoding.into()) {
162        Ok(blob) => Ok(blob.to_vec()),
163        Err(result) => {
164            let error_blob = result.0.get_error_buffer()?;
165            Err(HassleError::ValidationError(
166                library.get_blob_as_string(&error_blob.into())?,
167            ))
168        }
169    }
170}
171
172pub use crate::fake_sign::fake_sign_dxil_in_place;