hyperlight_common/flatbuffer_wrappers/
function_call.rs

1/*
2Copyright 2025  The Hyperlight Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17use alloc::string::{String, ToString};
18use alloc::vec::Vec;
19
20use anyhow::{Error, Result, bail};
21use flatbuffers::{FlatBufferBuilder, WIPOffset, size_prefixed_root};
22#[cfg(feature = "tracing")]
23use tracing::{Span, instrument};
24
25use super::function_types::{ParameterValue, ReturnType};
26use crate::flatbuffers::hyperlight::generated::{
27    FunctionCall as FbFunctionCall, FunctionCallArgs as FbFunctionCallArgs,
28    FunctionCallType as FbFunctionCallType, Parameter, ParameterArgs,
29    ParameterValue as FbParameterValue, hlbool, hlboolArgs, hldouble, hldoubleArgs, hlfloat,
30    hlfloatArgs, hlint, hlintArgs, hllong, hllongArgs, hlstring, hlstringArgs, hluint, hluintArgs,
31    hlulong, hlulongArgs, hlvecbytes, hlvecbytesArgs,
32};
33
34/// The type of function call.
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum FunctionCallType {
37    /// The function call is to a guest function.
38    Guest,
39    /// The function call is to a host function.
40    Host,
41}
42
43/// `Functioncall` represents a call to a function in the guest or host.
44#[derive(Clone)]
45pub struct FunctionCall {
46    /// The function name
47    pub function_name: String,
48    /// The parameters for the function call.
49    pub parameters: Option<Vec<ParameterValue>>,
50    function_call_type: FunctionCallType,
51    /// The return type of the function call
52    pub expected_return_type: ReturnType,
53}
54
55impl FunctionCall {
56    #[cfg_attr(feature = "tracing", instrument(skip_all, parent = Span::current(), level= "Trace"))]
57    pub fn new(
58        function_name: String,
59        parameters: Option<Vec<ParameterValue>>,
60        function_call_type: FunctionCallType,
61        expected_return_type: ReturnType,
62    ) -> Self {
63        Self {
64            function_name,
65            parameters,
66            function_call_type,
67            expected_return_type,
68        }
69    }
70
71    /// The type of the function call.
72    pub fn function_call_type(&self) -> FunctionCallType {
73        self.function_call_type.clone()
74    }
75
76    /// Encodes self into the given builder and returns the encoded data.
77    ///
78    /// # Notes
79    ///
80    /// The builder should not be reused after a call to encode, since this function
81    /// does not reset the state of the builder. If you want to reuse the builder,
82    /// you'll need to reset it first.
83    pub fn encode<'a>(&self, builder: &'a mut FlatBufferBuilder) -> &'a [u8] {
84        let function_name = builder.create_string(&self.function_name);
85
86        let function_call_type = match self.function_call_type {
87            FunctionCallType::Guest => FbFunctionCallType::guest,
88            FunctionCallType::Host => FbFunctionCallType::host,
89        };
90
91        let expected_return_type = self.expected_return_type.into();
92
93        let parameters = match &self.parameters {
94            Some(p) if !p.is_empty() => {
95                let parameter_offsets: Vec<WIPOffset<Parameter>> = p
96                    .iter()
97                    .map(|param| match param {
98                        ParameterValue::Int(i) => {
99                            let hlint = hlint::create(builder, &hlintArgs { value: *i });
100                            Parameter::create(
101                                builder,
102                                &ParameterArgs {
103                                    value_type: FbParameterValue::hlint,
104                                    value: Some(hlint.as_union_value()),
105                                },
106                            )
107                        }
108                        ParameterValue::UInt(ui) => {
109                            let hluint = hluint::create(builder, &hluintArgs { value: *ui });
110                            Parameter::create(
111                                builder,
112                                &ParameterArgs {
113                                    value_type: FbParameterValue::hluint,
114                                    value: Some(hluint.as_union_value()),
115                                },
116                            )
117                        }
118                        ParameterValue::Long(l) => {
119                            let hllong = hllong::create(builder, &hllongArgs { value: *l });
120                            Parameter::create(
121                                builder,
122                                &ParameterArgs {
123                                    value_type: FbParameterValue::hllong,
124                                    value: Some(hllong.as_union_value()),
125                                },
126                            )
127                        }
128                        ParameterValue::ULong(ul) => {
129                            let hlulong = hlulong::create(builder, &hlulongArgs { value: *ul });
130                            Parameter::create(
131                                builder,
132                                &ParameterArgs {
133                                    value_type: FbParameterValue::hlulong,
134                                    value: Some(hlulong.as_union_value()),
135                                },
136                            )
137                        }
138                        ParameterValue::Float(f) => {
139                            let hlfloat = hlfloat::create(builder, &hlfloatArgs { value: *f });
140                            Parameter::create(
141                                builder,
142                                &ParameterArgs {
143                                    value_type: FbParameterValue::hlfloat,
144                                    value: Some(hlfloat.as_union_value()),
145                                },
146                            )
147                        }
148                        ParameterValue::Double(d) => {
149                            let hldouble = hldouble::create(builder, &hldoubleArgs { value: *d });
150                            Parameter::create(
151                                builder,
152                                &ParameterArgs {
153                                    value_type: FbParameterValue::hldouble,
154                                    value: Some(hldouble.as_union_value()),
155                                },
156                            )
157                        }
158                        ParameterValue::Bool(b) => {
159                            let hlbool = hlbool::create(builder, &hlboolArgs { value: *b });
160                            Parameter::create(
161                                builder,
162                                &ParameterArgs {
163                                    value_type: FbParameterValue::hlbool,
164                                    value: Some(hlbool.as_union_value()),
165                                },
166                            )
167                        }
168                        ParameterValue::String(s) => {
169                            let val = builder.create_string(s.as_str());
170                            let hlstring =
171                                hlstring::create(builder, &hlstringArgs { value: Some(val) });
172                            Parameter::create(
173                                builder,
174                                &ParameterArgs {
175                                    value_type: FbParameterValue::hlstring,
176                                    value: Some(hlstring.as_union_value()),
177                                },
178                            )
179                        }
180                        ParameterValue::VecBytes(v) => {
181                            let vec_bytes = builder.create_vector(v);
182                            let hlvecbytes = hlvecbytes::create(
183                                builder,
184                                &hlvecbytesArgs {
185                                    value: Some(vec_bytes),
186                                },
187                            );
188                            Parameter::create(
189                                builder,
190                                &ParameterArgs {
191                                    value_type: FbParameterValue::hlvecbytes,
192                                    value: Some(hlvecbytes.as_union_value()),
193                                },
194                            )
195                        }
196                    })
197                    .collect();
198                Some(builder.create_vector(&parameter_offsets))
199            }
200            _ => None,
201        };
202
203        let function_call = FbFunctionCall::create(
204            builder,
205            &FbFunctionCallArgs {
206                function_name: Some(function_name),
207                parameters,
208                function_call_type,
209                expected_return_type,
210            },
211        );
212        builder.finish_size_prefixed(function_call, None);
213        builder.finished_data()
214    }
215}
216
217#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
218pub fn validate_guest_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> {
219    let guest_function_call_fb = size_prefixed_root::<FbFunctionCall>(function_call_buffer)
220        .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
221    match guest_function_call_fb.function_call_type() {
222        FbFunctionCallType::guest => Ok(()),
223        other => {
224            bail!("Invalid function call type: {:?}", other);
225        }
226    }
227}
228
229#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
230pub fn validate_host_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> {
231    let host_function_call_fb = size_prefixed_root::<FbFunctionCall>(function_call_buffer)
232        .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
233    match host_function_call_fb.function_call_type() {
234        FbFunctionCallType::host => Ok(()),
235        other => {
236            bail!("Invalid function call type: {:?}", other);
237        }
238    }
239}
240
241impl TryFrom<&[u8]> for FunctionCall {
242    type Error = Error;
243    #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
244    fn try_from(value: &[u8]) -> Result<Self> {
245        let function_call_fb = size_prefixed_root::<FbFunctionCall>(value)
246            .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
247        let function_name = function_call_fb.function_name();
248        let function_call_type = match function_call_fb.function_call_type() {
249            FbFunctionCallType::guest => FunctionCallType::Guest,
250            FbFunctionCallType::host => FunctionCallType::Host,
251            other => {
252                bail!("Invalid function call type: {:?}", other);
253            }
254        };
255        let expected_return_type = function_call_fb.expected_return_type().try_into()?;
256
257        let parameters = function_call_fb
258            .parameters()
259            .map(|v| {
260                v.iter()
261                    .map(|p| p.try_into())
262                    .collect::<Result<Vec<ParameterValue>>>()
263            })
264            .transpose()?;
265
266        Ok(Self {
267            function_name: function_name.to_string(),
268            parameters,
269            function_call_type,
270            expected_return_type,
271        })
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use alloc::vec;
278
279    use super::*;
280    use crate::flatbuffer_wrappers::function_types::ReturnType;
281
282    #[test]
283    fn read_from_flatbuffer() -> Result<()> {
284        let mut builder = FlatBufferBuilder::new();
285        let test_data = FunctionCall::new(
286            "PrintTwelveArgs".to_string(),
287            Some(vec![
288                ParameterValue::String("1".to_string()),
289                ParameterValue::Int(2),
290                ParameterValue::Long(3),
291                ParameterValue::String("4".to_string()),
292                ParameterValue::String("5".to_string()),
293                ParameterValue::Bool(true),
294                ParameterValue::Bool(false),
295                ParameterValue::UInt(8),
296                ParameterValue::ULong(9),
297                ParameterValue::Int(10),
298                ParameterValue::Float(3.123),
299                ParameterValue::Double(0.01),
300            ]),
301            FunctionCallType::Guest,
302            ReturnType::Int,
303        )
304        .encode(&mut builder);
305
306        let function_call = FunctionCall::try_from(test_data)?;
307        assert_eq!(function_call.function_name, "PrintTwelveArgs");
308        assert!(function_call.parameters.is_some());
309        let parameters = function_call.parameters.unwrap();
310        assert_eq!(parameters.len(), 12);
311        let expected_parameters = vec![
312            ParameterValue::String("1".to_string()),
313            ParameterValue::Int(2),
314            ParameterValue::Long(3),
315            ParameterValue::String("4".to_string()),
316            ParameterValue::String("5".to_string()),
317            ParameterValue::Bool(true),
318            ParameterValue::Bool(false),
319            ParameterValue::UInt(8),
320            ParameterValue::ULong(9),
321            ParameterValue::Int(10),
322            ParameterValue::Float(3.123),
323            ParameterValue::Double(0.01),
324        ];
325        assert!(expected_parameters == parameters);
326        assert_eq!(function_call.function_call_type, FunctionCallType::Guest);
327
328        Ok(())
329    }
330}