tflite 0.9.8

Rust bindings for TensorFlow Lite
use std::ffi::CStr;
use std::fmt;

use crate::bindings;

pub type ElementKind = bindings::TfLiteType;
pub type QuantizationParams = bindings::TfLiteQuantizationParams;

pub trait ElemKindOf {
    fn elem_kind_of() -> ElementKind;
}

impl ElemKindOf for f32 {
    fn elem_kind_of() -> ElementKind {
        bindings::TfLiteType::kTfLiteFloat32
    }
}

impl ElemKindOf for u8 {
    fn elem_kind_of() -> ElementKind {
        bindings::TfLiteType::kTfLiteUInt8
    }
}

impl ElemKindOf for i32 {
    fn elem_kind_of() -> ElementKind {
        bindings::TfLiteType::kTfLiteInt32
    }
}

pub struct TensorInfo {
    pub name: String,
    pub element_kind: ElementKind,
    pub dims: Vec<usize>,
}

impl fmt::Debug for TensorInfo {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("TensorInfo")
            .field("name", &self.name)
            .field("element_kind", &self.element_kind)
            .field("dims", &self.dims)
            .finish()
    }
}

impl<'a> From<&'a bindings::TfLiteTensor> for TensorInfo {
    fn from(t: &'a bindings::TfLiteTensor) -> Self {
        Self {
            name: unsafe { CStr::from_ptr(t.name) }.to_str().unwrap().to_string(),
            element_kind: t.type_,
            dims: {
                let slice = unsafe {
                    let dims = &*t.dims;
                    dims.data.as_slice(dims.size as usize)
                };
                slice.iter().map(|n| *n as usize).collect()
            },
        }
    }
}