anamnesis 0.4.4

Parse any tensor format, recover any precision — framework-agnostic FP8/GPTQ/AWQ/BnB dequantization, NPZ parsing, and PyTorch .pth conversion for Rust
Documentation
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Cross-validation tests against `PyTorch` reference dequantization.
//!
//! These tests load pre-computed fixtures generated by
//! `tests/fixtures/fp8_reference/generate.py` and compare anamnesis
//! dequantization output against `PyTorch`'s `float8_e4m3fn` → `bfloat16`
//! conversion. Each fixture is a 256×256 slice from a real model.

#![allow(
    clippy::panic,
    clippy::unwrap_used,
    clippy::expect_used,
    clippy::indexing_slicing,
    clippy::as_conversions,
    clippy::wildcard_enum_match_arm
)]

use std::time::Instant;

use anamnesis::{
    dequantize_fp8_to_bf16, dequantize_per_channel_fp8_to_bf16, dequantize_per_tensor_fp8_to_bf16,
    Dtype,
};

// ---------------------------------------------------------------------------
// Fixture parsing
// ---------------------------------------------------------------------------

/// Binary fixture layout (all little-endian):
///
/// - 4 bytes: scheme (0 = fine-grained, 1 = per-tensor, 2 = per-channel)
/// - 4 bytes: scale dtype (0 = `F32`, 1 = `BF16`, 2 = `F16`)
/// - 4 bytes: rows (`u32`)
/// - 4 bytes: cols (`u32`)
/// - 4 bytes: weight byte count (`u32`)
/// - 4 bytes: scale byte count (`u32`)
/// - 4 bytes: expected byte count (`u32`)
/// - weight bytes, scale bytes, expected `BF16` bytes
struct Fixture {
    scheme: u32,
    scale_dtype: Dtype,
    rows: usize,
    cols: usize,
    weight_data: Vec<u8>,
    scale_data: Vec<u8>,
    expected_bf16: Vec<u8>,
}

fn read_u32_le(data: &[u8], offset: usize) -> u32 {
    let bytes: [u8; 4] = data[offset..offset + 4].try_into().unwrap();
    u32::from_le_bytes(bytes)
}

fn parse_fixture(data: &[u8]) -> Fixture {
    let scheme = read_u32_le(data, 0);
    let scale_dtype_id = read_u32_le(data, 4);
    let rows = read_u32_le(data, 8) as usize;
    let cols = read_u32_le(data, 12) as usize;
    let weight_len = read_u32_le(data, 16) as usize;
    let scale_len = read_u32_le(data, 20) as usize;
    let expected_len = read_u32_le(data, 24) as usize;

    let header_size = 28;
    let weight_start = header_size;
    let scale_start = weight_start + weight_len;
    let expected_start = scale_start + scale_len;

    let scale_dtype = match scale_dtype_id {
        0 => Dtype::F32,
        1 => Dtype::BF16,
        2 => Dtype::F16,
        other => panic!("unknown scale dtype id: {other}"),
    };

    Fixture {
        scheme,
        scale_dtype,
        rows,
        cols,
        weight_data: data[weight_start..weight_start + weight_len].to_vec(),
        scale_data: data[scale_start..scale_start + scale_len].to_vec(),
        expected_bf16: data[expected_start..expected_start + expected_len].to_vec(),
    }
}

// ---------------------------------------------------------------------------
// BF16 comparison
// ---------------------------------------------------------------------------

/// Compare two `BF16` byte slices, allowing up to `max_ulp_diff` ULP
/// (unit in the last place) difference per element.
///
/// Returns the number of mismatched elements and the maximum ULP diff found.
fn compare_bf16(actual: &[u8], expected: &[u8], max_ulp_diff: u16) -> (usize, u16) {
    assert_eq!(actual.len(), expected.len(), "output length mismatch");
    let mut mismatches = 0;
    let mut max_diff: u16 = 0;

    for (i, (a_pair, e_pair)) in actual
        .chunks_exact(2)
        .zip(expected.chunks_exact(2))
        .enumerate()
    {
        let a_bits = u16::from_le_bytes([a_pair[0], a_pair[1]]);
        let e_bits = u16::from_le_bytes([e_pair[0], e_pair[1]]);

        // Handle NaN: both NaN is a match.
        let a_is_nan = (a_bits & 0x7F80 == 0x7F80) && (a_bits & 0x007F != 0);
        let e_is_nan = (e_bits & 0x7F80 == 0x7F80) && (e_bits & 0x007F != 0);
        if a_is_nan && e_is_nan {
            continue;
        }
        if a_is_nan != e_is_nan {
            mismatches += 1;
            continue;
        }

        let diff = a_bits.abs_diff(e_bits);
        if diff > max_ulp_diff {
            mismatches += 1;
            if i < 5 {
                eprintln!(
                    "  element {i}: actual=0x{a_bits:04X}, expected=0x{e_bits:04X}, diff={diff} ULP"
                );
            }
        }
        if diff > max_diff {
            max_diff = diff;
        }
    }
    (mismatches, max_diff)
}

