use crate::tfpb::tensorflow::{GraphDef, NodeDef, SavedModel};
use prost::Message;
use std::{fs, path};
use tract_hir::internal::*;
#[derive(Default)]
pub struct ParsingContext {
pub node_output_arities: HashMap<String, usize>,
}
#[derive(Clone, Default)]
pub struct TfOpRegister(
pub HashMap<String, fn(&ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>>>,
);
impl TfOpRegister {
pub fn insert(
&mut self,
s: &'static str,
builder: fn(&ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>>,
) {
self.0.insert(s.into(), builder);
}
}
pub struct Tensorflow {
pub op_register: TfOpRegister,
}
pub struct TfModelExtensions {
pub control_inputs: Vec<(usize, usize)>,
pub initializing_nodes: Vec<usize>,
}
impl TfModelExtensions {
pub fn preproc(&self, mut original: InferenceModel) -> TractResult<InferenceModel> {
if self.initializing_nodes.len() > 0 {
let as_outlets =
self.initializing_nodes.iter().map(|n| OutletId::new(*n, 0)).collect::<Vec<_>>();
let plan =
SimplePlan::new_for_outputs_and_deps(&original, &as_outlets, &self.control_inputs)?;
let mut state = SimpleState::new(plan)?;
let _outputs = state.run(tvec!())?;
let tensors = state.session_state.tensors;
for node in &mut original.nodes {
if let Some(var) = node.op_as_mut::<crate::ops::vars::VariableV2>() {
if let Some(value) = tensors.get(&var.id) {
var.initializer = Some(value.clone().into_arc_tensor());
}
}
}
}
Ok(original)
}
}
pub struct TfModelAndExtensions(pub InferenceModel, pub TfModelExtensions);
impl Tensorflow {
fn parse_input(i: &str) -> TractResult<(&str, usize)> {
let pair = if i.starts_with("^") {
(&i[1..], 0)
} else {
let splits: Vec<_> = i.splitn(2, ':').collect();
(splits[0], if splits.len() > 1 { splits[1].parse::<usize>()? } else { 0 })
};
Ok(pair)
}
pub fn determinize(model: &mut GraphDef) -> TractResult<()> {
for pbnode in &mut model.node {
if pbnode.op == "RandomUniform" {
if pbnode.get_attr_int::<i64>("seed")? == 0
&& pbnode.get_attr_int::<i64>("seed2")? == 0
{
pbnode.attr.insert("seed".to_string(), 1.into());
pbnode.attr.insert("seed2".to_string(), 1.into());
}
}
}
Ok(())
}
pub fn read_frozen_from_path(&self, p: impl AsRef<path::Path>) -> TractResult<GraphDef> {
#[cfg(not(target_arch = "wasm32"))]
let map = unsafe { mapr::Mmap::map(&fs::File::open(p)?)? };
#[cfg(target_arch = "wasm32")]
let map = fs::read(p)?;
Ok(GraphDef::decode(&*map)?)
}
pub fn read_frozen_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
let mut v = vec![];
r.read_to_end(&mut v)?;
let b = bytes::Bytes::from(v);
Ok(GraphDef::decode(b)?)
}
pub fn open_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<SavedModel> {
let mut v = vec![];
r.read_to_end(&mut v)?;
let b = bytes::Bytes::from(v);
Ok(SavedModel::decode(b)?)
}
pub fn read_saved_model(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
let mut saved = self.open_saved_model(r)?;
Ok(saved.meta_graphs.remove(0).graph_def.unwrap())
}
pub fn parse_graph(&self, graph: &GraphDef) -> TractResult<TfModelAndExtensions> {
use crate::ops::control_flow as cf;
let mut model = InferenceModel::default();
let mut inputs = tvec!();
let mut context = ParsingContext::default();
let mut control_inputs = vec![];
for pbnode in &graph.node {
for i in &pbnode.input {
let (node, slot) = Self::parse_input(i)?;
let arity = context.node_output_arities.entry(node.to_string()).or_insert(1);
*arity = (*arity).max(slot + 1);
}
}
for pbnode in &graph.node {
let name = &pbnode.name;
if pbnode.op == "NextIteration" {
let source_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Source);
let sink_op = cf::NextIteration::new(name.clone(), cf::NextIterationRole::Sink);
let _source =
model.add_node(name.clone(), source_op, tvec!(InferenceFact::default()))?;
let _sink = model.add_node(format!("{}-Sink", name), sink_op, tvec!())?;
continue;
}
let op = match self.op_register.0.get(&pbnode.op) {
Some(builder) => (builder)(&context, pbnode)?,
None => tract_hir::ops::unimpl::UnimplementedOp::new(
context.node_output_arities.get(name).cloned().unwrap_or(1),
&pbnode.op,
format!("{:?}", pbnode),
)
.into(),
};
let noutputs =
op.nboutputs()?.max(context.node_output_arities.get(name).cloned().unwrap_or(1));
let facts = tvec!(InferenceFact::default(); noutputs);
let node_id = model.add_node(name.clone(), op, facts)?;
if pbnode.op == "Placeholder" {
let dt = pbnode.get_attr_datum_type("dtype")?;
let mut fact = InferenceFact::dt(dt);
if let Some(shape) = pbnode.get_attr_opt_shape("shape")? {
let shape_factoid = ShapeFactoid::closed(
shape
.iter()
.map(|d| {
if *d == -1 {
GenericFactoid::Any
} else {
GenericFactoid::Only(d.to_dim())
}
})
.collect(),
);
fact = fact.with_shape(shape_factoid);
}
inputs.push(OutletId::new(node_id, 0));
model.set_outlet_fact(OutletId::new(node_id, 0), fact)?;
}
}
for pbnode in &graph.node {
let node_id = if pbnode.op == "NextIteration" {
model.node_by_name(&*format!("{}-Sink", &pbnode.name))?.id
} else {
model.node_by_name(&pbnode.name)?.id
};
for (ix, i) in pbnode.input.iter().filter(|n| !n.starts_with("^")).enumerate() {
let input = Self::parse_input(i)?;
let prec = model.node_by_name(input.0)?.id;
let outlet = OutletId::new(prec, input.1);
let inlet = InletId::new(node_id, ix);
model.add_edge(outlet, inlet)?;
model.set_outlet_label(outlet, i.to_string())?;
}
for i in pbnode.input.iter().filter(|n| n.starts_with("^")) {
let input = Self::parse_input(i)?;
let prec = model.node_by_name(input.0)?.id;
control_inputs.push((model.node_id_by_name(&pbnode.name)?, prec));
}
}
for id in 0..model.nodes().len() {
use crate::ops::vars::*;
if model.node(id).op_is::<Assign>() {
let prec = model.node(id).inputs[0];
let var_id = model.node(prec.node).op_as::<VariableV2>().map(|v| v.id.clone());
if let (Some(var_id), Some(assign)) =
(var_id, model.node_mut(id).op_as_mut::<Assign>())
{
assign.var_id = Some(var_id);
} else {
bail!("Model contains unlinked Assign/Variable2");
}
}
}
model.set_input_outlets(&*inputs)?;
model.auto_outputs()?;
let extensions = TfModelExtensions { control_inputs, initializing_nodes: vec![] };
Ok(TfModelAndExtensions(model, extensions))
}
}
impl Framework<GraphDef, InferenceModel> for Tensorflow {
fn proto_model_for_path(&self, r: impl AsRef<path::Path>) -> TractResult<GraphDef> {
self.read_frozen_model(&mut fs::File::open(r.as_ref())?)
.or_else(|_| self.read_saved_model(&mut fs::File::open(r.as_ref())?))
}
fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<GraphDef> {
self.read_frozen_model(r)
}
fn model_for_proto_model(&self, graph: &GraphDef) -> TractResult<InferenceModel> {
Ok(self.parse_graph(graph)?.0)
}
}