use super::builtins::{
TYPE_SCALAR_F16, TYPE_SCALAR_F32, TYPE_SCALAR_F64, TYPE_SCALAR_I32, TYPE_SCALAR_U8,
TYPE_TENSOR, TYPE_TENSOR_BF16, TYPE_TENSOR_BOOL, TYPE_TENSOR_F16, TYPE_TENSOR_F32,
TYPE_TENSOR_F64, TYPE_TENSOR_I16, TYPE_TENSOR_I32, TYPE_TENSOR_I64, TYPE_TENSOR_I8,
TYPE_TENSOR_U16, TYPE_TENSOR_U32, TYPE_TENSOR_U64, TYPE_TENSOR_U8,
};
use super::TypeNode;
pub trait Storage: Send + Sync + 'static {
const TYPE: &'static TypeNode;
}
impl Storage for [f32] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_F32;
}
impl Storage for [f64] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_F64;
}
impl Storage for [half::f16] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_F16;
}
impl Storage for [half::bf16] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_BF16;
}
impl Storage for [u8] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_U8;
}
impl Storage for [u16] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_U16;
}
impl Storage for [u32] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_U32;
}
impl Storage for [u64] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_U64;
}
impl Storage for [i8] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_I8;
}
impl Storage for [i16] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_I16;
}
impl Storage for [i32] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_I32;
}
impl Storage for [i64] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_I64;
}
impl Storage for [bool] {
const TYPE: &'static TypeNode = &TYPE_TENSOR_BOOL;
}
impl Storage for f32 {
const TYPE: &'static TypeNode = &TYPE_SCALAR_F32;
}
impl Storage for f64 {
const TYPE: &'static TypeNode = &TYPE_SCALAR_F64;
}
impl Storage for u16 {
const TYPE: &'static TypeNode = &TYPE_SCALAR_F16;
}
impl Storage for u8 {
const TYPE: &'static TypeNode = &TYPE_SCALAR_U8;
}
impl Storage for i32 {
const TYPE: &'static TypeNode = &TYPE_SCALAR_I32;
}
#[derive(Clone, Debug)]
pub struct AnyTensor {
pub bytes: Vec<u8>,
pub dtype: Dtype,
pub shape: Vec<usize>,
}
impl Storage for AnyTensor {
const TYPE: &'static TypeNode = &TYPE_TENSOR;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Dtype {
F32,
F64,
F16,
BF16,
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
Bool,
}
impl Dtype {
pub fn type_node(self) -> &'static TypeNode {
match self {
Dtype::F32 => &TYPE_TENSOR_F32,
Dtype::F64 => &TYPE_TENSOR_F64,
Dtype::F16 => &TYPE_TENSOR_F16,
Dtype::BF16 => &TYPE_TENSOR_BF16,
Dtype::U8 => &TYPE_TENSOR_U8,
Dtype::U16 => &TYPE_TENSOR_U16,
Dtype::U32 => &TYPE_TENSOR_U32,
Dtype::U64 => &TYPE_TENSOR_U64,
Dtype::I8 => &TYPE_TENSOR_I8,
Dtype::I16 => &TYPE_TENSOR_I16,
Dtype::I32 => &TYPE_TENSOR_I32,
Dtype::I64 => &TYPE_TENSOR_I64,
Dtype::Bool => &TYPE_TENSOR_BOOL,
}
}
}