Skip to main content

litert/
tensor_buffer.rs

1// Copyright 2026 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(non_upper_case_globals)]
16
17use std::any::TypeId;
18use std::ffi::c_void;
19use std::marker::PhantomData;
20use std::mem;
21
22use crate::bindings::*;
23use crate::call_check_status;
24use crate::environment::Environment;
25use crate::error::{Error, ErrorCause};
26
27/// Requirements for a tensor buffer.
28///
29/// This struct represents the requirements for a tensor buffer. It is used to determine the
30/// supported buffer types and the buffer size.
31pub struct TensorBufferRequirements<'a> {
32    raw_requirements: LiteRtTensorBufferRequirements,
33    _phantom: PhantomData<&'a LiteRtTensorBufferRequirements>,
34}
35
36impl<'a> TensorBufferRequirements<'a> {
37    pub(crate) fn new(raw_requirements: LiteRtTensorBufferRequirements) -> Self {
38        Self { raw_requirements: raw_requirements, _phantom: PhantomData {} }
39    }
40
41    /// Returns the size of the buffer in bytes.
42    pub fn buffer_size(&self) -> Result<usize, Error> {
43        let mut buffer_size: usize = 0;
44        call_check_status!(
45            // SAFETY: self.raw_requirements is always valid, it's guaranteed to be initialized by
46            // a wrapper function.
47            unsafe {
48                LiteRtGetTensorBufferRequirementsBufferSize(self.raw_requirements, &mut buffer_size)
49            },
50            ErrorCause::GetTensorBufferRequirementsBufferSize
51        );
52        Ok(buffer_size)
53    }
54
55    /// Returns the supported tensor buffer types.
56    pub fn supported_types(&self) -> Result<Vec<TensorBufferType>, Error> {
57        let mut num_supported_types: i32 = 0;
58        call_check_status!(
59            // SAFETY: self.raw_requirements is always valid, it's guaranteed to be initialized by
60            // a wrapper function.
61            unsafe {
62                LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes(
63                    self.raw_requirements,
64                    &mut num_supported_types,
65                )
66            },
67            ErrorCause::GetNumTensorBufferRequirementsSupportedBufferTypes
68        );
69        let mut result = Vec::with_capacity(num_supported_types as usize);
70        for i in 0..num_supported_types {
71            let mut ttype = LiteRtTensorBufferType_kLiteRtTensorBufferTypeUnknown;
72            call_check_status!(
73                // SAFETY: self.raw_requirements is always valid, it's guaranteed to be initialized by
74                // a wrapper function.
75                unsafe {
76                    LiteRtGetTensorBufferRequirementsSupportedTensorBufferType(
77                        self.raw_requirements,
78                        i,
79                        &mut ttype,
80                    )
81                },
82                ErrorCause::GetTensorBufferRequirementsSupportedTensorBufferType
83            );
84            result.push(TensorBufferType::from_c_enum(ttype)?);
85        }
86        Ok(result)
87    }
88}
89
90/// The element type of a tensor buffer.
91#[derive(Debug, Clone, Copy, PartialEq)]
92pub enum ElementType {
93    None,
94    Bool,
95    Int4,
96    Int8,
97    Int16,
98    Int32,
99    Int64,
100    UInt8,
101    UInt16,
102    UInt32,
103    UInt64,
104    Float16,
105    BFloat16,
106    Float32,
107    Float64,
108    Complex64,
109    Complex128,
110    TfResource,
111    TfString,
112    TfVariant,
113}
114
115impl ElementType {
116    #[allow(dead_code)]
117    pub(crate) fn to_c_enum(&self) -> LiteRtElementType {
118        match self {
119            Self::None => LiteRtElementType_kLiteRtElementTypeNone,
120            Self::Bool => LiteRtElementType_kLiteRtElementTypeBool,
121            Self::Int4 => LiteRtElementType_kLiteRtElementTypeInt4,
122            Self::Int8 => LiteRtElementType_kLiteRtElementTypeInt8,
123            Self::Int16 => LiteRtElementType_kLiteRtElementTypeInt16,
124            Self::Int32 => LiteRtElementType_kLiteRtElementTypeInt32,
125            Self::Int64 => LiteRtElementType_kLiteRtElementTypeInt64,
126            Self::UInt8 => LiteRtElementType_kLiteRtElementTypeUInt8,
127            Self::UInt16 => LiteRtElementType_kLiteRtElementTypeUInt16,
128            Self::UInt32 => LiteRtElementType_kLiteRtElementTypeUInt32,
129            Self::UInt64 => LiteRtElementType_kLiteRtElementTypeUInt64,
130            Self::Float16 => LiteRtElementType_kLiteRtElementTypeFloat16,
131            Self::BFloat16 => LiteRtElementType_kLiteRtElementTypeBFloat16,
132            Self::Float32 => LiteRtElementType_kLiteRtElementTypeFloat32,
133            Self::Float64 => LiteRtElementType_kLiteRtElementTypeFloat64,
134            Self::Complex64 => LiteRtElementType_kLiteRtElementTypeComplex64,
135            Self::Complex128 => LiteRtElementType_kLiteRtElementTypeComplex128,
136            Self::TfResource => LiteRtElementType_kLiteRtElementTypeTfResource,
137            Self::TfString => LiteRtElementType_kLiteRtElementTypeTfString,
138            Self::TfVariant => LiteRtElementType_kLiteRtElementTypeTfVariant,
139        }
140    }
141
142    pub(crate) fn from_c_enum(enum_value: LiteRtElementType) -> Result<ElementType, Error> {
143        match enum_value {
144            LiteRtElementType_kLiteRtElementTypeNone => Ok(Self::None),
145            LiteRtElementType_kLiteRtElementTypeBool => Ok(Self::Bool),
146            LiteRtElementType_kLiteRtElementTypeInt4 => Ok(Self::Int4),
147            LiteRtElementType_kLiteRtElementTypeInt8 => Ok(Self::Int8),
148            LiteRtElementType_kLiteRtElementTypeInt16 => Ok(Self::Int16),
149            LiteRtElementType_kLiteRtElementTypeInt32 => Ok(Self::Int32),
150            LiteRtElementType_kLiteRtElementTypeInt64 => Ok(Self::Int64),
151            LiteRtElementType_kLiteRtElementTypeUInt8 => Ok(Self::UInt8),
152            LiteRtElementType_kLiteRtElementTypeUInt16 => Ok(Self::UInt16),
153            LiteRtElementType_kLiteRtElementTypeUInt32 => Ok(Self::UInt32),
154            LiteRtElementType_kLiteRtElementTypeUInt64 => Ok(Self::UInt64),
155            LiteRtElementType_kLiteRtElementTypeFloat16 => Ok(Self::Float16),
156            LiteRtElementType_kLiteRtElementTypeBFloat16 => Ok(Self::BFloat16),
157            LiteRtElementType_kLiteRtElementTypeFloat32 => Ok(Self::Float32),
158            LiteRtElementType_kLiteRtElementTypeFloat64 => Ok(Self::Float64),
159            LiteRtElementType_kLiteRtElementTypeComplex64 => Ok(Self::Complex64),
160            LiteRtElementType_kLiteRtElementTypeComplex128 => Ok(Self::Complex128),
161            LiteRtElementType_kLiteRtElementTypeTfResource => Ok(Self::TfResource),
162            LiteRtElementType_kLiteRtElementTypeTfString => Ok(Self::TfString),
163            LiteRtElementType_kLiteRtElementTypeTfVariant => Ok(Self::TfVariant),
164            _ => Err(Error::new(
165                ErrorCause::InvalidElementTypeEnumValue,
166                LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
167            )),
168        }
169    }
170
171    fn is_compatible<T: 'static>(self) -> bool {
172        let type_id = TypeId::of::<T>();
173        if TypeId::of::<bool>() == type_id {
174            self == Self::Bool
175        } else if TypeId::of::<i8>() == type_id || TypeId::of::<i8>() == type_id {
176            self == Self::Int8 || self == Self::UInt8
177        } else if TypeId::of::<i16>() == type_id || TypeId::of::<i16>() == type_id {
178            self == Self::Int16 || self == Self::UInt16
179        } else if TypeId::of::<i32>() == type_id || TypeId::of::<i32>() == type_id {
180            self == Self::Int32 || self == Self::UInt32
181        } else if TypeId::of::<f32>() == type_id {
182            self == Self::Float32
183        } else if TypeId::of::<f64>() == type_id {
184            self == Self::Float64
185        } else {
186            // TODO: Add support for other types.
187            false
188        }
189    }
190}
191
192pub enum TensorBufferType {
193    Unknown,
194    HostMemory,
195    Ahwb,
196    Ion,
197    DmaBuf,
198    FastRpc,
199    GlBuffer,
200    GlTexture,
201    OpenClBuffer,
202    OpenClBufferFp16,
203    OpenClTexture,
204    OpenClTextureFp16,
205    OpenClBufferPacked,
206}
207
208impl TensorBufferType {
209    pub fn to_c_enum(&self) -> LiteRtTensorBufferType {
210        match self {
211            Self::Unknown => LiteRtTensorBufferType_kLiteRtTensorBufferTypeUnknown,
212            Self::HostMemory => LiteRtTensorBufferType_kLiteRtTensorBufferTypeHostMemory,
213            Self::Ahwb => LiteRtTensorBufferType_kLiteRtTensorBufferTypeAhwb,
214            Self::Ion => LiteRtTensorBufferType_kLiteRtTensorBufferTypeIon,
215            Self::DmaBuf => LiteRtTensorBufferType_kLiteRtTensorBufferTypeDmaBuf,
216            Self::FastRpc => LiteRtTensorBufferType_kLiteRtTensorBufferTypeFastRpc,
217            Self::GlBuffer => LiteRtTensorBufferType_kLiteRtTensorBufferTypeGlBuffer,
218            Self::GlTexture => LiteRtTensorBufferType_kLiteRtTensorBufferTypeGlTexture,
219            Self::OpenClBuffer => LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBuffer,
220            Self::OpenClBufferFp16 => {
221                LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBufferFp16
222            }
223            Self::OpenClTexture => LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClTexture,
224            Self::OpenClTextureFp16 => {
225                LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClTextureFp16
226            }
227            Self::OpenClBufferPacked => {
228                LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBufferPacked
229            }
230        }
231    }
232    pub fn from_c_enum(enum_value: LiteRtTensorBufferType) -> Result<TensorBufferType, Error> {
233        match enum_value {
234            LiteRtTensorBufferType_kLiteRtTensorBufferTypeUnknown => Ok(Self::Unknown),
235            LiteRtTensorBufferType_kLiteRtTensorBufferTypeHostMemory => Ok(Self::HostMemory),
236            LiteRtTensorBufferType_kLiteRtTensorBufferTypeAhwb => Ok(Self::Ahwb),
237            LiteRtTensorBufferType_kLiteRtTensorBufferTypeIon => Ok(Self::Ion),
238            LiteRtTensorBufferType_kLiteRtTensorBufferTypeDmaBuf => Ok(Self::DmaBuf),
239            LiteRtTensorBufferType_kLiteRtTensorBufferTypeFastRpc => Ok(Self::FastRpc),
240            LiteRtTensorBufferType_kLiteRtTensorBufferTypeGlBuffer => Ok(Self::GlBuffer),
241            LiteRtTensorBufferType_kLiteRtTensorBufferTypeGlTexture => Ok(Self::GlTexture),
242            LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBuffer => Ok(Self::OpenClBuffer),
243            LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBufferFp16 => {
244                Ok(Self::OpenClBufferFp16)
245            }
246            LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClTexture => Ok(Self::OpenClTexture),
247            LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClTextureFp16 => {
248                Ok(Self::OpenClTextureFp16)
249            }
250            LiteRtTensorBufferType_kLiteRtTensorBufferTypeOpenClBufferPacked => {
251                Ok(Self::OpenClBufferPacked)
252            }
253            _ => Err(Error::new(
254                ErrorCause::InvalidTensorBufferTypeEnumValue,
255                LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
256            )),
257        }
258    }
259}
260
261pub struct TensorBuffer<'a> {
262    pub(crate) raw_tensor_buffer: LiteRtTensorBuffer,
263    element_type: ElementType,
264    _phantom: PhantomData<&'a LiteRtTensorBuffer>,
265}
266
267struct TensorBufferLock<'a, T> {
268    buffer: &'a TensorBuffer<'a>,
269    raw_data: *mut T,
270}
271
272impl<T> Drop for TensorBufferLock<'_, T> {
273    fn drop(&mut self) {
274        // SAFETY: self.buffer.raw_tensor_buffer is always valid, it's guaranteed to be initialized by
275        // a wrapper function.
276        unsafe {
277            LiteRtUnlockTensorBuffer(self.buffer.raw_tensor_buffer);
278        }
279    }
280}
281
282impl<'a> TensorBuffer<'a> {
283    pub(crate) fn new(
284        environment: &Environment,
285        tensor_type: *const LiteRtRankedTensorType,
286        buffer_type: &TensorBufferType,
287        buffer_size: usize,
288        element_type: ElementType,
289    ) -> Result<TensorBuffer<'a>, Error> {
290        let mut buffer_ptr: *mut LiteRtTensorBufferT = std::ptr::null_mut();
291        call_check_status!(
292            // SAFETY: environment.raw_environment is always valid, it's guaranteed to be initialized by
293            // a wrapper function.
294            unsafe {
295                LiteRtCreateManagedTensorBuffer(
296                    environment.raw_environment,
297                    buffer_type.to_c_enum(),
298                    tensor_type,
299                    buffer_size,
300                    &mut buffer_ptr,
301                )
302            },
303            ErrorCause::CreateManagedTensorBuffer
304        );
305        Ok(TensorBuffer { raw_tensor_buffer: buffer_ptr, element_type, _phantom: PhantomData {} })
306    }
307
308    /// Returns the element type of the tensor buffer.
309    pub fn element_type(&self) -> ElementType {
310        self.element_type
311    }
312
313    fn lock_read<T>(&'a self) -> Result<TensorBufferLock<'a, T>, Error> {
314        let mut data: *mut c_void = std::ptr::null_mut();
315        call_check_status!(
316            // SAFETY: self.raw_tensor_buffer is always valid, it's guaranteed to be initialized by
317            // a wrapper function.
318            // We assume that the output is valid if the return status is OK or don't use the output pointer.
319            unsafe {
320                LiteRtLockTensorBuffer(
321                    self.raw_tensor_buffer,
322                    &mut data,
323                    LiteRtTensorBufferLockMode_kLiteRtTensorBufferLockModeRead,
324                )
325            },
326            ErrorCause::LockTensorBufferRead
327        );
328        Ok(TensorBufferLock { buffer: self, raw_data: data as *mut T })
329    }
330
331    fn lock_write<T>(&'a self) -> Result<TensorBufferLock<'a, T>, Error> {
332        let mut data: *mut c_void = std::ptr::null_mut();
333        call_check_status!(
334            // SAFETY: self.raw_tensor_buffer is always valid, it's guaranteed to be initialized by
335            // a wrapper function.
336            // We assume that the output is valid if the return status is OK or don't use the output pointer.
337            unsafe {
338                LiteRtLockTensorBuffer(
339                    self.raw_tensor_buffer,
340                    &mut data,
341                    LiteRtTensorBufferLockMode_kLiteRtTensorBufferLockModeWrite,
342                )
343            },
344            ErrorCause::LockTensorBufferWrite
345        );
346        Ok(TensorBufferLock { buffer: self, raw_data: data as *mut T })
347    }
348
349    /// Returns the size of the tensor buffer in bytes.
350    pub fn packed_size(&self) -> Result<usize, Error> {
351        let mut size: usize = 0;
352        call_check_status!(
353            // SAFETY: self.raw_tensor_buffer is always valid, it's guaranteed to be initialized by
354            // a wrapper function.
355            unsafe { LiteRtGetTensorBufferPackedSize(self.raw_tensor_buffer, &mut size) },
356            ErrorCause::GetTensorBufferPackedSize
357        );
358        Ok(size)
359    }
360
361    /// Writes data to the tensor buffer.
362    ///
363    /// The data must be compatible with the element type of the tensor buffer.
364    /// The data must be big enough to fill the tensor buffer.
365    ///
366    /// Returns the number of bytes written to the tensor buffer.
367    pub fn write<T: 'static>(&self, data: &[T]) -> Result<usize, Error> {
368        if !self.element_type.is_compatible::<T>() {
369            return Err(Error::new(
370                ErrorCause::IncompatibleWriteType,
371                LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
372            ));
373        }
374        let lock = self.lock_write()?;
375        let dst_size = self.packed_size()?;
376        let src_size = mem::size_of_val(data);
377        if dst_size < src_size {
378            return Err(Error::new(
379                ErrorCause::TensorBufferTooSmall,
380                LiteRtStatus_kLiteRtStatusErrorRuntimeFailure,
381            ));
382        }
383        // TODO(mgubin): Do something when input data is smaller that the tensor buffer.
384        // SAFETY: lock.raw_data is always valid, it's guaranteed to be initialized by
385        // lock_write function.
386        // data is a pointer to the start of the data buffer, it's valid as provided by safe
387        // Rust code.
388        // src_size / std::mem::size_of::<T>() is the number of elements to copy, it's
389        // guaranteed that it won't overwrite the output buffer of read after the end of the input data.
390        unsafe {
391            std::ptr::copy(data.as_ptr(), lock.raw_data, src_size / std::mem::size_of::<T>());
392        }
393
394        Ok(src_size)
395    }
396
397    /// Reads data from the tensor buffer.
398    ///
399    /// The data must be compatible with the element type of the tensor buffer.
400    /// The data must be big enough.
401    ///
402    /// Returns the number of bytes read from the tensor buffer.
403    pub fn read<T: 'static>(&self, data: &mut [T]) -> Result<usize, Error> {
404        if !self.element_type.is_compatible::<T>() {
405            return Err(Error::new(
406                ErrorCause::IncompatibleReadType,
407                LiteRtStatus_kLiteRtStatusErrorInvalidArgument,
408            ));
409        }
410        let lock = self.lock_read()?;
411        let src_size = self.packed_size()?;
412        let dst_size = mem::size_of_val(data);
413        if dst_size < src_size {
414            return Err(Error::new(
415                ErrorCause::TensorBufferTooSmall,
416                LiteRtStatus_kLiteRtStatusErrorRuntimeFailure,
417            ));
418        }
419        let to_copy = std::cmp::min(src_size, dst_size) / std::mem::size_of::<T>();
420        // SAFETY: lock.raw_data is always valid, it's guaranteed to be initialized by
421        // lock_read function.
422        // data is a pointer to the start of the data buffer, it's valid as provided by safe
423        // Rust code.
424        // to_copy is the number of elements to copy, it's guaranteed that it won't overwrite
425        // the output buffer of read after the end of the input data.
426        unsafe {
427            std::ptr::copy(lock.raw_data, data.as_mut_ptr(), to_copy);
428        }
429        Ok(to_copy)
430    }
431}
432
433impl Drop for TensorBuffer<'_> {
434    fn drop(&mut self) {
435        // SAFETY: self.raw_tensor_buffer is always valid, it's guaranteed to be initialized by
436        // a wrapper function.
437        unsafe {
438            LiteRtDestroyTensorBuffer(self.raw_tensor_buffer);
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    #[test]
447    fn test_element_type_compatibility() {
448        assert!(ElementType::Bool.is_compatible::<bool>());
449        assert!(!ElementType::Bool.is_compatible::<u32>());
450        assert!(!ElementType::Float32.is_compatible::<u32>());
451    }
452}