// ---------------------------------------------------------------------------
// Scale reading helper
// ---------------------------------------------------------------------------

fn read_scalar_scale(data: &[u8], dtype: Dtype) -> f32 {
    match dtype {
        Dtype::F32 => {
            let arr: [u8; 4] = data[..4].try_into().unwrap();
            f32::from_le_bytes(arr)
        }
        Dtype::BF16 => {
            let arr: [u8; 2] = data[..2].try_into().unwrap();
            f32::from_bits(u32::from(u16::from_le_bytes(arr)) << 16)
        }
        Dtype::F16 => {
            let arr: [u8; 2] = data[..2].try_into().unwrap();
            half::f16::from_le_bytes(arr).to_f32()
        }
        _ => panic!("unexpected scale dtype: {dtype}"),
    }
}

// ---------------------------------------------------------------------------
// Unified test runner
// ---------------------------------------------------------------------------

fn run_cross_validation(name: &str, data: &[u8], max_ulp: u16) {
    let fixture = parse_fixture(data);
    let total = fixture.rows * fixture.cols;

    // Dequantize with anamnesis and measure time.
    let start = Instant::now();
    let actual = match fixture.scheme {
        0 => dequantize_fp8_to_bf16(
            &fixture.weight_data,
            &fixture.scale_data,
            fixture.rows,
            fixture.cols,
            fixture.scale_dtype,
        )
        .expect("fine-grained dequant failed"),
        1 => {
            let scale = read_scalar_scale(&fixture.scale_data, fixture.scale_dtype);
            dequantize_per_tensor_fp8_to_bf16(&fixture.weight_data, scale)
                .expect("per-tensor dequant failed")
        }
        2 => dequantize_per_channel_fp8_to_bf16(
            &fixture.weight_data,
            &fixture.scale_data,
            fixture.rows,
            fixture.cols,
            fixture.scale_dtype,
        )
        .expect("per-channel dequant failed"),
        other => panic!("unknown scheme: {other}"),
    };
    let elapsed = start.elapsed();

    assert_eq!(actual.len(), fixture.expected_bf16.len());

    let (mismatches, max_diff) = compare_bf16(&actual, &fixture.expected_bf16, max_ulp);
    eprintln!(
        "{name}: {total} elements, {mismatches} mismatches, \
         max ULP diff = {max_diff}, anamnesis = {:.1} µs",
        elapsed.as_secs_f64() * 1e6
    );
    assert_eq!(
        mismatches, 0,
        "{name}: {mismatches}/{total} elements differ by more than {max_ulp} ULP"
    );
}

// ---------------------------------------------------------------------------
// Fine-grained (3 models)
// ---------------------------------------------------------------------------

#[test]
fn cross_validate_exaone_fine_grained() {
    run_cross_validation(
        "EXAONE fine-grained (BF16 scales)",
        include_bytes!("fixtures/fp8_reference/exaone_fine_grained.bin"),
        1,
    );
}

#[test]
fn cross_validate_qwen3_1_7b_fine_grained() {
    run_cross_validation(
        "Qwen3-1.7B fine-grained (BF16 scales)",
        include_bytes!("fixtures/fp8_reference/qwen3_1_7b_fine_grained.bin"),
        1,
    );
}

#[test]
fn cross_validate_qwen3_4b_fine_grained_f16() {
    run_cross_validation(
        "Qwen3-4B fine-grained (F16 scales)",
        include_bytes!("fixtures/fp8_reference/qwen3_4b_fine_grained_f16.bin"),
        1,
    );
}

// ---------------------------------------------------------------------------
// Per-tensor (3 models)
// ---------------------------------------------------------------------------

#[test]
fn cross_validate_ministral_per_tensor() {
    run_cross_validation(
        "Ministral per-tensor (BF16 scalar)",
        include_bytes!("fixtures/fp8_reference/ministral_per_tensor.bin"),
        1,
    );
}

#[test]
fn cross_validate_llama_static_per_tensor() {
    run_cross_validation(
        "Llama-static per-tensor (BF16 [1])",
        include_bytes!("fixtures/fp8_reference/llama_static_per_tensor.bin"),
        1,
    );
}

#[test]
fn cross_validate_nvidia_llama_per_tensor() {
    run_cross_validation(
        "NVIDIA Llama per-tensor (F32 scalar)",
        include_bytes!("fixtures/fp8_reference/nvidia_llama_per_tensor.bin"),
        1,
    );
}

// ---------------------------------------------------------------------------
// Per-channel (1 model)
// ---------------------------------------------------------------------------

#[test]
fn cross_validate_llama_dynamic_per_channel() {
    run_cross_validation(
        "Llama-dynamic per-channel (BF16 [N,1])",
        include_bytes!("fixtures/fp8_reference/llama_dynamic_per_channel.bin"),
        1,
    );
}

// Phase 4.8 reader-generic safetensors header parsing is cross-validated
// against the upstream HuggingFace `safetensors` Python library in the
// dedicated `cross_validation_safetensors.rs` integration file (one fixture
// + one Python-sourced JSON reference per quantization scheme).