1use alloc::string::{String, ToString};
18use alloc::vec::Vec;
19
20use anyhow::{Error, Result, bail};
21use flatbuffers::{FlatBufferBuilder, WIPOffset, size_prefixed_root};
22#[cfg(feature = "tracing")]
23use tracing::{Span, instrument};
24
25use super::function_types::{ParameterValue, ReturnType};
26use crate::flatbuffers::hyperlight::generated::{
27 FunctionCall as FbFunctionCall, FunctionCallArgs as FbFunctionCallArgs,
28 FunctionCallType as FbFunctionCallType, Parameter, ParameterArgs,
29 ParameterValue as FbParameterValue, hlbool, hlboolArgs, hldouble, hldoubleArgs, hlfloat,
30 hlfloatArgs, hlint, hlintArgs, hllong, hllongArgs, hlstring, hlstringArgs, hluint, hluintArgs,
31 hlulong, hlulongArgs, hlvecbytes, hlvecbytesArgs,
32};
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum FunctionCallType {
37 Guest,
39 Host,
41}
42
43#[derive(Clone)]
45pub struct FunctionCall {
46 pub function_name: String,
48 pub parameters: Option<Vec<ParameterValue>>,
50 function_call_type: FunctionCallType,
51 pub expected_return_type: ReturnType,
53}
54
55impl FunctionCall {
56 #[cfg_attr(feature = "tracing", instrument(skip_all, parent = Span::current(), level= "Trace"))]
57 pub fn new(
58 function_name: String,
59 parameters: Option<Vec<ParameterValue>>,
60 function_call_type: FunctionCallType,
61 expected_return_type: ReturnType,
62 ) -> Self {
63 Self {
64 function_name,
65 parameters,
66 function_call_type,
67 expected_return_type,
68 }
69 }
70
71 pub fn function_call_type(&self) -> FunctionCallType {
73 self.function_call_type.clone()
74 }
75
76 pub fn encode<'a>(&self, builder: &'a mut FlatBufferBuilder) -> &'a [u8] {
84 let function_name = builder.create_string(&self.function_name);
85
86 let function_call_type = match self.function_call_type {
87 FunctionCallType::Guest => FbFunctionCallType::guest,
88 FunctionCallType::Host => FbFunctionCallType::host,
89 };
90
91 let expected_return_type = self.expected_return_type.into();
92
93 let parameters = match &self.parameters {
94 Some(p) if !p.is_empty() => {
95 let parameter_offsets: Vec<WIPOffset<Parameter>> = p
96 .iter()
97 .map(|param| match param {
98 ParameterValue::Int(i) => {
99 let hlint = hlint::create(builder, &hlintArgs { value: *i });
100 Parameter::create(
101 builder,
102 &ParameterArgs {
103 value_type: FbParameterValue::hlint,
104 value: Some(hlint.as_union_value()),
105 },
106 )
107 }
108 ParameterValue::UInt(ui) => {
109 let hluint = hluint::create(builder, &hluintArgs { value: *ui });
110 Parameter::create(
111 builder,
112 &ParameterArgs {
113 value_type: FbParameterValue::hluint,
114 value: Some(hluint.as_union_value()),
115 },
116 )
117 }
118 ParameterValue::Long(l) => {
119 let hllong = hllong::create(builder, &hllongArgs { value: *l });
120 Parameter::create(
121 builder,
122 &ParameterArgs {
123 value_type: FbParameterValue::hllong,
124 value: Some(hllong.as_union_value()),
125 },
126 )
127 }
128 ParameterValue::ULong(ul) => {
129 let hlulong = hlulong::create(builder, &hlulongArgs { value: *ul });
130 Parameter::create(
131 builder,
132 &ParameterArgs {
133 value_type: FbParameterValue::hlulong,
134 value: Some(hlulong.as_union_value()),
135 },
136 )
137 }
138 ParameterValue::Float(f) => {
139 let hlfloat = hlfloat::create(builder, &hlfloatArgs { value: *f });
140 Parameter::create(
141 builder,
142 &ParameterArgs {
143 value_type: FbParameterValue::hlfloat,
144 value: Some(hlfloat.as_union_value()),
145 },
146 )
147 }
148 ParameterValue::Double(d) => {
149 let hldouble = hldouble::create(builder, &hldoubleArgs { value: *d });
150 Parameter::create(
151 builder,
152 &ParameterArgs {
153 value_type: FbParameterValue::hldouble,
154 value: Some(hldouble.as_union_value()),
155 },
156 )
157 }
158 ParameterValue::Bool(b) => {
159 let hlbool = hlbool::create(builder, &hlboolArgs { value: *b });
160 Parameter::create(
161 builder,
162 &ParameterArgs {
163 value_type: FbParameterValue::hlbool,
164 value: Some(hlbool.as_union_value()),
165 },
166 )
167 }
168 ParameterValue::String(s) => {
169 let val = builder.create_string(s.as_str());
170 let hlstring =
171 hlstring::create(builder, &hlstringArgs { value: Some(val) });
172 Parameter::create(
173 builder,
174 &ParameterArgs {
175 value_type: FbParameterValue::hlstring,
176 value: Some(hlstring.as_union_value()),
177 },
178 )
179 }
180 ParameterValue::VecBytes(v) => {
181 let vec_bytes = builder.create_vector(v);
182 let hlvecbytes = hlvecbytes::create(
183 builder,
184 &hlvecbytesArgs {
185 value: Some(vec_bytes),
186 },
187 );
188 Parameter::create(
189 builder,
190 &ParameterArgs {
191 value_type: FbParameterValue::hlvecbytes,
192 value: Some(hlvecbytes.as_union_value()),
193 },
194 )
195 }
196 })
197 .collect();
198 Some(builder.create_vector(¶meter_offsets))
199 }
200 _ => None,
201 };
202
203 let function_call = FbFunctionCall::create(
204 builder,
205 &FbFunctionCallArgs {
206 function_name: Some(function_name),
207 parameters,
208 function_call_type,
209 expected_return_type,
210 },
211 );
212 builder.finish_size_prefixed(function_call, None);
213 builder.finished_data()
214 }
215}
216
217#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
218pub fn validate_guest_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> {
219 let guest_function_call_fb = size_prefixed_root::<FbFunctionCall>(function_call_buffer)
220 .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
221 match guest_function_call_fb.function_call_type() {
222 FbFunctionCallType::guest => Ok(()),
223 other => {
224 bail!("Invalid function call type: {:?}", other);
225 }
226 }
227}
228
229#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
230pub fn validate_host_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> {
231 let host_function_call_fb = size_prefixed_root::<FbFunctionCall>(function_call_buffer)
232 .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
233 match host_function_call_fb.function_call_type() {
234 FbFunctionCallType::host => Ok(()),
235 other => {
236 bail!("Invalid function call type: {:?}", other);
237 }
238 }
239}
240
241impl TryFrom<&[u8]> for FunctionCall {
242 type Error = Error;
243 #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
244 fn try_from(value: &[u8]) -> Result<Self> {
245 let function_call_fb = size_prefixed_root::<FbFunctionCall>(value)
246 .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
247 let function_name = function_call_fb.function_name();
248 let function_call_type = match function_call_fb.function_call_type() {
249 FbFunctionCallType::guest => FunctionCallType::Guest,
250 FbFunctionCallType::host => FunctionCallType::Host,
251 other => {
252 bail!("Invalid function call type: {:?}", other);
253 }
254 };
255 let expected_return_type = function_call_fb.expected_return_type().try_into()?;
256
257 let parameters = function_call_fb
258 .parameters()
259 .map(|v| {
260 v.iter()
261 .map(|p| p.try_into())
262 .collect::<Result<Vec<ParameterValue>>>()
263 })
264 .transpose()?;
265
266 Ok(Self {
267 function_name: function_name.to_string(),
268 parameters,
269 function_call_type,
270 expected_return_type,
271 })
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use alloc::vec;
278
279 use super::*;
280 use crate::flatbuffer_wrappers::function_types::ReturnType;
281
282 #[test]
283 fn read_from_flatbuffer() -> Result<()> {
284 let mut builder = FlatBufferBuilder::new();
285 let test_data = FunctionCall::new(
286 "PrintTwelveArgs".to_string(),
287 Some(vec![
288 ParameterValue::String("1".to_string()),
289 ParameterValue::Int(2),
290 ParameterValue::Long(3),
291 ParameterValue::String("4".to_string()),
292 ParameterValue::String("5".to_string()),
293 ParameterValue::Bool(true),
294 ParameterValue::Bool(false),
295 ParameterValue::UInt(8),
296 ParameterValue::ULong(9),
297 ParameterValue::Int(10),
298 ParameterValue::Float(3.123),
299 ParameterValue::Double(0.01),
300 ]),
301 FunctionCallType::Guest,
302 ReturnType::Int,
303 )
304 .encode(&mut builder);
305
306 let function_call = FunctionCall::try_from(test_data)?;
307 assert_eq!(function_call.function_name, "PrintTwelveArgs");
308 assert!(function_call.parameters.is_some());
309 let parameters = function_call.parameters.unwrap();
310 assert_eq!(parameters.len(), 12);
311 let expected_parameters = vec![
312 ParameterValue::String("1".to_string()),
313 ParameterValue::Int(2),
314 ParameterValue::Long(3),
315 ParameterValue::String("4".to_string()),
316 ParameterValue::String("5".to_string()),
317 ParameterValue::Bool(true),
318 ParameterValue::Bool(false),
319 ParameterValue::UInt(8),
320 ParameterValue::ULong(9),
321 ParameterValue::Int(10),
322 ParameterValue::Float(3.123),
323 ParameterValue::Double(0.01),
324 ];
325 assert!(expected_parameters == parameters);
326 assert_eq!(function_call.function_call_type, FunctionCallType::Guest);
327
328 Ok(())
329 }
330}