hyperlight_common/flatbuffer_wrappers/
function_call.rs

1/*
2Copyright 2025  The Hyperlight Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17use alloc::string::{String, ToString};
18use alloc::vec::Vec;
19
20use anyhow::{Error, Result, bail};
21use flatbuffers::{WIPOffset, size_prefixed_root};
22#[cfg(feature = "tracing")]
23use tracing::{Span, instrument};
24
25use super::function_types::{ParameterValue, ReturnType};
26use crate::flatbuffers::hyperlight::generated::{
27    FunctionCall as FbFunctionCall, FunctionCallArgs as FbFunctionCallArgs,
28    FunctionCallType as FbFunctionCallType, Parameter, ParameterArgs,
29    ParameterValue as FbParameterValue, hlbool, hlboolArgs, hldouble, hldoubleArgs, hlfloat,
30    hlfloatArgs, hlint, hlintArgs, hllong, hllongArgs, hlstring, hlstringArgs, hluint, hluintArgs,
31    hlulong, hlulongArgs, hlvecbytes, hlvecbytesArgs,
32};
33
34/// The type of function call.
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum FunctionCallType {
37    /// The function call is to a guest function.
38    Guest,
39    /// The function call is to a host function.
40    Host,
41}
42
43/// `Functioncall` represents a call to a function in the guest or host.
44#[derive(Clone)]
45pub struct FunctionCall {
46    /// The function name
47    pub function_name: String,
48    /// The parameters for the function call.
49    pub parameters: Option<Vec<ParameterValue>>,
50    function_call_type: FunctionCallType,
51    /// The return type of the function call
52    pub expected_return_type: ReturnType,
53}
54
55impl FunctionCall {
56    #[cfg_attr(feature = "tracing", instrument(skip_all, parent = Span::current(), level= "Trace"))]
57    pub fn new(
58        function_name: String,
59        parameters: Option<Vec<ParameterValue>>,
60        function_call_type: FunctionCallType,
61        expected_return_type: ReturnType,
62    ) -> Self {
63        Self {
64            function_name,
65            parameters,
66            function_call_type,
67            expected_return_type,
68        }
69    }
70
71    /// The type of the function call.
72    pub fn function_call_type(&self) -> FunctionCallType {
73        self.function_call_type.clone()
74    }
75}
76
77#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
78pub fn validate_guest_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> {
79    let guest_function_call_fb = size_prefixed_root::<FbFunctionCall>(function_call_buffer)
80        .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
81    match guest_function_call_fb.function_call_type() {
82        FbFunctionCallType::guest => Ok(()),
83        other => {
84            bail!("Invalid function call type: {:?}", other);
85        }
86    }
87}
88
89#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
90pub fn validate_host_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> {
91    let host_function_call_fb = size_prefixed_root::<FbFunctionCall>(function_call_buffer)
92        .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
93    match host_function_call_fb.function_call_type() {
94        FbFunctionCallType::host => Ok(()),
95        other => {
96            bail!("Invalid function call type: {:?}", other);
97        }
98    }
99}
100
101impl TryFrom<&[u8]> for FunctionCall {
102    type Error = Error;
103    #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
104    fn try_from(value: &[u8]) -> Result<Self> {
105        let function_call_fb = size_prefixed_root::<FbFunctionCall>(value)
106            .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
107        let function_name = function_call_fb.function_name();
108        let function_call_type = match function_call_fb.function_call_type() {
109            FbFunctionCallType::guest => FunctionCallType::Guest,
110            FbFunctionCallType::host => FunctionCallType::Host,
111            other => {
112                bail!("Invalid function call type: {:?}", other);
113            }
114        };
115        let expected_return_type = function_call_fb.expected_return_type().try_into()?;
116
117        let parameters = function_call_fb
118            .parameters()
119            .map(|v| {
120                v.iter()
121                    .map(|p| p.try_into())
122                    .collect::<Result<Vec<ParameterValue>>>()
123            })
124            .transpose()?;
125
126        Ok(Self {
127            function_name: function_name.to_string(),
128            parameters,
129            function_call_type,
130            expected_return_type,
131        })
132    }
133}
134
135impl TryFrom<FunctionCall> for Vec<u8> {
136    type Error = Error;
137    #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
138    fn try_from(value: FunctionCall) -> Result<Vec<u8>> {
139        let mut builder = flatbuffers::FlatBufferBuilder::new();
140        let function_name = builder.create_string(&value.function_name);
141
142        let function_call_type = match value.function_call_type {
143            FunctionCallType::Guest => FbFunctionCallType::guest,
144            FunctionCallType::Host => FbFunctionCallType::host,
145        };
146
147        let expected_return_type = value.expected_return_type.into();
148
149        let parameters = match &value.parameters {
150            Some(p) => {
151                let num_items = p.len();
152                let mut parameters: Vec<WIPOffset<Parameter>> = Vec::with_capacity(num_items);
153
154                for param in p {
155                    match param {
156                        ParameterValue::Int(i) => {
157                            let hlint = hlint::create(&mut builder, &hlintArgs { value: *i });
158                            let parameter = Parameter::create(
159                                &mut builder,
160                                &ParameterArgs {
161                                    value_type: FbParameterValue::hlint,
162                                    value: Some(hlint.as_union_value()),
163                                },
164                            );
165                            parameters.push(parameter);
166                        }
167                        ParameterValue::UInt(ui) => {
168                            let hluint = hluint::create(&mut builder, &hluintArgs { value: *ui });
169                            let parameter = Parameter::create(
170                                &mut builder,
171                                &ParameterArgs {
172                                    value_type: FbParameterValue::hluint,
173                                    value: Some(hluint.as_union_value()),
174                                },
175                            );
176                            parameters.push(parameter);
177                        }
178                        ParameterValue::Long(l) => {
179                            let hllong = hllong::create(&mut builder, &hllongArgs { value: *l });
180                            let parameter = Parameter::create(
181                                &mut builder,
182                                &ParameterArgs {
183                                    value_type: FbParameterValue::hllong,
184                                    value: Some(hllong.as_union_value()),
185                                },
186                            );
187                            parameters.push(parameter);
188                        }
189                        ParameterValue::ULong(ul) => {
190                            let hlulong =
191                                hlulong::create(&mut builder, &hlulongArgs { value: *ul });
192                            let parameter = Parameter::create(
193                                &mut builder,
194                                &ParameterArgs {
195                                    value_type: FbParameterValue::hlulong,
196                                    value: Some(hlulong.as_union_value()),
197                                },
198                            );
199                            parameters.push(parameter);
200                        }
201                        ParameterValue::Float(f) => {
202                            let hlfloat = hlfloat::create(&mut builder, &hlfloatArgs { value: *f });
203                            let parameter = Parameter::create(
204                                &mut builder,
205                                &ParameterArgs {
206                                    value_type: FbParameterValue::hlfloat,
207                                    value: Some(hlfloat.as_union_value()),
208                                },
209                            );
210                            parameters.push(parameter);
211                        }
212                        ParameterValue::Double(d) => {
213                            let hldouble =
214                                hldouble::create(&mut builder, &hldoubleArgs { value: *d });
215                            let parameter = Parameter::create(
216                                &mut builder,
217                                &ParameterArgs {
218                                    value_type: FbParameterValue::hldouble,
219                                    value: Some(hldouble.as_union_value()),
220                                },
221                            );
222                            parameters.push(parameter);
223                        }
224                        ParameterValue::Bool(b) => {
225                            let hlbool: WIPOffset<hlbool<'_>> =
226                                hlbool::create(&mut builder, &hlboolArgs { value: *b });
227                            let parameter = Parameter::create(
228                                &mut builder,
229                                &ParameterArgs {
230                                    value_type: FbParameterValue::hlbool,
231                                    value: Some(hlbool.as_union_value()),
232                                },
233                            );
234                            parameters.push(parameter);
235                        }
236                        ParameterValue::String(s) => {
237                            let hlstring = {
238                                let val = builder.create_string(s.as_str());
239                                hlstring::create(&mut builder, &hlstringArgs { value: Some(val) })
240                            };
241                            let parameter = Parameter::create(
242                                &mut builder,
243                                &ParameterArgs {
244                                    value_type: FbParameterValue::hlstring,
245                                    value: Some(hlstring.as_union_value()),
246                                },
247                            );
248                            parameters.push(parameter);
249                        }
250                        ParameterValue::VecBytes(v) => {
251                            let vec_bytes = builder.create_vector(v);
252
253                            let hlvecbytes = hlvecbytes::create(
254                                &mut builder,
255                                &hlvecbytesArgs {
256                                    value: Some(vec_bytes),
257                                },
258                            );
259                            let parameter = Parameter::create(
260                                &mut builder,
261                                &ParameterArgs {
262                                    value_type: FbParameterValue::hlvecbytes,
263                                    value: Some(hlvecbytes.as_union_value()),
264                                },
265                            );
266                            parameters.push(parameter);
267                        }
268                    }
269                }
270                parameters
271            }
272            None => Vec::new(),
273        };
274
275        let parameters = if !parameters.is_empty() {
276            Some(builder.create_vector(&parameters))
277        } else {
278            None
279        };
280
281        let function_call = FbFunctionCall::create(
282            &mut builder,
283            &FbFunctionCallArgs {
284                function_name: Some(function_name),
285                parameters,
286                function_call_type,
287                expected_return_type,
288            },
289        );
290        builder.finish_size_prefixed(function_call, None);
291        let res = builder.finished_data().to_vec();
292
293        Ok(res)
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use alloc::vec;
300
301    use super::*;
302    use crate::flatbuffer_wrappers::function_types::ReturnType;
303
304    #[test]
305    fn read_from_flatbuffer() -> Result<()> {
306        let test_data: Vec<u8> = FunctionCall::new(
307            "PrintTwelveArgs".to_string(),
308            Some(vec![
309                ParameterValue::String("1".to_string()),
310                ParameterValue::Int(2),
311                ParameterValue::Long(3),
312                ParameterValue::String("4".to_string()),
313                ParameterValue::String("5".to_string()),
314                ParameterValue::Bool(true),
315                ParameterValue::Bool(false),
316                ParameterValue::UInt(8),
317                ParameterValue::ULong(9),
318                ParameterValue::Int(10),
319                ParameterValue::Float(3.123),
320                ParameterValue::Double(0.01),
321            ]),
322            FunctionCallType::Guest,
323            ReturnType::Int,
324        )
325        .try_into()
326        .unwrap();
327
328        let function_call = FunctionCall::try_from(test_data.as_slice())?;
329        assert_eq!(function_call.function_name, "PrintTwelveArgs");
330        assert!(function_call.parameters.is_some());
331        let parameters = function_call.parameters.unwrap();
332        assert_eq!(parameters.len(), 12);
333        let expected_parameters = vec![
334            ParameterValue::String("1".to_string()),
335            ParameterValue::Int(2),
336            ParameterValue::Long(3),
337            ParameterValue::String("4".to_string()),
338            ParameterValue::String("5".to_string()),
339            ParameterValue::Bool(true),
340            ParameterValue::Bool(false),
341            ParameterValue::UInt(8),
342            ParameterValue::ULong(9),
343            ParameterValue::Int(10),
344            ParameterValue::Float(3.123),
345            ParameterValue::Double(0.01),
346        ];
347        assert!(expected_parameters == parameters);
348        assert_eq!(function_call.function_call_type, FunctionCallType::Guest);
349
350        Ok(())
351    }
352}