use hyperlight_common::flatbuffer_wrappers::function_types::{
ParameterValue, ReturnType, ReturnValue,
};
use tracing::{instrument, Span};
use super::guest_dispatch::call_function_on_guest;
use crate::{MultiUseSandbox, Result, SingleUseSandbox};
#[derive(Debug)]
pub struct MultiUseGuestCallContext {
sbox: MultiUseSandbox,
}
impl MultiUseGuestCallContext {
#[instrument(skip_all, parent = Span::current())]
pub fn start(sbox: MultiUseSandbox) -> Self {
Self { sbox }
}
#[instrument(err(Debug),skip(self, args),parent = Span::current())]
pub fn call(
&mut self,
func_name: &str,
func_ret_type: ReturnType,
args: Option<Vec<ParameterValue>>,
) -> Result<ReturnValue> {
call_function_on_guest(&mut self.sbox, func_name, func_ret_type, args)
}
#[instrument(err(Debug), skip(self), parent = Span::current())]
pub fn finish(mut self) -> Result<MultiUseSandbox> {
self.sbox.restore_state()?;
Ok(self.sbox)
}
pub(crate) fn finish_no_reset(self) -> MultiUseSandbox {
self.sbox
}
}
#[derive(Debug)]
pub struct SingleUseGuestCallContext {
sbox: SingleUseSandbox,
}
impl SingleUseGuestCallContext {
#[instrument(skip_all, parent = Span::current())]
pub(crate) fn start(sbox: SingleUseSandbox) -> Self {
Self { sbox }
}
#[instrument(err(Debug),skip(self, args),parent = Span::current())]
pub(crate) fn call(
mut self,
func_name: &str,
func_ret_type: ReturnType,
args: Option<Vec<ParameterValue>>,
) -> Result<ReturnValue> {
self.call_internal(func_name, func_ret_type, args)
}
#[instrument(skip_all, parent = Span::current())]
fn call_internal(
&mut self,
func_name: &str,
func_ret_type: ReturnType,
args: Option<Vec<ParameterValue>>,
) -> Result<ReturnValue> {
call_function_on_guest(&mut self.sbox, func_name, func_ret_type, args)
}
pub fn call_from_func<
Fn: FnOnce(&mut SingleUseMultiGuestCallContext) -> Result<ReturnValue>,
>(
self,
f: Fn,
) -> Result<ReturnValue> {
let mut ctx = SingleUseMultiGuestCallContext::new(self);
f(&mut ctx)
}
}
pub struct SingleUseMultiGuestCallContext {
call_context: SingleUseGuestCallContext,
}
impl SingleUseMultiGuestCallContext {
fn new(call_context: SingleUseGuestCallContext) -> Self {
Self { call_context }
}
pub fn call(
&mut self,
func_name: &str,
func_ret_type: ReturnType,
args: Option<Vec<ParameterValue>>,
) -> Result<ReturnValue> {
self.call_context
.call_internal(func_name, func_ret_type, args)
}
}
#[cfg(test)]
mod tests {
use std::sync::mpsc::sync_channel;
use std::thread::{self, JoinHandle};
use hyperlight_common::flatbuffer_wrappers::function_types::{
ParameterValue, ReturnType, ReturnValue,
};
use hyperlight_testing::simple_guest_as_string;
use crate::func::call_ctx::SingleUseMultiGuestCallContext;
use crate::sandbox_state::sandbox::EvolvableSandbox;
use crate::sandbox_state::transition::Noop;
use crate::{
GuestBinary, HyperlightError, MultiUseSandbox, Result, SingleUseSandbox,
UninitializedSandbox,
};
fn new_uninit() -> Result<UninitializedSandbox> {
let path = simple_guest_as_string().map_err(|e| {
HyperlightError::Error(format!("failed to get simple guest path ({e:?})"))
})?;
UninitializedSandbox::new(GuestBinary::FilePath(path), None, None, None)
}
#[test]
fn singleusesandbox_single_call() {
let calls = [
(
"StackAllocate",
ReturnType::Int,
Some(vec![ParameterValue::Int(1)]),
ReturnValue::Int(1),
),
(
"CallMalloc",
ReturnType::Int,
Some(vec![ParameterValue::Int(200)]),
ReturnValue::Int(200),
),
];
for call in calls.iter() {
let sbox: SingleUseSandbox = new_uninit().unwrap().evolve(Noop::default()).unwrap();
let ctx = sbox.new_call_context();
let res = ctx.call(call.0, call.1, call.2.clone()).unwrap();
assert_eq!(call.3, res);
}
}
#[test]
fn singleusesandbox_multi_call() {
let calls = [
(
"StackAllocate",
ReturnType::Int,
Some(vec![ParameterValue::Int(1)]),
ReturnValue::Int(1),
),
(
"CallMalloc",
ReturnType::Int,
Some(vec![ParameterValue::Int(200)]),
ReturnValue::Int(200),
),
];
let sbox: SingleUseSandbox = new_uninit().unwrap().evolve(Noop::default()).unwrap();
let ctx = sbox.new_call_context();
let callback_closure = |ctx: &mut SingleUseMultiGuestCallContext| {
let mut res: ReturnValue = ReturnValue::Int(0);
for call in calls.iter() {
res = ctx
.call(call.0, call.1, call.2.clone())
.expect("failed to call guest function");
assert_eq!(call.3, res);
}
Ok(res)
};
let res = ctx.call_from_func(callback_closure).unwrap();
assert_eq!(calls.last().unwrap().3, res);
}
#[test]
fn test_multi_call_multi_thread() {
let (snd, recv) = sync_channel::<Vec<TestFuncCall>>(0);
let recv_hdl = thread::spawn(move || {
let mut sbox: MultiUseSandbox = new_uninit().unwrap().evolve(Noop::default()).unwrap();
while let Ok(calls) = recv.recv() {
let mut ctx = sbox.new_call_context();
for call in calls {
let res = ctx
.call(call.func_name.as_str(), call.ret_type, call.params)
.unwrap();
assert_eq!(call.expected_ret, res);
}
sbox = ctx.finish().unwrap();
}
});
let send_handles: Vec<JoinHandle<()>> = (0..10)
.map(|i| {
let sender = snd.clone();
thread::spawn(move || {
let calls: Vec<TestFuncCall> = vec![
TestFuncCall {
func_name: "StackAllocate".to_string(),
ret_type: ReturnType::Int,
params: Some(vec![ParameterValue::Int(i + 1)]),
expected_ret: ReturnValue::Int(i + 1),
},
TestFuncCall {
func_name: "CallMalloc".to_string(),
ret_type: ReturnType::Int,
params: Some(vec![ParameterValue::Int(i + 2)]),
expected_ret: ReturnValue::Int(i + 2),
},
];
sender.send(calls).unwrap();
})
})
.collect();
for hdl in send_handles {
hdl.join().unwrap();
}
drop(snd);
recv_hdl.join().unwrap();
}
pub struct TestSandbox {
sandbox: MultiUseSandbox,
}
impl TestSandbox {
pub fn new() -> Self {
let sbox: MultiUseSandbox = new_uninit().unwrap().evolve(Noop::default()).unwrap();
Self { sandbox: sbox }
}
pub fn call_add_to_static_multiple_times(mut self, i: i32) -> Result<TestSandbox> {
let mut ctx = self.sandbox.new_call_context();
let mut sum: i32 = 0;
for n in 0..i {
let result = ctx.call(
"AddToStatic",
ReturnType::Int,
Some(vec![ParameterValue::Int(n)]),
);
sum += n;
println!("{:?}", result);
let result = result.unwrap();
assert_eq!(result, ReturnValue::Int(sum));
}
let result = ctx.finish();
assert!(result.is_ok());
self.sandbox = result.unwrap();
Ok(self)
}
pub fn call_add_to_static(mut self, i: i32) -> Result<()> {
for n in 0..i {
let result = self.sandbox.call_guest_function_by_name(
"AddToStatic",
ReturnType::Int,
Some(vec![ParameterValue::Int(n)]),
);
println!("{:?}", result);
let result = result.unwrap();
assert_eq!(result, ReturnValue::Int(n));
}
Ok(())
}
}
#[test]
fn ensure_multiusesandbox_multi_calls_dont_reset_state() {
let sandbox = TestSandbox::new();
let result = sandbox.call_add_to_static_multiple_times(5);
assert!(result.is_ok());
}
#[test]
fn ensure_multiusesandbox_single_calls_do_reset_state() {
let sandbox = TestSandbox::new();
let result = sandbox.call_add_to_static(5);
assert!(result.is_ok());
}
struct TestFuncCall {
func_name: String,
ret_type: ReturnType,
params: Option<Vec<ParameterValue>>,
expected_ret: ReturnValue,
}
}