use super::DType;
use bytemuck::{Pod, Zeroable};
use std::ops::{Add, Div, Mul, Sub};
pub trait Element:
Copy
+ Clone
+ Send
+ Sync
+ Pod
+ Zeroable
+ 'static
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ PartialOrd
{
const DTYPE: DType;
fn to_f64(self) -> f64;
fn from_f64(v: f64) -> Self;
#[inline]
fn to_f32(self) -> f32 {
self.to_f64() as f32
}
#[inline]
fn from_f32(v: f32) -> Self {
Self::from_f64(v as f64)
}
fn zero() -> Self;
fn one() -> Self;
}
impl Element for f64 {
const DTYPE: DType = DType::F64;
#[inline]
fn to_f64(self) -> f64 {
self
}
#[inline]
fn from_f64(v: f64) -> Self {
v
}
#[inline]
fn to_f32(self) -> f32 {
self as f32
}
#[inline]
fn from_f32(v: f32) -> Self {
v as f64
}
#[inline]
fn zero() -> Self {
0.0
}
#[inline]
fn one() -> Self {
1.0
}
}
impl Element for f32 {
const DTYPE: DType = DType::F32;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as f32
}
#[inline]
fn to_f32(self) -> f32 {
self
}
#[inline]
fn from_f32(v: f32) -> Self {
v
}
#[inline]
fn zero() -> Self {
0.0
}
#[inline]
fn one() -> Self {
1.0
}
}
impl Element for i64 {
const DTYPE: DType = DType::I64;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as i64
}
#[inline]
fn zero() -> Self {
0
}
#[inline]
fn one() -> Self {
1
}
}
impl Element for i32 {
const DTYPE: DType = DType::I32;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as i32
}
#[inline]
fn zero() -> Self {
0
}
#[inline]
fn one() -> Self {
1
}
}
impl Element for i16 {
const DTYPE: DType = DType::I16;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as i16
}
#[inline]
fn zero() -> Self {
0
}
#[inline]
fn one() -> Self {
1
}
}
impl Element for i8 {
const DTYPE: DType = DType::I8;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as i8
}
#[inline]
fn zero() -> Self {
0
}
#[inline]
fn one() -> Self {
1
}
}
impl Element for u64 {
const DTYPE: DType = DType::U64;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as u64
}
#[inline]
fn zero() -> Self {
0
}
#[inline]
fn one() -> Self {
1
}
}
impl Element for u32 {
const DTYPE: DType = DType::U32;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as u32
}
#[inline]
fn zero() -> Self {
0
}
#[inline]
fn one() -> Self {
1
}
}
impl Element for u16 {
const DTYPE: DType = DType::U16;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as u16
}
#[inline]
fn zero() -> Self {
0
}
#[inline]
fn one() -> Self {
1
}
}
impl Element for u8 {
const DTYPE: DType = DType::U8;
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
v as u8
}
#[inline]
fn zero() -> Self {
0
}
#[inline]
fn one() -> Self {
1
}
}
#[cfg(feature = "f16")]
impl Element for half::f16 {
const DTYPE: DType = DType::F16;
#[inline]
fn to_f64(self) -> f64 {
self.to_f64()
}
#[inline]
fn from_f64(v: f64) -> Self {
half::f16::from_f64(v)
}
#[inline]
fn to_f32(self) -> f32 {
self.to_f32()
}
#[inline]
fn from_f32(v: f32) -> Self {
half::f16::from_f32(v)
}
#[inline]
fn zero() -> Self {
half::f16::ZERO
}
#[inline]
fn one() -> Self {
half::f16::ONE
}
}
#[cfg(feature = "f16")]
impl Element for half::bf16 {
const DTYPE: DType = DType::BF16;
#[inline]
fn to_f64(self) -> f64 {
self.to_f64()
}
#[inline]
fn from_f64(v: f64) -> Self {
half::bf16::from_f64(v)
}
#[inline]
fn to_f32(self) -> f32 {
self.to_f32()
}
#[inline]
fn from_f32(v: f32) -> Self {
half::bf16::from_f32(v)
}
#[inline]
fn zero() -> Self {
half::bf16::ZERO
}
#[inline]
fn one() -> Self {
half::bf16::ONE
}
}
impl Element for super::fp8::FP8E4M3 {
const DTYPE: DType = DType::FP8E4M3;
#[inline]
fn to_f64(self) -> f64 {
self.to_f32() as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
Self::from_f32(v as f32)
}
#[inline]
fn to_f32(self) -> f32 {
self.to_f32()
}
#[inline]
fn from_f32(v: f32) -> Self {
Self::from_f32(v)
}
#[inline]
fn zero() -> Self {
Self::ZERO
}
#[inline]
fn one() -> Self {
Self::ONE
}
}
impl Element for super::fp8::FP8E5M2 {
const DTYPE: DType = DType::FP8E5M2;
#[inline]
fn to_f64(self) -> f64 {
self.to_f32() as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
Self::from_f32(v as f32)
}
#[inline]
fn to_f32(self) -> f32 {
self.to_f32()
}
#[inline]
fn from_f32(v: f32) -> Self {
Self::from_f32(v)
}
#[inline]
fn zero() -> Self {
Self::ZERO
}
#[inline]
fn one() -> Self {
Self::ONE
}
}
impl Element for super::complex::Complex64 {
const DTYPE: DType = DType::Complex64;
#[inline]
fn to_f64(self) -> f64 {
self.magnitude() as f64
}
#[inline]
fn from_f64(v: f64) -> Self {
Self::new(v as f32, 0.0)
}
#[inline]
fn zero() -> Self {
Self::ZERO
}
#[inline]
fn one() -> Self {
Self::ONE
}
}
impl Element for super::complex::Complex128 {
const DTYPE: DType = DType::Complex128;
#[inline]
fn to_f64(self) -> f64 {
self.magnitude()
}
#[inline]
fn from_f64(v: f64) -> Self {
Self::new(v, 0.0)
}
#[inline]
fn zero() -> Self {
Self::ZERO
}
#[inline]
fn one() -> Self {
Self::ONE
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_element_dtype() {
assert_eq!(f64::DTYPE, DType::F64);
assert_eq!(f32::DTYPE, DType::F32);
assert_eq!(i32::DTYPE, DType::I32);
assert_eq!(u8::DTYPE, DType::U8);
}
#[test]
fn test_element_conversions() {
assert_eq!(f32::from_f64(2.5).to_f64(), 2.5f32 as f64);
assert_eq!(i32::from_f64(42.0), 42);
}
#[test]
fn test_fp8_element_dtype() {
use super::super::fp8::{FP8E4M3, FP8E5M2};
assert_eq!(FP8E4M3::DTYPE, DType::FP8E4M3);
assert_eq!(FP8E5M2::DTYPE, DType::FP8E5M2);
}
#[test]
fn test_fp8_element_conversions() {
use super::super::fp8::{FP8E4M3, FP8E5M2};
let e4m3 = FP8E4M3::from_f64(2.0);
assert!((e4m3.to_f64() - 2.0).abs() < 0.1);
let e5m2 = FP8E5M2::from_f64(100.0);
assert!((e5m2.to_f64() - 100.0).abs() < 15.0);
assert_eq!(FP8E4M3::zero().to_f32(), 0.0);
assert!((FP8E4M3::one().to_f32() - 1.0).abs() < 0.01);
assert_eq!(FP8E5M2::zero().to_f32(), 0.0);
assert!((FP8E5M2::one().to_f32() - 1.0).abs() < 0.01);
}
}