1use std::{
2 ffi::{CStr, OsStr},
3 fmt::{self, Debug, Formatter},
4 marker::PhantomData,
5 sync::Arc,
6};
7
8use libloading::Library;
9
10use crate::{sys, Error, Type, XnnPackDelegateOptions};
11
12pub struct TfLite {
13 vt: Arc<sys::TfLiteVt>,
14 xnnpack_vt: Option<Arc<sys::XnnPackVt>>,
15}
16
17impl TfLite {
18 pub fn load<P: AsRef<OsStr>>(path: P) -> Result<Self, Error> {
19 let library = Arc::new(unsafe { Library::new(path).unwrap() });
20
21 let vt = sys::TfLiteVt::load(library.clone())?;
22 let xnnpack_vt = sys::XnnPackVt::load(library);
23
24 Ok(Self {
25 vt: Arc::new(vt),
26 xnnpack_vt: xnnpack_vt.map(Arc::new),
27 })
28 }
29
30 pub fn version(&self) -> &CStr {
31 unsafe { CStr::from_ptr((self.vt.version)()) }
32 }
33
34 pub fn model_create(&self, data: Vec<u8>) -> Result<Model, Error> {
35 let model = unsafe { (self.vt.model_create)(data.as_ptr() as *const _, data.len()) };
36
37 if model.is_null() {
38 Err(Error::Generic)
39 } else {
40 Ok(Model {
41 vt: self.vt.clone(),
42 model,
43 _data: data,
44 })
45 }
46 }
47
48 pub fn interpreter_options_create(&self) -> InterpreterOptions {
49 InterpreterOptions {
50 vt: self.vt.clone(),
51 options: unsafe { (self.vt.interpreter_options_create)() },
52 delegates: Vec::new(),
53 }
54 }
55
56 pub fn interpreter_create(&self, model: Model, mut options: InterpreterOptions) -> Interpreter {
57 let interpreter = unsafe { (self.vt.interpreter_create)(model.model, options.options) };
58
59 Interpreter {
60 vt: self.vt.clone(),
61 interpreter,
62 _model: model,
63 _delegates: std::mem::take(&mut options.delegates),
64 }
65 }
66
67 pub fn xnnpack_delegate_options_default(&self) -> XnnPackDelegateOptions {
68 let xnnpack_vt = self.xnnpack_vt.as_ref().expect("xnnpack not available");
69 unsafe { (xnnpack_vt.delegate_options_default)() }
70 }
71
72 pub fn xnnpack_delegate_create(&self, options: &XnnPackDelegateOptions) -> Delegate {
73 let xnnpack_vt = self.xnnpack_vt.as_ref().expect("xnnpack not available");
74 let delegate = unsafe { (xnnpack_vt.delegate_create)(options) };
75
76 Delegate {
77 _vt: self.vt.clone(),
78 delegate,
79 destructor: xnnpack_vt.delegate_delete,
80 }
81 }
82}
83
84pub struct Model {
85 vt: Arc<sys::TfLiteVt>,
86 model: *mut sys::TfLiteModel,
87 _data: Vec<u8>,
88}
89
90impl Drop for Model {
91 fn drop(&mut self) {
92 unsafe { (self.vt.model_delete)(self.model) };
93 }
94}
95
96pub struct InterpreterOptions {
97 vt: Arc<sys::TfLiteVt>,
98 options: *mut sys::TfLiteInterpreterOptions,
99 delegates: Vec<Delegate>,
100}
101
102impl InterpreterOptions {
103 pub fn set_num_threads(&mut self, num_threads: i32) {
104 unsafe { (self.vt.interpreter_options_set_num_threads)(self.options, num_threads) };
105 }
106
107 pub fn add_delegate(&mut self, delegate: Delegate) {
108 unsafe { (self.vt.interpreter_options_add_delegate)(self.options, delegate.delegate) };
109 self.delegates.push(delegate);
110 }
111}
112
113impl Drop for InterpreterOptions {
114 fn drop(&mut self) {
115 unsafe { (self.vt.interpreter_options_delete)(self.options) };
116 }
117}
118
119pub struct Delegate {
120 _vt: Arc<sys::TfLiteVt>,
121 delegate: *mut sys::TfLiteDelegate,
122 destructor: unsafe extern "C" fn(*mut sys::TfLiteDelegate),
123}
124
125impl Drop for Delegate {
126 fn drop(&mut self) {
127 unsafe { (self.destructor)(self.delegate) };
128 }
129}
130
131pub struct Interpreter {
132 vt: Arc<sys::TfLiteVt>,
133 interpreter: *mut sys::TfLiteInterpreter,
134 _model: Model,
135 _delegates: Vec<Delegate>,
136}
137
138impl Interpreter {
139 pub fn input_tensor_count(&self) -> i32 {
140 unsafe { (self.vt.interpreter_get_input_tensor_count)(self.interpreter) }
141 }
142
143 pub fn input_tensor(&self, index: i32) -> Option<Tensor> {
144 let tensor = unsafe { (self.vt.interpreter_get_input_tensor)(self.interpreter, index) };
145
146 if tensor.is_null() {
147 None
148 } else {
149 Some(Tensor {
150 vt: self.vt.clone(),
151 tensor,
152 _p: PhantomData,
153 })
154 }
155 }
156
157 pub fn allocate_tensors(&mut self) -> Result<(), Error> {
158 let status = unsafe { (self.vt.interpreter_allocate_tensors)(self.interpreter) };
159
160 if status == sys::TfLiteStatus::Ok {
161 Ok(())
162 } else {
163 Err(Error::ErrorStatus(status))
164 }
165 }
166
167 pub fn invoke(&mut self) -> Result<(), Error> {
168 let status = unsafe { (self.vt.interpreter_invoke)(self.interpreter) };
169
170 if status == sys::TfLiteStatus::Ok {
171 Ok(())
172 } else {
173 Err(Error::ErrorStatus(status))
174 }
175 }
176
177 pub fn output_tensor_count(&self) -> i32 {
178 unsafe { (self.vt.interpreter_get_output_tensor_count)(self.interpreter) }
179 }
180
181 pub fn output_tensor(&self, index: i32) -> Option<Tensor> {
182 let tensor = unsafe { (self.vt.interpreter_get_output_tensor)(self.interpreter, index) };
183
184 if tensor.is_null() {
185 None
186 } else {
187 Some(Tensor {
188 vt: self.vt.clone(),
189 tensor,
190 _p: PhantomData,
191 })
192 }
193 }
194}
195
196impl Drop for Interpreter {
197 fn drop(&mut self) {
198 unsafe { (self.vt.interpreter_delete)(self.interpreter) };
199 }
200}
201
202pub struct Tensor<'a> {
203 vt: Arc<sys::TfLiteVt>,
204 tensor: *mut sys::TfLiteTensor,
205 _p: PhantomData<&'a ()>,
206}
207
208impl<'a> Tensor<'a> {
209 pub fn type_(&self) -> Type {
210 unsafe { (self.vt.tensor_type)(self.tensor) }
211 }
212
213 pub fn num_dims(&self) -> i32 {
214 unsafe { (self.vt.tensor_num_dims)(self.tensor) }
215 }
216
217 pub fn dim(&self, index: i32) -> i32 {
218 unsafe { (self.vt.tensor_dim)(self.tensor, index) }
219 }
220
221 pub fn data(&self) -> Option<&'a [u8]> {
222 unsafe {
223 let data = (self.vt.tensor_data)(self.tensor) as *const u8;
224 if data.is_null() {
225 return None;
226 }
227
228 let len = (self.vt.tensor_byte_size)(self.tensor);
229 Some(std::slice::from_raw_parts(data, len))
230 }
231 }
232
233 pub fn data_mut(&mut self) -> Option<&'a mut [u8]> {
234 unsafe {
235 let data = (self.vt.tensor_data)(self.tensor) as *mut u8;
236 if data.is_null() {
237 return None;
238 }
239
240 let len = (self.vt.tensor_byte_size)(self.tensor);
241 Some(std::slice::from_raw_parts_mut(data, len))
242 }
243 }
244
245 pub fn name(&self) -> &CStr {
246 unsafe { CStr::from_ptr((self.vt.tensor_name)(self.tensor)) }
247 }
248}
249
250impl<'a> Debug for Tensor<'a> {
251 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
252 write!(
253 f,
254 "Tensor {{ name: {:?}, type: {:?}, dims: [",
255 self.name(),
256 self.type_()
257 )?;
258
259 for i in 0..self.num_dims() {
260 write!(f, "{}", self.dim(i))?;
261 if i < self.num_dims() - 1 {
262 write!(f, ", ")?;
263 }
264 }
265
266 write!(f, "] }}")
267 }
268}