hyperlight_host/func/
host_functions.rs

1/*
2Copyright 2024 The Hyperlight Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17#![allow(non_snake_case)]
18use std::sync::{Arc, Mutex};
19
20use hyperlight_common::flatbuffer_wrappers::function_types::ParameterValue;
21use hyperlight_common::flatbuffer_wrappers::host_function_definition::HostFunctionDefinition;
22use paste::paste;
23use tracing::{instrument, Span};
24
25use super::{HyperlightFunction, SupportedParameterType, SupportedReturnType};
26use crate::sandbox::{ExtraAllowedSyscall, UninitializedSandbox};
27use crate::HyperlightError::UnexpectedNoOfArguments;
28use crate::{log_then_return, new_error, Result};
29
30macro_rules! host_function {
31    // Special case for zero parameters
32    (0) => {
33        paste! {
34            /// Trait for registering a host function with zero parameters.
35            pub trait HostFunction0<'a, R: SupportedReturnType<R>> {
36                /// Register the host function with the given name in the sandbox.
37                fn register(
38                    &self,
39                    sandbox: &mut UninitializedSandbox,
40                    name: &str,
41                ) -> Result<()>;
42
43                /// Register the host function with the given name in the sandbox, allowing extra syscalls.
44                #[cfg(all(feature = "seccomp", target_os = "linux"))]
45                fn register_with_extra_allowed_syscalls(
46                    &self,
47                    sandbox: &mut UninitializedSandbox,
48                    name: &str,
49                    extra_allowed_syscalls: Vec<ExtraAllowedSyscall>,
50                ) -> Result<()>;
51            }
52
53            impl<'a, T, R> HostFunction0<'a, R> for Arc<Mutex<T>>
54            where
55                T: FnMut() -> Result<R> + Send + 'static,
56                R: SupportedReturnType<R>,
57            {
58                #[instrument(
59                    err(Debug), skip(self, sandbox), parent = Span::current(), level = "Trace"
60                )]
61                fn register(
62                    &self,
63                    sandbox: &mut UninitializedSandbox,
64                    name: &str,
65                ) -> Result<()> {
66                    register_host_function_0(self.clone(), sandbox, name, None)
67                }
68
69                #[cfg(all(feature = "seccomp", target_os = "linux"))]
70                #[instrument(
71                    err(Debug), skip(self, sandbox, extra_allowed_syscalls),
72                    parent = Span::current(), level = "Trace"
73                )]
74                fn register_with_extra_allowed_syscalls(
75                    &self,
76                    sandbox: &mut UninitializedSandbox,
77                    name: &str,
78                    extra_allowed_syscalls: Vec<ExtraAllowedSyscall>,
79                ) -> Result<()> {
80                    register_host_function_0(self.clone(), sandbox, name, Some(extra_allowed_syscalls))
81                }
82            }
83
84            fn register_host_function_0<T, R>(
85                self_: Arc<Mutex<T>>,
86                sandbox: &mut UninitializedSandbox,
87                name: &str,
88                extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
89            ) -> Result<()>
90            where
91                T: FnMut() -> Result<R> + Send + 'static,
92                R: SupportedReturnType<R>,
93            {
94                let cloned = self_.clone();
95                let func = Box::new(move |_: Vec<ParameterValue>| {
96                    let result = cloned
97                        .try_lock()
98                        .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?()?;
99                    Ok(result.get_hyperlight_value())
100                });
101
102                if let Some(_eas) = extra_allowed_syscalls {
103                    if cfg!(all(feature = "seccomp", target_os = "linux")) {
104                        // Register with extra allowed syscalls
105                        #[cfg(all(feature = "seccomp", target_os = "linux"))]
106                        {
107                            sandbox
108                                .host_funcs
109                                .try_lock()
110                                .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?
111                                .register_host_function_with_syscalls(
112                                    sandbox.mgr.as_mut(),
113                                    &HostFunctionDefinition::new(name.to_string(), None, R::get_hyperlight_type()),
114                                    HyperlightFunction::new(func),
115                                    _eas,
116                                )?;
117                        }
118                    } else {
119                        // Log and return an error
120                        log_then_return!("Extra allowed syscalls are only supported on Linux with seccomp enabled");
121                    }
122                } else {
123                    // Register without extra allowed syscalls
124                    sandbox
125                        .host_funcs
126                        .try_lock()
127                        .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?
128                        .register_host_function(
129                            sandbox.mgr.as_mut(),
130                            &HostFunctionDefinition::new(name.to_string(), None, R::get_hyperlight_type()),
131                            HyperlightFunction::new(func),
132                        )?;
133                }
134
135                Ok(())
136            }
137        }
138    };
139    // General case for one or more parameters
140    ($N:expr, $($P:ident),+) => {
141        paste! {
142            /// Trait for registering a host function with $N parameters.
143            pub trait [<HostFunction $N>]<'a, $($P,)* R>
144            where
145                $($P: SupportedParameterType<$P> + Clone + 'a,)*
146                R: SupportedReturnType<R>,
147            {
148                /// Register the host function with the given name in the sandbox.
149                fn register(
150                    &self,
151                    sandbox: &mut UninitializedSandbox,
152                    name: &str,
153                ) -> Result<()>;
154
155                /// Register the host function with the given name in the sandbox, allowing extra syscalls.
156                #[cfg(all(feature = "seccomp", target_os = "linux"))]
157                fn register_with_extra_allowed_syscalls(
158                    &self,
159                    sandbox: &mut UninitializedSandbox,
160                    name: &str,
161                    extra_allowed_syscalls: Vec<ExtraAllowedSyscall>,
162                ) -> Result<()>;
163            }
164
165            impl<'a, T, $($P,)* R> [<HostFunction $N>]<'a, $($P,)* R> for Arc<Mutex<T>>
166            where
167                T: FnMut($($P),*) -> Result<R> + Send + 'static,
168                $($P: SupportedParameterType<$P> + Clone + 'a,)*
169                R: SupportedReturnType<R>,
170            {
171                #[instrument(
172                    err(Debug), skip(self, sandbox), parent = Span::current(), level = "Trace"
173                )]
174                fn register(
175                    &self,
176                    sandbox: &mut UninitializedSandbox,
177                    name: &str,
178                ) -> Result<()> {
179                    [<register_host_function_ $N>](self.clone(), sandbox, name, None)
180                }
181
182                #[cfg(all(feature = "seccomp", target_os = "linux"))]
183                #[instrument(
184                    err(Debug), skip(self, sandbox, extra_allowed_syscalls),
185                    parent = Span::current(), level = "Trace"
186                )]
187                fn register_with_extra_allowed_syscalls(
188                    &self,
189                    sandbox: &mut UninitializedSandbox,
190                    name: &str,
191                    extra_allowed_syscalls: Vec<ExtraAllowedSyscall>,
192                ) -> Result<()> {
193                    [<register_host_function_ $N>](self.clone(), sandbox, name, Some(extra_allowed_syscalls))
194                }
195            }
196
197            fn [<register_host_function_ $N>]<'a, T, $($P,)* R>(
198                self_: Arc<Mutex<T>>,
199                sandbox: &mut UninitializedSandbox,
200                name: &str,
201                extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
202            ) -> Result<()>
203            where
204                T: FnMut($($P),*) -> Result<R> + Send + 'static,
205                $($P: SupportedParameterType<$P> + Clone + 'a,)*
206                R: SupportedReturnType<R>,
207            {
208                let cloned = self_.clone();
209                let func = Box::new(move |args: Vec<ParameterValue>| {
210                    if args.len() != $N {
211                        log_then_return!(UnexpectedNoOfArguments(args.len(), $N));
212                    }
213
214                    let mut args_iter = args.into_iter();
215                    $(
216                        let $P = $P::get_inner(args_iter.next().unwrap())?;
217                    )*
218
219                    let result = cloned
220                        .try_lock()
221                        .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?(
222                            $($P),*
223                        )?;
224                    Ok(result.get_hyperlight_value())
225                });
226
227                let parameter_types = Some(vec![$($P::get_hyperlight_type()),*]);
228
229                if let Some(_eas) = extra_allowed_syscalls {
230                    if cfg!(all(feature = "seccomp", target_os = "linux")) {
231                        // Register with extra allowed syscalls
232                        #[cfg(all(feature = "seccomp", target_os = "linux"))]
233                        {
234                            sandbox
235                                .host_funcs
236                                .try_lock()
237                                .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?
238                                .register_host_function_with_syscalls(
239                                    sandbox.mgr.as_mut(),
240                                    &HostFunctionDefinition::new(
241                                        name.to_string(),
242                                        parameter_types,
243                                        R::get_hyperlight_type(),
244                                    ),
245                                    HyperlightFunction::new(func),
246                                    _eas,
247                                )?;
248                        }
249                    } else {
250                        // Log and return an error
251                        log_then_return!("Extra allowed syscalls are only supported on Linux with seccomp enabled");
252                    }
253                } else {
254                    // Register without extra allowed syscalls
255                    sandbox
256                        .host_funcs
257                        .try_lock()
258                        .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?
259                        .register_host_function(
260                            sandbox.mgr.as_mut(),
261                            &HostFunctionDefinition::new(
262                                name.to_string(),
263                                parameter_types,
264                                R::get_hyperlight_type(),
265                            ),
266                            HyperlightFunction::new(func),
267                        )?;
268                }
269
270                Ok(())
271            }
272        }
273    };
274}
275
276host_function!(0);
277host_function!(1, P1);
278host_function!(2, P1, P2);
279host_function!(3, P1, P2, P3);
280host_function!(4, P1, P2, P3, P4);
281host_function!(5, P1, P2, P3, P4, P5);
282host_function!(6, P1, P2, P3, P4, P5, P6);
283host_function!(7, P1, P2, P3, P4, P5, P6, P7);
284host_function!(8, P1, P2, P3, P4, P5, P6, P7, P8);
285host_function!(9, P1, P2, P3, P4, P5, P6, P7, P8, P9);
286host_function!(10, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10);