tract-tensorflow 0.3.3

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::tfpb::tensor::TensorProto;
use crate::tfpb::tensor_shape::{TensorShapeProto, TensorShapeProto_Dim};
use crate::tfpb::types::DataType;
use std::convert::{TryFrom, TryInto};
use tract_core::internal::*;

impl TryFrom<DataType> for DatumType {
    type Error = TractError;
    fn try_from(t: DataType) -> TractResult<DatumType> {
        match t {
            DataType::DT_BOOL => Ok(DatumType::Bool),
            DataType::DT_UINT8 => Ok(DatumType::U8),
            DataType::DT_UINT16 => Ok(DatumType::U16),
            DataType::DT_INT8 => Ok(DatumType::I8),
            DataType::DT_INT16 => Ok(DatumType::I16),
            DataType::DT_INT32 => Ok(DatumType::I32),
            DataType::DT_INT64 => Ok(DatumType::I64),
            DataType::DT_HALF => Ok(DatumType::F16),
            DataType::DT_FLOAT => Ok(DatumType::F32),
            DataType::DT_DOUBLE => Ok(DatumType::F64),
            DataType::DT_STRING => Ok(DatumType::String),
            _ => Err(format!("Unknown DatumType {:?}", t))?,
        }
    }
}

impl<'a> TryFrom<&'a TensorShapeProto> for TVec<usize> {
    type Error = TractError;
    fn try_from(t: &'a TensorShapeProto) -> TractResult<TVec<usize>> {
        Ok(t.get_dim().iter().map(|d| d.size as usize).collect::<TVec<_>>())
    }
}

impl TryFrom<DatumType> for DataType  {
    type Error = TractError;
    fn try_from(dt: DatumType) -> TractResult<DataType> {
        match dt {
            DatumType::Bool => Ok(DataType::DT_BOOL),
            DatumType::U8 => Ok(DataType::DT_UINT8),
            DatumType::U16 => Ok(DataType::DT_UINT16),
            DatumType::I8 => Ok(DataType::DT_INT8),
            DatumType::I16 => Ok(DataType::DT_INT16),
            DatumType::I32 => Ok(DataType::DT_INT32),
            DatumType::I64 => Ok(DataType::DT_INT64),
            DatumType::F16 => Ok(DataType::DT_HALF),
            DatumType::F32 => Ok(DataType::DT_FLOAT),
            DatumType::F64 => Ok(DataType::DT_DOUBLE),
            DatumType::String => Ok(DataType::DT_STRING),
            DatumType::TDim => bail!("Dimension is not translatable in protobuf"),
        }
    }
}

impl<'a> TryFrom<&'a TensorProto> for Tensor {
    type Error = TractError;
    fn try_from(t: &TensorProto) -> TractResult<Tensor> {
        let dtype = t.get_dtype();
        let dims: TVec<usize> = t.get_tensor_shape().try_into()?;
        let rank = dims.len();
        let content = t.get_tensor_content();
        let mat: Tensor = if content.len() != 0 {
            unsafe {
                match dtype {
                    DataType::DT_FLOAT => Self::from_raw::<f32>(&dims, content)?,
                    DataType::DT_INT32 => Self::from_raw::<i32>(&dims, content)?,
                    DataType::DT_INT64 => Self::from_raw::<i64>(&dims, content)?,
                    _ => unimplemented!("missing type {:?}", dtype),
                }
            }
        } else {
            use ndarray::Array;
            match dtype {
                DataType::DT_INT32 => {
                    Array::from_shape_vec(&*dims, t.get_int_val().to_vec())?.into()
                }
                DataType::DT_INT64 => {
                    Array::from_shape_vec(&*dims, t.get_int64_val().to_vec())?.into()
                }
                DataType::DT_FLOAT => {
                    Array::from_shape_vec(&*dims, t.get_float_val().to_vec())?.into()
                }
                _ => unimplemented!("missing type {:?}", dtype),
            }
        };
        assert_eq!(rank, mat.shape().len());
        Ok(mat)
    }
}

impl<'a> TryFrom<&'a Tensor> for TensorProto {
    type Error = TractError;
    fn try_from(from: &Tensor) -> TractResult<TensorProto> {
        let mut shape = TensorShapeProto::new();
        let dims = from
            .shape()
            .iter()
            .map(|d| {
                let mut dim = TensorShapeProto_Dim::new();
                dim.size = *d as _;
                dim
            })
            .collect();
        shape.set_dim(::protobuf::RepeatedField::from_vec(dims));
        let mut tensor = TensorProto::new();
        tensor.set_tensor_shape(shape);
        match from.datum_type() {
            DatumType::F32 => {
                tensor.set_dtype(DatumType::F32.try_into()?);
                tensor.set_float_val(from.to_array_view::<f32>()?.iter().cloned().collect());
            }
            DatumType::F64 => {
                tensor.set_dtype(DatumType::F64.try_into()?);
                tensor.set_double_val(from.to_array_view::<f64>()?.iter().cloned().collect());
            }
            DatumType::I32 => {
                tensor.set_dtype(DatumType::I32.try_into()?);
                tensor.set_int_val(from.to_array_view::<i32>()?.iter().cloned().collect());
            }
            DatumType::I64 => {
                tensor.set_dtype(DatumType::I64.try_into()?);
                tensor.set_int64_val(from.to_array_view::<i64>()?.iter().cloned().collect());
            }
            _ => unimplemented!("missing type {:?}", from.datum_type()),
        }
        Ok(tensor)
    }
}