use std::path::Path;
use tinyquant_core::codec::{Codebook, Codec, CodecConfig, CompressedVector, Parallelism};
fn load_training() -> Vec<f32> {
let p = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/codebook/training_n10000_d64.f32.bin");
let bytes = std::fs::read(&p).unwrap_or_else(|e| panic!("training fixture missing: {e}"));
bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
fn full_pipeline(training: &[f32], n_threads: usize) -> Vec<CompressedVector> {
let cfg = CodecConfig::new(4, 42, 64, true).unwrap();
let cb = Codebook::train(training, &cfg).unwrap();
let codec = Codec::new();
let rows = 64_usize;
let cols = 64_usize;
let batch = &training[..rows * cols];
if n_threads == 1 {
codec
.compress_batch_with(batch, rows, cols, &cfg, &cb, Parallelism::Serial)
.unwrap()
} else {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(n_threads)
.build()
.expect("failed to build local rayon pool");
pool.install(|| {
let par = Parallelism::Custom(|count, body| {
use rayon::prelude::*;
(0..count).into_par_iter().for_each(body);
});
codec
.compress_batch_with(batch, rows, cols, &cfg, &cb, par)
.unwrap()
})
}
}
fn cvs_eq(a: &[CompressedVector], b: &[CompressedVector], label: &str) {
assert_eq!(a.len(), b.len(), "{label}: length mismatch");
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
assert_eq!(
x.indices(),
y.indices(),
"{label}: indices mismatch at row {i}"
);
assert_eq!(
x.residual(),
y.residual(),
"{label}: residual mismatch at row {i}"
);
}
}
#[test]
fn full_pipeline_deterministic_across_repeat_runs() {
let training = load_training();
let n_threads = rayon::current_num_threads().max(2);
let first = full_pipeline(&training, n_threads);
for run in 1..10 {
let again = full_pipeline(&training, n_threads);
cvs_eq(&first, &again, &format!("repeat-run {run} (t={n_threads})"));
}
}
#[test]
fn full_pipeline_deterministic_across_thread_counts() {
let training = load_training();
let reference = full_pipeline(&training, 1);
let n_cpu = rayon::current_num_threads().max(2);
for &t in &[2_usize, 4, 8, n_cpu] {
if t > n_cpu {
continue; }
let out = full_pipeline(&training, t);
cvs_eq(&reference, &out, &format!("t={t}"));
}
}