#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Half(u16);
impl Half {
pub const ZERO: Half = Half(0);
pub const ONE: Half = Half(0x3c00);
pub const MAX: Half = Half(0x7bff);
pub const MIN_POSITIVE: Half = Half(0x0400);
pub const INFINITY: Half = Half(0x7c00);
pub const NEG_INFINITY: Half = Half(0xfc00);
pub const NAN: Half = Half(0x7e00);
pub const fn from_bits(bits: u16) -> Self {
Half(bits)
}
pub const fn to_bits(self) -> u16 {
self.0
}
pub fn from_f32(value: f32) -> Self {
let bits = value.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xff) as i32;
let frac = bits & 0x7fffff;
if exp == 0xff {
if frac == 0 {
return Half(((sign << 15) | 0x7c00) as u16);
} else {
return Half(((sign << 15) | 0x7e00) as u16);
}
}
let new_exp = exp - 127 + 15;
if new_exp <= 0 {
if new_exp < -10 {
return Half((sign << 15) as u16);
}
let shift = 1 - new_exp;
let frac_with_hidden = frac | 0x800000;
let frac16 = (frac_with_hidden >> (shift + 13)) as u16;
return Half(((sign << 15) | frac16 as u32) as u16);
}
if new_exp >= 31 {
return Half(((sign << 15) | 0x7c00) as u16);
}
let frac16 = (frac >> 13) as u16;
Half(((sign << 15) | ((new_exp as u32) << 10) | frac16 as u32) as u16)
}
pub fn to_f32(self) -> f32 {
let bits = self.0 as u32;
let sign = (bits >> 15) & 1;
let exp = (bits >> 10) & 0x1f;
let frac = bits & 0x3ff;
if exp == 0 {
if frac == 0 {
return f32::from_bits(sign << 31);
}
let mut frac = frac;
let mut e = -14i32;
while frac & 0x400 == 0 {
frac <<= 1;
e -= 1;
}
frac &= 0x3ff;
let exp32 = (e + 127) as u32;
let frac32 = frac << 13;
return f32::from_bits((sign << 31) | (exp32 << 23) | frac32);
}
if exp == 0x1f {
if frac == 0 {
return f32::from_bits((sign << 31) | 0x7f800000);
}
return f32::from_bits((sign << 31) | 0x7fc00000);
}
let exp32 = (exp as i32 - 15 + 127) as u32;
let frac32 = frac << 13;
f32::from_bits((sign << 31) | (exp32 << 23) | frac32)
}
pub fn is_nan(self) -> bool {
(self.0 & 0x7c00) == 0x7c00 && (self.0 & 0x03ff) != 0
}
pub fn is_infinite(self) -> bool {
(self.0 & 0x7fff) == 0x7c00
}
pub fn is_finite(self) -> bool {
(self.0 & 0x7c00) != 0x7c00
}
pub fn is_zero(self) -> bool {
(self.0 & 0x7fff) == 0
}
}
impl From<f32> for Half {
fn from(value: f32) -> Self {
Half::from_f32(value)
}
}
impl From<Half> for f32 {
fn from(value: Half) -> Self {
value.to_f32()
}
}
impl From<f64> for Half {
fn from(value: f64) -> Self {
Half::from_f32(value as f32)
}
}
impl From<Half> for f64 {
fn from(value: Half) -> Self {
value.to_f32() as f64
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct BFloat16(u16);
impl BFloat16 {
pub const ZERO: BFloat16 = BFloat16(0);
pub const ONE: BFloat16 = BFloat16(0x3f80);
pub const MAX: BFloat16 = BFloat16(0x7f7f);
pub const MIN_POSITIVE: BFloat16 = BFloat16(0x0080);
pub const INFINITY: BFloat16 = BFloat16(0x7f80);
pub const NEG_INFINITY: BFloat16 = BFloat16(0xff80);
pub const NAN: BFloat16 = BFloat16(0x7fc0);
#[inline]
pub const fn from_bits(bits: u16) -> Self {
BFloat16(bits)
}
#[inline]
pub const fn to_bits(self) -> u16 {
self.0
}
pub fn from_f32(value: f32) -> Self {
let bits = value.to_bits();
if (bits & 0x7f80_0000) == 0x7f80_0000 && (bits & 0x007f_ffff) != 0 {
return BFloat16(((bits >> 16) | 0x0040) as u16);
}
let rounding_bias = 0x0000_7fff_u32 + ((bits >> 16) & 1);
BFloat16(((bits + rounding_bias) >> 16) as u16)
}
#[inline]
pub fn to_f32(self) -> f32 {
f32::from_bits((self.0 as u32) << 16)
}
#[inline]
pub fn from_f64(value: f64) -> Self {
BFloat16::from_f32(value as f32)
}
#[inline]
pub fn to_f64(self) -> f64 {
self.to_f32() as f64
}
#[inline]
pub fn is_nan(self) -> bool {
(self.0 & 0x7f80) == 0x7f80 && (self.0 & 0x007f) != 0
}
#[inline]
pub fn is_infinite(self) -> bool {
(self.0 & 0x7fff) == 0x7f80
}
#[inline]
pub fn is_finite(self) -> bool {
(self.0 & 0x7f80) != 0x7f80
}
#[inline]
pub fn is_zero(self) -> bool {
(self.0 & 0x7fff) == 0
}
#[inline]
pub fn is_subnormal(self) -> bool {
(self.0 & 0x7f80) == 0 && (self.0 & 0x007f) != 0
}
#[inline]
pub fn abs(self) -> Self {
BFloat16(self.0 & 0x7fff)
}
#[inline]
pub fn neg(self) -> Self {
BFloat16(self.0 ^ 0x8000)
}
}
impl From<f32> for BFloat16 {
#[inline]
fn from(value: f32) -> Self {
BFloat16::from_f32(value)
}
}
impl From<BFloat16> for f32 {
#[inline]
fn from(value: BFloat16) -> Self {
value.to_f32()
}
}
impl From<f64> for BFloat16 {
#[inline]
fn from(value: f64) -> Self {
BFloat16::from_f64(value)
}
}
impl From<BFloat16> for f64 {
#[inline]
fn from(value: BFloat16) -> Self {
value.to_f64()
}
}
impl From<Half> for BFloat16 {
#[inline]
fn from(value: Half) -> Self {
BFloat16::from_f32(value.to_f32())
}
}
impl From<BFloat16> for Half {
#[inline]
fn from(value: BFloat16) -> Self {
Half::from_f32(value.to_f32())
}
}
impl std::fmt::Display for BFloat16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_f32())
}
}
impl std::ops::Add for BFloat16 {
type Output = BFloat16;
fn add(self, rhs: BFloat16) -> BFloat16 {
BFloat16::from_f32(self.to_f32() + rhs.to_f32())
}
}
impl std::ops::Sub for BFloat16 {
type Output = BFloat16;
fn sub(self, rhs: BFloat16) -> BFloat16 {
BFloat16::from_f32(self.to_f32() - rhs.to_f32())
}
}
impl std::ops::Mul for BFloat16 {
type Output = BFloat16;
fn mul(self, rhs: BFloat16) -> BFloat16 {
BFloat16::from_f32(self.to_f32() * rhs.to_f32())
}
}
impl std::ops::Div for BFloat16 {
type Output = BFloat16;
fn div(self, rhs: BFloat16) -> BFloat16 {
BFloat16::from_f32(self.to_f32() / rhs.to_f32())
}
}
impl std::ops::Neg for BFloat16 {
type Output = BFloat16;
fn neg(self) -> BFloat16 {
self.neg()
}
}
impl PartialOrd for BFloat16 {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
pub fn f32_slice_to_bf16(src: &[f32]) -> Vec<BFloat16> {
src.iter().copied().map(BFloat16::from_f32).collect()
}
pub fn bf16_slice_to_f32(src: &[BFloat16]) -> Vec<f32> {
src.iter().copied().map(BFloat16::to_f32).collect()
}
pub fn f64_slice_to_bf16(src: &[f64]) -> Vec<BFloat16> {
src.iter().copied().map(BFloat16::from_f64).collect()
}
pub fn bf16_grad_norm(grads: &[BFloat16]) -> f32 {
grads
.iter()
.map(|g| {
let v = g.to_f32();
v * v
})
.sum::<f32>()
.sqrt()
}