1#![allow(non_upper_case_globals)]
16
17use crate::bindings::*;
18use crate::call_check_status;
19use crate::error::{Error, ErrorCause};
20use crate::helper_funs::c_str_to_str;
21use crate::ElementType;
22use std::ffi::{c_char, c_void, CString};
23use std::marker::PhantomData;
24
25pub struct Model {
28 pub(crate) raw_model: LiteRtModel,
29}
30
31pub struct Subgraph<'a> {
34 raw_subgraph: LiteRtSubgraph,
35 _phantom: PhantomData<&'a LiteRtSubgraph>,
36}
37
38pub struct Signature<'a> {
41 raw_signature: LiteRtSignature,
42 _phantom: PhantomData<&'a LiteRtSignature>,
43}
44
45enum InputOutputNamesIteratorKind {
46 Input,
47 Output,
48}
49
50pub struct InputOutputNamesIterator<'a> {
52 signature: &'a Signature<'a>,
53 index: LiteRtParamIndex,
54 total_num_names: LiteRtParamIndex,
55 kind: InputOutputNamesIteratorKind,
56}
57
58impl Signature<'_> {
59 pub fn key(&self) -> Result<&str, Error> {
61 let mut key: *const c_char = std::ptr::null_mut();
62 call_check_status!(
63 unsafe { LiteRtGetSignatureKey(self.raw_signature, &mut key) },
66 ErrorCause::GetSignatureKey
67 );
68 unsafe { c_str_to_str(key) }
70 }
71
72 pub fn subgraph(&self) -> Result<Subgraph<'_>, Error> {
74 let mut raw_subgraph_ptr: LiteRtSubgraph = std::ptr::null_mut();
75 call_check_status!(
76 unsafe { LiteRtGetSignatureSubgraph(self.raw_signature, &mut raw_subgraph_ptr) },
79 ErrorCause::GetSignatureSubgraph
80 );
81 Ok(Subgraph { raw_subgraph: raw_subgraph_ptr, _phantom: PhantomData {} })
82 }
83
84 pub fn num_inputs(&self) -> Result<LiteRtParamIndex, Error> {
86 let mut num_inputs: LiteRtParamIndex = 0;
87 call_check_status!(
88 unsafe { LiteRtGetNumSignatureInputs(self.raw_signature, &mut num_inputs) },
90 ErrorCause::GetNumSignatureInputs
91 );
92 Ok(num_inputs)
93 }
94
95 pub fn input_names(&self) -> Result<InputOutputNamesIterator<'_>, Error> {
97 let num_inputs = self.num_inputs()?;
98 Ok(InputOutputNamesIterator {
99 signature: self,
100 index: 0,
101 total_num_names: num_inputs,
102 kind: InputOutputNamesIteratorKind::Input,
103 })
104 }
105
106 pub fn num_outputs(&self) -> Result<LiteRtParamIndex, Error> {
108 let mut num_outputs: LiteRtParamIndex = 0;
109 call_check_status!(
110 unsafe { LiteRtGetNumSignatureOutputs(self.raw_signature, &mut num_outputs) },
112 ErrorCause::GetNumSignatureOutputs
113 );
114 Ok(num_outputs)
115 }
116
117 pub fn output_names(&self) -> Result<InputOutputNamesIterator<'_>, Error> {
119 let num_outputs = self.num_outputs()?;
120 Ok(InputOutputNamesIterator {
121 signature: self,
122 index: 0,
123 total_num_names: num_outputs,
124 kind: InputOutputNamesIteratorKind::Output,
125 })
126 }
127}
128
129impl<'a> Iterator for InputOutputNamesIterator<'a> {
130 type Item = Result<&'a str, Error>;
131 fn next(&mut self) -> Option<Self::Item> {
132 if self.index >= self.total_num_names {
133 return None;
134 }
135 let mut name: *const c_char = std::ptr::null_mut();
136 unsafe {
139 let status = match self.kind {
140 InputOutputNamesIteratorKind::Input => {
141 LiteRtGetSignatureInputName(self.signature.raw_signature, self.index, &mut name)
142 }
143 InputOutputNamesIteratorKind::Output => LiteRtGetSignatureOutputName(
144 self.signature.raw_signature,
145 self.index,
146 &mut name,
147 ),
148 };
149 if status != LiteRtStatus_kLiteRtStatusOk {
150 return Some(Err(Error::new(ErrorCause::GetSignatureInputName, status)));
151 }
152 }
153 self.index += 1;
154 let name_str = unsafe { c_str_to_str(name) };
156 match name_str {
157 Ok(name) => Some(Ok(name)),
158 Err(e) => Some(Err(e)),
159 }
160 }
161 fn size_hint(&self) -> (usize, Option<usize>) {
162 (self.total_num_names, Some(self.total_num_names))
163 }
164}
165
166pub struct SignatureIterator<'a> {
168 model: &'a Model,
169 index: LiteRtParamIndex,
170 total_num_signatures: LiteRtParamIndex,
171}
172
173impl<'a> Iterator for SignatureIterator<'a> {
174 type Item = Result<Signature<'a>, Error>;
175 fn next(&mut self) -> Option<Self::Item> {
176 if self.index >= self.total_num_signatures {
177 return None;
178 }
179 let mut raw_signature_ptr: LiteRtSignature = std::ptr::null_mut();
180 unsafe {
184 let status =
185 LiteRtGetModelSignature(self.model.raw_model, self.index, &mut raw_signature_ptr);
186 self.index += 1;
187 if status != LiteRtStatus_kLiteRtStatusOk {
188 return Some(Err(Error::new(ErrorCause::GetSignature, status)));
189 }
190 }
191 Some(Ok(Signature { raw_signature: raw_signature_ptr, _phantom: PhantomData {} }))
192 }
193 fn size_hint(&self) -> (usize, Option<usize>) {
194 (self.total_num_signatures, Some(self.total_num_signatures))
195 }
196}
197
198pub struct Tensor<'a> {
201 raw_tensor: LiteRtTensor,
202 _phantom: PhantomData<&'a LiteRtTensor>,
203}
204
205impl<'a> Tensor<'a> {
206 fn type_id(&self) -> Result<LiteRtTensorTypeId, Error> {
207 let mut raw_tensor_type = LiteRtTensorTypeId_kLiteRtRankedTensorType;
208 call_check_status!(
209 unsafe { LiteRtGetTensorTypeId(self.raw_tensor, &mut raw_tensor_type) },
212 ErrorCause::GetTensorTypeId
213 );
214 Ok(raw_tensor_type)
215 }
216
217 pub fn unranked_tensor_type(&self) -> Result<LiteRtUnrankedTensorType, Error> {
219 let mut raw_tensor_type = LiteRtUnrankedTensorType { element_type: 0 };
220 call_check_status!(
221 unsafe { LiteRtGetUnrankedTensorType(self.raw_tensor, &mut raw_tensor_type) },
224 ErrorCause::GetUnrankedTensorType
225 );
226 Ok(raw_tensor_type)
227 }
228
229 pub fn ranked_tensor_type(&self) -> Result<LiteRtRankedTensorType, Error> {
231 let mut raw_tensor_type = LiteRtRankedTensorType::default();
232 call_check_status!(
233 unsafe { LiteRtGetRankedTensorType(self.raw_tensor, &mut raw_tensor_type) },
236 ErrorCause::GetRankedTensorType
237 );
238 Ok(raw_tensor_type)
239 }
240
241 pub fn element_type(&self) -> Result<ElementType, Error> {
243 match self.type_id()? {
244 LiteRtTensorTypeId_kLiteRtRankedTensorType => {
245 let rtt = self.ranked_tensor_type()?;
246 let rtt_element_type = ElementType::from_c_enum(rtt.element_type)?;
247 Ok(rtt_element_type)
248 }
249 LiteRtTensorTypeId_kLiteRtUnrankedTensorType => {
250 let utt = self.unranked_tensor_type()?;
252 let utt_element_type = ElementType::from_c_enum(utt.element_type)?;
253 Ok(utt_element_type)
254 }
255 _ => Err(Error::new(
256 ErrorCause::InvalidTensorTypeId,
257 LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
258 )),
259 }
260 }
261
262 pub fn name(&self) -> Result<&str, Error> {
264 let mut name: *const c_char = std::ptr::null_mut();
265 call_check_status!(
266 unsafe { LiteRtGetTensorName(self.raw_tensor, &mut name) },
269 ErrorCause::GetTensorName
270 );
271 unsafe { c_str_to_str(name) }
273 }
274}
275
276impl<'a> Subgraph<'a> {
277 pub fn num_inputs(&self) -> Result<LiteRtParamIndex, Error> {
279 let mut num_inputs: LiteRtParamIndex = 0;
280 call_check_status!(
281 unsafe { LiteRtGetNumSubgraphInputs(self.raw_subgraph, &mut num_inputs) },
283 ErrorCause::GetNumSubgraphInputs
284 );
285 Ok(num_inputs)
286 }
287
288 pub fn num_outputs(&self) -> Result<LiteRtParamIndex, Error> {
290 let mut num_outputs: LiteRtParamIndex = 0;
291 call_check_status!(
292 unsafe { LiteRtGetNumSubgraphOutputs(self.raw_subgraph, &mut num_outputs) },
294 ErrorCause::GetNumSubgraphOutputs
295 );
296 Ok(num_outputs)
297 }
298
299 fn input_tensor(&self, tensor_index: LiteRtParamIndex) -> Result<Tensor<'_>, Error> {
300 let mut raw_tensor_ptr: LiteRtTensor = std::ptr::null_mut();
301 call_check_status!(
302 unsafe { LiteRtGetSubgraphInput(self.raw_subgraph, tensor_index, &mut raw_tensor_ptr) },
305 ErrorCause::GetSubgraphInput
306 );
307 Ok(Tensor { raw_tensor: raw_tensor_ptr, _phantom: PhantomData {} })
308 }
309
310 pub fn input_tensor_by_name(&self, tensor_name: &str) -> Result<Tensor<'_>, Error> {
312 let num_inputs = self.num_inputs()?;
313 for i in 0..num_inputs {
314 let tensor = self.input_tensor(i)?;
315 if tensor.name()? == tensor_name {
316 return Ok(tensor);
317 }
318 }
319 return Err(Error::new(
320 ErrorCause::SubgraphInputTensorByNameNotFound,
321 LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
322 ));
323 }
324
325 fn output_tensor(&self, tensor_index: LiteRtParamIndex) -> Result<Tensor<'_>, Error> {
326 let mut raw_tensor_ptr: LiteRtTensor = std::ptr::null_mut();
327 call_check_status!(
328 unsafe {
331 LiteRtGetSubgraphOutput(self.raw_subgraph, tensor_index, &mut raw_tensor_ptr)
332 },
333 ErrorCause::GetSubgraphOutput
334 );
335 Ok(Tensor { raw_tensor: raw_tensor_ptr, _phantom: PhantomData {} })
336 }
337
338 pub fn output_tensor_by_name(&self, tensor_name: &str) -> Result<Tensor<'_>, Error> {
340 let num_outputs = self.num_outputs()?;
341 for i in 0..num_outputs {
342 let tensor = self.output_tensor(i)?;
343 if tensor.name()? == tensor_name {
344 return Ok(tensor);
345 }
346 }
347 return Err(Error::new(
348 ErrorCause::SubgraphOutputTensorByNameNotFound,
349 LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
350 ));
351 }
352}
353
354impl Model {
355 pub fn create_model_from_file(path: &str) -> Result<Self, Error> {
357 let path_c_string =
358 CString::new(path).expect("CString::new failed: string contains null bytes");
359 let c_ptr: *const c_char = path_c_string.as_ptr();
360 let mut raw_model_ptr: LiteRtModel = std::ptr::null_mut();
361 call_check_status!(
362 unsafe { LiteRtCreateModelFromFile(c_ptr, &mut raw_model_ptr) },
364 ErrorCause::CreateModelFromFile
365 );
366 Ok(Model { raw_model: raw_model_ptr })
367 }
368
369 pub fn create_model_from_buffer(buffer: &mut [u8]) -> Result<Self, Error> {
371 let mut raw_model_ptr: LiteRtModel = std::ptr::null_mut();
372 call_check_status!(
373 unsafe {
375 LiteRtCreateModelFromBuffer(
376 buffer.as_ptr() as *const c_void,
377 buffer.len(),
378 &mut raw_model_ptr,
379 )
380 },
381 ErrorCause::CreateModelFromBuffer
382 );
383 Ok(Model { raw_model: raw_model_ptr })
384 }
385
386 pub fn num_subgraphs(&self) -> Result<LiteRtParamIndex, Error> {
388 let mut num_subgraphs: LiteRtParamIndex = 0;
389 call_check_status!(
390 unsafe { LiteRtGetNumModelSubgraphs(self.raw_model, &mut num_subgraphs) },
392 ErrorCause::GetNumModelSubgraphs
393 );
394 Ok(num_subgraphs)
395 }
396
397 pub fn num_signatures(&self) -> Result<LiteRtParamIndex, Error> {
399 let mut num_signatures: LiteRtParamIndex = 0;
400 call_check_status!(
401 unsafe { LiteRtGetNumModelSignatures(self.raw_model, &mut num_signatures) },
404 ErrorCause::GetNumModelSignatures
405 );
406 Ok(num_signatures)
407 }
408
409 pub fn signatures(&self) -> Result<SignatureIterator<'_>, Error> {
411 Ok(SignatureIterator {
412 model: self,
413 index: 0,
414 total_num_signatures: self.num_signatures()?,
415 })
416 }
417
418 pub fn signature(&self, index: LiteRtParamIndex) -> Result<Signature<'_>, Error> {
420 let mut raw_signature_ptr: LiteRtSignature = std::ptr::null_mut();
421 call_check_status!(
422 unsafe { LiteRtGetModelSignature(self.raw_model, index, &mut raw_signature_ptr) },
425 ErrorCause::GetModelSignature
426 );
427 Ok(Signature { raw_signature: raw_signature_ptr, _phantom: PhantomData{} })
428 }
429}
430
431impl Drop for Model {
432 fn drop(&mut self) {
433 unsafe {
436 LiteRtDestroyModel(self.raw_model);
437 }
438 }
439}