use super::super::import;
use super::super::*;
use std::collections::BTreeMap;
fn dummy_tensor(shape: Vec<usize>) -> (Vec<f32>, Vec<usize>) {
let size: usize = shape.iter().product();
(vec![0.1; size], shape)
}
#[test]
fn test_infer_config_huggingface_naming() {
let mut tensors = BTreeMap::new();
tensors.insert(
"model.embed_tokens.weight".to_string(),
dummy_tensor(vec![1000, 128]),
);
for i in 0..4 {
tensors.insert(
format!("model.layers.{i}.self_attn.q_proj.weight"),
dummy_tensor(vec![128, 128]),
);
tensors.insert(
format!("model.layers.{i}.self_attn.q_proj.bias"),
dummy_tensor(vec![128]),
);
tensors.insert(
format!("model.layers.{i}.self_attn.k_proj.weight"),
dummy_tensor(vec![128, 128]),
);
tensors.insert(
format!("model.layers.{i}.mlp.gate_proj.weight"),
dummy_tensor(vec![512, 128]),
);
}
let config = infer_model_config_from_tensors(&tensors);
assert!(config.is_some(), "Should infer config from HF tensors");
let config = config.expect("config should be Some");
assert_eq!(config.vocab_size, Some(1000));
assert_eq!(config.hidden_size, Some(128));
assert_eq!(config.num_layers, Some(4));
assert_eq!(config.intermediate_size, Some(512));
assert_eq!(
config.architecture.as_deref(),
Some("qwen2"),
"model.layers pattern should detect qwen2"
);
assert!(config.num_heads.is_some());
assert!(config.num_kv_heads.is_some());
}
#[test]
fn test_infer_config_gguf_naming() {
let mut tensors = BTreeMap::new();
tensors.insert("token_embd.weight".to_string(), dummy_tensor(vec![64, 500]));
for i in 0..2 {
tensors.insert(format!("blk.{i}.attn_q.weight"), dummy_tensor(vec![64, 64]));
tensors.insert(format!("blk.{i}.attn_k.weight"), dummy_tensor(vec![64, 64]));
tensors.insert(
format!("blk.{i}.ffn_gate.weight"),
dummy_tensor(vec![256, 64]),
);
}
let config = infer_model_config_from_tensors(&tensors);
assert!(config.is_some(), "Should infer config from GGUF tensors");
let config = config.expect("config should be Some");
assert_eq!(config.vocab_size, Some(500));
assert_eq!(config.hidden_size, Some(64));
assert_eq!(config.num_layers, Some(2));
assert_eq!(config.intermediate_size, Some(256));
assert_eq!(
config.architecture.as_deref(),
Some("unknown"),
"blk. pattern should detect unknown (ambiguous GGUF naming)"
);
}
#[test]
fn test_infer_config_gpt2_naming() {
let mut tensors = BTreeMap::new();
tensors.insert("wte.weight".to_string(), dummy_tensor(vec![5000, 256]));
for i in 0..6 {
tensors.insert(
format!("transformer.h.{i}.attn.query.weight"),
dummy_tensor(vec![256, 256]),
);
tensors.insert(
format!("transformer.h.{i}.attn.key.weight"),
dummy_tensor(vec![256, 256]),
);
}
let config = infer_model_config_from_tensors(&tensors);
assert!(config.is_some());
let config = config.expect("config should be Some");
assert_eq!(config.vocab_size, Some(5000));
assert_eq!(config.hidden_size, Some(256));
assert_eq!(config.num_layers, Some(6));
assert_eq!(
config.architecture.as_deref(),
Some("gpt2"),
"transformer.h pattern should detect gpt2"
);
}
#[test]
fn test_infer_config_gqa_model() {
let mut tensors = BTreeMap::new();
tensors.insert(
"model.embed_tokens.weight".to_string(),
dummy_tensor(vec![1000, 128]),
);
tensors.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
dummy_tensor(vec![128, 128]), );
tensors.insert(
"model.layers.0.self_attn.k_proj.weight".to_string(),
dummy_tensor(vec![64, 128]), );
let config = infer_model_config_from_tensors(&tensors);
assert!(config.is_some());
let config = config.expect("config should be Some");
assert_eq!(config.num_heads, Some(2));
assert_eq!(config.num_kv_heads, Some(1));
}
#[test]
fn test_infer_config_mha_with_head_dim_64() {
let mut tensors = BTreeMap::new();
tensors.insert(
"model.embed_tokens.weight".to_string(),
dummy_tensor(vec![1000, 256]),
);
tensors.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
dummy_tensor(vec![256, 256]), );
tensors.insert(
"model.layers.0.self_attn.k_proj.weight".to_string(),
dummy_tensor(vec![256, 256]), );
let config = infer_model_config_from_tensors(&tensors);
assert!(config.is_some());
let config = config.expect("config should be Some");
assert_eq!(config.num_heads, Some(4));
assert_eq!(config.num_kv_heads, Some(4)); }
#[test]
fn test_infer_config_no_embedding_returns_none() {
let mut tensors = BTreeMap::new();
tensors.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
dummy_tensor(vec![128, 128]),
);
let config = infer_model_config_from_tensors(&tensors);
assert!(
config.is_none(),
"Should return None without embedding tensor"
);
}
#[test]
fn test_infer_config_empty_tensors_returns_none() {
let tensors = BTreeMap::new();
let config = infer_model_config_from_tensors(&tensors);
assert!(config.is_none(), "Should return None for empty tensor map");
}
#[test]
fn test_infer_config_1d_embedding_returns_none() {
let mut tensors = BTreeMap::new();
tensors.insert(
"model.embed_tokens.weight".to_string(),
dummy_tensor(vec![128]),
);
let config = infer_model_config_from_tensors(&tensors);
assert!(
config.is_none(),
"Should return None for 1D embedding shape"
);
}
#[test]
fn test_infer_config_word_embeddings_naming() {
let mut tensors = BTreeMap::new();
tensors.insert(
"word_embeddings.weight".to_string(),
dummy_tensor(vec![3000, 768]),
);
tensors.insert(
"blocks.0.attn.query.weight".to_string(),
dummy_tensor(vec![768, 768]),
);
let config = infer_model_config_from_tensors(&tensors);
assert!(config.is_some());
let config = config.expect("config should be Some");
assert_eq!(config.vocab_size, Some(3000));
assert_eq!(config.hidden_size, Some(768));
}
#[test]
fn test_infer_config_rope_type_qwen() {
let mut tensors = BTreeMap::new();
tensors.insert(
"model.embed_tokens.weight".to_string(),
dummy_tensor(vec![1000, 128]),
);
tensors.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
dummy_tensor(vec![128, 128]),
);
tensors.insert(
"model.layers.0.self_attn.q_proj.bias".to_string(),
dummy_tensor(vec![128]),
);
let config = infer_model_config_from_tensors(&tensors).expect("config should be Some");
assert_eq!(config.rope_type, Some(2), "qwen2 should have rope_type=2");
}
#[test]
fn test_infer_config_rope_type_gpt2() {
let mut tensors = BTreeMap::new();
tensors.insert("wte.weight".to_string(), dummy_tensor(vec![1000, 128]));
tensors.insert(
"transformer.h.0.dummy".to_string(),
dummy_tensor(vec![128, 128]),
);
let config = infer_model_config_from_tensors(&tensors).expect("config should be Some");
assert_eq!(config.rope_type, Some(0), "gpt2 should have rope_type=0");
}
#[test]
fn test_infer_config_unknown_architecture() {
let mut tensors = BTreeMap::new();
tensors.insert(
"model.embed_tokens.weight".to_string(),
dummy_tensor(vec![1000, 128]),
);
tensors.insert(
"some.random.tensor".to_string(),
dummy_tensor(vec![128, 128]),
);
let config = infer_model_config_from_tensors(&tensors).expect("config should be Some");
assert_eq!(config.architecture.as_deref(), Some("unknown"));
assert_eq!(config.num_layers, Some(0));
}
#[test]
fn test_infer_config_ffn_up_naming() {
let mut tensors = BTreeMap::new();
tensors.insert(
"token_embd.weight".to_string(),
dummy_tensor(vec![64, 1000]),
);
tensors.insert(
"blk.0.ffn_up.weight".to_string(),
dummy_tensor(vec![256, 64]),
);
let config = infer_model_config_from_tensors(&tensors).expect("config should be Some");
assert_eq!(config.intermediate_size, Some(256));
}
#[test]
fn test_infer_config_fc1_naming() {
let mut tensors = BTreeMap::new();
tensors.insert(
"model.embed_tokens.weight".to_string(),
dummy_tensor(vec![1000, 128]),
);
tensors.insert(
"model.layers.0.fc1.weight".to_string(),
dummy_tensor(vec![512, 128]),
);
let config = infer_model_config_from_tensors(&tensors).expect("config should be Some");
assert_eq!(config.intermediate_size, Some(512));
}
#[test]
fn test_compute_tensor_stats_normal_data() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let stats = compute_tensor_stats("test_tensor", &data);
assert_eq!(stats.name, "test_tensor");
assert_eq!(stats.count, 5);
assert!((stats.mean - 3.0).abs() < 1e-5);
assert!((stats.min - 1.0).abs() < 1e-5);
assert!((stats.max - 5.0).abs() < 1e-5);
assert!(stats.std > 0.0);
assert_eq!(stats.nan_count, 0);
assert_eq!(stats.inf_count, 0);
assert_eq!(stats.zero_count, 0);
}
#[test]
fn test_compute_tensor_stats_empty() {
let data: Vec<f32> = vec![];
let stats = compute_tensor_stats("empty", &data);
assert_eq!(stats.count, 0);
assert_eq!(stats.mean, 0.0);
assert_eq!(stats.min, 0.0);
assert_eq!(stats.max, 0.0);
assert_eq!(stats.std, 0.0);
}
#[test]
fn test_compute_tensor_stats_with_nan() {
let data = vec![1.0, f32::NAN, 3.0];
let stats = compute_tensor_stats("nan_test", &data);
assert_eq!(stats.nan_count, 1);
assert_eq!(stats.count, 3);
}
#[test]
fn test_compute_tensor_stats_with_inf() {
let data = vec![1.0, f32::INFINITY, f32::NEG_INFINITY, 2.0];
let stats = compute_tensor_stats("inf_test", &data);
assert_eq!(stats.inf_count, 2);
assert_eq!(stats.count, 4);
}
#[test]
fn test_compute_tensor_stats_with_zeros() {
let data = vec![0.0, 0.0, 1.0, 0.0];
let stats = compute_tensor_stats("zero_test", &data);
assert_eq!(stats.zero_count, 3);
}
#[test]
fn test_compute_std_normal() {
let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
let mean = 5.0_f32;
let std = compute_std(&data, mean, data.len());
assert!(std > 0.0);
assert!(
std > 1.0 && std < 3.0,
"std={std} should be between 1.0 and 3.0"
);
}
#[test]
fn test_compute_std_single_value() {
let data = vec![42.0];
let std = compute_std(&data, 42.0, 1);
assert_eq!(std, 0.0, "Std of single value should be 0");
}
#[test]
fn test_compute_std_zero_valid_count() {
let data = vec![f32::NAN, f32::NAN];
let std = compute_std(&data, 0.0, 0);
assert_eq!(std, 0.0, "Std with 0 valid should be 0");
}
#[test]
fn test_tensor_accumulator_new() {
let acc = TensorAccumulator::new();
assert_eq!(acc.sum, 0.0);
assert_eq!(acc.nan_count, 0);
assert_eq!(acc.inf_count, 0);
assert_eq!(acc.zero_count, 0);
assert_eq!(acc.valid_count, 0);
assert_eq!(acc.min, f32::INFINITY);
assert_eq!(acc.max, f32::NEG_INFINITY);
}
include!("infer_config_accumulator.rs");
include!("infer_config_gqa_heads.rs");
include!("infer_config_arch_fallback.rs");