use std::convert::TryInto;
use crate::runtime::Function;
use crate::{runtime::function::Result, runtime::ByteArray, Context, Module, NDArray};
pub struct GraphRt {
module: Module,
}
impl GraphRt {
pub fn from_module(module: Module, ctx: Context) -> Result<GraphRt> {
let default: Box<dyn Fn(Context) -> Result<Module>> =
module.get_function("default", false)?.into();
Ok(Self {
module: default(ctx)?,
})
}
pub fn create_from_parts(graph: &str, lib: Module, ctx: Context) -> Result<Self> {
let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
let runtime_create_fn_ret = runtime_create_fn.invoke(vec![
graph.into(),
lib.into(),
(&ctx.device_type).into(),
(ctx.device_id as i32).into(),
]);
let graph_runtime_module: Module = runtime_create_fn_ret?.try_into()?;
Ok(Self {
module: graph_runtime_module,
})
}
pub fn load_params<P>(&mut self, params: P) -> Result<()>
where
P: Into<ByteArray>,
{
let load_param_fn = self.module.get_function("load_params", false)?;
let params: ByteArray = params.into();
load_param_fn.invoke(vec![(¶ms).into()])?;
Ok(())
}
pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> {
let ref set_input_fn = self.module.get_function("set_input", false)?;
set_input_fn.invoke(vec![name.into(), input.into()])?;
Ok(())
}
pub fn run(&mut self) -> Result<()> {
let ref run_fn = self.module.get_function("run", false)?;
run_fn.invoke(vec![])?;
Ok(())
}
pub fn get_output(&mut self, i: i64) -> Result<NDArray> {
let get_output_fn = self.module.get_function("get_output", false)?;
get_output_fn.invoke(vec![i.into()])?.try_into()
}
pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> {
let get_output_fn = self.module.get_function("get_output", false)?;
get_output_fn.invoke(vec![i.into(), output.into()])?;
Ok(())
}
}