hyperlight_host/func/
call_ctx.rs

1/*
2Copyright 2024 The Hyperlight Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17use hyperlight_common::flatbuffer_wrappers::function_types::{
18    ParameterValue, ReturnType, ReturnValue,
19};
20use tracing::{instrument, Span};
21
22use super::guest_dispatch::call_function_on_guest;
23use crate::{MultiUseSandbox, Result};
24/// A context for calling guest functions.
25///
26/// Takes ownership of an existing `MultiUseSandbox`.
27/// Once created, guest function calls may be made through this and only this context
28/// until it is converted back to the `MultiUseSandbox` from which it originated.
29///
30/// Upon this conversion,the memory associated with the `MultiUseSandbox` it owns will be reset to the state it was in before
31/// this context was created.
32///
33/// Calls made through this context will cause state to be retained across calls, until such time as the context
34/// is converted back to a `MultiUseSandbox`
35///
36/// If dropped, the `MultiUseSandbox` from which it came will be also be dropped as it is owned by the
37/// `MultiUseGuestCallContext` until it is converted back to a `MultiUseSandbox`
38///
39#[derive(Debug)]
40pub struct MultiUseGuestCallContext {
41    sbox: MultiUseSandbox,
42}
43
44impl MultiUseGuestCallContext {
45    /// Take ownership  of a `MultiUseSandbox` and
46    /// return a new `MultiUseGuestCallContext` instance.
47    ///     
48    #[instrument(skip_all, parent = Span::current())]
49    pub fn start(sbox: MultiUseSandbox) -> Self {
50        Self { sbox }
51    }
52
53    /// Call the guest function called `func_name` with the given arguments
54    /// `args`, and expect the return value have the same type as
55    /// `func_ret_type`.
56    ///
57    /// Every call to a guest function through this method will be made with the same "context"
58    /// meaning that the guest state resulting from any previous call will be present/osbservable
59    /// by the guest function called.
60    ///
61    /// If you want  to reset state, call `finish()` on this `MultiUseGuestCallContext`
62    /// and get a new one from the resulting `MultiUseSandbox`
63    #[instrument(err(Debug),skip(self, args),parent = Span::current())]
64    pub fn call(
65        &mut self,
66        func_name: &str,
67        func_ret_type: ReturnType,
68        args: Option<Vec<ParameterValue>>,
69    ) -> Result<ReturnValue> {
70        // we are guaranteed to be holding a lock now, since `self` can't
71        // exist without doing so. Since GuestCallContext is effectively
72        // !Send (and !Sync), we also don't need to worry about
73        // synchronization
74
75        call_function_on_guest(&mut self.sbox, func_name, func_ret_type, args)
76    }
77
78    /// Close out the context and get back the internally-stored
79    /// `MultiUseSandbox`. Future contexts opened by the returned sandbox
80    /// will have guest state restored.
81    #[instrument(err(Debug), skip(self), parent = Span::current())]
82    pub fn finish(mut self) -> Result<MultiUseSandbox> {
83        self.sbox.restore_state()?;
84        Ok(self.sbox)
85    }
86    /// Close out the context and get back the internally-stored
87    /// `MultiUseSandbox`.
88    ///
89    /// Note that this method is pub(crate) and does not reset the state of the
90    /// sandbox.
91    ///
92    /// It is intended to be used when evolving a MultiUseSandbox to a new state
93    /// and is not intended to be called publicly. It allows the state of the guest to be altered
94    /// during the evolution of one sandbox state to another, enabling the new state created
95    /// to be captured and stored in the Sandboxes state stack.
96    ///
97    pub(crate) fn finish_no_reset(self) -> MultiUseSandbox {
98        self.sbox
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use std::sync::mpsc::sync_channel;
105    use std::thread::{self, JoinHandle};
106
107    use hyperlight_common::flatbuffer_wrappers::function_types::{
108        ParameterValue, ReturnType, ReturnValue,
109    };
110    use hyperlight_testing::simple_guest_as_string;
111
112    use crate::sandbox_state::sandbox::EvolvableSandbox;
113    use crate::sandbox_state::transition::Noop;
114    use crate::{GuestBinary, HyperlightError, MultiUseSandbox, Result, UninitializedSandbox};
115
116    fn new_uninit() -> Result<UninitializedSandbox> {
117        let path = simple_guest_as_string().map_err(|e| {
118            HyperlightError::Error(format!("failed to get simple guest path ({e:?})"))
119        })?;
120        UninitializedSandbox::new(GuestBinary::FilePath(path), None, None, None)
121    }
122
123    /// Test to create a `MultiUseSandbox`, then call several guest functions
124    /// on it across different threads.
125    ///
126    /// This test works by passing messages between threads using Rust's
127    /// [mpsc crate](https://doc.rust-lang.org/std/sync/mpsc). Details of this
128    /// interaction are as follows.
129    ///
130    /// One thread acts as the receiver (AKA: consumer) and owns the
131    /// `MultiUseSandbox`. This receiver fields requests from N senders
132    /// (AKA: producers) to make batches of calls.
133    ///
134    /// Upon receipt of a message to execute a batch, a new
135    /// `MultiUseGuestCallContext` is created in the receiver thread from the
136    /// existing `MultiUseSandbox`, and the batch is executed.
137    ///
138    /// After the batch is complete, the `MultiUseGuestCallContext` is done
139    /// and it is converted back to the underlying `MultiUseSandbox`
140    #[test]
141    fn test_multi_call_multi_thread() {
142        let (snd, recv) = sync_channel::<Vec<TestFuncCall>>(0);
143
144        // create new receiver thread and on it, begin listening for
145        // requests to execute batches of calls
146        let recv_hdl = thread::spawn(move || {
147            let mut sbox: MultiUseSandbox = new_uninit().unwrap().evolve(Noop::default()).unwrap();
148            while let Ok(calls) = recv.recv() {
149                let mut ctx = sbox.new_call_context();
150                for call in calls {
151                    let res = ctx
152                        .call(call.func_name.as_str(), call.ret_type, call.params)
153                        .unwrap();
154                    assert_eq!(call.expected_ret, res);
155                }
156                sbox = ctx.finish().unwrap();
157            }
158        });
159
160        // create new sender threads
161        let send_handles: Vec<JoinHandle<()>> = (0..10)
162            .map(|i| {
163                let sender = snd.clone();
164                thread::spawn(move || {
165                    let calls: Vec<TestFuncCall> = vec![
166                        TestFuncCall {
167                            func_name: "Echo".to_string(),
168                            ret_type: ReturnType::String,
169                            params: Some(vec![ParameterValue::String(
170                                format!("Hello {}", i).to_string(),
171                            )]),
172                            expected_ret: ReturnValue::String(format!("Hello {}", i).to_string()),
173                        },
174                        TestFuncCall {
175                            func_name: "CallMalloc".to_string(),
176                            ret_type: ReturnType::Int,
177                            params: Some(vec![ParameterValue::Int(i + 2)]),
178                            expected_ret: ReturnValue::Int(i + 2),
179                        },
180                    ];
181                    sender.send(calls).unwrap();
182                })
183            })
184            .collect();
185
186        for hdl in send_handles {
187            hdl.join().unwrap();
188        }
189        // after all sender threads are done, drop the sender itself
190        // so the receiver thread can exit. then, ensure the receiver
191        // thread has exited.
192        drop(snd);
193        recv_hdl.join().unwrap();
194    }
195
196    pub struct TestSandbox {
197        sandbox: MultiUseSandbox,
198    }
199
200    impl TestSandbox {
201        pub fn new() -> Self {
202            let sbox: MultiUseSandbox = new_uninit().unwrap().evolve(Noop::default()).unwrap();
203            Self { sandbox: sbox }
204        }
205        pub fn call_add_to_static_multiple_times(mut self, i: i32) -> Result<TestSandbox> {
206            let mut ctx = self.sandbox.new_call_context();
207            let mut sum: i32 = 0;
208            for n in 0..i {
209                let result = ctx.call(
210                    "AddToStatic",
211                    ReturnType::Int,
212                    Some(vec![ParameterValue::Int(n)]),
213                );
214                sum += n;
215                println!("{:?}", result);
216                let result = result.unwrap();
217                assert_eq!(result, ReturnValue::Int(sum));
218            }
219            let result = ctx.finish();
220            assert!(result.is_ok());
221            self.sandbox = result.unwrap();
222            Ok(self)
223        }
224
225        pub fn call_add_to_static(mut self, i: i32) -> Result<()> {
226            for n in 0..i {
227                let result = self.sandbox.call_guest_function_by_name(
228                    "AddToStatic",
229                    ReturnType::Int,
230                    Some(vec![ParameterValue::Int(n)]),
231                );
232                println!("{:?}", result);
233                let result = result.unwrap();
234                assert_eq!(result, ReturnValue::Int(n));
235            }
236            Ok(())
237        }
238    }
239
240    #[test]
241    fn ensure_multiusesandbox_multi_calls_dont_reset_state() {
242        let sandbox = TestSandbox::new();
243        let result = sandbox.call_add_to_static_multiple_times(5);
244        assert!(result.is_ok());
245    }
246
247    #[test]
248    fn ensure_multiusesandbox_single_calls_do_reset_state() {
249        let sandbox = TestSandbox::new();
250        let result = sandbox.call_add_to_static(5);
251        assert!(result.is_ok());
252    }
253
254    struct TestFuncCall {
255        func_name: String,
256        ret_type: ReturnType,
257        params: Option<Vec<ParameterValue>>,
258        expected_ret: ReturnValue,
259    }
260}