1use crate::primitives::{
2 itype::_Type, IUnknown, IUnknownVtbl, Interface, _MethodInfo, wrap_method_arguments,
3 wrap_strings_in_array, GUID, HRESULT,
4};
5use std::{
6 ffi::{c_long, c_void},
7 ops::Deref,
8 ptr,
9};
10use windows::{
11 core::BSTR,
12 Win32::System::{
13 Com::{SAFEARRAY, VARIANT, VT_UNKNOWN},
14 Ole::{SafeArrayCreateVector, SafeArrayGetElement, SafeArrayGetUBound},
15 },
16};
17
18#[repr(C)]
19pub struct _Assembly {
20 pub vtable: *const _AssemblyVtbl,
21}
22
23#[repr(C)]
24pub struct _AssemblyVtbl {
25 pub parent: IUnknownVtbl,
26 pub GetTypeInfoCount: *const c_void,
27 pub GetTypeInfo: *const c_void,
28 pub GetIDsOfNames: *const c_void,
29 pub Invoke: *const c_void,
30 pub ToString: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
31 pub Equals: *const c_void,
32 pub GetHashCode: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut c_long) -> HRESULT,
33 pub GetType: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut _Type) -> HRESULT,
34 pub get_CodeBase:
35 unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
36 pub get_EscapedCodeBase:
37 unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
38 pub GetName: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut c_void) -> HRESULT,
39 pub GetName_2: *const c_void,
40 pub get_FullName:
41 unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
42 pub get_EntryPoint:
43 unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut _MethodInfo) -> HRESULT,
44 pub GetType_2: unsafe extern "system" fn(
45 this: *mut c_void,
46 name: *mut u16,
47 pRetVal: *mut *mut _Type,
48 ) -> HRESULT,
49 pub GetType_3: *const c_void,
50 pub GetExportedTypes: *const c_void,
51 pub GetTypes:
52 unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut SAFEARRAY) -> HRESULT,
53 pub GetManifestResourceStream: *const c_void,
54 pub GetManifestResourceStream_2: *const c_void,
55 pub GetFile: *const c_void,
56 pub GetFiles: *const c_void,
57 pub GetFiles_2: *const c_void,
58 pub GetManifestResourceNames: *const c_void,
59 pub GetManifestResourceInfo: *const c_void,
60 pub get_Location:
61 unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
62 pub get_Evidence: *const c_void,
63 pub GetCustomAttributes: *const c_void,
64 pub GetCustomAttributes_2: *const c_void,
65 pub IsDefined: *const c_void,
66 pub GetObjectData: *const c_void,
67 pub add_ModuleResolve: *const c_void,
68 pub remove_ModuleResolve: *const c_void,
69 pub GetType_4: *const c_void,
70 pub GetSatelliteAssembly: *const c_void,
71 pub GetSatelliteAssembly_2: *const c_void,
72 pub LoadModule: *const c_void,
73 pub LoadModule_2: *const c_void,
74 pub CreateInstance: unsafe extern "system" fn(
75 this: *mut c_void,
76 typeName: *mut u16,
77 pRetVal: *mut VARIANT,
78 ) -> HRESULT,
79 pub CreateInstance_2: *const c_void,
80 pub CreateInstance_3: *const c_void,
81 pub GetLoadedModules: *const c_void,
82 pub GetLoadedModules_2: *const c_void,
83 pub GetModules: *const c_void,
84 pub GetModules_2: *const c_void,
85 pub GetModule: *const c_void,
86 pub GetReferencedAssemblies: *const c_void,
87 pub get_GlobalAssemblyCache: *const c_void,
88}
89
90impl _Assembly {
91 pub fn run_entrypoint(&self, args: &[String]) -> Result<VARIANT, String> {
92 let entrypoint = (*self).get_entrypoint()?;
93 let signature = unsafe { (*entrypoint).to_string()? };
94
95 if signature.ends_with("Main()") {
96 return unsafe { (*entrypoint).invoke_without_args(None) };
97 }
98
99 if signature.ends_with("Main(System.String[])") {
100 let args_variant = wrap_strings_in_array(args)?;
101 let method_args = wrap_method_arguments(vec![args_variant])?;
102
103 return unsafe { (*entrypoint).invoke(method_args, None) };
104 }
105
106 Err(format!(
107 "Cannot handle an entrypoint with this method signature: {}",
108 signature
109 ))
110 }
111
112 pub fn get_entrypoint(&self) -> Result<*mut _MethodInfo, String> {
113 let mut method_info_ptr: *mut _MethodInfo = ptr::null_mut();
114
115 let hr = unsafe { (*self).get_EntryPoint(&mut method_info_ptr) };
116
117 if hr.is_err() {
118 return Err(format!("Could not retrieve entrypoint: {:?}", hr));
119 }
120
121 if method_info_ptr.is_null() {
122 return Err("Could not retrieve entrypoint".into());
123 }
124
125 Ok(method_info_ptr)
126 }
127
128 pub fn to_string(&self) -> Result<String, String> {
129 let mut buffer = BSTR::new();
130
131 let hr = unsafe { (*self).ToString(&mut buffer as *mut _ as *mut *mut u16) };
132
133 if hr.is_err() {
134 return Err(format!("Failed while running `ToString`: {:?}", hr));
135 }
136
137 Ok(buffer.to_string())
138 }
139
140 #[inline]
141 pub unsafe fn ToString(&self, pRetVal: *mut *mut u16) -> HRESULT {
142 ((*self.vtable).ToString)(self as *const _ as *mut _, pRetVal)
143 }
144
145 #[inline]
146 pub unsafe fn GetHashCode(&self, pRetVal: *mut c_long) -> HRESULT {
147 ((*self.vtable).GetHashCode)(self as *const _ as *mut _, pRetVal)
148 }
149
150 #[inline]
151 pub unsafe fn GetType(&self, pRetVal: *mut *mut _Type) -> HRESULT {
152 ((*self.vtable).GetType)(self as *const _ as *mut _, pRetVal)
153 }
154
155 #[inline]
156 pub unsafe fn get_CodeBase(&self, pRetVal: *mut *mut u16) -> HRESULT {
157 ((*self.vtable).get_CodeBase)(self as *const _ as *mut _, pRetVal)
158 }
159
160 #[inline]
161 pub unsafe fn get_EscapedCodeBase(&self, pRetVal: *mut *mut u16) -> HRESULT {
162 ((*self.vtable).get_EscapedCodeBase)(self as *const _ as *mut _, pRetVal)
163 }
164
165 #[inline]
166 pub unsafe fn GetName(&self, pRetVal: *mut *mut c_void) -> HRESULT {
167 ((*self.vtable).GetName)(self as *const _ as *mut _, pRetVal)
168 }
169
170 #[inline]
171 pub unsafe fn get_FullName(&self, pRetVal: *mut *mut u16) -> HRESULT {
172 ((*self.vtable).get_FullName)(self as *const _ as *mut _, pRetVal)
173 }
174
175 #[inline]
176 pub unsafe fn get_EntryPoint(&self, pRetVal: *mut *mut _MethodInfo) -> HRESULT {
177 ((*self.vtable).get_EntryPoint)(self as *const _ as *mut _, pRetVal)
178 }
179
180 #[inline]
181 pub unsafe fn GetType_2(&self, name: *mut u16, pRetVal: *mut *mut _Type) -> HRESULT {
182 ((*self.vtable).GetType_2)(self as *const _ as *mut _, name, pRetVal)
183 }
184
185 #[inline]
186 pub unsafe fn GetTypes(&self, pRetVal: *mut *mut SAFEARRAY) -> HRESULT {
187 ((*self.vtable).GetTypes)(self as *const _ as *mut _, pRetVal)
188 }
189
190 #[inline]
191 pub unsafe fn get_Location(&self, pRetVal: *mut *mut u16) -> HRESULT {
192 ((*self.vtable).get_Location)(self as *const _ as *mut _, pRetVal)
193 }
194
195 #[inline]
196 pub unsafe fn CreateInstance(&self, typeName: *mut u16, pRetVal: *mut VARIANT) -> HRESULT {
197 ((*self.vtable).CreateInstance)(self as *const _ as *mut _, typeName, pRetVal)
198 }
199
200 pub fn create_instance(&self, name: &str) -> Result<VARIANT, String> {
201 let dw = BSTR::from(name);
202
203 let mut instance: VARIANT = VARIANT::default();
204 let hr = unsafe { (*self).CreateInstance(dw.into_raw() as *mut _, &mut instance) };
205
206 if hr.is_err() {
207 return Err(format!(
208 "Error while creating instance of `{}`: 0x{:x}",
209 name, hr.0
210 ));
211 }
212
213 Ok(instance)
214 }
215
216 pub fn get_type(&self, name: &str) -> Result<*mut _Type, String> {
217 let dw = BSTR::from(name);
218
219 let mut type_ptr: *mut _Type = ptr::null_mut();
220 let hr = unsafe { (*self).GetType_2(dw.into_raw() as *mut _, &mut type_ptr) };
221
222 if hr.is_err() {
223 return Err(format!(
224 "Error while retrieving type `{}`: 0x{:x}",
225 name, hr.0
226 ));
227 }
228
229 if type_ptr.is_null() {
230 return Err(format!("Could not retrieve type `{}`", name));
231 }
232
233 Ok(type_ptr)
234 }
235
236 pub fn get_types(&self) -> Result<Vec<*mut _Type>, String> {
237 let mut results: Vec<*mut _Type> = vec![];
238
239 let mut safe_array_ptr: *mut SAFEARRAY = unsafe { SafeArrayCreateVector(VT_UNKNOWN, 0, 0) };
240
241 let hr = unsafe { (*self).GetTypes(&mut safe_array_ptr) };
242
243 if hr.is_err() {
244 return Err(format!("Error while retrieving types: 0x{:x}", hr.0));
245 }
246
247 let ubound = unsafe { SafeArrayGetUBound(safe_array_ptr, 1) }.unwrap_or(0);
248
249 for i in 0..ubound {
250 let indices: [i32; 1] = [i as _];
251 let mut variant: *mut _Type = ptr::null_mut();
252 let pv = &mut variant as *mut _ as *mut c_void;
253
254 match unsafe { SafeArrayGetElement(safe_array_ptr, indices.as_ptr(), pv) } {
255 Ok(_) => {},
256 Err(e) => return Err(format!("Could not access safe array: {:?}", e.code())),
257 }
258
259 if !pv.is_null() {
260 results.push(variant)
261 }
262 }
263
264 Ok(results)
265 }
266}
267
268impl Interface for _Assembly {
269 const IID: GUID = GUID::from_values(
270 0x17156360,
271 0x2f1a,
272 0x384a,
273 [0xbc, 0x52, 0xfd, 0xe9, 0x3c, 0x21, 0x5c, 0x5b],
274 );
275
276 fn vtable(&self) -> *const c_void {
277 self.vtable as *const _ as *const c_void
278 }
279}
280
281impl Deref for _Assembly {
282 type Target = IUnknown;
283
284 #[inline]
285 fn deref(&self) -> &IUnknown {
286 unsafe { &*(self as *const _Assembly as *const IUnknown) }
287 }
288}