use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall;
use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterType, ReturnType};
use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode;
use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result;
use hyperlight_common::for_each_tuple;
use hyperlight_common::func::{
Function, ParameterTuple, ResultType, ReturnValue, SupportedReturnType,
};
use hyperlight_guest::error::{HyperlightGuestError, Result};
pub type GuestFunc = fn(FunctionCall) -> Result<Vec<u8>>;
#[derive(Debug, Clone)]
pub struct GuestFunctionDefinition<F: Copy> {
pub function_name: String,
pub parameter_types: Vec<ParameterType>,
pub return_type: ReturnType,
pub function_pointer: F,
}
#[doc(hidden)]
pub trait IntoGuestFunction<Output, Args>
where
Self: Function<Output, Args, HyperlightGuestError>,
Self: Copy + 'static,
Output: SupportedReturnType,
Args: ParameterTuple,
{
#[doc(hidden)]
const ASSERT_ZERO_SIZED: ();
fn into_guest_function(self) -> fn(FunctionCall) -> Result<Vec<u8>>;
}
pub trait AsGuestFunctionDefinition<Output, Args>
where
Self: Function<Output, Args, HyperlightGuestError>,
Self: IntoGuestFunction<Output, Args>,
Output: SupportedReturnType,
Args: ParameterTuple,
{
fn as_guest_function_definition(
&self,
name: impl Into<String>,
) -> GuestFunctionDefinition<GuestFunc>;
}
fn into_flatbuffer_result(value: ReturnValue) -> Vec<u8> {
match value {
ReturnValue::Void(()) => get_flatbuffer_result(()),
ReturnValue::Int(i) => get_flatbuffer_result(i),
ReturnValue::UInt(u) => get_flatbuffer_result(u),
ReturnValue::Long(l) => get_flatbuffer_result(l),
ReturnValue::ULong(ul) => get_flatbuffer_result(ul),
ReturnValue::Float(f) => get_flatbuffer_result(f),
ReturnValue::Double(d) => get_flatbuffer_result(d),
ReturnValue::Bool(b) => get_flatbuffer_result(b),
ReturnValue::String(s) => get_flatbuffer_result(s.as_str()),
ReturnValue::VecBytes(v) => get_flatbuffer_result(v.as_slice()),
}
}
macro_rules! impl_host_function {
([$N:expr] ($($p:ident: $P:ident),*)) => {
impl<F, R, $($P),*> IntoGuestFunction<R::ReturnType, ($($P,)*)> for F
where
F: Fn($($P),*) -> R,
F: Function<R::ReturnType, ($($P,)*), HyperlightGuestError>,
F: Copy + 'static, ($($P,)*): ParameterTuple,
R: ResultType<HyperlightGuestError>,
{
#[doc(hidden)]
const ASSERT_ZERO_SIZED: () = const {
assert!(core::mem::size_of::<Self>() == 0)
};
fn into_guest_function(self) -> fn(FunctionCall) -> Result<Vec<u8>> {
|fc: FunctionCall| {
let this = unsafe { core::mem::zeroed::<F>() };
let params = fc.parameters.unwrap_or_default();
let params = <($($P,)*) as ParameterTuple>::from_value(params)?;
let result = Function::<R::ReturnType, ($($P,)*), HyperlightGuestError>::call(&this, params)?;
Ok(into_flatbuffer_result(result.into_value()))
}
}
}
};
}
impl<F, Args, Output> AsGuestFunctionDefinition<Output, Args> for F
where
F: IntoGuestFunction<Output, Args>,
Args: ParameterTuple,
Output: SupportedReturnType,
{
fn as_guest_function_definition(
&self,
name: impl Into<String>,
) -> GuestFunctionDefinition<GuestFunc> {
let parameter_types = Args::TYPE.to_vec();
let return_type = Output::TYPE;
let function_pointer = self.into_guest_function();
GuestFunctionDefinition {
function_name: name.into(),
parameter_types,
return_type,
function_pointer,
}
}
}
for_each_tuple!(impl_host_function);
impl<F: Copy> GuestFunctionDefinition<F> {
pub fn new(
function_name: String,
parameter_types: Vec<ParameterType>,
return_type: ReturnType,
function_pointer: F,
) -> Self {
Self {
function_name,
parameter_types,
return_type,
function_pointer,
}
}
pub fn from_fn<Output, Args>(
function_name: String,
function: impl AsGuestFunctionDefinition<Output, Args>,
) -> GuestFunctionDefinition<GuestFunc>
where
Args: ParameterTuple,
Output: SupportedReturnType,
{
function.as_guest_function_definition(function_name)
}
pub fn verify_parameters(&self, parameter_types: &[ParameterType]) -> Result<()> {
const MAX_PARAMETERS: usize = 11;
if parameter_types.len() > MAX_PARAMETERS {
return Err(HyperlightGuestError::new(
ErrorCode::GuestError,
format!(
"Function {} has too many parameters: {} (max allowed is {}).",
self.function_name,
parameter_types.len(),
MAX_PARAMETERS
),
));
}
if self.parameter_types.len() != parameter_types.len() {
return Err(HyperlightGuestError::new(
ErrorCode::GuestFunctionIncorrecNoOfParameters,
format!(
"Called function {} with {} parameters but it takes {}.",
self.function_name,
parameter_types.len(),
self.parameter_types.len()
),
));
}
for (i, parameter_type) in self.parameter_types.iter().enumerate() {
if parameter_type != ¶meter_types[i] {
return Err(HyperlightGuestError::new(
ErrorCode::GuestFunctionParameterTypeMismatch,
format!(
"Expected parameter type {:?} for parameter index {} of function {} but got {:?}.",
parameter_type, i, self.function_name, parameter_types[i]
),
));
}
}
Ok(())
}
}