1use crate::bindings::*;
18use crate::call_check_status;
19use crate::environment::Environment;
20use crate::error::{Error, ErrorCause};
21use crate::model::{Model, Tensor};
22use crate::tensor_buffer::{TensorBuffer, TensorBufferRequirements};
23
24pub struct Options {
26 raw_options: LiteRtOptions,
27}
28
29pub enum LiteRtHwAccelerator {
31 None,
32 Cpu,
33 Gpu,
34 Npu,
35}
36
37impl LiteRtHwAccelerator {
38 pub fn to_c_enum(&self) -> LiteRtHwAccelerators {
39 match self {
40 Self::None => LiteRtHwAccelerators_kLiteRtHwAcceleratorNone,
41 Self::Cpu => LiteRtHwAccelerators_kLiteRtHwAcceleratorCpu,
42 Self::Gpu => LiteRtHwAccelerators_kLiteRtHwAcceleratorGpu,
43 Self::Npu => LiteRtHwAccelerators_kLiteRtHwAcceleratorNpu,
44 }
45 }
46}
47
48impl Options {
49 pub fn default() -> Result<Self, Error> {
51 let mut raw_options_ptr: *mut LiteRtOptionsT = std::ptr::null_mut();
52 call_check_status!(
53 unsafe { LiteRtCreateOptions(&mut raw_options_ptr) },
56 ErrorCause::CreateOptions
57 );
58 Ok(Self {
59 raw_options: raw_options_ptr,
60 })
61 }
62
63 pub fn create_with_accelerator(accelerator: LiteRtHwAccelerator) -> Result<Self, Error> {
65 let accelerator_c_enum = accelerator.to_c_enum();
66 let options = Self::default()?;
67 call_check_status!(
68 unsafe {
71 LiteRtSetOptionsHardwareAccelerators(options.raw_options, accelerator_c_enum)
72 },
73 ErrorCause::SetOptionsHardwareAccelerators
74 );
75 Ok(options)
76 }
77}
78
79impl Drop for Options {
80 fn drop(&mut self) {
81 unsafe {
83 LiteRtDestroyOptions(self.raw_options);
84 }
85 }
86}
87
88pub struct CompiledModel {
90 pub(crate) raw_compiled_model: LiteRtCompiledModel,
91}
92
93impl CompiledModel {
94 pub fn create(
96 environment: &Environment,
97 model: &Model,
98 options: &Options,
99 ) -> Result<Self, Error> {
100 let mut raw_compiled_model_ptr: *mut LiteRtCompiledModelT = std::ptr::null_mut();
101 call_check_status!(
102 unsafe {
104 LiteRtCreateCompiledModel(
105 environment.raw_environment,
106 model.raw_model,
107 options.raw_options,
108 &mut raw_compiled_model_ptr,
109 )
110 },
111 ErrorCause::CreateCompiledModel
112 );
113 Ok(CompiledModel {
114 raw_compiled_model: raw_compiled_model_ptr,
115 })
116 }
117
118 fn input_buffer_requirements(
119 &self,
120 signature_index: LiteRtParamIndex,
121 input_index: LiteRtParamIndex,
122 ) -> Result<TensorBufferRequirements<'_>, Error> {
123 let mut requirements_ptr: *mut LiteRtTensorBufferRequirementsT = std::ptr::null_mut();
124 call_check_status!(
125 unsafe {
127 LiteRtGetCompiledModelInputBufferRequirements(
128 self.raw_compiled_model,
129 signature_index,
130 input_index,
131 &mut requirements_ptr,
132 )
133 },
134 ErrorCause::GetCompiledModelInputBufferRequirements
135 );
136 Ok(TensorBufferRequirements::new(requirements_ptr))
137 }
138
139 fn output_buffer_requirements(
140 &self,
141 signature_index: LiteRtParamIndex,
142 output_index: LiteRtParamIndex,
143 ) -> Result<TensorBufferRequirements<'_>, Error> {
144 let mut requirements_ptr: *mut LiteRtTensorBufferRequirementsT = std::ptr::null_mut();
145 call_check_status!(
146 unsafe {
148 LiteRtGetCompiledModelOutputBufferRequirements(
149 self.raw_compiled_model,
150 signature_index,
151 output_index,
152 &mut requirements_ptr,
153 )
154 },
155 ErrorCause::GetCompiledModelOutputBufferRequirements
156 );
157 Ok(TensorBufferRequirements::new(requirements_ptr))
158 }
159
160 pub fn create_input_tensor_buffers(
162 &self,
163 environment: &Environment,
164 model: &Model,
165 signature_index: LiteRtParamIndex,
166 ) -> Result<Vec<TensorBuffer<'_>>, Error> {
167 let signature = model.signature(signature_index)?;
168 let subgraph = signature.subgraph()?;
169 let mut result = Vec::with_capacity(signature.num_inputs()?);
170 for (i, input_name) in signature.input_names()?.enumerate() {
171 let input_requirements = self.input_buffer_requirements(signature_index, i)?;
172 let tensor = subgraph.input_tensor_by_name(input_name?)?;
173 let buffer =
174 CompiledModel::create_buffer_impl(environment, &input_requirements, &tensor)?;
175 result.push(buffer);
176 }
177 Ok(result)
178 }
179
180 pub fn create_output_tensor_buffers(
182 &self,
183 environment: &Environment,
184 model: &Model,
185 signature_index: LiteRtParamIndex,
186 ) -> Result<Vec<TensorBuffer<'_>>, Error> {
187 let signature = model.signature(signature_index)?;
188 let subgraph = signature.subgraph()?;
189 let mut result = Vec::with_capacity(signature.num_outputs()?);
190 for (i, output_name) in signature.output_names()?.enumerate() {
191 let output_requirements = self.output_buffer_requirements(signature_index, i)?;
192 let tensor = subgraph.output_tensor_by_name(output_name?)?;
193 let buffer =
194 CompiledModel::create_buffer_impl(environment, &output_requirements, &tensor)?;
195 result.push(buffer);
196 }
197 Ok(result)
198 }
199
200 fn create_buffer_impl<'a>(
201 environment: &Environment,
202 requirements: &TensorBufferRequirements,
203 tensor: &Tensor,
204 ) -> Result<TensorBuffer<'a>, Error> {
205 let supported_types = requirements.supported_types()?;
206 let Some(buffer_type) = supported_types.get(0) else {
208 return Err(Error::new(
209 ErrorCause::InputDoesntSupportAnyTensorBufferTypes,
210 LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
211 ));
212 };
213 let tensor_type = tensor.ranked_tensor_type()?;
214 let element_type = tensor.element_type()?;
215 let buffer_size = requirements.buffer_size()?;
216 TensorBuffer::new(
217 environment,
218 &tensor_type,
219 buffer_type,
220 buffer_size,
221 element_type,
222 )
223 }
224
225 pub fn run(
227 &self,
228 signature_index: LiteRtParamIndex,
229 input: &[TensorBuffer<'_>],
230 output: &[TensorBuffer<'_>],
231 ) -> Result<(), Error> {
232 let mut input_ptrs: Vec<_> = input
233 .iter()
234 .map(|tensor| tensor.raw_tensor_buffer)
235 .collect();
236 let mut output_ptrs: Vec<_> = output
237 .iter()
238 .map(|tensor| tensor.raw_tensor_buffer)
239 .collect();
240 call_check_status!(
241 unsafe {
244 LiteRtRunCompiledModel(
245 self.raw_compiled_model,
246 signature_index,
247 input_ptrs.len(),
248 input_ptrs.as_mut_ptr(),
249 output_ptrs.len(),
250 output_ptrs.as_mut_ptr(),
251 )
252 },
253 ErrorCause::RunCompiledModel
254 );
255 Ok(())
256 }
257}