use std::ffi::CStr;
use std::fmt;
use std::ptr::NonNull;
use edgefirst_tflite_sys::{
self as sys, TfLiteTensor, TfLiteType_kTfLiteBFloat16, TfLiteType_kTfLiteBool,
TfLiteType_kTfLiteComplex128, TfLiteType_kTfLiteComplex64, TfLiteType_kTfLiteFloat16,
TfLiteType_kTfLiteFloat32, TfLiteType_kTfLiteFloat64, TfLiteType_kTfLiteInt16,
TfLiteType_kTfLiteInt32, TfLiteType_kTfLiteInt4, TfLiteType_kTfLiteInt64,
TfLiteType_kTfLiteInt8, TfLiteType_kTfLiteNoType, TfLiteType_kTfLiteResource,
TfLiteType_kTfLiteString, TfLiteType_kTfLiteUInt16, TfLiteType_kTfLiteUInt32,
TfLiteType_kTfLiteUInt64, TfLiteType_kTfLiteUInt8, TfLiteType_kTfLiteVariant,
};
use num_traits::FromPrimitive;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, num_derive::FromPrimitive)]
#[repr(isize)]
#[allow(clippy::cast_possible_wrap)] pub enum TensorType {
NoType = TfLiteType_kTfLiteNoType as isize,
Float32 = TfLiteType_kTfLiteFloat32 as isize,
Int32 = TfLiteType_kTfLiteInt32 as isize,
UInt8 = TfLiteType_kTfLiteUInt8 as isize,
Int64 = TfLiteType_kTfLiteInt64 as isize,
String = TfLiteType_kTfLiteString as isize,
Bool = TfLiteType_kTfLiteBool as isize,
Int16 = TfLiteType_kTfLiteInt16 as isize,
Complex64 = TfLiteType_kTfLiteComplex64 as isize,
Int8 = TfLiteType_kTfLiteInt8 as isize,
Float16 = TfLiteType_kTfLiteFloat16 as isize,
Float64 = TfLiteType_kTfLiteFloat64 as isize,
Complex128 = TfLiteType_kTfLiteComplex128 as isize,
UInt64 = TfLiteType_kTfLiteUInt64 as isize,
Resource = TfLiteType_kTfLiteResource as isize,
Variant = TfLiteType_kTfLiteVariant as isize,
UInt32 = TfLiteType_kTfLiteUInt32 as isize,
UInt16 = TfLiteType_kTfLiteUInt16 as isize,
Int4 = TfLiteType_kTfLiteInt4 as isize,
BFloat16 = TfLiteType_kTfLiteBFloat16 as isize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct QuantizationParams {
pub scale: f32,
pub zero_point: i32,
}
pub struct Tensor<'a> {
pub(crate) ptr: *const TfLiteTensor,
pub(crate) lib: &'a sys::tensorflowlite_c,
}
impl Tensor<'_> {
#[must_use]
pub fn tensor_type(&self) -> TensorType {
let raw = unsafe { self.lib.TfLiteTensorType(self.ptr) };
FromPrimitive::from_u32(raw).unwrap_or(TensorType::NoType)
}
#[must_use]
pub fn name(&self) -> &str {
unsafe { CStr::from_ptr(self.lib.TfLiteTensorName(self.ptr)) }
.to_str()
.unwrap_or("<invalid-utf8>")
}
pub fn num_dims(&self) -> Result<usize> {
let n = unsafe { self.lib.TfLiteTensorNumDims(self.ptr) };
usize::try_from(n).map_err(|_| {
Error::invalid_argument(format!(
"tensor `{}` does not have dimensions set",
self.name()
))
})
}
pub fn dim(&self, index: usize) -> Result<usize> {
let num_dims = self.num_dims()?;
if index >= num_dims {
return Err(Error::invalid_argument(format!(
"dimension index {index} out of bounds for tensor with {num_dims} dimensions"
)));
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let i = index as i32;
let d = unsafe { self.lib.TfLiteTensorDim(self.ptr, i) };
#[allow(clippy::cast_sign_loss)]
Ok(d as usize)
}
pub fn shape(&self) -> Result<Vec<usize>> {
let num_dims = self.num_dims()?;
let mut dims = Vec::with_capacity(num_dims);
for i in 0..num_dims {
dims.push(self.dim(i)?);
}
Ok(dims)
}
#[must_use]
pub fn byte_size(&self) -> usize {
unsafe { self.lib.TfLiteTensorByteSize(self.ptr) }
}
pub fn volume(&self) -> Result<usize> {
Ok(self.shape()?.iter().product::<usize>())
}
#[must_use]
pub fn quantization_params(&self) -> QuantizationParams {
let params = unsafe { self.lib.TfLiteTensorQuantizationParams(self.ptr) };
QuantizationParams {
scale: params.scale,
zero_point: params.zero_point,
}
}
pub fn as_slice<T: Copy>(&self) -> Result<&[T]> {
let volume = self.volume()?;
if std::mem::size_of::<T>() * volume > self.byte_size() {
return Err(Error::invalid_argument(format!(
"tensor byte size {} is too small for {} elements of {}",
self.byte_size(),
volume,
std::any::type_name::<T>(),
)));
}
let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr) };
if ptr.is_null() {
return Err(Error::null_pointer("TfLiteTensorData returned null"));
}
Ok(unsafe { std::slice::from_raw_parts(ptr.cast::<T>(), volume) })
}
}
impl fmt::Debug for Tensor<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write_tensor_debug(
f,
self.name(),
self.num_dims(),
|i| self.dim(i),
self.tensor_type(),
)
}
}
impl fmt::Display for Tensor<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write_tensor_debug(
f,
self.name(),
self.num_dims(),
|i| self.dim(i),
self.tensor_type(),
)
}
}
pub struct TensorMut<'a> {
pub(crate) ptr: NonNull<TfLiteTensor>,
pub(crate) lib: &'a sys::tensorflowlite_c,
}
impl TensorMut<'_> {
#[must_use]
pub fn tensor_type(&self) -> TensorType {
let raw = unsafe { self.lib.TfLiteTensorType(self.ptr.as_ptr()) };
FromPrimitive::from_u32(raw).unwrap_or(TensorType::NoType)
}
#[must_use]
pub fn name(&self) -> &str {
unsafe { CStr::from_ptr(self.lib.TfLiteTensorName(self.ptr.as_ptr())) }
.to_str()
.unwrap_or("<invalid-utf8>")
}
pub fn num_dims(&self) -> Result<usize> {
let n = unsafe { self.lib.TfLiteTensorNumDims(self.ptr.as_ptr()) };
usize::try_from(n).map_err(|_| {
Error::invalid_argument(format!(
"tensor `{}` does not have dimensions set",
self.name()
))
})
}
pub fn dim(&self, index: usize) -> Result<usize> {
let num_dims = self.num_dims()?;
if index >= num_dims {
return Err(Error::invalid_argument(format!(
"dimension index {index} out of bounds for tensor with {num_dims} dimensions"
)));
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let i = index as i32;
let d = unsafe { self.lib.TfLiteTensorDim(self.ptr.as_ptr(), i) };
#[allow(clippy::cast_sign_loss)]
Ok(d as usize)
}
pub fn shape(&self) -> Result<Vec<usize>> {
let num_dims = self.num_dims()?;
let mut dims = Vec::with_capacity(num_dims);
for i in 0..num_dims {
dims.push(self.dim(i)?);
}
Ok(dims)
}
#[must_use]
pub fn byte_size(&self) -> usize {
unsafe { self.lib.TfLiteTensorByteSize(self.ptr.as_ptr()) }
}
pub fn volume(&self) -> Result<usize> {
Ok(self.shape()?.iter().product::<usize>())
}
#[must_use]
pub fn quantization_params(&self) -> QuantizationParams {
let params = unsafe { self.lib.TfLiteTensorQuantizationParams(self.ptr.as_ptr()) };
QuantizationParams {
scale: params.scale,
zero_point: params.zero_point,
}
}
pub fn as_slice<T: Copy>(&self) -> Result<&[T]> {
let volume = self.volume()?;
if std::mem::size_of::<T>() * volume > self.byte_size() {
return Err(Error::invalid_argument(format!(
"tensor byte size {} is too small for {} elements of {}",
self.byte_size(),
volume,
std::any::type_name::<T>(),
)));
}
let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr.as_ptr()) };
if ptr.is_null() {
return Err(Error::null_pointer("TfLiteTensorData returned null"));
}
Ok(unsafe { std::slice::from_raw_parts(ptr.cast::<T>(), volume) })
}
pub fn as_mut_slice<T: Copy>(&mut self) -> Result<&mut [T]> {
let volume = self.volume()?;
if std::mem::size_of::<T>() * volume > self.byte_size() {
return Err(Error::invalid_argument(format!(
"tensor byte size {} is too small for {} elements of {}",
self.byte_size(),
volume,
std::any::type_name::<T>(),
)));
}
let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr.as_ptr()) };
if ptr.is_null() {
return Err(Error::null_pointer("TfLiteTensorData returned null"));
}
Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), volume) })
}
pub fn copy_from_slice<T: Copy>(&mut self, data: &[T]) -> Result<()> {
let slice = self.as_mut_slice::<T>()?;
if data.len() != slice.len() {
return Err(Error::invalid_argument(format!(
"data length {} does not match tensor volume {}",
data.len(),
slice.len(),
)));
}
slice.copy_from_slice(data);
Ok(())
}
}
impl fmt::Debug for TensorMut<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write_tensor_debug(
f,
self.name(),
self.num_dims(),
|i| self.dim(i),
self.tensor_type(),
)
}
}
impl fmt::Display for TensorMut<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write_tensor_debug(
f,
self.name(),
self.num_dims(),
|i| self.dim(i),
self.tensor_type(),
)
}
}
fn write_tensor_debug(
f: &mut fmt::Formatter<'_>,
name: &str,
num_dims: Result<usize>,
dim_fn: impl Fn(usize) -> Result<usize>,
tensor_type: TensorType,
) -> fmt::Result {
let num_dims = num_dims.unwrap_or(0);
write!(f, "{name}: ")?;
for i in 0..num_dims {
if i > 0 {
f.write_str("x")?;
}
write!(f, "{}", dim_fn(i).unwrap_or(0))?;
}
write!(f, " {tensor_type:?}")
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn tensor_type_from_primitive_all_variants() {
let cases: &[(isize, TensorType)] = &[
(0, TensorType::NoType),
(1, TensorType::Float32),
(2, TensorType::Int32),
(3, TensorType::UInt8),
(4, TensorType::Int64),
(5, TensorType::String),
(6, TensorType::Bool),
(7, TensorType::Int16),
(8, TensorType::Complex64),
(9, TensorType::Int8),
(10, TensorType::Float16),
(11, TensorType::Float64),
(12, TensorType::Complex128),
(13, TensorType::UInt64),
(14, TensorType::Resource),
(15, TensorType::Variant),
(16, TensorType::UInt32),
(17, TensorType::UInt16),
(18, TensorType::Int4),
(19, TensorType::BFloat16),
];
for &(raw, expected) in cases {
let result = TensorType::from_isize(raw);
assert_eq!(
result,
Some(expected),
"TensorType::from_isize({raw}) should be Some({expected:?})"
);
}
}
#[test]
fn tensor_type_from_u32_all_variants() {
for raw in 0u32..=19 {
let result = TensorType::from_u32(raw);
assert!(
result.is_some(),
"TensorType::from_u32({raw}) should be Some"
);
}
}
#[test]
fn tensor_type_unknown_value_returns_none() {
assert_eq!(TensorType::from_isize(999), None);
assert_eq!(TensorType::from_u32(999), None);
assert_eq!(TensorType::from_isize(-1), None);
assert_eq!(TensorType::from_isize(20), None);
}
#[test]
fn tensor_type_clone() {
let original = TensorType::Float32;
let cloned = original;
assert_eq!(original, cloned);
}
#[test]
fn tensor_type_partial_eq() {
assert_eq!(TensorType::Int8, TensorType::Int8);
assert_ne!(TensorType::Int8, TensorType::UInt8);
}
#[test]
fn tensor_type_hash() {
let mut set = HashSet::new();
set.insert(TensorType::Float32);
set.insert(TensorType::Float32);
set.insert(TensorType::Int32);
assert_eq!(set.len(), 2);
}
#[test]
fn tensor_type_all_variants_unique_in_hashset() {
let all = [
TensorType::NoType,
TensorType::Float32,
TensorType::Int32,
TensorType::UInt8,
TensorType::Int64,
TensorType::String,
TensorType::Bool,
TensorType::Int16,
TensorType::Complex64,
TensorType::Int8,
TensorType::Float16,
TensorType::Float64,
TensorType::Complex128,
TensorType::UInt64,
TensorType::Resource,
TensorType::Variant,
TensorType::UInt32,
TensorType::UInt16,
TensorType::Int4,
TensorType::BFloat16,
];
let set: HashSet<_> = all.iter().copied().collect();
assert_eq!(set.len(), 20);
}
#[test]
fn tensor_type_debug_format() {
assert_eq!(format!("{:?}", TensorType::Float32), "Float32");
assert_eq!(format!("{:?}", TensorType::NoType), "NoType");
assert_eq!(format!("{:?}", TensorType::BFloat16), "BFloat16");
assert_eq!(format!("{:?}", TensorType::Complex128), "Complex128");
}
#[test]
fn quantization_params_construction() {
let params = QuantizationParams {
scale: 0.5,
zero_point: 128,
};
assert!((params.scale - 0.5).abs() < f32::EPSILON);
assert_eq!(params.zero_point, 128);
}
#[test]
fn quantization_params_zero_values() {
let params = QuantizationParams {
scale: 0.0,
zero_point: 0,
};
assert!((params.scale - 0.0).abs() < f32::EPSILON);
assert_eq!(params.zero_point, 0);
}
#[test]
fn quantization_params_negative_zero_point() {
let params = QuantizationParams {
scale: 0.007_812_5,
zero_point: -128,
};
assert!((params.scale - 0.007_812_5).abs() < f32::EPSILON);
assert_eq!(params.zero_point, -128);
}
#[test]
fn quantization_params_debug() {
let params = QuantizationParams {
scale: 1.0,
zero_point: 0,
};
let debug = format!("{params:?}");
assert!(debug.contains("QuantizationParams"));
assert!(debug.contains("scale"));
assert!(debug.contains("zero_point"));
}
#[test]
fn quantization_params_clone() {
let original = QuantizationParams {
scale: 0.25,
zero_point: 64,
};
let cloned = original;
assert_eq!(original, cloned);
}
#[test]
fn quantization_params_partial_eq() {
let a = QuantizationParams {
scale: 0.5,
zero_point: 128,
};
let b = QuantizationParams {
scale: 0.5,
zero_point: 128,
};
let c = QuantizationParams {
scale: 0.25,
zero_point: 128,
};
assert_eq!(a, b);
assert_ne!(a, c);
}
}