Skip to main content

edgefirst_tflite/
interpreter.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Interpreter and builder for `TFLite` model inference.
5//!
6//! The [`Interpreter`] is created through a builder pattern:
7//!
8//! ```no_run
9//! use edgefirst_tflite::{Library, Model, Interpreter};
10//!
11//! let lib = Library::new()?;
12//! let model = Model::from_file(&lib, "model.tflite")?;
13//!
14//! let mut interpreter = Interpreter::builder(&lib)?
15//!     .num_threads(4)
16//!     .build(&model)?;
17//!
18//! interpreter.invoke()?;
19//! # Ok::<(), edgefirst_tflite::Error>(())
20//! ```
21
22use std::ptr::NonNull;
23
24use edgefirst_tflite_sys::{TfLiteInterpreter, TfLiteInterpreterOptions};
25
26use crate::delegate::Delegate;
27use crate::error::{self, Error, Result};
28use crate::model::Model;
29use crate::tensor::{Tensor, TensorMut};
30use crate::Library;
31
32// ---------------------------------------------------------------------------
33// InterpreterBuilder
34// ---------------------------------------------------------------------------
35
36/// Builder for configuring and creating a `TFLite` [`Interpreter`].
37///
38/// Created via [`Interpreter::builder`].
39pub struct InterpreterBuilder<'lib> {
40    options: NonNull<TfLiteInterpreterOptions>,
41    delegates: Vec<Delegate>,
42    lib: &'lib Library,
43}
44
45impl<'lib> InterpreterBuilder<'lib> {
46    /// Set the number of threads for inference.
47    ///
48    /// A value of -1 lets `TFLite` choose based on the platform.
49    #[must_use]
50    pub fn num_threads(self, n: i32) -> Self {
51        // SAFETY: `self.options` is a valid non-null options pointer created by
52        // `TfLiteInterpreterOptionsCreate`.
53        unsafe {
54            self.lib
55                .as_sys()
56                .TfLiteInterpreterOptionsSetNumThreads(self.options.as_ptr(), n);
57        }
58        self
59    }
60
61    /// Add a delegate for hardware acceleration.
62    ///
63    /// The delegate is moved into the builder and will be owned by the
64    /// resulting [`Interpreter`].
65    #[must_use]
66    pub fn delegate(mut self, d: Delegate) -> Self {
67        // SAFETY: `self.options` and the delegate pointer are both valid. The
68        // delegate is stored in `self.delegates` to keep it alive.
69        unsafe {
70            self.lib
71                .as_sys()
72                .TfLiteInterpreterOptionsAddDelegate(self.options.as_ptr(), d.as_ptr());
73        }
74        self.delegates.push(d);
75        self
76    }
77
78    /// Build the interpreter for the given model.
79    ///
80    /// This creates the interpreter and allocates tensors. After this call,
81    /// input tensors can be populated and inference can be run.
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if interpreter creation fails or tensor allocation
86    /// returns a non-OK status.
87    pub fn build(mut self, model: &Model<'lib>) -> Result<Interpreter<'lib>> {
88        // SAFETY: `model.as_ptr()` and `self.options` are both valid non-null
89        // pointers. The library is loaded and the function pointer is valid.
90        let raw = unsafe {
91            self.lib
92                .as_sys()
93                .TfLiteInterpreterCreate(model.as_ptr(), self.options.as_ptr())
94        };
95
96        let interp_ptr = NonNull::new(raw)
97            .ok_or_else(|| Error::null_pointer("TfLiteInterpreterCreate returned null"))?;
98
99        let interpreter = Interpreter {
100            ptr: interp_ptr,
101            delegates: std::mem::take(&mut self.delegates),
102            lib: self.lib,
103        };
104
105        // SAFETY: `interpreter.ptr` is a valid interpreter pointer just created above.
106        let status = unsafe {
107            self.lib
108                .as_sys()
109                .TfLiteInterpreterAllocateTensors(interpreter.ptr.as_ptr())
110        };
111        error::status_to_result(status)
112            .map_err(|e| e.with_context("TfLiteInterpreterAllocateTensors"))?;
113
114        Ok(interpreter)
115    }
116}
117
118impl Drop for InterpreterBuilder<'_> {
119    fn drop(&mut self) {
120        // SAFETY: `self.options` was created by `TfLiteInterpreterOptionsCreate`
121        // and has not been deleted yet.
122        unsafe {
123            self.lib
124                .as_sys()
125                .TfLiteInterpreterOptionsDelete(self.options.as_ptr());
126        }
127    }
128}
129
130impl std::fmt::Debug for InterpreterBuilder<'_> {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        f.debug_struct("InterpreterBuilder")
133            .field("delegates", &self.delegates.len())
134            .finish()
135    }
136}
137
138// ---------------------------------------------------------------------------
139// Interpreter
140// ---------------------------------------------------------------------------
141
142/// `TFLite` inference engine.
143///
144/// Owns its delegates and provides access to input/output tensors.
145/// Created via [`Interpreter::builder`].
146pub struct Interpreter<'lib> {
147    ptr: NonNull<TfLiteInterpreter>,
148    delegates: Vec<Delegate>,
149    lib: &'lib Library,
150}
151
152impl<'lib> Interpreter<'lib> {
153    /// Create a new [`InterpreterBuilder`] for configuring an interpreter.
154    ///
155    /// # Errors
156    ///
157    /// Returns an error if `TfLiteInterpreterOptionsCreate` returns null.
158    pub fn builder(lib: &'lib Library) -> Result<InterpreterBuilder<'lib>> {
159        // SAFETY: The library is loaded and the function pointer is valid.
160        let options = NonNull::new(unsafe { lib.as_sys().TfLiteInterpreterOptionsCreate() })
161            .ok_or_else(|| Error::null_pointer("TfLiteInterpreterOptionsCreate returned null"))?;
162
163        Ok(InterpreterBuilder {
164            options,
165            delegates: Vec::new(),
166            lib,
167        })
168    }
169
170    /// Re-allocate tensors after an input resize.
171    ///
172    /// This must be called after [`Interpreter::resize_input`] and before
173    /// [`Interpreter::invoke`]. Any previously obtained tensor slices or
174    /// pointers are invalidated.
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if the C API returns a non-OK status.
179    pub fn allocate_tensors(&mut self) -> Result<()> {
180        // SAFETY: `self.ptr` is a valid interpreter pointer.
181        let status = unsafe {
182            self.lib
183                .as_sys()
184                .TfLiteInterpreterAllocateTensors(self.ptr.as_ptr())
185        };
186        error::status_to_result(status)
187            .map_err(|e| e.with_context("TfLiteInterpreterAllocateTensors"))
188    }
189
190    /// Resize an input tensor's dimensions.
191    ///
192    /// After resizing, [`Interpreter::allocate_tensors`] must be called
193    /// before inference can proceed.
194    ///
195    /// # Errors
196    ///
197    /// Returns an error if the C API returns a non-OK status (e.g., the
198    /// input index is out of range).
199    pub fn resize_input(&mut self, input_index: usize, shape: &[i32]) -> Result<()> {
200        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
201        let index = input_index as i32;
202        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
203        let dims_size = shape.len() as i32;
204        // SAFETY: `self.ptr` is a valid interpreter pointer. `shape` is a
205        // valid slice and `dims_size` is its length. The C API copies the
206        // shape data, so the slice only needs to be valid for this call.
207        let status = unsafe {
208            self.lib.as_sys().TfLiteInterpreterResizeInputTensor(
209                self.ptr.as_ptr(),
210                index,
211                shape.as_ptr(),
212                dims_size,
213            )
214        };
215        error::status_to_result(status)
216            .map_err(|e| e.with_context("TfLiteInterpreterResizeInputTensor"))
217    }
218
219    /// Run model inference.
220    ///
221    /// # Errors
222    ///
223    /// Returns an error if the C API returns a non-OK status.
224    pub fn invoke(&mut self) -> Result<()> {
225        // SAFETY: `self.ptr` is a valid interpreter pointer with tensors allocated.
226        let status = unsafe { self.lib.as_sys().TfLiteInterpreterInvoke(self.ptr.as_ptr()) };
227        error::status_to_result(status).map_err(|e| e.with_context("TfLiteInterpreterInvoke"))
228    }
229
230    /// Get immutable views of all input tensors.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if any input tensor pointer is null.
235    pub fn inputs(&self) -> Result<Vec<Tensor<'_>>> {
236        let count = self.input_count();
237        let mut inputs = Vec::with_capacity(count);
238        for i in 0..count {
239            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
240            // SAFETY: `self.ptr` is a valid interpreter and `i` is in bounds
241            // (below `input_count`).
242            let raw = unsafe {
243                self.lib
244                    .as_sys()
245                    .TfLiteInterpreterGetInputTensor(self.ptr.as_ptr(), i as i32)
246            };
247            if raw.is_null() {
248                return Err(Error::null_pointer(format!(
249                    "TfLiteInterpreterGetInputTensor returned null for index {i}"
250                )));
251            }
252            inputs.push(Tensor {
253                ptr: raw,
254                lib: self.lib.as_sys(),
255            });
256        }
257        Ok(inputs)
258    }
259
260    /// Get mutable views of all input tensors.
261    ///
262    /// # Errors
263    ///
264    /// Returns an error if any input tensor pointer is null.
265    pub fn inputs_mut(&mut self) -> Result<Vec<TensorMut<'_>>> {
266        let count = self.input_count();
267        let mut inputs = Vec::with_capacity(count);
268        for i in 0..count {
269            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
270            // SAFETY: `self.ptr` is a valid interpreter and `i` is in bounds.
271            // We hold `&mut self` ensuring exclusive access to the tensor data.
272            let raw = unsafe {
273                self.lib
274                    .as_sys()
275                    .TfLiteInterpreterGetInputTensor(self.ptr.as_ptr(), i as i32)
276            };
277            let ptr = NonNull::new(raw).ok_or_else(|| {
278                Error::null_pointer(format!(
279                    "TfLiteInterpreterGetInputTensor returned null for index {i}"
280                ))
281            })?;
282            inputs.push(TensorMut {
283                ptr,
284                lib: self.lib.as_sys(),
285            });
286        }
287        Ok(inputs)
288    }
289
290    /// Get immutable views of all output tensors.
291    ///
292    /// # Errors
293    ///
294    /// Returns an error if any output tensor pointer is null.
295    pub fn outputs(&self) -> Result<Vec<Tensor<'_>>> {
296        let count = self.output_count();
297        let mut outputs = Vec::with_capacity(count);
298        for i in 0..count {
299            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
300            // SAFETY: `self.ptr` is a valid interpreter and `i` is in bounds
301            // (below `output_count`).
302            let raw = unsafe {
303                self.lib
304                    .as_sys()
305                    .TfLiteInterpreterGetOutputTensor(self.ptr.as_ptr(), i as i32)
306            };
307            if raw.is_null() {
308                return Err(Error::null_pointer(format!(
309                    "TfLiteInterpreterGetOutputTensor returned null for index {i}"
310                )));
311            }
312            outputs.push(Tensor {
313                ptr: raw,
314                lib: self.lib.as_sys(),
315            });
316        }
317        Ok(outputs)
318    }
319
320    /// Returns the number of input tensors.
321    #[must_use]
322    pub fn input_count(&self) -> usize {
323        // SAFETY: `self.ptr` is a valid interpreter pointer.
324        #[allow(clippy::cast_sign_loss)]
325        let count = unsafe {
326            self.lib
327                .as_sys()
328                .TfLiteInterpreterGetInputTensorCount(self.ptr.as_ptr())
329        } as usize;
330        count
331    }
332
333    /// Returns the number of output tensors.
334    #[must_use]
335    pub fn output_count(&self) -> usize {
336        // SAFETY: `self.ptr` is a valid interpreter pointer.
337        #[allow(clippy::cast_sign_loss)]
338        let count = unsafe {
339            self.lib
340                .as_sys()
341                .TfLiteInterpreterGetOutputTensorCount(self.ptr.as_ptr())
342        } as usize;
343        count
344    }
345
346    /// Access all delegates owned by this interpreter.
347    #[must_use]
348    pub fn delegates(&self) -> &[Delegate] {
349        &self.delegates
350    }
351
352    /// Access a specific delegate by index.
353    #[must_use]
354    pub fn delegate(&self, index: usize) -> Option<&Delegate> {
355        self.delegates.get(index)
356    }
357}
358
359impl std::fmt::Debug for Interpreter<'_> {
360    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        f.debug_struct("Interpreter")
362            .field("ptr", &self.ptr)
363            .field("delegates", &self.delegates.len())
364            .finish()
365    }
366}
367
368impl Drop for Interpreter<'_> {
369    fn drop(&mut self) {
370        // SAFETY: The interpreter was created by `TfLiteInterpreterCreate` and
371        // has not been deleted. Delegates are dropped after the interpreter
372        // since they are stored in the same struct and Rust drops fields in
373        // declaration order.
374        unsafe {
375            self.lib.as_sys().TfLiteInterpreterDelete(self.ptr.as_ptr());
376        }
377    }
378}