#[cfg(test)]
mod tests {
use super::*;
fn gen() -> SyntheticWeightGenerator {
SyntheticWeightGenerator::new(42)
}
#[test]
fn test_generator_deterministic() {
let gen1 = SyntheticWeightGenerator::new(42);
let gen2 = SyntheticWeightGenerator::new(42);
let w1 = gen1.generate_f32(&[10, 10]);
let w2 = gen2.generate_f32(&[10, 10]);
assert_eq!(w1, w2, "Same seed should produce same weights");
}
#[test]
fn test_generator_different_seeds() {
let gen1 = SyntheticWeightGenerator::new(42);
let gen2 = SyntheticWeightGenerator::new(43);
let w1 = gen1.generate_f32(&[10, 10]);
let w2 = gen2.generate_f32(&[10, 10]);
assert_ne!(w1, w2, "Different seeds should produce different weights");
}
#[test]
fn test_quant_block_sizes() {
let g = gen();
let cases: &[(&str, fn(&SyntheticWeightGenerator, usize) -> Vec<u8>, usize, usize)] = &[
("Q4_0 1-block", SyntheticWeightGenerator::generate_q4_0, 32, 18),
("Q4_0 2-blocks", SyntheticWeightGenerator::generate_q4_0, 64, 36),
("Q8_0 1-block", SyntheticWeightGenerator::generate_q8_0, 32, 34),
];
for &(label, generate_fn, num_elements, expected) in cases {
let data = generate_fn(&g, num_elements);
assert_eq!(data.len(), expected, "{label}: expected {expected} bytes for {num_elements} elements");
}
}
#[test]
fn test_quant_dispatch() {
let g = gen();
let cases: &[(QuantType, usize, usize)] = &[
(QuantType::F32, 100, 400), (QuantType::F16, 100, 200), (QuantType::BF16, 100, 200), (QuantType::Q4_K, 256, 144), (QuantType::Q5_K, 32, 22), (QuantType::Q6_K, 32, 34), ];
for &(quant, num_elements, expected) in cases {
let data = g.generate_quant(num_elements, quant);
assert_eq!(data.len(), expected, "{quant:?}: expected {expected} bytes for {num_elements} elements");
}
}
#[test]
fn test_model_weights_generation() {
let config = ModelConfig::tiny();
let weights = gen().generate_model_weights(&config, QuantType::Q4_0);
assert_eq!(weights.layer_weights.len(), config.num_layers);
assert_eq!(weights.output_norm.len(), config.hidden_dim);
assert!(weights.total_bytes() > 0);
}
#[test]
fn test_model_weights_metrics() {
let config = ModelConfig::tiny();
let weights = gen().generate_model_weights(&config, QuantType::F32);
assert!(weights.total_bytes() > 0);
assert_eq!(weights.param_count(), config.param_count());
}
#[test]
fn test_token_generator() {
let tg = TokenGenerator::new(42, 256);
let tokens = tg.generate(10);
assert_eq!(tokens.len(), 10);
assert!(tokens.iter().all(|&t| t > 0 && t < 256));
}
#[test]
fn test_token_generator_deterministic() {
let gen1 = TokenGenerator::new(42, 256);
let gen2 = TokenGenerator::new(42, 256);
assert_eq!(gen1.generate(10), gen2.generate(10));
}
#[test]
fn test_token_generator_distribution() {
let vocab_size = 1000;
let tg = TokenGenerator::new(42, vocab_size);
let common = vec![1, 2, 3];
let tokens = tg.generate_with_distribution(100, &common);
assert_eq!(tokens.len(), 100);
assert!(tokens.iter().all(|&t| t < vocab_size as u32));
}
#[test]
fn test_token_generator_distribution_empty_common() {
let tg = TokenGenerator::new(42, 100);
let tokens = tg.generate_with_distribution(10, &[]);
assert_eq!(tokens.len(), 10);
}
#[test]
fn test_f16_generation() {
let weights = gen().generate_f16(&[100]);
assert_eq!(weights.len(), 100);
assert!(weights.iter().all(|w| w.is_finite()));
}
#[test]
fn test_generate_f32_scaled() {
let scale = 10.0;
let weights = gen().generate_f32_scaled(&[100], scale);
assert_eq!(weights.len(), 100);
for &w in &weights {
assert!(w >= -scale && w <= scale);
}
}
#[test]
fn test_f32_empty_shape() {
let weights = gen().generate_f32(&[]);
assert_eq!(weights.len(), 1); }
}