hassle_rs/
utils.rs

1use std::ffi::CStr;
2use std::path::PathBuf;
3
4use crate::os::{SysFreeString, SysStringLen, BSTR, HRESULT, LPCSTR, LPCWSTR, WCHAR};
5use crate::wrapper::*;
6use thiserror::Error;
7
8pub(crate) fn to_wide(msg: &str) -> Vec<WCHAR> {
9    widestring::WideCString::from_str(msg)
10        .unwrap()
11        .into_vec_with_nul()
12}
13
14pub(crate) fn from_wide(wide: LPCWSTR) -> String {
15    unsafe { widestring::WideCStr::from_ptr_str(wide) }
16        .to_string()
17        .expect("widestring decode failed")
18}
19
20pub(crate) fn from_bstr(string: BSTR) -> String {
21    let len = unsafe { SysStringLen(string) } as usize;
22
23    let result = unsafe { widestring::WideStr::from_ptr(string, len) }
24        .to_string()
25        .expect("widestring decode failed");
26
27    unsafe { SysFreeString(string) };
28    result
29}
30
31pub(crate) fn from_lpstr(string: LPCSTR) -> String {
32    unsafe { CStr::from_ptr(string) }
33        .to_str()
34        .unwrap()
35        .to_owned()
36}
37
38struct DefaultIncludeHandler {}
39
40impl DxcIncludeHandler for DefaultIncludeHandler {
41    fn load_source(&mut self, filename: String) -> Option<String> {
42        use std::io::Read;
43        match std::fs::File::open(filename) {
44            Ok(mut f) => {
45                let mut content = String::new();
46                f.read_to_string(&mut content).ok()?;
47                Some(content)
48            }
49            Err(_) => None,
50        }
51    }
52}
53
54#[derive(Error, Debug)]
55pub enum HassleError {
56    #[error("Win32 error: {0:x}")]
57    Win32Error(HRESULT),
58    #[error("{0}")]
59    CompileError(String),
60    #[error("Validation error: {0}")]
61    ValidationError(String),
62    #[error("Failed to load library {filename:?}: {inner:?}")]
63    LoadLibraryError {
64        filename: PathBuf,
65        #[source]
66        inner: libloading::Error,
67    },
68    #[error("LibLoading error: {0:?}")]
69    LibLoadingError(#[from] libloading::Error),
70    #[error("Windows only")]
71    WindowsOnly(String),
72}
73
74pub type Result<T, E = HassleError> = std::result::Result<T, E>;
75
76impl HRESULT {
77    /// Turns an [`HRESULT`] from the COM [`crate::ffi`] API declaration
78    /// into a [`Result`] containing [`HassleError`].
79    pub fn result(self) -> Result<()> {
80        self.result_with_success(())
81    }
82
83    /// Turns an [`HRESULT`] from the COM [`crate::ffi`] API declaration
84    /// into a [`Result`] containing [`HassleError`], with the desired value.
85    ///
86    /// Note that `v` is passed by value and is not a closure that is executed
87    /// lazily.  Use the short-circuiting `?` operator for such cases:
88    /// ```no_run
89    /// let mut blob: ComPtr<IDxcBlob> = ComPtr::new();
90    /// unsafe { self.inner.get_result(blob.as_mut_ptr()) }.result()?;
91    /// Ok(DxcBlob::new(blob))
92    /// ```
93    pub fn result_with_success<T>(self, v: T) -> Result<T> {
94        if self.is_err() {
95            Err(HassleError::Win32Error(self))
96        } else {
97            Ok(v)
98        }
99    }
100}
101
102/// Helper function to directly compile a HLSL shader to an intermediate language,
103/// this function expects `dxcompiler.dll` to be available in the current
104/// executable environment.
105///
106/// Specify -spirv as one of the `args` to compile to SPIR-V
107/// `dxc_path` can point to a library directly or the directory containing the library,
108/// in which case the appended filename depends on the platform.
109pub fn compile_hlsl(
110    source_name: &str,
111    shader_text: &str,
112    entry_point: &str,
113    target_profile: &str,
114    args: &[&str],
115    defines: &[(&str, Option<&str>)],
116) -> Result<Vec<u8>> {
117    let dxc = Dxc::new(None)?;
118
119    let compiler = dxc.create_compiler()?;
120    let library = dxc.create_library()?;
121
122    let blob = library.create_blob_with_encoding_from_str(shader_text)?;
123
124    let result = compiler.compile(
125        &blob,
126        source_name,
127        entry_point,
128        target_profile,
129        args,
130        Some(&mut DefaultIncludeHandler {}),
131        defines,
132    );
133
134    match result {
135        Err(result) => {
136            let error_blob = result.0.get_error_buffer()?;
137            Err(HassleError::CompileError(
138                library.get_blob_as_string(&error_blob.into())?,
139            ))
140        }
141        Ok(result) => {
142            let result_blob = result.get_result()?;
143
144            Ok(result_blob.to_vec())
145        }
146    }
147}
148
149/// Helper function to validate a DXIL binary independent from the compilation process,
150/// this function expects `dxcompiler.dll` and `dxil.dll` to be available in the current
151/// execution environment.
152///
153/// `dxil.dll` is only available on Windows.
154pub fn validate_dxil(data: &[u8]) -> Result<Vec<u8>, HassleError> {
155    let dxc = Dxc::new(None)?;
156    let dxil = Dxil::new(None)?;
157
158    let validator = dxil.create_validator()?;
159    let library = dxc.create_library()?;
160
161    let blob_encoding = library.create_blob_with_encoding(data)?;
162
163    match validator.validate(blob_encoding.into()) {
164        Ok(blob) => Ok(blob.to_vec()),
165        Err(result) => {
166            let error_blob = result.0.get_error_buffer()?;
167            Err(HassleError::ValidationError(
168                library.get_blob_as_string(&error_blob.into())?,
169            ))
170        }
171    }
172}
173
174pub use crate::fake_sign::fake_sign_dxil_in_place;