use std::fs;
use std::path::PathBuf;
use tinyquant_core::codec::RotationMatrix;
const FIXTURE_DIR: &str = "tests/fixtures/rotation";
fn load_fixture(name: &str, expected_dim: usize) -> Vec<f64> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join(FIXTURE_DIR)
.join(name);
let bytes =
fs::read(&path).unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
assert_eq!(
bytes.len(),
expected_dim * expected_dim * 8,
"fixture {name} wrong size"
);
let mut values = Vec::with_capacity(expected_dim * expected_dim);
for chunk in bytes.chunks_exact(8) {
values.push(f64::from_le_bytes(chunk.try_into().unwrap()));
}
values
}
#[test]
fn seed_42_dim_64_matches_frozen_snapshot_bit_for_bit() {
let expected = load_fixture("seed_42_dim_64.f64.bin", 64);
let rot = RotationMatrix::build(42, 64);
let actual = rot.matrix();
assert_eq!(actual.len(), expected.len());
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert_eq!(
a.to_bits(),
e.to_bits(),
"fixture mismatch at index {i}: {a} vs {e}"
);
}
}
#[test]
fn seed_42_dim_64_fixture_is_orthogonal_within_1e_12() {
let expected = load_fixture("seed_42_dim_64.f64.bin", 64);
let dim = 64usize;
for i in 0..dim {
for j in 0..dim {
let mut acc = 0.0f64;
for k in 0..dim {
acc += expected[i * dim + k] * expected[j * dim + k];
}
let target = if i == j { 1.0 } else { 0.0 };
assert!(
(acc - target).abs() < 1e-12,
"loaded fixture not orthogonal at ({i},{j}): {acc}"
);
}
}
}
#[test]
#[ignore = "cross-runner SIMD ISA nondeterminism at dim=768; see R19"]
fn seed_42_dim_768_matches_frozen_snapshot_bit_for_bit() {
let expected = load_fixture("seed_42_dim_768.f64.bin", 768);
let rot = RotationMatrix::build(42, 768);
let actual = rot.matrix();
assert_eq!(actual.len(), expected.len());
let mut mismatches = 0usize;
for (a, e) in actual.iter().zip(expected.iter()) {
if a.to_bits() != e.to_bits() {
mismatches += 1;
}
}
assert_eq!(
mismatches, 0,
"{mismatches} f64 words differ between build and fixture for (42, 768)"
);
}
#[test]
fn seed_42_dim_768_build_is_orthogonal_within_1e_12() {
let rot = RotationMatrix::build(42, 768);
let m = rot.matrix();
let dim = 768usize;
let pairs: &[(usize, usize)] = &[
(0, 0),
(0, 1),
(0, 7),
(0, 100),
(0, 767),
(1, 1),
(1, 2),
(1, 500),
(7, 7),
(7, 8),
(7, 256),
(100, 100),
(100, 101),
(100, 300),
(255, 255),
(255, 256),
(255, 767),
(383, 383),
(383, 384),
(383, 500),
(500, 500),
(500, 501),
(500, 767),
(700, 700),
(700, 701),
(767, 767),
];
for &(i, j) in pairs {
let mut acc = 0.0f64;
for k in 0..dim {
acc += m[i * dim + k] * m[j * dim + k];
}
let target = if i == j { 1.0 } else { 0.0 };
assert!(
(acc - target).abs() < 1e-12,
"non-orthogonal at ({i},{j}): {acc}"
);
}
}