use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_map_hf_to_apr_name_with_model_prefix() {
let hf_name = "model.decoder.layers.0.encoder_attn.q_proj.weight";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "decoder.layers.0.encoder_attn.q_proj.weight");
}
#[test]
fn test_map_hf_to_apr_name_without_prefix() {
let hf_name = "decoder.layers.0.self_attn.k_proj.weight";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "decoder.layers.0.self_attn.k_proj.weight");
}
#[test]
fn test_map_hf_to_apr_name_encoder() {
let hf_name = "model.encoder.embed_positions.weight";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "encoder.embed_positions.weight");
}
#[test]
fn test_map_hf_to_apr_name_proj_out() {
let hf_name = "model.decoder.layers.3.fc2.weight";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "decoder.layers.3.fc2.weight");
}
#[test]
fn test_map_hf_to_apr_name_empty() {
let hf_name = "";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "");
}
#[test]
fn test_map_hf_to_apr_name_only_model() {
let hf_name = "model.";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "");
}
#[test]
fn test_map_hf_to_apr_name_no_model_prefix() {
let hf_name = "lm_head.weight";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "lm_head.weight");
}
#[test]
fn test_run_file_not_found() {
let result = run(
Path::new("/nonexistent/model.apr"),
"openai/whisper-tiny",
None,
1e-5,
false,
);
assert!(result.is_err());
}
#[test]
fn test_run_invalid_apr() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not a valid apr file").expect("write");
let result = run(file.path(), "openai/whisper-tiny", None, 1e-5, false);
assert!(result.is_err());
}
#[test]
fn test_run_with_tensor_filter() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(
file.path(),
"openai/whisper-tiny",
Some("decoder.layers.0"),
1e-5,
false,
);
assert!(result.is_err());
}
#[test]
fn test_run_with_json_output() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(
file.path(),
"openai/whisper-tiny",
None,
1e-5,
true, );
assert!(result.is_err());
}
#[test]
fn test_run_with_custom_threshold() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(
file.path(),
"openai/whisper-tiny",
None,
1e-3, false,
);
assert!(result.is_err());
}
#[test]
fn test_run_with_strict_threshold() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(
file.path(),
"openai/whisper-tiny",
None,
1e-8, false,
);
assert!(result.is_err());
}
#[test]
#[cfg(not(feature = "safetensors-compare"))]
fn test_run_feature_disabled() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(file.path(), "openai/whisper-tiny", None, 1e-5, false);
match result {
Err(CliError::FeatureDisabled(feature)) => {
assert_eq!(feature, "safetensors-compare");
}
_ => panic!("Expected FeatureDisabled error"),
}
}
#[test]
fn test_map_hf_to_apr_name_special_chars() {
let hf_name = "model.layer_norm.weight";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "layer_norm.weight");
}
#[test]
fn test_map_hf_to_apr_name_deep_nesting() {
let hf_name = "model.decoder.layers.23.self_attn.k_proj.weight";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "decoder.layers.23.self_attn.k_proj.weight");
}
#[test]
fn test_map_hf_to_apr_name_bias() {
let hf_name = "model.encoder.layers.0.fc1.bias";
let apr_name = map_hf_to_apr_name(hf_name);
assert_eq!(apr_name, "encoder.layers.0.fc1.bias");
}
#[test]
fn map_hf_name_starting_with_model_underscore_is_not_stripped() {
let apr_name = map_hf_to_apr_name("model_weights.layer.weight");
assert_eq!(apr_name, "model_weights.layer.weight");
}
#[test]
fn map_hf_name_model_alone_without_dot_is_unchanged() {
let apr_name = map_hf_to_apr_name("model");
assert_eq!(apr_name, "model");
}
#[test]
fn map_hf_name_double_model_prefix_strips_first_only() {
let apr_name = map_hf_to_apr_name("model.model.layer.weight");
assert_eq!(apr_name, "model.layer.weight");
}
#[test]
fn map_hf_name_single_segment_no_dots() {
let apr_name = map_hf_to_apr_name("embed_tokens");
assert_eq!(apr_name, "embed_tokens");
}
#[test]
fn map_hf_name_model_dot_single_segment() {
let apr_name = map_hf_to_apr_name("model.weight");
assert_eq!(apr_name, "weight");
}
#[test]
fn map_hf_name_preserves_large_layer_index() {
let apr_name = map_hf_to_apr_name("model.decoder.layers.127.self_attn.v_proj.weight");
assert_eq!(apr_name, "decoder.layers.127.self_attn.v_proj.weight");
}
#[test]
fn map_hf_name_preserves_zero_layer_index() {
let apr_name = map_hf_to_apr_name("model.encoder.layers.0.layer_norm.weight");
assert_eq!(apr_name, "encoder.layers.0.layer_norm.weight");
}
#[test]
fn map_hf_name_gpt_style_no_prefix() {
let apr_name = map_hf_to_apr_name("transformer.h.0.attn.c_attn.weight");
assert_eq!(apr_name, "transformer.h.0.attn.c_attn.weight");
}
#[test]
fn map_hf_name_bert_style_with_prefix() {
let apr_name = map_hf_to_apr_name("model.embeddings.word_embeddings.weight");
assert_eq!(apr_name, "embeddings.word_embeddings.weight");
}
#[test]
fn run_nonexistent_file_returns_correct_variant() {
let result = run(
Path::new("/this/path/does/not/exist.apr"),
"test/repo",
None,
1e-5,
false,
);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(
err,
CliError::FeatureDisabled(_) | CliError::FileNotFound(_)
),
"Expected FeatureDisabled or FileNotFound, got: {err:?}"
);
}
#[test]
fn run_with_zero_threshold() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(file.path(), "test/repo", None, 0.0, false);
assert!(result.is_err());
}
#[test]
fn run_with_negative_threshold() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(file.path(), "test/repo", None, -1.0, false);
assert!(result.is_err());
}
#[test]
fn run_with_empty_repo_name() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(file.path(), "", None, 1e-5, false);
assert!(result.is_err());
}
#[test]
fn run_with_all_options_combined() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"not valid").expect("write");
let result = run(
file.path(),
"openai/whisper-tiny",
Some("encoder"),
1e-3,
true,
);
assert!(result.is_err());
}
#[test]
#[cfg(not(feature = "safetensors-compare"))]
fn run_feature_disabled_ignores_all_args() {
let mut file = NamedTempFile::with_suffix(".apr").expect("create temp file");
file.write_all(b"valid or not, doesn't matter")
.expect("write");
let result = run(file.path(), "any/repo", Some("any filter"), 42.0, true);
assert!(
matches!(result, Err(CliError::FeatureDisabled(ref f)) if f == "safetensors-compare"),
"Feature disabled should be returned regardless of args"
);
}