use std::ptr::NonNull;
use edgefirst_tflite_sys::{TfLiteInterpreter, TfLiteInterpreterOptions};
use crate::delegate::Delegate;
use crate::error::{self, Error, Result};
use crate::model::Model;
use crate::profiler::Profiler;
use crate::tensor::{Tensor, TensorMut};
use crate::Library;
pub struct InterpreterBuilder<'lib> {
options: NonNull<TfLiteInterpreterOptions>,
delegates: Vec<Delegate>,
lib: &'lib Library,
}
impl<'lib> InterpreterBuilder<'lib> {
#[must_use]
pub fn num_threads(self, n: i32) -> Self {
unsafe {
self.lib
.as_sys()
.TfLiteInterpreterOptionsSetNumThreads(self.options.as_ptr(), n);
}
self
}
#[must_use]
pub fn delegate(mut self, d: Delegate) -> Self {
unsafe {
self.lib
.as_sys()
.TfLiteInterpreterOptionsAddDelegate(self.options.as_ptr(), d.as_ptr());
}
self.delegates.push(d);
self
}
pub fn profiler(self, profiler: &Profiler) -> Result<Self> {
let tflite_lib = self.lib.reopen()?;
let set_profiler: libloading::Symbol<
'_,
unsafe extern "C" fn(*mut TfLiteInterpreterOptions, *mut std::ffi::c_void),
> = unsafe { tflite_lib.get(b"TfLiteInterpreterOptionsSetTelemetryProfiler\0") }.map_err(
|_| {
Error::invalid_argument(
"TfLiteInterpreterOptionsSetTelemetryProfiler symbol not found — \
the TFLite library may not support the telemetry profiler API",
)
},
)?;
unsafe {
set_profiler(self.options.as_ptr(), profiler.as_ptr());
}
Ok(self)
}
pub fn build(mut self, model: &Model<'lib>) -> Result<Interpreter<'lib>> {
let raw = unsafe {
self.lib
.as_sys()
.TfLiteInterpreterCreate(model.as_ptr(), self.options.as_ptr())
};
let interp_ptr = NonNull::new(raw)
.ok_or_else(|| Error::null_pointer("TfLiteInterpreterCreate returned null"))?;
let interpreter = Interpreter {
ptr: interp_ptr,
delegates: std::mem::take(&mut self.delegates),
lib: self.lib,
};
let status = unsafe {
self.lib
.as_sys()
.TfLiteInterpreterAllocateTensors(interpreter.ptr.as_ptr())
};
error::status_to_result(status)
.map_err(|e| e.with_context("TfLiteInterpreterAllocateTensors"))?;
Ok(interpreter)
}
}
impl Drop for InterpreterBuilder<'_> {
fn drop(&mut self) {
unsafe {
self.lib
.as_sys()
.TfLiteInterpreterOptionsDelete(self.options.as_ptr());
}
}
}
impl std::fmt::Debug for InterpreterBuilder<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InterpreterBuilder")
.field("delegates", &self.delegates.len())
.finish()
}
}
pub struct Interpreter<'lib> {
ptr: NonNull<TfLiteInterpreter>,
delegates: Vec<Delegate>,
lib: &'lib Library,
}
impl<'lib> Interpreter<'lib> {
pub fn builder(lib: &'lib Library) -> Result<InterpreterBuilder<'lib>> {
let options = NonNull::new(unsafe { lib.as_sys().TfLiteInterpreterOptionsCreate() })
.ok_or_else(|| Error::null_pointer("TfLiteInterpreterOptionsCreate returned null"))?;
Ok(InterpreterBuilder {
options,
delegates: Vec::new(),
lib,
})
}
pub fn allocate_tensors(&mut self) -> Result<()> {
let status = unsafe {
self.lib
.as_sys()
.TfLiteInterpreterAllocateTensors(self.ptr.as_ptr())
};
error::status_to_result(status)
.map_err(|e| e.with_context("TfLiteInterpreterAllocateTensors"))
}
pub fn resize_input(&mut self, input_index: usize, shape: &[i32]) -> Result<()> {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let index = input_index as i32;
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let dims_size = shape.len() as i32;
let status = unsafe {
self.lib.as_sys().TfLiteInterpreterResizeInputTensor(
self.ptr.as_ptr(),
index,
shape.as_ptr(),
dims_size,
)
};
error::status_to_result(status)
.map_err(|e| e.with_context("TfLiteInterpreterResizeInputTensor"))
}
pub fn invoke(&mut self) -> Result<()> {
let status = unsafe { self.lib.as_sys().TfLiteInterpreterInvoke(self.ptr.as_ptr()) };
error::status_to_result(status).map_err(|e| e.with_context("TfLiteInterpreterInvoke"))
}
pub fn inputs(&self) -> Result<Vec<Tensor<'_>>> {
let count = self.input_count();
let mut inputs = Vec::with_capacity(count);
for i in 0..count {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let raw = unsafe {
self.lib
.as_sys()
.TfLiteInterpreterGetInputTensor(self.ptr.as_ptr(), i as i32)
};
if raw.is_null() {
return Err(Error::null_pointer(format!(
"TfLiteInterpreterGetInputTensor returned null for index {i}"
)));
}
inputs.push(Tensor {
ptr: raw,
lib: self.lib.as_sys(),
});
}
Ok(inputs)
}
pub fn inputs_mut(&mut self) -> Result<Vec<TensorMut<'_>>> {
let count = self.input_count();
let mut inputs = Vec::with_capacity(count);
for i in 0..count {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let raw = unsafe {
self.lib
.as_sys()
.TfLiteInterpreterGetInputTensor(self.ptr.as_ptr(), i as i32)
};
let ptr = NonNull::new(raw).ok_or_else(|| {
Error::null_pointer(format!(
"TfLiteInterpreterGetInputTensor returned null for index {i}"
))
})?;
inputs.push(TensorMut {
ptr,
lib: self.lib.as_sys(),
});
}
Ok(inputs)
}
pub fn outputs(&self) -> Result<Vec<Tensor<'_>>> {
let count = self.output_count();
let mut outputs = Vec::with_capacity(count);
for i in 0..count {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let raw = unsafe {
self.lib
.as_sys()
.TfLiteInterpreterGetOutputTensor(self.ptr.as_ptr(), i as i32)
};
if raw.is_null() {
return Err(Error::null_pointer(format!(
"TfLiteInterpreterGetOutputTensor returned null for index {i}"
)));
}
outputs.push(Tensor {
ptr: raw,
lib: self.lib.as_sys(),
});
}
Ok(outputs)
}
#[must_use]
pub fn input_count(&self) -> usize {
#[allow(clippy::cast_sign_loss)]
let count = unsafe {
self.lib
.as_sys()
.TfLiteInterpreterGetInputTensorCount(self.ptr.as_ptr())
} as usize;
count
}
#[must_use]
pub fn output_count(&self) -> usize {
#[allow(clippy::cast_sign_loss)]
let count = unsafe {
self.lib
.as_sys()
.TfLiteInterpreterGetOutputTensorCount(self.ptr.as_ptr())
} as usize;
count
}
#[must_use]
pub fn delegates(&self) -> &[Delegate] {
&self.delegates
}
#[must_use]
pub fn delegate(&self, index: usize) -> Option<&Delegate> {
self.delegates.get(index)
}
}
impl std::fmt::Debug for Interpreter<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Interpreter")
.field("ptr", &self.ptr)
.field("delegates", &self.delegates.len())
.finish()
}
}
impl Drop for Interpreter<'_> {
fn drop(&mut self) {
unsafe {
self.lib.as_sys().TfLiteInterpreterDelete(self.ptr.as_ptr());
}
}
}