1#![allow(non_upper_case_globals)]
16
17use std::any::TypeId;
18use std::ffi::c_void;
19use std::marker::PhantomData;
20use std::mem;
21
22use crate::bindings::*;
23use crate::call_check_status;
24use crate::environment::Environment;
25use crate::error::{Error, ErrorCause};
26
27pub struct TensorBufferRequirements<'a> {
32 raw_requirements: LiteRtTensorBufferRequirements,
33 _phantom: PhantomData<&'a LiteRtTensorBufferRequirements>,
34}
35
36impl<'a> TensorBufferRequirements<'a> {
37 pub(crate) fn new(raw_requirements: LiteRtTensorBufferRequirements) -> Self {
38 Self { raw_requirements: raw_requirements, _phantom: PhantomData {} }
39 }
40
41 pub fn buffer_size(&self) -> Result<usize, Error> {
43 let mut buffer_size: usize = 0;
44 call_check_status!(
45 unsafe {
48 LiteRtGetTensorBufferRequirementsBufferSize(self.raw_requirements, &mut buffer_size)
49 },
50 ErrorCause::GetTensorBufferRequirementsBufferSize
51 );
52 Ok(buffer_size)
53 }
54
55 pub fn supported_types(&self) -> Result<Vec<TensorBufferType>, Error> {
57 let mut num_supported_types: i32 = 0;
58 call_check_status!(
59 unsafe {
62 LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes(
63 self.raw_requirements,
64 &mut num_supported_types,
65 )
66 },
67 ErrorCause::GetNumTensorBufferRequirementsSupportedBufferTypes
68 );
69 let mut result = Vec::with_capacity(num_supported_types as usize);
70 for i in 0..num_supported_types {
71 let mut ttype = LiteRtTensorBufferType_kLiteRtTensorBufferTypeUnknown;
72 call_check_status!(
73 unsafe {
76 LiteRtGetTensorBufferRequirementsSupportedTensorBufferType(
77 self.raw_requirements,
78 i,
79 &mut ttype,
80 )
81 },
82 ErrorCause::GetTensorBufferRequirementsSupportedTensorBufferType
83 );
84 result.push(TensorBufferType::from_c_enum(ttype)?);
85 }
86 Ok(result)
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq)]
92pub enum ElementType {
93 None,
94 Bool,
95 Int4,
96 Int8,
97 Int16,
98 Int32,
99 Int64,
100 UInt8,
101 UInt16,
102 UInt32,
103 UInt64,
104 Float16,
105 BFloat16,
106 Float32,
107 Float64,
108 Complex64,
109 Complex128,
110 TfResource,
111 TfString,
112 TfVariant,
113}
114
115impl ElementType {
116 #[allow(dead_code)]
117 pub(crate) fn to_c_enum(&self) -> LiteRtElementType {
118 match self {
119 Self::None => LiteRtElementType_kLiteRtElementTypeNone,
120 Self::Bool => LiteRtElementType_kLiteRtElementTypeBool,
121 Self::Int4 => LiteRtElementType_kLiteRtElementTypeInt4,
122 Self::Int8 => LiteRtElementType_kLiteRtElementTypeInt8,
123 Self::Int16 => LiteRtElementType_kLiteRtElementTypeInt16,
124 Self::Int32 => LiteRtElementType_kLiteRtElementTypeInt32,
125 Self::Int64 => LiteRtElementType_kLiteRtElementTypeInt64,
126 Self::UInt8 => LiteRtElementType_kLiteRtElementTypeUInt8,
127 Self::UInt16 => LiteRtElementType_kLiteRtElementTypeUInt16,
128 Self::UInt32 => LiteRtElementType_kLiteRtElementTypeUInt32,
129 Self::UInt64 => LiteRtElementType_kLiteRtElementTypeUInt64,
130 Self::Float16 => LiteRtElementType_kLiteRtElementTypeFloat16,
131 Self::BFloat16 => LiteRtElementType_kLiteRtElementTypeBFloat16,
132 Self::Float32 => LiteRtElementType_kLiteRtElementTypeFloat32,
133 Self::Float64 => LiteRtElementType_kLiteRtElementTypeFloat64,
134 Self::Complex64 => LiteRtElementType_kLiteRtElementTypeComplex64,
135 Self::Complex128 => LiteRtElementType_kLiteRtElementTypeComplex128,
136 Self::TfResource => LiteRtElementType_kLiteRtElementTypeTfResource,
137 Self::TfString => LiteRtElementType_kLiteRtElementTypeTfString,
138 Self::TfVariant => LiteRtElementType_kLiteRtElementTypeTfVariant,
139 }
140 }
141
142 pub(crate) fn from_c_enum(enum_value: LiteRtElementType) -> Result<ElementType, Error> {
143 match enum_value {
144 LiteRtElementType_kLiteRtElementTypeNone => Ok(Self::None),
145 LiteRtElementType_kLiteRtElementTypeBool => Ok(Self::Bool),
146 LiteRtElementType_kLiteRtElementTypeInt4 => Ok(Self::Int4),
147 LiteRtElementType_kLiteRtElementTypeInt8 => Ok(Self::Int8),
148 LiteRtElementType_kLiteRtElementTypeInt16 => Ok(Self::Int16),
149 LiteRtElementType_kLiteRtElementTypeInt32 => Ok(Self::Int32),
150 LiteRtElementType_kLiteRtElementTypeInt64 => Ok(Self::Int64),
151 LiteRtElementType_kLiteRtElementTypeUInt8 => Ok(Self::UInt8),
152 LiteRtElementType_kLiteRtElementTypeUInt16 => Ok(Self::UInt16),
153 LiteRtElementType_kLiteRtElementTypeUInt32 => Ok(Self::UInt32),
154 LiteRtElementType_kLiteRtElementTypeUInt64 => Ok(Self::UInt64),
155 LiteRtElementType_kLiteRtElementTypeFloat16 => Ok(Self::Float16),
156 LiteRtElementType_kLiteRtElementTypeBFloat16 => Ok(Self::BFloat16),
157 LiteRtElementType_kLiteRtElementTypeFloat32 => Ok(Self::Float32),
158 LiteRtElementType_kLiteRtElementTypeFloat64 => Ok(Self::Float64),
159 LiteRtElementType_kLiteRtElementTypeComplex64 => Ok(Self::Complex64),
160 LiteRtElementType_kLiteRtElementTypeComplex128 => Ok(Self::Complex128),
161 LiteRtElementType_kLiteRtElementTypeTfResource => Ok(Self::TfResource),
162 LiteRtElementType_kLiteRtElementTypeTfString => Ok(Self::TfString),
163 LiteRtElementType_kLiteRtElementTypeTfVariant => Ok(Self::TfVariant),
164 _ => Err(Error::new(
165 ErrorCause::InvalidElementTypeEnumValue,
166 LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
167 )),
168 }
169 }
170
171 fn is_compatible<T: 'static>(self) -> bool {
172 let type_id = TypeId::of::<T>();
173 if TypeId::of::<bool>() == type_id {
174 self == Self::Bool
175 } else if TypeId::of::<i8>() == type_id || TypeId::of::<i8>() == type_id {
176 self == Self::Int8 || self == Self::UInt8
177 } else if TypeId::of::<i16>() == type_id || TypeId::of::<i16>() == type_id {
178 self == Self::Int16 || self == Self::UInt16
179 } else if TypeId::of::<i32>() == type_id || TypeId::of::<i32>() == type_id {
180 self == Self::Int32 || self == Self::UInt32
181 } else if TypeId::of::<f32>() == type_id {
182 self == Self::Float32
183 } else if TypeId::of::<f64>() == type_id {
184 self == Self::Float64
185 } else {
186 false
188 }
189 }
190}
191
192pub enum TensorBufferType {
193 Unknown,
194 HostMemory,
195 Ahwb,
196 Ion,
197 DmaBuf,
198 FastRpc,
199 GlBuffer,
200 GlTexture,
201 OpenClBuffer,
202 OpenClBufferFp16,
203 OpenClTexture,
204 OpenClTextureFp16,
205 OpenClBufferPacked,
206}
207
208impl TensorBufferType {
209 pub fn to_c_enum(&self) -> LiteRtTensorBufferType {
210 match self {
211 Self::Unknown => LiteRtTensorBufferType_kLiteRtTensorBufferTypeUnknown,
212 Self::HostMemory => LiteRtTensorBufferType_kLiteRtTensorBufferTypeHostMemory,
213 Self::Ahwb => LiteRtTensorBufferType_kLiteRtTensorBufferTypeAhwb,
214 Self::Ion => LiteRtTensorBufferType_kLiteRtTensorBufferTypeIon,
215 Self::DmaBuf => LiteRtTensorBufferType_kLiteRtTensorBufferTypeDmaBuf,
216 Self::FastRpc => LiteRtTensorBufferType_kLiteRtTensorBufferTypeFastRpc,
217 Self::GlBuffer => LiteRtTensorBufferType_kLiteRtTensorBufferTypeGlBuffer,
218 Self::GlTexture => LiteRtTensorBufferType_kLiteRtTensorBufferTypeGlTexture,
219 Self::OpenClBuffer => LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBuffer,
220 Self::OpenClBufferFp16 => {
221 LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBufferFp16
222 }
223 Self::OpenClTexture => LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClTexture,
224 Self::OpenClTextureFp16 => {
225 LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClTextureFp16
226 }
227 Self::OpenClBufferPacked => {
228 LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBufferPacked
229 }
230 }
231 }
232 pub fn from_c_enum(enum_value: LiteRtTensorBufferType) -> Result<TensorBufferType, Error> {
233 match enum_value {
234 LiteRtTensorBufferType_kLiteRtTensorBufferTypeUnknown => Ok(Self::Unknown),
235 LiteRtTensorBufferType_kLiteRtTensorBufferTypeHostMemory => Ok(Self::HostMemory),
236 LiteRtTensorBufferType_kLiteRtTensorBufferTypeAhwb => Ok(Self::Ahwb),
237 LiteRtTensorBufferType_kLiteRtTensorBufferTypeIon => Ok(Self::Ion),
238 LiteRtTensorBufferType_kLiteRtTensorBufferTypeDmaBuf => Ok(Self::DmaBuf),
239 LiteRtTensorBufferType_kLiteRtTensorBufferTypeFastRpc => Ok(Self::FastRpc),
240 LiteRtTensorBufferType_kLiteRtTensorBufferTypeGlBuffer => Ok(Self::GlBuffer),
241 LiteRtTensorBufferType_kLiteRtTensorBufferTypeGlTexture => Ok(Self::GlTexture),
242 LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBuffer => Ok(Self::OpenClBuffer),
243 LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBufferFp16 => {
244 Ok(Self::OpenClBufferFp16)
245 }
246 LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClTexture => Ok(Self::OpenClTexture),
247 LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClTextureFp16 => {
248 Ok(Self::OpenClTextureFp16)
249 }
250 LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBufferPacked => {
251 Ok(Self::OpenClBufferPacked)
252 }
253 _ => Err(Error::new(
254 ErrorCause::InvalidTensorBufferTypeEnumValue,
255 LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
256 )),
257 }
258 }
259}
260
261pub struct TensorBuffer<'a> {
262 pub(crate) raw_tensor_buffer: LiteRtTensorBuffer,
263 element_type: ElementType,
264 _phantom: PhantomData<&'a LiteRtTensorBuffer>,
265}
266
267struct TensorBufferLock<'a, T> {
268 buffer: &'a TensorBuffer<'a>,
269 raw_data: *mut T,
270}
271
272impl<T> Drop for TensorBufferLock<'_, T> {
273 fn drop(&mut self) {
274 unsafe {
277 LiteRtUnlockTensorBuffer(self.buffer.raw_tensor_buffer);
278 }
279 }
280}
281
282impl<'a> TensorBuffer<'a> {
283 pub(crate) fn new(
284 environment: &Environment,
285 tensor_type: *const LiteRtRankedTensorType,
286 buffer_type: &TensorBufferType,
287 buffer_size: usize,
288 element_type: ElementType,
289 ) -> Result<TensorBuffer<'a>, Error> {
290 let mut buffer_ptr: *mut LiteRtTensorBufferT = std::ptr::null_mut();
291 call_check_status!(
292 unsafe {
295 LiteRtCreateManagedTensorBuffer(
296 environment.raw_environment,
297 buffer_type.to_c_enum(),
298 tensor_type,
299 buffer_size,
300 &mut buffer_ptr,
301 )
302 },
303 ErrorCause::CreateManagedTensorBuffer
304 );
305 Ok(TensorBuffer { raw_tensor_buffer: buffer_ptr, element_type, _phantom: PhantomData {} })
306 }
307
308 pub fn element_type(&self) -> ElementType {
310 self.element_type
311 }
312
313 fn lock_read<T>(&'a self) -> Result<TensorBufferLock<'a, T>, Error> {
314 let mut data: *mut c_void = std::ptr::null_mut();
315 call_check_status!(
316 unsafe {
320 LiteRtLockTensorBuffer(
321 self.raw_tensor_buffer,
322 &mut data,
323 LiteRtTensorBufferLockMode_kLiteRtTensorBufferLockModeRead,
324 )
325 },
326 ErrorCause::LockTensorBufferRead
327 );
328 Ok(TensorBufferLock { buffer: self, raw_data: data as *mut T })
329 }
330
331 fn lock_write<T>(&'a self) -> Result<TensorBufferLock<'a, T>, Error> {
332 let mut data: *mut c_void = std::ptr::null_mut();
333 call_check_status!(
334 unsafe {
338 LiteRtLockTensorBuffer(
339 self.raw_tensor_buffer,
340 &mut data,
341 LiteRtTensorBufferLockMode_kLiteRtTensorBufferLockModeWrite,
342 )
343 },
344 ErrorCause::LockTensorBufferWrite
345 );
346 Ok(TensorBufferLock { buffer: self, raw_data: data as *mut T })
347 }
348
349 pub fn packed_size(&self) -> Result<usize, Error> {
351 let mut size: usize = 0;
352 call_check_status!(
353 unsafe { LiteRtGetTensorBufferPackedSize(self.raw_tensor_buffer, &mut size) },
356 ErrorCause::GetTensorBufferPackedSize
357 );
358 Ok(size)
359 }
360
361 pub fn write<T: 'static>(&self, data: &[T]) -> Result<usize, Error> {
368 if !self.element_type.is_compatible::<T>() {
369 return Err(Error::new(
370 ErrorCause::IncompatibleWriteType,
371 LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
372 ));
373 }
374 let lock = self.lock_write()?;
375 let dst_size = self.packed_size()?;
376 let src_size = mem::size_of_val(data);
377 if dst_size < src_size {
378 return Err(Error::new(
379 ErrorCause::TensorBufferTooSmall,
380 LiteRtStatus_kLiteRtStatusErrorRuntimeFailure,
381 ));
382 }
383 unsafe {
391 std::ptr::copy(data.as_ptr(), lock.raw_data, src_size / std::mem::size_of::<T>());
392 }
393
394 Ok(src_size)
395 }
396
397 pub fn read<T: 'static>(&self, data: &mut [T]) -> Result<usize, Error> {
404 if !self.element_type.is_compatible::<T>() {
405 return Err(Error::new(
406 ErrorCause::IncompatibleReadType,
407 LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
408 ));
409 }
410 let lock = self.lock_read()?;
411 let src_size = self.packed_size()?;
412 let dst_size = mem::size_of_val(data);
413 if dst_size < src_size {
414 return Err(Error::new(
415 ErrorCause::TensorBufferTooSmall,
416 LiteRtStatus_kLiteRtStatusErrorRuntimeFailure,
417 ));
418 }
419 let to_copy = std::cmp::min(src_size, dst_size) / std::mem::size_of::<T>();
420 unsafe {
427 std::ptr::copy(lock.raw_data, data.as_mut_ptr(), to_copy);
428 }
429 Ok(to_copy)
430 }
431}
432
433impl Drop for TensorBuffer<'_> {
434 fn drop(&mut self) {
435 unsafe {
438 LiteRtDestroyTensorBuffer(self.raw_tensor_buffer);
439 }
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 #[test]
447 fn test_element_type_compatibility() {
448 assert!(ElementType::Bool.is_compatible::<bool>());
449 assert!(!ElementType::Bool.is_compatible::<u32>());
450 assert!(!ElementType::Float32.is_compatible::<u32>());
451 }
452}