Skip to main content

edgefirst_tflite/
model.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Model loading for `TFLite` inference.
5
6use std::os::raw::c_void;
7use std::path::Path;
8use std::ptr::NonNull;
9
10use edgefirst_tflite_sys::TfLiteModel;
11
12use crate::error::{Error, Result};
13use crate::Library;
14
15/// A loaded `TFLite` model.
16///
17/// Models can be created from in-memory bytes or from a file. The model
18/// data is kept alive for the lifetime of the `Model`.
19#[derive(Debug)]
20#[allow(clippy::struct_field_names)]
21pub struct Model<'lib> {
22    ptr: NonNull<TfLiteModel>,
23    model_mem: Vec<u8>,
24    lib: &'lib Library,
25}
26
27impl<'lib> Model<'lib> {
28    /// Create a `Model` from raw bytes.
29    ///
30    /// Takes ownership of the provided byte buffer and passes a pointer to
31    /// the underlying `TFLite` C API. The data is kept alive for the
32    /// lifetime of the returned `Model`.
33    pub fn from_bytes(lib: &'lib Library, data: impl Into<Vec<u8>>) -> Result<Self> {
34        let model_mem: Vec<u8> = data.into();
35        // SAFETY: We pass a valid pointer and length from the owned Vec.
36        // The Vec is stored in `model_mem` and lives as long as the Model,
37        // satisfying TFLite's requirement that the buffer outlives the model.
38        let raw = unsafe {
39            lib.as_sys()
40                .TfLiteModelCreate(model_mem.as_ptr().cast::<c_void>(), model_mem.len())
41        };
42        let ptr = NonNull::new(raw)
43            .ok_or_else(|| Error::null_pointer("TfLiteModelCreate returned null"))?;
44        Ok(Self {
45            ptr,
46            model_mem,
47            lib,
48        })
49    }
50
51    /// Create a `Model` by reading a file from disk.
52    ///
53    /// Reads the entire file into memory, then delegates to
54    /// [`Model::from_bytes`].
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if the file cannot be read (I/O error) or if the
59    /// `TFLite` C API fails to parse the model bytes (returns null).
60    pub fn from_file(lib: &'lib Library, path: impl AsRef<Path>) -> Result<Self> {
61        let data = std::fs::read(path.as_ref())
62            .map_err(|e| Error::invalid_argument(format!("{}: {e}", path.as_ref().display())))?;
63        Self::from_bytes(lib, data)
64    }
65
66    /// Returns the raw model data bytes.
67    #[must_use]
68    pub fn data(&self) -> &[u8] {
69        &self.model_mem
70    }
71
72    /// Returns the raw `TfLiteModel` pointer for use by the interpreter.
73    pub(crate) fn as_ptr(&self) -> *mut TfLiteModel {
74        self.ptr.as_ptr()
75    }
76}
77
78impl Drop for Model<'_> {
79    fn drop(&mut self) {
80        // SAFETY: `self.ptr` was created by `TfLiteModelCreate` and has not
81        // been deleted yet. The matching `TfLiteModelDelete` releases the
82        // model resources allocated by the C library.
83        unsafe {
84            self.lib.as_sys().TfLiteModelDelete(self.ptr.as_ptr());
85        }
86    }
87}