#[inline]
pub(crate) fn f16_bits_to_f32(h: u16) -> f32 {
let sign = ((h as u32) & 0x8000) << 16;
let exp_mant = (h as u32) & 0x7fff;
if exp_mant == 0 {
return f32::from_bits(sign);
}
let exp = (exp_mant >> 10) & 0x1f;
let mant = exp_mant & 0x3ff;
if exp == 0x1f {
return f32::from_bits(sign | 0x7f80_0000 | (mant << 13));
}
if exp == 0 {
let p = 31 - mant.leading_zeros(); let biased_exp_f32 = p + 103;
let stored_mant_f32 = (mant << (23 - p)) & 0x7f_ffff;
return f32::from_bits(sign | (biased_exp_f32 << 23) | stored_mant_f32);
}
let biased_exp_f32 = exp + 112;
f32::from_bits(sign | (biased_exp_f32 << 23) | (mant << 13))
}
#[inline]
pub(crate) fn f32_to_f16_bits(v: f32) -> u16 {
let bits = v.to_bits();
let sign = ((bits >> 16) & 0x8000) as u16;
let exp_f32 = ((bits >> 23) & 0xff) as i32;
let frac_f32 = bits & 0x7f_ffff;
if exp_f32 == 0xff {
if frac_f32 == 0 {
return sign | 0x7c00; }
let payload = (frac_f32 >> 13) as u16;
return sign | 0x7c00 | (payload | 0x0200);
}
if exp_f32 == 0 {
return sign;
}
let target_exp = exp_f32 - 127 + 15;
if target_exp >= 0x1f {
return sign | 0x7c00;
}
if target_exp > 0 {
let truncated = (frac_f32 >> 13) as u16;
let round_bit = (frac_f32 >> 12) & 1;
let sticky = (frac_f32 & 0xfff) != 0;
let round_up = round_bit != 0 && (sticky || (truncated & 1) != 0);
if !round_up {
return sign | ((target_exp as u16) << 10) | truncated;
}
let new_mant = truncated + 1;
if new_mant < 0x400 {
return sign | ((target_exp as u16) << 10) | new_mant;
}
let new_exp = target_exp + 1;
if new_exp >= 0x1f {
return sign | 0x7c00;
}
return sign | ((new_exp as u16) << 10);
}
if target_exp < -10 {
if target_exp == -11 {
if frac_f32 != 0 {
return sign | 1;
}
}
return sign;
}
let shift = (14 - target_exp) as u32; let full_mant = (1u32 << 23) | frac_f32;
let round_bit = (full_mant >> (shift - 1)) & 1;
let sticky_mask = (1u32 << (shift - 1)) - 1;
let sticky = (full_mant & sticky_mask) != 0;
let truncated = full_mant >> shift;
let round_up = round_bit != 0 && (sticky || (truncated & 1) != 0);
let result = if round_up { truncated + 1 } else { truncated };
sign | (result as u16)
}
pub(crate) fn f16_bits_to_f32_slice(src: &[u16], dst: &mut [f32]) {
assert_eq!(
src.len(),
dst.len(),
"f16_bits_to_f32_slice length mismatch"
);
#[allow(unexpected_cfgs)]
{
archmage::incant!(cvt_f16_to_f32(src, dst), [v3, scalar])
}
}
pub(crate) fn f32_to_f16_bits_slice(src: &[f32], dst: &mut [u16]) {
assert_eq!(
src.len(),
dst.len(),
"f32_to_f16_bits_slice length mismatch"
);
#[allow(unexpected_cfgs)]
{
archmage::incant!(cvt_f32_to_f16(src, dst), [v3, scalar])
}
}
fn cvt_f16_to_f32_scalar(_tok: archmage::ScalarToken, src: &[u16], dst: &mut [f32]) {
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = f16_bits_to_f32(*s);
}
}
fn cvt_f32_to_f16_scalar(_tok: archmage::ScalarToken, src: &[f32], dst: &mut [u16]) {
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = f32_to_f16_bits(*s);
}
}
#[cfg(target_arch = "x86_64")]
#[archmage::arcane(import_intrinsics)]
fn cvt_f16_to_f32_v3(_tok: archmage::X64V3Token, src: &[u16], dst: &mut [f32]) {
let n = src.len();
let chunks = n / 8;
for i in 0..chunks {
let s_chunk: &[u16; 8] = (&src[i * 8..i * 8 + 8]).try_into().unwrap();
let d_chunk: &mut [f32; 8] = (&mut dst[i * 8..i * 8 + 8]).try_into().unwrap();
let packed = _mm_loadu_si128(s_chunk);
let lanes = _mm256_cvtph_ps(packed);
_mm256_storeu_ps(d_chunk, lanes);
}
let tail_start = chunks * 8;
for i in tail_start..n {
dst[i] = f16_bits_to_f32(src[i]);
}
}
#[cfg(target_arch = "x86_64")]
#[archmage::arcane(import_intrinsics)]
fn cvt_f32_to_f16_v3(_tok: archmage::X64V3Token, src: &[f32], dst: &mut [u16]) {
let n = src.len();
let chunks = n / 8;
for i in 0..chunks {
let s_chunk: &[f32; 8] = (&src[i * 8..i * 8 + 8]).try_into().unwrap();
let d_chunk: &mut [u16; 8] = (&mut dst[i * 8..i * 8 + 8]).try_into().unwrap();
let lanes = _mm256_loadu_ps(s_chunk);
let packed = _mm256_cvtps_ph::<_MM_FROUND_TO_NEAREST_INT>(lanes);
_mm_storeu_si128(d_chunk, packed);
}
let tail_start = chunks * 8;
for i in tail_start..n {
dst[i] = f32_to_f16_bits(src[i]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exhaustive_f16_f32_f16_roundtrip() {
for bits in 0u16..=0xffff {
let f = f16_bits_to_f32(bits);
let back = f32_to_f16_bits(f);
let is_f16_nan = (bits & 0x7c00) == 0x7c00 && (bits & 0x03ff) != 0;
if is_f16_nan {
assert!(
f.is_nan(),
"bits {:#06x} should produce f32 NaN, got {}",
bits,
f
);
let back_is_nan = (back & 0x7c00) == 0x7c00 && (back & 0x03ff) != 0;
assert!(
back_is_nan,
"bits {:#06x} round-tripped to non-NaN {:#06x}",
bits, back
);
continue;
}
assert_eq!(
bits, back,
"bit-exact roundtrip failed for {:#06x} → {} → {:#06x}",
bits, f, back
);
}
}
#[test]
fn f16_bits_to_f32_matches_half_crate() {
for bits in 0u16..=0xffff {
let ours = f16_bits_to_f32(bits);
let theirs = half::f16::from_bits(bits).to_f32();
if ours.is_nan() && theirs.is_nan() {
continue;
}
assert_eq!(
ours.to_bits(),
theirs.to_bits(),
"bits {:#06x}: ours={} ({:#010x}), theirs={} ({:#010x})",
bits,
ours,
ours.to_bits(),
theirs,
theirs.to_bits()
);
}
}
#[test]
fn f32_to_f16_bits_matches_half_crate_sampled() {
for bits in 0u16..=0xffff {
let center = f16_bits_to_f32(bits);
if !center.is_finite() {
continue;
}
let next = f16_bits_to_f32(bits.wrapping_add(1));
let midpoint = center + (next - center) * 0.5;
for v in [center, midpoint, center.next_up(), center.next_down()] {
if v.is_nan() {
continue;
}
let ours = f32_to_f16_bits(v);
let theirs = half::f16::from_f32(v).to_bits();
assert_eq!(
ours,
theirs,
"f32 {} ({:#010x}): ours={:#06x}, theirs={:#06x}",
v,
v.to_bits(),
ours,
theirs
);
}
}
}
#[test]
fn slice_f16_to_f32_simd_matches_scalar_exhaustive() {
let bits: Vec<u16> = (0u16..=0xffff).collect();
let mut via_slice = vec![0.0f32; bits.len()];
let mut via_scalar = vec![0.0f32; bits.len()];
f16_bits_to_f32_slice(&bits, &mut via_slice);
for (i, &b) in bits.iter().enumerate() {
via_scalar[i] = f16_bits_to_f32(b);
}
for i in 0..bits.len() {
let a = via_slice[i];
let b = via_scalar[i];
if a.is_nan() && b.is_nan() {
continue;
}
assert_eq!(
a.to_bits(),
b.to_bits(),
"f16 bits {:#06x}: slice={} (bits {:#010x}), scalar={} (bits {:#010x})",
bits[i],
a,
a.to_bits(),
b,
b.to_bits()
);
}
}
#[test]
fn slice_f32_to_f16_simd_matches_scalar_sampled() {
let mut samples: Vec<f32> = Vec::new();
for b in 0u16..=0xffff {
let c = f16_bits_to_f32(b);
samples.push(c);
if c.is_finite() {
samples.push(c.next_up());
samples.push(c.next_down());
let next = f16_bits_to_f32(b.wrapping_add(1));
if next.is_finite() {
samples.push(c + (next - c) * 0.5);
}
}
}
let mut via_slice = vec![0u16; samples.len()];
let mut via_scalar = vec![0u16; samples.len()];
f32_to_f16_bits_slice(&samples, &mut via_slice);
for (i, &v) in samples.iter().enumerate() {
via_scalar[i] = f32_to_f16_bits(v);
}
for i in 0..samples.len() {
if samples[i].is_nan() {
let a_nan = (via_slice[i] & 0x7c00) == 0x7c00 && (via_slice[i] & 0x03ff) != 0;
let b_nan = (via_scalar[i] & 0x7c00) == 0x7c00 && (via_scalar[i] & 0x03ff) != 0;
assert!(a_nan && b_nan, "NaN input lost NaN-ness");
continue;
}
assert_eq!(
via_slice[i],
via_scalar[i],
"f32 {} ({:#010x}): slice={:#06x}, scalar={:#06x}",
samples[i],
samples[i].to_bits(),
via_slice[i],
via_scalar[i],
);
}
}
#[test]
fn f32_to_f16_boundary_cases() {
assert_eq!(f32_to_f16_bits(0.0), 0x0000);
assert_eq!(f32_to_f16_bits(-0.0), 0x8000);
assert_eq!(f32_to_f16_bits(1.0), 0x3c00);
assert_eq!(f32_to_f16_bits(-1.0), 0xbc00);
assert_eq!(f32_to_f16_bits(2.0f32.powi(-14)), 0x0400);
assert_eq!(f32_to_f16_bits(2.0f32.powi(-24)), 0x0001);
assert_eq!(f32_to_f16_bits(65504.0), 0x7bff);
assert_eq!(f32_to_f16_bits(65520.0), 0x7c00);
assert_eq!(f32_to_f16_bits(1e9), 0x7c00);
assert_eq!(f32_to_f16_bits(-1e9), 0xfc00);
assert_eq!(f32_to_f16_bits(f32::INFINITY), 0x7c00);
assert_eq!(f32_to_f16_bits(f32::NEG_INFINITY), 0xfc00);
assert!((f32_to_f16_bits(f32::NAN) & 0x7c00) == 0x7c00);
assert!((f32_to_f16_bits(f32::NAN) & 0x03ff) != 0);
}
#[test]
fn f16_bits_to_f32_boundary_cases() {
assert_eq!(f16_bits_to_f32(0x0000).to_bits(), 0x0000_0000);
assert_eq!(f16_bits_to_f32(0x8000).to_bits(), 0x8000_0000);
assert_eq!(f16_bits_to_f32(0x3c00), 1.0);
assert_eq!(f16_bits_to_f32(0xbc00), -1.0);
let v = f16_bits_to_f32(0x0001);
assert_eq!(v, 2.0f32.powi(-24));
let v = f16_bits_to_f32(0x03ff);
let expected = 1023.0 * 2.0f32.powi(-24);
assert!((v - expected).abs() < 1e-30);
let v = f16_bits_to_f32(0x0400);
assert_eq!(v, 2.0f32.powi(-14));
assert_eq!(f16_bits_to_f32(0x7bff), 65504.0);
assert!(f16_bits_to_f32(0x7c00).is_infinite());
assert!(f16_bits_to_f32(0x7c00).is_sign_positive());
assert!(f16_bits_to_f32(0xfc00).is_infinite());
assert!(f16_bits_to_f32(0xfc00).is_sign_negative());
assert!(f16_bits_to_f32(0x7e00).is_nan());
assert!(f16_bits_to_f32(0xffff).is_nan());
}
}