clroxide 1.1.1

A library that allows you to host the CLR and execute dotnet binaries.
Documentation
use crate::primitives::{
    itype::_Type, IUnknown, IUnknownVtbl, Interface, _MethodInfo, wrap_method_arguments,
    wrap_strings_in_array, GUID, HRESULT,
};
use std::{
    ffi::{c_long, c_void},
    ops::Deref,
    ptr,
};
use windows::{
    core::BSTR,
    Win32::System::{
        Com::{SAFEARRAY, VARIANT, VT_UNKNOWN},
        Ole::{SafeArrayCreateVector, SafeArrayGetElement, SafeArrayGetUBound},
    },
};

#[repr(C)]
pub struct _Assembly {
    pub vtable: *const _AssemblyVtbl,
}

#[repr(C)]
pub struct _AssemblyVtbl {
    pub parent: IUnknownVtbl,
    pub GetTypeInfoCount: *const c_void,
    pub GetTypeInfo: *const c_void,
    pub GetIDsOfNames: *const c_void,
    pub Invoke: *const c_void,
    pub ToString: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
    pub Equals: *const c_void,
    pub GetHashCode: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut c_long) -> HRESULT,
    pub GetType: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut _Type) -> HRESULT,
    pub get_CodeBase:
        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
    pub get_EscapedCodeBase:
        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
    pub GetName: unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut c_void) -> HRESULT,
    pub GetName_2: *const c_void,
    pub get_FullName:
        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
    pub get_EntryPoint:
        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut _MethodInfo) -> HRESULT,
    pub GetType_2: unsafe extern "system" fn(
        this: *mut c_void,
        name: *mut u16,
        pRetVal: *mut *mut _Type,
    ) -> HRESULT,
    pub GetType_3: *const c_void,
    pub GetExportedTypes: *const c_void,
    pub GetTypes:
        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut SAFEARRAY) -> HRESULT,
    pub GetManifestResourceStream: *const c_void,
    pub GetManifestResourceStream_2: *const c_void,
    pub GetFile: *const c_void,
    pub GetFiles: *const c_void,
    pub GetFiles_2: *const c_void,
    pub GetManifestResourceNames: *const c_void,
    pub GetManifestResourceInfo: *const c_void,
    pub get_Location:
        unsafe extern "system" fn(this: *mut c_void, pRetVal: *mut *mut u16) -> HRESULT,
    pub get_Evidence: *const c_void,
    pub GetCustomAttributes: *const c_void,
    pub GetCustomAttributes_2: *const c_void,
    pub IsDefined: *const c_void,
    pub GetObjectData: *const c_void,
    pub add_ModuleResolve: *const c_void,
    pub remove_ModuleResolve: *const c_void,
    pub GetType_4: *const c_void,
    pub GetSatelliteAssembly: *const c_void,
    pub GetSatelliteAssembly_2: *const c_void,
    pub LoadModule: *const c_void,
    pub LoadModule_2: *const c_void,
    pub CreateInstance: unsafe extern "system" fn(
        this: *mut c_void,
        typeName: *mut u16,
        pRetVal: *mut VARIANT,
    ) -> HRESULT,
    pub CreateInstance_2: *const c_void,
    pub CreateInstance_3: *const c_void,
    pub GetLoadedModules: *const c_void,
    pub GetLoadedModules_2: *const c_void,
    pub GetModules: *const c_void,
    pub GetModules_2: *const c_void,
    pub GetModule: *const c_void,
    pub GetReferencedAssemblies: *const c_void,
    pub get_GlobalAssemblyCache: *const c_void,
}

impl _Assembly {
    pub fn run_entrypoint(&self, args: &[String]) -> Result<VARIANT, String> {
        let entrypoint = (*self).get_entrypoint()?;
        let signature = unsafe { (*entrypoint).to_string()? };

        if signature.ends_with("Main()") {
            return unsafe { (*entrypoint).invoke_without_args(None) };
        }

        if signature.ends_with("Main(System.String[])") {
            let args_variant = wrap_strings_in_array(args)?;
            let method_args = wrap_method_arguments(vec![args_variant])?;

            return unsafe { (*entrypoint).invoke(method_args, None) };
        }

        Err(format!(
            "Cannot handle an entrypoint with this method signature: {}",
            signature
        ))
    }

    pub fn get_entrypoint(&self) -> Result<*mut _MethodInfo, String> {
        let mut method_info_ptr: *mut _MethodInfo = ptr::null_mut();

        let hr = unsafe { (*self).get_EntryPoint(&mut method_info_ptr) };

        if hr.is_err() {
            return Err(format!("Could not retrieve entrypoint: {:?}", hr));
        }

        if method_info_ptr.is_null() {
            return Err("Could not retrieve entrypoint".into());
        }

        Ok(method_info_ptr)
    }

