#![allow(
clippy::panic,
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::as_conversions,
clippy::wildcard_enum_match_arm
)]
use std::time::Instant;
use anamnesis::{
dequantize_fp8_to_bf16, dequantize_per_channel_fp8_to_bf16, dequantize_per_tensor_fp8_to_bf16,
Dtype,
};
struct Fixture {
scheme: u32,
scale_dtype: Dtype,
rows: usize,
cols: usize,
weight_data: Vec<u8>,
scale_data: Vec<u8>,
expected_bf16: Vec<u8>,
}
fn read_u32_le(data: &[u8], offset: usize) -> u32 {
let bytes: [u8; 4] = data[offset..offset + 4].try_into().unwrap();
u32::from_le_bytes(bytes)
}
fn parse_fixture(data: &[u8]) -> Fixture {
let scheme = read_u32_le(data, 0);
let scale_dtype_id = read_u32_le(data, 4);
let rows = read_u32_le(data, 8) as usize;
let cols = read_u32_le(data, 12) as usize;
let weight_len = read_u32_le(data, 16) as usize;
let scale_len = read_u32_le(data, 20) as usize;
let expected_len = read_u32_le(data, 24) as usize;
let header_size = 28;
let weight_start = header_size;
let scale_start = weight_start + weight_len;
let expected_start = scale_start + scale_len;
let scale_dtype = match scale_dtype_id {
0 => Dtype::F32,
1 => Dtype::BF16,
2 => Dtype::F16,
other => panic!("unknown scale dtype id: {other}"),
};
Fixture {
scheme,
scale_dtype,
rows,
cols,
weight_data: data[weight_start..weight_start + weight_len].to_vec(),
scale_data: data[scale_start..scale_start + scale_len].to_vec(),
expected_bf16: data[expected_start..expected_start + expected_len].to_vec(),
}
}
fn compare_bf16(actual: &[u8], expected: &[u8], max_ulp_diff: u16) -> (usize, u16) {
assert_eq!(actual.len(), expected.len(), "output length mismatch");
let mut mismatches = 0;
let mut max_diff: u16 = 0;
for (i, (a_pair, e_pair)) in actual
.chunks_exact(2)
.zip(expected.chunks_exact(2))
.enumerate()
{
let a_bits = u16::from_le_bytes([a_pair[0], a_pair[1]]);
let e_bits = u16::from_le_bytes([e_pair[0], e_pair[1]]);
let a_is_nan = (a_bits & 0x7F80 == 0x7F80) && (a_bits & 0x007F != 0);
let e_is_nan = (e_bits & 0x7F80 == 0x7F80) && (e_bits & 0x007F != 0);
if a_is_nan && e_is_nan {
continue;
}
if a_is_nan != e_is_nan {
mismatches += 1;
continue;
}
let diff = a_bits.abs_diff(e_bits);
if diff > max_ulp_diff {
mismatches += 1;
if i < 5 {
eprintln!(
" element {i}: actual=0x{a_bits:04X}, expected=0x{e_bits:04X}, diff={diff} ULP"
);
}
}
if diff > max_diff {
max_diff = diff;
}
}
(mismatches, max_diff)
}
fn read_scalar_scale(data: &[u8], dtype: Dtype) -> f32 {
match dtype {
Dtype::F32 => {
let arr: [u8; 4] = data[..4].try_into().unwrap();
f32::from_le_bytes(arr)
}
Dtype::BF16 => {
let arr: [u8; 2] = data[..2].try_into().unwrap();
f32::from_bits(u32::from(u16::from_le_bytes(arr)) << 16)
}
Dtype::F16 => {
let arr: [u8; 2] = data[..2].try_into().unwrap();
half::f16::from_le_bytes(arr).to_f32()
}
_ => panic!("unexpected scale dtype: {dtype}"),
}
}
fn run_cross_validation(name: &str, data: &[u8], max_ulp: u16) {
let fixture = parse_fixture(data);
let total = fixture.rows * fixture.cols;
let start = Instant::now();
let actual = match fixture.scheme {
0 => dequantize_fp8_to_bf16(
&fixture.weight_data,
&fixture.scale_data,
fixture.rows,
fixture.cols,
fixture.scale_dtype,
)
.expect("fine-grained dequant failed"),
1 => {
let scale = read_scalar_scale(&fixture.scale_data, fixture.scale_dtype);
dequantize_per_tensor_fp8_to_bf16(&fixture.weight_data, scale)
.expect("per-tensor dequant failed")
}
2 => dequantize_per_channel_fp8_to_bf16(
&fixture.weight_data,
&fixture.scale_data,
fixture.rows,
fixture.cols,
fixture.scale_dtype,
)
.expect("per-channel dequant failed"),
other => panic!("unknown scheme: {other}"),
};
let elapsed = start.elapsed();
assert_eq!(actual.len(), fixture.expected_bf16.len());
let (mismatches, max_diff) = compare_bf16(&actual, &fixture.expected_bf16, max_ulp);
eprintln!(
"{name}: {total} elements, {mismatches} mismatches, \
max ULP diff = {max_diff}, anamnesis = {:.1} µs",
elapsed.as_secs_f64() * 1e6
);
assert_eq!(
mismatches, 0,
"{name}: {mismatches}/{total} elements differ by more than {max_ulp} ULP"
);
}
#[test]
fn cross_validate_exaone_fine_grained() {
run_cross_validation(
"EXAONE fine-grained (BF16 scales)",
include_bytes!("fixtures/fp8_reference/exaone_fine_grained.bin"),
1,
);
}
#[test]
fn cross_validate_qwen3_1_7b_fine_grained() {
run_cross_validation(
"Qwen3-1.7B fine-grained (BF16 scales)",
include_bytes!("fixtures/fp8_reference/qwen3_1_7b_fine_grained.bin"),
1,
);
}
#[test]
fn cross_validate_qwen3_4b_fine_grained_f16() {
run_cross_validation(
"Qwen3-4B fine-grained (F16 scales)",
include_bytes!("fixtures/fp8_reference/qwen3_4b_fine_grained_f16.bin"),
1,
);
}
#[test]
fn cross_validate_ministral_per_tensor() {
run_cross_validation(
"Ministral per-tensor (BF16 scalar)",
include_bytes!("fixtures/fp8_reference/ministral_per_tensor.bin"),
1,
);
}
#[test]
fn cross_validate_llama_static_per_tensor() {
run_cross_validation(
"Llama-static per-tensor (BF16 [1])",
include_bytes!("fixtures/fp8_reference/llama_static_per_tensor.bin"),
1,
);
}
#[test]
fn cross_validate_nvidia_llama_per_tensor() {
run_cross_validation(
"NVIDIA Llama per-tensor (F32 scalar)",
include_bytes!("fixtures/fp8_reference/nvidia_llama_per_tensor.bin"),
1,
);
}
#[test]
fn cross_validate_llama_dynamic_per_channel() {
run_cross_validation(
"Llama-dynamic per-channel (BF16 [N,1])",
include_bytes!("fixtures/fp8_reference/llama_dynamic_per_channel.bin"),
1,
);
}