use core::ptr::null_mut;
use crate::error::error_reporter;
use crate::tensor::InnerTensorData;
use crate::{CoralDevice, Error, Input, Model, Output, Tensor};
use core::marker::PhantomData;
use crate::ffi::*;
pub struct Interpreter {
ptr: *mut TfLiteInterpreter,
}
impl Interpreter {
pub fn new(model: Model, dev: CoralDevice) -> Result<Self, Error> {
unsafe {
let opts: *mut TfLiteInterpreterOptions = TfLiteInterpreterOptionsCreate();
TfLiteInterpreterOptionsSetErrorReporter(opts, Some(error_reporter), null_mut());
TfLiteInterpreterOptionsAddDelegate(opts, dev.create_delegate());
let ptr = TfLiteInterpreterCreate(model.ptr, opts);
if ptr.is_null() {
return Err(Error::FailedToCreateInterpreter);
}
Self::allocate_tensors(&mut Self { ptr })?;
Ok(Self { ptr })
}
}
fn allocate_tensors(&mut self) -> Result<(), Error> {
unsafe {
let ret = TfLiteInterpreterAllocateTensors(self.ptr);
Error::from(ret)
}
}
pub fn input_tensor<T: InnerTensorData>(&self, id: u32) -> Tensor<Input, T> {
unsafe {
let ptr = TfLiteInterpreterGetInputTensor(self.ptr, id as i32);
Tensor::<Input, T> {
ptr,
_marker: PhantomData,
}
}
}
pub fn output_tensor<T: InnerTensorData>(&self, id: u32) -> Tensor<Output, T> {
unsafe {
let ptr = TfLiteInterpreterGetOutputTensor(self.ptr, id as i32);
Tensor::<Output, T> {
ptr: ptr as *mut _,
_marker: PhantomData,
}
}
}
pub fn invoke(&mut self) -> Result<(), Error> {
unsafe {
let ret = TfLiteInterpreterInvoke(self.ptr);
Error::from(ret)
}
}
}
impl Drop for Interpreter {
fn drop(&mut self) {
unsafe {
TfLiteInterpreterDelete(self.ptr);
}
}
}