hyperlight_common/flatbuffer_wrappers/
host_function_definition.rs

1/*
2Copyright 2024 The Hyperlight Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17use alloc::string::{String, ToString};
18use alloc::vec::Vec;
19
20use anyhow::{anyhow, Error, Result};
21use flatbuffers::{FlatBufferBuilder, WIPOffset};
22#[cfg(feature = "tracing")]
23use tracing::{instrument, Span};
24
25use super::function_types::{ParameterType, ReturnType};
26use crate::flatbuffers::hyperlight::generated::{
27    HostFunctionDefinition as FbHostFunctionDefinition,
28    HostFunctionDefinitionArgs as FbHostFunctionDefinitionArgs, ParameterType as FbParameterType,
29};
30
31/// The definition of a function exposed from the host to the guest
32#[derive(Debug, Default, Clone, PartialEq, Eq)]
33pub struct HostFunctionDefinition {
34    /// The function name
35    pub function_name: String,
36    /// The type of the parameter values for the host function call.
37    pub parameter_types: Option<Vec<ParameterType>>,
38    /// The type of the return value from the host function call
39    pub return_type: ReturnType,
40}
41
42impl HostFunctionDefinition {
43    /// Create a new `HostFunctionDefinition`.
44    #[cfg_attr(feature = "tracing", instrument(skip_all, parent = Span::current(), level= "Trace"))]
45    pub fn new(
46        function_name: String,
47        parameter_types: Option<Vec<ParameterType>>,
48        return_type: ReturnType,
49    ) -> Self {
50        Self {
51            function_name,
52            parameter_types,
53            return_type,
54        }
55    }
56
57    /// Convert this `HostFunctionDefinition` into a `WIPOffset<FbHostFunctionDefinition>`.
58    #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
59    pub(crate) fn convert_to_flatbuffer_def<'a>(
60        &self,
61        builder: &mut FlatBufferBuilder<'a>,
62    ) -> Result<WIPOffset<FbHostFunctionDefinition<'a>>> {
63        let host_function_name = builder.create_string(&self.function_name);
64        let return_value_type = self.return_type.into();
65        let vec_parameters = match &self.parameter_types {
66            Some(vec_pvt) => {
67                let num_items = vec_pvt.len();
68                let mut parameters: Vec<FbParameterType> = Vec::with_capacity(num_items);
69                for pvt in vec_pvt {
70                    let fb_pvt = pvt.clone().into();
71                    parameters.push(fb_pvt);
72                }
73                Some(builder.create_vector(&parameters))
74            }
75            None => None,
76        };
77
78        let fb_host_function_definition: WIPOffset<FbHostFunctionDefinition> =
79            FbHostFunctionDefinition::create(
80                builder,
81                &FbHostFunctionDefinitionArgs {
82                    function_name: Some(host_function_name),
83                    return_type: return_value_type,
84                    parameters: vec_parameters,
85                },
86            );
87
88        Ok(fb_host_function_definition)
89    }
90
91    /// Verify that the function call has the correct parameter types.
92    #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
93    pub fn verify_equal_parameter_types(
94        &self,
95        function_call_parameter_types: &[ParameterType],
96    ) -> Result<()> {
97        if let Some(parameter_types) = &self.parameter_types {
98            for (i, parameter_type) in parameter_types.iter().enumerate() {
99                if parameter_type != &function_call_parameter_types[i] {
100                    return Err(anyhow!("Incorrect parameter type for parameter {}", i + 1));
101                }
102            }
103        }
104        Ok(())
105    }
106}
107
108impl TryFrom<&FbHostFunctionDefinition<'_>> for HostFunctionDefinition {
109    type Error = Error;
110    #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
111    fn try_from(value: &FbHostFunctionDefinition) -> Result<Self> {
112        let function_name = value.function_name().to_string();
113        let return_type = value.return_type().try_into().map_err(|_| {
114            anyhow!(
115                "Failed to convert return type for function {}",
116                function_name
117            )
118        })?;
119        let parameter_types = match value.parameters() {
120            Some(pvt) => {
121                let len = pvt.len();
122                let mut pv: Vec<ParameterType> = Vec::with_capacity(len);
123                for fb_pvt in pvt {
124                    let pvt: ParameterType = fb_pvt.try_into().map_err(|_| {
125                        anyhow!(
126                            "Failed to convert parameter type for function {}",
127                            function_name
128                        )
129                    })?;
130                    pv.push(pvt);
131                }
132                Some(pv)
133            }
134            None => None,
135        };
136
137        Ok(Self::new(function_name, parameter_types, return_type))
138    }
139}
140
141impl TryFrom<&[u8]> for HostFunctionDefinition {
142    type Error = Error;
143    #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
144    fn try_from(value: &[u8]) -> Result<Self> {
145        let fb_host_function_definition = flatbuffers::root::<FbHostFunctionDefinition<'_>>(value)
146            .map_err(|e| anyhow!("Error while reading HostFunctionDefinition: {:?}", e))?;
147        Self::try_from(&fb_host_function_definition)
148    }
149}
150
151impl TryFrom<&HostFunctionDefinition> for Vec<u8> {
152    type Error = Error;
153    #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
154    fn try_from(hfd: &HostFunctionDefinition) -> Result<Vec<u8>> {
155        let mut builder = flatbuffers::FlatBufferBuilder::new();
156        let host_function_definition = hfd.convert_to_flatbuffer_def(&mut builder)?;
157        builder.finish_size_prefixed(host_function_definition, None);
158        Ok(builder.finished_data().to_vec())
159    }
160}