tinyquant-io 0.0.0

Serialization, mmap, and file I/O for TinyQuant.
Documentation
//! Byte-parity tests: Rust `to_bytes` must produce bit-identical output to
//! the Python `CompressedVector.to_bytes()` reference for each of the 10
//! canonical fixture cases.
//!
//! Fixtures are generated by:
//!   `python scripts/generate_rust_fixtures.py serialization`
//!
//! Each case lives in `tests/fixtures/<case_id>/` and contains:
//! - `indices.u8.bin`  — raw u8 index bytes (one per dimension)
//! - `residual.u8.bin` — raw residual bytes (present only when residual=true)
//! - `config_hash.txt` — UTF-8 config hash string
//! - `expected.bin`    — Python-produced serialized bytes (ground truth)

use std::path::Path;
use std::sync::Arc;
use tinyquant_core::codec::CompressedVector;
use tinyquant_io::compressed_vector::{from_bytes, to_bytes};

const FIXTURE_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures");

/// Load a case directory and return the CompressedVector + expected bytes.
fn load_case(
    case_id: &str,
    bit_width: u8,
    dim: u32,
    has_residual: bool,
) -> (CompressedVector, Vec<u8>) {
    let base = Path::new(FIXTURE_DIR).join(case_id);

    let indices_path = base.join("indices.u8.bin");
    let config_hash_path = base.join("config_hash.txt");
    let expected_path = base.join("expected.bin");

    let indices = std::fs::read(&indices_path)
        .unwrap_or_else(|e| panic!("failed to read {}: {e}", indices_path.display()));
    let config_hash: String = std::fs::read_to_string(&config_hash_path)
        .unwrap_or_else(|e| panic!("failed to read {}: {e}", config_hash_path.display()));
    let expected = std::fs::read(&expected_path)
        .unwrap_or_else(|e| panic!("failed to read {}: {e}", expected_path.display()));

    let residual: Option<Box<[u8]>> = if has_residual {
        let r_path = base.join("residual.u8.bin");
        let r = std::fs::read(&r_path)
            .unwrap_or_else(|e| panic!("failed to read {}: {e}", r_path.display()));
        Some(r.into_boxed_slice())
    } else {
        None
    };

    let cv = CompressedVector::new(
        indices.into_boxed_slice(),
        residual,
        Arc::from(config_hash.as_str()),
        dim,
        bit_width,
    )
    .unwrap_or_else(|e| panic!("CompressedVector::new failed for {case_id}: {e}"));

    (cv, expected)
}

/// Assert Rust to_bytes matches Python expected.bin, then round-trip through from_bytes.
fn assert_parity(case_id: &str, bit_width: u8, dim: u32, has_residual: bool) {
    let (cv, expected) = load_case(case_id, bit_width, dim, has_residual);
    let got = to_bytes(&cv);

    assert_eq!(
        got,
        expected,
        "{case_id}: Rust to_bytes differs from Python expected.bin (len rust={} py={})",
        got.len(),
        expected.len()
    );

    // Also verify round-trip decode
    let cv2 = from_bytes(&got)
        .unwrap_or_else(|e| panic!("{case_id}: from_bytes failed after to_bytes: {e}"));
    assert_eq!(
        cv2.indices(),
        cv.indices(),
        "{case_id}: round-trip indices mismatch"
    );
    assert_eq!(
        cv2.residual(),
        cv.residual(),
        "{case_id}: round-trip residual mismatch"
    );
    assert_eq!(
        cv2.dimension(),
        dim,
        "{case_id}: round-trip dimension mismatch"
    );
    assert_eq!(
        cv2.bit_width(),
        bit_width,
        "{case_id}: round-trip bit_width mismatch"
    );
}

#[test]
fn case_01_bw4_dim768_residual() {
    assert_parity("case_01", 4, 768, true);
}

#[test]
fn case_02_bw2_dim768_no_residual() {
    assert_parity("case_02", 2, 768, false);
}

#[test]
fn case_03_bw8_dim768_no_residual() {
    assert_parity("case_03", 8, 768, false);
}

#[test]
fn case_04_bw4_dim1_no_residual() {
    assert_parity("case_04", 4, 1, false);
}

#[test]
fn case_05_bw2_dim17_residual() {
    assert_parity("case_05", 2, 17, true);
}

#[test]
fn case_06_bw4_dim15_no_residual() {
    assert_parity("case_06", 4, 15, false);
}

#[test]
fn case_07_bw8_dim1536_residual() {
    assert_parity("case_07", 8, 1536, true);
}

#[test]
fn case_08_bw4_dim768_no_residual() {
    assert_parity("case_08", 4, 768, false);
}

#[test]
fn case_09_bw2_dim16_no_residual() {
    assert_parity("case_09", 2, 16, false);
}

#[test]
fn case_10_bw4_dim16_residual() {
    assert_parity("case_10", 4, 16, true);
}