tflite_dyn/
tflite.rs

1use std::{
2    ffi::{CStr, OsStr},
3    fmt::{self, Debug, Formatter},
4    marker::PhantomData,
5    sync::Arc,
6};
7
8use libloading::Library;
9
10use crate::{sys, Error, Type, XnnPackDelegateOptions};
11
12pub struct TfLite {
13    vt: Arc<sys::TfLiteVt>,
14    xnnpack_vt: Option<Arc<sys::XnnPackVt>>,
15}
16
17impl TfLite {
18    pub fn load<P: AsRef<OsStr>>(path: P) -> Result<Self, Error> {
19        let library = Arc::new(unsafe { Library::new(path).unwrap() });
20
21        let vt = sys::TfLiteVt::load(library.clone())?;
22        let xnnpack_vt = sys::XnnPackVt::load(library);
23
24        Ok(Self {
25            vt: Arc::new(vt),
26            xnnpack_vt: xnnpack_vt.map(Arc::new),
27        })
28    }
29
30    pub fn version(&self) -> &CStr {
31        unsafe { CStr::from_ptr((self.vt.version)()) }
32    }
33
34    pub fn model_create(&self, data: Vec<u8>) -> Result<Model, Error> {
35        let model = unsafe { (self.vt.model_create)(data.as_ptr() as *const _, data.len()) };
36
37        if model.is_null() {
38            Err(Error::Generic)
39        } else {
40            Ok(Model {
41                vt: self.vt.clone(),
42                model,
43                _data: data,
44            })
45        }
46    }
47
48    pub fn interpreter_options_create(&self) -> InterpreterOptions {
49        InterpreterOptions {
50            vt: self.vt.clone(),
51            options: unsafe { (self.vt.interpreter_options_create)() },
52            delegates: Vec::new(),
53        }
54    }
55
56    pub fn interpreter_create(&self, model: Model, mut options: InterpreterOptions) -> Interpreter {
57        let interpreter = unsafe { (self.vt.interpreter_create)(model.model, options.options) };
58
59        Interpreter {
60            vt: self.vt.clone(),
61            interpreter,
62            _model: model,
63            _delegates: std::mem::take(&mut options.delegates),
64        }
65    }
66
67    pub fn xnnpack_delegate_options_default(&self) -> XnnPackDelegateOptions {
68        let xnnpack_vt = self.xnnpack_vt.as_ref().expect("xnnpack not available");
69        unsafe { (xnnpack_vt.delegate_options_default)() }
70    }
71
72    pub fn xnnpack_delegate_create(&self, options: &XnnPackDelegateOptions) -> Delegate {
73        let xnnpack_vt = self.xnnpack_vt.as_ref().expect("xnnpack not available");
74        let delegate = unsafe { (xnnpack_vt.delegate_create)(options) };
75
76        Delegate {
77            _vt: self.vt.clone(),
78            delegate,
79            destructor: xnnpack_vt.delegate_delete,
80        }
81    }
82}
83
84pub struct Model {
85    vt: Arc<sys::TfLiteVt>,
86    model: *mut sys::TfLiteModel,
87    _data: Vec<u8>,
88}
89
90impl Drop for Model {
91    fn drop(&mut self) {
92        unsafe { (self.vt.model_delete)(self.model) };
93    }
94}
95
96pub struct InterpreterOptions {
97    vt: Arc<sys::TfLiteVt>,
98    options: *mut sys::TfLiteInterpreterOptions,
99    delegates: Vec<Delegate>,
100}
101
102impl InterpreterOptions {
103    pub fn set_num_threads(&mut self, num_threads: i32) {
104        unsafe { (self.vt.interpreter_options_set_num_threads)(self.options, num_threads) };
105    }
106
107    pub fn add_delegate(&mut self, delegate: Delegate) {
108        unsafe { (self.vt.interpreter_options_add_delegate)(self.options, delegate.delegate) };
109        self.delegates.push(delegate);
110    }
111}
112
113impl Drop for InterpreterOptions {
114    fn drop(&mut self) {
115        unsafe { (self.vt.interpreter_options_delete)(self.options) };
116    }
117}
118
119pub struct Delegate {
120    _vt: Arc<sys::TfLiteVt>,
121    delegate: *mut sys::TfLiteDelegate,
122    destructor: unsafe extern "C" fn(*mut sys::TfLiteDelegate),
123}
124
125impl Drop for Delegate {
126    fn drop(&mut self) {
127        unsafe { (self.destructor)(self.delegate) };
128    }
129}
130
131pub struct Interpreter {
132    vt: Arc<sys::TfLiteVt>,
133    interpreter: *mut sys::TfLiteInterpreter,
134    _model: Model,
135    _delegates: Vec<Delegate>,
136}
137
138impl Interpreter {
139    pub fn input_tensor_count(&self) -> i32 {
140        unsafe { (self.vt.interpreter_get_input_tensor_count)(self.interpreter) }
141    }
142
143    pub fn input_tensor(&self, index: i32) -> Option<Tensor> {
144        let tensor = unsafe { (self.vt.interpreter_get_input_tensor)(self.interpreter, index) };
145
146        if tensor.is_null() {
147            None
148        } else {
149            Some(Tensor {
150                vt: self.vt.clone(),
151                tensor,
152                _p: PhantomData,
153            })
154        }
155    }
156
157    pub fn allocate_tensors(&mut self) -> Result<(), Error> {
158        let status = unsafe { (self.vt.interpreter_allocate_tensors)(self.interpreter) };
159
160        if status == sys::TfLiteStatus::Ok {
161            Ok(())
162        } else {
163            Err(Error::ErrorStatus(status))
164        }
165    }
166
167    pub fn invoke(&mut self) -> Result<(), Error> {
168        let status = unsafe { (self.vt.interpreter_invoke)(self.interpreter) };
169
170        if status == sys::TfLiteStatus::Ok {
171            Ok(())
172        } else {
173            Err(Error::ErrorStatus(status))
174        }
175    }
176
177    pub fn output_tensor_count(&self) -> i32 {
178        unsafe { (self.vt.interpreter_get_output_tensor_count)(self.interpreter) }
179    }
180
181    pub fn output_tensor(&self, index: i32) -> Option<Tensor> {
182        let tensor = unsafe { (self.vt.interpreter_get_output_tensor)(self.interpreter, index) };
183
184        if tensor.is_null() {
185            None
186        } else {
187            Some(Tensor {
188                vt: self.vt.clone(),
189                tensor,
190                _p: PhantomData,
191            })
192        }
193    }
194}
195
196impl Drop for Interpreter {
197    fn drop(&mut self) {
198        unsafe { (self.vt.interpreter_delete)(self.interpreter) };
199    }
200}
201
202pub struct Tensor<'a> {
203    vt: Arc<sys::TfLiteVt>,
204    tensor: *mut sys::TfLiteTensor,
205    _p: PhantomData<&'a ()>,
206}
207
208impl<'a> Tensor<'a> {
209    pub fn type_(&self) -> Type {
210        unsafe { (self.vt.tensor_type)(self.tensor) }
211    }
212
213    pub fn num_dims(&self) -> i32 {
214        unsafe { (self.vt.tensor_num_dims)(self.tensor) }
215    }
216
217    pub fn dim(&self, index: i32) -> i32 {
218        unsafe { (self.vt.tensor_dim)(self.tensor, index) }
219    }
220
221    pub fn data(&self) -> Option<&'a [u8]> {
222        unsafe {
223            let data = (self.vt.tensor_data)(self.tensor) as *const u8;
224            if data.is_null() {
225                return None;
226            }
227
228            let len = (self.vt.tensor_byte_size)(self.tensor);
229            Some(std::slice::from_raw_parts(data, len))
230        }
231    }
232
233    pub fn data_mut(&mut self) -> Option<&'a mut [u8]> {
234        unsafe {
235            let data = (self.vt.tensor_data)(self.tensor) as *mut u8;
236            if data.is_null() {
237                return None;
238            }
239
240            let len = (self.vt.tensor_byte_size)(self.tensor);
241            Some(std::slice::from_raw_parts_mut(data, len))
242        }
243    }
244
245    pub fn name(&self) -> &CStr {
246        unsafe { CStr::from_ptr((self.vt.tensor_name)(self.tensor)) }
247    }
248}
249
250impl<'a> Debug for Tensor<'a> {
251    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
252        write!(
253            f,
254            "Tensor {{ name: {:?}, type: {:?}, dims: [",
255            self.name(),
256            self.type_()
257        )?;
258
259        for i in 0..self.num_dims() {
260            write!(f, "{}", self.dim(i))?;
261            if i < self.num_dims() - 1 {
262                write!(f, ", ")?;
263            }
264        }
265
266        write!(f, "] }}")
267    }
268}