use core::convert::TryInto;
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use crate::micro_error_reporter::MicroErrorReporter;
use crate::micro_op_resolver::OpResolverRepr;
use crate::tensor::{ElemTypeOf, Tensor, TensorInfo};
use crate::Error;
use crate::{model::Model, Status};
use managed::ManagedSlice;
use crate::bindings;
use crate::bindings::tflite;
cpp! {{
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
}}
static mut ERROR_REPORTER: MaybeUninit<MicroErrorReporter> =
MaybeUninit::uninit();
pub struct MicroInterpreter<'a> {
micro_interpreter: tflite::MicroInterpreter,
_phantom: PhantomData<&'a ()>,
}
impl<'a> MicroInterpreter<'a> {
pub fn new<'m: 'a, 't: 'a, TArena, OpResolver>(
model: &'m Model,
resolver: OpResolver,
tensor_arena: TArena,
) -> Result<Self, Error>
where
OpResolver: OpResolverRepr,
TArena: Into<ManagedSlice<'t, u8>>,
{
let resolver = resolver.to_inner();
let mut tensor_arena = tensor_arena.into();
let tensor_arena_size = tensor_arena.len();
let tensor_arena = tensor_arena.as_mut_ptr();
let micro_error_reporter_ref = unsafe {
let micro_error_reporter = MicroErrorReporter::new();
ERROR_REPORTER = MaybeUninit::new(micro_error_reporter);
&ERROR_REPORTER };
let mut status = bindings::TfLiteStatus::kTfLiteError;
let mut micro_interpreter = unsafe {
let status_ref = &mut status;
cpp! ([
model as "const tflite::Model*",
resolver as "tflite::MicroMutableOpResolver<128>",
tensor_arena as "uint8_t*",
tensor_arena_size as "size_t",
micro_error_reporter_ref as "tflite::MicroErrorReporter*",
status_ref as "TfLiteStatus*"
] -> tflite::MicroInterpreter as "tflite::MicroInterpreter"
{
tflite::ErrorReporter* error_reporter = micro_error_reporter_ref;
tflite::MicroInterpreter interpreter(model,
resolver,
tensor_arena,
tensor_arena_size,
error_reporter);
*status_ref = interpreter.initialization_status();
return interpreter;
})
};
if status != bindings::TfLiteStatus::kTfLiteOk {
return Err(Error::InterpreterInitError);
}
let allocate_tensors_status = unsafe {
let interpreter_ref = &mut micro_interpreter;
cpp! ([interpreter_ref as "tflite::MicroInterpreter*"]
-> bindings::TfLiteStatus as "TfLiteStatus" {
return interpreter_ref->AllocateTensors();
})
};
if allocate_tensors_status != bindings::TfLiteStatus::kTfLiteOk {
return Err(Error::AllocateTensorsError);
}
Ok(Self {
micro_interpreter,
_phantom: PhantomData,
})
}
pub fn input_info(&self, n: usize) -> TensorInfo {
let interpreter = &self.micro_interpreter;
let input_tensor: &'a Tensor = unsafe {
let inp = cpp!([
interpreter as "tflite::MicroInterpreter*",
n as "size_t"]
-> *mut bindings::TfLiteTensor as "TfLiteTensor*" {
return interpreter->input(n);
});
assert!(!inp.is_null(), "Obtained nullptr from TensorFlow");
inp.into()
};
input_tensor.info()
}
pub fn input<T: ElemTypeOf + core::clone::Clone>(
&mut self,
n: usize,
data: &[T],
) -> Result<(), Error> {
let interpreter = &self.micro_interpreter;
let input_tensor: &mut Tensor = unsafe {
let inp = cpp!([
interpreter as "tflite::MicroInterpreter*",
n as "size_t"]
-> *mut bindings::TfLiteTensor as "TfLiteTensor*" {
return interpreter->input(n);
});
assert!(!inp.is_null(), "Obtained nullptr from TensorFlow");
inp.into()
};
let tensor_info: TensorInfo = input_tensor.inner().try_into()?;
let tensor_len = tensor_info.dims.iter().product::<i32>();
if tensor_len != data.len().try_into().unwrap() {
Err(Error::InputDataLenMismatch)
} else {
input_tensor.as_data_mut().clone_from_slice(data);
Ok(())
}
}
pub fn invoke(&mut self) -> Result<(), Status> {
let interpreter = &self.micro_interpreter;
let status = unsafe {
cpp!([interpreter as "tflite::MicroInterpreter*"]
-> bindings::TfLiteStatus as "TfLiteStatus" {
return interpreter->Invoke();
})
};
match status.into() {
Status::Ok => Ok(()),
e => Err(e),
}
}
pub fn output(&self, n: usize) -> &'a Tensor {
let interpreter = &self.micro_interpreter;
unsafe {
let out = cpp!([
interpreter as "tflite::MicroInterpreter*",
n as "size_t"]
-> *mut bindings::TfLiteTensor as "TfLiteTensor*" {
return interpreter->output(n);
});
assert!(!out.is_null(), "Obtained nullptr from Tensorflow!");
out.into()
}
}
pub fn arena_used_bytes(&self) -> usize {
let interpreter = &self.micro_interpreter;
unsafe {
cpp!([interpreter as "tflite::MicroInterpreter*"]
-> usize as "size_t" {
return interpreter->arena_used_bytes();
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::micro_op_resolver::AllOpResolver;
use crate::tensor::ElementType;
#[test]
fn new_interpreter_static_arena() {
let model = include_bytes!("../examples/models/hello_world.tflite");
let model = Model::from_buffer(&model[..]).unwrap();
let all_op_resolver = AllOpResolver::new();
const TENSOR_ARENA_SIZE: usize = 4 * 1024;
let mut tensor_arena: [u8; TENSOR_ARENA_SIZE] = [0; TENSOR_ARENA_SIZE];
let _ = MicroInterpreter::new(
&model,
all_op_resolver,
&mut tensor_arena[..],
)
.unwrap();
}
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "alloc")]
use alloc::{vec, vec::Vec};
#[test]
#[cfg(any(feature = "std", feature = "alloc"))]
fn new_interpreter_alloc_arena() {
let model = include_bytes!("../examples/models/hello_world.tflite");
let model = Model::from_buffer(&model[..]).unwrap();
let all_op_resolver = AllOpResolver::new();
let tensor_arena: Vec<u8> = vec![0u8; 4 * 1024];
let _ = MicroInterpreter::new(&model, all_op_resolver, tensor_arena)
.unwrap();
}
#[test]
fn input_info() {
let model = include_bytes!("../examples/models/hello_world.tflite");
let model = Model::from_buffer(&model[..]).unwrap();
let all_op_resolver = AllOpResolver::new();
const TENSOR_ARENA_SIZE: usize = 4 * 1024;
let mut tensor_arena: [u8; TENSOR_ARENA_SIZE] = [0; TENSOR_ARENA_SIZE];
let interpreter = MicroInterpreter::new(
&model,
all_op_resolver,
&mut tensor_arena[..],
)
.unwrap();
let info = interpreter.input_info(0);
assert_eq!(info.name, "dense_2_input");
assert_eq!(info.element_type, ElementType::Float32);
assert_eq!(info.dims, [1, 1]);
}
}