use std::{
fs::File,
io::{BufRead, BufReader},
};
use diskann_wide::{Const, SIMDVector, SupportedLaneCount};
use half::f16;
fn test_file_path(name: &str) -> String {
format!("{}/test_data/{}", env!("CARGO_MANIFEST_DIR"), name)
}
const FLOAT16_INPUT_FILE: &str = "float16_conversion.txt";
fn parse_float16_reference_file() -> Vec<(f16, f32)> {
let file = File::open(test_file_path(FLOAT16_INPUT_FILE)).unwrap();
let reader = BufReader::new(file);
let mut output = Vec::new();
for line in reader.lines() {
let line = line.unwrap();
let mut split = line.split(", ");
let float16 = match split.next() {
Some(field) => field,
None => panic!("could not parse {}", line),
};
let float32 = match split.next() {
Some(field) => field,
None => panic!("could not parse {}", line),
};
assert!(split.next().is_none());
let float16: f16 = match u16::from_str_radix(float16.trim_start_matches("0x"), 16) {
Ok(number) => f16::from_bits(number),
Err(err) => panic!("could not parse {} as a u16, error: {}", float16, err),
};
let float32: f32 = match float32 {
"neg_infinity" => f32::NEG_INFINITY,
"infinity" => f32::INFINITY,
"nan" => f32::NAN,
other => match other.parse::<f32>() {
Ok(number) => number,
Err(err) => panic!("could not parse {} as a f32, error: {}", float32, err),
},
};
output.push((float16, float32))
}
assert_eq!(
output.len(),
65536,
"conversion file should have exactly 65536 entries"
);
output
}
fn test_f16_to_f32_exhaustive<T, U, const N: usize>(
cases: &[(f16, f32)],
convert: &dyn Fn(U) -> T,
opname: &str,
) where
Const<N>: SupportedLaneCount,
T: SIMDVector<Arch = diskann_wide::arch::Current, Scalar = f32, ConstLanes = Const<N>>,
U: SIMDVector<Arch = diskann_wide::arch::Current, Scalar = f16, ConstLanes = Const<N>>,
{
for case in cases.iter() {
let from = U::splat(diskann_wide::ARCH, case.0);
let converted: T = convert(from);
let check = |c: f32| {
if case.1.is_nan() {
assert!(c.is_nan(), "failed for case: {:?}. Op = {}", case, opname);
} else {
assert_eq!(c, case.1, "failed for case: {:?}. Op = {}", case, opname);
}
};
for c in converted.to_array() {
check(c);
}
let converted = diskann_wide::cast_f16_to_f32(case.0);
check(converted);
}
}
diskann_wide::alias!(f32x8);
diskann_wide::alias!(f16x8);
#[test]
fn test_f16_to_f32() {
let cases = parse_float16_reference_file();
test_f16_to_f32_exhaustive::<f32x8, f16x8, 8>(&cases, &|x| x.into(), "into");
test_f16_to_f32_exhaustive::<f32x8, f16x8, 8>(&cases, &|x| x.cast(), "cast");
}
fn test_f32_to_f16_exhaustive<T, U, const N: usize>(
cases: &[(f16, f32)],
convert: &dyn Fn(U) -> T,
opname: &str,
) where
Const<N>: SupportedLaneCount,
T: SIMDVector<Arch = diskann_wide::arch::Current, Scalar = f16, ConstLanes = Const<N>>,
U: SIMDVector<Arch = diskann_wide::arch::Current, Scalar = f32, ConstLanes = Const<N>>,
{
let test_conversion = |from: f32, expected: f16| {
let expected: f32 = diskann_wide::cast_f16_to_f32(expected);
let wide = U::splat(diskann_wide::ARCH, from);
let converted = convert(wide).to_array();
let check = |c: f16| {
let c: f32 = c.into();
if expected.is_nan() {
assert!(
c.is_nan(),
"failed for case: {:?}. Op = {}",
(from, expected),
opname
);
} else {
assert_eq!(
c,
expected,
"failed for case: {:?}. Op = {}",
(from, expected),
opname
);
}
};
for c in converted {
check(c);
}
let converted = diskann_wide::cast_f32_to_f16(from);
check(converted);
};
test_conversion(f32::INFINITY, f16::INFINITY);
test_conversion(f32::MAX, f16::INFINITY);
test_conversion(f32::NEG_INFINITY, f16::NEG_INFINITY);
test_conversion(f32::MIN, f16::NEG_INFINITY);
test_conversion(f32::NAN, f16::NAN);
let mut cases_handled = 0;
for i in 0..cases.len() - 1 {
let base = cases[i];
let next = cases[i + 1];
test_conversion(base.1, base.0);
cases_handled += 1;
if next.1.is_nan() {
break;
}
assert_eq!(base.1.total_cmp(&next.1), std::cmp::Ordering::Less);
if base.1.is_infinite() || next.1.is_infinite() {
continue;
}
let step = (next.1 - base.1) / 4.0;
test_conversion(base.1 + step, base.0);
test_conversion(base.1 + 3.0 * step, next.0);
let bits = base.0.to_bits();
if bits.is_multiple_of(2) {
test_conversion(base.1 + 2.0 * step, base.0);
} else {
test_conversion(base.1 + 2.0 * step, next.0);
}
}
assert_eq!(
cases_handled, 63490,
"expected to handle exactly 63490 non-NAN values for f16"
);
}
#[test]
fn test_f32_to_f16() {
let cases = parse_float16_reference_file();
test_f32_to_f16_exhaustive::<f16x8, f32x8, 8>(&cases, &|x| x.cast(), "cast");
}