use std::path::Path;
use std::sync::Arc;
#[cfg(target_os = "linux")]
use hyperlight_host::mem::memory_region::{MemoryRegion, MemoryRegionFlags, MemoryRegionType};
use hyperlight_host::sandbox::snapshot::Snapshot;
use hyperlight_host::{MultiUseSandbox, Result, new_error};
use super::loaded_wasm_sandbox::LoadedWasmSandbox;
use crate::sandbox::metrics::{
METRIC_ACTIVE_WASM_SANDBOXES, METRIC_SANDBOX_LOADS, METRIC_TOTAL_WASM_SANDBOXES,
};
mod backing_sandbox {
use super::*;
#[derive(Debug)]
pub(super) enum BackingSandbox {
Clean(MultiUseSandbox),
Loaded(MultiUseSandbox),
Dirty(MultiUseSandbox),
Missing,
}
impl BackingSandbox {
pub(super) fn clean(&mut self, snapshot: Arc<Snapshot>) -> Result<()> {
*self = match std::mem::replace(self, BackingSandbox::Missing) {
BackingSandbox::Clean(x) => BackingSandbox::Clean(x),
BackingSandbox::Loaded(_) => {
return Err(new_error!(
"internal invariant violation: cleaning loaded backing sandbox"
));
}
BackingSandbox::Dirty(mut x) => {
x.restore(snapshot)?;
BackingSandbox::Clean(x)
}
BackingSandbox::Missing => {
return Err(new_error!(
"internal invariant violation: cleaning missing backing sandbox"
));
}
};
Ok(())
}
pub(super) fn load_via_restore(&mut self, snapshot: Arc<Snapshot>) -> Result<()> {
*self = match std::mem::replace(self, BackingSandbox::Missing) {
BackingSandbox::Clean(mut x) | BackingSandbox::Dirty(mut x) => {
x.restore(snapshot)?;
BackingSandbox::Loaded(x)
}
BackingSandbox::Loaded(_) => {
return Err(new_error!(
"internal invariant violation: loading loaded backing sandbox"
));
}
BackingSandbox::Missing => {
return Err(new_error!(
"internal invariant violation: loading missing backing sandbox"
));
}
};
Ok(())
}
pub(super) fn load_via_fn(
&mut self,
load: impl FnOnce(&mut MultiUseSandbox) -> Result<()>,
) -> Result<()> {
*self = match std::mem::replace(self, BackingSandbox::Missing) {
BackingSandbox::Clean(mut x) => {
load(&mut x)?;
BackingSandbox::Loaded(x)
}
_ => {
return Err(new_error!(
"internal invariant violation: loading non-clean backing sandbox"
));
}
};
Ok(())
}
pub(super) fn get_loaded(&mut self) -> Result<MultiUseSandbox> {
match std::mem::replace(self, BackingSandbox::Missing) {
BackingSandbox::Loaded(x) => Ok(x),
_ => Err(new_error!(
"internal invariant violation: encountered non-loaded backing sandbox"
)),
}
}
}
#[cfg(test)]
mod tests {
use super::super::tests::*;
use super::*;
#[test]
fn test_backing_sandbox_use_marks_dirty() -> Result<()> {
let mut sb = SandboxBuilder::new().build()?;
sb.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let sb = sb.load_runtime()?;
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
let sb = lb.unload_module()?;
assert!(matches!(sb.inner, super::BackingSandbox::Dirty(_)));
Ok(())
}
#[test]
fn test_dirty_backing_sandbox_cannot_be_loaded_via_fn() -> Result<()> {
let mut sb = SandboxBuilder::new().build()?;
sb.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let sb = sb.load_runtime()?;
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
let mut sb = lb.unload_module()?;
assert!(sb.inner.load_via_fn(|_| Ok(())).is_err());
Ok(())
}
#[test]
fn test_dirty_backing_sandbox_cannot_be_gotten_as_loaded() -> Result<()> {
let mut sb = SandboxBuilder::new().build()?;
sb.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let sb = sb.load_runtime()?;
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
let mut sb = lb.unload_module()?;
assert!(sb.inner.get_loaded().is_err());
Ok(())
}
}
}
use backing_sandbox::*;
pub struct WasmSandbox {
inner: BackingSandbox,
snapshot: Option<Arc<Snapshot>>,
}
const MAPPED_BINARY_VA: u64 = 0x1_0000_0000u64;
impl WasmSandbox {
pub(super) fn new(mut inner: MultiUseSandbox) -> Result<Self> {
let snapshot = inner.snapshot()?;
metrics::gauge!(METRIC_ACTIVE_WASM_SANDBOXES).increment(1);
metrics::counter!(METRIC_TOTAL_WASM_SANDBOXES).increment(1);
Ok(WasmSandbox {
inner: BackingSandbox::Clean(inner),
snapshot: Some(snapshot),
})
}
pub(super) fn new_from_loaded(
loaded: MultiUseSandbox,
snapshot: Arc<Snapshot>,
) -> Result<Self> {
metrics::gauge!(METRIC_ACTIVE_WASM_SANDBOXES).increment(1);
metrics::counter!(METRIC_TOTAL_WASM_SANDBOXES).increment(1);
Ok(WasmSandbox {
inner: BackingSandbox::Dirty(loaded),
snapshot: Some(snapshot),
})
}
fn clean_inner(&mut self) -> Result<()> {
let snapshot = self.snapshot.as_ref().ok_or(new_error!(
"internal invariant violation: Snapshot is missing"
))?;
self.inner.clean(snapshot.clone())
}
pub fn load_module(mut self, file: impl AsRef<Path>) -> Result<LoadedWasmSandbox> {
self.clean_inner()?;
self.inner.load_via_fn(|inner| {
if let Ok(len) = inner.map_file_cow(file.as_ref(), MAPPED_BINARY_VA, None) {
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len))?;
} else {
let wasm_bytes = std::fs::read(file)?;
load_wasm_module_from_bytes(inner, wasm_bytes)?;
}
Ok(())
})?;
self.finalize_module_load()
}
pub fn load_from_snapshot(mut self, snapshot: Arc<Snapshot>) -> Result<LoadedWasmSandbox> {
self.inner.load_via_restore(snapshot)?;
self.finalize_module_load()
}
#[cfg(target_os = "linux")]
pub unsafe fn load_module_by_mapping(
mut self,
base: *mut libc::c_void,
len: usize,
) -> Result<LoadedWasmSandbox> {
self.clean_inner()?;
self.inner.load_via_fn(|inner| {
let guest_base: usize = MAPPED_BINARY_VA as usize;
let rgn = MemoryRegion {
host_region: base as usize..base.wrapping_add(len) as usize,
guest_region: guest_base..guest_base + len,
flags: MemoryRegionFlags::READ | MemoryRegionFlags::EXECUTE,
region_type: MemoryRegionType::Heap,
};
if let Ok(()) = unsafe { inner.map_region(&rgn) } {
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len as u64))?;
} else {
let wasm_bytes =
unsafe { std::slice::from_raw_parts(base as *const u8, len).to_vec() };
load_wasm_module_from_bytes(inner, wasm_bytes)?;
}
Ok(())
})?;
self.finalize_module_load()
}
pub fn load_module_from_buffer(mut self, buffer: &[u8]) -> Result<LoadedWasmSandbox> {
self.clean_inner()?;
self.inner
.load_via_fn(|inner| load_wasm_module_from_bytes(inner, buffer.to_vec()))?;
self.finalize_module_load()
}
fn finalize_module_load(mut self) -> Result<LoadedWasmSandbox> {
metrics::counter!(METRIC_SANDBOX_LOADS).increment(1);
let sandbox = self.inner.get_loaded()?;
let snapshot = self.snapshot.take().ok_or(new_error!(
"internal invariant violation: Snapshot is missing"
))?;
LoadedWasmSandbox::new(sandbox, snapshot)
}
}
fn load_wasm_module_from_bytes(inner: &mut MultiUseSandbox, wasm_bytes: Vec<u8>) -> Result<()> {
let res: i32 = inner.call(
"LoadWasmModule",
(wasm_bytes.clone(), wasm_bytes.len() as i32),
)?;
if res != 0 {
return Err(new_error!(
"LoadWasmModule Failed with error code {:?}",
res
));
}
Ok(())
}
impl std::fmt::Debug for WasmSandbox {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WasmSandbox").finish()
}
}
impl Drop for WasmSandbox {
fn drop(&mut self) {
metrics::gauge!(METRIC_ACTIVE_WASM_SANDBOXES).decrement(1);
}
}
#[cfg(test)]
mod tests {
use std::env;
use std::path::Path;
use hyperlight_host::{HyperlightError, is_hypervisor_present};
use super::*;
pub(super) use crate::sandbox::sandbox_builder::SandboxBuilder;
#[test]
fn test_new_sandbox() -> Result<()> {
let _sandbox = SandboxBuilder::new().build()?;
Ok(())
}
pub(super) fn get_time_since_boot_microsecond() -> Result<i64> {
let res = std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)?
.as_micros();
i64::try_from(res).map_err(HyperlightError::IntConversionFailure)
}
#[test]
fn test_termination() -> Result<()> {
let mut sandbox = SandboxBuilder::new().build()?;
sandbox.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let loaded = sandbox.load_runtime()?;
let run_wasm = get_test_file_path("RunWasm.aot")?;
let mut loaded = loaded.load_module(run_wasm)?;
let interrupt = loaded.interrupt_handle()?;
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(1000));
interrupt.kill();
});
let result = loaded.call_guest_function::<i32>("KeepCPUBusy", 10000i32);
match result {
Ok(_) => panic!("Expected error"),
Err(e) => match e {
HyperlightError::ExecutionCanceledByHost() => {}
_ => panic!("Unexpected error: {:?}", e),
},
}
assert!(
loaded.is_poisoned()?,
"Sandbox should be poisoned after interruption"
);
Ok(())
}
#[test]
fn test_sandbox_is_poisoned_after_interruption() -> Result<()> {
let mut sandbox = SandboxBuilder::new().build()?;
sandbox.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let loaded = sandbox.load_runtime()?;
let run_wasm = get_test_file_path("RunWasm.aot")?;
let mut loaded = loaded.load_module(run_wasm)?;
assert!(
!loaded.is_poisoned()?,
"Sandbox should not be poisoned initially"
);
let interrupt = loaded.interrupt_handle()?;
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(500));
interrupt.kill();
});
let _ = loaded.call_guest_function::<i32>("KeepCPUBusy", 100000i32);
assert!(
loaded.is_poisoned()?,
"Sandbox should be poisoned after interruption"
);
Ok(())
}
#[test]
fn test_call_guest_function_fails_when_poisoned() -> Result<()> {
let mut sandbox = SandboxBuilder::new().build()?;
sandbox.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let loaded = sandbox.load_runtime()?;
let run_wasm = get_test_file_path("RunWasm.aot")?;
let mut loaded = loaded.load_module(run_wasm)?;
let interrupt = loaded.interrupt_handle()?;
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(500));
interrupt.kill();
});
let _ = loaded.call_guest_function::<i32>("KeepCPUBusy", 100000i32);
let result = loaded.call_guest_function::<i32>("PrintOutput", 42i32);
match result {
Ok(_) => panic!("Expected PoisonedSandbox error"),
Err(HyperlightError::PoisonedSandbox) => {
}
Err(e) => panic!("Unexpected error: {:?}", e),
}
Ok(())
}
#[test]
fn test_snapshot_fails_when_poisoned() -> Result<()> {
let mut sandbox = SandboxBuilder::new().build()?;
sandbox.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let loaded = sandbox.load_runtime()?;
let run_wasm = get_test_file_path("RunWasm.aot")?;
let mut loaded = loaded.load_module(run_wasm)?;
let interrupt = loaded.interrupt_handle()?;
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(500));
interrupt.kill();
});
let _ = loaded.call_guest_function::<i32>("KeepCPUBusy", 100000i32);
let result = loaded.snapshot();
match result {
Ok(_) => panic!("Expected PoisonedSandbox error"),
Err(HyperlightError::PoisonedSandbox) => {
}
Err(e) => panic!("Unexpected error: {:?}", e),
}
Ok(())
}
#[test]
fn test_restore_recovers_poisoned_sandbox() -> Result<()> {
let mut sandbox = SandboxBuilder::new().build()?;
sandbox.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let loaded = sandbox.load_runtime()?;
let run_wasm = get_test_file_path("RunWasm.aot")?;
let mut loaded = loaded.load_module(run_wasm)?;
let snapshot = loaded.snapshot()?;
let interrupt = loaded.interrupt_handle()?;
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(500));
interrupt.kill();
});
let _ = loaded.call_guest_function::<i32>("KeepCPUBusy", 100000i32);
assert!(loaded.is_poisoned()?, "Sandbox should be poisoned");
loaded.restore(snapshot)?;
assert!(
!loaded.is_poisoned()?,
"Sandbox should not be poisoned after restore"
);
let result: i32 = loaded.call_guest_function("CalcFib", 10i32)?;
assert_eq!(result, 55);
Ok(())
}
#[test]
fn test_unload_module_recovers_poisoned_sandbox() -> Result<()> {
let mut sandbox = SandboxBuilder::new().build()?;
sandbox.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)?;
let loaded = sandbox.load_runtime()?;
let run_wasm = get_test_file_path("RunWasm.aot")?;
let mut loaded = loaded.load_module(run_wasm)?;
let interrupt = loaded.interrupt_handle()?;
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(500));
interrupt.kill();
});
let _ = loaded.call_guest_function::<i32>("KeepCPUBusy", 100000i32);
assert!(loaded.is_poisoned()?, "Sandbox should be poisoned");
let wasm_sandbox = loaded.unload_module()?;
let helloworld_wasm = get_test_file_path("HelloWorld.aot")?;
let mut new_loaded = wasm_sandbox.load_module(helloworld_wasm)?;
assert!(
!new_loaded.is_poisoned()?,
"New sandbox should not be poisoned"
);
let result: i32 = new_loaded.call_guest_function("HelloWorld", "Test".to_string())?;
assert_eq!(result, 0);
Ok(())
}
#[test]
fn test_load_module_file() {
let sandboxes = get_test_wasm_sandboxes().unwrap();
for sbox_test in sandboxes {
let name = sbox_test.name;
println!("test_load_module: {name}");
let wasm_sandbox = sbox_test.sbox;
let helloworld_wasm = get_test_file_path("HelloWorld.aot").unwrap();
let mut loaded_wasm_sandbox = wasm_sandbox.load_module(helloworld_wasm).unwrap();
let result: i32 = loaded_wasm_sandbox
.call_guest_function("HelloWorld", "Message from Rust Test".to_string())
.unwrap();
println!("({name}) Result {:?}", result);
}
}
#[test]
fn test_load_from_snapshot() {
let mut sandbox = SandboxBuilder::new().build().unwrap();
sandbox
.register(
"GetTimeSinceBootMicrosecond",
get_time_since_boot_microsecond,
)
.unwrap();
let sb = sandbox.load_runtime().unwrap();
let helloworld_wasm = get_test_file_path("HelloWorld.aot").unwrap();
let runwasm_wasm = get_test_file_path("RunWasm.aot").unwrap();
let mut lb1 = sb.load_module(helloworld_wasm).unwrap();
let result: i32 = lb1
.call_guest_function("HelloWorld", "Message from Rust Test".to_string())
.unwrap();
assert_eq!(result, 0);
let snapshot = lb1.snapshot().unwrap();
let sb = lb1.unload_module().unwrap();
let mut lb2 = sb.load_module(runwasm_wasm).unwrap();
let result: i32 = lb2.call_guest_function("CalcFib", 10i32).unwrap();
assert_eq!(result, 55);
let sb = lb2.unload_module().unwrap();
let mut lb3 = sb.load_from_snapshot(snapshot).unwrap();
let result: i32 = lb3
.call_guest_function("HelloWorld", "Message from Rust Test".to_string())
.unwrap();
assert_eq!(result, 0);
}
#[test]
fn test_load_module_buffer() {
let sandboxes = get_test_wasm_sandboxes().unwrap();
for sbox_test in sandboxes {
let name = sbox_test.name;
println!("test_load_module: {name}");
let wasm_sandbox = sbox_test.sbox;
let wasm_module_buffer: Vec<u8> =
std::fs::read(get_test_file_path("HelloWorld.aot").unwrap()).unwrap();
let mut loaded_wasm_sandbox = wasm_sandbox
.load_module_from_buffer(&wasm_module_buffer)
.unwrap();
let result: i32 = loaded_wasm_sandbox
.call_guest_function("HelloWorld", "Message from Rust Test".to_string())
.unwrap();
println!("({name}) Result {:?}", result);
}
}
pub(super) fn get_test_file_path(filename: &str) -> Result<String> {
#[cfg(debug_assertions)]
let config = "debug";
#[cfg(not(debug_assertions))]
let config = "release";
let proj_dir = env::var_os("CARGO_MANIFEST_DIR").unwrap_or_else(|| {
env::var_os("RUST_DIR_FOR_DEBUGGING_TESTS")
.expect("Failed to get CARGO_MANIFEST_DIR or RUST_DIR_FOR_DEBUGGING_TESTS env var")
});
let relative_path = "../../x64";
let filename_path = Path::new(&proj_dir)
.join(relative_path)
.join(config)
.join(filename);
let full_path = filename_path
.canonicalize()
.unwrap()
.to_str()
.unwrap()
.to_string();
Ok(full_path)
}
struct SandboxTest {
sbox: WasmSandbox,
name: String,
}
fn get_test_wasm_sandboxes() -> Result<Vec<SandboxTest>> {
let builder = SandboxBuilder::new()
.with_guest_input_buffer_size(0x8000)
.with_guest_output_buffer_size(0x8000)
.with_guest_scratch_size(0x2000)
.with_guest_heap_size(0x100000);
let mut sandboxes: Vec<SandboxTest> = Vec::new();
if is_hypervisor_present() {
sandboxes.push(SandboxTest {
sbox: builder.clone().build()?.load_runtime()?,
name: "regular in-hypervisor".to_string(),
});
}
Ok(sandboxes)
}
}