Skip to main content

litert/
compiled_model.rs

1// Copyright 2026 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The compiled model is the result of compiling a model with specific options.
16//! It can be used to run inference on the model.
17use crate::bindings::*;
18use crate::call_check_status;
19use crate::environment::Environment;
20use crate::error::{Error, ErrorCause};
21use crate::model::{Model, Tensor};
22use crate::tensor_buffer::{TensorBuffer, TensorBufferRequirements};
23
24/// Options for compiling a model.
25pub struct Options {
26    raw_options: LiteRtOptions,
27}
28
29/// Hardware accelerators that can be used for inference.
30pub enum LiteRtHwAccelerator {
31    None,
32    Cpu,
33    Gpu,
34    Npu,
35}
36
37impl LiteRtHwAccelerator {
38    pub fn to_c_enum(&self) -> LiteRtHwAccelerators {
39        match self {
40            Self::None => LiteRtHwAccelerators_kLiteRtHwAcceleratorNone,
41            Self::Cpu => LiteRtHwAccelerators_kLiteRtHwAcceleratorCpu,
42            Self::Gpu => LiteRtHwAccelerators_kLiteRtHwAcceleratorGpu,
43            Self::Npu => LiteRtHwAccelerators_kLiteRtHwAcceleratorNpu,
44        }
45    }
46}
47
48impl Options {
49    /// Creates a new set of options with default values.
50    pub fn default() -> Result<Self, Error> {
51        let mut raw_options_ptr: *mut LiteRtOptionsT = std::ptr::null_mut();
52        call_check_status!(
53            // SAFETY: The raw_options_ptr is initialized to null_mut before and points to a valid object
54            // after the function call, if it's successful.
55            unsafe { LiteRtCreateOptions(&mut raw_options_ptr) },
56            ErrorCause::CreateOptions
57        );
58        Ok(Self {
59            raw_options: raw_options_ptr,
60        })
61    }
62
63    /// Creates a new set of options with the specified hardware accelerator.
64    pub fn create_with_accelerator(accelerator: LiteRtHwAccelerator) -> Result<Self, Error> {
65        let accelerator_c_enum = accelerator.to_c_enum();
66        let options = Self::default()?;
67        call_check_status!(
68            // SAFETY: options.raw_options is valid because it's created by calling the default() function.
69            // accelerator_c_enum is valid because it's created by calling the to_c_enum() function.
70            unsafe {
71                LiteRtSetOptionsHardwareAccelerators(options.raw_options, accelerator_c_enum)
72            },
73            ErrorCause::SetOptionsHardwareAccelerators
74        );
75        Ok(options)
76    }
77}
78
79impl Drop for Options {
80    fn drop(&mut self) {
81        // SAFETY: self.raw_options is valid because it's created by calling the default() function.
82        unsafe {
83            LiteRtDestroyOptions(self.raw_options);
84        }
85    }
86}
87
88/// A compiled model that can be used to run inference.
89pub struct CompiledModel {
90    pub(crate) raw_compiled_model: LiteRtCompiledModel,
91}
92
93impl CompiledModel {
94    /// Creates a new compiled model.
95    pub fn create(
96        environment: &Environment,
97        model: &Model,
98        options: &Options,
99    ) -> Result<Self, Error> {
100        let mut raw_compiled_model_ptr: *mut LiteRtCompiledModelT = std::ptr::null_mut();
101        call_check_status!(
102            // SAFETY: All input pointers are initalized before and point to valid objects.
103            unsafe {
104                LiteRtCreateCompiledModel(
105                    environment.raw_environment,
106                    model.raw_model,
107                    options.raw_options,
108                    &mut raw_compiled_model_ptr,
109                )
110            },
111            ErrorCause::CreateCompiledModel
112        );
113        Ok(CompiledModel {
114            raw_compiled_model: raw_compiled_model_ptr,
115        })
116    }
117
118    fn input_buffer_requirements(
119        &self,
120        signature_index: LiteRtParamIndex,
121        input_index: LiteRtParamIndex,
122    ) -> Result<TensorBufferRequirements<'_>, Error> {
123        let mut requirements_ptr: *mut LiteRtTensorBufferRequirementsT = std::ptr::null_mut();
124        call_check_status!(
125            // SAFETY: self.raw_compiled_model is valid because it's created by calling the create() function.
126            unsafe {
127                LiteRtGetCompiledModelInputBufferRequirements(
128                    self.raw_compiled_model,
129                    signature_index,
130                    input_index,
131                    &mut requirements_ptr,
132                )
133            },
134            ErrorCause::GetCompiledModelInputBufferRequirements
135        );
136        Ok(TensorBufferRequirements::new(requirements_ptr))
137    }
138
139    fn output_buffer_requirements(
140        &self,
141        signature_index: LiteRtParamIndex,
142        output_index: LiteRtParamIndex,
143    ) -> Result<TensorBufferRequirements<'_>, Error> {
144        let mut requirements_ptr: *mut LiteRtTensorBufferRequirementsT = std::ptr::null_mut();
145        call_check_status!(
146            // SAFETY: self.raw_compiled_model is valid because it's created by calling the create() function.
147            unsafe {
148                LiteRtGetCompiledModelOutputBufferRequirements(
149                    self.raw_compiled_model,
150                    signature_index,
151                    output_index,
152                    &mut requirements_ptr,
153                )
154            },
155            ErrorCause::GetCompiledModelOutputBufferRequirements
156        );
157        Ok(TensorBufferRequirements::new(requirements_ptr))
158    }
159
160    /// Creates a set of input tensor buffers for the specified signature.
161    pub fn create_input_tensor_buffers(
162        &self,
163        environment: &Environment,
164        model: &Model,
165        signature_index: LiteRtParamIndex,
166    ) -> Result<Vec<TensorBuffer<'_>>, Error> {
167        let signature = model.signature(signature_index)?;
168        let subgraph = signature.subgraph()?;
169        let mut result = Vec::with_capacity(signature.num_inputs()?);
170        for (i, input_name) in signature.input_names()?.enumerate() {
171            let input_requirements = self.input_buffer_requirements(signature_index, i)?;
172            let tensor = subgraph.input_tensor_by_name(input_name?)?;
173            let buffer =
174                CompiledModel::create_buffer_impl(environment, &input_requirements, &tensor)?;
175            result.push(buffer);
176        }
177        Ok(result)
178    }
179
180    /// Creates a set of output tensor buffers for the specified signature.
181    pub fn create_output_tensor_buffers(
182        &self,
183        environment: &Environment,
184        model: &Model,
185        signature_index: LiteRtParamIndex,
186    ) -> Result<Vec<TensorBuffer<'_>>, Error> {
187        let signature = model.signature(signature_index)?;
188        let subgraph = signature.subgraph()?;
189        let mut result = Vec::with_capacity(signature.num_outputs()?);
190        for (i, output_name) in signature.output_names()?.enumerate() {
191            let output_requirements = self.output_buffer_requirements(signature_index, i)?;
192            let tensor = subgraph.output_tensor_by_name(output_name?)?;
193            let buffer =
194                CompiledModel::create_buffer_impl(environment, &output_requirements, &tensor)?;
195            result.push(buffer);
196        }
197        Ok(result)
198    }
199
200    fn create_buffer_impl<'a>(
201        environment: &Environment,
202        requirements: &TensorBufferRequirements,
203        tensor: &Tensor,
204    ) -> Result<TensorBuffer<'a>, Error> {
205        let supported_types = requirements.supported_types()?;
206        // For simplicity we just pick the first supported tensor buffer type.
207        let Some(buffer_type) = supported_types.get(0) else {
208            return Err(Error::new(
209                ErrorCause::InputDoesntSupportAnyTensorBufferTypes,
210                LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
211            ));
212        };
213        let tensor_type = tensor.ranked_tensor_type()?;
214        let element_type = tensor.element_type()?;
215        let buffer_size = requirements.buffer_size()?;
216        TensorBuffer::new(
217            environment,
218            &tensor_type,
219            buffer_type,
220            buffer_size,
221            element_type,
222        )
223    }
224
225    /// Runs inference on the compiled model.
226    pub fn run(
227        &self,
228        signature_index: LiteRtParamIndex,
229        input: &[TensorBuffer<'_>],
230        output: &[TensorBuffer<'_>],
231    ) -> Result<(), Error> {
232        let mut input_ptrs: Vec<_> = input
233            .iter()
234            .map(|tensor| tensor.raw_tensor_buffer)
235            .collect();
236        let mut output_ptrs: Vec<_> = output
237            .iter()
238            .map(|tensor| tensor.raw_tensor_buffer)
239            .collect();
240        call_check_status!(
241            // SAFETY: self.raw_compiled_model is valid because it's created by calling the create() function.
242            // input_ptrs and output_ptrs are valid because they are created in the function.
243            unsafe {
244                LiteRtRunCompiledModel(
245                    self.raw_compiled_model,
246                    signature_index,
247                    input_ptrs.len(),
248                    input_ptrs.as_mut_ptr(),
249                    output_ptrs.len(),
250                    output_ptrs.as_mut_ptr(),
251                )
252            },
253            ErrorCause::RunCompiledModel
254        );
255        Ok(())
256    }
257}