use super::*;
use std::fs;
fn temp_dir(name: &str) -> PathBuf {
let dir = std::env::temp_dir().join(format!("mlxrs_audio_load_{}_{}", std::process::id(), name));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn get_model_path_resolves_local_path() {
let dir = temp_dir("resolves_local");
let s = dir.to_string_lossy().into_owned();
let resolved = get_model_path(&s).expect("local existing path resolves");
assert_eq!(resolved, dir);
}
#[test]
fn get_model_path_rejects_hf_hub_path() {
let err = get_model_path("mlx-community/silero-vad")
.expect_err("non-local repo id must be rejected, not silently fetched");
let msg = err.to_string();
assert!(
msg.contains("local on-disk directory"),
"error should explain the no-network policy, got: {msg}"
);
assert!(
msg.contains("huggingface-cli download"),
"error should point at the out-of-process workaround, got: {msg}"
);
}
#[test]
fn get_model_path_local_missing_is_clear_error() {
let err = get_model_path("/definitely/does/not/exist/mlxrs-a9-missing")
.expect_err("missing local path must error, not fetch");
let msg = err.to_string();
assert!(
msg.contains("local path not found"),
"error should name the local-not-found case, got: {msg}"
);
}
#[test]
fn load_config_reads_small_json() {
let dir = temp_dir("load_config_small");
let body = r#"{ "model_type": "silero_vad", "hidden_size": 128 }"#;
fs::write(dir.join("config.json"), body).unwrap();
let text = load_config(&dir).expect("config.json reads");
assert_eq!(text, body);
}
#[test]
fn load_config_missing_is_clear_error() {
let dir = temp_dir("load_config_missing");
let err = load_config(&dir).expect_err("missing config.json must error");
let msg = err.to_string();
assert!(
msg.contains("audio model config not found"),
"error should name the missing-config case, got: {msg}"
);
}
#[test]
fn apply_quantization_passes_through_unquantized_model() {
let body = r#"{ "model_type": "silero_vad", "hidden_size": 128 }"#;
let q = apply_quantization(body).expect("dense config parses");
assert!(
q.is_none(),
"no quantization block → Ok(None), got Some(_) — broke the dense-model path"
);
}
#[test]
fn apply_quantization_parses_global_block() {
let body = r#"{
"model_type": "silero_vad",
"quantization": { "group_size": 64, "bits": 4 }
}"#;
let q = apply_quantization(body).expect("quantized config parses");
let plq = q.expect("Some(PerLayerQuantization) for quantized config");
let global = plq.quantization.expect("global default present");
assert_eq!(global.group_size, 64);
assert_eq!(global.bits, 4);
}
#[test]
fn base_load_model_local_path_resolves() {
let dir = temp_dir("base_load_local");
let body = r#"{ "model_type": "silero_vad" }"#;
fs::write(dir.join("config.json"), body).unwrap();
let bundle = base_load_model(&dir.to_string_lossy()).expect("local dir loads");
assert_eq!(bundle.model_path(), dir);
assert_eq!(bundle.config_json(), body);
assert!(bundle.quantization().is_none());
}
#[test]
fn apply_quantization_parses_quantization_config_key() {
let body = r#"{
"model_type": "voxtral",
"quantization_config": { "bits": 4, "group_size": 64 }
}"#;
let q = apply_quantization(body).expect("HF-key config parses");
let plq = q.expect("Some(PerLayerQuantization) for HF-key config");
let global = plq.quantization.expect("global default present");
assert_eq!(global.group_size, 64);
assert_eq!(global.bits, 4);
}
#[test]
fn apply_quantization_defaults_missing_group_size_to_64() {
let body = r#"{
"model_type": "voxtral",
"quantization": { "bits": 4 }
}"#;
let q = apply_quantization(body).expect("missing-group_size config parses");
let plq = q.expect("Some(PerLayerQuantization) for default-injected config");
let global = plq.quantization.expect("global default present");
assert_eq!(global.group_size, 64, "audio default group_size is 64");
assert_eq!(global.bits, 4);
}
#[test]
fn apply_quantization_top_level_takes_precedence_over_quantization_config() {
let body = r#"{
"model_type": "voxtral",
"quantization": { "bits": 8, "group_size": 32 },
"quantization_config": { "bits": 4, "group_size": 64 }
}"#;
let q = apply_quantization(body).expect("both-keys config parses");
let plq = q.expect("Some(PerLayerQuantization) for both-keys config");
let global = plq.quantization.expect("global default present");
assert_eq!(global.bits, 8, "top-level `quantization` wins");
assert_eq!(global.group_size, 32, "top-level `quantization` wins");
}
#[test]
fn apply_quantization_null_primary_falls_back_to_quantization_config() {
let body = r#"{
"model_type": "voxtral",
"quantization": null,
"quantization_config": { "bits": 4, "group_size": 64 }
}"#;
let q = apply_quantization(body).expect("null-primary config falls back");
let plq = q.expect("Some(PerLayerQuantization) from quantization_config fallback");
let global = plq.quantization.expect("global default present");
assert_eq!(global.bits, 4, "fallback block's `bits` selected");
assert_eq!(
global.group_size, 64,
"fallback block's `group_size` selected"
);
}
#[test]
fn apply_quantization_only_null_quantization_config_returns_none() {
let body = r#"{ "model_type": "voxtral", "quantization_config": null }"#;
let q = apply_quantization(body).expect("null-only quantization_config parses as dense");
assert!(
q.is_none(),
"null quantization_config → Ok(None), matches upstream's no-op early return"
);
}
#[test]
fn apply_quantization_only_null_quantization_returns_none() {
let body = r#"{ "model_type": "voxtral", "quantization": null }"#;
let q = apply_quantization(body).expect("null-only quantization parses as dense");
assert!(
q.is_none(),
"null quantization → Ok(None), matches upstream's no-op early return"
);
}
#[test]
fn apply_quantization_both_null_returns_none() {
let body = r#"{
"model_type": "voxtral",
"quantization": null,
"quantization_config": null
}"#;
let q = apply_quantization(body).expect("both-null config parses as dense");
assert!(
q.is_none(),
"both keys null → Ok(None), matches upstream's no-op early return"
);
}
#[test]
fn get_model_path_hf_url_prefix_yields_clean_repo_id_in_error() {
let err = get_model_path("hf://mlx-community/silero-vad")
.expect_err("hf:// repo id must be rejected with a clean workaround");
let Error::OutOfRange(payload) = &err else {
panic!("hf:// rejection must be OutOfRange, got: {err:?}");
};
let value = payload.value();
assert!(
value.contains("repo_id=mlx-community/silero-vad"),
"value should embed the clean repo id (repo_id=...), got: {value}"
);
let repo_id_segment = value
.split_once("repo_id=")
.map(|(_, after)| after)
.expect("repo_id= segment present in value");
assert!(
!repo_id_segment.contains("hf://"),
"clean repo_id must not embed the `hf://` prefix, got: {repo_id_segment}"
);
assert!(
payload
.context()
.contains("huggingface-cli download <repo>"),
"context must reference the huggingface-cli workaround, got: {}",
payload.context()
);
}
#[test]
fn get_model_path_https_huggingface_url_yields_clean_repo_id_in_error() {
let err = get_model_path("https://huggingface.co/mlx-community/silero-vad")
.expect_err("https://huggingface.co/ URL must be rejected with a clean workaround");
let Error::OutOfRange(payload) = &err else {
panic!("https://huggingface.co/ rejection must be OutOfRange, got: {err:?}");
};
let value = payload.value();
assert!(
value.contains("repo_id=mlx-community/silero-vad"),
"value should embed the clean repo id (repo_id=...), got: {value}"
);
let repo_id_segment = value
.split_once("repo_id=")
.map(|(_, after)| after)
.expect("repo_id= segment present in value");
assert!(
!repo_id_segment.contains("https://huggingface.co/"),
"clean repo_id must not embed the full URL, got: {repo_id_segment}"
);
assert!(
payload
.context()
.contains("huggingface-cli download <repo>"),
"context must reference the huggingface-cli workaround, got: {}",
payload.context()
);
}
#[test]
#[allow(non_snake_case)]
fn per_domain_load_modules_expose_uniform_MODEL_REMAPPING() {
let tts: &[(&str, &str)] = crate::audio::tts::load::MODEL_REMAPPING;
let stt: &[(&str, &str)] = crate::audio::stt::load::MODEL_REMAPPING;
let sts: &[(&str, &str)] = crate::audio::sts::load::MODEL_REMAPPING;
let vad: &[(&str, &str)] = crate::audio::vad::load::MODEL_REMAPPING;
let lid: &[(&str, &str)] = crate::audio::lid::load::MODEL_REMAPPING;
let codec: &[(&str, &str)] = crate::audio::codec::load::MODEL_REMAPPING;
assert!(
codec.is_empty(),
"codec's MODEL_REMAPPING must be empty per upstream's no-remapping shape, got: {codec:?}"
);
assert!(
!tts.is_empty(),
"TTS MODEL_REMAPPING must mirror upstream's non-empty alias table"
);
assert!(
!stt.is_empty(),
"STT MODEL_REMAPPING must mirror upstream's non-empty alias table"
);
assert!(
!sts.is_empty(),
"STS MODEL_REMAPPING must mirror upstream's non-empty alias table"
);
assert!(
!vad.is_empty(),
"VAD MODEL_REMAPPING must mirror upstream's non-empty alias table"
);
assert!(
!lid.is_empty(),
"LID MODEL_REMAPPING must mirror upstream's non-empty alias table"
);
}