#[allow(unused_imports)]
use super::super::*;
#[cfg(test)]
mod tests_source_parsing {
use super::*;
#[test]
fn test_parse_hf_org_repo() {
let source = Source::parse("hf://openai/whisper-tiny").unwrap();
assert_eq!(
source,
Source::HuggingFace {
org: "openai".to_string(),
repo: "whisper-tiny".to_string(),
file: None,
}
);
}
#[test]
fn test_parse_hf_org_repo_file() {
let source = Source::parse("hf://openai/whisper-tiny/model.safetensors").unwrap();
assert_eq!(
source,
Source::HuggingFace {
org: "openai".to_string(),
repo: "whisper-tiny".to_string(),
file: Some("model.safetensors".to_string()),
}
);
}
#[test]
fn test_parse_hf_nested_file() {
let source =
Source::parse("hf://meta-llama/Llama-2-7b/pytorch_model-00001-of-00002.bin").unwrap();
assert_eq!(
source,
Source::HuggingFace {
org: "meta-llama".to_string(),
repo: "Llama-2-7b".to_string(),
file: Some("pytorch_model-00001-of-00002.bin".to_string()),
}
);
}
#[test]
fn test_parse_local_path() {
let source = Source::parse("./models/model.safetensors").unwrap();
assert_eq!(
source,
Source::Local(PathBuf::from("./models/model.safetensors"))
);
}
#[test]
fn test_parse_url() {
let source = Source::parse("https://example.com/model.safetensors").unwrap();
assert_eq!(
source,
Source::Url("https://example.com/model.safetensors".to_string())
);
}
#[test]
fn test_parse_hf_invalid() {
let result = Source::parse("hf://invalid");
assert!(result.is_err());
}
#[test]
fn test_default_file() {
let hf = Source::HuggingFace {
org: "openai".to_string(),
repo: "whisper".to_string(),
file: None,
};
assert_eq!(hf.default_file(), "model.safetensors");
let hf_with_file = Source::HuggingFace {
org: "openai".to_string(),
repo: "whisper".to_string(),
file: Some("custom.safetensors".to_string()),
};
assert_eq!(hf_with_file.default_file(), "custom.safetensors");
}
}
#[cfg(test)]
mod tests_name_mapping {
use super::*;
#[test]
fn test_whisper_strips_model_prefix() {
let mapped = Architecture::Whisper.map_name("model.encoder.conv1.weight");
assert_eq!(mapped, "encoder.conv1.weight");
}
#[test]
fn test_whisper_no_prefix_unchanged() {
let mapped = Architecture::Whisper.map_name("encoder.conv1.weight");
assert_eq!(mapped, "encoder.conv1.weight");
}
#[test]
fn test_whisper_decoder_layer_norm_strips_prefix() {
let mapped = Architecture::Whisper.map_name("model.decoder.layer_norm.weight");
assert_eq!(mapped, "decoder.layer_norm.weight");
}
#[test]
fn test_auto_preserves_model_prefix() {
let mapped = Architecture::Auto.map_name("model.encoder.layers.0.self_attn.q_proj.weight");
assert_eq!(mapped, "model.encoder.layers.0.self_attn.q_proj.weight");
}
#[test]
fn test_llama_mapping() {
let mapped = Architecture::Llama.map_name("model.layers.0.self_attn.q_proj.weight");
assert_eq!(mapped, "model.layers.0.self_attn.q_proj.weight");
}
#[test]
fn test_bert_mapping() {
let mapped =
Architecture::Bert.map_name("bert.encoder.layer.0.attention.self.query.weight");
assert_eq!(mapped, "bert.encoder.layer.0.attention.self.query.weight");
}
#[test]
fn test_qwen2_mapping() {
let mapped = Architecture::Qwen2.map_name("model.layers.0.self_attn.q_proj.weight");
assert_eq!(mapped, "model.layers.0.self_attn.q_proj.weight");
}
#[test]
fn test_falsify_from_model_type_matches_yaml_architecture_map() {
let supported: &[(&str, Architecture)] = &[
("qwen2", Architecture::Qwen2),
("qwen", Architecture::Qwen2),
("qwen2.5", Architecture::Qwen2),
("qwen3", Architecture::Qwen3),
("qwen3_5", Architecture::Qwen3_5),
("qwen3.5", Architecture::Qwen3_5),
("llama", Architecture::Llama),
("llama3", Architecture::Llama),
("whisper", Architecture::Whisper),
("bert", Architecture::Bert),
("gpt2", Architecture::Gpt2),
("phi", Architecture::Phi),
("phi3", Architecture::Phi),
("mistral", Architecture::Mistral),
("gemma", Architecture::Gemma),
("gemma2", Architecture::Gemma),
("deepseek", Architecture::DeepSeek),
];
for &(model_type, expected) in supported {
let got = Architecture::from_model_type(model_type);
assert_eq!(
got,
Some(expected),
"FALSIFY: from_model_type(\"{model_type}\") = {got:?}, expected Some({expected:?}) \
(tensor-names-v1.yaml)"
);
}
}
#[test]
fn test_falsify_from_model_type_unknown_returns_none() {
let unknowns = ["jamba", "future_model_2027", ""];
for name in &unknowns {
assert_eq!(
Architecture::from_model_type(name),
None,
"FALSIFY: from_model_type(\"{name}\") should return None for unknown arch"
);
}
}
#[test]
fn test_falsify_phi_architecture_mapping() {
assert_eq!(
Architecture::from_model_type("phi"),
Some(Architecture::Phi)
);
assert_eq!(
Architecture::from_model_type("phi3"),
Some(Architecture::Phi)
);
}
}
#[cfg(test)]
mod tests_tensor_expectations {
use super::*;
#[test]
fn test_layer_norm_weight_expectation() {
let exp = TensorExpectation::for_tensor("encoder.layer_norm.weight");
assert!(exp.is_some());
let exp = exp.unwrap();
assert_eq!(exp.mean_range, (0.5, 3.0));
}
#[test]
fn test_layer_norm_bias_expectation() {
let exp = TensorExpectation::for_tensor("decoder.layers.0.self_attn_layer_norm.bias");
assert!(exp.is_some());
let exp = exp.unwrap();
assert_eq!(exp.mean_range, (-0.5, 0.5));
}
#[test]
fn test_linear_weight_expectation() {
let exp = TensorExpectation::for_tensor("encoder.layers.0.fc1.weight");
assert!(exp.is_some());
let exp = exp.unwrap();
assert_eq!(exp.mean_range, (-0.1, 0.1));
}
#[test]
fn test_embedding_expectation() {
let exp = TensorExpectation::for_tensor("decoder.embed_tokens.weight");
assert!(exp.is_some());
}
#[test]
fn test_check_layer_norm_valid() {
let stats = TensorStats {
name: "encoder.layer_norm.weight".to_string(),
count: 384,
min: 0.5,
max: 2.0,
mean: 1.0,
std: 0.3,
nan_count: 0,
inf_count: 0,
zero_count: 0,
};
let exp = TensorExpectation::LAYER_NORM_WEIGHT;
assert!(exp.check(&stats).is_ok());
}
#[test]
fn test_check_layer_norm_invalid_mean() {
let stats = TensorStats {
name: "decoder.layer_norm.weight".to_string(),
count: 384,
min: 5.0,
max: 15.0,
mean: 11.0, std: 2.0,
nan_count: 0,
inf_count: 0,
zero_count: 0,
};
let exp = TensorExpectation::LAYER_NORM_WEIGHT;
let result = exp.check(&stats);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("mean=11"));
assert!(err.contains("outside expected range"));
}
#[test]
fn test_rmsnorm_weight_detection() {
let exp = TensorExpectation::for_tensor("model.layers.0.input_layernorm.weight");
assert!(exp.is_some());
assert_eq!(exp.unwrap().mean_range, (-1.0, 10.0));
let exp = TensorExpectation::for_tensor("model.layers.5.post_attention_layernorm.weight");
assert!(exp.is_some());
assert_eq!(exp.unwrap().mean_range, (-1.0, 10.0));
let exp = TensorExpectation::for_tensor("model.norm.weight");
assert!(exp.is_some());
assert_eq!(exp.unwrap().mean_range, (-1.0, 10.0));
}
#[test]
fn test_rmsnorm_accepts_trained_weights() {
let stats = TensorStats {
name: "model.layers.0.input_layernorm.weight".to_string(),
count: 2048,
min: -0.2,
max: 0.8,
mean: 0.05, std: 0.15,
nan_count: 0,
inf_count: 0,
zero_count: 0,
};
let exp = TensorExpectation::RMSNORM_WEIGHT;
assert!(exp.check(&stats).is_ok());
}
}
#[cfg(test)]
mod tests_converter_builder {
use super::*;
#[test]
fn test_converter_builder_chain() {
let converter = AprConverter::new()
.source("hf://openai/whisper-tiny")
.unwrap()
.architecture(Architecture::Whisper)
.validate(ValidationConfig::Strict)
.quantize(QuantizationType::Int8)
.compress(Compression::Lz4);
assert_eq!(converter.architecture, Architecture::Whisper);
assert_eq!(converter.validation, ValidationConfig::Strict);
assert_eq!(converter.quantize, Some(QuantizationType::Int8));
assert_eq!(converter.compress, Some(Compression::Lz4));
}
#[test]
fn test_converter_no_source_error() {
let converter = AprConverter::new();
let result = converter.convert();
assert!(result.is_err());
}
}
#[cfg(test)]
mod tests_import_options {
use super::*;
#[test]
fn test_default_options() {
let opts = ImportOptions::default();
assert_eq!(opts.architecture, Architecture::Auto);
assert_eq!(opts.validation, ValidationConfig::Strict);
assert_eq!(opts.quantize, None);
assert_eq!(opts.compress, None);
assert!(!opts.strict);
assert!(opts.cache);
}
}
include!("core_conversion.rs");
include!("core_convert.rs");
include!("core_rosetta_gqa.rs");
include!("core_q4k_q6k_roundtrip.rs");
include!("streaming_quantize_test.rs");