    pub fn to_string(&self) -> Result<String, String> {
        let mut buffer = BSTR::new();

        let hr = unsafe { (*self).ToString(&mut buffer as *mut _ as *mut *mut u16) };

        if hr.is_err() {
            return Err(format!("Failed while running `ToString`: {:?}", hr));
        }

        Ok(buffer.to_string())
    }

    #[inline]
    pub unsafe fn ToString(&self, pRetVal: *mut *mut u16) -> HRESULT {
        ((*self.vtable).ToString)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn GetHashCode(&self, pRetVal: *mut c_long) -> HRESULT {
        ((*self.vtable).GetHashCode)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn GetType(&self, pRetVal: *mut *mut _Type) -> HRESULT {
        ((*self.vtable).GetType)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn get_CodeBase(&self, pRetVal: *mut *mut u16) -> HRESULT {
        ((*self.vtable).get_CodeBase)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn get_EscapedCodeBase(&self, pRetVal: *mut *mut u16) -> HRESULT {
        ((*self.vtable).get_EscapedCodeBase)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn GetName(&self, pRetVal: *mut *mut c_void) -> HRESULT {
        ((*self.vtable).GetName)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn get_FullName(&self, pRetVal: *mut *mut u16) -> HRESULT {
        ((*self.vtable).get_FullName)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn get_EntryPoint(&self, pRetVal: *mut *mut _MethodInfo) -> HRESULT {
        ((*self.vtable).get_EntryPoint)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn GetType_2(&self, name: *mut u16, pRetVal: *mut *mut _Type) -> HRESULT {
        ((*self.vtable).GetType_2)(self as *const _ as *mut _, name, pRetVal)
    }

    #[inline]
    pub unsafe fn GetTypes(&self, pRetVal: *mut *mut SAFEARRAY) -> HRESULT {
        ((*self.vtable).GetTypes)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn get_Location(&self, pRetVal: *mut *mut u16) -> HRESULT {
        ((*self.vtable).get_Location)(self as *const _ as *mut _, pRetVal)
    }

    #[inline]
    pub unsafe fn CreateInstance(&self, typeName: *mut u16, pRetVal: *mut VARIANT) -> HRESULT {
        ((*self.vtable).CreateInstance)(self as *const _ as *mut _, typeName, pRetVal)
    }

    pub fn create_instance(&self, name: &str) -> Result<VARIANT, String> {
        let dw = BSTR::from(name);

        let mut instance: VARIANT = VARIANT::default();
        let hr = unsafe { (*self).CreateInstance(dw.into_raw() as *mut _, &mut instance) };

        if hr.is_err() {
            return Err(format!(
                "Error while creating instance of `{}`: 0x{:x}",
                name, hr.0
            ));
        }

        Ok(instance)
    }

    pub fn get_type(&self, name: &str) -> Result<*mut _Type, String> {
        let dw = BSTR::from(name);

        let mut type_ptr: *mut _Type = ptr::null_mut();
        let hr = unsafe { (*self).GetType_2(dw.into_raw() as *mut _, &mut type_ptr) };

        if hr.is_err() {
            return Err(format!(
                "Error while retrieving type `{}`: 0x{:x}",
                name, hr.0
            ));
        }

        if type_ptr.is_null() {
            return Err(format!("Could not retrieve type `{}`", name));
        }

        Ok(type_ptr)
    }

    pub fn get_types(&self) -> Result<Vec<*mut _Type>, String> {
        let mut results: Vec<*mut _Type> = vec![];

        let mut safe_array_ptr: *mut SAFEARRAY = unsafe { SafeArrayCreateVector(VT_UNKNOWN, 0, 0) };

        let hr = unsafe { (*self).GetTypes(&mut safe_array_ptr) };

        if hr.is_err() {
            return Err(format!("Error while retrieving types: 0x{:x}", hr.0));
        }

        let ubound = unsafe { SafeArrayGetUBound(safe_array_ptr, 1) }.unwrap_or(0);

        for i in 0..ubound {
            let indices: [i32; 1] = [i as _];
            let mut variant: *mut _Type = ptr::null_mut();
            let pv = &mut variant as *mut _ as *mut c_void;

            match unsafe { SafeArrayGetElement(safe_array_ptr, indices.as_ptr(), pv) } {
                Ok(_) => {},
                Err(e) => return Err(format!("Could not access safe array: {:?}", e.code())),
            }

            if !pv.is_null() {
                results.push(variant)
            }
        }

        Ok(results)
    }
}

impl Interface for _Assembly {
    const IID: GUID = GUID::from_values(
        0x17156360,
        0x2f1a,
        0x384a,
        [0xbc, 0x52, 0xfd, 0xe9, 0x3c, 0x21, 0x5c, 0x5b],
    );

    fn vtable(&self) -> *const c_void {
        self.vtable as *const _ as *const c_void
    }
}

impl Deref for _Assembly {
    type Target = IUnknown;

    #[inline]
    fn deref(&self) -> &IUnknown {
        unsafe { &*(self as *const _Assembly as *const IUnknown) }
    }
}