clroxide/primitives/
iassembly.rs

1use crate::primitives::{
2    itype::_Type, IUnknown, IUnknownVtbl, Interface, _MethodInfo, wrap_method_arguments,
3    wrap_strings_in_array, GUID, HRESULT,
4};
5use std::{
6    ffi::{c_long, c_void},
7    ops::Deref,
8    ptr,
9};
10use windows::{
11    core::BSTR,
12    Win32::System::{
13        Com::{SAFEARRAY, VARIANT, VT_UNKNOWN},
14        Ole::{SafeArrayCreateVector, SafeArrayGetElement, SafeArrayGetUBound},
15    },
16};
17
18#[repr(C)]
19pub struct _Assembly {
20    pub vtable: *const _AssemblyVtbl,
21}
22
23#[repr(C)]
24pub struct _AssemblyVtbl {
25    pub parent: IUnknownVtbl,
26    pub GetTypeInfoCount: *const c_void,
27    pub GetTypeInfo: *const c_void,
28    pub GetIDsOfNames: *const c_void,
29    pub Invoke: *const c_void,
30    pub ToString: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
31    pub Equals: *const c_void,
32    pub GetHashCode: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut c_long) -> HRESULT,
33    pub GetType: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut _Type) -> HRESULT,
34    pub get_CodeBase:
35        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
36    pub get_EscapedCodeBase:
37        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
38    pub GetName: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut c_void) -> HRESULT,
39    pub GetName_2: *const c_void,
40    pub get_FullName:
41        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
42    pub get_EntryPoint:
43        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut _MethodInfo) -> HRESULT,
44    pub GetType_2: unsafe extern "system" fn(
45        this: *mut c_void,
46        name: *mut u16,
47        pRetVal: *mut *mut _Type,
48    ) -> HRESULT,
49    pub GetType_3: *const c_void,
50    pub GetExportedTypes: *const c_void,
51    pub GetTypes:
52        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut SAFEARRAY) -> HRESULT,
53    pub GetManifestResourceStream: *const c_void,
54    pub GetManifestResourceStream_2: *const c_void,
55    pub GetFile: *const c_void,
56    pub GetFiles: *const c_void,
57    pub GetFiles_2: *const c_void,
58    pub GetManifestResourceNames: *const c_void,
59    pub GetManifestResourceInfo: *const c_void,
60    pub get_Location:
61        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
62    pub get_Evidence: *const c_void,
63    pub GetCustomAttributes: *const c_void,
64    pub GetCustomAttributes_2: *const c_void,
65    pub IsDefined: *const c_void,
66    pub GetObjectData: *const c_void,
67    pub add_ModuleResolve: *const c_void,
68    pub remove_ModuleResolve: *const c_void,
69    pub GetType_4: *const c_void,
70    pub GetSatelliteAssembly: *const c_void,
71    pub GetSatelliteAssembly_2: *const c_void,
72    pub LoadModule: *const c_void,
73    pub LoadModule_2: *const c_void,
74    pub CreateInstance: unsafe extern "system" fn(
75        this: *mut c_void,
76        typeName: *mut u16,
77        pRetVal: *mut VARIANT,
78    ) -> HRESULT,
79    pub CreateInstance_2: *const c_void,
80    pub CreateInstance_3: *const c_void,
81    pub GetLoadedModules: *const c_void,
82    pub GetLoadedModules_2: *const c_void,
83    pub GetModules: *const c_void,
84    pub GetModules_2: *const c_void,
85    pub GetModule: *const c_void,
86    pub GetReferencedAssemblies: *const c_void,
87    pub get_GlobalAssemblyCache: *const c_void,
88}
89
90impl _Assembly {
91    pub fn run_entrypoint(&self, args: &[String]) -> Result<VARIANT, String> {
92        let entrypoint = (*self).get_entrypoint()?;
93        let signature = unsafe { (*entrypoint).to_string()? };
94
95        if signature.ends_with("Main()") {
96            return unsafe { (*entrypoint).invoke_without_args(None) };
97        }
98
99        if signature.ends_with("Main(System.String[])") {
100            let args_variant = wrap_strings_in_array(args)?;
101            let method_args = wrap_method_arguments(vec![args_variant])?;
102
103            return unsafe { (*entrypoint).invoke(method_args, None) };
104        }
105
106        Err(format!(
107            "Cannot handle an entrypoint with this method signature: {}",
108            signature
109        ))
110    }
111
112    pub fn get_entrypoint(&self) -> Result<*mut _MethodInfo, String> {
113        let mut method_info_ptr: *mut _MethodInfo = ptr::null_mut();
114
115        let hr = unsafe { (*self).get_EntryPoint(&mut method_info_ptr) };
116
117        if hr.is_err() {
118            return Err(format!("Could not retrieve entrypoint: {:?}", hr));
119        }
120
121        if method_info_ptr.is_null() {
122            return Err("Could not retrieve entrypoint".into());
123        }
124
125        Ok(method_info_ptr)
126    }
127
128    pub fn to_string(&self) -> Result<String, String> {
129        let mut buffer = BSTR::new();
130
131        let hr = unsafe { (*self).ToString(&mut buffer as *mut _ as *mut *mut u16) };
132
133        if hr.is_err() {
134            return Err(format!("Failed while running `ToString`: {:?}", hr));
135        }
136
137        Ok(buffer.to_string())
138    }
139
140    #[inline]
141    pub unsafe fn ToString(&self, pRetVal: *mut *mut u16) -> HRESULT {
142        ((*self.vtable).ToString)(self as *const _ as *mut _, pRetVal)
143    }
144
145    #[inline]
146    pub unsafe fn GetHashCode(&self, pRetVal: *mut c_long) -> HRESULT {
147        ((*self.vtable).GetHashCode)(self as *const _ as *mut _, pRetVal)
148    }
149
150    #[inline]
151    pub unsafe fn GetType(&self, pRetVal: *mut *mut _Type) -> HRESULT {
152        ((*self.vtable).GetType)(self as *const _ as *mut _, pRetVal)
153    }
154
155    #[inline]
156    pub unsafe fn get_CodeBase(&self, pRetVal: *mut *mut u16) -> HRESULT {
157        ((*self.vtable).get_CodeBase)(self as *const _ as *mut _, pRetVal)
158    }
159
160    #[inline]
161    pub unsafe fn get_EscapedCodeBase(&self, pRetVal: *mut *mut u16) -> HRESULT {
162        ((*self.vtable).get_EscapedCodeBase)(self as *const _ as *mut _, pRetVal)
163    }
164
165    #[inline]
166    pub unsafe fn GetName(&self, pRetVal: *mut *mut c_void) -> HRESULT {
167        ((*self.vtable).GetName)(self as *const _ as *mut _, pRetVal)
168    }
169
170    #[inline]
171    pub unsafe fn get_FullName(&self, pRetVal: *mut *mut u16) -> HRESULT {
172        ((*self.vtable).get_FullName)(self as *const _ as *mut _, pRetVal)
173    }
174
175    #[inline]
176    pub unsafe fn get_EntryPoint(&self, pRetVal: *mut *mut _MethodInfo) -> HRESULT {
177        ((*self.vtable).get_EntryPoint)(self as *const _ as *mut _, pRetVal)
178    }
179
180    #[inline]
181    pub unsafe fn GetType_2(&self, name: *mut u16, pRetVal: *mut *mut _Type) -> HRESULT {
182        ((*self.vtable).GetType_2)(self as *const _ as *mut _, name, pRetVal)
183    }
184
185    #[inline]
186    pub unsafe fn GetTypes(&self, pRetVal: *mut *mut SAFEARRAY) -> HRESULT {
187        ((*self.vtable).GetTypes)(self as *const _ as *mut _, pRetVal)
188    }
189
190    #[inline]
191    pub unsafe fn get_Location(&self, pRetVal: *mut *mut u16) -> HRESULT {
192        ((*self.vtable).get_Location)(self as *const _ as *mut _, pRetVal)
193    }
194
195    #[inline]
196    pub unsafe fn CreateInstance(&self, typeName: *mut u16, pRetVal: *mut VARIANT) -> HRESULT {
197        ((*self.vtable).CreateInstance)(self as *const _ as *mut _, typeName, pRetVal)
198    }
199
200    pub fn create_instance(&self, name: &str) -> Result<VARIANT, String> {
201        let dw = BSTR::from(name);
202
203        let mut instance: VARIANT = VARIANT::default();
204        let hr = unsafe { (*self).CreateInstance(dw.into_raw() as *mut _, &mut instance) };
205
206        if hr.is_err() {
207            return Err(format!(
208                "Error while creating instance of `{}`: 0x{:x}",
209                name, hr.0
210            ));
211        }
212
213        Ok(instance)
214    }
215
216    pub fn get_type(&self, name: &str) -> Result<*mut _Type, String> {
217        let dw = BSTR::from(name);
218
219        let mut type_ptr: *mut _Type = ptr::null_mut();
220        let hr = unsafe { (*self).GetType_2(dw.into_raw() as *mut _, &mut type_ptr) };
221
222        if hr.is_err() {
223            return Err(format!(
224                "Error while retrieving type `{}`: 0x{:x}",
225                name, hr.0
226            ));
227        }
228
229        if type_ptr.is_null() {
230            return Err(format!("Could not retrieve type `{}`", name));
231        }
232
233        Ok(type_ptr)
234    }
235
236    pub fn get_types(&self) -> Result<Vec<*mut _Type>, String> {
237        let mut results: Vec<*mut _Type> = vec![];
238
239        let mut safe_array_ptr: *mut SAFEARRAY = unsafe { SafeArrayCreateVector(VT_UNKNOWN, 0, 0) };
240
241        let hr = unsafe { (*self).GetTypes(&mut safe_array_ptr) };
242
243        if hr.is_err() {
244            return Err(format!("Error while retrieving types: 0x{:x}", hr.0));
245        }
246
247        let ubound = unsafe { SafeArrayGetUBound(safe_array_ptr, 1) }.unwrap_or(0);
248
249        for i in 0..ubound {
250            let indices: [i32; 1] = [i as _];
251            let mut variant: *mut _Type = ptr::null_mut();
252            let pv = &mut variant as *mut _ as *mut c_void;
253
254            match unsafe { SafeArrayGetElement(safe_array_ptr, indices.as_ptr(), pv) } {
255                Ok(_) => {},
256                Err(e) => return Err(format!("Could not access safe array: {:?}", e.code())),
257            }
258
259            if !pv.is_null() {
260                results.push(variant)
261            }
262        }
263
264        Ok(results)
265    }
266}
267
268impl Interface for _Assembly {
269    const IID: GUID = GUID::from_values(
270        0x17156360,
271        0x2f1a,
272        0x384a,
273        [0xbc, 0x52, 0xfd, 0xe9, 0x3c, 0x21, 0x5c, 0x5b],
274    );
275
276    fn vtable(&self) -> *const c_void {
277        self.vtable as *const _ as *const c_void
278    }
279}
280
281impl Deref for _Assembly {
282    type Target = IUnknown;
283
284    #[inline]
285    fn deref(&self) -> &IUnknown {
286        unsafe { &*(self as *const _Assembly as *const IUnknown) }
287    }
288}