1use std::sync::{Arc, Mutex};
18
19use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnValue};
20
21use super::utils::for_each_tuple;
22use super::{ParameterTuple, ResultType, SupportedReturnType};
23use crate::sandbox::{ExtraAllowedSyscall, UninitializedSandbox};
24use crate::{Result, log_then_return, new_error};
25
26pub trait Registerable {
29 fn register_host_function<Args: ParameterTuple, Output: SupportedReturnType>(
31 &mut self,
32 name: &str,
33 hf: impl Into<HostFunction<Output, Args>>,
34 ) -> Result<()>;
35 #[cfg(all(feature = "seccomp", target_os = "linux"))]
38 fn register_host_function_with_syscalls<Args: ParameterTuple, Output: SupportedReturnType>(
39 &mut self,
40 name: &str,
41 hf: impl Into<HostFunction<Output, Args>>,
42 eas: Vec<ExtraAllowedSyscall>,
43 ) -> Result<()>;
44}
45impl Registerable for UninitializedSandbox {
46 fn register_host_function<Args: ParameterTuple, Output: SupportedReturnType>(
47 &mut self,
48 name: &str,
49 hf: impl Into<HostFunction<Output, Args>>,
50 ) -> Result<()> {
51 let mut hfs = self
52 .host_funcs
53 .try_lock()
54 .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?;
55 (*hfs).register_host_function(name.to_string(), hf.into().into())
56 }
57 #[cfg(all(feature = "seccomp", target_os = "linux"))]
58 fn register_host_function_with_syscalls<Args: ParameterTuple, Output: SupportedReturnType>(
59 &mut self,
60 name: &str,
61 hf: impl Into<HostFunction<Output, Args>>,
62 eas: Vec<ExtraAllowedSyscall>,
63 ) -> Result<()> {
64 let mut hfs = self
65 .host_funcs
66 .try_lock()
67 .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?;
68 (*hfs).register_host_function_with_syscalls(name.to_string(), hf.into().into(), eas)
69 }
70}
71
72#[derive(Clone)]
75pub struct HostFunction<Output, Args>
76where
77 Args: ParameterTuple,
78 Output: SupportedReturnType,
79{
80 func: Arc<dyn Fn(Args) -> Result<Output> + Send + Sync + 'static>,
105}
106
107pub(crate) struct TypeErasedHostFunction {
108 func: Box<dyn Fn(Vec<ParameterValue>) -> Result<ReturnValue> + Send + Sync + 'static>,
109}
110
111impl<Args, Output> HostFunction<Output, Args>
112where
113 Args: ParameterTuple,
114 Output: SupportedReturnType,
115{
116 pub fn call(&self, args: Args) -> Result<Output> {
118 (self.func)(args)
119 }
120}
121
122impl TypeErasedHostFunction {
123 pub(crate) fn call(&self, args: Vec<ParameterValue>) -> Result<ReturnValue> {
124 (self.func)(args)
125 }
126}
127
128impl<Args, Output> From<HostFunction<Output, Args>> for TypeErasedHostFunction
129where
130 Args: ParameterTuple,
131 Output: SupportedReturnType,
132{
133 fn from(func: HostFunction<Output, Args>) -> TypeErasedHostFunction {
134 TypeErasedHostFunction {
135 func: Box::new(move |args: Vec<ParameterValue>| {
136 let args = Args::from_value(args)?;
137 Ok(func.call(args)?.into_value())
138 }),
139 }
140 }
141}
142
143macro_rules! impl_host_function {
144 ([$N:expr] ($($p:ident: $P:ident),*)) => {
145 impl<F, R, $($P),*> From<F> for HostFunction<R::ReturnType, ($($P,)*)>
153 where
154 F: FnMut($($P),*) -> R + Send + 'static,
155 ($($P,)*): ParameterTuple,
156 R: ResultType,
157 {
158 fn from(mut func: F) -> HostFunction<R::ReturnType, ($($P,)*)> {
159 let func = move |($($p,)*): ($($P,)*)| -> Result<R::ReturnType> {
160 func($($p),*).into_result()
161 };
162 let func = Mutex::new(func);
163 HostFunction {
164 func: Arc::new(move |args: ($($P,)*)| {
165 func.try_lock()
166 .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?
167 (args)
168 })
169 }
170 }
171 }
172 };
173}
174
175for_each_tuple!(impl_host_function);
176
177pub(crate) fn register_host_function<Args: ParameterTuple, Output: SupportedReturnType>(
178 func: impl Into<HostFunction<Output, Args>>,
179 sandbox: &mut UninitializedSandbox,
180 name: &str,
181 extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
182) -> Result<()> {
183 let func = func.into().into();
184
185 if let Some(_eas) = extra_allowed_syscalls {
186 if cfg!(all(feature = "seccomp", target_os = "linux")) {
187 #[cfg(all(feature = "seccomp", target_os = "linux"))]
189 {
190 sandbox
191 .host_funcs
192 .try_lock()
193 .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?
194 .register_host_function_with_syscalls(name.to_string(), func, _eas)?;
195 }
196 } else {
197 log_then_return!(
199 "Extra allowed syscalls are only supported on Linux with seccomp enabled"
200 );
201 }
202 } else {
203 sandbox
205 .host_funcs
206 .try_lock()
207 .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?
208 .register_host_function(name.to_string(), func)?;
209 }
210
211 Ok(())
212}