#![allow(non_camel_case_types)]
#[link(name = "numkong")]
extern "C" {
fn nk_f32_to_f16(src: *const f32, dest: *mut u16);
fn nk_f16_to_f32(src: *const u16, dest: *mut f32);
fn nk_f32_to_bf16(src: *const f32, dest: *mut u16);
fn nk_bf16_to_f32(src: *const u16, dest: *mut f32);
fn nk_f32_to_e4m3(src: *const f32, dest: *mut u8);
fn nk_e4m3_to_f32(src: *const u8, dest: *mut f32);
fn nk_f32_to_e5m2(src: *const f32, dest: *mut u8);
fn nk_e5m2_to_f32(src: *const u8, dest: *mut f32);
fn nk_f32_to_e2m3(src: *const f32, dest: *mut u8);
fn nk_e2m3_to_f32(src: *const u8, dest: *mut f32);
fn nk_f32_to_e3m2(src: *const f32, dest: *mut u8);
fn nk_e3m2_to_f32(src: *const u8, dest: *mut f32);
}
#[inline(always)]
pub(crate) fn f32_abs_compat(x: f32) -> f32 { f32::from_bits(x.to_bits() & 0x7FFF_FFFF) }
#[inline(always)]
pub(crate) fn f32_round_compat(x: f32) -> f32 {
let t = x as i32 as f32;
let d = x - t;
if d >= 0.5 {
t + 1.0
} else if d <= -0.5 {
t - 1.0
} else {
t
}
}
#[inline(always)]
pub(crate) fn f64_round_compat(x: f64) -> f64 {
let t = x as i64 as f64;
let d = x - t;
if d >= 0.5 {
t + 1.0
} else if d <= -0.5 {
t - 1.0
} else {
t
}
}
#[inline]
pub fn is_close(a: f64, b: f64, atol: f64, rtol: f64) -> bool {
let diff = if a > b { a - b } else { b - a };
diff <= atol + rtol * (if b >= 0.0 { b } else { -b })
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct f16(pub u16);
impl f16 {
pub const ZERO: Self = f16(0);
pub const ONE: Self = f16(0x3C00);
pub const NEG_ONE: Self = f16(0xBC00);
pub const NAN: Self = f16(0x7E00);
#[inline(always)]
pub fn from_f32(value: f32) -> Self {
let mut result: u16 = 0;
unsafe { nk_f32_to_f16(&value, &mut result) };
f16(result)
}
#[inline(always)]
pub fn to_f32(self) -> f32 {
let mut result: f32 = 0.0;
unsafe { nk_f16_to_f32(&self.0, &mut result) };
result
}
#[inline(always)]
pub fn is_nan(self) -> bool { self.to_f32().is_nan() }
#[inline(always)]
pub fn is_infinite(self) -> bool { self.to_f32().is_infinite() }
#[inline(always)]
pub fn is_finite(self) -> bool { self.to_f32().is_finite() }
#[inline(always)]
pub fn abs(self) -> Self { Self::from_f32(f32_abs_compat(self.to_f32())) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn floor(self) -> Self { Self::from_f32(self.to_f32().floor()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn ceil(self) -> Self { Self::from_f32(self.to_f32().ceil()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn round(self) -> Self { Self::from_f32(self.to_f32().round()) }
}
impl core::fmt::Debug for f16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "f16({}, 0x{:04x})", self.to_f32(), self.0)
}
}
impl core::fmt::Display for f16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if f.alternate() {
core::fmt::Display::fmt(&self.to_f32(), f)?;
write!(f, " [0x{:04x}]", self.0)
} else {
core::fmt::Display::fmt(&self.to_f32(), f)
}
}
}
impl core::fmt::LowerHex for f16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for f16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for f16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl core::ops::Add for f16 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() + rhs.to_f32()) }
}
impl core::ops::Sub for f16 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() - rhs.to_f32()) }
}
impl core::ops::Mul for f16 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() * rhs.to_f32()) }
}
impl core::ops::Div for f16 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() / rhs.to_f32()) }
}
impl core::ops::Neg for f16 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self::Output { Self::from_f32(-self.to_f32()) }
}
impl core::cmp::PartialOrd for f16 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct bf16(pub u16);
impl bf16 {
pub const ZERO: Self = bf16(0);
pub const ONE: Self = bf16(0x3F80);
pub const NEG_ONE: Self = bf16(0xBF80);
pub const NAN: Self = bf16(0x7FC0);
#[inline(always)]
pub fn from_f32(value: f32) -> Self {
let mut result: u16 = 0;
unsafe { nk_f32_to_bf16(&value, &mut result) };
bf16(result)
}
#[inline(always)]
pub fn to_f32(self) -> f32 {
let mut result: f32 = 0.0;
unsafe { nk_bf16_to_f32(&self.0, &mut result) };
result
}
#[inline(always)]
pub fn is_nan(self) -> bool { self.to_f32().is_nan() }
#[inline(always)]
pub fn is_infinite(self) -> bool { self.to_f32().is_infinite() }
#[inline(always)]
pub fn is_finite(self) -> bool { self.to_f32().is_finite() }
#[inline(always)]
pub fn abs(self) -> Self { Self::from_f32(f32_abs_compat(self.to_f32())) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn floor(self) -> Self { Self::from_f32(self.to_f32().floor()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn ceil(self) -> Self { Self::from_f32(self.to_f32().ceil()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn round(self) -> Self { Self::from_f32(self.to_f32().round()) }
}
impl core::fmt::Debug for bf16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "bf16({}, 0x{:04x})", self.to_f32(), self.0)
}
}
impl core::fmt::Display for bf16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if f.alternate() {
core::fmt::Display::fmt(&self.to_f32(), f)?;
write!(f, " [0x{:04x}]", self.0)
} else {
core::fmt::Display::fmt(&self.to_f32(), f)
}
}
}
impl core::fmt::LowerHex for bf16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for bf16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for bf16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl core::ops::Add for bf16 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() + rhs.to_f32()) }
}
impl core::ops::Sub for bf16 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() - rhs.to_f32()) }
}
impl core::ops::Mul for bf16 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() * rhs.to_f32()) }
}
impl core::ops::Div for bf16 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() / rhs.to_f32()) }
}
impl core::ops::Neg for bf16 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self::Output { Self::from_f32(-self.to_f32()) }
}
impl core::cmp::PartialOrd for bf16 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct e4m3(pub u8);
impl e4m3 {
pub const ZERO: Self = e4m3(0x00);
pub const ONE: Self = e4m3(0x38);
pub const NEG_ONE: Self = e4m3(0xB8);
pub const NAN: Self = e4m3(0x7F);
#[inline(always)]
pub fn from_f32(value: f32) -> Self {
let mut result: u8 = 0;
unsafe { nk_f32_to_e4m3(&value, &mut result) };
e4m3(result)
}
#[inline(always)]
pub fn to_f32(self) -> f32 {
let mut result: f32 = 0.0;
unsafe { nk_e4m3_to_f32(&self.0, &mut result) };
result
}
#[inline(always)]
pub fn is_nan(self) -> bool { (self.0 & 0x7F) == 0x7F }
#[inline(always)]
pub fn is_infinite(self) -> bool { false }
#[inline(always)]
pub fn is_finite(self) -> bool { !self.is_nan() }
#[inline(always)]
pub fn abs(self) -> Self { Self::from_f32(f32_abs_compat(self.to_f32())) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn floor(self) -> Self { Self::from_f32(self.to_f32().floor()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn ceil(self) -> Self { Self::from_f32(self.to_f32().ceil()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn round(self) -> Self { Self::from_f32(self.to_f32().round()) }
}
impl core::fmt::Debug for e4m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "e4m3({}, 0x{:02x})", self.to_f32(), self.0)
}
}
impl core::fmt::Display for e4m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if f.alternate() {
core::fmt::Display::fmt(&self.to_f32(), f)?;
write!(f, " [0x{:02x}]", self.0)
} else {
core::fmt::Display::fmt(&self.to_f32(), f)
}
}
}
impl core::fmt::LowerHex for e4m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for e4m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for e4m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl core::ops::Add for e4m3 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() + rhs.to_f32()) }
}
impl core::ops::Sub for e4m3 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() - rhs.to_f32()) }
}
impl core::ops::Mul for e4m3 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() * rhs.to_f32()) }
}
impl core::ops::Div for e4m3 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() / rhs.to_f32()) }
}
impl core::ops::Neg for e4m3 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self::Output { Self::from_f32(-self.to_f32()) }
}
impl core::cmp::PartialOrd for e4m3 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct e5m2(pub u8);
impl e5m2 {
pub const ZERO: Self = e5m2(0x00);
pub const ONE: Self = e5m2(0x3C);
pub const NEG_ONE: Self = e5m2(0xBC);
pub const NAN: Self = e5m2(0x7F);
#[inline(always)]
pub fn from_f32(value: f32) -> Self {
let mut result: u8 = 0;
unsafe { nk_f32_to_e5m2(&value, &mut result) };
e5m2(result)
}
#[inline(always)]
pub fn to_f32(self) -> f32 {
let mut result: f32 = 0.0;
unsafe { nk_e5m2_to_f32(&self.0, &mut result) };
result
}
#[inline(always)]
pub fn is_nan(self) -> bool {
let exp = (self.0 >> 2) & 0x1F;
let mant = self.0 & 0x03;
exp == 0x1F && mant != 0
}
#[inline(always)]
pub fn is_infinite(self) -> bool {
let exp = (self.0 >> 2) & 0x1F;
let mant = self.0 & 0x03;
exp == 0x1F && mant == 0
}
#[inline(always)]
pub fn is_finite(self) -> bool {
let exp = (self.0 >> 2) & 0x1F;
exp != 0x1F
}
#[inline(always)]
pub fn abs(self) -> Self { Self::from_f32(f32_abs_compat(self.to_f32())) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn floor(self) -> Self { Self::from_f32(self.to_f32().floor()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn ceil(self) -> Self { Self::from_f32(self.to_f32().ceil()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn round(self) -> Self { Self::from_f32(self.to_f32().round()) }
}
impl core::fmt::Debug for e5m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "e5m2({}, 0x{:02x})", self.to_f32(), self.0)
}
}
impl core::fmt::Display for e5m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if f.alternate() {
core::fmt::Display::fmt(&self.to_f32(), f)?;
write!(f, " [0x{:02x}]", self.0)
} else {
core::fmt::Display::fmt(&self.to_f32(), f)
}
}
}
impl core::fmt::LowerHex for e5m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for e5m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for e5m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl core::ops::Add for e5m2 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() + rhs.to_f32()) }
}
impl core::ops::Sub for e5m2 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() - rhs.to_f32()) }
}
impl core::ops::Mul for e5m2 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() * rhs.to_f32()) }
}
impl core::ops::Div for e5m2 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() / rhs.to_f32()) }
}
impl core::ops::Neg for e5m2 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self::Output { Self::from_f32(-self.to_f32()) }
}
impl core::cmp::PartialOrd for e5m2 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct e2m3(pub u8);
impl e2m3 {
pub const ZERO: Self = e2m3(0x00);
pub const ONE: Self = e2m3(0x08);
pub const NEG_ONE: Self = e2m3(0x28);
#[inline(always)]
pub fn from_f32(value: f32) -> Self {
let mut result: u8 = 0;
unsafe { nk_f32_to_e2m3(&value, &mut result) };
e2m3(result)
}
#[inline(always)]
pub fn to_f32(self) -> f32 {
let mut result: f32 = 0.0;
unsafe { nk_e2m3_to_f32(&self.0, &mut result) };
result
}
#[inline(always)]
pub fn is_nan(self) -> bool {
false }
#[inline(always)]
pub fn is_infinite(self) -> bool { false }
#[inline(always)]
pub fn is_finite(self) -> bool { true }
#[inline(always)]
pub fn abs(self) -> Self {
e2m3(self.0 & 0x1F) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn floor(self) -> Self { Self::from_f32(self.to_f32().floor()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn ceil(self) -> Self { Self::from_f32(self.to_f32().ceil()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn round(self) -> Self { Self::from_f32(self.to_f32().round()) }
}
impl core::fmt::Debug for e2m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "e2m3({}, 0x{:02x})", self.to_f32(), self.0)
}
}
impl core::fmt::Display for e2m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if f.alternate() {
core::fmt::Display::fmt(&self.to_f32(), f)?;
write!(f, " [0x{:02x}]", self.0)
} else {
core::fmt::Display::fmt(&self.to_f32(), f)
}
}
}
impl core::fmt::LowerHex for e2m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for e2m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for e2m3 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl core::ops::Add for e2m3 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() + rhs.to_f32()) }
}
impl core::ops::Sub for e2m3 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() - rhs.to_f32()) }
}
impl core::ops::Mul for e2m3 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() * rhs.to_f32()) }
}
impl core::ops::Div for e2m3 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() / rhs.to_f32()) }
}
impl core::ops::Neg for e2m3 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self::Output { Self::from_f32(-self.to_f32()) }
}
impl core::cmp::PartialOrd for e2m3 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct e3m2(pub u8);
impl e3m2 {
pub const ZERO: Self = e3m2(0x00);
pub const ONE: Self = e3m2(0x0C);
pub const NEG_ONE: Self = e3m2(0x2C);
#[inline(always)]
pub fn from_f32(value: f32) -> Self {
let mut result: u8 = 0;
unsafe { nk_f32_to_e3m2(&value, &mut result) };
e3m2(result)
}
#[inline(always)]
pub fn to_f32(self) -> f32 {
let mut result: f32 = 0.0;
unsafe { nk_e3m2_to_f32(&self.0, &mut result) };
result
}
#[inline(always)]
pub fn is_nan(self) -> bool {
let exp = (self.0 >> 2) & 0x07;
let mant = self.0 & 0x03;
exp == 0x07 && mant != 0
}
#[inline(always)]
pub fn is_infinite(self) -> bool {
let exp = (self.0 >> 2) & 0x07;
let mant = self.0 & 0x03;
exp == 0x07 && mant == 0
}
#[inline(always)]
pub fn is_finite(self) -> bool {
let exp = (self.0 >> 2) & 0x07;
exp != 0x07
}
#[inline(always)]
pub fn abs(self) -> Self {
e3m2(self.0 & 0x1F) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn floor(self) -> Self { Self::from_f32(self.to_f32().floor()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn ceil(self) -> Self { Self::from_f32(self.to_f32().ceil()) }
#[cfg(feature = "std")]
#[inline(always)]
pub fn round(self) -> Self { Self::from_f32(self.to_f32().round()) }
}
impl core::fmt::Debug for e3m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "e3m2({}, 0x{:02x})", self.to_f32(), self.0)
}
}
impl core::fmt::Display for e3m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if f.alternate() {
core::fmt::Display::fmt(&self.to_f32(), f)?;
write!(f, " [0x{:02x}]", self.0)
} else {
core::fmt::Display::fmt(&self.to_f32(), f)
}
}
}
impl core::fmt::LowerHex for e3m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for e3m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for e3m2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl core::ops::Add for e3m2 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() + rhs.to_f32()) }
}
impl core::ops::Sub for e3m2 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() - rhs.to_f32()) }
}
impl core::ops::Mul for e3m2 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() * rhs.to_f32()) }
}
impl core::ops::Div for e3m2 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self::Output { Self::from_f32(self.to_f32() / rhs.to_f32()) }
}
impl core::ops::Neg for e3m2 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self::Output { Self::from_f32(-self.to_f32()) }
}
impl core::cmp::PartialOrd for e3m2 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
impl From<f32> for f16 {
#[inline(always)]
fn from(value: f32) -> Self { Self::from_f32(value) }
}
impl From<f16> for f32 {
#[inline(always)]
fn from(value: f16) -> Self { value.to_f32() }
}
impl From<f32> for bf16 {
#[inline(always)]
fn from(value: f32) -> Self { Self::from_f32(value) }
}
impl From<bf16> for f32 {
#[inline(always)]
fn from(value: bf16) -> Self { value.to_f32() }
}
impl From<f32> for e4m3 {
#[inline(always)]
fn from(value: f32) -> Self { Self::from_f32(value) }
}
impl From<e4m3> for f32 {
#[inline(always)]
fn from(value: e4m3) -> Self { value.to_f32() }
}
impl From<f32> for e5m2 {
#[inline(always)]
fn from(value: f32) -> Self { Self::from_f32(value) }
}
impl From<e5m2> for f32 {
#[inline(always)]
fn from(value: e5m2) -> Self { value.to_f32() }
}
impl From<f32> for e2m3 {
#[inline(always)]
fn from(value: f32) -> Self { Self::from_f32(value) }
}
impl From<e2m3> for f32 {
#[inline(always)]
fn from(value: e2m3) -> Self { value.to_f32() }
}
impl From<f32> for e3m2 {
#[inline(always)]
fn from(value: f32) -> Self { Self::from_f32(value) }
}
impl From<e3m2> for f32 {
#[inline(always)]
fn from(value: e3m2) -> Self { value.to_f32() }
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct u1x8(pub u8);
impl core::fmt::Debug for u1x8 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "u1x8(0b{:08b}, 0x{:02x})", self.0, self.0)
}
}
impl core::fmt::Display for u1x8 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if f.alternate() {
write!(f, "0b{:08b} [0x{:02x}]", self.0, self.0)
} else {
write!(f, "0b{:08b}", self.0)
}
}
}
impl core::fmt::LowerHex for u1x8 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for u1x8 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for u1x8 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl u1x8 {
#[inline(always)]
pub const fn new(bits: u8) -> Self { u1x8(bits) }
#[inline(always)]
pub const fn bits(self) -> u8 { self.0 }
#[inline(always)]
pub const fn from_bools(
b0: bool,
b1: bool,
b2: bool,
b3: bool,
b4: bool,
b5: bool,
b6: bool,
b7: bool,
) -> Self {
u1x8(
(b0 as u8)
| ((b1 as u8) << 1)
| ((b2 as u8) << 2)
| ((b3 as u8) << 3)
| ((b4 as u8) << 4)
| ((b5 as u8) << 5)
| ((b6 as u8) << 6)
| ((b7 as u8) << 7),
)
}
#[inline(always)]
pub const fn to_bools(self) -> (bool, bool, bool, bool, bool, bool, bool, bool) {
(
(self.0 & 1) != 0,
(self.0 & 2) != 0,
(self.0 & 4) != 0,
(self.0 & 8) != 0,
(self.0 & 16) != 0,
(self.0 & 32) != 0,
(self.0 & 64) != 0,
(self.0 & 128) != 0,
)
}
}
impl From<(bool, bool, bool, bool, bool, bool, bool, bool)> for u1x8 {
#[inline(always)]
fn from(b: (bool, bool, bool, bool, bool, bool, bool, bool)) -> Self {
u1x8::from_bools(b.0, b.1, b.2, b.3, b.4, b.5, b.6, b.7)
}
}
impl From<u1x8> for (bool, bool, bool, bool, bool, bool, bool, bool) {
#[inline(always)]
fn from(v: u1x8) -> Self { v.to_bools() }
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct u4x2(pub u8);
impl core::fmt::Debug for u4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let (lo, hi) = self.to_u8s();
write!(f, "u4x2({}, {}, 0x{:02x})", lo, hi, self.0)
}
}
impl core::fmt::Display for u4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let (lo, hi) = self.to_u8s();
if f.alternate() {
write!(f, "({}, {}) [0x{:02x}]", lo, hi, self.0)
} else {
write!(f, "({}, {})", lo, hi)
}
}
}
impl core::fmt::LowerHex for u4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for u4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for u4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl u4x2 {
#[inline(always)]
pub const fn new(packed: u8) -> Self { u4x2(packed) }
#[inline(always)]
pub const fn packed(self) -> u8 { self.0 }
#[inline(always)]
pub const fn from_u8s(lo: u8, hi: u8) -> Self {
let lo_sat = if lo > 15 { 15 } else { lo };
let hi_sat = if hi > 15 { 15 } else { hi };
u4x2(lo_sat | (hi_sat << 4))
}
#[inline(always)]
pub const fn to_u8s(self) -> (u8, u8) { (self.0 & 0x0F, self.0 >> 4) }
}
impl From<(u8, u8)> for u4x2 {
#[inline(always)]
fn from(v: (u8, u8)) -> Self { u4x2::from_u8s(v.0, v.1) }
}
impl From<u4x2> for (u8, u8) {
#[inline(always)]
fn from(v: u4x2) -> Self { v.to_u8s() }
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct i4x2(pub u8);
impl core::fmt::Debug for i4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let (lo, hi) = self.to_i8s();
write!(f, "i4x2({}, {}, 0x{:02x})", lo, hi, self.0)
}
}
impl core::fmt::Display for i4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let (lo, hi) = self.to_i8s();
if f.alternate() {
write!(f, "({}, {}) [0x{:02x}]", lo, hi, self.0)
} else {
write!(f, "({}, {})", lo, hi)
}
}
}
impl core::fmt::LowerHex for i4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::LowerHex::fmt(&self.0, f)
}
}
impl core::fmt::UpperHex for i4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::UpperHex::fmt(&self.0, f)
}
}
impl core::fmt::Binary for i4x2 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Binary::fmt(&self.0, f)
}
}
impl i4x2 {
#[inline(always)]
pub const fn new(packed: u8) -> Self { i4x2(packed) }
#[inline(always)]
pub const fn packed(self) -> u8 { self.0 }
#[inline(always)]
pub const fn from_i8s(lo: i8, hi: i8) -> Self {
let lo_sat = if lo < -8 {
-8
} else if lo > 7 {
7
} else {
lo
};
let hi_sat = if hi < -8 {
-8
} else if hi > 7 {
7
} else {
hi
};
i4x2(((lo_sat as u8) & 0x0F) | (((hi_sat as u8) & 0x0F) << 4))
}
#[inline(always)]
pub const fn to_i8s(self) -> (i8, i8) {
let lo = (self.0 & 0x0F) as i8;
let hi = ((self.0 >> 4) & 0x0F) as i8;
let lo = if lo & 0x08 != 0 { lo | (!0x0Fi8) } else { lo };
let hi = if hi & 0x08 != 0 { hi | (!0x0Fi8) } else { hi };
(lo, hi)
}
}
impl From<(i8, i8)> for i4x2 {
#[inline(always)]
fn from(v: (i8, i8)) -> Self { i4x2::from_i8s(v.0, v.1) }
}
impl From<i4x2> for (i8, i8) {
#[inline(always)]
fn from(v: i4x2) -> Self { v.to_i8s() }
}
pub trait StorageElement: Sized + Copy + Clone + Default {
fn zero() -> Self;
fn one() -> Self;
fn dimensions_per_value() -> usize { 1 }
}
pub trait NumberLike: StorageElement {
fn from_f32(v: f32) -> Self;
fn to_f32(self) -> f32;
fn from_f64(v: f64) -> Self { Self::from_f32(v as f32) }
fn to_f64(self) -> f64 { self.to_f32() as f64 }
fn abs(self) -> Self { Self::from_f32(f32_abs_compat(self.to_f32())) }
fn is_nan(self) -> bool { self.to_f32().is_nan() }
fn is_finite(self) -> bool { self.to_f32().is_finite() }
fn is_infinite(self) -> bool { self.to_f32().is_infinite() }
fn has_infinity() -> bool { false }
fn has_nan() -> bool { false }
fn has_subnormals() -> bool { false }
fn max_value() -> f32 { f32::MAX }
fn min_positive() -> f32 { f32::MIN_POSITIVE }
}
pub trait FloatLike: NumberLike {}
impl<T: NumberLike> FloatLike for T {}
impl StorageElement for f32 {
fn zero() -> Self { 0.0 }
fn one() -> Self { 1.0 }
}
impl NumberLike for f32 {
fn from_f32(v: f32) -> Self { v }
fn to_f32(self) -> f32 { self }
fn from_f64(v: f64) -> Self { v as f32 }
fn to_f64(self) -> f64 { self as f64 }
fn abs(self) -> Self { f32_abs_compat(self) }
fn is_nan(self) -> bool { f32::is_nan(self) }
fn is_finite(self) -> bool { f32::is_finite(self) }
fn is_infinite(self) -> bool { f32::is_infinite(self) }
fn has_infinity() -> bool { true }
fn has_nan() -> bool { true }
fn has_subnormals() -> bool { true }
fn max_value() -> f32 { f32::MAX }
fn min_positive() -> f32 { f32::MIN_POSITIVE }
}
impl StorageElement for f64 {
fn zero() -> Self { 0.0 }
fn one() -> Self { 1.0 }
}
impl NumberLike for f64 {
fn from_f32(v: f32) -> Self { v as f64 }
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self { v }
fn to_f64(self) -> f64 { self }
fn abs(self) -> Self { f64::from_bits(self.to_bits() & 0x7FFF_FFFF_FFFF_FFFF) }
fn is_nan(self) -> bool { f64::is_nan(self) }
fn is_finite(self) -> bool { f64::is_finite(self) }
fn is_infinite(self) -> bool { f64::is_infinite(self) }
fn has_infinity() -> bool { true }
fn has_nan() -> bool { true }
fn has_subnormals() -> bool { true }
fn max_value() -> f32 { f32::MAX }
fn min_positive() -> f32 { f32::MIN_POSITIVE }
}
impl StorageElement for f16 {
fn zero() -> Self { f16(0) }
fn one() -> Self { f16::from_f32(1.0) }
}
impl NumberLike for f16 {
fn from_f32(v: f32) -> Self { f16::from_f32(v) }
fn to_f32(self) -> f32 { self.to_f32() }
fn abs(self) -> Self { self.abs() }
fn is_nan(self) -> bool { self.is_nan() }
fn is_finite(self) -> bool { self.is_finite() }
fn is_infinite(self) -> bool { self.is_infinite() }
fn has_infinity() -> bool { true }
fn has_nan() -> bool { true }
fn has_subnormals() -> bool { true }
fn max_value() -> f32 { 65504.0 }
fn min_positive() -> f32 { 6.1e-5 }
}
impl StorageElement for bf16 {
fn zero() -> Self { bf16(0) }
fn one() -> Self { bf16::from_f32(1.0) }
}
impl NumberLike for bf16 {
fn from_f32(v: f32) -> Self { bf16::from_f32(v) }
fn to_f32(self) -> f32 { self.to_f32() }
fn abs(self) -> Self { self.abs() }
fn is_nan(self) -> bool { self.is_nan() }
fn is_finite(self) -> bool { self.is_finite() }
fn is_infinite(self) -> bool { self.is_infinite() }
fn has_infinity() -> bool { true }
fn has_nan() -> bool { true }
fn has_subnormals() -> bool { true }
fn max_value() -> f32 { 3.4e38 }
fn min_positive() -> f32 { 1.2e-38 }
}
impl StorageElement for e4m3 {
fn zero() -> Self { e4m3(0) }
fn one() -> Self { e4m3::from_f32(1.0) }
}
impl NumberLike for e4m3 {
fn from_f32(v: f32) -> Self { e4m3::from_f32(v) }
fn to_f32(self) -> f32 { self.to_f32() }
fn abs(self) -> Self { self.abs() }
fn is_nan(self) -> bool { self.is_nan() }
fn is_finite(self) -> bool { self.is_finite() }
fn is_infinite(self) -> bool { self.is_infinite() }
fn has_nan() -> bool { true }
fn has_subnormals() -> bool { true }
fn max_value() -> f32 { 448.0 }
fn min_positive() -> f32 { 0.001953125 }
}
impl StorageElement for e5m2 {
fn zero() -> Self { e5m2(0) }
fn one() -> Self { e5m2::from_f32(1.0) }
}
impl NumberLike for e5m2 {
fn from_f32(v: f32) -> Self { e5m2::from_f32(v) }
fn to_f32(self) -> f32 { self.to_f32() }
fn abs(self) -> Self { self.abs() }
fn is_nan(self) -> bool { self.is_nan() }
fn is_finite(self) -> bool { self.is_finite() }
fn is_infinite(self) -> bool { self.is_infinite() }
fn has_infinity() -> bool { true }
fn has_nan() -> bool { true }
fn has_subnormals() -> bool { true }
fn max_value() -> f32 { 57344.0 }
fn min_positive() -> f32 { 6.103_515_6e-5 }
}
impl StorageElement for e2m3 {
fn zero() -> Self { e2m3(0) }
fn one() -> Self { e2m3::from_f32(1.0) }
}
impl NumberLike for e2m3 {
fn from_f32(v: f32) -> Self { e2m3::from_f32(v) }
fn to_f32(self) -> f32 { self.to_f32() }
fn abs(self) -> Self { self.abs() }
fn is_nan(self) -> bool { self.is_nan() }
fn is_finite(self) -> bool { self.is_finite() }
fn is_infinite(self) -> bool { self.is_infinite() }
fn has_subnormals() -> bool { true }
fn max_value() -> f32 { 7.5 }
fn min_positive() -> f32 { 0.0625 }
}
impl StorageElement for e3m2 {
fn zero() -> Self { e3m2(0) }
fn one() -> Self { e3m2::from_f32(1.0) }
}
impl NumberLike for e3m2 {
fn from_f32(v: f32) -> Self { e3m2::from_f32(v) }
fn to_f32(self) -> f32 { self.to_f32() }
fn abs(self) -> Self { self.abs() }
fn is_nan(self) -> bool { self.is_nan() }
fn is_finite(self) -> bool { self.is_finite() }
fn is_infinite(self) -> bool { self.is_infinite() }
fn has_subnormals() -> bool { true }
fn max_value() -> f32 { 28.0 }
fn min_positive() -> f32 { 0.125 }
}
impl StorageElement for i8 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl NumberLike for i8 {
fn from_f32(v: f32) -> Self { f32_round_compat(v) as i8 }
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self { f64_round_compat(v) as i8 }
fn to_f64(self) -> f64 { self as f64 }
}
impl StorageElement for u8 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl NumberLike for u8 {
fn from_f32(v: f32) -> Self {
let r = f32_round_compat(v);
if r < 0.0 {
0
} else {
r as u8
}
}
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self {
let r = f64_round_compat(v);
if r < 0.0 {
0
} else {
r as u8
}
}
fn to_f64(self) -> f64 { self as f64 }
}
impl StorageElement for i32 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl NumberLike for i32 {
fn from_f32(v: f32) -> Self { f32_round_compat(v) as i32 }
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self { f64_round_compat(v) as i32 }
fn to_f64(self) -> f64 { self as f64 }
}
impl StorageElement for u32 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl NumberLike for u32 {
fn from_f32(v: f32) -> Self {
let r = f32_round_compat(v);
if r < 0.0 {
0
} else {
r as u32
}
}
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self {
let r = f64_round_compat(v);
if r < 0.0 {
0
} else {
r as u32
}
}
fn to_f64(self) -> f64 { self as f64 }
}
impl StorageElement for i16 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl NumberLike for i16 {
fn from_f32(v: f32) -> Self { f32_round_compat(v) as i16 }
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self { f64_round_compat(v) as i16 }
fn to_f64(self) -> f64 { self as f64 }
}
impl StorageElement for u16 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl NumberLike for u16 {
fn from_f32(v: f32) -> Self {
let r = f32_round_compat(v);
if r < 0.0 {
0
} else {
r as u16
}
}
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self {
let r = f64_round_compat(v);
if r < 0.0 {
0
} else {
r as u16
}
}
fn to_f64(self) -> f64 { self as f64 }
}
impl StorageElement for i64 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl NumberLike for i64 {
fn from_f32(v: f32) -> Self { f32_round_compat(v) as i64 }
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self { f64_round_compat(v) as i64 }
fn to_f64(self) -> f64 { self as f64 }
}
impl StorageElement for u64 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl NumberLike for u64 {
fn from_f32(v: f32) -> Self {
let r = f32_round_compat(v);
if r < 0.0 {
0
} else {
r as u64
}
}
fn to_f32(self) -> f32 { self as f32 }
fn from_f64(v: f64) -> Self {
let r = f64_round_compat(v);
if r < 0.0 {
0
} else {
r as u64
}
}
fn to_f64(self) -> f64 { self as f64 }
}
impl StorageElement for usize {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
}
impl StorageElement for i4x2 {
fn zero() -> Self { i4x2::from((0i8, 0i8)) }
fn one() -> Self { i4x2::from((1i8, 1i8)) }
fn dimensions_per_value() -> usize { 2 }
}
impl NumberLike for i4x2 {
fn from_f32(v: f32) -> Self {
let r = f32_round_compat(v) as i8;
i4x2::from((r, r))
}
fn to_f32(self) -> f32 {
let (a, _) = self.into();
a as f32
}
}
impl StorageElement for u4x2 {
fn zero() -> Self { u4x2::from((0u8, 0u8)) }
fn one() -> Self { u4x2::from((1u8, 1u8)) }
fn dimensions_per_value() -> usize { 2 }
}
impl NumberLike for u4x2 {
fn from_f32(v: f32) -> Self {
let r = f32_round_compat(v);
let r = if r < 0.0 { 0u8 } else { r as u8 };
u4x2::from((r, r))
}
fn to_f32(self) -> f32 {
let (a, _) = self.into();
a as f32
}
}
impl StorageElement for u1x8 {
fn zero() -> Self { u1x8(0x00) }
fn one() -> Self { u1x8(0xFF) }
fn dimensions_per_value() -> usize { 8 }
}
impl NumberLike for u1x8 {
fn from_f32(v: f32) -> Self {
if v > 0.0 {
u1x8(0xFF)
} else {
u1x8(0x00)
}
}
fn to_f32(self) -> f32 { self.0.count_ones() as f32 }
}
#[inline(always)]
fn complex_mul_components<T>(lhs_re: T, lhs_im: T, rhs_re: T, rhs_im: T) -> (T, T)
where
T: Copy + core::ops::Add<Output = T> + core::ops::Sub<Output = T> + core::ops::Mul<Output = T>,
{
(
lhs_re * rhs_re - lhs_im * rhs_im,
lhs_re * rhs_im + lhs_im * rhs_re,
)
}
#[inline(always)]
fn complex_div_components<T>(lhs_re: T, lhs_im: T, rhs_re: T, rhs_im: T) -> (T, T)
where
T: Copy
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::Div<Output = T>,
{
let denom = rhs_re * rhs_re + rhs_im * rhs_im;
(
(lhs_re * rhs_re + lhs_im * rhs_im) / denom,
(lhs_im * rhs_re - lhs_re * rhs_im) / denom,
)
}
#[doc(hidden)]
pub trait ComplexComponent:
NumberLike
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
+ core::ops::Mul<Output = Self>
+ core::ops::Div<Output = Self>
+ core::ops::Neg<Output = Self>
{
type Norm: Copy + core::ops::Add<Output = Self::Norm>;
fn from_norm(value: Self::Norm) -> Self;
fn norm_component(self) -> Self::Norm;
}
impl ComplexComponent for f16 {
type Norm = f32;
fn from_norm(value: Self::Norm) -> Self { Self::from_f32(value) }
fn norm_component(self) -> Self::Norm { self.to_f32() * self.to_f32() }
}
impl ComplexComponent for bf16 {
type Norm = f32;
fn from_norm(value: Self::Norm) -> Self { Self::from_f32(value) }
fn norm_component(self) -> Self::Norm { self.to_f32() * self.to_f32() }
}
impl ComplexComponent for f32 {
type Norm = f32;
fn from_norm(value: Self::Norm) -> Self { value }
fn norm_component(self) -> Self::Norm { self * self }
}
impl ComplexComponent for f64 {
type Norm = f64;
fn from_norm(value: Self::Norm) -> Self { value }
fn norm_component(self) -> Self::Norm { self * self }
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub struct complex<T> {
pub re: T,
pub im: T,
}
impl<T> complex<T> {
pub const fn from_real_imag(re: T, im: T) -> Self { Self { re, im } }
pub const fn to_real_imag(self) -> (T, T)
where
T: Copy,
{
(self.re, self.im)
}
}
impl<T: ComplexComponent> complex<T> {
pub fn conj(self) -> Self {
Self {
re: self.re,
im: -self.im,
}
}
pub fn norm_sqr(self) -> T::Norm { self.re.norm_component() + self.im.norm_component() }
}
pub type f16c = complex<f16>;
pub type bf16c = complex<bf16>;
pub type f32c = complex<f32>;
pub type f64c = complex<f64>;
impl<T: ComplexComponent> From<f32> for complex<T> {
fn from(value: f32) -> Self {
Self {
re: T::from_f32(value),
im: T::zero(),
}
}
}
impl<T: ComplexComponent> core::ops::Add for complex<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
re: self.re + rhs.re,
im: self.im + rhs.im,
}
}
}
impl<T: ComplexComponent> core::ops::Sub for complex<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self {
re: self.re - rhs.re,
im: self.im - rhs.im,
}
}
}
impl<T: ComplexComponent> core::ops::Mul for complex<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
let (re, im) = complex_mul_components(self.re, self.im, rhs.re, rhs.im);
Self { re, im }
}
}
impl<T: ComplexComponent> core::ops::Div for complex<T> {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
let (re, im) = complex_div_components(self.re, self.im, rhs.re, rhs.im);
Self { re, im }
}
}
impl<T: ComplexComponent> core::ops::Neg for complex<T> {
type Output = Self;
fn neg(self) -> Self::Output {
Self {
re: -self.re,
im: -self.im,
}
}
}
impl<T: NumberLike> core::fmt::Display for complex<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let re = self.re.to_f32();
let im = self.im.to_f32();
if let Some(prec) = f.precision() {
if im < 0.0 {
write!(f, "{:.prec$} - {:.prec$}i", re, -im)
} else {
write!(f, "{:.prec$} + {:.prec$}i", re, im)
}
} else if im < 0.0 {
write!(f, "{} - {}i", re, -im)
} else {
write!(f, "{} + {}i", re, im)
}
}
}
impl<T: ComplexComponent> StorageElement for complex<T> {
fn zero() -> Self {
Self {
re: T::zero(),
im: T::zero(),
}
}
fn one() -> Self {
Self {
re: T::one(),
im: T::zero(),
}
}
}
impl<T: ComplexComponent> NumberLike for complex<T> {
fn from_f32(v: f32) -> Self { Self::from(v) }
fn to_f32(self) -> f32 { self.re.to_f32() }
fn from_f64(v: f64) -> Self {
Self {
re: T::from_f64(v),
im: T::zero(),
}
}
fn to_f64(self) -> f64 { self.re.to_f64() }
fn abs(self) -> Self {
Self {
re: T::from_norm(self.norm_sqr()),
im: T::zero(),
}
}
fn is_nan(self) -> bool { self.re.is_nan() || self.im.is_nan() }
fn is_finite(self) -> bool { self.re.is_finite() && self.im.is_finite() }
fn is_infinite(self) -> bool { self.re.is_infinite() || self.im.is_infinite() }
fn has_infinity() -> bool { T::has_infinity() }
fn has_nan() -> bool { T::has_nan() }
fn has_subnormals() -> bool { T::has_subnormals() }
fn max_value() -> f32 { T::max_value() }
fn min_positive() -> f32 { T::min_positive() }
}
impl<T: ComplexComponent> FloatConvertible for complex<T> {
type DimScalar = complex<T>;
type Unpacked = [complex<T>; 1];
#[inline(always)]
fn unpack(self) -> [complex<T>; 1] { [self] }
#[inline(always)]
fn pack(dims: [complex<T>; 1]) -> Self { dims[0] }
}
pub trait FloatConvertible: NumberLike {
type DimScalar: Copy + Default + NumberLike;
type Unpacked: AsRef<[Self::DimScalar]> + AsMut<[Self::DimScalar]> + Copy + Default;
fn unpack(self) -> Self::Unpacked;
fn pack(dims: Self::Unpacked) -> Self;
}
impl FloatConvertible for f32 {
type DimScalar = f32;
type Unpacked = [f32; 1];
#[inline(always)]
fn unpack(self) -> [f32; 1] { [self] }
#[inline(always)]
fn pack(dims: [f32; 1]) -> Self { dims[0] }
}
impl FloatConvertible for f64 {
type DimScalar = f64;
type Unpacked = [f64; 1];
#[inline(always)]
fn unpack(self) -> [f64; 1] { [self] }
#[inline(always)]
fn pack(dims: [f64; 1]) -> Self { dims[0] }
}
impl FloatConvertible for f16 {
type DimScalar = f16;
type Unpacked = [f16; 1];
#[inline(always)]
fn unpack(self) -> [f16; 1] { [self] }
#[inline(always)]
fn pack(dims: [f16; 1]) -> Self { dims[0] }
}
impl FloatConvertible for bf16 {
type DimScalar = bf16;
type Unpacked = [bf16; 1];
#[inline(always)]
fn unpack(self) -> [bf16; 1] { [self] }
#[inline(always)]
fn pack(dims: [bf16; 1]) -> Self { dims[0] }
}
impl FloatConvertible for e4m3 {
type DimScalar = e4m3;
type Unpacked = [e4m3; 1];
#[inline(always)]
fn unpack(self) -> [e4m3; 1] { [self] }
#[inline(always)]
fn pack(dims: [e4m3; 1]) -> Self { dims[0] }
}
impl FloatConvertible for e5m2 {
type DimScalar = e5m2;
type Unpacked = [e5m2; 1];
#[inline(always)]
fn unpack(self) -> [e5m2; 1] { [self] }
#[inline(always)]
fn pack(dims: [e5m2; 1]) -> Self { dims[0] }
}
impl FloatConvertible for e2m3 {
type DimScalar = e2m3;
type Unpacked = [e2m3; 1];
#[inline(always)]
fn unpack(self) -> [e2m3; 1] { [self] }
#[inline(always)]
fn pack(dims: [e2m3; 1]) -> Self { dims[0] }
}
impl FloatConvertible for e3m2 {
type DimScalar = e3m2;
type Unpacked = [e3m2; 1];
#[inline(always)]
fn unpack(self) -> [e3m2; 1] { [self] }
#[inline(always)]
fn pack(dims: [e3m2; 1]) -> Self { dims[0] }
}
impl FloatConvertible for i8 {
type DimScalar = i8;
type Unpacked = [i8; 1];
#[inline(always)]
fn unpack(self) -> [i8; 1] { [self] }
#[inline(always)]
fn pack(dims: [i8; 1]) -> Self { dims[0] }
}
impl FloatConvertible for u8 {
type DimScalar = u8;
type Unpacked = [u8; 1];
#[inline(always)]
fn unpack(self) -> [u8; 1] { [self] }
#[inline(always)]
fn pack(dims: [u8; 1]) -> Self { dims[0] }
}
impl FloatConvertible for i16 {
type DimScalar = i16;
type Unpacked = [i16; 1];
#[inline(always)]
fn unpack(self) -> [i16; 1] { [self] }
#[inline(always)]
fn pack(dims: [i16; 1]) -> Self { dims[0] }
}
impl FloatConvertible for u16 {
type DimScalar = u16;
type Unpacked = [u16; 1];
#[inline(always)]
fn unpack(self) -> [u16; 1] { [self] }
#[inline(always)]
fn pack(dims: [u16; 1]) -> Self { dims[0] }
}
impl FloatConvertible for i32 {
type DimScalar = i32;
type Unpacked = [i32; 1];
#[inline(always)]
fn unpack(self) -> [i32; 1] { [self] }
#[inline(always)]
fn pack(dims: [i32; 1]) -> Self { dims[0] }
}
impl FloatConvertible for u32 {
type DimScalar = u32;
type Unpacked = [u32; 1];
#[inline(always)]
fn unpack(self) -> [u32; 1] { [self] }
#[inline(always)]
fn pack(dims: [u32; 1]) -> Self { dims[0] }
}
impl FloatConvertible for i64 {
type DimScalar = i64;
type Unpacked = [i64; 1];
#[inline(always)]
fn unpack(self) -> [i64; 1] { [self] }
#[inline(always)]
fn pack(dims: [i64; 1]) -> Self { dims[0] }
}
impl FloatConvertible for u64 {
type DimScalar = u64;
type Unpacked = [u64; 1];
#[inline(always)]
fn unpack(self) -> [u64; 1] { [self] }
#[inline(always)]
fn pack(dims: [u64; 1]) -> Self { dims[0] }
}
impl FloatConvertible for i4x2 {
type DimScalar = i8;
type Unpacked = [i8; 2];
#[inline(always)]
fn unpack(self) -> [i8; 2] {
let (lo, hi) = self.to_i8s();
[lo, hi]
}
#[inline(always)]
fn pack(dims: [i8; 2]) -> Self { i4x2::from_i8s(dims[0], dims[1]) }
}
impl FloatConvertible for u4x2 {
type DimScalar = u8;
type Unpacked = [u8; 2];
#[inline(always)]
fn unpack(self) -> [u8; 2] {
let (lo, hi) = self.to_u8s();
[lo, hi]
}
#[inline(always)]
fn pack(dims: [u8; 2]) -> Self { u4x2::from_u8s(dims[0], dims[1]) }
}
impl FloatConvertible for u1x8 {
type DimScalar = u8;
type Unpacked = [u8; 8];
#[inline(always)]
fn unpack(self) -> [u8; 8] {
let mut out = [0u8; 8];
for (i, slot) in out.iter_mut().enumerate() {
*slot = (self.0 >> i) & 1;
}
out
}
#[inline(always)]
fn pack(dims: [u8; 8]) -> Self {
let mut byte = 0u8;
for (i, &dim) in dims.iter().enumerate() {
if dim != 0 {
byte |= 1 << i;
}
}
u1x8(byte)
}
}
pub struct DimRef<'a, T: FloatConvertible> {
value: T::DimScalar,
_marker: core::marker::PhantomData<&'a T>,
}
impl<'a, T: FloatConvertible> DimRef<'a, T> {
#[inline]
pub fn new(value: T::DimScalar) -> Self {
Self {
value,
_marker: core::marker::PhantomData,
}
}
}
impl<T: FloatConvertible> Copy for DimRef<'_, T> {}
impl<T: FloatConvertible> Clone for DimRef<'_, T> {
#[inline]
fn clone(&self) -> Self { *self }
}
impl<T: FloatConvertible> core::ops::Deref for DimRef<'_, T> {
type Target = T::DimScalar;
#[inline]
fn deref(&self) -> &T::DimScalar { &self.value }
}
impl<T: FloatConvertible> PartialEq for DimRef<'_, T>
where
T::DimScalar: PartialEq,
{
#[inline]
fn eq(&self, other: &Self) -> bool { self.value == other.value }
}
impl<T: FloatConvertible> PartialOrd for DimRef<'_, T>
where
T::DimScalar: PartialOrd,
{
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.value.partial_cmp(&other.value)
}
}
impl<T: FloatConvertible> core::fmt::Debug for DimRef<'_, T>
where
T::DimScalar: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.value.fmt(f) }
}
impl<T: FloatConvertible> core::fmt::Display for DimRef<'_, T>
where
T::DimScalar: core::fmt::Display,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.value.fmt(f) }
}
pub struct DimMut<'a, T: FloatConvertible> {
ptr: *mut T,
sub_index: usize,
value: T::DimScalar,
_marker: core::marker::PhantomData<&'a mut T>,
}
impl<'a, T: FloatConvertible> DimMut<'a, T> {
#[inline]
pub unsafe fn new(ptr: *mut T, sub_index: usize, value: T::DimScalar) -> Self {
Self {
ptr,
sub_index,
value,
_marker: core::marker::PhantomData,
}
}
}
impl<T: FloatConvertible> core::ops::Deref for DimMut<'_, T> {
type Target = T::DimScalar;
#[inline]
fn deref(&self) -> &T::DimScalar { &self.value }
}
impl<T: FloatConvertible> core::ops::DerefMut for DimMut<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut T::DimScalar { &mut self.value }
}
impl<T: FloatConvertible> Drop for DimMut<'_, T> {
fn drop(&mut self) {
let mut unpacked = unsafe { *self.ptr }.unpack();
unpacked.as_mut()[self.sub_index] = self.value;
unsafe { self.ptr.write(T::pack(unpacked)) };
}
}
impl<T: FloatConvertible> core::fmt::Debug for DimMut<'_, T>
where
T::DimScalar: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.value.fmt(f) }
}
impl<T: FloatConvertible> core::fmt::Display for DimMut<'_, T>
where
T::DimScalar: core::fmt::Display,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.value.fmt(f) }
}
#[cfg(test)]
pub(crate) trait TestableType: FloatLike {
fn atol() -> f64;
fn rtol() -> f64;
}
#[cfg(test)]
pub(crate) fn assert_close(actual: f64, expected: f64, atol: f64, rtol: f64, msg: &str) {
let tol = atol + rtol * expected.abs();
assert!(
(actual - expected).abs() <= tol,
"{}: expected {} but got {} (atol={}, rtol={}, tol={})",
msg,
expected,
actual,
atol,
rtol,
tol
);
}
#[cfg(test)]
impl TestableType for f32 {
fn atol() -> f64 { 1e-4 }
fn rtol() -> f64 { 1e-4 }
}
#[cfg(test)]
impl TestableType for f64 {
fn atol() -> f64 { 1e-9 }
fn rtol() -> f64 { 1e-9 }
}
#[cfg(test)]
impl TestableType for f16 {
fn atol() -> f64 { 0.05 }
fn rtol() -> f64 { 0.05 }
}
#[cfg(test)]
impl TestableType for bf16 {
fn atol() -> f64 { 0.1 }
fn rtol() -> f64 { 0.1 }
}
#[cfg(test)]
impl TestableType for e4m3 {
fn atol() -> f64 { 0.5 }
fn rtol() -> f64 { 0.1 }
}
#[cfg(test)]
impl TestableType for e5m2 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.1 }
}
#[cfg(test)]
impl TestableType for e2m3 {
fn atol() -> f64 { 0.5 }
fn rtol() -> f64 { 0.1 }
}
#[cfg(test)]
impl TestableType for e3m2 {
fn atol() -> f64 { 0.5 }
fn rtol() -> f64 { 0.1 }
}
#[cfg(test)]
impl TestableType for i8 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for u8 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for i32 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for u32 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for i16 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for u16 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for i64 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for u64 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for i4x2 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for u4x2 {
fn atol() -> f64 { 1.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
impl TestableType for u1x8 {
fn atol() -> f64 { 0.0 }
fn rtol() -> f64 { 0.0 }
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_scalar_roundtrip<T: FloatLike>(original: f32, abs_tol: f32, rel_tol: f32) {
let converted = T::from_f32(original);
let roundtrip = NumberLike::to_f32(converted);
if original == 0.0 {
assert_eq!(roundtrip, 0.0, "Zero should roundtrip exactly");
return;
}
let abs_error = f32_abs_compat(roundtrip - original);
let rel_error = abs_error / f32_abs_compat(original);
assert!(
abs_error <= abs_tol || rel_error <= rel_tol,
"Roundtrip failed for {}: got {} (abs_err={:.6}, rel_err={:.6})",
original,
roundtrip,
abs_error,
rel_error
);
}
fn assert_scalar_almost_equal<T: FloatLike>(
actual: f32,
expected: f32,
abs_tol: f32,
rel_tol: f32,
context: &str,
) {
let abs_error = f32_abs_compat(actual - expected);
let rel_error = if expected != 0.0 {
abs_error / f32_abs_compat(expected)
} else {
abs_error
};
assert!(
abs_error <= abs_tol || rel_error <= rel_tol,
"{}: expected {} but got {} (abs_err={:.6}, rel_err={:.6})",
context,
expected,
actual,
abs_error,
rel_error
);
}
fn check_arithmetic<T>(a_val: f32, b_val: f32, abs_tol: f32, rel_tol: f32)
where
T: FloatLike
+ PartialOrd
+ PartialEq
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::Div<Output = T>
+ core::ops::Neg<Output = T>,
{
let a = T::from_f32(a_val);
let b = T::from_f32(b_val);
assert_scalar_almost_equal::<T>(
NumberLike::to_f32(a + b),
a_val + b_val,
abs_tol,
rel_tol,
"add",
);
assert_scalar_almost_equal::<T>(
NumberLike::to_f32(a - b),
a_val - b_val,
abs_tol,
rel_tol,
"sub",
);
assert_scalar_almost_equal::<T>(
NumberLike::to_f32(a * b),
a_val * b_val,
abs_tol,
rel_tol,
"mul",
);
assert_scalar_almost_equal::<T>(
NumberLike::to_f32(a / b),
a_val / b_val,
abs_tol,
rel_tol,
"div",
);
assert_scalar_almost_equal::<T>(NumberLike::to_f32(-a), -a_val, abs_tol, rel_tol, "neg");
assert_eq!(NumberLike::to_f32(T::zero()), 0.0);
assert_scalar_almost_equal::<T>(NumberLike::to_f32(T::one()), 1.0, abs_tol, rel_tol, "ONE");
assert_scalar_almost_equal::<T>(
NumberLike::to_f32(T::from_f32(-1.0)),
-1.0,
abs_tol,
rel_tol,
"NEG_ONE",
);
assert!(a > b);
assert!(a == a);
assert_scalar_almost_equal::<T>(
NumberLike::to_f32(NumberLike::abs(-a)),
a_val,
abs_tol,
rel_tol,
"abs",
);
assert!(NumberLike::is_finite(a));
assert!(!NumberLike::is_nan(a));
if T::has_infinity() {
assert!(!NumberLike::is_infinite(a));
}
}
fn check_roundtrip<T: FloatLike>(values: &[f32], abs_tol: f32, rel_tol: f32) {
for &v in values {
assert_scalar_roundtrip::<T>(v, abs_tol, rel_tol);
}
}
fn check_edge_cases<T: FloatLike>() {
assert_eq!(NumberLike::to_f32(T::from_f32(0.0)), 0.0);
assert_eq!(NumberLike::to_f32(T::from_f32(-0.0)), 0.0);
if T::has_infinity() {
assert!(NumberLike::to_f32(T::from_f32(f32::INFINITY)).is_infinite());
assert!(NumberLike::to_f32(T::from_f32(f32::NEG_INFINITY)).is_infinite());
} else {
assert!(!NumberLike::to_f32(T::from_f32(f32::INFINITY)).is_infinite());
}
if T::has_nan() {
assert!(NumberLike::to_f32(T::from_f32(f32::NAN)).is_nan());
} else {
assert!(!NumberLike::to_f32(T::from_f32(f32::NAN)).is_nan());
}
let big = T::max_value() * 10.0;
let overflow = T::from_f32(big);
if T::has_infinity() {
assert!(
NumberLike::to_f32(overflow).is_infinite()
|| NumberLike::to_f32(overflow) >= T::max_value()
);
} else {
let v = NumberLike::to_f32(overflow);
assert!(!v.is_infinite() && !v.is_nan());
assert!(v <= T::max_value());
let neg = NumberLike::to_f32(T::from_f32(-big));
assert!(!neg.is_infinite() && !neg.is_nan());
assert!(neg >= -T::max_value());
}
}
fn check_subnormals<T: FloatLike>(values: &[f32], upper_bound: f32) {
for &val in values {
let roundtrip = NumberLike::to_f32(T::from_f32(val));
assert!(
roundtrip >= 0.0 && roundtrip < upper_bound,
"{} subnormal test failed for {}: got {}",
core::any::type_name::<T>(),
val,
roundtrip
);
}
}
#[test]
fn arithmetic_ieee_halfs() {
check_arithmetic::<f16>(3.5, 2.0, 0.002, 0.001);
check_arithmetic::<bf16>(3.5, 2.0, 0.016, 0.008);
}
#[test]
fn arithmetic_minifloats() {
check_arithmetic::<e4m3>(2.0, 1.5, 0.25, 0.125);
check_arithmetic::<e5m2>(2.0, 1.5, 0.5, 0.25);
check_arithmetic::<e2m3>(2.0, 1.5, 0.25, 0.125);
check_arithmetic::<e3m2>(4.0, 2.0, 0.5, 0.25);
}
#[test]
fn roundtrip() {
check_roundtrip::<f16>(
&[
0.0, 1.0, -1.0, 0.5, 2.0, 4.0, 8.0, 16.0, 100.0, 1000.0, 10000.0, 0.001, 0.0001,
0.00001, -100.0, -1000.0,
],
0.002,
0.001,
);
check_roundtrip::<bf16>(
&[
0.0, 1.0, -1.0, 0.5, 2.0, 10.0, 100.0, 1000.0, 1e6, 0.001, 1e-6, -100.0, -1000.0,
],
0.016,
0.008,
);
check_roundtrip::<e4m3>(
&[0.0, 1.0, -1.0, 0.5, 2.0, 4.0, 8.0, 16.0, 64.0, 128.0, 224.0],
0.25,
0.125,
);
check_roundtrip::<e5m2>(
&[
0.0, 1.0, -1.0, 0.5, 2.0, 4.0, 8.0, 16.0, 64.0, 256.0, 1024.0,
],
0.5,
0.25,
);
check_roundtrip::<e2m3>(
&[
0.0, 1.0, -1.0, 0.5, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 7.5, 0.25, -0.25, 0.125, -0.125,
-4.0, -6.0, -7.5,
],
0.25,
0.125,
);
check_roundtrip::<e3m2>(
&[
0.0, 1.0, -1.0, 0.5, 2.0, 4.0, 8.0, 16.0, 20.0, 24.0, 28.0, 0.25, -0.25, -20.0,
-28.0,
],
0.5,
0.25,
);
}
#[test]
fn edge_cases() {
check_edge_cases::<f16>();
check_edge_cases::<bf16>();
check_edge_cases::<e4m3>();
check_edge_cases::<e5m2>();
check_edge_cases::<e2m3>();
check_edge_cases::<e3m2>();
}
#[test]
fn subnormals() {
check_subnormals::<f16>(&[1e-5, 1e-6, 1e-7, 5e-6, 5e-7], 1e-4);
check_subnormals::<bf16>(&[1e-39, 1e-40, 1e-42], 1e-37);
check_subnormals::<e4m3>(&[0.001, 0.0005], 0.002);
check_subnormals::<e5m2>(&[0.00005, 0.00003, 0.00001], 0.0001);
check_subnormals::<e2m3>(&[0.03, 0.015], 0.07);
check_subnormals::<e3m2>(&[0.0625, 0.03], 0.15);
}
#[test]
fn half_crate_interop() {
use half::bf16 as HalfBF16;
use half::f16 as HalfF16;
for bits in 0u16..=u16::MAX {
let half_val = HalfF16::from_bits(bits).to_f32();
let nk_val = f16(bits).to_f32();
assert!(
half_val.to_bits() == nk_val.to_bits() || (half_val.is_nan() && nk_val.is_nan()),
"f16 mismatch at bits 0x{bits:04X}: half={half_val}, numkong={nk_val}"
);
}
for bits in 0u16..=u16::MAX {
let half_val = HalfBF16::from_bits(bits).to_f32();
let nk_val = bf16(bits).to_f32();
assert!(
half_val.to_bits() == nk_val.to_bits() || (half_val.is_nan() && nk_val.is_nan()),
"bf16 mismatch at bits 0x{bits:04X}: half={half_val}, numkong={nk_val}"
);
}
}
#[test]
fn is_close_exact() {
assert!(is_close(1.0, 1.0, 0.0, 0.0));
assert!(is_close(0.0, 0.0, 0.0, 0.0));
assert!(is_close(-5.0, -5.0, 0.0, 0.0));
}
#[test]
fn is_close_within_atol() {
assert!(is_close(1.0, 1.0 + 1e-7, 1e-6, 0.0));
assert!(!is_close(1.0, 1.0 + 1e-5, 1e-6, 0.0));
}
#[test]
fn is_close_within_rtol() {
assert!(is_close(100.0, 100.01, 0.0, 1e-3));
assert!(!is_close(100.0, 100.2, 0.0, 1e-3));
}
#[test]
fn is_close_nan_inf() {
assert!(!is_close(f64::NAN, f64::NAN, 1e-6, 1e-6));
assert!(!is_close(f64::NAN, 0.0, 1e-6, 1e-6));
assert!(!is_close(f64::INFINITY, f64::INFINITY, 0.0, 0.0));
assert!(!is_close(f64::INFINITY, f64::NEG_INFINITY, 0.0, 0.0));
}
}