use alloc::string::{String, ToString};
use alloc::vec::Vec;
use anyhow::{Error, Result, bail};
use flatbuffers::{FlatBufferBuilder, WIPOffset, size_prefixed_root};
#[cfg(feature = "tracing")]
use tracing::{Span, instrument};
use super::function_types::{ParameterValue, ReturnType};
use crate::flatbuffers::hyperlight::generated::{
FunctionCall as FbFunctionCall, FunctionCallArgs as FbFunctionCallArgs,
FunctionCallType as FbFunctionCallType, Parameter, ParameterArgs,
ParameterValue as FbParameterValue, hlbool, hlboolArgs, hldouble, hldoubleArgs, hlfloat,
hlfloatArgs, hlint, hlintArgs, hllong, hllongArgs, hlstring, hlstringArgs, hluint, hluintArgs,
hlulong, hlulongArgs, hlvecbytes, hlvecbytesArgs,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FunctionCallType {
Guest,
Host,
}
#[derive(Clone)]
pub struct FunctionCall {
pub function_name: String,
pub parameters: Option<Vec<ParameterValue>>,
function_call_type: FunctionCallType,
pub expected_return_type: ReturnType,
}
impl FunctionCall {
#[cfg_attr(feature = "tracing", instrument(skip_all, parent = Span::current(), level= "Trace"))]
pub fn new(
function_name: String,
parameters: Option<Vec<ParameterValue>>,
function_call_type: FunctionCallType,
expected_return_type: ReturnType,
) -> Self {
Self {
function_name,
parameters,
function_call_type,
expected_return_type,
}
}
pub fn function_call_type(&self) -> FunctionCallType {
self.function_call_type.clone()
}
pub fn encode<'a>(&self, builder: &'a mut FlatBufferBuilder) -> &'a [u8] {
let function_name = builder.create_string(&self.function_name);
let function_call_type = match self.function_call_type {
FunctionCallType::Guest => FbFunctionCallType::guest,
FunctionCallType::Host => FbFunctionCallType::host,
};
let expected_return_type = self.expected_return_type.into();
let parameters = match &self.parameters {
Some(p) if !p.is_empty() => {
let parameter_offsets: Vec<WIPOffset<Parameter>> = p
.iter()
.map(|param| match param {
ParameterValue::Int(i) => {
let hlint = hlint::create(builder, &hlintArgs { value: *i });
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hlint,
value: Some(hlint.as_union_value()),
},
)
}
ParameterValue::UInt(ui) => {
let hluint = hluint::create(builder, &hluintArgs { value: *ui });
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hluint,
value: Some(hluint.as_union_value()),
},
)
}
ParameterValue::Long(l) => {
let hllong = hllong::create(builder, &hllongArgs { value: *l });
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hllong,
value: Some(hllong.as_union_value()),
},
)
}
ParameterValue::ULong(ul) => {
let hlulong = hlulong::create(builder, &hlulongArgs { value: *ul });
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hlulong,
value: Some(hlulong.as_union_value()),
},
)
}
ParameterValue::Float(f) => {
let hlfloat = hlfloat::create(builder, &hlfloatArgs { value: *f });
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hlfloat,
value: Some(hlfloat.as_union_value()),
},
)
}
ParameterValue::Double(d) => {
let hldouble = hldouble::create(builder, &hldoubleArgs { value: *d });
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hldouble,
value: Some(hldouble.as_union_value()),
},
)
}
ParameterValue::Bool(b) => {
let hlbool = hlbool::create(builder, &hlboolArgs { value: *b });
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hlbool,
value: Some(hlbool.as_union_value()),
},
)
}
ParameterValue::String(s) => {
let val = builder.create_string(s.as_str());
let hlstring =
hlstring::create(builder, &hlstringArgs { value: Some(val) });
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hlstring,
value: Some(hlstring.as_union_value()),
},
)
}
ParameterValue::VecBytes(v) => {
let vec_bytes = builder.create_vector(v);
let hlvecbytes = hlvecbytes::create(
builder,
&hlvecbytesArgs {
value: Some(vec_bytes),
},
);
Parameter::create(
builder,
&ParameterArgs {
value_type: FbParameterValue::hlvecbytes,
value: Some(hlvecbytes.as_union_value()),
},
)
}
})
.collect();
Some(builder.create_vector(¶meter_offsets))
}
_ => None,
};
let function_call = FbFunctionCall::create(
builder,
&FbFunctionCallArgs {
function_name: Some(function_name),
parameters,
function_call_type,
expected_return_type,
},
);
builder.finish_size_prefixed(function_call, None);
builder.finished_data()
}
}
#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
pub fn validate_guest_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> {
let guest_function_call_fb = size_prefixed_root::<FbFunctionCall>(function_call_buffer)
.map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
match guest_function_call_fb.function_call_type() {
FbFunctionCallType::guest => Ok(()),
other => {
bail!("Invalid function call type: {:?}", other);
}
}
}
#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
pub fn validate_host_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> {
let host_function_call_fb = size_prefixed_root::<FbFunctionCall>(function_call_buffer)
.map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
match host_function_call_fb.function_call_type() {
FbFunctionCallType::host => Ok(()),
other => {
bail!("Invalid function call type: {:?}", other);
}
}
}
impl TryFrom<&[u8]> for FunctionCall {
type Error = Error;
#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))]
fn try_from(value: &[u8]) -> Result<Self> {
let function_call_fb = size_prefixed_root::<FbFunctionCall>(value)
.map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?;
let function_name = function_call_fb.function_name();
let function_call_type = match function_call_fb.function_call_type() {
FbFunctionCallType::guest => FunctionCallType::Guest,
FbFunctionCallType::host => FunctionCallType::Host,
other => {
bail!("Invalid function call type: {:?}", other);
}
};
let expected_return_type = function_call_fb.expected_return_type().try_into()?;
let parameters = function_call_fb
.parameters()
.map(|v| {
v.iter()
.map(|p| p.try_into())
.collect::<Result<Vec<ParameterValue>>>()
})
.transpose()?;
Ok(Self {
function_name: function_name.to_string(),
parameters,
function_call_type,
expected_return_type,
})
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use super::*;
use crate::flatbuffer_wrappers::function_types::ReturnType;
#[test]
fn read_from_flatbuffer() -> Result<()> {
let mut builder = FlatBufferBuilder::new();
let test_data = FunctionCall::new(
"PrintTwelveArgs".to_string(),
Some(vec![
ParameterValue::String("1".to_string()),
ParameterValue::Int(2),
ParameterValue::Long(3),
ParameterValue::String("4".to_string()),
ParameterValue::String("5".to_string()),
ParameterValue::Bool(true),
ParameterValue::Bool(false),
ParameterValue::UInt(8),
ParameterValue::ULong(9),
ParameterValue::Int(10),
ParameterValue::Float(3.123),
ParameterValue::Double(0.01),
]),
FunctionCallType::Guest,
ReturnType::Int,
)
.encode(&mut builder);
let function_call = FunctionCall::try_from(test_data)?;
assert_eq!(function_call.function_name, "PrintTwelveArgs");
assert!(function_call.parameters.is_some());
let parameters = function_call.parameters.unwrap();
assert_eq!(parameters.len(), 12);
let expected_parameters = vec![
ParameterValue::String("1".to_string()),
ParameterValue::Int(2),
ParameterValue::Long(3),
ParameterValue::String("4".to_string()),
ParameterValue::String("5".to_string()),
ParameterValue::Bool(true),
ParameterValue::Bool(false),
ParameterValue::UInt(8),
ParameterValue::ULong(9),
ParameterValue::Int(10),
ParameterValue::Float(3.123),
ParameterValue::Double(0.01),
];
assert!(expected_parameters == parameters);
assert_eq!(function_call.function_call_type, FunctionCallType::Guest);
Ok(())
}
}