Skip to main content

litert/
model.rs

1// Copyright 2026 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(non_upper_case_globals)]
16
17use crate::bindings::*;
18use crate::call_check_status;
19use crate::error::{Error, ErrorCause};
20use crate::helper_funs::c_str_to_str;
21use crate::ElementType;
22use std::ffi::{c_char, c_void, CString};
23use std::marker::PhantomData;
24
25/// `Model` is a wrapper around the LiteRtModel C struct.
26/// Usually represents a model loaded from a file.
27pub struct Model {
28    pub(crate) raw_model: LiteRtModel,
29}
30
31/// `Subgraph` is a wrapper around the LiteRtSubgraph C struct.
32/// It represents a subgraph of a model.
33pub struct Subgraph<'a> {
34    raw_subgraph: LiteRtSubgraph,
35    _phantom: PhantomData<&'a LiteRtSubgraph>,
36}
37
38/// `Signature` is a wrapper around the LiteRtSignature C struct.
39/// It represents a signature of a model.
40pub struct Signature<'a> {
41    raw_signature: LiteRtSignature,
42    _phantom: PhantomData<&'a LiteRtSignature>,
43}
44
45enum InputOutputNamesIteratorKind {
46    Input,
47    Output,
48}
49
50/// An iterator over the input or output names of a signature.
51pub struct InputOutputNamesIterator<'a> {
52    signature: &'a Signature<'a>,
53    index: LiteRtParamIndex,
54    total_num_names: LiteRtParamIndex,
55    kind: InputOutputNamesIteratorKind,
56}
57
58impl Signature<'_> {
59    /// Returns the key of the signature.
60    pub fn key(&self) -> Result<&str, Error> {
61        let mut key: *const c_char = std::ptr::null_mut();
62        call_check_status!(
63            // SAFETY: self.raw_signature is always valid as it's initialized by a wrapper function.
64            // We assume that the output is valid if the return status is OK or don't use the output pointers.
65            unsafe { LiteRtGetSignatureKey(self.raw_signature, &mut key) },
66            ErrorCause::GetSignatureKey
67        );
68        // SAFETY: We assume that if C API returns OK then the output is valid.
69        unsafe { c_str_to_str(key) }
70    }
71
72    /// Returns the subgraph associated with the signature.
73    pub fn subgraph(&self) -> Result<Subgraph<'_>, Error> {
74        let mut raw_subgraph_ptr: LiteRtSubgraph = std::ptr::null_mut();
75        call_check_status!(
76            // SAFETY: self.raw_signature is always valid as it's initialized by a wrapper function.
77            // We assume that the output is valid if the return status is OK or don't use the output pointers.
78            unsafe { LiteRtGetSignatureSubgraph(self.raw_signature, &mut raw_subgraph_ptr) },
79            ErrorCause::GetSignatureSubgraph
80        );
81        Ok(Subgraph { raw_subgraph: raw_subgraph_ptr, _phantom: PhantomData {} })
82    }
83
84    /// Returns the number of inputs of the signature.
85    pub fn num_inputs(&self) -> Result<LiteRtParamIndex, Error> {
86        let mut num_inputs: LiteRtParamIndex = 0;
87        call_check_status!(
88            // SAFETY: self.raw_signature is always valid as it's initialized by a wrapper function.
89            unsafe { LiteRtGetNumSignatureInputs(self.raw_signature, &mut num_inputs) },
90            ErrorCause::GetNumSignatureInputs
91        );
92        Ok(num_inputs)
93    }
94
95    /// Returns an iterator over the input names of the signature.
96    pub fn input_names(&self) -> Result<InputOutputNamesIterator<'_>, Error> {
97        let num_inputs = self.num_inputs()?;
98        Ok(InputOutputNamesIterator {
99            signature: self,
100            index: 0,
101            total_num_names: num_inputs,
102            kind: InputOutputNamesIteratorKind::Input,
103        })
104    }
105
106    /// Returns the number of outputs of the signature.
107    pub fn num_outputs(&self) -> Result<LiteRtParamIndex, Error> {
108        let mut num_outputs: LiteRtParamIndex = 0;
109        call_check_status!(
110            // SAFETY: self.raw_signature is always valid as it's initialized by a wrapper function.
111            unsafe { LiteRtGetNumSignatureOutputs(self.raw_signature, &mut num_outputs) },
112            ErrorCause::GetNumSignatureOutputs
113        );
114        Ok(num_outputs)
115    }
116
117    /// Returns an iterator over the output names of the signature.
118    pub fn output_names(&self) -> Result<InputOutputNamesIterator<'_>, Error> {
119        let num_outputs = self.num_outputs()?;
120        Ok(InputOutputNamesIterator {
121            signature: self,
122            index: 0,
123            total_num_names: num_outputs,
124            kind: InputOutputNamesIteratorKind::Output,
125        })
126    }
127}
128
129impl<'a> Iterator for InputOutputNamesIterator<'a> {
130    type Item = Result<&'a str, Error>;
131    fn next(&mut self) -> Option<Self::Item> {
132        if self.index >= self.total_num_names {
133            return None;
134        }
135        let mut name: *const c_char = std::ptr::null_mut();
136        // SAFETY: self.raw_signature is always valid as it's initialized by a wrapper function.
137        // We assume that the output is valid if the return status is OK or don't use the output pointers.
138        unsafe {
139            let status = match self.kind {
140                InputOutputNamesIteratorKind::Input => {
141                    LiteRtGetSignatureInputName(self.signature.raw_signature, self.index, &mut name)
142                }
143                InputOutputNamesIteratorKind::Output => LiteRtGetSignatureOutputName(
144                    self.signature.raw_signature,
145                    self.index,
146                    &mut name,
147                ),
148            };
149            if status != LiteRtStatus_kLiteRtStatusOk {
150                return Some(Err(Error::new(ErrorCause::GetSignatureInputName, status)));
151            }
152        }
153        self.index += 1;
154        // SAFETY: We assume that if C API returns OK then the output is valid.
155        let name_str = unsafe { c_str_to_str(name) };
156        match name_str {
157            Ok(name) => Some(Ok(name)),
158            Err(e) => Some(Err(e)),
159        }
160    }
161    fn size_hint(&self) -> (usize, Option<usize>) {
162        (self.total_num_names, Some(self.total_num_names))
163    }
164}
165
166/// An iterator over the signatures of a model.
167pub struct SignatureIterator<'a> {
168    model: &'a Model,
169    index: LiteRtParamIndex,
170    total_num_signatures: LiteRtParamIndex,
171}
172
173impl<'a> Iterator for SignatureIterator<'a> {
174    type Item = Result<Signature<'a>, Error>;
175    fn next(&mut self) -> Option<Self::Item> {
176        if self.index >= self.total_num_signatures {
177            return None;
178        }
179        let mut raw_signature_ptr: LiteRtSignature = std::ptr::null_mut();
180        // SAFETY: self.model.raw_model is always valid as it's initialized by a wrapper function.
181        // self.index is always valid, it is explicitly limited to the valid range.
182        // We assume that the output is valid if the return status is OK or don't use the output pointers.
183        unsafe {
184            let status =
185                LiteRtGetModelSignature(self.model.raw_model, self.index, &mut raw_signature_ptr);
186            self.index += 1;
187            if status != LiteRtStatus_kLiteRtStatusOk {
188                return Some(Err(Error::new(ErrorCause::GetSignature, status)));
189            }
190        }
191        Some(Ok(Signature { raw_signature: raw_signature_ptr, _phantom: PhantomData {} }))
192    }
193    fn size_hint(&self) -> (usize, Option<usize>) {
194        (self.total_num_signatures, Some(self.total_num_signatures))
195    }
196}
197
198/// `Tensor` is a wrapper around the LiteRtTensor C struct.
199/// It represents a tensor in a model.
200pub struct Tensor<'a> {
201    raw_tensor: LiteRtTensor,
202    _phantom: PhantomData<&'a LiteRtTensor>,
203}
204
205impl<'a> Tensor<'a> {
206    fn type_id(&self) -> Result<LiteRtTensorTypeId, Error> {
207        let mut raw_tensor_type = LiteRtTensorTypeId_kLiteRtRankedTensorType;
208        call_check_status!(
209            // SAFETY: self.raw_tensor is always valid as it's initialized by a wrapper function.
210            // We assume that the output is valid if the return status is OK or don't use the output pointer.
211            unsafe { LiteRtGetTensorTypeId(self.raw_tensor, &mut raw_tensor_type) },
212            ErrorCause::GetTensorTypeId
213        );
214        Ok(raw_tensor_type)
215    }
216
217    /// Returns the unranked tensor type.
218    pub fn unranked_tensor_type(&self) -> Result<LiteRtUnrankedTensorType, Error> {
219        let mut raw_tensor_type = LiteRtUnrankedTensorType { element_type: 0 };
220        call_check_status!(
221            // SAFETY: self.raw_tensor is always valid as it's initialized by a wrapper function.
222            // We assume that the output is valid if the return status is OK or don't use the output pointer.
223            unsafe { LiteRtGetUnrankedTensorType(self.raw_tensor, &mut raw_tensor_type) },
224            ErrorCause::GetUnrankedTensorType
225        );
226        Ok(raw_tensor_type)
227    }
228
229    /// Returns the ranked tensor type.
230    pub fn ranked_tensor_type(&self) -> Result<LiteRtRankedTensorType, Error> {
231        let mut raw_tensor_type = LiteRtRankedTensorType::default();
232        call_check_status!(
233            // SAFETY: self.raw_tensor is always valid as it's initialized by a wrapper function.
234            // We assume that the output is valid if the return status is OK or don't use the output pointer.
235            unsafe { LiteRtGetRankedTensorType(self.raw_tensor, &mut raw_tensor_type) },
236            ErrorCause::GetRankedTensorType
237        );
238        Ok(raw_tensor_type)
239    }
240
241    /// Returns the element type of the tensor.
242    pub fn element_type(&self) -> Result<ElementType, Error> {
243        match self.type_id()? {
244            LiteRtTensorTypeId_kLiteRtRankedTensorType => {
245                let rtt = self.ranked_tensor_type()?;
246                let rtt_element_type = ElementType::from_c_enum(rtt.element_type)?;
247                Ok(rtt_element_type)
248            }
249            LiteRtTensorTypeId_kLiteRtUnrankedTensorType => {
250                // TODO(mgubin): Add support for layout.
251                let utt = self.unranked_tensor_type()?;
252                let utt_element_type = ElementType::from_c_enum(utt.element_type)?;
253                Ok(utt_element_type)
254            }
255            _ => Err(Error::new(
256                ErrorCause::InvalidTensorTypeId,
257                LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
258            )),
259        }
260    }
261
262    /// Returns the name of the tensor.
263    pub fn name(&self) -> Result<&str, Error> {
264        let mut name: *const c_char = std::ptr::null_mut();
265        call_check_status!(
266            // SAFETY: self.raw_tensor is always valid as it's initialized by a wrapper function.
267            // We assume that the output is valid if the return status is OK or don't use the output pointer.
268            unsafe { LiteRtGetTensorName(self.raw_tensor, &mut name) },
269            ErrorCause::GetTensorName
270        );
271        // SAFETY: We assume that if C API returns OK then the output is valid.
272        unsafe { c_str_to_str(name) }
273    }
274}
275
276impl<'a> Subgraph<'a> {
277    /// Returns the number of inputs of the subgraph.
278    pub fn num_inputs(&self) -> Result<LiteRtParamIndex, Error> {
279        let mut num_inputs: LiteRtParamIndex = 0;
280        call_check_status!(
281            // SAFETY: self.raw_subgraph is always valid as it's initialized by a wrapper function.
282            unsafe { LiteRtGetNumSubgraphInputs(self.raw_subgraph, &mut num_inputs) },
283            ErrorCause::GetNumSubgraphInputs
284        );
285        Ok(num_inputs)
286    }
287
288    /// Returns the number of outputs of the subgraph.
289    pub fn num_outputs(&self) -> Result<LiteRtParamIndex, Error> {
290        let mut num_outputs: LiteRtParamIndex = 0;
291        call_check_status!(
292            // SAFETY: self.raw_subgraph is always valid as it's initialized by a wrapper function.
293            unsafe { LiteRtGetNumSubgraphOutputs(self.raw_subgraph, &mut num_outputs) },
294            ErrorCause::GetNumSubgraphOutputs
295        );
296        Ok(num_outputs)
297    }
298
299    fn input_tensor(&self, tensor_index: LiteRtParamIndex) -> Result<Tensor<'_>, Error> {
300        let mut raw_tensor_ptr: LiteRtTensor = std::ptr::null_mut();
301        call_check_status!(
302            // SAFETY: self.raw_subgraph is always valid as it's initialized by a wrapper function.
303            // We assume that the output is valid if the return status is OK or don't use the output pointers.
304            unsafe { LiteRtGetSubgraphInput(self.raw_subgraph, tensor_index, &mut raw_tensor_ptr) },
305            ErrorCause::GetSubgraphInput
306        );
307        Ok(Tensor { raw_tensor: raw_tensor_ptr, _phantom: PhantomData {} })
308    }
309
310    /// Returns the input tensor with the given name.
311    pub fn input_tensor_by_name(&self, tensor_name: &str) -> Result<Tensor<'_>, Error> {
312        let num_inputs = self.num_inputs()?;
313        for i in 0..num_inputs {
314            let tensor = self.input_tensor(i)?;
315            if tensor.name()? == tensor_name {
316                return Ok(tensor);
317            }
318        }
319        return Err(Error::new(
320            ErrorCause::SubgraphInputTensorByNameNotFound,
321            LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
322        ));
323    }
324
325    fn output_tensor(&self, tensor_index: LiteRtParamIndex) -> Result<Tensor<'_>, Error> {
326        let mut raw_tensor_ptr: LiteRtTensor = std::ptr::null_mut();
327        call_check_status!(
328            // SAFETY: self.raw_subgraph is always valid as it's initialized by a wrapper function.
329            // We assume that the output is valid if the return status is OK or don't use the output pointer.
330            unsafe {
331                LiteRtGetSubgraphOutput(self.raw_subgraph, tensor_index, &mut raw_tensor_ptr)
332            },
333            ErrorCause::GetSubgraphOutput
334        );
335        Ok(Tensor { raw_tensor: raw_tensor_ptr, _phantom: PhantomData {} })
336    }
337
338    /// Returns the output tensor with the given name.
339    pub fn output_tensor_by_name(&self, tensor_name: &str) -> Result<Tensor<'_>, Error> {
340        let num_outputs = self.num_outputs()?;
341        for i in 0..num_outputs {
342            let tensor = self.output_tensor(i)?;
343            if tensor.name()? == tensor_name {
344                return Ok(tensor);
345            }
346        }
347        return Err(Error::new(
348            ErrorCause::SubgraphOutputTensorByNameNotFound,
349            LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
350        ));
351    }
352}
353
354impl Model {
355    /// Creates a model from a file path.
356    pub fn create_model_from_file(path: &str) -> Result<Self, Error> {
357        let path_c_string =
358            CString::new(path).expect("CString::new failed: string contains null bytes");
359        let c_ptr: *const c_char = path_c_string.as_ptr();
360        let mut raw_model_ptr: LiteRtModel = std::ptr::null_mut();
361        call_check_status!(
362            // SAFETY: c_ptr is a valid pointer to the memory buffer provided by safe Rust code.
363            unsafe { LiteRtCreateModelFromFile(c_ptr, &mut raw_model_ptr) },
364            ErrorCause::CreateModelFromFile
365        );
366        Ok(Model { raw_model: raw_model_ptr })
367    }
368
369    /// Creates a model from a memory buffer.
370    pub fn create_model_from_buffer(buffer: &mut [u8]) -> Result<Self, Error> {
371        let mut raw_model_ptr: LiteRtModel = std::ptr::null_mut();
372        call_check_status!(
373            // SAFETY: buffer is a valid pointer to the memory buffer provided by safe Rust code.
374            unsafe {
375                LiteRtCreateModelFromBuffer(
376                    buffer.as_ptr() as *const c_void,
377                    buffer.len(),
378                    &mut raw_model_ptr,
379                )
380            },
381            ErrorCause::CreateModelFromBuffer
382        );
383        Ok(Model { raw_model: raw_model_ptr })
384    }
385
386    /// Returns the number of subgraphs in the model.
387    pub fn num_subgraphs(&self) -> Result<LiteRtParamIndex, Error> {
388        let mut num_subgraphs: LiteRtParamIndex = 0;
389        call_check_status!(
390            // SAFETY: self.raw_model is always valid as it's initialized by a wrapper function.
391            unsafe { LiteRtGetNumModelSubgraphs(self.raw_model, &mut num_subgraphs) },
392            ErrorCause::GetNumModelSubgraphs
393        );
394        Ok(num_subgraphs)
395    }
396
397    /// Returns the number of signatures in the model.
398    pub fn num_signatures(&self) -> Result<LiteRtParamIndex, Error> {
399        let mut num_signatures: LiteRtParamIndex = 0;
400        call_check_status!(
401            // SAFETY: self.raw_model is always valid as it's initialized by a wrapper function.
402            // We assume that the output is valid if the return status is OK or don't use the output pointer.
403            unsafe { LiteRtGetNumModelSignatures(self.raw_model, &mut num_signatures) },
404            ErrorCause::GetNumModelSignatures
405        );
406        Ok(num_signatures)
407    }
408
409    /// Returns an iterator over the signatures of the model.
410    pub fn signatures(&self) -> Result<SignatureIterator<'_>, Error> {
411        Ok(SignatureIterator {
412            model: self,
413            index: 0,
414            total_num_signatures: self.num_signatures()?,
415        })
416    }
417
418    /// Returns the signature at the given index.
419    pub fn signature(&self, index: LiteRtParamIndex) -> Result<Signature<'_>, Error> {
420        let mut raw_signature_ptr: LiteRtSignature = std::ptr::null_mut();
421        call_check_status!(
422            // SAFETY: self.raw_model is always valid as it's initialized by a wrapper function.
423            // We assume that the output is valid if the return status is OK or don't use the output pointers.
424            unsafe { LiteRtGetModelSignature(self.raw_model, index, &mut raw_signature_ptr) },
425            ErrorCause::GetModelSignature
426        );
427        Ok(Signature { raw_signature: raw_signature_ptr, _phantom: PhantomData{} })
428    }
429}
430
431impl Drop for Model {
432    fn drop(&mut self) {
433        // SAFETY: self.raw_model is always valid, it's guaranteed to be initialized by
434        // create* function.
435        unsafe {
436            LiteRtDestroyModel(self.raw_model);
437        }
438    }
439}