#![cfg(feature = "lm")]
use std::{collections::HashMap, fs, io::Write, path::PathBuf, process};
use mlxrs::{
Array, io,
lm::load::{self, Config},
};
fn temp_dir(name: &str) -> PathBuf {
let dir = std::env::temp_dir().join(format!("mlxrs_lm_load_{}_{}", process::id(), name));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
const TOKENIZER_JSON: &str = include_str!("fixtures/tokenizer.json");
const TOKENIZER_CONFIG_JSON: &str = include_str!("fixtures/tokenizer_config.json");
const FULL_CONFIG_JSON: &str = r#"{
"model_type": "qwen3",
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"head_dim": 64,
"rope_theta": 1000000.0,
"vocab_size": 151936,
"tie_word_embeddings": true,
"sliding_window": 4096,
"quantization": { "group_size": 64, "bits": 4 },
"max_position_embeddings": 32768,
"some_future_key": [1, 2, 3]
}"#;
#[test]
fn config_parses_minimal_and_ignores_extra() {
let cfg = Config::from_json(FULL_CONFIG_JSON).unwrap();
assert_eq!(cfg.model_type(), "qwen3");
assert_eq!(cfg.hidden_size, 1024);
assert_eq!(cfg.num_hidden_layers, 24);
assert_eq!(cfg.num_attention_heads, 16);
assert_eq!(cfg.num_key_value_heads, 8);
assert_eq!(cfg.head_dim, 64);
assert_eq!(cfg.rope_theta, 1_000_000.0);
assert_eq!(cfg.vocab_size, 151936);
assert!(cfg.tie_word_embeddings);
assert_eq!(cfg.sliding_window, Some(4096));
let q = cfg.quantization.expect("quantization block present");
assert_eq!(q.group_size, 64);
assert_eq!(q.bits, 4);
}
#[test]
fn config_optionals_default_when_absent() {
let json = r#"{
"model_type": "llama",
"hidden_size": 512,
"num_hidden_layers": 4,
"num_attention_heads": 8,
"num_key_value_heads": 8,
"head_dim": 64,
"rope_theta": 10000.0,
"vocab_size": 32000,
"tie_word_embeddings": false,
"unrelated": "ignored"
}"#;
let cfg = Config::from_json(json).unwrap();
assert_eq!(cfg.model_type(), "llama");
assert!(!cfg.tie_word_embeddings);
assert_eq!(cfg.sliding_window, None);
assert!(cfg.quantization.is_none());
}
#[test]
fn config_missing_required_is_parse_error() {
let json = r#"{
"model_type": "qwen3",
"hidden_size": 1024,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"head_dim": 64,
"rope_theta": 1000000.0,
"vocab_size": 151936,
"tie_word_embeddings": true
}"#;
let err = Config::from_json(json).unwrap_err();
assert!(
matches!(&err, mlxrs::Error::Parse(p)
if p.context() == "Config::from_json" && p.input_kind() == "model config JSON"),
"expected Error::Parse from Config::from_json, got {err:?}"
);
}
#[test]
fn config_invalid_json_is_parse_error() {
let err = Config::from_json("{ not json").unwrap_err();
assert!(
matches!(&err, mlxrs::Error::Parse(p)
if p.context() == "Config::from_json" && p.input_kind() == "model config JSON"),
"expected Error::Parse from Config::from_json, got {err:?}"
);
}
fn small(v: &[f32], shape: (usize, usize)) -> Array {
Array::from_slice(v, &shape).unwrap()
}
#[test]
fn weights_merges_shards_and_keeps_quant_triples() {
let dir = temp_dir("shards");
let mut s1 = HashMap::new();
s1.insert("a.weight".to_string(), small(&[1.0, 2.0, 3.0, 4.0], (2, 2)));
s1.insert("a.scales".to_string(), small(&[0.5, 0.25], (1, 2)));
s1.insert("a.biases".to_string(), small(&[0.1, 0.2], (1, 2)));
let mut s2 = HashMap::new();
s2.insert("b.weight".to_string(), small(&[9.0, 8.0], (1, 2)));
io::save_safetensors(&dir.join("model-00001-of-00002.safetensors"), &s1).unwrap();
io::save_safetensors(&dir.join("model-00002-of-00002.safetensors"), &s2).unwrap();
fs::write(
dir.join("model.safetensors.index.json"),
br#"{
"metadata": { "total_size": 32, "total_parameters": 8 },
"weight_map": {
"a.weight": "model-00001-of-00002.safetensors",
"a.scales": "model-00001-of-00002.safetensors",
"a.biases": "model-00001-of-00002.safetensors",
"b.weight": "model-00002-of-00002.safetensors"
}
}"#,
)
.unwrap();
let mut w = load::load_weights(&dir).unwrap();
assert_eq!(w.len(), 4, "all four keys merged");
let mut aw = w.remove("a.weight").unwrap();
assert_eq!(aw.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
let mut asc = w.remove("a.scales").unwrap();
assert_eq!(asc.to_vec::<f32>().unwrap(), vec![0.5, 0.25]);
let mut ab = w.remove("a.biases").unwrap();
assert_eq!(ab.to_vec::<f32>().unwrap(), vec![0.1, 0.2]);
let mut bw = w.remove("b.weight").unwrap();
assert_eq!(bw.to_vec::<f32>().unwrap(), vec![9.0, 8.0]);
}
#[test]
fn weights_single_unsharded_safetensors() {
let dir = temp_dir("single");
let mut m = HashMap::new();
m.insert(
"tok_embeddings.weight".to_string(),
small(&[1.0, 2.0], (1, 2)),
);
io::save_safetensors(&dir.join("model.safetensors"), &m).unwrap();
let w = load::load_weights(&dir).unwrap();
assert_eq!(w.len(), 1);
assert!(w.contains_key("tok_embeddings.weight"));
}
#[test]
fn weights_missing_is_file_io_error() {
let dir = temp_dir("empty");
let err = load::load_weights(&dir).unwrap_err();
assert!(
matches!(&err, mlxrs::Error::FileIo(p)
if p.path() == dir.as_path() && p.inner().kind() == std::io::ErrorKind::NotFound),
"expected Error::FileIo(NotFound) for a dir with no weights, got {err:?}"
);
}
fn write_model_dir(name: &str) -> PathBuf {
let dir = temp_dir(name);
fs::write(dir.join("config.json"), FULL_CONFIG_JSON).unwrap();
let mut m = HashMap::new();
m.insert(
"model.embed_tokens.weight".to_string(),
small(&[1.0, 2.0, 3.0, 4.0], (2, 2)),
);
io::save_safetensors(&dir.join("model.safetensors"), &m).unwrap();
let mut tj = fs::File::create(dir.join("tokenizer.json")).unwrap();
tj.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
let mut tc = fs::File::create(dir.join("tokenizer_config.json")).unwrap();
tc.write_all(TOKENIZER_CONFIG_JSON.as_bytes()).unwrap();
dir
}
#[test]
fn load_returns_config_weights_tokenizer() {
let dir = write_model_dir("full");
let (cfg, weights, tok) = load::load(&dir).unwrap();
assert_eq!(cfg.model_type(), "qwen3");
assert_eq!(cfg.num_hidden_layers, 24);
assert_eq!(cfg.sliding_window, Some(4096));
assert!(weights.contains_key("model.embed_tokens.weight"));
let ids = tok.encode("hello world", false).unwrap();
assert!(!ids.is_empty());
assert_eq!(ids, vec![3, 4]);
}
#[test]
fn load_missing_config_is_file_io_error() {
let dir = temp_dir("no_config");
let mut m = HashMap::new();
m.insert("w".to_string(), small(&[1.0], (1, 1)));
io::save_safetensors(&dir.join("model.safetensors"), &m).unwrap();
let mut tj = fs::File::create(dir.join("tokenizer.json")).unwrap();
tj.write_all(TOKENIZER_JSON.as_bytes()).unwrap();
match load::load(&dir) {
Err(mlxrs::Error::FileIo(p))
if p.inner().kind() == std::io::ErrorKind::NotFound
&& p.path() == dir.join("config.json") => {}
Err(other) => panic!("expected Error::FileIo(NotFound) for missing config.json, got {other:?}"),
Ok(_) => panic!("expected Err when config.json absent, got Ok"),
}
}
fn write_meta(dir: &std::path::Path) {
fs::write(
dir.join("config.json"),
br#"{"model_type":"llama","hidden_size":8,"num_hidden_layers":2,"num_attention_heads":2,"num_key_value_heads":2,"head_dim":4,"rope_theta":10000.0,"vocab_size":32,"tie_word_embeddings":false}"#,
)
.unwrap();
fs::write(dir.join("tokenizer.json"), TOKENIZER_JSON).unwrap();
fs::write(dir.join("tokenizer_config.json"), TOKENIZER_CONFIG_JSON).unwrap();
}
fn loadable(name: &str) -> PathBuf {
let d = temp_dir(name);
write_meta(&d);
let mut m = HashMap::new();
m.insert("w".to_string(), small(&[1.0], (1, 1)));
io::save_safetensors(&d.join("model.safetensors"), &m).unwrap();
d
}
#[cfg(unix)]
#[test]
fn load_follows_symlinked_weights_hf_snapshot_layout() {
let dir = temp_dir("symlink_weights");
write_meta(&dir);
let blobs = dir.join("blobs");
fs::create_dir_all(&blobs).unwrap();
let mut m = HashMap::new();
m.insert(
"blk.0.weight".to_string(),
small(&[1.0, 2.0, 3.0, 4.0], (2, 2)),
);
io::save_safetensors(&blobs.join("blob.safetensors"), &m).unwrap();
std::os::unix::fs::symlink(
blobs.join("blob.safetensors"),
dir.join("model.safetensors"),
)
.unwrap();
let (_c, w, _t) = load::load(&dir)
.expect("a HF-snapshot-style dir whose model.safetensors is a symlink must load");
let arr = w
.get("blk.0.weight")
.expect("symlinked model.safetensors must be resolved & loaded, not skipped");
assert_eq!(arr.shape(), vec![2, 2]);
}
#[test]
fn load_resolves_eos_set_replace_not_merge() {
use std::collections::BTreeSet;
let set = |ids: &[u32]| ids.iter().copied().collect::<BTreeSet<u32>>();
use load::EosTokenId::{Many, Single};
let d0 = loadable("eos_base");
let (c0, _w, t0) = load::load(&d0).expect("baseline loads");
let base: BTreeSet<u32> = t0.eos_token_ids_iter().collect();
assert_eq!(c0.eos_token_id, None, "no config/gen eos ⇒ Config eos None");
assert!(
!base.contains(&4242)
&& !base.contains(&4243)
&& !base.contains(&4244)
&& !base.contains(&7)
&& !base.contains(&8)
&& !base.contains(&9)
&& !base.contains(&10)
&& !base.contains(&0),
"test ids must be outside the fixture's base eos set: {base:?}"
);
assert!(
!base.is_empty(),
"fixture tokenizer must have its own eos for the replace guard"
);
let cfg_with_eos = |eos: &str| {
format!(
r#"{{"model_type":"llama","hidden_size":8,"num_hidden_layers":2,"num_attention_heads":2,"num_key_value_heads":2,"head_dim":4,"rope_theta":10000.0,"vocab_size":32,"tie_word_embeddings":false,"eos_token_id":{eos}}}"#
)
};
let d1 = loadable("eos_gen_list");
fs::write(
d1.join("generation_config.json"),
br#"{"eos_token_id":[4242,4243]}"#,
)
.unwrap();
let (c1, _w, t1) = load::load(&d1).unwrap();
assert_eq!(
t1.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
set(&[4242, 4243]),
"list REPLACES, not merges"
);
assert_eq!(
c1.eos_token_id,
Some(Many(vec![4242, 4243])),
"Config eos overwritten (list, shape preserved)"
);
let d2 = loadable("eos_gen_int");
fs::write(
d2.join("generation_config.json"),
br#"{"eos_token_id":4244}"#,
)
.unwrap();
let (c2, _w, t2) = load::load(&d2).unwrap();
assert_eq!(
t2.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
set(&[4244]),
"scalar REPLACES, not merges"
);
assert_eq!(
c2.eos_token_id,
Some(Single(4244)),
"Config eos overwritten (scalar, shape preserved)"
);
let dl0 = loadable("eos_gen_list0");
fs::write(
dl0.join("generation_config.json"),
br#"{"eos_token_id":[0,4242]}"#,
)
.unwrap();
let (cl0, _w, tl0) = load::load(&dl0).unwrap();
assert_eq!(
tl0.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
set(&[0, 4242]),
"list [0,..] keeps 0"
);
assert_eq!(
cl0.eos_token_id,
Some(Many(vec![0, 4242])),
"Config eos list keeps 0"
);
let dz = loadable("eos_gen_zero");
fs::write(dz.join("generation_config.json"), br#"{"eos_token_id":0}"#).unwrap();
let (cz, _w, tz) = load::load(&dz).unwrap();
assert_eq!(
tz.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
base,
"falsy scalar 0 ⇒ tokenizer default"
);
assert_eq!(
cz.eos_token_id, None,
"falsy scalar 0 ⇒ Config eos untouched"
);
let de = loadable("eos_gen_empty");
fs::write(de.join("generation_config.json"), br#"{"eos_token_id":[]}"#).unwrap();
let (ce, _w, te) = load::load(&de).unwrap();
assert_eq!(
te.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
base,
"empty list is falsy ⇒ tokenizer default"
);
assert_eq!(ce.eos_token_id, None, "empty list ⇒ Config eos untouched");
let dc = loadable("eos_cfg_int");
fs::write(dc.join("config.json"), cfg_with_eos("7")).unwrap();
let (cc, _w, tc) = load::load(&dc).unwrap();
assert_eq!(
tc.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
set(&[7]),
"config.json eos REPLACES default"
);
assert_eq!(
cc.eos_token_id,
Some(Single(7)),
"Config eos = config.json (scalar, no gen)"
);
let dcl = loadable("eos_cfg_list");
fs::write(dcl.join("config.json"), cfg_with_eos("[9,10]")).unwrap();
let (ccl, _w, tcl) = load::load(&dcl).unwrap();
assert_eq!(
tcl.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
set(&[9, 10]),
"config.json list REPLACES default"
);
assert_eq!(
ccl.eos_token_id,
Some(Many(vec![9, 10])),
"Config eos = config.json (list, no gen)"
);
let dp = loadable("eos_precedence");
fs::write(dp.join("config.json"), cfg_with_eos("7")).unwrap();
fs::write(dp.join("generation_config.json"), br#"{"eos_token_id":8}"#).unwrap();
let (cp, _w, tp) = load::load(&dp).unwrap();
assert_eq!(
tp.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
set(&[8]),
"generation_config overrides config.json"
);
assert_eq!(
cp.eos_token_id,
Some(Single(8)),
"Config eos overwritten by truthy generation_config (precedence)"
);
let db = loadable("eos_bad");
fs::write(db.join("generation_config.json"), b"{ not json").unwrap();
let (cb, _w, tb) = load::load(&db).expect("malformed generation_config is tolerated");
assert_eq!(
tb.eos_token_ids_iter().collect::<BTreeSet<u32>>(),
base,
"malformed ⇒ tokenizer default"
);
assert_eq!(cb.eos_token_id, None, "malformed ⇒ Config eos untouched");
}