use std::ffi::OsStr;
use std::os::windows::ffi::OsStrExt;
use std::sync::{Arc, Mutex, Weak};
use tracing::{instrument, Span};
use windows::core::PCWSTR;
use windows::Win32::Foundation::HMODULE;
use windows::Win32::System::LibraryLoader::LoadLibraryW;
use windows_sys::Win32::Foundation::FreeLibrary;
use super::ptr::RawPtr;
use crate::hypervisor::wrappers::HModuleWrapper;
use crate::{log_then_return, Result};
#[derive(Clone)]
pub struct LoadedLib {
inner: Arc<LoadedLibInner>,
}
static LOADED_LIB: Mutex<Weak<LoadedLibInner>> = Mutex::new(Weak::new());
impl LoadedLib {
#[instrument(err(Debug), parent = Span::current(), level= "Trace")]
pub fn load(path: impl AsRef<OsStr> + std::fmt::Debug) -> Result<Self> {
let mut lock = LOADED_LIB.lock().unwrap();
if lock.upgrade().is_some() {
log_then_return!("LoadedLib: Only one guest binary can be loaded at any single time");
}
let inner = Arc::new(LoadedLibInner::load(path)?);
*lock = Arc::downgrade(&inner);
Ok(Self { inner })
}
#[instrument(skip_all, parent = Span::current(), level= "Trace")]
pub(super) fn base_addr(&self) -> RawPtr {
self.inner.base_addr()
}
}
struct LoadedLibInner {
handle: HModuleWrapper,
}
impl LoadedLibInner {
fn load(path: impl AsRef<OsStr>) -> Result<Self> {
let path: Vec<u16> = path.as_ref().encode_wide().chain([0]).collect();
let pcwstr = PCWSTR::from_raw(path.as_ptr());
let handle = unsafe { LoadLibraryW(pcwstr) }?;
Ok(Self {
handle: handle.into(),
})
}
fn base_addr(&self) -> RawPtr {
RawPtr::from(<HModuleWrapper as Into<HMODULE>>::into(self.handle).0 as u64)
}
}
impl Drop for LoadedLibInner {
fn drop(&mut self) {
unsafe { FreeLibrary(<HModuleWrapper as Into<HMODULE>>::into(self.handle).0) };
}
}
#[cfg(test)]
mod tests {
use hyperlight_testing::{rust_guest_as_pathbuf, simple_guest_exe_as_string};
use serial_test::serial;
use super::LoadedLib;
#[test]
#[serial]
fn test_universal() {
{
let path = simple_guest_exe_as_string().unwrap();
let lib = LoadedLib::load(path).unwrap();
drop(lib);
}
{
let path = simple_guest_exe_as_string().unwrap();
let lib1 = LoadedLib::load(&path);
assert!(lib1.is_ok());
let lib2 = LoadedLib::load(&path);
assert!(lib2.is_err());
drop(lib1);
let lib3 = LoadedLib::load(&path);
assert!(lib3.is_ok());
}
{
let lib_name = rust_guest_as_pathbuf("simpleguest.exe");
let lib = LoadedLib::load(lib_name).unwrap();
for _ in 0..9 {
let l = lib.clone();
assert_eq!(lib.base_addr(), l.base_addr());
}
}
}
}