#![feature(fs_read_write, try_from)]
extern crate serde;
extern crate serde_json;
extern crate tvm;
use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
use tvm::runtime::Graph;
#[test]
fn test_load_graph() {
let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
.unwrap()
.read_to_end(&mut params_bytes)
.unwrap();
let params = tvm::runtime::load_param_dict(¶ms_bytes);
let graph = Graph::try_from(
&fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
).unwrap();
assert_eq!(graph.nodes[3].op, "tvm_op");
assert_eq!(
graph.nodes[3]
.attrs
.as_ref()
.unwrap()
.get("func_name")
.unwrap(),
"fuse_dense"
);
assert_eq!(graph.nodes[5].inputs[0].index, 0);
assert_eq!(graph.nodes[6].inputs[0].index, 1);
assert_eq!(graph.heads.len(), 2);
}