use core::ffi::c_void;
use std::fs::File;
use std::io::Write;
use std::mem::size_of;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use crossbeam_channel::{Receiver, Sender, TryRecvError, unbounded};
use rust_embed::RustEmbed;
use tracing::{Span, error, info, instrument, warn};
use windows::Win32::Foundation::HANDLE;
use windows::Win32::Security::SECURITY_ATTRIBUTES;
use windows::Win32::System::JobObjects::{
AssignProcessToJobObject, CreateJobObjectA, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
JOBOBJECT_BASIC_LIMIT_INFORMATION, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
JobObjectExtendedLimitInformation, SetInformationJobObject, TerminateJobObject,
};
use windows::Win32::System::Threading::{
CREATE_SUSPENDED, CreateProcessA, PROCESS_INFORMATION, STARTUPINFOA,
};
use windows::core::PCSTR;
use super::surrogate_process::SurrogateProcess;
use super::wrappers::{HandleWrapper, PSTRWrapper};
use crate::HyperlightError::WindowsAPIError;
use crate::{Result, log_then_return, new_error};
#[derive(RustEmbed)]
#[folder = "$HYPERLIGHT_SURROGATE_DIR"]
#[include = "hyperlight_surrogate.exe"]
struct Asset;
const EMBEDDED_SURROGATE_NAME: &str = "hyperlight_surrogate.exe";
const HARD_MAX_SURROGATE_PROCESSES: usize = 512;
const INITIAL_SURROGATES_ENV_VAR: &str = "HYPERLIGHT_INITIAL_SURROGATES";
const MAX_SURROGATES_ENV_VAR: &str = "HYPERLIGHT_MAX_SURROGATES";
fn surrogate_binary_name() -> Result<String> {
let exe = Asset::get(EMBEDDED_SURROGATE_NAME)
.ok_or_else(|| new_error!("could not find embedded surrogate binary"))?;
let hash = blake3::hash(exe.data.as_ref());
let short_hash = &hash.to_hex()[..8];
Ok(format!("hyperlight_surrogate_{short_hash}.exe"))
}
fn compute_surrogate_counts(raw_initial: Option<usize>, raw_max: Option<usize>) -> (usize, usize) {
let max = raw_max
.map(|n| n.clamp(1, HARD_MAX_SURROGATE_PROCESSES))
.unwrap_or(HARD_MAX_SURROGATE_PROCESSES);
let initial = raw_initial.map(|n| n.clamp(1, max)).unwrap_or(max);
(initial, max)
}
fn surrogate_process_counts() -> (usize, usize) {
let raw_initial = std::env::var(INITIAL_SURROGATES_ENV_VAR)
.ok()
.and_then(|v| v.parse::<usize>().ok());
let raw_max = std::env::var(MAX_SURROGATES_ENV_VAR)
.ok()
.and_then(|v| v.parse::<usize>().ok());
let (initial, max) = compute_surrogate_counts(raw_initial, raw_max);
if let Some(n) = raw_initial
&& n != initial
{
warn!("{INITIAL_SURROGATES_ENV_VAR}={n} was clamped to {initial}");
}
if let Some(n) = raw_max
&& n != max
{
warn!("{MAX_SURROGATES_ENV_VAR}={n} was clamped to {max}");
}
(initial, max)
}
pub(crate) struct SurrogateProcessManager {
job_handle: HandleWrapper,
process_receiver: Receiver<HandleWrapper>,
process_sender: Sender<HandleWrapper>,
surrogate_process_path: PathBuf,
max_processes: usize,
created_count: AtomicUsize,
}
impl SurrogateProcessManager {
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
fn new() -> Result<Self> {
let binary_name = surrogate_binary_name()?;
ensure_surrogate_process_exe(&binary_name)?;
let surrogate_process_path = get_surrogate_process_dir()?.join(&binary_name);
let (initial, max) = surrogate_process_counts();
let (sender, receiver) = unbounded();
let job_handle = create_job_object()?;
let surrogate_process_manager = SurrogateProcessManager {
job_handle,
process_receiver: receiver,
process_sender: sender,
surrogate_process_path,
max_processes: max,
created_count: AtomicUsize::new(0),
};
surrogate_process_manager.create_initial_surrogate_processes(initial)?;
Ok(surrogate_process_manager)
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
pub(super) fn get_surrogate_process(&self) -> Result<SurrogateProcess> {
match self.process_receiver.try_recv() {
Ok(handle) => {
let surrogate_process_handle: HANDLE = handle.into();
return Ok(SurrogateProcess::new(surrogate_process_handle));
}
Err(TryRecvError::Empty) => {
}
Err(TryRecvError::Disconnected) => {
return Err(new_error!("surrogate process channel disconnected"));
}
}
loop {
let current = self.created_count.load(Ordering::Acquire);
if current >= self.max_processes {
break;
}
if self
.created_count
.compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
info!(
"on-demand surrogate process creation ({}/{})",
current + 1,
self.max_processes
);
let handle =
match create_surrogate_process(&self.surrogate_process_path, self.job_handle) {
Ok(h) => h,
Err(e) => {
self.created_count.fetch_sub(1, Ordering::AcqRel);
return Err(e);
}
};
let surrogate_process_handle: HANDLE = handle.into();
return Ok(SurrogateProcess::new(surrogate_process_handle));
}
}
let surrogate_process_handle: HANDLE = self.process_receiver.recv()?.into();
Ok(SurrogateProcess::new(surrogate_process_handle))
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
pub(super) fn return_surrogate_process(&self, proc_handle: HandleWrapper) -> Result<()> {
Ok(self.process_sender.send(proc_handle)?)
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
fn create_initial_surrogate_processes(&self, initial_count: usize) -> Result<()> {
info!(
"pre-creating {} surrogate processes ({}={:?}, {}={:?})",
initial_count,
INITIAL_SURROGATES_ENV_VAR,
std::env::var(INITIAL_SURROGATES_ENV_VAR).ok(),
MAX_SURROGATES_ENV_VAR,
std::env::var(MAX_SURROGATES_ENV_VAR).ok(),
);
for _ in 0..initial_count {
let surrogate_process =
create_surrogate_process(&self.surrogate_process_path, self.job_handle)?;
self.process_sender.send(surrogate_process)?;
self.created_count.fetch_add(1, Ordering::AcqRel);
}
Ok(())
}
}
impl Drop for SurrogateProcessManager {
#[instrument(skip_all, parent = Span::current(), level= "Trace")]
fn drop(&mut self) {
let handle: HANDLE = self.job_handle.into();
if unsafe {
TerminateJobObject(handle, 0)
}
.is_err()
{
error!("surrogate job objects were not all terminated");
}
}
}
lazy_static::lazy_static! {
static ref SURROGATE_PROCESSES_MANAGER: std::result::Result<SurrogateProcessManager, &'static str> =
match SurrogateProcessManager::new() {
Ok(manager) => Ok(manager),
Err(e) => {
error!("Failed to create SurrogateProcessManager: {:?}", e);
Err("Failed to create SurrogateProcessManager")
}
};
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
pub(crate) fn get_surrogate_process_manager() -> Result<&'static SurrogateProcessManager> {
match &*SURROGATE_PROCESSES_MANAGER {
Ok(manager) => Ok(manager),
Err(e) => {
error!("Failed to get SurrogateProcessManager: {:?}", e);
Err(new_error!("Failed to get SurrogateProcessManager {}", e))
}
}
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
fn create_job_object() -> Result<HandleWrapper> {
let security_attributes: SECURITY_ATTRIBUTES = Default::default();
let job_object = unsafe { CreateJobObjectA(Some(&security_attributes), PCSTR::null())? };
let mut job_object_information = JOBOBJECT_EXTENDED_LIMIT_INFORMATION {
BasicLimitInformation: JOBOBJECT_BASIC_LIMIT_INFORMATION {
LimitFlags: JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
..Default::default()
},
..Default::default()
};
let job_object_information_ptr: *mut c_void =
&mut job_object_information as *mut _ as *mut c_void;
if let Err(e) = unsafe {
SetInformationJobObject(
job_object,
JobObjectExtendedLimitInformation,
job_object_information_ptr,
size_of::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() as u32,
)
} {
log_then_return!(WindowsAPIError(e.clone()));
}
Ok(job_object.into())
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
fn get_surrogate_process_dir() -> Result<PathBuf> {
let binding = std::env::current_exe()?;
let path = binding
.parent()
.ok_or_else(|| new_error!("could not get parent directory of current executable"))?;
Ok(path.to_path_buf())
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
fn ensure_surrogate_process_exe(binary_name: &str) -> Result<()> {
let dir = get_surrogate_process_dir()?;
let surrogate_process_path = dir.join(binary_name);
let exe = Asset::get(EMBEDDED_SURROGATE_NAME)
.ok_or_else(|| new_error!("could not find embedded surrogate binary"))?;
match File::create_new(&surrogate_process_path) {
Ok(mut f) => {
info!(
"{} does not exist, extracting to {}",
binary_name,
&surrogate_process_path.display()
);
if let Err(e) = f.write_all(exe.data.as_ref()) {
drop(f);
let _ = std::fs::remove_file(&surrogate_process_path);
return Err(e.into());
}
}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
}
Err(e) => return Err(e.into()),
}
Ok(())
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
fn create_surrogate_process(
surrogate_process_path: &Path,
job_handle: HandleWrapper,
) -> Result<HandleWrapper> {
let mut process_information: PROCESS_INFORMATION = unsafe { std::mem::zeroed() };
let mut startup_info: STARTUPINFOA = unsafe { std::mem::zeroed() };
let process_attributes: SECURITY_ATTRIBUTES = Default::default();
let thread_attributes: SECURITY_ATTRIBUTES = Default::default();
startup_info.cb = std::mem::size_of::<STARTUPINFOA>() as u32;
let cmd_line = surrogate_process_path.to_str().ok_or(new_error!(
"failed to convert surrogate process path to a string"
))?;
let p_cmd_line = &PSTRWrapper::try_from(cmd_line)?;
if let Err(e) = unsafe {
CreateProcessA(
PCSTR::null(),
Some(p_cmd_line.into()),
Some(&process_attributes),
Some(&thread_attributes),
false,
CREATE_SUSPENDED,
None,
None,
&startup_info,
&mut process_information,
)
} {
log_then_return!(WindowsAPIError(e.clone()));
}
let job_handle: HANDLE = job_handle.into();
let process_handle: HANDLE = process_information.hProcess;
unsafe {
if let Err(e) = AssignProcessToJobObject(job_handle, process_handle) {
log_then_return!(WindowsAPIError(e.clone()));
}
}
Ok(process_handle.into())
}
#[cfg(test)]
mod tests {
use std::ffi::CStr;
use std::thread;
use std::time::{Duration, Instant};
use rand::{RngExt, rng};
use windows::Win32::Foundation::HANDLE;
use windows::Win32::System::Diagnostics::ToolHelp::{
CreateToolhelp32Snapshot, PROCESSENTRY32, Process32First, Process32Next, TH32CS_SNAPPROCESS,
};
use windows::Win32::System::JobObjects::IsProcessInJob;
use windows_result::BOOL;
use super::*;
use crate::mem::shared_mem::{ExclusiveSharedMemory, SharedMemory};
#[test]
fn test_surrogate_process_manager() {
let mut threads = Vec::new();
let surrogate_process_manager = get_surrogate_process_manager().unwrap();
let max_processes = surrogate_process_manager.max_processes;
for t in 0..max_processes * 2 {
let thread_handle = thread::spawn(move || -> Result<()> {
let surrogate_process_manager_res = get_surrogate_process_manager();
let mut rng = rng();
assert!(surrogate_process_manager_res.is_ok());
let surrogate_process_manager = surrogate_process_manager_res.unwrap();
let job_handle = surrogate_process_manager.job_handle;
for p in 0..max_processes {
let timer = Instant::now();
let surrogate_process = {
let res = surrogate_process_manager.get_surrogate_process()?;
let elapsed = timer.elapsed();
if (elapsed.as_millis() as u64) > 150 {
println!("Get Process Time Thread {} Process {}: {:?}", t, p, elapsed);
}
res
};
let mut result: BOOL = Default::default();
let process_handle: HANDLE = surrogate_process.process_handle.into();
let job_handle: HANDLE = job_handle.into();
unsafe {
assert!(
IsProcessInJob(process_handle, Some(job_handle), &mut result).is_ok()
);
assert!(result.as_bool());
}
let n: u64 = rng.random_range(1..16);
thread::sleep(Duration::from_millis(n));
drop(surrogate_process);
}
Ok(())
});
threads.push(thread_handle);
}
for thread_handle in threads {
assert!(thread_handle.join().is_ok());
}
assert_number_of_surrogate_processes(max_processes);
}
#[track_caller]
fn assert_number_of_surrogate_processes(expected_count: usize) {
const MAX_RETRIES: u32 = 30;
let mut attempt = 0;
loop {
let snapshot_handle = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) };
assert!(snapshot_handle.is_ok());
let snapshot_handle = snapshot_handle.unwrap();
let mut process_entry = PROCESSENTRY32 {
dwSize: size_of::<PROCESSENTRY32>() as u32,
..Default::default()
};
let mut result = unsafe { Process32First(snapshot_handle, &mut process_entry).is_ok() };
let mut count = 0;
while result {
if let Ok(process_name) =
unsafe { CStr::from_ptr(process_entry.szExeFile.as_ptr()).to_str() }
&& process_name.starts_with("hyperlight_surrogate_")
&& process_name.ends_with(".exe")
{
count += 1;
}
unsafe {
result = Process32Next(snapshot_handle, &mut process_entry).is_ok();
}
}
attempt += 1;
if expected_count == 0 && count > 0 && attempt < MAX_RETRIES {
thread::sleep(Duration::from_secs(1));
} else {
assert_eq!(count, expected_count);
break;
}
}
}
#[test]
fn windows_guard_page() {
const SIZE: usize = 4096;
let mgr = get_surrogate_process_manager().unwrap();
let mem = ExclusiveSharedMemory::new(SIZE).unwrap();
let mut process = mgr.get_surrogate_process().unwrap();
let surrogate_address = process
.map(
HandleWrapper::from(mem.get_mmap_file_handle()),
mem.raw_ptr() as usize,
mem.raw_mem_size(),
&crate::mem::memory_region::SurrogateMapping::SandboxMemory,
)
.unwrap();
let buffer = vec![0u8; SIZE];
let bytes_read: Option<*mut usize> = None;
let process_handle: HANDLE = process.process_handle.into();
unsafe {
let success = windows::Win32::System::Diagnostics::Debug::ReadProcessMemory(
process_handle,
surrogate_address,
buffer.as_ptr() as *mut c_void,
SIZE,
bytes_read,
);
assert!(success.is_err());
let success = windows::Win32::System::Diagnostics::Debug::ReadProcessMemory(
process_handle,
surrogate_address.wrapping_add(SIZE),
buffer.as_ptr() as *mut c_void,
SIZE,
bytes_read,
);
assert!(success.is_ok());
let success = windows::Win32::System::Diagnostics::Debug::ReadProcessMemory(
process_handle,
surrogate_address.wrapping_add(2 * SIZE),
buffer.as_ptr() as *mut c_void,
SIZE,
bytes_read,
);
assert!(success.is_err());
}
}
#[test]
fn readonly_file_mapping_skips_guard_pages() {
const SIZE: usize = 4096;
let mgr = get_surrogate_process_manager().unwrap();
let mem = ExclusiveSharedMemory::new(SIZE).unwrap();
let mut process = mgr.get_surrogate_process().unwrap();
let surrogate_address = process
.map(
HandleWrapper::from(mem.get_mmap_file_handle()),
mem.raw_ptr() as usize,
mem.raw_mem_size(),
&crate::mem::memory_region::SurrogateMapping::ReadOnlyFile,
)
.unwrap();
let buffer = vec![0u8; SIZE];
let bytes_read: Option<*mut usize> = None;
let process_handle: HANDLE = process.process_handle.into();
unsafe {
let success = windows::Win32::System::Diagnostics::Debug::ReadProcessMemory(
process_handle,
surrogate_address,
buffer.as_ptr() as *mut c_void,
SIZE,
bytes_read,
);
assert!(
success.is_ok(),
"First page should be readable with ReadOnlyFile (no guard page)"
);
let success = windows::Win32::System::Diagnostics::Debug::ReadProcessMemory(
process_handle,
surrogate_address.wrapping_add(SIZE),
buffer.as_ptr() as *mut c_void,
SIZE,
bytes_read,
);
assert!(
success.is_ok(),
"Middle page should be readable with ReadOnlyFile"
);
let success = windows::Win32::System::Diagnostics::Debug::ReadProcessMemory(
process_handle,
surrogate_address.wrapping_add(2 * SIZE),
buffer.as_ptr() as *mut c_void,
SIZE,
bytes_read,
);
assert!(
success.is_ok(),
"Last page should be readable with ReadOnlyFile (no guard page)"
);
}
}
#[test]
fn surrogate_map_ref_counting() {
let mgr = get_surrogate_process_manager().unwrap();
let mem = ExclusiveSharedMemory::new(4096).unwrap();
let mut process = mgr.get_surrogate_process().unwrap();
let handle = HandleWrapper::from(mem.get_mmap_file_handle());
let host_base = mem.raw_ptr() as usize;
let host_size = mem.raw_mem_size();
let addr1 = process
.map(
handle,
host_base,
host_size,
&crate::mem::memory_region::SurrogateMapping::SandboxMemory,
)
.unwrap();
let addr2 = process
.map(
handle,
host_base,
host_size,
&crate::mem::memory_region::SurrogateMapping::SandboxMemory,
)
.unwrap();
assert_eq!(
addr1, addr2,
"Repeated map should return the same surrogate address"
);
process.unmap(host_base);
assert!(
process.mappings.contains_key(&host_base),
"Mapping should still exist after first unmap (ref count > 0)"
);
process.unmap(host_base);
assert!(
!process.mappings.contains_key(&host_base),
"Mapping should be removed after ref count reaches 0"
);
}
#[test]
fn test_ensure_surrogate_exe() {
let test_binary_name = "hyperlight_surrogate_test_extraction.exe";
let dir = get_surrogate_process_dir().expect("should get surrogate dir");
let path = dir.join(test_binary_name);
let _ = std::fs::remove_file(&path);
ensure_surrogate_process_exe(test_binary_name).expect("first call should succeed");
assert!(path.exists(), "binary should exist after extraction");
let on_disk = std::fs::read(&path).expect("should read extracted file");
let embedded = Asset::get(EMBEDDED_SURROGATE_NAME).expect("embedded asset should exist");
assert_eq!(
on_disk,
embedded.data.as_ref(),
"extracted file content should match embedded binary"
);
ensure_surrogate_process_exe(test_binary_name)
.expect("second call should succeed when file already exists");
std::fs::remove_file(&path).expect("should be able to delete test binary");
assert!(!path.exists(), "binary should be gone after deletion");
ensure_surrogate_process_exe(test_binary_name)
.expect("should succeed re-extracting after deletion");
assert!(path.exists(), "binary should be re-created after deletion");
let _ = std::fs::remove_file(&path);
let binary_name = surrogate_binary_name().expect("should succeed");
let binary_name_2 = surrogate_binary_name().expect("second call should also succeed");
assert_eq!(
binary_name, binary_name_2,
"surrogate_binary_name should be deterministic"
);
}
#[test]
fn test_compute_surrogate_counts() {
let (initial, max) = compute_surrogate_counts(None, None);
assert_eq!(
initial, HARD_MAX_SURROGATE_PROCESSES,
"default initial should be {HARD_MAX_SURROGATE_PROCESSES}"
);
assert_eq!(
max, HARD_MAX_SURROGATE_PROCESSES,
"default max should be {HARD_MAX_SURROGATE_PROCESSES}"
);
let (initial, max) = compute_surrogate_counts(Some(32), None);
assert_eq!(initial, 32, "initial should honour provided value");
assert_eq!(
max, HARD_MAX_SURROGATE_PROCESSES,
"max should default when unset"
);
let (initial, max) = compute_surrogate_counts(Some(8), Some(64));
assert_eq!(initial, 8);
assert_eq!(max, 64);
let (initial, max) = compute_surrogate_counts(Some(100), Some(10));
assert_eq!(max, 10, "max is authoritative and should not be inflated");
assert_eq!(
initial, 10,
"initial should be clamped down to max when it exceeds it"
);
let (initial, max) = compute_surrogate_counts(Some(0), None);
assert_eq!(initial, 1, "initial should be clamped to minimum of 1");
assert_eq!(
max, HARD_MAX_SURROGATE_PROCESSES,
"max should default when unset"
);
let (initial, max) = compute_surrogate_counts(Some(9999), None);
assert_eq!(
initial, HARD_MAX_SURROGATE_PROCESSES,
"initial should be clamped to {HARD_MAX_SURROGATE_PROCESSES}"
);
assert_eq!(max, HARD_MAX_SURROGATE_PROCESSES);
let (initial, max) = compute_surrogate_counts(None, Some(256));
assert_eq!(max, 256, "max should honour provided value");
assert_eq!(
initial, 256,
"initial should be clamped down to max when it defaults above it"
);
let (initial, max) = compute_surrogate_counts(None, Some(0));
assert_eq!(max, 1, "max should be clamped to minimum of 1");
assert_eq!(initial, 1, "initial should be clamped down to max");
let (initial, max) = compute_surrogate_counts(None, Some(9999));
assert_eq!(
max, HARD_MAX_SURROGATE_PROCESSES,
"max should be clamped to {HARD_MAX_SURROGATE_PROCESSES}"
);
assert_eq!(initial, HARD_MAX_SURROGATE_PROCESSES);
let (initial, max) = compute_surrogate_counts(Some(1), Some(1));
assert_eq!(initial, 1);
assert_eq!(max, 1);
let (initial, max) = compute_surrogate_counts(
Some(HARD_MAX_SURROGATE_PROCESSES),
Some(HARD_MAX_SURROGATE_PROCESSES),
);
assert_eq!(initial, HARD_MAX_SURROGATE_PROCESSES);
assert_eq!(max, HARD_MAX_SURROGATE_PROCESSES);
}
#[test]
fn test_surrogate_process_counts_defaults() {
let (initial, max) = surrogate_process_counts();
assert!(
(1..=HARD_MAX_SURROGATE_PROCESSES).contains(&initial),
"initial {initial} should be in 1..={HARD_MAX_SURROGATE_PROCESSES}"
);
assert!(
(1..=HARD_MAX_SURROGATE_PROCESSES).contains(&max),
"max {max} should be in 1..={HARD_MAX_SURROGATE_PROCESSES}"
);
assert!(initial <= max, "initial ({initial}) must be <= max ({max})");
}
}