use std::path::Path;
use std::sync::Arc;
use tinyquant_core::codec::CompressedVector;
use tinyquant_io::compressed_vector::{from_bytes, to_bytes};
const FIXTURE_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures");
fn load_case(
case_id: &str,
bit_width: u8,
dim: u32,
has_residual: bool,
) -> (CompressedVector, Vec<u8>) {
let base = Path::new(FIXTURE_DIR).join(case_id);
let indices_path = base.join("indices.u8.bin");
let config_hash_path = base.join("config_hash.txt");
let expected_path = base.join("expected.bin");
let indices = std::fs::read(&indices_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", indices_path.display()));
let config_hash: String = std::fs::read_to_string(&config_hash_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", config_hash_path.display()));
let expected = std::fs::read(&expected_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", expected_path.display()));
let residual: Option<Box<[u8]>> = if has_residual {
let r_path = base.join("residual.u8.bin");
let r = std::fs::read(&r_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", r_path.display()));
Some(r.into_boxed_slice())
} else {
None
};
let cv = CompressedVector::new(
indices.into_boxed_slice(),
residual,
Arc::from(config_hash.as_str()),
dim,
bit_width,
)
.unwrap_or_else(|e| panic!("CompressedVector::new failed for {case_id}: {e}"));
(cv, expected)
}
fn assert_parity(case_id: &str, bit_width: u8, dim: u32, has_residual: bool) {
let (cv, expected) = load_case(case_id, bit_width, dim, has_residual);
let got = to_bytes(&cv);
assert_eq!(
got,
expected,
"{case_id}: Rust to_bytes differs from Python expected.bin (len rust={} py={})",
got.len(),
expected.len()
);
let cv2 = from_bytes(&got)
.unwrap_or_else(|e| panic!("{case_id}: from_bytes failed after to_bytes: {e}"));
assert_eq!(
cv2.indices(),
cv.indices(),
"{case_id}: round-trip indices mismatch"
);
assert_eq!(
cv2.residual(),
cv.residual(),
"{case_id}: round-trip residual mismatch"
);
assert_eq!(
cv2.dimension(),
dim,
"{case_id}: round-trip dimension mismatch"
);
assert_eq!(
cv2.bit_width(),
bit_width,
"{case_id}: round-trip bit_width mismatch"
);
}
#[test]
fn case_01_bw4_dim768_residual() {
assert_parity("case_01", 4, 768, true);
}
#[test]
fn case_02_bw2_dim768_no_residual() {
assert_parity("case_02", 2, 768, false);
}
#[test]
fn case_03_bw8_dim768_no_residual() {
assert_parity("case_03", 8, 768, false);
}
#[test]
fn case_04_bw4_dim1_no_residual() {
assert_parity("case_04", 4, 1, false);
}
#[test]
fn case_05_bw2_dim17_residual() {
assert_parity("case_05", 2, 17, true);
}
#[test]
fn case_06_bw4_dim15_no_residual() {
assert_parity("case_06", 4, 15, false);
}
#[test]
fn case_07_bw8_dim1536_residual() {
assert_parity("case_07", 8, 1536, true);
}
#[test]
fn case_08_bw4_dim768_no_residual() {
assert_parity("case_08", 4, 768, false);
}
#[test]
fn case_09_bw2_dim16_no_residual() {
assert_parity("case_09", 2, 16, false);
}
#[test]
fn case_10_bw4_dim16_residual() {
assert_parity("case_10", 4, 16, true);
}