tract-tensorflow 0.3.3

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
#![allow(unused)]
#![allow(deprecated)]
#![allow(non_snake_case)]

error_chain! {
    foreign_links {
        Io(::std::io::Error);
        NdarrayShape(::ndarray::ShapeError);
        Protobuf(::protobuf::ProtobufError);
        StrUtf8(::std::str::Utf8Error);
    }
    errors {
        TFString {}
    }
}

impl ::std::convert::From<::tensorflow::Status> for Error {
    fn from(tfs: ::tensorflow::Status) -> Error {
        format!("Tensorflow error: {:?}", tfs).into()
    }
}

pub mod tf;

pub use protobuf::Message;

use crate::tfpb;
use crate::tfpb::tensor_shape::{TensorShapeProto, TensorShapeProto_Dim};
use crate::tfpb::types::DataType;
use tract_core::internal::*;

pub fn placeholder<Shape: Into<Option<TensorShapeProto>>>(
    name: &str,
    t: DataType,
    shape: Shape,
) -> tfpb::node_def::NodeDef {
    let mut node = tfpb::node().name(name).op("Placeholder").attr("dtype", t);
    if let Some(shape) = shape.into() {
        node = node.attr("shape", shape)
    }
    node
}

pub fn tensor_shape(dims: &[usize]) -> TensorShapeProto {
    let mut shape = TensorShapeProto::new();
    shape.set_dim(
        dims.iter()
            .map(|&d| {
                let mut dim = TensorShapeProto_Dim::new();
                dim.set_size(d as i64);
                dim
            })
            .collect(),
    );
    shape
}

pub fn const_f32(name: &str, t: &Tensor) -> tfpb::node_def::NodeDef {
    let mut tf = crate::tfpb::tensor_f32(
        t.shape().iter().cloned().collect(),
        t.to_array_view::<f32>().unwrap().iter().cloned().collect(),
    );
    tfpb::node().name(name).op("Const").attr("dtype", DataType::DT_FLOAT).attr("value", tf)
}

pub fn placeholder_f32(name: &str) -> tfpb::node_def::NodeDef {
    placeholder(name, DataType::DT_FLOAT, None)
}

pub fn const_i32(name: &str, t: &Tensor) -> tfpb::node_def::NodeDef {
    let mut tf = crate::tfpb::tensor_i32(
        t.shape().iter().cloned().collect(),
        t.to_array_view::<i32>().unwrap().iter().cloned().collect(),
    );
    tfpb::node().name(name).op("Const").attr("dtype", DataType::DT_INT32).attr("value", tf)
}

pub fn placeholder_i32(name: &str) -> tfpb::node_def::NodeDef {
    placeholder(name, DataType::DT_INT32, None)
}