use crate::bit_writer::BitWriter;
use crate::error::Result;
pub fn f32_to_f16_bits(value: f32) -> Result<u16> {
let bits = value.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xFF) as i32;
let mantissa = bits & 0x7F_FFFF;
if exp == 0 && mantissa == 0 {
return Ok((sign << 15) as u16);
}
if exp == 0xFF {
return Err(crate::error::Error::InvalidInput(
"F16 cannot encode Inf or NaN".into(),
));
}
let new_exp = exp - 127 + 15;
if new_exp >= 31 {
return Err(crate::error::Error::InvalidInput(format!(
"F16 overflow: {value} exceeds max representable (65504)"
)));
}
if new_exp <= 0 {
if new_exp < -10 {
return Ok((sign << 15) as u16);
}
let m = mantissa | 0x80_0000;
let shift = 1 - new_exp;
let half_mantissa = (m >> (13 + shift)) as u16;
return Ok(((sign << 15) as u16) | half_mantissa);
}
let half_mantissa = (mantissa >> 13) as u16;
let half_exp = (new_exp as u16) << 10;
Ok(((sign << 15) as u16) | half_exp | half_mantissa)
}
pub fn f16_bits_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) & 1) as u32;
let exp = ((bits >> 10) & 0x1F) as u32;
let mantissa = (bits & 0x3FF) as u32;
if exp == 0 {
if mantissa == 0 {
return f32::from_bits(sign << 31);
}
let mut m = mantissa;
let mut e: i32 = -1;
while m & 0x400 == 0 {
m <<= 1;
e -= 1;
}
m &= 0x3FF; let f32_exp = ((e + 127) as u32) & 0xFF;
let f32_mantissa = m << 13;
return f32::from_bits((sign << 31) | (f32_exp << 23) | f32_mantissa);
}
if exp == 31 {
if mantissa == 0 {
return f32::from_bits((sign << 31) | (0xFF << 23));
}
return f32::from_bits((sign << 31) | (0xFF << 23) | (mantissa << 13));
}
let f32_exp = (exp as i32 - 15 + 127) as u32;
let f32_mantissa = mantissa << 13;
f32::from_bits((sign << 31) | (f32_exp << 23) | f32_mantissa)
}
pub fn f16_roundtrip(value: f32) -> Result<f32> {
Ok(f16_bits_to_f32(f32_to_f16_bits(value)?))
}
pub fn write_f16(value: f32, writer: &mut BitWriter) -> Result<()> {
let bits = f32_to_f16_bits(value)?;
writer.write(16, bits as u64)?;
Ok(())
}
pub fn write_lf_quant(writer: &mut BitWriter, dc_quant_custom: Option<[f32; 3]>) -> Result<()> {
match dc_quant_custom {
None => {
writer.write(1, 1)?; }
Some(dq) => {
writer.write(1, 0)?; for &q in &dq {
write_f16(q * 128.0, writer)?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_f16_roundtrip_exact_values() {
for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.5, 2.0, 0.25, 65504.0] {
assert_eq!(f16_roundtrip(v).unwrap(), v, "f16_roundtrip({v}) failed");
}
}
#[test]
fn test_f16_roundtrip_zero() {
assert_eq!(f16_roundtrip(0.0).unwrap(), 0.0);
let neg_zero = f16_roundtrip(-0.0).unwrap();
assert!(neg_zero.is_sign_negative() && neg_zero == 0.0);
}
#[test]
fn test_f16_roundtrip_truncation() {
let v = 1.0 / 3.0; let rt = f16_roundtrip(v).unwrap();
assert!((rt - v).abs() < 0.001);
assert_eq!(f16_roundtrip(rt).unwrap(), rt);
}
#[test]
fn test_f16_dc_quant_roundtrip() {
let enc_factors = [65536.0f32, 4096.0, 4096.0];
for &ef in &enc_factors {
let dc_quant = 1.0 / ef;
let f16_val = dc_quant * 128.0;
let rt = f16_roundtrip(f16_val).unwrap();
let decoder_dc_quant = rt / 128.0;
let decoder_inv = 1.0 / decoder_dc_quant;
assert!(
(decoder_inv - ef).abs() / ef < 0.01,
"enc_factor {ef}: decoder sees {decoder_inv}"
);
}
}
#[test]
fn test_f16_overflow_rejects() {
assert!(f16_roundtrip(100000.0).is_err());
assert!(f32_to_f16_bits(100000.0).is_err());
}
#[test]
fn test_f16_inf_nan_rejects() {
assert!(f32_to_f16_bits(f32::INFINITY).is_err());
assert!(f32_to_f16_bits(f32::NEG_INFINITY).is_err());
assert!(f32_to_f16_bits(f32::NAN).is_err());
}
#[test]
fn test_f16_small_values() {
let small = f16_roundtrip(0.0001).unwrap();
assert!(small > 0.0 && small < 0.001, "got {small}");
assert_eq!(f16_roundtrip(small).unwrap(), small);
}
#[test]
fn test_f16_bits_to_f32_inf() {
assert!(f16_bits_to_f32(0x7C00).is_infinite());
assert!(f16_bits_to_f32(0x7C00) > 0.0);
assert!(f16_bits_to_f32(0xFC00).is_infinite());
assert!(f16_bits_to_f32(0xFC00) < 0.0);
}
#[test]
fn test_f16_bits_to_f32_nan() {
assert!(f16_bits_to_f32(0x7C01).is_nan());
}
#[test]
fn test_write_f16() {
let mut writer = BitWriter::new();
write_f16(1.0, &mut writer).unwrap();
assert_eq!(writer.bits_written(), 16);
}
}