edgefirst_tflite/interpreter.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Interpreter and builder for `TFLite` model inference.
5//!
6//! The [`Interpreter`] is created through a builder pattern:
7//!
8//! ```no_run
9//! use edgefirst_tflite::{Library, Model, Interpreter};
10//!
11//! let lib = Library::new()?;
12//! let model = Model::from_file(&lib, "model.tflite")?;
13//!
14//! let mut interpreter = Interpreter::builder(&lib)?
15//! .num_threads(4)
16//! .build(&model)?;
17//!
18//! interpreter.invoke()?;
19//! # Ok::<(), edgefirst_tflite::Error>(())
20//! ```
21
22use std::ptr::NonNull;
23
24use edgefirst_tflite_sys::{TfLiteInterpreter, TfLiteInterpreterOptions};
25
26use crate::delegate::Delegate;
27use crate::error::{self, Error, Result};
28use crate::model::Model;
29use crate::tensor::{Tensor, TensorMut};
30use crate::Library;
31
32// ---------------------------------------------------------------------------
33// InterpreterBuilder
34// ---------------------------------------------------------------------------
35
36/// Builder for configuring and creating a `TFLite` [`Interpreter`].
37///
38/// Created via [`Interpreter::builder`].
39pub struct InterpreterBuilder<'lib> {
40 options: NonNull<TfLiteInterpreterOptions>,
41 delegates: Vec<Delegate>,
42 lib: &'lib Library,
43}
44
45impl<'lib> InterpreterBuilder<'lib> {
46 /// Set the number of threads for inference.
47 ///
48 /// A value of -1 lets `TFLite` choose based on the platform.
49 #[must_use]
50 pub fn num_threads(self, n: i32) -> Self {
51 // SAFETY: `self.options` is a valid non-null options pointer created by
52 // `TfLiteInterpreterOptionsCreate`.
53 unsafe {
54 self.lib
55 .as_sys()
56 .TfLiteInterpreterOptionsSetNumThreads(self.options.as_ptr(), n);
57 }
58 self
59 }
60
61 /// Add a delegate for hardware acceleration.
62 ///
63 /// The delegate is moved into the builder and will be owned by the
64 /// resulting [`Interpreter`].
65 #[must_use]
66 pub fn delegate(mut self, d: Delegate) -> Self {
67 // SAFETY: `self.options` and the delegate pointer are both valid. The
68 // delegate is stored in `self.delegates` to keep it alive.
69 unsafe {
70 self.lib
71 .as_sys()
72 .TfLiteInterpreterOptionsAddDelegate(self.options.as_ptr(), d.as_ptr());
73 }
74 self.delegates.push(d);
75 self
76 }
77
78 /// Build the interpreter for the given model.
79 ///
80 /// This creates the interpreter and allocates tensors. After this call,
81 /// input tensors can be populated and inference can be run.
82 ///
83 /// # Errors
84 ///
85 /// Returns an error if interpreter creation fails or tensor allocation
86 /// returns a non-OK status.
87 pub fn build(mut self, model: &Model<'lib>) -> Result<Interpreter<'lib>> {
88 // SAFETY: `model.as_ptr()` and `self.options` are both valid non-null
89 // pointers. The library is loaded and the function pointer is valid.
90 let raw = unsafe {
91 self.lib
92 .as_sys()
93 .TfLiteInterpreterCreate(model.as_ptr(), self.options.as_ptr())
94 };
95
96 let interp_ptr = NonNull::new(raw)
97 .ok_or_else(|| Error::null_pointer("TfLiteInterpreterCreate returned null"))?;
98
99 let interpreter = Interpreter {
100 ptr: interp_ptr,
101 delegates: std::mem::take(&mut self.delegates),
102 lib: self.lib,
103 };
104
105 // SAFETY: `interpreter.ptr` is a valid interpreter pointer just created above.
106 let status = unsafe {
107 self.lib
108 .as_sys()
109 .TfLiteInterpreterAllocateTensors(interpreter.ptr.as_ptr())
110 };
111 error::status_to_result(status)
112 .map_err(|e| e.with_context("TfLiteInterpreterAllocateTensors"))?;
113
114 Ok(interpreter)
115 }
116}
117
118impl Drop for InterpreterBuilder<'_> {
119 fn drop(&mut self) {
120 // SAFETY: `self.options` was created by `TfLiteInterpreterOptionsCreate`
121 // and has not been deleted yet.
122 unsafe {
123 self.lib
124 .as_sys()
125 .TfLiteInterpreterOptionsDelete(self.options.as_ptr());
126 }
127 }
128}
129
130impl std::fmt::Debug for InterpreterBuilder<'_> {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_struct("InterpreterBuilder")
133 .field("delegates", &self.delegates.len())
134 .finish()
135 }
136}
137
138// ---------------------------------------------------------------------------
139// Interpreter
140// ---------------------------------------------------------------------------
141
142/// `TFLite` inference engine.
143///
144/// Owns its delegates and provides access to input/output tensors.
145/// Created via [`Interpreter::builder`].
146pub struct Interpreter<'lib> {
147 ptr: NonNull<TfLiteInterpreter>,
148 delegates: Vec<Delegate>,
149 lib: &'lib Library,
150}
151
152impl<'lib> Interpreter<'lib> {
153 /// Create a new [`InterpreterBuilder`] for configuring an interpreter.
154 ///
155 /// # Errors
156 ///
157 /// Returns an error if `TfLiteInterpreterOptionsCreate` returns null.
158 pub fn builder(lib: &'lib Library) -> Result<InterpreterBuilder<'lib>> {
159 // SAFETY: The library is loaded and the function pointer is valid.
160 let options = NonNull::new(unsafe { lib.as_sys().TfLiteInterpreterOptionsCreate() })
161 .ok_or_else(|| Error::null_pointer("TfLiteInterpreterOptionsCreate returned null"))?;
162
163 Ok(InterpreterBuilder {
164 options,
165 delegates: Vec::new(),
166 lib,
167 })
168 }
169
170 /// Re-allocate tensors after an input resize.
171 ///
172 /// This must be called after [`Interpreter::resize_input`] and before
173 /// [`Interpreter::invoke`]. Any previously obtained tensor slices or
174 /// pointers are invalidated.
175 ///
176 /// # Errors
177 ///
178 /// Returns an error if the C API returns a non-OK status.
179 pub fn allocate_tensors(&mut self) -> Result<()> {
180 // SAFETY: `self.ptr` is a valid interpreter pointer.
181 let status = unsafe {
182 self.lib
183 .as_sys()
184 .TfLiteInterpreterAllocateTensors(self.ptr.as_ptr())
185 };
186 error::status_to_result(status)
187 .map_err(|e| e.with_context("TfLiteInterpreterAllocateTensors"))
188 }
189
190 /// Resize an input tensor's dimensions.
191 ///
192 /// After resizing, [`Interpreter::allocate_tensors`] must be called
193 /// before inference can proceed.
194 ///
195 /// # Errors
196 ///
197 /// Returns an error if the C API returns a non-OK status (e.g., the
198 /// input index is out of range).
199 pub fn resize_input(&mut self, input_index: usize, shape: &[i32]) -> Result<()> {
200 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
201 let index = input_index as i32;
202 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
203 let dims_size = shape.len() as i32;
204 // SAFETY: `self.ptr` is a valid interpreter pointer. `shape` is a
205 // valid slice and `dims_size` is its length. The C API copies the
206 // shape data, so the slice only needs to be valid for this call.
207 let status = unsafe {
208 self.lib.as_sys().TfLiteInterpreterResizeInputTensor(
209 self.ptr.as_ptr(),
210 index,
211 shape.as_ptr(),
212 dims_size,
213 )
214 };
215 error::status_to_result(status)
216 .map_err(|e| e.with_context("TfLiteInterpreterResizeInputTensor"))
217 }
218
219 /// Run model inference.
220 ///
221 /// # Errors
222 ///
223 /// Returns an error if the C API returns a non-OK status.
224 pub fn invoke(&mut self) -> Result<()> {
225 // SAFETY: `self.ptr` is a valid interpreter pointer with tensors allocated.
226 let status = unsafe { self.lib.as_sys().TfLiteInterpreterInvoke(self.ptr.as_ptr()) };
227 error::status_to_result(status).map_err(|e| e.with_context("TfLiteInterpreterInvoke"))
228 }
229
230 /// Get immutable views of all input tensors.
231 ///
232 /// # Errors
233 ///
234 /// Returns an error if any input tensor pointer is null.
235 pub fn inputs(&self) -> Result<Vec<Tensor<'_>>> {
236 let count = self.input_count();
237 let mut inputs = Vec::with_capacity(count);
238 for i in 0..count {
239 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
240 // SAFETY: `self.ptr` is a valid interpreter and `i` is in bounds
241 // (below `input_count`).
242 let raw = unsafe {
243 self.lib
244 .as_sys()
245 .TfLiteInterpreterGetInputTensor(self.ptr.as_ptr(), i as i32)
246 };
247 if raw.is_null() {
248 return Err(Error::null_pointer(format!(
249 "TfLiteInterpreterGetInputTensor returned null for index {i}"
250 )));
251 }
252 inputs.push(Tensor {
253 ptr: raw,
254 lib: self.lib.as_sys(),
255 });
256 }
257 Ok(inputs)
258 }
259
260 /// Get mutable views of all input tensors.
261 ///
262 /// # Errors
263 ///
264 /// Returns an error if any input tensor pointer is null.
265 pub fn inputs_mut(&mut self) -> Result<Vec<TensorMut<'_>>> {
266 let count = self.input_count();
267 let mut inputs = Vec::with_capacity(count);
268 for i in 0..count {
269 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
270 // SAFETY: `self.ptr` is a valid interpreter and `i` is in bounds.
271 // We hold `&mut self` ensuring exclusive access to the tensor data.
272 let raw = unsafe {
273 self.lib
274 .as_sys()
275 .TfLiteInterpreterGetInputTensor(self.ptr.as_ptr(), i as i32)
276 };
277 let ptr = NonNull::new(raw).ok_or_else(|| {
278 Error::null_pointer(format!(
279 "TfLiteInterpreterGetInputTensor returned null for index {i}"
280 ))
281 })?;
282 inputs.push(TensorMut {
283 ptr,
284 lib: self.lib.as_sys(),
285 });
286 }
287 Ok(inputs)
288 }
289
290 /// Get immutable views of all output tensors.
291 ///
292 /// # Errors
293 ///
294 /// Returns an error if any output tensor pointer is null.
295 pub fn outputs(&self) -> Result<Vec<Tensor<'_>>> {
296 let count = self.output_count();
297 let mut outputs = Vec::with_capacity(count);
298 for i in 0..count {
299 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
300 // SAFETY: `self.ptr` is a valid interpreter and `i` is in bounds
301 // (below `output_count`).
302 let raw = unsafe {
303 self.lib
304 .as_sys()
305 .TfLiteInterpreterGetOutputTensor(self.ptr.as_ptr(), i as i32)
306 };
307 if raw.is_null() {
308 return Err(Error::null_pointer(format!(
309 "TfLiteInterpreterGetOutputTensor returned null for index {i}"
310 )));
311 }
312 outputs.push(Tensor {
313 ptr: raw,
314 lib: self.lib.as_sys(),
315 });
316 }
317 Ok(outputs)
318 }
319
320 /// Returns the number of input tensors.
321 #[must_use]
322 pub fn input_count(&self) -> usize {
323 // SAFETY: `self.ptr` is a valid interpreter pointer.
324 #[allow(clippy::cast_sign_loss)]
325 let count = unsafe {
326 self.lib
327 .as_sys()
328 .TfLiteInterpreterGetInputTensorCount(self.ptr.as_ptr())
329 } as usize;
330 count
331 }
332
333 /// Returns the number of output tensors.
334 #[must_use]
335 pub fn output_count(&self) -> usize {
336 // SAFETY: `self.ptr` is a valid interpreter pointer.
337 #[allow(clippy::cast_sign_loss)]
338 let count = unsafe {
339 self.lib
340 .as_sys()
341 .TfLiteInterpreterGetOutputTensorCount(self.ptr.as_ptr())
342 } as usize;
343 count
344 }
345
346 /// Access all delegates owned by this interpreter.
347 #[must_use]
348 pub fn delegates(&self) -> &[Delegate] {
349 &self.delegates
350 }
351
352 /// Access a specific delegate by index.
353 #[must_use]
354 pub fn delegate(&self, index: usize) -> Option<&Delegate> {
355 self.delegates.get(index)
356 }
357}
358
359impl std::fmt::Debug for Interpreter<'_> {
360 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361 f.debug_struct("Interpreter")
362 .field("ptr", &self.ptr)
363 .field("delegates", &self.delegates.len())
364 .finish()
365 }
366}
367
368impl Drop for Interpreter<'_> {
369 fn drop(&mut self) {
370 // SAFETY: The interpreter was created by `TfLiteInterpreterCreate` and
371 // has not been deleted. Delegates are dropped after the interpreter
372 // since they are stored in the same struct and Rust drops fields in
373 // declaration order.
374 unsafe {
375 self.lib.as_sys().TfLiteInterpreterDelete(self.ptr.as_ptr());
376 }
377 }
378}