use std::collections::HashMap;
use std::io::{IsTerminal, Write};
use hyperlight_common::flatbuffer_wrappers::function_types::{
ParameterType, ParameterValue, ReturnType, ReturnValue,
};
use hyperlight_common::flatbuffer_wrappers::host_function_definition::HostFunctionDefinition;
use hyperlight_common::flatbuffer_wrappers::host_function_details::HostFunctionDetails;
use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor};
use tracing::{Span, instrument};
use crate::HyperlightError::HostFunctionNotFound;
use crate::Result;
use crate::func::host_functions::TypeErasedHostFunction;
#[derive(Default)]
pub struct FunctionRegistry {
functions_map: HashMap<String, FunctionEntry>,
}
impl From<&mut FunctionRegistry> for HostFunctionDetails {
fn from(registry: &mut FunctionRegistry) -> Self {
let host_functions = registry
.functions_map
.iter()
.map(|(name, entry)| HostFunctionDefinition {
function_name: name.clone(),
parameter_types: Some(entry.parameter_types.to_vec()),
return_type: entry.return_type,
})
.collect();
HostFunctionDetails {
host_functions: Some(host_functions),
}
}
}
pub struct FunctionEntry {
pub function: TypeErasedHostFunction,
pub parameter_types: &'static [ParameterType],
pub return_type: ReturnType,
}
impl FunctionRegistry {
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
pub(crate) fn register_host_function(
&mut self,
name: String,
func: FunctionEntry,
) -> Result<()> {
self.functions_map.insert(name, func);
Ok(())
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
#[allow(dead_code)]
pub(super) fn host_print(&mut self, msg: String) -> Result<i32> {
let res = self.call_host_func_impl("HostPrint", vec![ParameterValue::String(msg)])?;
res.try_into()
.map_err(|_| HostFunctionNotFound("HostPrint".to_string()))
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
pub(super) fn call_host_function(
&self,
name: &str,
args: Vec<ParameterValue>,
) -> Result<ReturnValue> {
self.call_host_func_impl(name, args)
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
fn call_host_func_impl(&self, name: &str, args: Vec<ParameterValue>) -> Result<ReturnValue> {
let FunctionEntry {
function,
parameter_types: _,
return_type: _,
} = self
.functions_map
.get(name)
.ok_or_else(|| HostFunctionNotFound(name.to_string()))?;
crate::metrics::maybe_time_and_emit_host_call(name, || function.call(args))
}
}
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
pub(super) fn default_writer_func(s: String) -> Result<i32> {
match std::io::stdout().is_terminal() {
false => {
print!("{}", s);
Ok(s.len() as i32)
}
true => {
let mut stdout = StandardStream::stdout(ColorChoice::Auto);
let mut color_spec = ColorSpec::new();
color_spec.set_fg(Some(Color::Green));
stdout.set_color(&color_spec)?;
stdout.write_all(s.as_bytes())?;
stdout.reset()?;
Ok(s.len() as i32)
}
}
}