use alloc::{borrow::Cow, string::String, sync::Arc};
use core::fmt;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum TensorElementType {
Float32,
Uint8,
Int8,
Uint16,
Int16,
Int32,
Int64,
String,
Bool,
Float16,
Float64,
Uint32,
Uint64,
Bfloat16,
Complex64,
Complex128,
Float8E4M3FN,
Float8E4M3FNUZ,
Float8E5M2,
Float8E5M2FNUZ,
Uint4,
Int4,
Undefined
}
impl TensorElementType {
pub fn byte_size(&self, container_capacity: usize) -> usize {
match self {
TensorElementType::Uint4 | TensorElementType::Int4 => container_capacity / 2,
TensorElementType::Bool | TensorElementType::Int8 | TensorElementType::Uint8 => container_capacity,
TensorElementType::Int16 | TensorElementType::Uint16 => container_capacity * 2,
TensorElementType::Int32 | TensorElementType::Uint32 => container_capacity * 4,
TensorElementType::Int64 | TensorElementType::Uint64 => container_capacity * 8,
TensorElementType::String => 0, TensorElementType::Float8E4M3FN | TensorElementType::Float8E4M3FNUZ | TensorElementType::Float8E5M2 | TensorElementType::Float8E5M2FNUZ => {
container_capacity * 4
}
TensorElementType::Float16 | TensorElementType::Bfloat16 => container_capacity * 2,
TensorElementType::Float32 => container_capacity * 4,
TensorElementType::Float64 => container_capacity * 8,
TensorElementType::Complex64 => container_capacity * 8,
TensorElementType::Complex128 => container_capacity * 16,
TensorElementType::Undefined => 0
}
}
}
impl fmt::Display for TensorElementType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
TensorElementType::Bool => "bool",
TensorElementType::Bfloat16 => "bf16",
TensorElementType::Float16 => "f16",
TensorElementType::Float32 => "f32",
TensorElementType::Float64 => "f64",
TensorElementType::Int16 => "i16",
TensorElementType::Int32 => "i32",
TensorElementType::Int64 => "i64",
TensorElementType::Int8 => "i8",
TensorElementType::Int4 => "i4",
TensorElementType::String => "String",
TensorElementType::Uint16 => "u16",
TensorElementType::Uint32 => "u32",
TensorElementType::Uint64 => "u64",
TensorElementType::Uint8 => "u8",
TensorElementType::Uint4 => "u4",
TensorElementType::Complex64 => "c64",
TensorElementType::Complex128 => "c128",
TensorElementType::Float8E4M3FN => "f8_e4m3fn",
TensorElementType::Float8E4M3FNUZ => "f8_e4m3fnuz",
TensorElementType::Float8E5M2 => "f8_e5m2",
TensorElementType::Float8E5M2FNUZ => "f8_e5m2fnuz",
TensorElementType::Undefined => "undefined"
})
}
}
impl From<TensorElementType> for ort_sys::ONNXTensorElementDataType {
fn from(val: TensorElementType) -> Self {
match val {
TensorElementType::Undefined => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED,
TensorElementType::Float32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
TensorElementType::Uint8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
TensorElementType::Int8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
TensorElementType::Uint16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
TensorElementType::Int16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
TensorElementType::Int32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
TensorElementType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
TensorElementType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
TensorElementType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
TensorElementType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
TensorElementType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
TensorElementType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
TensorElementType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
TensorElementType::Bfloat16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16,
TensorElementType::Int4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4,
TensorElementType::Uint4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4,
TensorElementType::Complex64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64,
TensorElementType::Complex128 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128,
TensorElementType::Float8E4M3FN => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN,
TensorElementType::Float8E4M3FNUZ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ,
TensorElementType::Float8E5M2 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2,
TensorElementType::Float8E5M2FNUZ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ
}
}
}
impl From<ort_sys::ONNXTensorElementDataType> for TensorElementType {
fn from(val: ort_sys::ONNXTensorElementDataType) -> Self {
match val {
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED => TensorElementType::Undefined,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => TensorElementType::Float32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => TensorElementType::Uint8,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => TensorElementType::Int8,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => TensorElementType::Uint16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => TensorElementType::Int16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => TensorElementType::Int32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => TensorElementType::Int64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => TensorElementType::String,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => TensorElementType::Bool,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => TensorElementType::Float16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => TensorElementType::Float64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => TensorElementType::Uint32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => TensorElementType::Uint64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => TensorElementType::Bfloat16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 => TensorElementType::Int4,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 => TensorElementType::Uint4,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => TensorElementType::Complex64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => TensorElementType::Complex128,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN => TensorElementType::Float8E4M3FN,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ => TensorElementType::Float8E4M3FNUZ,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 => TensorElementType::Float8E5M2,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ => TensorElementType::Float8E5M2FNUZ
}
}
}
pub trait IntoTensorElementType {
fn into_tensor_element_type() -> TensorElementType;
private_trait!();
}
pub trait PrimitiveTensorElementType: IntoTensorElementType {
private_trait!();
}
macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl IntoTensorElementType for $type_ {
fn into_tensor_element_type() -> TensorElementType {
TensorElementType::$variant
}
private_impl!();
}
impl PrimitiveTensorElementType for $type_ {
private_impl!();
}
};
}
impl_type_trait!(f32, Float32);
impl_type_trait!(u8, Uint8);
impl_type_trait!(i8, Int8);
impl_type_trait!(u16, Uint16);
impl_type_trait!(i16, Int16);
impl_type_trait!(i32, Int32);
impl_type_trait!(i64, Int64);
impl_type_trait!(bool, Bool);
#[cfg(feature = "half")]
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
impl_type_trait!(half::f16, Float16);
impl_type_trait!(f64, Float64);
impl_type_trait!(u32, Uint32);
impl_type_trait!(u64, Uint64);
#[cfg(feature = "half")]
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
impl_type_trait!(half::bf16, Bfloat16);
#[cfg(feature = "num-complex")]
#[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
impl_type_trait!(num_complex::Complex32, Complex64);
#[cfg(feature = "num-complex")]
#[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
impl_type_trait!(num_complex::Complex64, Complex128);
impl IntoTensorElementType for String {
fn into_tensor_element_type() -> TensorElementType {
TensorElementType::String
}
private_impl!();
}
pub trait Utf8Data {
fn as_utf8_bytes(&self) -> &[u8];
}
impl Utf8Data for String {
fn as_utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl Utf8Data for &str {
fn as_utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl Utf8Data for Cow<'_, str> {
fn as_utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl Utf8Data for Arc<str> {
fn as_utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}