use std::sync::atomic::{AtomicU64, Ordering};
use super::*;
use crate::embeddings::model::EmbeddingModelOutput;
fn mock_config_json(model_type: &str) -> String {
format!(
r#"{{
"model_type": "{model_type}",
"hidden_size": 8,
"num_hidden_layers": 2,
"vocab_size": 5,
"mock_extra": 7
}}"#
)
}
struct MockLoadedEmbedding {
hidden: usize,
}
impl EmbeddingModel for MockLoadedEmbedding {
fn forward(&self, input_ids: &Array, _attention_mask: &Array) -> Result<EmbeddingModelOutput> {
let (batch, seq) = match input_ids.shape().as_slice() {
[b, s] => (*b, *s),
_ => {
let shape = input_ids.shape();
return Err(Error::RankMismatch(RankMismatchPayload::new(
"MockLoadedEmbedding::forward expects rank-2 (batch, seq) ids",
shape.len() as u32,
shape,
)));
}
};
let data = vec![0.0_f32; batch * seq * self.hidden];
let last_hidden_state = Array::from_slice::<f32>(&data, &(batch, seq, self.hidden))?;
Ok(EmbeddingModelOutput::from_hidden_state(last_hidden_state))
}
}
fn mock_constructor() -> EmbeddingModelConstructor {
Box::new(
|loaded: &LoadedEmbeddingModel| -> Result<Box<dyn EmbeddingModel>> {
assert!(
!loaded.weights.is_empty(),
"constructor should receive the loaded weights"
);
assert!(
loaded.config_json.contains("mock_extra"),
"raw config json should reach the constructor"
);
Ok(Box::new(MockLoadedEmbedding { hidden: 4 }))
},
)
}
fn fresh_dir(tag: &str) -> PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!(
"mlxrs-emb-factory-{tag}-{}-{n}",
std::process::id()
));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
fn write_tokenizer(dir: &Path) {
use tokenizers::{
Tokenizer as HfTokenizer, models::wordlevel::WordLevel, pre_tokenizers::whitespace::Whitespace,
};
let vocab = [("a", 0u32), ("b", 1), ("c", 2)]
.iter()
.map(|(w, i)| ((*w).to_string(), *i))
.collect();
let wl = WordLevel::builder()
.vocab(vocab)
.unk_token("a".to_string())
.build()
.unwrap();
let mut hf = HfTokenizer::new(wl);
hf.with_pre_tokenizer(Some(Whitespace {}));
hf.save(dir.join("tokenizer.json"), false).unwrap();
}
fn write_model_dir_no_tokenizer(dir: &Path, model_type: &str) {
std::fs::write(dir.join("config.json"), mock_config_json(model_type)).unwrap();
let mut weights: EmbeddingWeights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2usize, 2)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
}
fn write_model_dir(dir: &Path, model_type: &str) {
write_model_dir_no_tokenizer(dir, model_type);
write_tokenizer(dir);
}
fn write_pooling_config(dir: &Path) {
let pooling_dir = dir.join("1_Pooling");
std::fs::create_dir_all(&pooling_dir).unwrap();
std::fs::write(
pooling_dir.join("config.json"),
r#"{"word_embedding_dimension": 4, "pooling_mode_mean_tokens": true}"#,
)
.unwrap();
}
#[test]
fn load_dispatches_to_registered_mock_and_returns_context() {
let dir = fresh_dir("dispatch");
write_model_dir(&dir, "mockemb");
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let ctx = load(&config, ®istry).expect("load should succeed");
assert_eq!(ctx.model_type, "mockemb");
assert!(ctx.pooling.is_none());
let ids = Array::from_slice::<i32>(&[0, 1, 2], &(1usize, 3)).unwrap();
let mask = Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(1usize, 3)).unwrap();
let out = ctx.model.forward(&ids, &mask).unwrap();
assert_eq!(out.last_hidden_state().shape(), vec![1, 3, 4]);
let tok_ids = ctx.tokenizer.encode("a b c", false).unwrap();
assert_eq!(tok_ids.len(), 3);
}
#[test]
fn from_id_resolves_as_local_path() {
let dir = fresh_dir("idpath");
write_model_dir(&dir, "mockemb");
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_id(dir.to_str().unwrap());
assert_eq!(config.model_directory(), dir.as_path());
let ctx = load(&config, ®istry).expect("id-as-local-path load should succeed");
assert_eq!(ctx.model_type, "mockemb");
}
#[test]
fn pooling_config_is_loaded_when_present() {
let dir = fresh_dir("pooling");
write_model_dir(&dir, "mockemb");
write_pooling_config(&dir);
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let ctx = load(&config, ®istry).expect("load with pooling config");
let pooling = ctx.pooling.expect("pooling config should be parsed");
assert_eq!(pooling.strategy(), crate::embeddings::PoolingStrategy::Mean);
assert_eq!(pooling.dimension(), Some(4));
}
#[test]
fn unknown_model_type_is_recoverable_error() {
let dir = fresh_dir("unknown");
write_model_dir(&dir, "nope");
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("unknown model_type must error");
};
let msg = err.to_string();
assert!(msg.contains("unsupported model type"), "got: {msg}");
assert!(msg.contains("nope"), "error should name the type: {msg}");
}
#[test]
fn missing_config_json_is_recoverable_error() {
let dir = fresh_dir("noconfig");
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("missing config.json must error");
};
assert!(
err.to_string().contains("config.json"),
"error should name config.json: {err}"
);
}
#[test]
fn oversized_config_json_is_recoverable_error() {
let dir = fresh_dir("bigconfig");
let mut huge = String::from("{\"model_type\": \"mockemb\", \"pad\": \"");
huge.push_str(&"x".repeat((MAX_CONFIG_BYTES as usize) + 16));
huge.push_str("\"}");
std::fs::write(dir.join("config.json"), huge).unwrap();
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("oversized config.json must error");
};
assert!(
err.to_string().contains("cap"),
"error should mention the byte cap: {err}"
);
}
#[test]
fn missing_weights_is_recoverable_error() {
let dir = fresh_dir("noweights");
std::fs::write(dir.join("config.json"), mock_config_json("mockemb")).unwrap();
write_tokenizer(&dir);
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("missing weights must error");
};
assert!(
err.to_string().contains("no model weights"),
"error should name the missing weights: {err}"
);
}
#[test]
fn empty_model_directory_is_recoverable_error_before_any_scan() {
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
for config in [
EmbeddingModelConfiguration::from_directory(""),
EmbeddingModelConfiguration::from_id(""),
EmbeddingModelConfiguration::from_directory(PathBuf::new()),
] {
assert!(
config.model_directory().as_os_str().is_empty(),
"fixture precondition: the model directory path must be empty"
);
let Err(err) = load(&config, ®istry) else {
panic!("an empty model directory path must be a recoverable error, not a load");
};
assert!(
matches!(err, Error::EmptyInput(_)),
"expected EmptyInput error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("model directory path must not be empty"),
"the error should explain the empty-path rejection; got: {msg}"
);
assert!(
!msg.contains("config.json") && !msg.contains("no model weights"),
"the empty path must be rejected before config/shard resolution; got: {msg}"
);
}
}
#[test]
fn empty_tokenizer_source_is_recoverable_error() {
let model_dir = fresh_dir("empty-tok-src");
write_model_dir(&model_dir, "mockemb");
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&model_dir).with_tokenizer_source("");
assert!(
config.tokenizer_directory().as_os_str().is_empty(),
"fixture precondition: the tokenizer directory path must be empty"
);
let Err(err) = load(&config, ®istry) else {
panic!("an empty tokenizer_source path must be a recoverable error");
};
assert!(
matches!(err, Error::EmptyInput(_)),
"expected EmptyInput error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("tokenizer directory path must not be empty"),
"the error should explain the empty tokenizer-path rejection; got: {msg}"
);
}
#[test]
fn collect_glob_shards_rejects_empty_dir() {
for suffix in ["**/model*.safetensors", "weight*.safetensors"] {
let Err(err) = collect_glob_shards(Path::new(""), suffix) else {
panic!("collect_glob_shards must reject an empty dir, not scan the filesystem root");
};
assert!(
matches!(err, Error::EmptyInput(_)),
"expected EmptyInput error; got {err:?}"
);
assert!(
err
.to_string()
.contains("model directory path must not be empty"),
"the error should explain the empty-path rejection; got: {err}"
);
}
}
#[test]
fn tokenizer_source_loads_from_separate_directory() {
let model_dir = fresh_dir("split-model");
write_model_dir_no_tokenizer(&model_dir, "mockemb");
assert!(!model_dir.join("tokenizer.json").exists());
let tok_dir = fresh_dir("split-tok");
write_tokenizer(&tok_dir);
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config =
EmbeddingModelConfiguration::from_directory(&model_dir).with_tokenizer_source(&tok_dir);
assert_eq!(config.tokenizer_directory(), tok_dir.as_path());
let ctx = load(&config, ®istry).expect("split-tokenizer load should succeed");
let ids = ctx.tokenizer.encode("a b c", false).unwrap();
assert_eq!(ids.len(), 3);
}
#[test]
fn registry_contains_and_separator_normalization() {
let registry = EmbeddingModelTypeRegistry::new().with("xlm-roberta", mock_constructor());
assert!(registry.contains("xlm-roberta"));
assert!(registry.contains("xlm_roberta"));
assert!(!registry.contains("bert"));
assert_eq!(remap_model_type("xlm-roberta"), "xlm_roberta");
assert_eq!(remap_model_type("bert"), "bert");
}
#[test]
fn register_replaces_and_returns_previous() {
let mut registry = EmbeddingModelTypeRegistry::new();
assert!(registry.register("mockemb", mock_constructor()).is_none());
assert!(registry.register("mockemb", mock_constructor()).is_some());
}
#[test]
fn separator_normalized_config_dispatches() {
let dir = fresh_dir("sep-dispatch");
write_model_dir(&dir, "xlm-roberta");
let registry = EmbeddingModelTypeRegistry::new().with("xlm_roberta", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let ctx = load(&config, ®istry).expect("separator-normalized dispatch");
assert_eq!(ctx.model_type, "xlm_roberta");
}
#[test]
fn unsupported_model_type_does_not_touch_weights() {
let dir = fresh_dir("unsupported-cheap");
std::fs::write(dir.join("config.json"), mock_config_json("nope")).unwrap();
std::fs::write(
dir.join("model.safetensors"),
b"this is not a safetensors file",
)
.unwrap();
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("unsupported model_type must error");
};
let msg = err.to_string();
assert!(
msg.contains("unsupported model type"),
"expected the unsupported-model error before any weight load, got: {msg}"
);
assert!(msg.contains("nope"), "error should name the type: {msg}");
}
#[test]
fn extract_finds_present_string_field() {
let r = extract_string_field(
r#"{"model_type": "bert", "hidden_size": 768}"#,
"model_type",
);
assert_eq!(r.unwrap(), Some("bert".to_owned()));
}
#[test]
fn extract_skips_other_typed_values_before_match() {
let src = r#"{
"nested": {"a": [1, 2, {"deep": true}], "b": "x"},
"arr": [true, false, null, 3.5e2],
"n": -12.0,
"flag": false,
"model_type": "qwen3"
}"#;
assert_eq!(
extract_string_field(src, "model_type").unwrap(),
Some("qwen3".to_owned())
);
}
#[test]
fn extract_returns_none_for_absent_field() {
let r = extract_string_field(r#"{"hidden_size": 768, "vocab_size": 5}"#, "model_type");
assert_eq!(r.unwrap(), None);
}
#[test]
fn extract_returns_none_for_empty_object() {
assert_eq!(extract_string_field("{}", "model_type").unwrap(), None);
}
#[test]
fn extract_decodes_json_escapes() {
let src = r#"{"model_type": "ab\nc😀"}"#;
assert_eq!(
extract_string_field(src, "model_type").unwrap(),
Some("ab\nc😀".to_owned())
);
}
#[test]
fn extract_rejects_non_string_matched_value() {
let r = extract_string_field(r#"{"model_type": 123}"#, "model_type");
assert!(r.is_err(), "a numeric model_type must be rejected");
}
#[test]
fn extract_rejects_non_object_root() {
assert!(extract_string_field(r#"["model_type"]"#, "model_type").is_err());
assert!(extract_string_field(r#""bert""#, "model_type").is_err());
}
#[test]
fn extract_rejects_malformed_json() {
assert!(extract_string_field(r#"{"model_type": "bert""#, "model_type").is_err());
assert!(extract_string_field(r#"{"model_type": "bert",}"#, "model_type").is_err());
assert!(extract_string_field(r#"{"model_type" "bert"}"#, "model_type").is_err());
assert!(extract_string_field(r#"{"model_type": "bert"} junk"#, "model_type").is_err());
assert!(extract_string_field(r#"{"model_type": "bert", "x": {"a": 1"#, "model_type").is_err());
}
#[test]
fn extract_rejects_pathologically_deep_nesting() {
let mut src = String::from(r#"{"deep": "#);
src.push_str(&"[".repeat(MAX_JSON_DEPTH + 8));
let r = extract_string_field(&src, "model_type");
assert!(
r.is_err(),
"pathological nesting must be a recoverable error"
);
}
#[test]
fn extract_duplicate_key_last_wins() {
let src = r#"{"model_type": "first", "model_type": "second"}"#;
assert_eq!(
extract_string_field(src, "model_type").unwrap(),
Some("second".to_owned()),
"last duplicate key must win"
);
let src3 = r#"{"model_type": "a", "x": 1, "model_type": "b", "model_type": "c"}"#;
assert_eq!(
extract_string_field(src3, "model_type").unwrap(),
Some("c".to_owned())
);
}
#[test]
fn extract_duplicate_key_non_string_later_value_is_rejected() {
let src = r#"{"model_type": "ok", "model_type": 7}"#;
assert!(
extract_string_field(src, "model_type").is_err(),
"a non-string duplicate value must be rejected"
);
}
#[test]
fn extract_rejects_rfc8259_malformed_numbers() {
for bad in [
r#"{"x": 01}"#, r#"{"x": 00}"#, r#"{"x": 1.}"#, r#"{"x": 1e}"#, r#"{"x": 1e+}"#, r#"{"x": 1E-}"#, r#"{"x": -}"#, r#"{"x": .5}"#, r#"{"x": 1..2}"#, ] {
assert!(
extract_string_field(bad, "model_type").is_err(),
"malformed number must be rejected: {bad}"
);
}
}
#[test]
fn extract_accepts_rfc8259_valid_numbers() {
for good in [
r#"{"x": 1}"#,
r#"{"x": 1.0}"#,
r#"{"x": 1e3}"#,
r#"{"x": -1.5e-2}"#,
r#"{"x": 0}"#,
r#"{"x": 0.5}"#,
r#"{"x": 10}"#,
r#"{"x": 1E+10}"#,
r#"{"x": -0}"#,
] {
assert_eq!(
extract_string_field(good, "model_type").unwrap(),
None,
"valid number must be accepted (model_type absent ⇒ None): {good}"
);
}
assert!(extract_string_field(r#"{"model_type": 1.0}"#, "model_type").is_err());
}
#[cfg(unix)]
#[test]
fn name_bytes_match_classifies_shard_names_at_byte_level() {
use std::{ffi::OsString, os::unix::ffi::OsStringExt};
let model_pat = (b"model".as_slice(), b".safetensors".as_slice());
let weight_pat = (b"weight".as_slice(), b".safetensors".as_slice());
let os = |b: &[u8]| OsString::from_vec(b.to_vec());
assert!(name_bytes_match(
&os(b"model.safetensors"),
model_pat.0,
model_pat.1
));
assert!(name_bytes_match(
&os(b"model-00001-of-00002.safetensors"),
model_pat.0,
model_pat.1
));
assert!(name_bytes_match(
&os(b"weights.safetensors"),
weight_pat.0,
weight_pat.1
));
let mut bad = b"model".to_vec();
bad.push(0xFF);
bad.extend_from_slice(b".safetensors");
assert!(
name_bytes_match(&os(&bad), model_pat.0, model_pat.1),
"a non-UTF-8 `model\\xff.safetensors` leaf must match the model shard pattern"
);
assert!(
OsString::from_vec(bad.clone()).to_str().is_none(),
"fixture precondition: the leaf must be non-UTF-8"
);
assert!(!name_bytes_match(
&os(b"tokenizer.json"),
model_pat.0,
model_pat.1
));
assert!(!name_bytes_match(
&os(b"model.bin"),
model_pat.0,
model_pat.1
));
assert!(!name_bytes_match(
&os(b"model.safetensors"),
weight_pat.0,
weight_pat.1
));
assert!(!name_bytes_match(
&os(b".safetensors"),
model_pat.0,
model_pat.1
));
assert!(!name_bytes_match(
&os(&[b'x', 0xFF]),
model_pat.0,
model_pat.1
));
}
fn write_one_tensor(path: &Path, tensor_key: &str) {
let mut weights: EmbeddingWeights = HashMap::new();
weights.insert(
tensor_key.to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap(),
);
crate::io::save_safetensors(path, &weights).unwrap();
}
#[test]
fn load_weights_recurses_and_prefixes_nested_shards() {
let dir = fresh_dir("recursive");
write_one_tensor(&dir.join("model.safetensors"), "embeddings.weight");
let vision = dir.join("vision_model");
std::fs::create_dir_all(&vision).unwrap();
write_one_tensor(&vision.join("model.safetensors"), "encoder.weight");
let weights = load_weights(&dir).expect("recursive load");
assert!(
weights.contains_key("embeddings.weight"),
"root-shard key must be verbatim; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert!(
weights.contains_key("vision_model.encoder.weight"),
"nested-shard key must be `<folder>.<key>`; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert_eq!(weights.len(), 2);
}
#[test]
fn load_weights_handles_nested_only_models() {
let dir = fresh_dir("nested-only");
let vision = dir.join("vision_model");
let text = dir.join("text_model");
std::fs::create_dir_all(&vision).unwrap();
std::fs::create_dir_all(&text).unwrap();
write_one_tensor(&vision.join("model.safetensors"), "w");
write_one_tensor(&text.join("model.safetensors"), "w");
let weights = load_weights(&dir).expect("nested-only load");
assert!(weights.contains_key("vision_model.w"));
assert!(weights.contains_key("text_model.w"));
assert_eq!(weights.len(), 2);
}
#[test]
fn load_weights_prefixes_with_immediate_parent_only() {
let dir = fresh_dir("deep-prefix");
let deep = dir.join("a").join("b");
std::fs::create_dir_all(&deep).unwrap();
write_one_tensor(&deep.join("model.safetensors"), "w");
let weights = load_weights(&dir).expect("deep nested load");
assert!(
weights.contains_key("b.w"),
"prefix must be the immediate parent folder name; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert!(!weights.contains_key("a.b.w"));
}
#[test]
fn load_weights_backcompat_weight_glob_is_root_only() {
let dir = fresh_dir("backcompat");
write_one_tensor(&dir.join("weights.safetensors"), "root.w");
let sub = dir.join("sub");
std::fs::create_dir_all(&sub).unwrap();
write_one_tensor(&sub.join("weight.safetensors"), "nested.w");
let weights = load_weights(&dir).expect("back-compat load");
assert!(weights.contains_key("root.w"));
assert!(
!weights.contains_key("sub.nested.w") && !weights.contains_key("nested.w"),
"the legacy weight glob is root-only; nested weight*.safetensors must be ignored; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert_eq!(weights.len(), 1);
}
#[test]
fn load_weights_excludes_hidden_directory_shards() {
let dir = fresh_dir("hidden-dir");
let vision = dir.join("vision_model");
std::fs::create_dir_all(&vision).unwrap();
write_one_tensor(&vision.join("model.safetensors"), "encoder.weight");
let hidden = dir.join(".hidden");
std::fs::create_dir_all(&hidden).unwrap();
write_one_tensor(&hidden.join("model.safetensors"), "stale.weight");
let weights = load_weights(&dir).expect("load with a hidden sibling dir");
assert!(
weights.contains_key("vision_model.encoder.weight"),
"the normal nested shard must load; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert!(
!weights.contains_key(".hidden.stale.weight")
&& !weights.contains_key("hidden.stale.weight")
&& !weights.contains_key("stale.weight"),
"a `.`-prefixed directory's shard must be excluded (include_hidden=False); got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert_eq!(weights.len(), 1);
}
#[test]
fn load_weights_excludes_hidden_file_shards() {
let dir = fresh_dir("hidden-file");
write_one_tensor(&dir.join("model.safetensors"), "real.weight");
write_one_tensor(&dir.join(".model.safetensors"), "hidden.weight");
let weights = load_weights(&dir).expect("load with a hidden-file sibling");
assert!(weights.contains_key("real.weight"));
assert!(
!weights.contains_key("hidden.weight"),
"a `.`-prefixed shard file must be excluded (include_hidden=False); got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert_eq!(weights.len(), 1);
}
#[cfg(unix)]
#[test]
fn load_weights_follows_symlinked_component_directory() {
let base = fresh_dir("symlink-dir");
let model_dir = base.join("model");
std::fs::create_dir_all(&model_dir).unwrap();
write_one_tensor(&model_dir.join("model.safetensors"), "root.weight");
let real_text = base.join("real_text_model");
std::fs::create_dir_all(&real_text).unwrap();
write_one_tensor(&real_text.join("model.safetensors"), "encoder.weight");
std::os::unix::fs::symlink(&real_text, model_dir.join("text_model")).unwrap();
let weights = load_weights(&model_dir).expect("symlinked component dir must load");
assert!(weights.contains_key("root.weight"));
assert!(
weights.contains_key("text_model.encoder.weight"),
"the symlinked component must load with the SYMLINK name as prefix; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert!(
!weights.contains_key("real_text_model.encoder.weight"),
"the prefix must be the path-as-walked symlink name, not the canonical target"
);
assert_eq!(weights.len(), 2);
}
#[cfg(unix)]
#[test]
fn load_weights_symlink_cycle_terminates() {
let dir = fresh_dir("symlink-cycle");
write_one_tensor(&dir.join("model.safetensors"), "root.weight");
let sub = dir.join("sub");
std::fs::create_dir_all(&sub).unwrap();
write_one_tensor(&sub.join("model.safetensors"), "nested.weight");
std::os::unix::fs::symlink(&dir, sub.join("loop")).unwrap();
let weights = load_weights(&dir).expect("a symlink cycle must terminate, not hang");
assert!(
weights.contains_key("root.weight"),
"the root shard must load; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert!(
weights.contains_key("sub.nested.weight"),
"the nested shard must load with its immediate-parent prefix; got {:?}",
weights.keys().collect::<Vec<_>>()
);
for (key, value) in &weights {
assert_eq!(
value.shape(),
vec![2],
"every merged weight (including a cycle alias) must be a real shard \
tensor; key {key:?} has shape {:?}",
value.shape()
);
}
}
#[cfg(unix)]
#[test]
fn load_weights_walks_both_aliases_to_one_real_directory() {
let dir = fresh_dir("alias-dirs");
write_one_tensor(&dir.join("model.safetensors"), "root.weight");
let real_text = dir.join("real_text_model");
std::fs::create_dir_all(&real_text).unwrap();
write_one_tensor(&real_text.join("model.safetensors"), "encoder.weight");
std::os::unix::fs::symlink(&real_text, dir.join("text_model")).unwrap();
let weights = load_weights(&dir).expect("aliased component dirs must both load");
assert!(weights.contains_key("root.weight"));
assert!(
weights.contains_key("real_text_model.encoder.weight"),
"the real directory alias must be walked; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert!(
weights.contains_key("text_model.encoder.weight"),
"the symlink alias to the SAME real dir must ALSO be walked (each path \
keeps its own prefix); got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert_eq!(weights.len(), 3);
}
#[cfg(unix)]
#[test]
fn load_weights_dangling_model_shard_fails_not_falls_back() {
let dir = fresh_dir("dangling-shard");
write_one_tensor(&dir.join("weights.safetensors"), "stale.weight");
std::os::unix::fs::symlink(
dir.join("does-not-exist.safetensors"),
dir.join("model.safetensors"),
)
.unwrap();
let result = load_weights(&dir);
let Err(err) = result else {
panic!(
"a dangling `model.safetensors` must fail the load, not fall back to \
the stale `weight*.safetensors`"
);
};
assert!(
matches!(err, Error::FileIo(_)),
"expected FileIo error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("model.safetensors"),
"the error must name the broken shard; got: {msg}"
);
}
#[test]
fn load_weights_model_named_directory_fails_not_falls_back() {
let dir = fresh_dir("model-named-dir");
write_one_tensor(&dir.join("weights.safetensors"), "stale.weight");
std::fs::create_dir_all(dir.join("model.safetensors")).unwrap();
let result = load_weights(&dir);
let Err(err) = result else {
panic!(
"a `model.safetensors` DIRECTORY must fail the load, not be descended \
and fall back to the stale `weight*.safetensors`"
);
};
assert!(
matches!(err, Error::FileIo(_)),
"expected FileIo error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("model.safetensors"),
"the error must name the offending entry; got: {msg}"
);
}
#[cfg(unix)]
#[test]
fn load_weights_model_named_symlink_to_directory_fails() {
let dir = fresh_dir("model-named-symlink-dir");
write_one_tensor(&dir.join("weights.safetensors"), "stale.weight");
let real = dir.join("real_dir");
std::fs::create_dir_all(&real).unwrap();
std::os::unix::fs::symlink(&real, dir.join("model.safetensors")).unwrap();
let result = load_weights(&dir);
let Err(err) = result else {
panic!("a `model.safetensors` symlink-to-directory must fail the load");
};
assert!(
matches!(err, Error::FileIo(_)),
"expected FileIo error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("model.safetensors"),
"the error must name the offending entry; got: {msg}"
);
}
#[cfg(unix)]
#[test]
fn load_weights_suppresses_unreadable_subdir() {
use std::os::unix::fs::PermissionsExt;
let dir = fresh_dir("unreadable-subdir");
let vision = dir.join("vision_model");
std::fs::create_dir_all(&vision).unwrap();
write_one_tensor(&vision.join("model.safetensors"), "encoder.weight");
let locked = dir.join("locked");
std::fs::create_dir_all(&locked).unwrap();
std::fs::set_permissions(&locked, std::fs::Permissions::from_mode(0o000)).unwrap();
let enforced = std::fs::read_dir(&locked).is_err();
let result = load_weights(&dir);
std::fs::set_permissions(&locked, std::fs::Permissions::from_mode(0o755)).unwrap();
if !enforced {
eprintln!(
"skipping unreadable-subdir assertion: this environment does not \
enforce directory read permission"
);
return;
}
let weights = result.expect("an unreadable nested dir must be skipped, not fail the load");
assert!(
weights.contains_key("vision_model.encoder.weight"),
"the readable sibling's shard must still load; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert_eq!(weights.len(), 1);
}
#[test]
fn load_weights_glob_recurses_deeply_and_excludes_hidden() {
let dir = fresh_dir("glob-deep-hidden");
write_one_tensor(&dir.join("model.safetensors"), "root.weight");
let deep = dir.join("a").join("b").join("c");
std::fs::create_dir_all(&deep).unwrap();
write_one_tensor(&deep.join("model.safetensors"), "deep.weight");
let hidden_dir = dir.join(".checkpoints");
std::fs::create_dir_all(&hidden_dir).unwrap();
write_one_tensor(&hidden_dir.join("model.safetensors"), "hidden_dir.weight");
write_one_tensor(&dir.join(".model.safetensors"), "hidden_file.weight");
let under_hidden = dir.join(".secret").join("text_model");
std::fs::create_dir_all(&under_hidden).unwrap();
write_one_tensor(
&under_hidden.join("model.safetensors"),
"under_hidden.weight",
);
let weights = load_weights(&dir).expect("deep recursive glob load");
assert!(
weights.contains_key("root.weight"),
"the root shard must load (** matches the model dir itself); got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert!(
weights.contains_key("c.deep.weight"),
"the deeply-nested shard must load, prefixed by its immediate parent `c`; got {:?}",
weights.keys().collect::<Vec<_>>()
);
for forbidden in [
"hidden_dir.weight",
".checkpoints.hidden_dir.weight",
"checkpoints.hidden_dir.weight",
"hidden_file.weight",
"under_hidden.weight",
"text_model.under_hidden.weight",
] {
assert!(
!weights.contains_key(forbidden),
"a `.`-prefixed path component must exclude its shard \
(path_has_hidden_component); leaked {forbidden:?} in {:?}",
weights.keys().collect::<Vec<_>>()
);
}
assert_eq!(weights.len(), 2);
}
#[cfg(unix)]
#[test]
fn load_weights_non_utf8_model_dir_is_recoverable_error() {
use std::{ffi::OsStr, os::unix::ffi::OsStrExt};
let mut raw: Vec<u8> = b"/tmp/mlxrs-emb-non-utf8-".to_vec();
raw.push(0xFF);
let bad_dir = PathBuf::from(OsStr::from_bytes(&raw));
assert!(
bad_dir.to_str().is_none(),
"test precondition: the constructed path must be non-UTF-8"
);
let Err(err) = load_weights(&bad_dir) else {
panic!("a non-UTF-8 model dir path must be a recoverable error, not a panic");
};
assert!(
matches!(err, Error::FileIo(_)),
"expected FileIo error; got {err:?}"
);
assert!(
err.to_string().contains("not valid UTF-8"),
"the error should explain the non-UTF-8 path rejection; got: {err}"
);
}
#[cfg(unix)]
#[test]
fn load_weights_non_utf8_descendant_does_not_panic() {
use std::os::unix::ffi::OsStringExt;
let dir = fresh_dir("non-utf8-child");
write_one_tensor(&dir.join("model.safetensors"), "root.weight");
let nested = dir.join("text_model");
std::fs::create_dir_all(&nested).unwrap();
write_one_tensor(&nested.join("model.safetensors"), "encoder.weight");
let non_utf8_name = std::ffi::OsString::from_vec(vec![b'm', 0xFF]);
if std::fs::write(dir.join(&non_utf8_name), b"junk").is_err() {
return;
}
let _ = std::fs::write(nested.join(&non_utf8_name), b"junk");
let weights = load_weights(&dir).expect("a non-UTF-8 descendant must not break the glob walk");
assert!(
weights.contains_key("root.weight"),
"the root shard must still load past a non-UTF-8 sibling; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert!(
weights.contains_key("text_model.encoder.weight"),
"the nested shard must still load past a non-UTF-8 sibling; got {:?}",
weights.keys().collect::<Vec<_>>()
);
assert_eq!(weights.len(), 2);
}
#[cfg(unix)]
#[test]
fn load_weights_non_utf8_nested_parent_is_recoverable_error() {
use std::os::unix::ffi::OsStringExt;
let dir = fresh_dir("non-utf8-parent");
let bad_folder = std::ffi::OsString::from_vec(vec![b't', 0xFF]);
let nested = dir.join(&bad_folder);
if std::fs::create_dir_all(&nested).is_err() {
return;
}
write_one_tensor(&nested.join("model.safetensors"), "encoder.weight");
let Err(err) = load_weights(&dir) else {
panic!(
"a nested shard under a non-UTF-8 parent folder must be a recoverable error, not a \
silent root-merge or a panic"
);
};
let Error::FileIo(payload) = &err else {
panic!("expected Error::FileIo; got {err:?}");
};
assert_eq!(
payload.op(),
FileOp::Other("weight_shard_discovery"),
"non-UTF-8 parent rejection MUST surface as FileOp::Other(\"weight_shard_discovery\"); \
got {:?}",
payload.op()
);
assert_eq!(
payload.inner().kind(),
std::io::ErrorKind::InvalidData,
"non-UTF-8 parent rejection MUST carry io::ErrorKind::InvalidData; got {:?}",
payload.inner().kind()
);
let msg = err.to_string();
assert!(
msg.contains("non-UTF-8 parent directory name"),
"the error should explain the non-UTF-8 parent rejection; got: {msg}"
);
assert!(
payload.path().starts_with(&dir),
"the error path must sit under the model directory; got {}",
payload.path().display()
);
assert!(
payload.path().ends_with("model.safetensors"),
"the error path must end with the shard file name; got {}",
payload.path().display()
);
}
#[test]
fn load_pooling_validated_before_heavy_io() {
let dir = fresh_dir("pooling-first");
std::fs::write(dir.join("config.json"), mock_config_json("mockemb")).unwrap();
std::fs::write(dir.join("model.safetensors"), b"not a safetensors file").unwrap();
let pooling_dir = dir.join("1_Pooling");
std::fs::create_dir_all(&pooling_dir).unwrap();
std::fs::write(pooling_dir.join("config.json"), b"{ not valid json").unwrap();
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("malformed pooling config must error");
};
let msg = err.to_string();
assert!(
!msg.contains("safetensors") && !msg.contains("no model weights"),
"pooling must be validated before the weight load; got: {msg}"
);
}
#[test]
fn configuration_new_exposes_id_and_tokenizer_source() {
let id = EmbeddingIdentifier::Directory(PathBuf::from("/models/enc"));
let tok = PathBuf::from("/tok/dir");
let cfg = EmbeddingModelConfiguration::new(id.clone(), Some(tok.clone()));
assert_eq!(cfg.id(), &id);
assert!(cfg.id().is_directory());
assert_eq!(cfg.tokenizer_source(), Some(tok.as_path()));
assert_eq!(cfg.tokenizer_directory(), tok.as_path());
assert_eq!(cfg.model_directory(), Path::new("/models/enc"));
let id2 = EmbeddingIdentifier::Id("org/name".to_owned());
let cfg2 = EmbeddingModelConfiguration::new(id2.clone(), None);
assert_eq!(cfg2.id(), &id2);
assert!(cfg2.id().is_id());
assert_eq!(cfg2.tokenizer_source(), None);
assert_eq!(cfg2.model_directory(), Path::new("org/name"));
assert_eq!(cfg2.tokenizer_directory(), Path::new("org/name"));
}
#[test]
fn loaded_embedding_model_accessors_return_components() {
let mut weights: EmbeddingWeights = HashMap::new();
weights.insert(
"enc.weight".to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap(),
);
let loaded = LoadedEmbeddingModel::new(
"bert".to_owned(),
r#"{"model_type": "bert"}"#.to_owned(),
weights,
);
assert_eq!(loaded.model_type(), "bert");
assert_eq!(loaded.config_json(), r#"{"model_type": "bert"}"#);
assert!(loaded.weights_ref().contains_key("enc.weight"));
assert_eq!(loaded.weights_ref().len(), 1);
}
#[test]
fn create_invokes_registered_constructor() {
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let mut weights: EmbeddingWeights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap(),
);
let loaded =
LoadedEmbeddingModel::new("mockemb".to_owned(), mock_config_json("mockemb"), weights);
let model = registry
.create(&loaded)
.expect("registered type must construct");
let ids = Array::from_slice::<i32>(&[0, 1], &(1usize, 2)).unwrap();
let mask = Array::from_slice::<f32>(&[1.0, 1.0], &(1usize, 2)).unwrap();
let out = model.forward(&ids, &mask).unwrap();
assert_eq!(out.last_hidden_state().shape(), vec![1, 2, 4]);
}
#[test]
fn create_rejects_unregistered_model_type() {
let registry = EmbeddingModelTypeRegistry::new().with("mockemb", mock_constructor());
let mut weights: EmbeddingWeights = HashMap::new();
weights.insert(
"x.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
let loaded = LoadedEmbeddingModel::new("nope".to_owned(), mock_config_json("nope"), weights);
let Err(err) = registry.create(&loaded) else {
panic!("an unregistered model_type must be a recoverable error");
};
let Error::UnknownEnumValue(payload) = &err else {
panic!("expected Error::UnknownEnumValue; got {err:?}");
};
assert_eq!(
payload.value(),
"nope",
"the rejected model type must be carried in the payload value"
);
assert!(payload.supported().is_empty());
let msg = err.to_string();
assert!(msg.contains("no constructor registered"), "got: {msg}");
assert!(msg.contains("nope"), "error should name the type: {msg}");
}
#[test]
fn path_has_hidden_component_not_under_root_is_false() {
let root = Path::new("/models/enc");
assert!(!path_has_hidden_component(
Path::new("/other/.hidden/model.safetensors"),
root
));
assert!(!path_has_hidden_component(
Path::new("/models/enc/model.safetensors"),
Path::new("/models/.enc")
));
}
#[test]
fn path_has_hidden_component_ignores_non_normal_components() {
let root = Path::new("/models/enc");
assert!(!path_has_hidden_component(
Path::new("/models/enc/sub/../model.safetensors"),
root
));
assert!(path_has_hidden_component(
Path::new("/models/enc/sub/../.cache/model.safetensors"),
root
));
}
#[test]
fn glob_root_drops_curdir_and_keeps_structure() {
assert_eq!(glob_root(Path::new(".")), PathBuf::new());
assert_eq!(glob_root(Path::new("./")), PathBuf::new());
assert_eq!(glob_root(Path::new("./sub")), PathBuf::from("sub"));
assert_eq!(glob_root(Path::new("a/./b")), PathBuf::from("a/b"));
assert_eq!(
glob_root(Path::new("/models/enc")),
PathBuf::from("/models/enc")
);
assert_eq!(
glob_root(Path::new("../models/enc")),
PathBuf::from("../models/enc")
);
assert!(
Path::new("model.safetensors")
.strip_prefix(glob_root(Path::new(".")))
.is_ok()
);
}
#[test]
fn scan_non_utf8_shards_ignores_unrecognized_suffix() {
assert!(scan_non_utf8_shards(Path::new("/nonexistent-xyz"), "not-a-known-suffix").is_ok());
assert!(scan_non_utf8_shards(Path::new("/nonexistent-xyz"), "").is_ok());
}
#[test]
fn scan_non_utf8_shards_unreadable_dir_is_suppressed() {
assert!(scan_non_utf8_shards(Path::new("/no/such/dir-xyz"), "**/model*.safetensors").is_ok());
let dir = fresh_dir("scan-empty");
assert!(scan_non_utf8_shards(&dir, "**/model*.safetensors").is_ok());
assert!(scan_non_utf8_shards(&dir, "weight*.safetensors").is_ok());
}
#[test]
fn collect_glob_shards_derives_root_and_nested_prefixes() {
let dir = fresh_dir("collect-prefix");
write_one_tensor(&dir.join("model.safetensors"), "root.w");
let vision = dir.join("vision_model");
std::fs::create_dir_all(&vision).unwrap();
write_one_tensor(&vision.join("model.safetensors"), "enc.w");
let shards =
collect_glob_shards(&dir, "**/model*.safetensors").expect("glob must collect both shards");
assert_eq!(
shards.len(),
2,
"a root + a nested shard must be discovered"
);
let root = shards
.iter()
.find(|s| s.path == dir.join("model.safetensors"))
.expect("the root shard must be present");
assert_eq!(
root.prefix, None,
"a genuine root shard keeps its keys verbatim (no prefix)"
);
let nested = shards
.iter()
.find(|s| s.path == vision.join("model.safetensors"))
.expect("the nested shard must be present");
assert_eq!(
nested.prefix.as_deref(),
Some("vision_model"),
"a nested shard's prefix is the IMMEDIATE parent folder name"
);
}
#[test]
fn read_optional_pooling_absent_pooling_dir_is_none() {
let dir = fresh_dir("pool-absent");
assert!(
read_optional_pooling(&dir)
.expect("absent pooling dir must be Ok")
.is_none()
);
}
#[test]
fn read_optional_pooling_real_dir_without_config_is_none() {
let dir = fresh_dir("pool-empty");
std::fs::create_dir_all(dir.join("1_Pooling")).unwrap();
assert!(
read_optional_pooling(&dir)
.expect("a real but empty 1_Pooling dir must be Ok(None)")
.is_none()
);
}
#[cfg(unix)]
#[test]
fn read_optional_pooling_symlink_to_real_dir_without_config_is_none() {
let dir = fresh_dir("pool-symlink-dir");
let real = dir.join("real_pooling");
std::fs::create_dir_all(&real).unwrap();
std::os::unix::fs::symlink(&real, dir.join("1_Pooling")).unwrap();
assert!(
read_optional_pooling(&dir)
.expect("a symlinked-to-real-dir 1_Pooling without config must be Ok(None)")
.is_none()
);
}
#[cfg(unix)]
#[test]
fn read_optional_pooling_dangling_symlink_parent_fails_closed() {
let dir = fresh_dir("pool-dangling-parent");
std::os::unix::fs::symlink(dir.join("does-not-exist"), dir.join("1_Pooling")).unwrap();
let Err(err) = read_optional_pooling(&dir) else {
panic!("a dangling 1_Pooling symlink must fail closed, not be Ok(None)");
};
let Error::FileIo(payload) = &err else {
panic!("expected Error::FileIo; got {err:?}");
};
assert_eq!(payload.op(), FileOp::Stat);
assert!(
payload.path().ends_with("1_Pooling"),
"the error must name the broken parent path; got {}",
payload.path().display()
);
}
#[cfg(unix)]
#[test]
fn read_optional_pooling_symlink_to_file_fails_on_child_probe() {
let dir = fresh_dir("pool-symlink-file");
let target = dir.join("a_file");
std::fs::write(&target, b"not a directory").unwrap();
std::os::unix::fs::symlink(&target, dir.join("1_Pooling")).unwrap();
let Err(err) = read_optional_pooling(&dir) else {
panic!("a 1_Pooling symlink to a file must fail closed, not be Ok(None)");
};
let Error::FileIo(payload) = &err else {
panic!("expected Error::FileIo; got {err:?}");
};
assert_eq!(payload.op(), FileOp::Stat);
assert!(
payload.path().ends_with("config.json"),
"the error must name the probed child path; got {}",
payload.path().display()
);
assert!(
err.to_string().contains("presence probe failed"),
"the error should explain the failed presence probe; got: {err}"
);
}
#[test]
fn read_model_type_rejects_non_regular_config_handle() {
let dir = fresh_dir("config-is-dir");
std::fs::create_dir_all(dir.join("config.json")).unwrap();
let Err(err) = read_model_type(&dir) else {
panic!("a directory `config.json` must be rejected");
};
let Error::FileIo(payload) = &err else {
panic!("expected Error::FileIo; got {err:?}");
};
assert_eq!(payload.op(), FileOp::Stat);
assert!(
err.to_string().contains("not a regular file"),
"the error should explain the non-regular handle; got: {err}"
);
}
#[test]
fn read_model_type_rejects_non_utf8_bytes() {
let dir = fresh_dir("config-non-utf8");
std::fs::write(dir.join("config.json"), [0xFF, 0xFE, 0x00, 0x01]).unwrap();
let Err(err) = read_model_type(&dir) else {
panic!("non-UTF-8 config bytes must error");
};
assert!(
matches!(err, Error::Parse(_)),
"expected Error::Parse; got {err:?}"
);
assert!(
err.to_string().contains("not valid UTF-8"),
"the error should explain the UTF-8 failure; got: {err}"
);
}
#[test]
fn read_model_type_wraps_invalid_json_as_parse_error() {
let dir = fresh_dir("config-bad-json");
std::fs::write(dir.join("config.json"), b"{ this is not json").unwrap();
let Err(err) = read_model_type(&dir) else {
panic!("malformed config.json must error");
};
assert!(
matches!(err, Error::Parse(_)),
"expected Error::Parse; got {err:?}"
);
assert!(
err.to_string().contains("invalid model config"),
"the error should name the invalid config; got: {err}"
);
}
#[test]
fn read_model_type_missing_field_is_typed_error() {
let dir = fresh_dir("config-no-model-type");
std::fs::write(
dir.join("config.json"),
r#"{"hidden_size": 8, "vocab_size": 5}"#,
)
.unwrap();
let Err(err) = read_model_type(&dir) else {
panic!("a config without model_type must error");
};
let Error::MissingField(payload) = &err else {
panic!("expected Error::MissingField; got {err:?}");
};
assert_eq!(payload.type_name(), "config.json");
assert_eq!(payload.field(), "model_type");
}
#[test]
fn read_model_type_canonicalizes_and_returns_raw_body() {
let dir = fresh_dir("config-canon");
let body = mock_config_json("xlm-roberta");
std::fs::write(dir.join("config.json"), &body).unwrap();
let (model_type, raw) = read_model_type(&dir).expect("a valid config must read");
assert_eq!(model_type, "xlm_roberta");
assert_eq!(raw, body, "the verbatim config body must be returned");
}
#[test]
fn extract_rejects_unexpected_byte_after_pair() {
let r = extract_string_field(r#"{"model_type": "bert"; "x": 1}"#, "model_type");
assert!(r.is_err(), "a stray separator byte must be rejected");
let msg = r.unwrap_err();
assert!(
msg.contains("expected `,` or `}`"),
"the error should name the separator expectation; got: {msg}"
);
}
#[test]
fn extract_full_unicode_escape_machinery() {
let src = r#"{"model_type": "A😀\"\\\/\b\f\n\r\t"}"#;
let got = extract_string_field(src, "model_type")
.unwrap()
.expect("the escaped string value must decode");
let expected = format!("A\u{1F600}\"\\/{}{}\n\r\t", '\u{08}', '\u{0C}');
assert_eq!(got, expected, "all JSON escapes must decode correctly");
}
#[test]
fn extract_rejects_surrogate_and_escape_errors() {
let cases = [
r#"{"model_type": "\uD83D"}"#, r#"{"model_type": "\uD83Dx"}"#, r#"{"model_type": "\uD83D\nABCD"}"#, r#"{"model_type": "\uD83D\u0041"}"#, r#"{"model_type": "\uDC00"}"#, r#"{"model_type": "\q"}"#, r#"{"model_type": "\uZZZZ"}"#, r#"{"model_type": "\u00"}"#, "{\"model_type\": \"ab\u{0001}\"}", r#"{"model_type": "\"#, r#"{"model_type": "unterminated"#, ];
for bad in cases {
assert!(
extract_string_field(bad, "model_type").is_err(),
"a surrogate/escape/control error must be rejected: {bad:?}"
);
}
}
#[test]
fn extract_skips_nested_objects_arrays_and_literals() {
let src = r#"{
"empty_obj": {},
"empty_arr": [],
"obj": {"a": 1, "b": [true, false, null], "c": {"d": "e"}},
"arr": [1, {"x": 2}, [3, 4], "s"],
"t": true,
"f": false,
"z": null,
"model_type": "deepskip"
}"#;
assert_eq!(
extract_string_field(src, "model_type").unwrap(),
Some("deepskip".to_owned())
);
}
#[test]
fn extract_rejects_invalid_literal_and_eoi_arms() {
let cases = [
r#"{"x": tru}"#, r#"{"x": nul}"#, r#"{"x": fa}"#, r#"{"x": @}"#, r#"{"x": "#, r#"{"x": [1, 2"#, r#"{"x": {"a": 1"#, r#"{"x": [1; 2]}"#, r#"{"x": [1, ]}"#, r#"{"x": {"a": 1 ; "b": 2}}"#, r#"{"x": {"a": 1,}}"#, r#"{"x" 1}"#, ];
for bad in cases {
assert!(
extract_string_field(bad, "model_type").is_err(),
"a malformed value-skip input must be rejected: {bad:?}"
);
}
}
#[test]
fn json_cursor_expect_reports_end_of_input() {
let r = extract_string_field("", "model_type");
assert!(r.is_err(), "an empty document must error");
let msg = r.unwrap_err();
assert!(
msg.contains("reached end of input"),
"expect should report end-of-input; got: {msg}"
);
assert!(extract_string_field(" \n\t ", "model_type").is_err());
}
#[test]
fn extract_decodes_bmp_unicode_escape_success() {
let got = extract_string_field("{\"model_type\": \"caf\\u00e9\"}", "model_type")
.unwrap()
.expect("a BMP \\u escape value must decode");
assert_eq!(
got, "caf\u{00e9}",
"lowercase \\u00e9 must decode to U+00E9"
);
let got_upper = extract_string_field("{\"model_type\": \"caf\\u00E9\"}", "model_type")
.unwrap()
.expect("an uppercase \\u escape value must decode");
assert_eq!(
got_upper, "caf\u{00e9}",
"uppercase \\u00E9 must decode to U+00E9"
);
let got_ascii = extract_string_field("{\"model_type\": \"\\u0041BC\"}", "model_type")
.unwrap()
.expect("a \\u0041 escape must decode to 'A'");
assert_eq!(got_ascii, "ABC", "\\u0041 must decode to ASCII 'A'");
}
#[test]
fn extract_decodes_valid_surrogate_pair_success() {
let got = extract_string_field("{\"model_type\": \"x\\uD83D\\uDE00y\"}", "model_type")
.unwrap()
.expect("a valid surrogate pair must decode");
assert_eq!(got, "x\u{1F600}y", "\\uD83D\\uDE00 must decode to U+1F600");
let got_lower = extract_string_field("{\"model_type\": \"\\ud83d\\ude00\"}", "model_type")
.unwrap()
.expect("a lowercase-hex surrogate pair must decode");
assert_eq!(
got_lower, "\u{1F600}",
"lowercase surrogate pair must decode to U+1F600"
);
let got_min = extract_string_field("{\"model_type\": \"\\uD800\\uDC00\"}", "model_type")
.unwrap()
.expect("the minimal astral surrogate pair must decode");
assert_eq!(
got_min, "\u{10000}",
"\\uD800\\uDC00 must decode to U+10000"
);
}