edgefirst_tflite/model.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Model loading for `TFLite` inference.
5
6use std::os::raw::c_void;
7use std::path::Path;
8use std::ptr::NonNull;
9
10use edgefirst_tflite_sys::TfLiteModel;
11
12use crate::error::{Error, Result};
13use crate::Library;
14
15/// A loaded `TFLite` model.
16///
17/// Models can be created from in-memory bytes or from a file. The model
18/// data is kept alive for the lifetime of the `Model`.
19#[derive(Debug)]
20#[allow(clippy::struct_field_names)]
21pub struct Model<'lib> {
22 ptr: NonNull<TfLiteModel>,
23 model_mem: Vec<u8>,
24 lib: &'lib Library,
25}
26
27impl<'lib> Model<'lib> {
28 /// Create a `Model` from raw bytes.
29 ///
30 /// Takes ownership of the provided byte buffer and passes a pointer to
31 /// the underlying `TFLite` C API. The data is kept alive for the
32 /// lifetime of the returned `Model`.
33 pub fn from_bytes(lib: &'lib Library, data: impl Into<Vec<u8>>) -> Result<Self> {
34 let model_mem: Vec<u8> = data.into();
35 // SAFETY: We pass a valid pointer and length from the owned Vec.
36 // The Vec is stored in `model_mem` and lives as long as the Model,
37 // satisfying TFLite's requirement that the buffer outlives the model.
38 let raw = unsafe {
39 lib.as_sys()
40 .TfLiteModelCreate(model_mem.as_ptr().cast::<c_void>(), model_mem.len())
41 };
42 let ptr = NonNull::new(raw)
43 .ok_or_else(|| Error::null_pointer("TfLiteModelCreate returned null"))?;
44 Ok(Self {
45 ptr,
46 model_mem,
47 lib,
48 })
49 }
50
51 /// Create a `Model` by reading a file from disk.
52 ///
53 /// Reads the entire file into memory, then delegates to
54 /// [`Model::from_bytes`].
55 ///
56 /// # Errors
57 ///
58 /// Returns an error if the file cannot be read (I/O error) or if the
59 /// `TFLite` C API fails to parse the model bytes (returns null).
60 pub fn from_file(lib: &'lib Library, path: impl AsRef<Path>) -> Result<Self> {
61 let data = std::fs::read(path.as_ref())
62 .map_err(|e| Error::invalid_argument(format!("{}: {e}", path.as_ref().display())))?;
63 Self::from_bytes(lib, data)
64 }
65
66 /// Returns the raw model data bytes.
67 #[must_use]
68 pub fn data(&self) -> &[u8] {
69 &self.model_mem
70 }
71
72 /// Returns the raw `TfLiteModel` pointer for use by the interpreter.
73 pub(crate) fn as_ptr(&self) -> *mut TfLiteModel {
74 self.ptr.as_ptr()
75 }
76}
77
78impl Drop for Model<'_> {
79 fn drop(&mut self) {
80 // SAFETY: `self.ptr` was created by `TfLiteModelCreate` and has not
81 // been deleted yet. The matching `TfLiteModelDelete` releases the
82 // model resources allocated by the C library.
83 unsafe {
84 self.lib.as_sys().TfLiteModelDelete(self.ptr.as_ptr());
85 }
86 }
87}