use std::any::Any;
use std::fmt::Debug;
use pjrt_sys::{
PJRT_Buffer_Type, PJRT_Buffer_Type_PJRT_Buffer_Type_BF16,
PJRT_Buffer_Type_PJRT_Buffer_Type_C128, PJRT_Buffer_Type_PJRT_Buffer_Type_C64,
PJRT_Buffer_Type_PJRT_Buffer_Type_F16, PJRT_Buffer_Type_PJRT_Buffer_Type_F32,
PJRT_Buffer_Type_PJRT_Buffer_Type_F64, PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3B11FNUZ,
PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3FN, PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3FNUZ,
PJRT_Buffer_Type_PJRT_Buffer_Type_F8E5M2, PJRT_Buffer_Type_PJRT_Buffer_Type_F8E5M2FNUZ,
PJRT_Buffer_Type_PJRT_Buffer_Type_INVALID, PJRT_Buffer_Type_PJRT_Buffer_Type_PRED,
PJRT_Buffer_Type_PJRT_Buffer_Type_S16, PJRT_Buffer_Type_PJRT_Buffer_Type_S2,
PJRT_Buffer_Type_PJRT_Buffer_Type_S32, PJRT_Buffer_Type_PJRT_Buffer_Type_S4,
PJRT_Buffer_Type_PJRT_Buffer_Type_S64, PJRT_Buffer_Type_PJRT_Buffer_Type_S8,
PJRT_Buffer_Type_PJRT_Buffer_Type_TOKEN, PJRT_Buffer_Type_PJRT_Buffer_Type_U16,
PJRT_Buffer_Type_PJRT_Buffer_Type_U2, PJRT_Buffer_Type_PJRT_Buffer_Type_U32,
PJRT_Buffer_Type_PJRT_Buffer_Type_U4, PJRT_Buffer_Type_PJRT_Buffer_Type_U64,
PJRT_Buffer_Type_PJRT_Buffer_Type_U8,
};
use crate::{Error, Result};
pub trait Type: Sized + Copy + Debug + 'static {
const NAME: &'static str;
const PRIMITIVE_TYPE: PrimitiveType;
const TYPE: Self;
const SIZE: usize = std::mem::size_of::<Self::ElemType>();
const ALIGNMENT: usize = std::mem::align_of::<Self::ElemType>();
type ElemType: ElemType<Type = Self>;
}
pub trait ElemType: Sized + Copy + Debug + 'static {
type Type: Type<ElemType = Self>;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Bool;
impl Type for Bool {
const NAME: &'static str = "bool";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::Pred;
const TYPE: Self = Bool;
type ElemType = bool;
}
impl ElemType for bool {
type Type = Bool;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct F32;
impl Type for F32 {
const NAME: &'static str = "f32";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::F32;
const TYPE: Self = F32;
type ElemType = f32;
}
impl ElemType for f32 {
type Type = F32;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct F64;
impl Type for F64 {
const NAME: &'static str = "f64";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::F64;
const TYPE: Self = F64;
type ElemType = f64;
}
impl ElemType for f64 {
type Type = F64;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct I8;
impl Type for I8 {
const NAME: &'static str = "i8";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::S8;
const TYPE: Self = I8;
type ElemType = i8;
}
impl ElemType for i8 {
type Type = I8;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct I16;
impl Type for I16 {
const NAME: &'static str = "i16";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::S16;
const TYPE: Self = I16;
type ElemType = i16;
}
impl ElemType for i16 {
type Type = I16;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct I32;
impl Type for I32 {
const NAME: &'static str = "i32";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::S32;
const TYPE: Self = I32;
type ElemType = i32;
}
impl ElemType for i32 {
type Type = I32;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct I64;
impl Type for I64 {
const NAME: &'static str = "i64";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::S64;
const TYPE: Self = I64;
type ElemType = i64;
}
impl ElemType for i64 {
type Type = I64;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct U8;
impl Type for U8 {
const NAME: &'static str = "u8";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::U8;
const TYPE: Self = U8;
type ElemType = u8;
}
impl ElemType for u8 {
type Type = U8;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct U16;
impl Type for U16 {
const NAME: &'static str = "u16";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::U16;
const TYPE: Self = U16;
type ElemType = u16;
}
impl ElemType for u16 {
type Type = U16;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct U32;
impl Type for U32 {
const NAME: &'static str = "u32";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::U32;
const TYPE: Self = U32;
type ElemType = u32;
}
impl ElemType for u32 {
type Type = U32;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct U64;
impl Type for U64 {
const NAME: &'static str = "u64";
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::U64;
const TYPE: Self = U64;
type ElemType = u64;
}
impl ElemType for u64 {
type Type = U64;
}
#[repr(i32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum PrimitiveType {
Invalid = PJRT_Buffer_Type_PJRT_Buffer_Type_INVALID as i32,
Pred = PJRT_Buffer_Type_PJRT_Buffer_Type_PRED as i32,
S8 = PJRT_Buffer_Type_PJRT_Buffer_Type_S8 as i32,
S16 = PJRT_Buffer_Type_PJRT_Buffer_Type_S16 as i32,
S32 = PJRT_Buffer_Type_PJRT_Buffer_Type_S32 as i32,
S64 = PJRT_Buffer_Type_PJRT_Buffer_Type_S64 as i32,
U8 = PJRT_Buffer_Type_PJRT_Buffer_Type_U8 as i32,
U16 = PJRT_Buffer_Type_PJRT_Buffer_Type_U16 as i32,
U32 = PJRT_Buffer_Type_PJRT_Buffer_Type_U32 as i32,
U64 = PJRT_Buffer_Type_PJRT_Buffer_Type_U64 as i32,
F16 = PJRT_Buffer_Type_PJRT_Buffer_Type_F16 as i32,
F32 = PJRT_Buffer_Type_PJRT_Buffer_Type_F32 as i32,
F64 = PJRT_Buffer_Type_PJRT_Buffer_Type_F64 as i32,
BF16 = PJRT_Buffer_Type_PJRT_Buffer_Type_BF16 as i32,
C64 = PJRT_Buffer_Type_PJRT_Buffer_Type_C64 as i32,
C128 = PJRT_Buffer_Type_PJRT_Buffer_Type_C128 as i32,
F8E5M2 = PJRT_Buffer_Type_PJRT_Buffer_Type_F8E5M2 as i32,
F8E4M3FN = PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3FN as i32,
F8E4M3B11FNUZ = PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3B11FNUZ as i32,
F8E5M2FNUZ = PJRT_Buffer_Type_PJRT_Buffer_Type_F8E5M2FNUZ as i32,
F8E4M3FNUZ = PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3FNUZ as i32,
S4 = PJRT_Buffer_Type_PJRT_Buffer_Type_S4 as i32,
U4 = PJRT_Buffer_Type_PJRT_Buffer_Type_U4 as i32,
Token = PJRT_Buffer_Type_PJRT_Buffer_Type_TOKEN as i32,
S2 = PJRT_Buffer_Type_PJRT_Buffer_Type_S2 as i32,
U2 = PJRT_Buffer_Type_PJRT_Buffer_Type_U2 as i32,
}
impl TryFrom<PrimitiveType> for Box<dyn DType> {
type Error = Error;
fn try_from(value: PrimitiveType) -> Result<Self> {
value.try_into_dtype()
}
}
impl PrimitiveType {
pub fn try_into_dtype(&self) -> Result<Box<dyn DType>> {
match self {
PrimitiveType::Invalid => todo!(),
PrimitiveType::Pred => Ok(Bool.boxed_dtype()),
PrimitiveType::S8 => Ok(I8.boxed_dtype()),
PrimitiveType::S16 => Ok(I16.boxed_dtype()),
PrimitiveType::S32 => Ok(I32.boxed_dtype()),
PrimitiveType::S64 => Ok(I64.boxed_dtype()),
PrimitiveType::U8 => Ok(U8.boxed_dtype()),
PrimitiveType::U16 => Ok(U16.boxed_dtype()),
PrimitiveType::U32 => Ok(U32.boxed_dtype()),
PrimitiveType::U64 => Ok(U64.boxed_dtype()),
PrimitiveType::F32 => Ok(F32.boxed_dtype()),
PrimitiveType::F64 => Ok(F64.boxed_dtype()),
PrimitiveType::F16 => todo!(),
PrimitiveType::BF16 => todo!(),
PrimitiveType::C64 => todo!(),
PrimitiveType::C128 => todo!(),
PrimitiveType::F8E5M2 => todo!(),
PrimitiveType::F8E4M3FN => todo!(),
PrimitiveType::F8E4M3B11FNUZ => todo!(),
PrimitiveType::F8E5M2FNUZ => todo!(),
PrimitiveType::F8E4M3FNUZ => todo!(),
PrimitiveType::S4 => todo!(),
PrimitiveType::U4 => todo!(),
PrimitiveType::Token => todo!(),
PrimitiveType::S2 => todo!(),
PrimitiveType::U2 => todo!(),
}
}
}
impl TryFrom<PJRT_Buffer_Type> for PrimitiveType {
type Error = Error;
#[allow(non_upper_case_globals)]
#[allow(non_snake_case)]
fn try_from(value: PJRT_Buffer_Type) -> Result<Self> {
match value {
PJRT_Buffer_Type_PJRT_Buffer_Type_INVALID => Ok(Self::Invalid),
PJRT_Buffer_Type_PJRT_Buffer_Type_PRED => Ok(Self::Pred),
PJRT_Buffer_Type_PJRT_Buffer_Type_S8 => Ok(Self::S8),
PJRT_Buffer_Type_PJRT_Buffer_Type_S16 => Ok(Self::S16),
PJRT_Buffer_Type_PJRT_Buffer_Type_S32 => Ok(Self::S32),
PJRT_Buffer_Type_PJRT_Buffer_Type_S64 => Ok(Self::S64),
PJRT_Buffer_Type_PJRT_Buffer_Type_U8 => Ok(Self::U8),
PJRT_Buffer_Type_PJRT_Buffer_Type_U16 => Ok(Self::U16),
PJRT_Buffer_Type_PJRT_Buffer_Type_U32 => Ok(Self::U32),
PJRT_Buffer_Type_PJRT_Buffer_Type_U64 => Ok(Self::U64),
PJRT_Buffer_Type_PJRT_Buffer_Type_F16 => Ok(Self::F16),
PJRT_Buffer_Type_PJRT_Buffer_Type_F32 => Ok(Self::F32),
PJRT_Buffer_Type_PJRT_Buffer_Type_F64 => Ok(Self::F64),
PJRT_Buffer_Type_PJRT_Buffer_Type_BF16 => Ok(Self::BF16),
PJRT_Buffer_Type_PJRT_Buffer_Type_C64 => Ok(Self::C64),
PJRT_Buffer_Type_PJRT_Buffer_Type_C128 => Ok(Self::C128),
PJRT_Buffer_Type_PJRT_Buffer_Type_F8E5M2 => Ok(Self::F8E5M2),
PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3FN => Ok(Self::F8E4M3FN),
PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3B11FNUZ => Ok(Self::F8E4M3B11FNUZ),
PJRT_Buffer_Type_PJRT_Buffer_Type_F8E5M2FNUZ => Ok(Self::F8E4M3FNUZ),
PJRT_Buffer_Type_PJRT_Buffer_Type_F8E4M3FNUZ => Ok(Self::F8E4M3FNUZ),
PJRT_Buffer_Type_PJRT_Buffer_Type_S4 => Ok(Self::S4),
PJRT_Buffer_Type_PJRT_Buffer_Type_U4 => Ok(Self::U4),
PJRT_Buffer_Type_PJRT_Buffer_Type_TOKEN => Ok(Self::Token),
PJRT_Buffer_Type_PJRT_Buffer_Type_S2 => Ok(Self::S2),
PJRT_Buffer_Type_PJRT_Buffer_Type_U2 => Ok(Self::U2),
_ => Err(Error::InvalidPrimitiveType(value as i32)),
}
}
}
pub trait DType {
fn name(&self) -> &'static str;
fn primitive_type(&self) -> PrimitiveType;
fn size(&self) -> usize;
fn alignment(&self) -> usize;
fn as_any(&self) -> &dyn Any;
fn boxed_dtype(&self) -> Box<dyn DType>;
}
impl Clone for Box<dyn DType> {
fn clone(&self) -> Self {
self.boxed_dtype()
}
}
impl<T: Type> DType for T {
fn name(&self) -> &'static str {
T::NAME
}
fn primitive_type(&self) -> PrimitiveType {
T::PRIMITIVE_TYPE
}
fn size(&self) -> usize {
T::SIZE
}
fn alignment(&self) -> usize {
T::ALIGNMENT
}
fn as_any(&self) -> &dyn Any {
self
}
fn boxed_dtype(&self) -> Box<dyn DType> {
Box::new(T::TYPE)
}
}
pub trait AsDType {
fn as_dtype(&self) -> &dyn DType;
}
impl<T: DType> AsDType for T {
fn as_dtype(&self) -> &dyn DType {
self
}
}