use alloc::format;
use alloc::vec::Vec;
use flatbuffers::FlatBufferBuilder;
use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType};
use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, ParameterType};
use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError};
use hyperlight_guest::error::{HyperlightGuestError, Result};
use tracing::instrument;
use crate::{GUEST_HANDLE, REGISTERED_GUEST_FUNCTIONS};
#[instrument(skip_all, level = "Info")]
pub(crate) fn call_guest_function(function_call: FunctionCall) -> Result<Vec<u8>> {
if function_call.function_call_type() != FunctionCallType::Guest {
return Err(HyperlightGuestError::new(
ErrorCode::GuestError,
format!(
"Invalid function call type: {:#?}, should be Guest.",
function_call.function_call_type()
),
));
}
#[allow(clippy::deref_addrof)]
if let Some(registered_function_definition) =
unsafe { (*(&raw const REGISTERED_GUEST_FUNCTIONS)).get(&function_call.function_name) }
{
let function_call_parameter_types: Vec<ParameterType> = function_call
.parameters
.iter()
.flatten()
.map(|p| p.into())
.collect();
registered_function_definition.verify_parameters(&function_call_parameter_types)?;
(registered_function_definition.function_pointer)(function_call)
} else {
unsafe extern "Rust" {
fn guest_dispatch_function(function_call: FunctionCall) -> Result<Vec<u8>>;
}
unsafe { guest_dispatch_function(function_call) }
}
}
pub(crate) fn internal_dispatch_function() {
#[cfg(all(feature = "trace_guest", target_arch = "x86_64"))]
let _entered = {
let guest_start_tsc = hyperlight_guest_tracing::invariant_tsc::read_tsc();
hyperlight_guest_tracing::new_call(guest_start_tsc);
tracing::span!(tracing::Level::INFO, "internal_dispatch_function").entered()
};
let handle = unsafe { GUEST_HANDLE };
let function_call = handle
.try_pop_shared_input_data_into::<FunctionCall>()
.expect("Function call deserialization failed");
let res = call_guest_function(function_call);
match res {
Ok(bytes) => {
handle
.push_shared_output_data(bytes.as_slice())
.expect("Failed to serialize function call result");
}
Err(err) => {
let guest_error = Err(GuestError::new(err.kind, err.message));
let fcr = FunctionCallResult::new(guest_error);
let mut builder = FlatBufferBuilder::new();
let data = fcr.encode(&mut builder);
handle
.push_shared_output_data(data)
.expect("Failed to serialize function call result");
}
}
#[cfg(all(feature = "trace_guest", target_arch = "x86_64"))]
{
_entered.exit();
hyperlight_guest_tracing::flush();
}
}