use core::{
mem::transmute,
ffi::c_void,
ptr::null_mut
};
use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
use obfstr::obfstr as s;
use dinvk::winapis::{
NtCurrentProcess,
NtProtectVirtualMemory,
NT_SUCCESS
};
use windows_core::{IUnknown, Interface, PCWSTR};
use windows_sys::Win32::{
UI::Shell::SHCreateMemStream,
System::Memory::PAGE_EXECUTE_READWRITE
};
use super::hosting::RustClrControl;
use crate::{com::*, variant::Variant};
use crate::error::{ClrError, Result};
#[derive(Default, Debug, Clone)]
pub struct RustClrRuntime<'a> {
pub buffer: &'a [u8],
pub identity_assembly: String,
pub runtime_version: Option<RuntimeVersion>,
pub domain_name: Option<String>,
pub app_domain: Option<_AppDomain>,
pub cor_runtime_host: Option<ICorRuntimeHost>,
}
impl<'a> RustClrRuntime<'a> {
pub fn new(buffer: &'a [u8]) -> Self {
Self {
buffer,
identity_assembly: String::new(),
runtime_version: None,
domain_name: None,
app_domain: None,
cor_runtime_host: None,
}
}
pub fn prepare(&mut self) -> Result<()> {
let meta_host = self.create_meta_host()?;
let runtime_info = self.get_runtime_info(&meta_host)?;
let addr = runtime_info.GetProcAddress(s!("GetCLRIdentityManager"))?;
let GetCLRIdentityManager = unsafe { transmute::<*mut c_void, CLRIdentityManagerType>(addr) };
let mut ptr = null_mut();
GetCLRIdentityManager(&ICLRAssemblyIdentityManager::IID, &mut ptr);
let iclr_assembly = ICLRAssemblyIdentityManager::from_raw(ptr)?;
let stream = unsafe { SHCreateMemStream(self.buffer.as_ptr(), self.buffer.len() as u32) };
self.identity_assembly = iclr_assembly.get_identity_stream(stream, 0)?;
let iclr_runtime_host = self.get_clr_runtime_host(&runtime_info)?;
if runtime_info.IsLoadable().is_ok() && !runtime_info.is_started() {
let host_control: IHostControl = RustClrControl::new(self.buffer, &self.identity_assembly).into();
iclr_runtime_host.SetHostControl(&host_control)?;
self.start_runtime(&iclr_runtime_host)?;
}
let cor_runtime_host = self.get_icor_runtime_host(&runtime_info)?;
self.init_app_domain(&cor_runtime_host)?;
self.cor_runtime_host = Some(self.get_icor_runtime_host(&runtime_info)?);
Ok(())
}
pub fn get_app_domain(&mut self) -> Result<_AppDomain> {
self.app_domain
.clone()
.ok_or(ClrError::NoDomainAvailable)
}
fn create_meta_host(&self) -> Result<ICLRMetaHost> {
CLRCreateInstance::<ICLRMetaHost>(&CLSID_CLRMETAHOST)
.map_err(|e| ClrError::MetaHostCreationError(format!("{e}")))
}
fn get_runtime_info(&self, meta_host: &ICLRMetaHost) -> Result<ICLRRuntimeInfo> {
let runtime_version = &self.runtime_version.unwrap_or(RuntimeVersion::V4);
let version_wide = runtime_version.to_vec();
let version = PCWSTR(version_wide.as_ptr());
meta_host
.GetRuntime::<ICLRRuntimeInfo>(version)
.map_err(|error| ClrError::RuntimeInfoError(format!("{error}")))
}
fn get_icor_runtime_host(&self, runtime_info: &ICLRRuntimeInfo) -> Result<ICorRuntimeHost> {
runtime_info
.GetInterface::<ICorRuntimeHost>(&CLSID_COR_RUNTIME_HOST)
.map_err(|error| ClrError::RuntimeHostError(format!("{error}")))
}
fn get_clr_runtime_host(&self, runtime_info: &ICLRRuntimeInfo) -> Result<ICLRuntimeHost> {
runtime_info
.GetInterface::<ICLRuntimeHost>(&CLSID_ICLR_RUNTIME_HOST)
.map_err(|error| ClrError::RuntimeHostError(format!("{error}")))
}
fn start_runtime(&self, iclr_runtime_host: &ICLRuntimeHost) -> Result<()> {
if iclr_runtime_host.Start() != 0 {
return Err(ClrError::RuntimeStartError);
}
Ok(())
}
fn init_app_domain(&mut self, cor_runtime_host: &ICorRuntimeHost) -> Result<()> {
let app_domain = if let Some(domain_name) = &self.domain_name {
let wide_domain_name = domain_name
.encode_utf16()
.chain(Some(0))
.collect::<Vec<u16>>();
cor_runtime_host.CreateDomain(PCWSTR(wide_domain_name.as_ptr()), null_mut())?
} else {
let uuid = uuid()
.to_string()
.encode_utf16()
.chain(Some(0))
.collect::<Vec<u16>>();
cor_runtime_host.CreateDomain(PCWSTR(uuid.as_ptr()), null_mut())?
};
self.app_domain = Some(app_domain);
Ok(())
}
pub fn unload_domain(&self) -> Result<()> {
if let (Some(cor_runtime_host), Some(app_domain)) =
(&self.cor_runtime_host, &self.app_domain)
{
cor_runtime_host.UnloadDomain(
app_domain
.cast::<windows_core::IUnknown>()
.map(|i| i.as_raw().cast())
.unwrap_or(null_mut()),
)?
}
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub enum RuntimeVersion {
V2,
V3,
V4,
UNKNOWN,
}
impl RuntimeVersion {
pub fn to_vec(self) -> Vec<u16> {
let runtime_version = match self {
RuntimeVersion::V2 => "v2.0.50727",
RuntimeVersion::V3 => "v3.0",
RuntimeVersion::V4 => "v4.0.30319",
RuntimeVersion::UNKNOWN => "UNKNOWN",
};
runtime_version
.encode_utf16()
.chain(Some(0))
.collect::<Vec<u16>>()
}
}
pub fn uuid() -> uuid::Uuid {
let mut buf = [0u8; 16];
for i in 0..4 {
let ticks = unsafe { core::arch::x86_64::_rdtsc() };
buf[i * 4] = ticks as u8;
buf[i * 4 + 1] = (ticks >> 8) as u8;
buf[i * 4 + 2] = (ticks >> 16) as u8;
buf[i * 4 + 3] = (ticks >> 24) as u8;
}
uuid::Uuid::from_bytes(buf)
}
pub fn patch_exit(mscorlib: &_Assembly) -> Result<()> {
let env = mscorlib.resolve_type(s!("System.Environment"))?;
let exit = env.method(s!("Exit"))?;
let method_info = mscorlib.resolve_type(s!("System.Reflection.MethodInfo"))?;
let method_handle = method_info.property(s!("MethodHandle"))?;
let instance = exit
.cast::<IUnknown>()
.map_err(|_| ClrError::Msg("Failed to cast to IUnknown"))?;
let method_handle_exit = method_handle.value(Some(instance.to_variant()), None)?;
let runtime_method = mscorlib.resolve_type(s!("System.RuntimeMethodHandle"))?;
let get_function_pointer = runtime_method.method(s!("GetFunctionPointer"))?;
let ptr = get_function_pointer.invoke(Some(method_handle_exit), None)?;
let mut addr_exit = unsafe { ptr.Anonymous.Anonymous.Anonymous.byref };
let mut old = 0;
let mut size = 1;
if !NT_SUCCESS(NtProtectVirtualMemory(
NtCurrentProcess(),
&mut addr_exit,
&mut size,
PAGE_EXECUTE_READWRITE,
&mut old,
)) {
return Err(ClrError::Msg(
"failed to change memory protection to RWX",
));
}
unsafe { *(ptr.Anonymous.Anonymous.Anonymous.byref as *mut u8) = 0xC3 };
if !NT_SUCCESS(NtProtectVirtualMemory(
NtCurrentProcess(),
&mut addr_exit,
&mut size,
old,
&mut old,
)) {
return Err(ClrError::Msg(
"failed to restore memory protection",
));
}
Ok(())
}