hyperlight_guest_bin/guest_function/
definition.rs1use alloc::format;
18use alloc::string::String;
19use alloc::vec::Vec;
20
21use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall;
22use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterType, ReturnType};
23use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode;
24use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result;
25use hyperlight_common::for_each_tuple;
26use hyperlight_common::func::{
27 Function, ParameterTuple, ResultType, ReturnValue, SupportedReturnType,
28};
29use hyperlight_guest::error::{HyperlightGuestError, Result};
30
31pub type GuestFunc = fn(FunctionCall) -> Result<Vec<u8>>;
33
34#[derive(Debug, Clone)]
39pub struct GuestFunctionDefinition<F: Copy> {
40 pub function_name: String,
42 pub parameter_types: Vec<ParameterType>,
44 pub return_type: ReturnType,
46 pub function_pointer: F,
48}
49
50#[doc(hidden)]
52pub trait IntoGuestFunction<Output, Args>
53where
54 Self: Function<Output, Args, HyperlightGuestError>,
55 Self: Copy + 'static,
56 Output: SupportedReturnType,
57 Args: ParameterTuple,
58{
59 #[doc(hidden)]
60 const ASSERT_ZERO_SIZED: ();
61
62 fn into_guest_function(self) -> fn(FunctionCall) -> Result<Vec<u8>>;
64}
65
66pub trait AsGuestFunctionDefinition<Output, Args>
68where
69 Self: Function<Output, Args, HyperlightGuestError>,
70 Self: IntoGuestFunction<Output, Args>,
71 Output: SupportedReturnType,
72 Args: ParameterTuple,
73{
74 fn as_guest_function_definition(
76 &self,
77 name: impl Into<String>,
78 ) -> GuestFunctionDefinition<GuestFunc>;
79}
80
81fn into_flatbuffer_result(value: ReturnValue) -> Vec<u8> {
82 match value {
83 ReturnValue::Void(()) => get_flatbuffer_result(()),
84 ReturnValue::Int(i) => get_flatbuffer_result(i),
85 ReturnValue::UInt(u) => get_flatbuffer_result(u),
86 ReturnValue::Long(l) => get_flatbuffer_result(l),
87 ReturnValue::ULong(ul) => get_flatbuffer_result(ul),
88 ReturnValue::Float(f) => get_flatbuffer_result(f),
89 ReturnValue::Double(d) => get_flatbuffer_result(d),
90 ReturnValue::Bool(b) => get_flatbuffer_result(b),
91 ReturnValue::String(s) => get_flatbuffer_result(s.as_str()),
92 ReturnValue::VecBytes(v) => get_flatbuffer_result(v.as_slice()),
93 }
94}
95
96macro_rules! impl_host_function {
97 ([$N:expr] ($($p:ident: $P:ident),*)) => {
98 impl<F, R, $($P),*> IntoGuestFunction<R::ReturnType, ($($P,)*)> for F
99 where
100 F: Fn($($P),*) -> R,
101 F: Function<R::ReturnType, ($($P,)*), HyperlightGuestError>,
102 F: Copy + 'static, ($($P,)*): ParameterTuple,
104 R: ResultType<HyperlightGuestError>,
105 {
106 #[doc(hidden)]
132 const ASSERT_ZERO_SIZED: () = const {
133 assert!(core::mem::size_of::<Self>() == 0)
134 };
135
136 fn into_guest_function(self) -> fn(FunctionCall) -> Result<Vec<u8>> {
137 |fc: FunctionCall| {
138 let this = unsafe { core::mem::zeroed::<F>() };
143 let params = fc.parameters.unwrap_or_default();
144 let params = <($($P,)*) as ParameterTuple>::from_value(params)?;
145 let result = Function::<R::ReturnType, ($($P,)*), HyperlightGuestError>::call(&this, params)?;
146 Ok(into_flatbuffer_result(result.into_value()))
147 }
148 }
149 }
150 };
151}
152
153impl<F, Args, Output> AsGuestFunctionDefinition<Output, Args> for F
154where
155 F: IntoGuestFunction<Output, Args>,
156 Args: ParameterTuple,
157 Output: SupportedReturnType,
158{
159 fn as_guest_function_definition(
160 &self,
161 name: impl Into<String>,
162 ) -> GuestFunctionDefinition<GuestFunc> {
163 let parameter_types = Args::TYPE.to_vec();
164 let return_type = Output::TYPE;
165 let function_pointer = self.into_guest_function();
166
167 GuestFunctionDefinition {
168 function_name: name.into(),
169 parameter_types,
170 return_type,
171 function_pointer,
172 }
173 }
174}
175
176for_each_tuple!(impl_host_function);
177
178impl<F: Copy> GuestFunctionDefinition<F> {
179 pub fn new(
181 function_name: String,
182 parameter_types: Vec<ParameterType>,
183 return_type: ReturnType,
184 function_pointer: F,
185 ) -> Self {
186 Self {
187 function_name,
188 parameter_types,
189 return_type,
190 function_pointer,
191 }
192 }
193
194 pub fn from_fn<Output, Args>(
197 function_name: String,
198 function: impl AsGuestFunctionDefinition<Output, Args>,
199 ) -> GuestFunctionDefinition<GuestFunc>
200 where
201 Args: ParameterTuple,
202 Output: SupportedReturnType,
203 {
204 function.as_guest_function_definition(function_name)
205 }
206
207 pub fn verify_parameters(&self, parameter_types: &[ParameterType]) -> Result<()> {
209 const MAX_PARAMETERS: usize = 11;
211 if parameter_types.len() > MAX_PARAMETERS {
212 return Err(HyperlightGuestError::new(
213 ErrorCode::GuestError,
214 format!(
215 "Function {} has too many parameters: {} (max allowed is {}).",
216 self.function_name,
217 parameter_types.len(),
218 MAX_PARAMETERS
219 ),
220 ));
221 }
222
223 if self.parameter_types.len() != parameter_types.len() {
224 return Err(HyperlightGuestError::new(
225 ErrorCode::GuestFunctionIncorrecNoOfParameters,
226 format!(
227 "Called function {} with {} parameters but it takes {}.",
228 self.function_name,
229 parameter_types.len(),
230 self.parameter_types.len()
231 ),
232 ));
233 }
234
235 for (i, parameter_type) in self.parameter_types.iter().enumerate() {
236 if parameter_type != ¶meter_types[i] {
237 return Err(HyperlightGuestError::new(
238 ErrorCode::GuestFunctionParameterTypeMismatch,
239 format!(
240 "Expected parameter type {:?} for parameter index {} of function {} but got {:?}.",
241 parameter_type, i, self.function_name, parameter_types[i]
242 ),
243 ));
244 }
245 }
246
247 Ok(())
248 }
249}