use std::collections::HashMap;
use std::sync::Arc;
use std::{fs, path};
use model::{Model, Node, OutletId, RawModel};
use ops;
use tfpb;
use Result;
pub fn for_path<P: AsRef<path::Path>>(p: P) -> Result<Model> {
for_reader(fs::File::open(p)?)
}
pub fn for_reader<R: ::std::io::Read>(r: R) -> Result<Model> {
from_tf(graphdef_for_reader(r)?)
}
pub fn graphdef_for_reader<R: ::std::io::Read>(mut r: R) -> Result<::tfpb::graph::GraphDef> {
Ok(::protobuf::parse_from_reader::<::tfpb::graph::GraphDef>(
&mut r,
)?)
}
pub fn graphdef_for_path<P: AsRef<path::Path>>(p: P) -> Result<::tfpb::graph::GraphDef> {
graphdef_for_reader(fs::File::open(p)?)
}
pub fn from_tf(graph: tfpb::graph::GraphDef) -> Result<Model> {
let mut nodes = vec![];
let mut nodes_by_name: HashMap<String, usize> = HashMap::new();
let op_builder = ops::OpBuilder::new();
for pbnode in graph.get_node().iter() {
let name = pbnode.get_name().to_string();
let inputs: Vec<OutletId> = pbnode
.get_input()
.iter()
.map(|i| {
let input: (usize, usize) = if i.starts_with("^") {
(
nodes_by_name
.get(&*i.replace("^", ""))
.ok_or(format!("No node {} found", i))?
.clone(),
0,
)
} else {
let splits: Vec<_> = i.splitn(2, ':').collect();
(
nodes_by_name
.get(splits[0])
.ok_or(format!("No node {} found", i))?
.clone(),
if splits.len() > 1 {
splits[1].parse::<usize>()?
} else {
0
},
)
};
Ok(OutletId::new(input.0, input.1))
})
.collect::<Result<Vec<_>>>()
.map_err(|e| format!("While building node {}, {}", name, e.description()))?;
let node = Node {
id: nodes.len(),
name: name.to_string(),
op_name: pbnode.get_op().to_string(),
inputs: inputs,
op: op_builder
.build(&pbnode)
.map_err(|e| format!("While building node {}, {}", name, e.description()))?,
};
nodes_by_name.insert(name, nodes.len());
nodes.push(node)
}
Ok(Model(Arc::new(RawModel {
nodes,
nodes_by_name,
})))
}