use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use serde_json::Value;
use tinyquant_core::codec::{CodecConfig, SUPPORTED_BIT_WIDTHS};
use tinyquant_core::errors::CodecError;
const FIXTURE_JSON: &str = include_str!("fixtures/config_hashes.json");
fn hash_of<T: Hash>(value: &T) -> u64 {
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
hasher.finish()
}
#[test]
fn new_bw2_ok() {
let cfg = CodecConfig::new(2, 0, 32, false).expect("bw=2 must succeed");
assert_eq!(cfg.bit_width(), 2);
assert_eq!(cfg.num_codebook_entries(), 4);
}
#[test]
fn new_bw4_ok() {
let cfg = CodecConfig::new(4, 42, 768, true).expect("bw=4 must succeed");
assert_eq!(cfg.bit_width(), 4);
assert_eq!(cfg.num_codebook_entries(), 16);
}
#[test]
fn new_bw8_ok() {
let cfg = CodecConfig::new(8, 7, 1536, true).expect("bw=8 must succeed");
assert_eq!(cfg.bit_width(), 8);
assert_eq!(cfg.num_codebook_entries(), 256);
}
#[test]
fn new_bw3_rejects_with_unsupported() {
let err = CodecConfig::new(3, 0, 32, false).expect_err("bw=3 must fail");
assert_eq!(err, CodecError::UnsupportedBitWidth { got: 3 });
}
#[test]
fn new_dim_zero_rejects_with_invalid_dimension() {
let err = CodecConfig::new(4, 0, 0, false).expect_err("dim=0 must fail");
assert_eq!(err, CodecError::InvalidDimension { got: 0 });
}
#[test]
fn supported_bit_widths_constant_is_2_4_8() {
assert_eq!(SUPPORTED_BIT_WIDTHS, &[2, 4, 8]);
}
#[test]
fn num_codebook_entries_table() {
assert_eq!(
CodecConfig::new(2, 0, 8, false)
.unwrap()
.num_codebook_entries(),
4
);
assert_eq!(
CodecConfig::new(4, 0, 8, false)
.unwrap()
.num_codebook_entries(),
16
);
assert_eq!(
CodecConfig::new(8, 0, 8, false)
.unwrap()
.num_codebook_entries(),
256
);
}
#[test]
fn config_hash_matches_python_spot_check_bw4_seed42_dim768_res_on() {
let expected = "cc50f0a21077b1971aadb240e4209ff1acc3b9da32233100e481f03b5161fbed";
let cfg = CodecConfig::new(4, 42, 768, true).unwrap();
assert_eq!(cfg.config_hash().as_ref(), expected);
}
#[test]
fn config_hash_matches_python_all_120_triples() {
let root: Value = serde_json::from_str(FIXTURE_JSON).expect("fixture parses");
let entries = root
.get("entries")
.and_then(Value::as_array)
.expect("entries array present");
assert_eq!(entries.len(), 120, "fixture sweep must cover 120 triples");
for (i, entry) in entries.iter().enumerate() {
let bit_width = entry["bit_width"].as_u64().unwrap() as u8;
let seed = entry["seed"].as_u64().unwrap();
let dimension = entry["dimension"].as_u64().unwrap() as u32;
let residual_enabled = entry["residual_enabled"].as_bool().unwrap();
let expected = entry["config_hash"].as_str().unwrap();
let cfg = CodecConfig::new(bit_width, seed, dimension, residual_enabled)
.unwrap_or_else(|e| panic!("triple {i} rejected: {e}"));
assert_eq!(
cfg.config_hash().as_ref(),
expected,
"triple {i}: bw={bit_width} seed={seed} dim={dimension} res={residual_enabled}"
);
}
}
#[test]
fn equality_ignores_cached_hash_field() {
let a = CodecConfig::new(4, 42, 768, true).unwrap();
let b = CodecConfig::new(4, 42, 768, true).unwrap();
assert_eq!(a, b);
assert_eq!(hash_of(&a), hash_of(&b));
}
#[test]
fn inequality_when_any_field_differs() {
let base = CodecConfig::new(4, 42, 768, true).unwrap();
assert_ne!(base, CodecConfig::new(2, 42, 768, true).unwrap());
assert_ne!(base, CodecConfig::new(4, 43, 768, true).unwrap());
assert_ne!(base, CodecConfig::new(4, 42, 769, true).unwrap());
assert_ne!(base, CodecConfig::new(4, 42, 768, false).unwrap());
}
#[test]
fn accessors_report_construction_inputs() {
let cfg = CodecConfig::new(8, 99, 1536, false).unwrap();
assert_eq!(cfg.bit_width(), 8);
assert_eq!(cfg.seed(), 99);
assert_eq!(cfg.dimension(), 1536);
assert!(!cfg.residual_enabled());
}
#[test]
fn debug_includes_all_primary_fields() {
let cfg = CodecConfig::new(4, 42, 768, true).unwrap();
let dbg = format!("{cfg:?}");
assert!(dbg.contains("bit_width"));
assert!(dbg.contains("42"));
assert!(dbg.contains("768"));
}