use super::*;
use crate::lm::quant::{PerLayerQuantization, QuantMode, Quantization};
fn fresh_dir(tag: &str) -> std::path::PathBuf {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!("mlxrs-lm-save-{tag}-{}-{n}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
fn f32_weight(n: usize) -> Array {
Array::from_slice::<f32>(&vec![0.0_f32; n], &(n,)).unwrap()
}
#[test]
fn array_nbytes_is_count_times_dtype_size() {
assert_eq!(array_nbytes(&f32_weight(10)).unwrap(), 40);
let u8s = Array::from_slice::<u8>(&[1u8, 2, 3], &(3usize,)).unwrap();
assert_eq!(array_nbytes(&u8s).unwrap(), 3);
let u32s = Array::from_slice::<u32>(&[0u32; 16], &(2usize, 8)).unwrap();
assert_eq!(array_nbytes(&u32s).unwrap(), 64);
}
#[test]
fn make_shards_all_fits_one_shard() {
let mut w: Weights = HashMap::new();
for name in ["a", "b", "c", "d"] {
w.insert(name.to_string(), f32_weight(25)); }
let one = make_shards(&w, MAX_FILE_SIZE_GB).unwrap();
assert_eq!(one.len(), 1, "4×100 bytes fits in one 5-GiB shard");
assert_eq!(one[0].len(), 4);
}
#[test]
fn make_shards_zero_cap_empty_leading_then_one_weight_per_shard() {
let mut w: Weights = HashMap::new();
for name in ["a", "b", "c", "d"] {
w.insert(name.to_string(), f32_weight(25));
}
let shards = make_shards(&w, 0).unwrap();
assert_eq!(shards.len(), 5);
assert!(
shards[0].is_empty(),
"guard-free split flushes an empty leading shard"
);
assert!(shards[1].contains_key("a"));
assert!(shards[2].contains_key("b"));
assert!(shards[3].contains_key("c"));
assert!(shards[4].contains_key("d"));
assert!(shards[1..].iter().all(|s| s.len() == 1));
}
#[test]
fn make_shards_over_cap_first_tensor_empty_leading_shard() {
let mut w: Weights = HashMap::new();
w.insert("big".to_string(), f32_weight(100)); w.insert("small".to_string(), f32_weight(1)); let shards = make_shards(&w, 0).unwrap();
assert_eq!(
shards.len(),
3,
"empty leading + over-cap tensor + remainder"
);
assert!(
shards[0].is_empty(),
"over-cap first tensor flushes an empty leading shard"
);
assert_eq!(shards[1].len(), 1);
assert!(shards[1].contains_key("big"));
assert_eq!(shards[2].len(), 1);
assert!(shards[2].contains_key("small"));
}
#[test]
fn make_shards_empty_map_yields_one_empty_shard() {
let w: Weights = HashMap::new();
let shards = make_shards(&w, MAX_FILE_SIZE_GB).unwrap();
assert_eq!(shards.len(), 1);
assert!(shards[0].is_empty());
}
#[test]
fn make_shards_single_weight_one_shard() {
let mut w: Weights = HashMap::new();
w.insert("solo".to_string(), f32_weight(7));
let shards = make_shards(&w, MAX_FILE_SIZE_GB).unwrap();
assert_eq!(shards.len(), 1);
assert_eq!(shards[0].len(), 1);
assert!(shards[0].contains_key("solo"));
}
#[test]
fn get_total_parameters_dense_sums_sizes() {
let mut w: Weights = HashMap::new();
w.insert("model.embed.weight".to_string(), f32_weight(25));
w.insert("model.norm.weight".to_string(), f32_weight(7));
let total = get_total_parameters(&w, &PerLayerQuantization::default()).unwrap();
assert_eq!(total, 32);
}
#[test]
fn get_total_parameters_quantized_unpacks_weight_skips_scales_and_biases() {
let mut w: Weights = HashMap::new();
let packed = Array::from_slice::<u32>(&[0u32; 16], &(2usize, 8)).unwrap();
w.insert("model.layers.0.q_proj.weight".to_string(), packed);
w.insert(
"model.layers.0.q_proj.scales".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
w.insert(
"model.layers.0.q_proj.biases".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
w.insert("model.norm.weight".to_string(), f32_weight(7));
let quant = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let total = get_total_parameters(&w, &quant).unwrap();
assert_eq!(total, 128 + 7);
}
#[test]
fn get_total_parameters_counts_genuine_module_bias() {
let mut w: Weights = HashMap::new();
w.insert("model.fc.weight".to_string(), f32_weight(5));
w.insert("model.fc.bias".to_string(), f32_weight(3));
let total = get_total_parameters(&w, &PerLayerQuantization::default()).unwrap();
assert_eq!(total, 8);
}
#[test]
fn get_total_parameters_orphan_biases_without_scales_is_counted() {
let mut w: Weights = HashMap::new();
w.insert("model.x.weight".to_string(), f32_weight(4));
w.insert("model.x.biases".to_string(), f32_weight(2));
let total = get_total_parameters(&w, &PerLayerQuantization::default()).unwrap();
assert_eq!(total, 6);
}
#[test]
fn get_total_parameters_quantized_without_params_errors() {
let mut w: Weights = HashMap::new();
w.insert(
"model.q.weight".to_string(),
Array::from_slice::<u32>(&[0u32; 8], &(2usize, 4)).unwrap(),
);
w.insert(
"model.q.scales".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
let err = get_total_parameters(&w, &PerLayerQuantization::default());
let Err(Error::LayerKeyed(p)) = err else {
panic!("expected Error::LayerKeyed, got {err:?}");
};
assert_eq!(p.layer(), "model.q");
assert!(
matches!(p.inner(), Error::InvariantViolation(iv)
if iv.context().contains("quantized layer") && iv.requirement().contains("resolvable")),
"expected inner InvariantViolation about resolvable quantization params, got {:?}",
p.inner()
);
}
#[test]
fn compute_bits_per_weight_dense_f32_is_32() {
let mut w: Weights = HashMap::new();
w.insert("model.w.weight".to_string(), f32_weight(10));
let bpw = compute_bits_per_weight(&w, &PerLayerQuantization::default()).unwrap();
assert!((bpw - 32.0).abs() < 1e-9, "expected 32.0, got {bpw}");
}
#[test]
fn compute_bits_per_weight_quantized_includes_scale_overhead() {
let mut w: Weights = HashMap::new();
w.insert(
"model.q.weight".to_string(),
Array::from_slice::<u32>(&[0u32; 16], &(2usize, 8)).unwrap(),
);
w.insert(
"model.q.scales".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
w.insert(
"model.q.biases".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
let quant = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let bpw = compute_bits_per_weight(&w, &quant).unwrap();
let expected = 80.0 * 8.0 / 128.0;
assert!(
(bpw - expected).abs() < 1e-9,
"expected {expected}, got {bpw}"
);
}
#[test]
fn compute_bits_per_weight_zero_params_errors() {
let w: Weights = HashMap::new();
let err = compute_bits_per_weight(&w, &PerLayerQuantization::default());
assert!(
matches!(&err, Err(Error::EmptyInput(p))
if p.context() == "compute_bits_per_weight: model parameters"),
"expected Error::EmptyInput naming `model parameters`, got {err:?}"
);
}
#[test]
fn does_model_support_input_embeddings_false_for_text_model() {
let model = crate::lm::model::MockModel::new(4);
assert!(!does_model_support_input_embeddings(&model));
}
#[test]
fn shard_file_name_generation_tagged() {
let gen_id = "1234567890123-deadbeef-00000000cafef00d";
assert_eq!(
shard_file_name(gen_id, 1, 1),
format!("model-gen-{gen_id}-00001-of-00001.safetensors")
);
assert_eq!(
shard_file_name(gen_id, 1, 3),
format!("model-gen-{gen_id}-00001-of-00003.safetensors")
);
assert_eq!(
shard_file_name(gen_id, 3, 3),
format!("model-gen-{gen_id}-00003-of-00003.safetensors")
);
assert_ne!(
shard_file_name("first-gen-id", 1, 1),
shard_file_name("second-gen-id", 1, 1),
"different generation ids must produce different shard names"
);
}
#[test]
fn new_gen_id_shape_and_counter_advance() {
let a = new_gen_id();
let b = new_gen_id();
assert_ne!(a, b, "successive new_gen_id() calls must differ");
for id in [&a, &b] {
let parts: Vec<&str> = id.split('-').collect();
assert_eq!(
parts.len(),
3,
"gen_id has 3 dash-separated components: {id}"
);
assert!(
parts[0].len() >= 13,
"ts_us is at least 13 chars wide (the format-spec pad): {}",
parts[0]
);
assert!(
parts[0].chars().all(|c| c.is_ascii_digit()),
"ts_us is decimal: {}",
parts[0]
);
assert_eq!(parts[1].len(), 8, "pid is 8 hex chars: {}", parts[1]);
assert!(
parts[1].chars().all(|c| c.is_ascii_hexdigit()),
"pid is hex: {}",
parts[1]
);
assert_eq!(parts[2].len(), 16, "ctr is 16 hex chars: {}", parts[2]);
assert!(
parts[2].chars().all(|c| c.is_ascii_hexdigit()),
"ctr is hex: {}",
parts[2]
);
}
let a_parts: Vec<&str> = a.split('-').collect();
let b_parts: Vec<&str> = b.split('-').collect();
assert_eq!(
a_parts[1], b_parts[1],
"PID stable across calls in the same process"
);
}
#[test]
fn save_model_single_shard_round_trips() {
let dir = fresh_dir("save-model-single");
let mut w: Weights = HashMap::new();
w.insert(
"model.b.weight".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3usize,)).unwrap(),
);
w.insert(
"model.a.weight".to_string(),
Array::from_slice::<f32>(&[4.0, 5.0], &(2usize,)).unwrap(),
);
save_model(&dir, &w, &PerLayerQuantization::default()).unwrap();
let shards = collect_sorted(&dir, |n| {
n.starts_with("model-gen-") && n.ends_with("-00001-of-00001.safetensors")
})
.unwrap();
assert_eq!(
shards.len(),
1,
"exactly one generation-tagged single shard file"
);
assert!(dir.join("model.safetensors.index.json").is_file());
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(
loaded
.get_mut("model.a.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![4.0, 5.0]
);
assert_eq!(
loaded
.get_mut("model.b.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0, 2.0, 3.0]
);
let index_text = std::fs::read_to_string(dir.join("model.safetensors.index.json")).unwrap();
let index: serde_json::Value = serde_json::from_str(&index_text).unwrap();
assert_eq!(index["metadata"]["total_size"], 20);
assert_eq!(index["metadata"]["total_parameters"], 5);
let wm = index["weight_map"].as_object().unwrap();
assert_eq!(wm.len(), 2);
let shard_basename = shards[0]
.file_name()
.unwrap()
.to_string_lossy()
.into_owned();
assert_eq!(wm["model.a.weight"], shard_basename);
assert_eq!(wm["model.b.weight"], shard_basename);
let keys: Vec<&String> = wm.keys().collect();
assert_eq!(keys, vec!["model.a.weight", "model.b.weight"]);
assert!(index_text.contains("\n \"metadata\""));
assert!(
!index_text.ends_with('\n'),
"json.dump writes no trailing newline"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn make_shards_borrows_without_cloning() {
let mut w: Weights = HashMap::new();
w.insert("x".to_string(), f32_weight(3));
let shards = make_shards(&w, MAX_FILE_SIZE_GB).unwrap();
assert_eq!(shards.len(), 1);
let shard_ref: &Array = shards[0]["x"];
let map_ref: &Array = w.get("x").unwrap();
assert!(
std::ptr::eq(shard_ref, map_ref),
"make_shards must borrow the input array, not clone it"
);
}
#[test]
fn save_model_multi_shard_naming_and_index_reload() {
let dir = fresh_dir("save-model-multi");
let w0 = Array::from_slice::<f32>(&[10.0], &(1usize,)).unwrap();
let w1 = Array::from_slice::<f32>(&[20.0, 21.0], &(2usize,)).unwrap();
let shards: Vec<Shard<'_>> = vec![BTreeMap::from([("w0", &w0)]), BTreeMap::from([("w1", &w1)])];
let count = shards.len();
let gen_id = "1234567890123-deadbeef-00000000cafef00d";
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
let mut weight_map: BTreeMap<String, String> = BTreeMap::new();
let mut written_basenames: Vec<String> = Vec::new();
for (i, s) in shards.iter().enumerate() {
let name = shard_file_name(gen_id, i + 1, count);
assert_eq!(
name,
format!(
"model-gen-{gen_id}-{:05}-of-{:05}.safetensors",
i + 1,
count
)
);
crate::io::save_safetensors_view(&dir.join(&name), s.iter().map(|(&k, &v)| (k, v)), &meta)
.unwrap();
for &k in s.keys() {
weight_map.insert(k.to_string(), name.clone());
}
written_basenames.push(name);
}
write_json_pretty_to_path(
&dir.join("model.safetensors.index.json"),
&serde_json::json!({
"metadata": { "total_size": 12, "total_parameters": 3 },
"weight_map": weight_map,
}),
"test: 2-shard index",
)
.unwrap();
let on_disk: std::collections::BTreeSet<String> = collect_sorted(&dir, |n| {
n.starts_with("model-gen-") && n.ends_with(".safetensors")
})
.unwrap()
.into_iter()
.map(|p| p.file_name().unwrap().to_string_lossy().into_owned())
.collect();
let indexed: std::collections::BTreeSet<String> = weight_map.values().cloned().collect();
assert_eq!(
on_disk, indexed,
"index `weight_map` values must exactly match the on-disk shard set"
);
let expected: std::collections::BTreeSet<String> = written_basenames.into_iter().collect();
assert_eq!(
indexed, expected,
"index lists every generation-tagged shard we wrote, no more, no less"
);
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(
loaded.get_mut("w0").unwrap().to_vec::<f32>().unwrap(),
vec![10.0]
);
assert_eq!(
loaded.get_mut("w1").unwrap().to_vec::<f32>().unwrap(),
vec![20.0, 21.0]
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_config_cleans_mirrors_and_sorts() {
let dir = fresh_dir("save-config");
let path = dir.join("config.json");
let src = r#"{
"model_type": "qwen3",
"_name_or_path": "/tmp/should-be-dropped",
"vision_config": {"drop": "me"},
"hidden_size": 64,
"quantization": {"group_size": 64, "bits": 4}
}"#;
save_config(src, &path).unwrap();
let text = std::fs::read_to_string(&path).unwrap();
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
let obj = v.as_object().unwrap();
assert!(!obj.contains_key("_name_or_path"));
assert!(!obj.contains_key("vision_config"));
assert_eq!(obj["quantization"]["bits"], 4);
assert_eq!(obj["quantization_config"]["bits"], 4);
assert_eq!(obj["quantization_config"]["group_size"], 64);
assert_eq!(obj["model_type"], "qwen3");
assert_eq!(obj["hidden_size"], 64);
let keys: Vec<&String> = obj.keys().collect();
let mut sorted = keys.clone();
sorted.sort();
assert_eq!(keys, sorted, "config.json keys must be sorted");
assert!(text.contains("\n \""));
assert!(!text.ends_with('\n'));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_config_rejects_non_object_json() {
let dir = fresh_dir("save-config-bad");
let err = save_config("[1, 2, 3]", &dir.join("config.json"));
assert!(
matches!(&err, Err(Error::InvariantViolation(iv))
if iv.context() == "save_config: config JSON" && iv.requirement() == "must be an object"),
"expected Error::InvariantViolation for non-object JSON, got {err:?}"
);
let err2 = save_config("not json at all", &dir.join("config.json"));
assert!(
matches!(&err2, Err(Error::Parse(p))
if p.context() == "save_config: config" && p.input_kind() == "JSON"),
"expected Error::Parse for non-JSON body, got {err2:?}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_driver_writes_weights_and_config() {
let dir = fresh_dir("save-driver");
let mut w: Weights = HashMap::new();
w.insert(
"model.embed_tokens.weight".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(4usize,)).unwrap(),
);
let config = r#"{"model_type": "qwen3", "_name_or_path": "drop", "hidden_size": 8}"#;
save(&dir, &w, config, &PerLayerQuantization::default()).unwrap();
let shards = collect_sorted(&dir, |n| {
n.starts_with("model-gen-") && n.ends_with(".safetensors")
})
.unwrap();
assert_eq!(
shards.len(),
1,
"the save driver produced exactly one generation-tagged shard"
);
assert!(dir.join("model.safetensors.index.json").is_file());
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(
loaded
.get_mut("model.embed_tokens.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
let cfg_text = std::fs::read_to_string(dir.join("config.json")).unwrap();
let cfg: serde_json::Value = serde_json::from_str(&cfg_text).unwrap();
assert!(!cfg.as_object().unwrap().contains_key("_name_or_path"));
assert_eq!(cfg["model_type"], "qwen3");
assert_eq!(cfg["hidden_size"], 8);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_overwrite_loads_only_new_weights() {
let dir = fresh_dir("save-model-overwrite-loads-new");
let stale_vals = [
("stale.a.weight", vec![100.0_f32]),
("stale.b.weight", vec![200.0_f32, 201.0]),
("stale.c.weight", vec![300.0_f32, 301.0, 302.0]),
];
let stale_arrays: Vec<(&str, Array)> = stale_vals
.iter()
.map(|(k, v)| (*k, Array::from_slice::<f32>(v, &(v.len(),)).unwrap()))
.collect();
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
let stale_count = stale_arrays.len();
let mut stale_map: BTreeMap<String, String> = BTreeMap::new();
for (i, (k, arr)) in stale_arrays.iter().enumerate() {
let name = format!("model-{:05}-of-{:05}.safetensors", i + 1, stale_count);
crate::io::save_safetensors_view(&dir.join(&name), std::iter::once((*k, arr)), &meta).unwrap();
stale_map.insert((*k).to_string(), name);
}
write_json_pretty_to_path(
&dir.join("model.safetensors.index.json"),
&serde_json::json!({
"metadata": { "total_size": 24, "total_parameters": 6 },
"weight_map": stale_map,
}),
"test: stale index",
)
.unwrap();
let mut new_w: Weights = HashMap::new();
new_w.insert(
"fresh.x.weight".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap(),
);
new_w.insert(
"fresh.y.weight".to_string(),
Array::from_slice::<f32>(&[3.0], &(1usize,)).unwrap(),
);
save_model(&dir, &new_w, &PerLayerQuantization::default()).unwrap();
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 2, "only the two new weights load back");
assert!(loaded.contains_key("fresh.x.weight"));
assert!(loaded.contains_key("fresh.y.weight"));
assert!(!loaded.contains_key("stale.a.weight"));
assert!(!loaded.contains_key("stale.b.weight"));
assert!(!loaded.contains_key("stale.c.weight"));
assert_eq!(
loaded
.get_mut("fresh.x.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0, 2.0]
);
let index_text = std::fs::read_to_string(dir.join("model.safetensors.index.json")).unwrap();
let index: serde_json::Value = serde_json::from_str(&index_text).unwrap();
let wm = index["weight_map"].as_object().unwrap();
assert_eq!(wm.len(), 2);
let shard_x = wm["fresh.x.weight"].as_str().unwrap();
let shard_y = wm["fresh.y.weight"].as_str().unwrap();
assert_eq!(
shard_x, shard_y,
"both new weights land in the same single shard"
);
assert!(
shard_x.starts_with("model-gen-") && shard_x.ends_with("-00001-of-00001.safetensors"),
"new shard is generation-tagged: got {shard_x}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_resave_same_checkpoint_is_stable() {
let dir = fresh_dir("save-model-resave");
let mut w: Weights = HashMap::new();
w.insert("m.w.weight".to_string(), f32_weight(4));
save_model(&dir, &w, &PerLayerQuantization::default()).unwrap();
save_model(&dir, &w, &PerLayerQuantization::default()).unwrap();
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 1);
assert!(loaded.contains_key("m.w.weight"));
assert_eq!(
loaded
.get_mut("m.w.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![0.0, 0.0, 0.0, 0.0]
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_no_overwrite_of_old_shards() {
let dir = fresh_dir("save-no-overwrite");
let mut first: Weights = HashMap::new();
first.insert(
"w.first.weight".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3usize,)).unwrap(),
);
save_model(&dir, &first, &PerLayerQuantization::default()).unwrap();
let first_shards: Vec<String> = collect_sorted(&dir, |n| {
n.starts_with("model-gen-") && n.ends_with(".safetensors")
})
.unwrap()
.into_iter()
.map(|p| p.file_name().unwrap().to_string_lossy().into_owned())
.collect();
assert_eq!(first_shards.len(), 1, "first save writes one shard");
std::thread::sleep(std::time::Duration::from_millis(5));
let mut second: Weights = HashMap::new();
second.insert(
"w.second.weight".to_string(),
Array::from_slice::<f32>(&[10.0, 20.0], &(2usize,)).unwrap(),
);
save_model(&dir, &second, &PerLayerQuantization::default()).unwrap();
let all_shards: Vec<String> = collect_sorted(&dir, |n| {
n.starts_with("model-gen-") && n.ends_with(".safetensors")
})
.unwrap()
.into_iter()
.map(|p| p.file_name().unwrap().to_string_lossy().into_owned())
.collect();
assert_eq!(
all_shards.len(),
2,
"both saves' shard files coexist on disk (no inline cleanup); got {all_shards:?}"
);
for s in &first_shards {
assert!(
all_shards.contains(s),
"the first save's shard {s} must survive the second save"
);
}
let index_text = std::fs::read_to_string(dir.join("model.safetensors.index.json")).unwrap();
let index: serde_json::Value = serde_json::from_str(&index_text).unwrap();
let wm = index["weight_map"].as_object().unwrap();
assert_eq!(wm.len(), 1, "second save's index lists one weight");
let indexed: std::collections::BTreeSet<String> = wm
.values()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
assert_eq!(
indexed.len(),
1,
"all keys in the new index reference exactly one shard"
);
let indexed_shard = indexed.iter().next().unwrap().clone();
assert!(
!first_shards.contains(&indexed_shard),
"the second save's index must not reference the first save's shard"
);
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 1, "load sees only the new checkpoint");
assert!(
loaded.contains_key("w.second.weight"),
"the second save's weight loads"
);
assert!(
!loaded.contains_key("w.first.weight"),
"the first save's weight is invisible to load (orphan on disk only)"
);
assert_eq!(
loaded
.get_mut("w.second.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![10.0, 20.0]
);
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn save_model_failed_save_keeps_previous_checkpoint_intact() {
use std::os::unix::fs::PermissionsExt;
let dir = fresh_dir("save-model-failed-intact");
let mut orig: Weights = HashMap::new();
orig.insert(
"orig.a.weight".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3usize,)).unwrap(),
);
orig.insert(
"orig.b.weight".to_string(),
Array::from_slice::<f32>(&[4.0, 5.0], &(2usize,)).unwrap(),
);
save_model(&dir, &orig, &PerLayerQuantization::default()).unwrap();
let orig_shards: Vec<std::path::PathBuf> = collect_sorted(&dir, |n| {
n.starts_with("model-gen-") && n.ends_with(".safetensors")
})
.unwrap();
assert!(
!orig_shards.is_empty(),
"the original save produced at least one generation-tagged shard"
);
let orig_shard_bytes: BTreeMap<std::path::PathBuf, Vec<u8>> = orig_shards
.iter()
.map(|p| (p.clone(), std::fs::read(p).unwrap()))
.collect();
let orig_index = std::fs::read_to_string(dir.join("model.safetensors.index.json")).unwrap();
let mut perms = std::fs::metadata(&dir).unwrap().permissions();
let orig_mode = perms.mode();
perms.set_mode(0o500); std::fs::set_permissions(&dir, perms).unwrap();
let mut replacement: Weights = HashMap::new();
replacement.insert("SHOULD.NOT.WIN.weight".to_string(), f32_weight(7));
let r = save_model(&dir, &replacement, &PerLayerQuantization::default());
let mut restore = std::fs::metadata(&dir).unwrap().permissions();
restore.set_mode(orig_mode);
std::fs::set_permissions(&dir, restore).unwrap();
assert!(r.is_err(), "a save into a read-only dir must fail");
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 2, "only the original two weights load back");
assert!(loaded.contains_key("orig.a.weight"));
assert!(loaded.contains_key("orig.b.weight"));
assert!(
!loaded.contains_key("SHOULD.NOT.WIN.weight"),
"the failed save's weight must not have leaked in"
);
assert_eq!(
loaded
.get_mut("orig.a.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0, 2.0, 3.0]
);
assert_eq!(
loaded
.get_mut("orig.b.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![4.0, 5.0]
);
assert_eq!(
std::fs::read_to_string(dir.join("model.safetensors.index.json")).unwrap(),
orig_index,
"the original index.json must survive the failed save unchanged"
);
for (path, bytes) in &orig_shard_bytes {
assert!(
path.is_file(),
"original shard {} must survive the failed save",
path.display()
);
assert_eq!(
&std::fs::read(path).unwrap(),
bytes,
"original shard {} must be byte-identical after the failed save",
path.display()
);
}
let leftover_tmp = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.file_name()
.to_string_lossy()
.ends_with(".tmp.safetensors")
});
assert!(
!leftover_tmp,
"no partial tempfile may remain after a failed save"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_failed_save_rename_failure_cleans_up_tempfiles() {
let dir = fresh_dir("save-model-failed-rename");
std::fs::create_dir_all(dir.join("model.safetensors.index.json")).unwrap();
let mut w: Weights = HashMap::new();
w.insert("m.w.weight".to_string(), f32_weight(4));
let r = save_model(&dir, &w, &PerLayerQuantization::default());
assert!(
r.is_err(),
"rename of the index onto an existing directory must fail"
);
assert!(
dir.join("model.safetensors.index.json").is_dir(),
"the colliding directory at the index path must be left untouched"
);
let leftover_tmp = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.file_name()
.to_string_lossy()
.ends_with(".tmp.safetensors")
});
assert!(
!leftover_tmp,
"every staged tempfile must be removed when a rename fails"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn save_config_failed_write_keeps_previous_config_intact() {
use std::os::unix::fs::PermissionsExt;
let dir = fresh_dir("save-config-failed-intact");
let config_path = dir.join("config.json");
save_config(r#"{"model_type": "good", "hidden_size": 8}"#, &config_path).unwrap();
let orig = std::fs::read_to_string(&config_path).unwrap();
let mut perms = std::fs::metadata(&dir).unwrap().permissions();
let orig_mode = perms.mode();
perms.set_mode(0o500);
std::fs::set_permissions(&dir, perms).unwrap();
let r = save_config(r#"{"model_type": "SHOULD-NOT-WIN"}"#, &config_path);
let mut restore = std::fs::metadata(&dir).unwrap().permissions();
restore.set_mode(orig_mode);
std::fs::set_permissions(&dir, restore).unwrap();
assert!(r.is_err(), "a config write into a read-only dir must fail");
assert_eq!(
std::fs::read_to_string(&config_path).unwrap(),
orig,
"the original config.json must survive the failed write unchanged"
);
let leftover_tmp = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.file_name()
.to_string_lossy()
.ends_with(".tmp.safetensors")
});
assert!(
!leftover_tmp,
"no partial tempfile may remain after a failed config write"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_ignores_stale_shards_not_in_index() {
let dir = fresh_dir("load-ignores-stale");
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
let real = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3usize,)).unwrap();
crate::io::save_safetensors_view(
&dir.join("model.safetensors"),
std::iter::once(("real.weight", &real)),
&meta,
)
.unwrap();
let stale = Array::from_slice::<f32>(&[99.0], &(1usize,)).unwrap();
crate::io::save_safetensors_view(
&dir.join("model-00099-of-00099.safetensors"),
std::iter::once(("stale.weight", &stale)),
&meta,
)
.unwrap();
let mut weight_map: BTreeMap<String, String> = BTreeMap::new();
weight_map.insert("real.weight".to_string(), "model.safetensors".to_string());
write_json_pretty_to_path(
&dir.join("model.safetensors.index.json"),
&serde_json::json!({
"metadata": { "total_size": 12, "total_parameters": 3 },
"weight_map": weight_map,
}),
"test: index ignores stale",
)
.unwrap();
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(
loaded.len(),
1,
"only the indexed weight loads; the stale shard is invisible"
);
assert!(loaded.contains_key("real.weight"));
assert!(
!loaded.contains_key("stale.weight"),
"an out-of-index shard must NOT resurrect tensors on load"
);
assert_eq!(
loaded
.get_mut("real.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0, 2.0, 3.0]
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_no_index_single_model_safetensors_loads() {
let dir = fresh_dir("load-single-no-index");
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
let w = Array::from_slice::<f32>(&[7.0, 8.0], &(2usize,)).unwrap();
crate::io::save_safetensors_view(
&dir.join("model.safetensors"),
std::iter::once(("only.weight", &w)),
&meta,
)
.unwrap();
assert!(!dir.join("model.safetensors.index.json").exists());
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(
loaded
.get_mut("only.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![7.0, 8.0]
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_legacy_weights_safetensors_fallback_loads() {
let dir = fresh_dir("load-legacy-weights");
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
let w = Array::from_slice::<f32>(&[42.0], &(1usize,)).unwrap();
crate::io::save_safetensors_view(
&dir.join("weights.safetensors"),
std::iter::once(("legacy.weight", &w)),
&meta,
)
.unwrap();
assert!(!dir.join("model.safetensors").exists());
assert!(!dir.join("model.safetensors.index.json").exists());
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(
loaded
.get_mut("legacy.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![42.0]
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_index_lists_missing_shard_errors() {
let dir = fresh_dir("load-index-missing-shard");
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
let w = Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap();
crate::io::save_safetensors_view(
&dir.join("model-00001-of-00002.safetensors"),
std::iter::once(("a.weight", &w)),
&meta,
)
.unwrap();
let mut weight_map: BTreeMap<String, String> = BTreeMap::new();
weight_map.insert(
"a.weight".to_string(),
"model-00001-of-00002.safetensors".to_string(),
);
weight_map.insert(
"b.weight".to_string(),
"model-00002-of-00002.safetensors".to_string(),
);
write_json_pretty_to_path(
&dir.join("model.safetensors.index.json"),
&serde_json::json!({
"metadata": { "total_size": 8, "total_parameters": 2 },
"weight_map": weight_map,
}),
"test: missing-shard index",
)
.unwrap();
let r = load_weights(&dir);
let Err(Error::FileIo(p)) = r else {
panic!("a missing indexed shard must be an Error::FileIo, got {r:?}");
};
assert_eq!(p.op(), FileOp::Stat);
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
assert_eq!(
p.path(),
dir.join("model-00002-of-00002.safetensors").as_path(),
"path must name the missing shard"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_index_with_path_traversal_errors() {
let dir = fresh_dir("load-index-path-traversal");
let mut weight_map: BTreeMap<String, String> = BTreeMap::new();
weight_map.insert("evil.weight".to_string(), "../../etc/passwd".to_string());
write_json_pretty_to_path(
&dir.join("model.safetensors.index.json"),
&serde_json::json!({
"metadata": { "total_size": 0, "total_parameters": 0 },
"weight_map": weight_map,
}),
"test: path-traversal index",
)
.unwrap();
let r = load_weights(&dir);
let Err(Error::LayerKeyed(p)) = r else {
panic!("a path-traversal shard name must be an Error::LayerKeyed, got {r:?}");
};
assert!(
p.layer().contains("weight_map[evil.weight]") && p.layer().contains("../../etc/passwd"),
"layer should name the offending mapping, got `{}`",
p.layer()
);
assert!(
matches!(p.inner(), Error::InvariantViolation(iv)
if iv.context().contains("weight_map shard name")
&& iv.requirement().contains("bare basename")),
"expected inner InvariantViolation about bare basename, got {:?}",
p.inner()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_malformed_index_errors() {
let dir = fresh_dir("load-index-malformed");
std::fs::write(
dir.join("model.safetensors.index.json"),
b"this is not valid JSON {{{",
)
.unwrap();
let r = load_weights(&dir);
let Err(Error::LayerKeyed(p)) = r else {
panic!("expected Error::LayerKeyed for a malformed index, got {r:?}");
};
assert!(
p.layer().contains("model.safetensors.index.json"),
"layer should name the index path, got `{}`",
p.layer()
);
assert!(
matches!(p.inner(), Error::Parse(pp)
if pp.context() == "load_via_index: model weight index" && pp.input_kind() == "JSON"),
"expected inner Error::Parse for malformed JSON, got {:?}",
p.inner()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_empty_dir_errors_listing_layouts() {
let dir = fresh_dir("load-empty");
let r = load_weights(&dir);
let Err(Error::FileIo(p)) = r else {
panic!("an empty dir must be an Error::FileIo, got {r:?}");
};
assert_eq!(p.path(), dir.as_path());
assert_eq!(p.op(), FileOp::Open);
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
let ctx = p.context();
assert!(
ctx.contains("model.safetensors.index.json")
&& ctx.contains("model.safetensors")
&& ctx.contains("weights.safetensors"),
"the context must list each resolver tier, got: {ctx}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_torn_publish_before_index_rename_keeps_old_checkpoint() {
let dir = fresh_dir("torn-publish-before-index-rename");
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
let old_a = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap();
let old_b = Array::from_slice::<f32>(&[3.0, 4.0, 5.0], &(3usize,)).unwrap();
crate::io::save_safetensors_view(
&dir.join("model-00001-of-00002.safetensors"),
std::iter::once(("old.a.weight", &old_a)),
&meta,
)
.unwrap();
crate::io::save_safetensors_view(
&dir.join("model-00002-of-00002.safetensors"),
std::iter::once(("old.b.weight", &old_b)),
&meta,
)
.unwrap();
let mut old_wm: BTreeMap<String, String> = BTreeMap::new();
old_wm.insert(
"old.a.weight".to_string(),
"model-00001-of-00002.safetensors".to_string(),
);
old_wm.insert(
"old.b.weight".to_string(),
"model-00002-of-00002.safetensors".to_string(),
);
let old_index_text = serde_json::to_string(&serde_json::json!({
"metadata": { "total_size": 20, "total_parameters": 5 },
"weight_map": old_wm,
}))
.unwrap();
std::fs::write(
dir.join("model.safetensors.index.json"),
old_index_text.as_bytes(),
)
.unwrap();
let mut sanity = load_weights(&dir).unwrap();
assert_eq!(sanity.len(), 2);
assert_eq!(
sanity
.get_mut("old.a.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0, 2.0]
);
drop(sanity);
std::fs::remove_file(dir.join("model.safetensors.index.json")).unwrap();
std::fs::create_dir_all(dir.join("model.safetensors.index.json")).unwrap();
let mut new_w: Weights = HashMap::new();
new_w.insert(
"new.x.weight".to_string(),
Array::from_slice::<f32>(&[100.0], &(1usize,)).unwrap(),
);
let r = save_model(&dir, &new_w, &PerLayerQuantization::default());
assert!(
r.is_err(),
"the index rename onto an existing directory must fail"
);
let old_a_path = dir.join("model-00001-of-00002.safetensors");
let old_b_path = dir.join("model-00002-of-00002.safetensors");
assert!(
old_a_path.is_file(),
"OLD shard 1 must survive the failed save"
);
assert!(
old_b_path.is_file(),
"OLD shard 2 must survive the failed save"
);
let new_shards_on_disk: Vec<std::path::PathBuf> = collect_sorted(&dir, |n| {
n.starts_with("model-gen-") && n.ends_with(".safetensors")
})
.unwrap();
assert_eq!(
new_shards_on_disk.len(),
1,
"the NEW shard rename SHOULD have succeeded (it's the index rename that fails); \
this asserts the torn-publish scenario the test is targeting"
);
let leftover_tmp = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.file_name()
.to_string_lossy()
.ends_with(".tmp.safetensors")
});
assert!(
!leftover_tmp,
"every staged tempfile must be removed when the index rename fails"
);
std::fs::remove_dir_all(dir.join("model.safetensors.index.json")).unwrap();
std::fs::write(
dir.join("model.safetensors.index.json"),
old_index_text.as_bytes(),
)
.unwrap();
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(
loaded.len(),
2,
"the OLD checkpoint loads EXACTLY (both weights)"
);
assert_eq!(
loaded
.get_mut("old.a.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0, 2.0],
"old.a is byte-identical"
);
assert_eq!(
loaded
.get_mut("old.b.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![3.0, 4.0, 5.0],
"old.b is byte-identical"
);
assert!(
!loaded.contains_key("new.x.weight"),
"the NEW shard is on disk but the OLD index ignores it"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_torn_after_shard_rename_before_index_rename_keeps_old_checkpoint() {
let dir = fresh_dir("torn-after-shard-before-index");
let mut first: Weights = HashMap::new();
first.insert(
"first.alpha.weight".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3usize,)).unwrap(),
);
first.insert(
"first.beta.weight".to_string(),
Array::from_slice::<f32>(&[4.0, 5.0], &(2usize,)).unwrap(),
);
save_model(&dir, &first, &PerLayerQuantization::default()).unwrap();
let old_shard_paths: Vec<std::path::PathBuf> = collect_sorted(&dir, |n| {
n.starts_with("model-gen-") && n.ends_with(".safetensors")
})
.unwrap();
assert!(
!old_shard_paths.is_empty(),
"first save produced at least one shard"
);
let old_shard_bytes: BTreeMap<std::path::PathBuf, Vec<u8>> = old_shard_paths
.iter()
.map(|p| (p.clone(), std::fs::read(p).unwrap()))
.collect();
let old_index_text = std::fs::read_to_string(dir.join("model.safetensors.index.json")).unwrap();
std::fs::remove_file(dir.join("model.safetensors.index.json")).unwrap();
std::fs::create_dir_all(dir.join("model.safetensors.index.json")).unwrap();
std::thread::sleep(std::time::Duration::from_millis(5));
let mut second: Weights = HashMap::new();
second.insert(
"second.gamma.weight".to_string(),
Array::from_slice::<f32>(&[100.0, 200.0], &(2usize,)).unwrap(),
);
let r = save_model(&dir, &second, &PerLayerQuantization::default());
assert!(
r.is_err(),
"the index rename onto an existing directory must fail"
);
for (path, bytes) in &old_shard_bytes {
assert!(
path.is_file(),
"OLD shard {} must survive the failed save",
path.display()
);
assert_eq!(
&std::fs::read(path).unwrap(),
bytes,
"OLD shard {} must be byte-identical after the failed save",
path.display()
);
}
std::fs::remove_dir_all(dir.join("model.safetensors.index.json")).unwrap();
std::fs::write(
dir.join("model.safetensors.index.json"),
old_index_text.as_bytes(),
)
.unwrap();
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 2);
assert!(loaded.contains_key("first.alpha.weight"));
assert!(loaded.contains_key("first.beta.weight"));
assert!(
!loaded.contains_key("second.gamma.weight"),
"the SECOND save's shard is on disk but the OLD index ignores it"
);
assert_eq!(
loaded
.get_mut("first.alpha.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0, 2.0, 3.0]
);
assert_eq!(
loaded
.get_mut("first.beta.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![4.0, 5.0]
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn fsync_dir_helper_basic() {
let dir = fresh_dir("fsync-dir-helper");
let r: std::io::Result<()> = fsync_dir(&dir);
r.expect("fsync_dir must succeed on a writable tmpdir");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_invalid_config_keeps_existing_checkpoint_byte_identical() {
let dir = fresh_dir("save-invalid-config-intact");
let mut w: Weights = HashMap::new();
w.insert(
"orig.a.weight".to_string(),
Array::from_slice::<f32>(&[10.0, 20.0], &(2usize,)).unwrap(),
);
w.insert(
"orig.b.weight".to_string(),
Array::from_slice::<f32>(&[30.0], &(1usize,)).unwrap(),
);
let good_config = r#"{"model_type": "qwen3", "hidden_size": 64}"#;
save(&dir, &w, good_config, &PerLayerQuantization::default()).unwrap();
let snapshot = |dir: &Path| -> BTreeMap<String, Vec<u8>> {
let mut m: BTreeMap<String, Vec<u8>> = BTreeMap::new();
for e in std::fs::read_dir(dir).unwrap().flatten() {
if e.file_type().unwrap().is_file() {
let name = e.file_name().to_string_lossy().into_owned();
let bytes = std::fs::read(e.path()).unwrap();
m.insert(name, bytes);
}
}
m
};
let before = snapshot(&dir);
assert!(
before
.keys()
.any(|k| k.starts_with("model-gen-") && k.ends_with(".safetensors")),
"the initial save produced a generation-tagged shard"
);
assert!(before.contains_key("model.safetensors.index.json"));
assert!(before.contains_key("config.json"));
let bad_config = "this is not valid JSON at all";
let other_weights: Weights = {
let mut m: Weights = HashMap::new();
m.insert(
"SHOULD.NOT.WIN.weight".to_string(),
Array::from_slice::<f32>(&[999.0], &(1usize,)).unwrap(),
);
m
};
let r = save(
&dir,
&other_weights,
bad_config,
&PerLayerQuantization::default(),
);
assert!(r.is_err(), "an invalid config must abort the save");
let after = snapshot(&dir);
let strip_tmp = |m: BTreeMap<String, Vec<u8>>| -> BTreeMap<String, Vec<u8>> {
m.into_iter()
.filter(|(k, _)| !k.ends_with(".tmp.safetensors"))
.collect()
};
let leftover_tmp = after.keys().any(|k| k.ends_with(".tmp.safetensors"));
assert_eq!(
strip_tmp(before),
strip_tmp(after),
"every file under {} must be byte-identical after an invalid-config save",
dir.display()
);
assert!(
!leftover_tmp,
"no staged config tempfile may remain after an invalid-config save"
);
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(
loaded
.get_mut("orig.a.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![10.0, 20.0]
);
assert!(!loaded.contains_key("SHOULD.NOT.WIN.weight"));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn get_total_parameters_scale_only_biases_is_error() {
for mode in [QuantMode::Mxfp4, QuantMode::Mxfp8, QuantMode::Nvfp4] {
let mut w: Weights = HashMap::new();
w.insert(
"model.layers.0.q_proj.weight".to_string(),
Array::from_slice::<u32>(&[0u32; 8], &(2usize, 4)).unwrap(),
);
w.insert(
"model.layers.0.q_proj.scales".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
w.insert(
"model.layers.0.q_proj.biases".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
let quant = PerLayerQuantization::from_global(Quantization {
group_size: 32,
bits: 4,
mode,
});
let err = get_total_parameters(&w, &quant);
let Err(Error::LayerKeyed(p)) = &err else {
panic!(
"a `.biases` under scale-only `{}` must be an Error::LayerKeyed, got {err:?}",
mode.as_str()
);
};
assert!(
p.layer().contains("q_proj") && p.layer().ends_with(".biases"),
"layer should name the offending `.biases` key, got `{}`",
p.layer()
);
let Error::KeyCollision(kp) = p.inner() else {
panic!("expected inner Error::KeyCollision, got {:?}", p.inner());
};
assert!(
kp.context().contains("mxfp4")
&& kp.context().contains("mxfp8")
&& kp.context().contains("nvfp4"),
"context should list the scale-only modes, got: {}",
kp.context()
);
assert_eq!(kp.key(), p.layer());
}
}
#[test]
fn get_total_parameters_affine_biases_still_skipped() {
let mut w: Weights = HashMap::new();
w.insert(
"model.layers.0.q_proj.weight".to_string(),
Array::from_slice::<u32>(&[0u32; 8], &(2usize, 4)).unwrap(),
);
w.insert(
"model.layers.0.q_proj.scales".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
w.insert(
"model.layers.0.q_proj.biases".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
let quant = PerLayerQuantization::from_global(Quantization::affine(32, 4));
let total = get_total_parameters(&w, &quant).unwrap();
assert_eq!(
total, 64,
"affine `.biases` skipped, only unpacked weight counts"
);
}
#[test]
fn gen_id_is_collision_resistant_across_same_ms_saves() {
let dir_a = fresh_dir("gen-id-collision-a");
let dir_b = fresh_dir("gen-id-collision-b");
let mut w: Weights = HashMap::new();
w.insert("w.weight".to_string(), f32_weight(2));
save_model(&dir_a, &w, &PerLayerQuantization::default()).unwrap();
save_model(&dir_b, &w, &PerLayerQuantization::default()).unwrap();
let basenames = |dir: &Path| -> Vec<String> {
collect_sorted(dir, |n| {
n.starts_with("model-gen-") && n.ends_with(".safetensors")
})
.unwrap()
.into_iter()
.map(|p| p.file_name().unwrap().to_string_lossy().into_owned())
.collect()
};
let a = basenames(&dir_a);
let b = basenames(&dir_b);
assert_eq!(a.len(), 1);
assert_eq!(b.len(), 1);
assert_ne!(
a[0], b[0],
"two same-process saves must produce distinct gen_id-tagged basenames; \
got {a:?} == {b:?}"
);
let _ = std::fs::remove_dir_all(&dir_a);
let _ = std::fs::remove_dir_all(&dir_b);
}
#[test]
fn save_model_refuses_to_overwrite_existing_shard_basename() {
let dir = fresh_dir("save-refuses-overwrite");
let forced_gen_id = "9999999999999-cafebabe-0000000000000042";
let collision_path = dir.join(shard_file_name(forced_gen_id, 1, 1));
let decoy_bytes = b"pre-existing decoy bytes that must NOT be overwritten";
std::fs::write(&collision_path, decoy_bytes).unwrap();
force_next_gen_id(forced_gen_id);
let mut w: Weights = HashMap::new();
w.insert("w.weight".to_string(), f32_weight(2));
let r = save_model(&dir, &w, &PerLayerQuantization::default());
match r {
Err(Error::ShardPathCollision(path)) => {
assert_eq!(
path, collision_path,
"the collision error names the planted path"
);
}
other => panic!("expected Err(ShardPathCollision), got {other:?}"),
}
assert!(
collision_path.is_file(),
"the planted decoy at {} must still be a file",
collision_path.display()
);
assert_eq!(
std::fs::read(&collision_path).unwrap(),
decoy_bytes,
"the planted decoy must be byte-identical (hard_link refused to replace)"
);
let leftover_tmp = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.file_name()
.to_string_lossy()
.ends_with(".tmp.safetensors")
});
assert!(
!leftover_tmp,
"no staged tempfile may remain after a ShardPathCollision abort"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_concurrent_create_at_final_path_returns_collision_error_not_silent_overwrite() {
let dir = fresh_dir("save-toctou-no-silent-overwrite");
let forced_gen_id = "7777777777777-feedface-00000000beefcafe";
let final_shard = dir.join(shard_file_name(forced_gen_id, 1, 1));
let racer_bytes = b"racer-bytes: a concurrent writer's payload that MUST survive";
std::fs::write(&final_shard, racer_bytes).unwrap();
force_next_gen_id(forced_gen_id);
let mut w: Weights = HashMap::new();
w.insert("z.weight".to_string(), f32_weight(3));
let r = save_model(&dir, &w, &PerLayerQuantization::default());
match r {
Err(Error::ShardPathCollision(path)) => {
assert_eq!(
path, final_shard,
"collision error names the planted (racer) path"
);
}
other => {
panic!("expected Err(ShardPathCollision) from atomic no-replace hard_link, got {other:?}")
}
}
assert!(
final_shard.is_file(),
"the racer file at {} must still be a regular file",
final_shard.display()
);
assert_eq!(
std::fs::read(&final_shard).unwrap(),
racer_bytes,
"racer bytes must be byte-identical — atomic no-replace forbids silent overwrite"
);
let leftover_tmp = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.file_name()
.to_string_lossy()
.ends_with(".tmp.safetensors")
});
assert!(
!leftover_tmp,
"no staged .tmp.safetensors may remain after a ShardPathCollision"
);
let index_path = dir.join("model.safetensors.index.json");
assert!(
!index_path.exists(),
"no index commit may occur when shard publish fails: {} exists",
index_path.display()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_post_index_fsync_failure_keeps_visible_checkpoint() {
let dir = fresh_dir("post-index-fsync-failure");
let mut w: Weights = HashMap::new();
w.insert(
"v.alpha.weight".to_string(),
Array::from_slice::<f32>(&[7.0, 8.0, 9.0], &(3usize,)).unwrap(),
);
w.insert(
"v.beta.weight".to_string(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
let _guard = arm_fsync_dir_fault(1);
let outcome = save_model(&dir, &w, &PerLayerQuantization::default())
.expect("post-index fsync failure must NOT propagate as Err — it is a durability warning");
drop(_guard);
let underlying = match outcome {
CommitOutcome::CommittedWithDurabilityWarning(e) => e,
CommitOutcome::Committed => {
panic!("expected CommittedWithDurabilityWarning, got Committed")
}
};
let underlying_msg = underlying.to_string();
assert!(
underlying_msg.contains("injected fsync_dir failure"),
"the durability warning carries the underlying io::Error: got {underlying_msg}"
);
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(loaded.len(), 2);
assert!(loaded.contains_key("v.alpha.weight"));
assert!(loaded.contains_key("v.beta.weight"));
assert_eq!(
loaded
.get_mut("v.alpha.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![7.0, 8.0, 9.0]
);
assert_eq!(
loaded
.get_mut("v.beta.weight")
.unwrap()
.to_vec::<f32>()
.unwrap(),
vec![1.0]
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_post_commit_durability_warning_still_commits_config() {
let dir = fresh_dir("save-post-commit-warning-commits-config");
let mut w0: Weights = HashMap::new();
w0.insert("w.weight".to_string(), f32_weight(2));
let before_config = r#"{"model_type": "OLD", "hidden_size": 4}"#;
save(&dir, &w0, before_config, &PerLayerQuantization::default()).unwrap();
let old_cfg = std::fs::read_to_string(dir.join("config.json")).unwrap();
assert!(
old_cfg.contains("\"OLD\""),
"the OLD config.json was written"
);
let mut w1: Weights = HashMap::new();
w1.insert(
"w.weight".to_string(),
Array::from_slice::<f32>(&[5.0, 6.0], &(2usize,)).unwrap(),
);
let after_config = r#"{"model_type": "NEW", "hidden_size": 8}"#;
let _guard = arm_fsync_dir_fault(1);
let r = save(&dir, &w1, after_config, &PerLayerQuantization::default());
drop(_guard);
match r {
Err(Error::DurabilityWarning(p)) => {
assert!(
p.committed(),
"save's DurabilityWarning must carry committed=true"
);
assert!(
p.source()
.to_string()
.contains("injected fsync_dir failure"),
"the underlying io::Error must be preserved: got {}",
p.source()
);
}
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
let new_cfg = std::fs::read_to_string(dir.join("config.json")).unwrap();
assert!(
new_cfg.contains("\"NEW\""),
"the NEW config.json must be on disk: got {new_cfg}"
);
assert!(
!new_cfg.contains("\"OLD\""),
"the OLD config.json content must be gone: got {new_cfg}"
);
let expected_cfg = {
let v: serde_json::Value = serde_json::from_str(after_config).unwrap();
let obj = v.as_object().unwrap().clone();
let sorted: BTreeMap<String, serde_json::Value> = obj.into_iter().collect();
let mut buf = Vec::new();
let fmt = serde_json::ser::PrettyFormatter::with_indent(b" ");
let mut ser = serde_json::Serializer::with_formatter(&mut buf, fmt);
serde::Serialize::serialize(&sorted, &mut ser).unwrap();
String::from_utf8(buf).unwrap()
};
assert_eq!(
new_cfg, expected_cfg,
"the NEW config.json must be byte-equal to the staged (cleaned/sorted) form"
);
let mut loaded = load_weights(&dir).unwrap();
assert_eq!(
loaded.get_mut("w.weight").unwrap().to_vec::<f32>().unwrap(),
vec![5.0, 6.0]
);
let leftover = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.file_name()
.to_string_lossy()
.ends_with(".tmp.safetensors")
});
assert!(!leftover, "no staged tempfile may leak");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn open_excl_temp_shard_returns_file_and_path() {
use std::io::Write as _;
let dir = fresh_dir("load1-open-excl-shape");
let final_path = dir.join("model-00001-of-00001.safetensors");
let (mut f, tmp) = open_excl_temp_shard(&final_path).unwrap();
assert_eq!(tmp.parent().unwrap(), final_path.parent().unwrap());
assert!(
tmp
.file_name()
.unwrap()
.to_string_lossy()
.ends_with(".tmp.safetensors"),
"tempfile must keep the .tmp.safetensors suffix, got {}",
tmp.display()
);
assert!(tmp.exists(), "open_excl_temp_shard must create the file");
let payload = b"LOAD-1: fd-bound shard tempfile";
f.write_all(payload).unwrap();
drop(f);
let on_disk = std::fs::read(&tmp).unwrap();
assert_eq!(
on_disk, payload,
"bytes written through the returned File must land at the returned path"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_safetensors_to_file_writes_via_fd_not_reopen_by_path() {
use std::os::unix::fs::MetadataExt;
let dir = fresh_dir("load1-safetensors-fd-not-reopen");
let staging = dir.join("staging.tmp.safetensors");
let decoy = dir.join("decoy.target");
std::fs::write(&decoy, b"DECOY: must not be overwritten").unwrap();
let decoy_meta_before = std::fs::metadata(&decoy).unwrap();
let decoy_inode_before = decoy_meta_before.ino();
let (mut staging_file, staging_path) = open_excl_temp_shard(&staging).unwrap();
let staging_inode = std::fs::metadata(&staging_path).unwrap().ino();
assert_ne!(
staging_inode, decoy_inode_before,
"test sanity: staging tempfile + decoy must be distinct inodes"
);
std::fs::remove_file(&staging_path).unwrap();
std::os::unix::fs::symlink(&decoy, &staging_path).unwrap();
let arr = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3usize,)).unwrap();
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
crate::io::save_safetensors_to_file(&mut staging_file, std::iter::once(("w", &arr)), &meta)
.unwrap();
drop(staging_file);
let decoy_after = std::fs::read(&decoy).unwrap();
assert_eq!(
decoy_after, b"DECOY: must not be overwritten",
"decoy must not be touched by the fd-bound writer"
);
let decoy_meta_after = std::fs::metadata(&decoy).unwrap();
assert_eq!(
decoy_meta_after.ino(),
decoy_inode_before,
"decoy inode must not have changed"
);
let lmeta = std::fs::symlink_metadata(&staging_path).unwrap();
assert!(lmeta.file_type().is_symlink());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn write_json_pretty_writes_via_fd_not_reopen_by_path() {
use std::os::unix::fs::MetadataExt;
let dir = fresh_dir("load1-json-fd-not-reopen");
let staging = dir.join("staging.tmp.safetensors");
let decoy = dir.join("decoy.json");
std::fs::write(&decoy, b"{\"decoy\": true}").unwrap();
let decoy_inode_before = std::fs::metadata(&decoy).unwrap().ino();
let (mut staging_file, staging_path) = open_excl_temp_shard(&staging).unwrap();
std::fs::remove_file(&staging_path).unwrap();
std::os::unix::fs::symlink(&decoy, &staging_path).unwrap();
let value = serde_json::json!({
"metadata": { "total_size": 0, "total_parameters": 0 },
"weight_map": {},
});
write_json_pretty(
&mut staging_file,
&staging_path,
&value,
"LOAD-1: json fd-bound",
)
.unwrap();
drop(staging_file);
let decoy_after = std::fs::read(&decoy).unwrap();
assert_eq!(
decoy_after, b"{\"decoy\": true}",
"decoy JSON must be untouched by the fd-bound writer"
);
assert_eq!(
std::fs::metadata(&decoy).unwrap().ino(),
decoy_inode_before,
"decoy inode must not have changed"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_safetensors_to_file_round_trips_via_path_load() {
let dir = fresh_dir("load1-fd-round-trip");
let arr_a = Array::from_slice::<f32>(&[1.0_f32, 2.0, 3.0, 4.0], &(4usize,)).unwrap();
let arr_b = Array::from_slice::<f32>(&[10.0_f32, 20.0], &(2usize,)).unwrap();
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
let path = dir.join("via_fd.safetensors");
let mut f = std::fs::File::create(&path).unwrap();
crate::io::save_safetensors_to_file(&mut f, [("a", &arr_a), ("b", &arr_b)], &meta).unwrap();
f.sync_all().unwrap();
drop(f);
let mut loaded = crate::io::load_safetensors(&path).unwrap();
assert_eq!(loaded.len(), 2);
let a_read = loaded.get_mut("a").unwrap().to_vec::<f32>().unwrap();
let b_read = loaded.get_mut("b").unwrap().to_vec::<f32>().unwrap();
assert_eq!(a_read, vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(b_read, vec![10.0, 20.0]);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_safetensors_to_file_truncates_prefilled_file_at_nonzero_offset() {
use std::io::{Seek, SeekFrom, Write as _};
let dir = fresh_dir("load1-fd-prefilled-nonzero");
let path = dir.join("prefilled_nonzero.safetensors");
let mut f = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&path)
.unwrap();
f.write_all(&[0xAB_u8; 100]).unwrap();
f.seek(SeekFrom::Start(50)).unwrap();
let arr = Array::from_slice::<f32>(&[1.0_f32, 2.0, 3.0, 4.0], &(4usize,)).unwrap();
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
crate::io::save_safetensors_to_file(&mut f, std::iter::once(("w", &arr)), &meta).unwrap();
f.sync_all().unwrap();
drop(f);
let mut loaded = crate::io::load_safetensors(&path).unwrap();
assert_eq!(loaded.len(), 1, "expected exactly one tensor in the file");
let w = loaded.get_mut("w").unwrap().to_vec::<f32>().unwrap();
assert_eq!(w, vec![1.0, 2.0, 3.0, 4.0]);
let control_path = dir.join("control.safetensors");
let mut control_arrays: HashMap<String, &Array> = HashMap::new();
control_arrays.insert("w".to_string(), &arr);
crate::io::save_safetensors_view(
&control_path,
control_arrays.iter().map(|(k, &v)| (k.as_str(), v)),
&meta,
)
.unwrap();
let on_disk = std::fs::metadata(&path).unwrap().len();
let control_size = std::fs::metadata(&control_path).unwrap().len();
assert_eq!(
on_disk, control_size,
"fd-bound writer on a prefilled-at-offset-50 file must produce the same \
byte count as the path-based writer on a fresh file (proves rewind+truncate \
wiped the 100-byte prefill); fd={on_disk}, control={control_size}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_safetensors_to_file_truncates_prefilled_file_longer_than_new_payload() {
use std::io::{Seek, SeekFrom, Write as _};
let dir = fresh_dir("load1-fd-prefilled-longer");
let path = dir.join("prefilled_longer.safetensors");
let mut f = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&path)
.unwrap();
f.write_all(&[0xCD_u8; 10000]).unwrap();
f.seek(SeekFrom::Start(0)).unwrap();
let arr = Array::from_slice::<f32>(&[7.0_f32, 8.0, 9.0], &(3usize,)).unwrap();
let mut meta: HashMap<String, String> = HashMap::new();
meta.insert("format".to_string(), "mlx".to_string());
crate::io::save_safetensors_to_file(&mut f, std::iter::once(("w", &arr)), &meta).unwrap();
f.sync_all().unwrap();
drop(f);
let mut loaded = crate::io::load_safetensors(&path).unwrap();
assert_eq!(loaded.len(), 1);
let w = loaded.get_mut("w").unwrap().to_vec::<f32>().unwrap();
assert_eq!(w, vec![7.0, 8.0, 9.0]);
let control_path = dir.join("control.safetensors");
let mut control_arrays: HashMap<String, &Array> = HashMap::new();
control_arrays.insert("w".to_string(), &arr);
crate::io::save_safetensors_view(
&control_path,
control_arrays.iter().map(|(k, &v)| (k.as_str(), v)),
&meta,
)
.unwrap();
let on_disk = std::fs::metadata(&path).unwrap().len();
let control_size = std::fs::metadata(&control_path).unwrap().len();
assert_eq!(
on_disk, control_size,
"fd-bound writer on a 10000-byte-prefilled file must produce the same byte \
count as the path-based writer on a fresh file (proves set_len(0) truncated \
trailing prefill); fd={on_disk}, control={control_size}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_safetensors_to_file_preserves_existing_file_on_interior_nul_metadata() {
let dir = fresh_dir("load1-fd-r2-nul-meta");
let path = dir.join("preexisting_meta.safetensors");
let original_bytes: &[u8] = b"existing valid safetensors payload here";
std::fs::write(&path, original_bytes).unwrap();
let original_len = original_bytes.len() as u64;
assert_eq!(
std::fs::metadata(&path).unwrap().len(),
original_len,
"pre-call: file size must equal prefill length"
);
let mut file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
let array = Array::from_slice::<f32>(&[1.0_f32, 2.0], &(2usize,)).unwrap();
let mut bad_metadata: HashMap<String, String> = HashMap::new();
bad_metadata.insert("key\0with-nul".to_string(), "value".to_string());
let result = crate::io::save_safetensors_to_file(
&mut file,
std::iter::once(("name", &array)),
&bad_metadata,
);
assert!(
result.is_err(),
"expected Err from interior-NUL in metadata key, got Ok"
);
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("NUL") || err_msg.contains("nul"),
"expected an interior-NUL error message, got: {err_msg}"
);
drop(file);
let bytes_after = std::fs::read(&path).unwrap();
assert_eq!(
bytes_after, original_bytes,
"DEFENSE-IN-DEPTH REGRESSION: input-validation Err from build_string_map must \
return before the destructive seek+set_len so a caller-owned prefilled file is \
byte-identical to its pre-call state on this error path. NOT a contract — see \
save_safetensors_to_file's Destructive mutation doc section."
);
assert_eq!(
bytes_after.len() as u64,
original_len,
"post-call: file size must still equal prefill length"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_safetensors_to_file_preserves_existing_file_on_interior_nul_array_name() {
let dir = fresh_dir("load1-fd-r2-nul-name");
let path = dir.join("preexisting_name.safetensors");
let original_bytes: &[u8] = b"another distinct prefilled payload, array-name path";
std::fs::write(&path, original_bytes).unwrap();
let original_len = original_bytes.len() as u64;
assert_eq!(
std::fs::metadata(&path).unwrap().len(),
original_len,
"pre-call: file size must equal prefill length"
);
let mut file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
let array = Array::from_slice::<f32>(&[3.0_f32, 4.0, 5.0], &(3usize,)).unwrap();
let good_metadata: HashMap<String, String> = HashMap::new();
let bad_name = "arr\0with-nul";
let result = crate::io::save_safetensors_to_file(
&mut file,
std::iter::once((bad_name, &array)),
&good_metadata,
);
assert!(
result.is_err(),
"expected Err from interior-NUL in array name, got Ok"
);
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("NUL") || err_msg.contains("nul"),
"expected an interior-NUL error message, got: {err_msg}"
);
drop(file);
let bytes_after = std::fs::read(&path).unwrap();
assert_eq!(
bytes_after, original_bytes,
"DEFENSE-IN-DEPTH REGRESSION (array-name path): input-validation Err from \
build_array_map must return before the destructive seek+set_len so a \
caller-owned prefilled file is byte-identical to its pre-call state on this \
error path. NOT a contract — see save_safetensors_to_file's Destructive \
mutation doc section."
);
assert_eq!(
bytes_after.len() as u64,
original_len,
"post-call: file size must still equal prefill length"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_safetensors_to_file_empty_metadata_succeeds_with_null_check() {
let dir = fresh_dir("load1-fd-r3-empty-meta-ok");
let path = dir.join("empty_meta_ok.safetensors");
let mut file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create_new(true)
.open(&path)
.unwrap();
let arr = Array::from_slice::<f32>(&[1.5_f32, 2.5, 3.5], &(3usize,)).unwrap();
let empty_metadata: HashMap<String, String> = HashMap::new();
crate::io::save_safetensors_to_file(&mut file, std::iter::once(("w", &arr)), &empty_metadata)
.expect(
"DEFENSE-IN-DEPTH REGRESSION: empty-metadata save_safetensors_to_file must \
succeed — the NULL-sentinel guard in build_string_map must not reject valid \
handles. See save_safetensors_to_file's Destructive mutation doc section.",
);
file.sync_all().unwrap();
drop(file);
let mut loaded = crate::io::load_safetensors(&path).unwrap();
assert_eq!(loaded.len(), 1, "round-trip must yield exactly one array");
let w = loaded.get_mut("w").unwrap().to_vec::<f32>().unwrap();
assert_eq!(
w,
vec![1.5, 2.5, 3.5],
"round-trip values must match the pre-save array"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn build_map_helpers_carry_null_sentinel_check() {
let src = std::fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/src/io.rs"))
.expect("must be able to read mlxrs/src/io.rs to verify NULL-sentinel guards");
let array_fn = src
.find("fn build_array_map")
.expect("build_array_map must exist in io.rs");
let array_end = {
let target = (array_fn + 3000).min(src.len());
let mut end = target;
while end > array_fn && !src.is_char_boundary(end) {
end -= 1;
}
end
};
let array_window = &src[array_fn..array_end];
assert!(
array_window.contains("mlx_map_string_to_array_new"),
"DEFENSE-IN-DEPTH STRUCTURAL: build_array_map must still call \
mlx_map_string_to_array_new"
);
assert!(
array_window.contains("ctx.is_null()"),
"DEFENSE-IN-DEPTH STRUCTURAL REGRESSION: build_array_map must check \
`guard.0.ctx.is_null()` immediately after `mlx_map_string_to_array_new()` to \
surface allocation-failure sentinels; the check appears to have been removed."
);
let string_fn = src
.find("fn build_string_map")
.expect("build_string_map must exist in io.rs");
let string_end = {
let target = (string_fn + 3000).min(src.len());
let mut end = target;
while end > string_fn && !src.is_char_boundary(end) {
end -= 1;
}
end
};
let string_window = &src[string_fn..string_end];
assert!(
string_window.contains("mlx_map_string_to_string_new"),
"DEFENSE-IN-DEPTH STRUCTURAL: build_string_map must still call \
mlx_map_string_to_string_new"
);
assert!(
string_window.contains("ctx.is_null()"),
"DEFENSE-IN-DEPTH STRUCTURAL REGRESSION: build_string_map must check \
`guard.0.ctx.is_null()` immediately after `mlx_map_string_to_string_new()` — \
without this guard, an allocation failure on the empty-metadata save path \
returns a NULL-ctx sentinel through `Ok(NULL)` to the caller. The check \
appears to have been removed."
);
let drains_last = |window: &str| {
window.contains("take_last()")
|| window.contains("LAST.with")
|| window.contains("crate::error::take_last")
};
assert!(
drains_last(array_window),
"DEFENSE-IN-DEPTH STRUCTURAL: build_array_map's NULL branch must DRAIN \
crate::error::LAST (via take_last() or LAST.with(..).take()), not peek \
— leaving a stale Err in the TLS pollutes later mlx-c calls on this thread."
);
assert!(
drains_last(string_window),
"DEFENSE-IN-DEPTH STRUCTURAL: build_string_map's NULL branch must DRAIN \
crate::error::LAST (via take_last() or LAST.with(..).take()), not peek \
— leaving a stale Err in the TLS pollutes later mlx-c calls on this thread."
);
}
#[test]
fn save_safetensors_to_file_writer_new_precedes_truncate() {
let src = std::fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/src/io.rs"))
.expect("must be able to read mlxrs/src/io.rs to verify writer-precedes-truncate ordering");
let fn_start = src
.find("pub fn save_safetensors_to_file")
.expect("save_safetensors_to_file must exist in io.rs");
let fn_tail = src[fn_start..]
.find("struct WriterState")
.expect("WriterState declaration must follow save_safetensors_to_file in io.rs");
let body = &src[fn_start..fn_start + fn_tail];
let writer_new_off = body.find("mlx_io_writer_new(").expect(
"DEFENSE-IN-DEPTH STRUCTURAL: save_safetensors_to_file must construct the \
mlx_io_writer via `mlx_io_writer_new(...)`; the writer-new call appears to \
have been removed or renamed.",
);
let seek_off = body.find("seek(SeekFrom::Start(0))").expect(
"DEFENSE-IN-DEPTH STRUCTURAL: save_safetensors_to_file must rewind via \
`seek(SeekFrom::Start(0))`; the rewind appears to have been removed or renamed.",
);
let set_len_off = body.find("set_len(0)").expect(
"DEFENSE-IN-DEPTH STRUCTURAL: save_safetensors_to_file must truncate via \
`set_len(0)`; the truncate appears to have been removed or renamed.",
);
assert!(
writer_new_off < seek_off,
"DEFENSE-IN-DEPTH STRUCTURAL REGRESSION: `mlx_io_writer_new(...)` must appear \
BEFORE `seek(SeekFrom::Start(0))` inside save_safetensors_to_file so an \
allocation failure in the writer ctor surfaces as Err before the destructive \
truncate. Current ordering has writer-new at byte {writer_new_off} and \
seek at byte {seek_off}.",
);
let post_writer_new = &body[writer_new_off..];
let next_lines: Vec<&str> = post_writer_new.lines().take(11).collect();
let null_check_window = next_lines.join("\n");
assert!(
null_check_window.contains(".ctx.is_null()") || null_check_window.contains("ctx.is_null()"),
"DEFENSE-IN-DEPTH STRUCTURAL REGRESSION: within 10 lines after \
`mlx_io_writer_new(...)` there must be an explicit `.ctx.is_null()` check \
that drains the NULL-ctx sentinel before any destructive file mutation. \
The check appears to have been removed.",
);
let drain_lines: Vec<&str> = post_writer_new.lines().take(21).collect();
let drain_window = drain_lines.join("\n");
assert!(
drain_window.contains("take_last()") || drain_window.contains("crate::error::take_last"),
"DEFENSE-IN-DEPTH STRUCTURAL REGRESSION: within 20 lines after \
`mlx_io_writer_new(...)` there must be a `take_last()` (or \
`crate::error::take_last`) DRAIN of the TLS error slot — peeking would \
leave a stale Err and poison the next unrelated mlx-c call on this thread. \
The drain appears to have been removed or replaced with a peek.",
);
let null_check_local_off = null_check_window
.find("ctx.is_null()")
.expect("invariant-2 guard above asserted this exists; cannot fail here");
let null_check_abs_off = writer_new_off + null_check_local_off;
assert!(
null_check_abs_off < set_len_off,
"DEFENSE-IN-DEPTH STRUCTURAL REGRESSION: `set_len(0)` must appear AFTER the \
`ctx.is_null()` check so a NULL-ctx writer sentinel cannot bypass the guard \
and trigger the destructive truncate. Current ordering has the null check at \
byte {null_check_abs_off} and set_len at byte {set_len_off}.",
);
}
#[test]
fn save_safetensors_to_file_writer_construction_precedes_truncate() {
let dir = fresh_dir("load1-fd-r4-writer-precedes-truncate");
let path = dir.join("prefilled_50_bytes.safetensors");
{
let mut prefill = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create_new(true)
.open(&path)
.unwrap();
std::io::Write::write_all(&mut prefill, &[0xA5_u8; 50]).unwrap();
prefill.sync_all().unwrap();
}
let mut file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
let arr = Array::from_slice::<f32>(&[7.0_f32, 8.0, 9.0], &(3usize,)).unwrap();
let empty_metadata: HashMap<String, String> = HashMap::new();
crate::io::save_safetensors_to_file(&mut file, std::iter::once(("v", &arr)), &empty_metadata)
.expect(
"DEFENSE-IN-DEPTH REGRESSION: happy-path save_safetensors_to_file with \
empty metadata must succeed (writer construction reached + write completed)",
);
file.sync_all().unwrap();
drop(file);
let mut loaded = crate::io::load_safetensors(&path).unwrap();
assert_eq!(
loaded.len(),
1,
"DEFENSE-IN-DEPTH REGRESSION: round-trip must yield exactly one array \
(saved one)"
);
let v = loaded.get_mut("v").unwrap().to_vec::<f32>().unwrap();
assert_eq!(
v,
vec![7.0, 8.0, 9.0],
"DEFENSE-IN-DEPTH REGRESSION: round-trip values must match the pre-save \
array — a mismatch would indicate the write did not run or wrote stale \
prefix bytes"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_safetensors_to_file_truncates_on_mlx_internal_error_zero_element_array() {
let dir = fresh_dir("load1-fd-destructive-zero-elem");
let path = dir.join("destructive_contract.safetensors");
let original_bytes: &[u8] = &[0xC3_u8; 50];
std::fs::write(&path, original_bytes).unwrap();
let original_len = original_bytes.len() as u64;
assert_eq!(
std::fs::metadata(&path).unwrap().len(),
original_len,
"pre-call: file size must equal prefill length"
);
let mut file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
let zero_arr = Array::from_slice::<f32>(&[], &(0usize,)).unwrap();
let empty_metadata: HashMap<String, String> = HashMap::new();
let result = crate::io::save_safetensors_to_file(
&mut file,
std::iter::once(("zero", &zero_arr)),
&empty_metadata,
);
assert!(
result.is_err(),
"expected Err from save_safetensors_to_file on a zero-element array — mlx-c's \
safetensors writer rejects this shape. If the writer started accepting \
zero-element arrays, pick another MLX-internal-rejection trigger to keep \
coverage of the destructive-contract path."
);
drop(file);
let post_len = std::fs::metadata(&path).unwrap().len();
assert!(
post_len < original_len,
"DESTRUCTIVE CONTRACT: save_safetensors_to_file MUST destructively mutate the \
file on an MLX-internal writer error (the destructive seek+set_len runs \
before mlx_save_safetensors_writer is invoked). The file size went from \
{original_len} bytes to {post_len} bytes — if this assertion fires because \
the file is BYTE-IDENTICAL to the prefill, the function silently regained a \
byte-preservation contract it explicitly disclaims. See \
save_safetensors_to_file's Destructive mutation doc section."
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn array_nbytes_two_byte_and_eight_byte_dtype_classes() {
let u16s = Array::from_slice::<u16>(&[0u16; 5], &(5usize,)).unwrap();
assert_eq!(array_nbytes(&u16s).unwrap(), 10);
let i16s = Array::from_slice::<i16>(&[0i16; 3], &(3usize,)).unwrap();
assert_eq!(array_nbytes(&i16s).unwrap(), 6);
let f16s = Array::from_slice::<half::f16>(&[half::f16::ZERO; 4], &(4usize,)).unwrap();
assert_eq!(array_nbytes(&f16s).unwrap(), 8);
let bf16s = Array::from_slice::<half::bf16>(&[half::bf16::ZERO; 2], &(2usize,)).unwrap();
assert_eq!(array_nbytes(&bf16s).unwrap(), 4);
let u64s = Array::from_slice::<u64>(&[0u64; 3], &(3usize,)).unwrap();
assert_eq!(array_nbytes(&u64s).unwrap(), 24);
let i64s = Array::from_slice::<i64>(&[0i64; 2], &(2usize,)).unwrap();
assert_eq!(array_nbytes(&i64s).unwrap(), 16);
let f64s = Array::from_slice::<f64>(&[0.0f64; 5], &(5usize,)).unwrap();
assert_eq!(array_nbytes(&f64s).unwrap(), 40);
}
#[test]
fn config_from_json_parses_required_subset_ignores_unknown_keys() {
let json = r#"{
"model_type": "qwen3",
"hidden_size": 64,
"num_hidden_layers": 12,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 8,
"rope_theta": 10000.0,
"vocab_size": 32000,
"tie_word_embeddings": true,
"max_position_embeddings": 4096,
"quantization_config": {"unmodeled": "ignored"}
}"#;
let config = Config::from_json(json).unwrap();
assert_eq!(config.model_type(), "qwen3");
assert_eq!(config.hidden_size, 64);
assert_eq!(config.num_hidden_layers, 12);
assert_eq!(config.num_attention_heads, 8);
assert_eq!(config.num_key_value_heads, 2);
assert_eq!(config.head_dim, 8);
assert!((config.rope_theta - 10000.0).abs() < 1e-3);
assert_eq!(config.vocab_size, 32000);
assert!(config.tie_word_embeddings);
assert!(config.sliding_window.is_none());
assert!(config.quantization.is_none());
assert!(config.eos_token_id.is_none());
}
#[test]
fn config_from_json_eos_token_id_scalar_and_list_forms() {
let base = r#""model_type":"m","hidden_size":1,"num_hidden_layers":1,
"num_attention_heads":1,"num_key_value_heads":1,"head_dim":1,
"rope_theta":1.0,"vocab_size":2,"tie_word_embeddings":false"#;
let scalar = format!("{{{base},\"eos_token_id\":128001}}");
let c = Config::from_json(&scalar).unwrap();
assert_eq!(c.eos_token_id, Some(EosTokenId::Single(128001)));
assert!(c.eos_token_id.as_ref().unwrap().is_single());
assert_eq!(c.eos_token_id.unwrap().into_ids(), vec![128001]);
let list = format!("{{{base},\"eos_token_id\":[1,2,3]}}");
let c2 = Config::from_json(&list).unwrap();
assert_eq!(c2.eos_token_id, Some(EosTokenId::Many(vec![1, 2, 3])));
assert!(c2.eos_token_id.as_ref().unwrap().is_many());
assert_eq!(c2.eos_token_id.unwrap().into_ids(), vec![1, 2, 3]);
}
#[test]
fn config_from_json_missing_required_field_is_parse_error() {
let json = r#"{
"model_type": "m", "hidden_size": 1, "num_hidden_layers": 1,
"num_attention_heads": 1, "num_key_value_heads": 1, "head_dim": 1,
"rope_theta": 1.0, "tie_word_embeddings": false
}"#; let err = Config::from_json(json);
assert!(
matches!(&err, Err(Error::Parse(p))
if p.context() == "Config::from_json" && p.input_kind() == "model config JSON"),
"expected Error::Parse for a missing required field, got {err:?}"
);
}
#[test]
fn config_from_json_malformed_is_parse_error() {
let err = Config::from_json("this is not json {{{");
assert!(
matches!(&err, Err(Error::Parse(p)) if p.context() == "Config::from_json"),
"expected Error::Parse for malformed JSON, got {err:?}"
);
}
#[cfg(unix)]
#[test]
fn read_bounded_config_file_open_error_other_than_notfound() {
let dir = fresh_dir("read-bounded-enotdir");
let as_file = dir.join("not_a_dir");
std::fs::write(&as_file, b"x").unwrap();
let through = as_file.join("config.json");
let r = read_bounded_config_file(&through, "model config");
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo for an ENOTDIR open, got {r:?}");
};
assert_eq!(p.op(), FileOp::Open);
assert_eq!(p.context(), "model config");
assert_ne!(
p.inner().kind(),
std::io::ErrorKind::NotFound,
"the ENOTDIR open error must NOT be classified as NotFound (that path returns Ok(None))"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn read_bounded_config_file_absent_is_ok_none() {
let dir = fresh_dir("read-bounded-absent");
let missing = dir.join("does-not-exist.json");
let r = read_bounded_config_file(&missing, "model config").unwrap();
assert!(r.is_none(), "an absent file must yield Ok(None)");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn read_bounded_config_file_rejects_non_regular_directory() {
let dir = fresh_dir("read-bounded-isdir");
let r = read_bounded_config_file(&dir, "model config");
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo for a directory target, got {r:?}");
};
assert_eq!(p.op(), FileOp::Stat);
assert_eq!(p.context(), "model config");
assert_eq!(p.inner().kind(), std::io::ErrorKind::InvalidInput);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn read_bounded_text_file_cap_exceeded() {
let dir = fresh_dir("read-bounded-cap");
let path = dir.join("big.json");
std::fs::write(&path, b"abcd").unwrap(); let r = read_bounded_text_file(&path, "model config", 3);
let Err(Error::CapExceeded(p)) = r else {
panic!("expected Error::CapExceeded for an over-cap body, got {r:?}");
};
assert_eq!(p.context(), "model config");
assert_eq!(p.cap_name(), "max_bytes");
assert_eq!(p.cap(), 3);
assert_eq!(p.observed(), 4);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn read_bounded_text_file_exactly_at_cap_ok() {
let dir = fresh_dir("read-bounded-at-cap");
let path = dir.join("ok.json");
std::fs::write(&path, b"abc").unwrap(); let r = read_bounded_text_file(&path, "model config", 3).unwrap();
assert_eq!(r.as_deref(), Some("abc"));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn read_bounded_text_file_non_utf8_is_layer_keyed_parse() {
let dir = fresh_dir("read-bounded-non-utf8");
let path = dir.join("bad.json");
std::fs::write(&path, [0xFF_u8, 0xFE, 0x00]).unwrap();
let r = read_bounded_text_file(&path, "model config", 1024);
let Err(Error::LayerKeyed(p)) = r else {
panic!("expected Error::LayerKeyed for non-UTF-8, got {r:?}");
};
assert!(
p.layer().contains("bad.json"),
"layer should name the path, got `{}`",
p.layer()
);
assert!(
matches!(p.inner(), Error::Parse(pp)
if pp.context() == "model config" && pp.input_kind() == "UTF-8"),
"expected inner Error::Parse(UTF-8), got {:?}",
p.inner()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_config_returns_typed_and_raw_no_generation_override() {
let dir = fresh_dir("load-config-basic");
let body = r#"{"model_type":"qwen3","hidden_size":16,"num_hidden_layers":2,
"num_attention_heads":4,"num_key_value_heads":2,"head_dim":4,
"rope_theta":10000.0,"vocab_size":100,"tie_word_embeddings":false,
"eos_token_id":7}"#;
std::fs::write(dir.join("config.json"), body).unwrap();
let (config, raw) = load_config(&dir).unwrap();
assert_eq!(config.model_type(), "qwen3");
assert_eq!(config.eos_token_id, Some(EosTokenId::Single(7)));
assert_eq!(raw, body);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_config_missing_config_json_errors() {
let dir = fresh_dir("load-config-missing");
let r = load_config(&dir);
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo for a missing config.json, got {r:?}");
};
assert_eq!(p.op(), FileOp::Open);
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
assert_eq!(p.path(), dir.join("config.json").as_path());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_config_generation_config_eos_override_replaces_in_place() {
let dir = fresh_dir("load-config-gen-override");
let body = r#"{"model_type":"m","hidden_size":1,"num_hidden_layers":1,
"num_attention_heads":1,"num_key_value_heads":1,"head_dim":1,
"rope_theta":1.0,"vocab_size":2,"tie_word_embeddings":false,
"eos_token_id":1}"#;
std::fs::write(dir.join("config.json"), body).unwrap();
std::fs::write(
dir.join("generation_config.json"),
r#"{"eos_token_id":[10,20]}"#,
)
.unwrap();
let (config, raw) = load_config(&dir).unwrap();
assert_eq!(
config.eos_token_id,
Some(EosTokenId::Many(vec![10, 20])),
"generation_config eos must REPLACE config.json's in place"
);
assert_eq!(raw, body);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn read_generation_eos_truthiness_matrix() {
let dir = fresh_dir("gen-eos-matrix");
let gp = dir.join("generation_config.json");
let _ = std::fs::remove_file(&gp);
assert_eq!(read_generation_eos(&dir), None, "absent → None");
std::fs::write(&gp, r#"{"eos_token_id":0}"#).unwrap();
assert_eq!(read_generation_eos(&dir), None, "scalar 0 → None");
std::fs::write(&gp, r#"{"eos_token_id":[]}"#).unwrap();
assert_eq!(read_generation_eos(&dir), None, "empty list → None");
std::fs::write(&gp, r#"{"eos_token_id":42}"#).unwrap();
assert_eq!(read_generation_eos(&dir), Some(EosTokenId::Single(42)));
std::fs::write(&gp, r#"{"eos_token_id":[0,5]}"#).unwrap();
assert_eq!(
read_generation_eos(&dir),
Some(EosTokenId::Many(vec![0, 5])),
"a non-empty list is truthy and preserves contents (incl. 0)"
);
std::fs::write(&gp, b"not json {{{").unwrap();
assert_eq!(read_generation_eos(&dir), None, "malformed → None");
std::fs::write(&gp, r#"{"something_else":1}"#).unwrap();
assert_eq!(read_generation_eos(&dir), None, "missing key → None");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_index_non_string_shard_value_errors() {
let dir = fresh_dir("load-index-non-string");
write_json_pretty_to_path(
&dir.join("model.safetensors.index.json"),
&serde_json::json!({
"metadata": { "total_size": 0, "total_parameters": 0 },
"weight_map": { "w.weight": 123 },
}),
"test: non-string shard value",
)
.unwrap();
let r = load_weights(&dir);
let Err(Error::LayerKeyed(p)) = r else {
panic!("expected Error::LayerKeyed for a non-string shard value, got {r:?}");
};
assert!(
p.layer().contains("weight_map[w.weight]"),
"layer should name the offending mapping, got `{}`",
p.layer()
);
assert!(
matches!(p.inner(), Error::InvariantViolation(iv)
if iv.context().contains("weight_map shard value") && iv.requirement().contains("string")),
"expected inner InvariantViolation about string value, got {:?}",
p.inner()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_index_without_weight_map_object_errors() {
let dir = fresh_dir("load-index-no-weight-map");
write_json_pretty_to_path(
&dir.join("model.safetensors.index.json"),
&serde_json::json!({
"metadata": { "total_size": 0, "total_parameters": 0 },
"weight_map": "not-an-object",
}),
"test: weight_map not an object",
)
.unwrap();
let r = load_weights(&dir);
let Err(Error::LayerKeyed(p)) = r else {
panic!("expected Error::LayerKeyed for a missing weight_map object, got {r:?}");
};
assert!(
p.layer().contains("model.safetensors.index.json"),
"layer should name the index path, got `{}`",
p.layer()
);
assert!(
matches!(p.inner(), Error::MissingKey(mk) if mk.key() == "weight_map"),
"expected inner MissingKey naming `weight_map`, got {:?}",
p.inner()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_weights_index_empty_weight_map_yields_empty_weights() {
let dir = fresh_dir("load-index-empty-map");
write_json_pretty_to_path(
&dir.join("model.safetensors.index.json"),
&serde_json::json!({
"metadata": { "total_size": 0, "total_parameters": 0 },
"weight_map": {},
}),
"test: empty weight_map",
)
.unwrap();
let loaded = load_weights(&dir).unwrap();
assert!(
loaded.is_empty(),
"an empty index weight_map yields an empty (but successful) load"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(not(feature = "gguf"))]
#[test]
fn load_weights_gguf_present_without_feature_is_unsupported() {
let dir = fresh_dir("load-gguf-unsupported");
std::fs::write(dir.join("model.gguf"), b"GGUF placeholder bytes").unwrap();
let r = load_weights(&dir);
let Err(Error::LayerKeyed(p)) = r else {
panic!("expected Error::LayerKeyed for an unsupported GGUF, got {r:?}");
};
assert!(
p.layer().contains("model.gguf"),
"layer should name the gguf file, got `{}`",
p.layer()
);
assert!(
matches!(p.inner(), Error::InvariantViolation(iv)
if iv.context().contains("GGUF") && iv.requirement().contains("enabled")),
"expected inner InvariantViolation about the gguf feature, got {:?}",
p.inner()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn load_weights_dir_is_regular_file_stat_error() {
let dir = fresh_dir("load-dir-is-file");
let as_file = dir.join("modelfile");
std::fs::write(&as_file, b"x").unwrap();
let r = load_weights(&as_file);
assert!(
matches!(&r, Err(Error::FileIo(_)) | Err(Error::LayerKeyed(_))),
"a non-directory model path must be a typed FileIo/LayerKeyed error, got {r:?}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn collect_sorted_unreadable_directory_errors() {
let dir = fresh_dir("collect-sorted-missing");
let missing = dir.join("no-such-subdir");
let r = collect_sorted(&missing, |n| n.ends_with(".safetensors"));
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo for an unreadable directory, got {r:?}");
};
assert_eq!(p.op(), FileOp::Read);
assert_eq!(p.context(), "cannot read model directory");
assert_eq!(p.path(), missing.as_path());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn collect_sorted_skips_matching_subdirectory_keeps_file() {
let dir = fresh_dir("collect-sorted-skip-dir");
std::fs::create_dir_all(dir.join("subdir.safetensors")).unwrap();
std::fs::write(dir.join("real.safetensors"), b"x").unwrap();
let out = collect_sorted(&dir, |n| n.ends_with(".safetensors")).unwrap();
assert_eq!(
out.len(),
1,
"only the regular file is collected, the dir is skipped"
);
assert_eq!(out[0], dir.join("real.safetensors"));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn get_total_parameters_nonpositive_bits_errors() {
let mut w: Weights = HashMap::new();
w.insert(
"model.q.weight".to_string(),
Array::from_slice::<u32>(&[0u32; 8], &(2usize, 4)).unwrap(),
);
w.insert(
"model.q.scales".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
let quant = PerLayerQuantization::from_global(Quantization::affine(64, 0));
let err = get_total_parameters(&w, &quant);
let Err(Error::LayerKeyed(p)) = err else {
panic!("expected Error::LayerKeyed for bits<=0, got {err:?}");
};
assert_eq!(p.layer(), "model.q");
assert!(
matches!(p.inner(), Error::OutOfRange(or)
if or.context().contains("bits") && or.requirement().contains("> 0")),
"expected inner OutOfRange about bits>0, got {:?}",
p.inner()
);
}
#[test]
fn get_total_parameters_biases_branch_unresolvable_quant_errors() {
let mut w: Weights = HashMap::new();
w.insert(
"model.q.weight".to_string(),
Array::from_slice::<u32>(&[0u32; 8], &(2usize, 4)).unwrap(),
);
w.insert(
"model.q.scales".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
w.insert(
"model.q.biases".to_string(),
Array::from_slice::<f32>(&[0.0, 0.0], &(2usize,)).unwrap(),
);
let err = get_total_parameters(&w, &PerLayerQuantization::default());
let Err(Error::LayerKeyed(p)) = err else {
panic!("expected Error::LayerKeyed for an unresolvable quantized triple, got {err:?}");
};
assert_eq!(p.layer(), "model.q");
assert!(
matches!(p.inner(), Error::InvariantViolation(iv)
if iv.requirement().contains("resolvable")),
"expected inner InvariantViolation about resolvable quant params, got {:?}",
p.inner()
);
}
#[cfg(unix)]
#[test]
fn save_model_create_dir_failure_on_nondir_parent() {
let dir = fresh_dir("save-model-create-dir-fail");
let blocker = dir.join("iam_a_file");
std::fs::write(&blocker, b"x").unwrap();
let target = blocker.join("sub");
let mut w: Weights = HashMap::new();
w.insert("w.weight".to_string(), f32_weight(2));
let r = save_model(&target, &w, &PerLayerQuantization::default());
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo(Create) for a non-dir parent, got {r:?}");
};
assert_eq!(p.op(), FileOp::Create);
assert_eq!(p.context(), "save_model: cannot create directory");
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn save_driver_create_dir_failure_on_nondir_parent() {
let dir = fresh_dir("save-driver-create-dir-fail");
let blocker = dir.join("iam_a_file");
std::fs::write(&blocker, b"x").unwrap();
let target = blocker.join("sub");
let mut w: Weights = HashMap::new();
w.insert("w.weight".to_string(), f32_weight(2));
let r = save(
&target,
&w,
r#"{"model_type":"m"}"#,
&PerLayerQuantization::default(),
);
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo(Create) for a non-dir destination parent, got {r:?}");
};
assert_eq!(p.op(), FileOp::Create);
assert_eq!(p.context(), "save: cannot create destination directory");
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn open_excl_temp_shard_no_file_name_component_errors() {
let r = open_excl_temp_shard(std::path::Path::new("/"));
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo for a path with no file_name, got {r:?}");
};
assert_eq!(p.op(), FileOp::Stat);
assert_eq!(p.context(), "save: destination has no file_name component");
assert_eq!(p.inner().kind(), std::io::ErrorKind::InvalidInput);
}
#[test]
fn fsync_path_succeeds_on_regular_file() {
let dir = fresh_dir("fsync-path-ok");
let path = dir.join("f.bin");
std::fs::write(&path, b"durable").unwrap();
let r: Result<()> = fsync_path(&path);
r.expect("fsync_path must succeed on a regular file");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn fsync_path_injector_skip_then_fail() {
let dir = fresh_dir("fsync-path-inject");
let path = dir.join("f.bin");
std::fs::write(&path, b"x").unwrap();
let _guard = arm_fsync_path_fault(1);
fsync_path(&path).expect("first call must pass (injector skip=1)");
let r = fsync_path(&path);
drop(_guard);
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo(Fsync) from the fired injector, got {r:?}");
};
assert_eq!(p.op(), FileOp::Fsync);
assert!(
p.inner()
.to_string()
.contains("injected fsync_path failure"),
"the wrapped io::Error must carry the injected message, got: {}",
p.inner()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn fsync_path_io_preserves_injected_kind() {
let dir = fresh_dir("fsync-path-io-kind");
let path = dir.join("f.bin");
std::fs::write(&path, b"x").unwrap();
let _guard = arm_fsync_path_fault_with_kind(0, std::io::ErrorKind::PermissionDenied);
let r: std::io::Result<()> = fsync_path_io(&path);
drop(_guard);
let e = r.expect_err("injector fires on the first call (skip=0)");
assert_eq!(
e.kind(),
std::io::ErrorKind::PermissionDenied,
"fsync_path_io must preserve the injected kind without collapsing to Other"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn fsync_path_io_remove_then_fail_real_os_error() {
let dir = fresh_dir("fsync-path-remove-fail");
let path = dir.join("f.bin");
std::fs::write(&path, b"x").unwrap();
let _guard = arm_fsync_path_fault_remove_then_fail(1);
fsync_path_io(&path).expect("first remove-then-fail call passes (skip=1)");
let r = fsync_path_io(&path);
drop(_guard);
let e = r.expect_err("the second call removes the file then fails on open");
assert_eq!(
e.kind(),
std::io::ErrorKind::NotFound,
"the OS-level open of a removed file is NotFound"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn fsync_open_file_for_path_injector_fires() {
let dir = fresh_dir("fsync-fd-inject");
let path = dir.join("f.bin");
{
let f = std::fs::File::create(&path).unwrap();
fsync_open_file_for_path(&f, &path).expect("fd-bound fsync must succeed on an open file");
}
let f = std::fs::File::open(&path).unwrap();
let _guard = arm_fsync_path_fault(0);
let r = fsync_open_file_for_path(&f, &path);
drop(_guard);
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo(Fsync) from the fd-bound injector, got {r:?}");
};
assert_eq!(p.op(), FileOp::Fsync);
assert!(
p.inner()
.to_string()
.contains("injected fsync_path failure"),
"fd-bound injector must carry the injected message, got: {}",
p.inner()
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn fsync_open_file_for_path_injector_skip_then_fail() {
let dir = fresh_dir("fsync-fd-skip-then-fail");
let path = dir.join("f.bin");
let f = std::fs::File::create(&path).unwrap();
let _guard = arm_fsync_path_fault(1);
fsync_open_file_for_path(&f, &path).expect("first fd-bound call passes (skip=1)");
let r = fsync_open_file_for_path(&f, &path);
drop(_guard);
assert!(
matches!(&r, Err(Error::FileIo(p)) if p.op() == FileOp::Fsync),
"second fd-bound call must fire the injector, got {r:?}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn fsync_open_file_for_path_remove_then_fail_still_succeeds() {
let dir = fresh_dir("fsync-fd-remove-then-fail");
let path = dir.join("f.bin");
let f = std::fs::File::create(&path).unwrap();
let _guard = arm_fsync_path_fault_remove_then_fail(0);
let r = fsync_open_file_for_path(&f, &path);
drop(_guard);
r.expect("fd-bound remove-then-fail still syncs the live inode (POSIX)");
assert!(!path.exists(), "the injector unlinked the path");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_config_post_rename_fsync_failure_is_durability_warning() {
let dir = fresh_dir("save-config-durability");
let path = dir.join("config.json");
let _guard = arm_fsync_dir_fault(0);
let r = save_config(r#"{"model_type":"qwen3","hidden_size":8}"#, &path);
drop(_guard);
match r {
Err(Error::DurabilityWarning(p)) => {
assert!(
p.committed(),
"config DurabilityWarning must be committed=true"
);
assert!(
p.source()
.to_string()
.contains("injected fsync_dir failure"),
"the underlying io::Error must be preserved, got: {}",
p.source()
);
}
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
assert!(path.is_file(), "config.json must be visible on disk");
let text = std::fs::read_to_string(&path).unwrap();
assert!(text.contains("qwen3"), "the new config content is on disk");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_pre_index_fsync_failure_is_hard_error() {
let dir = fresh_dir("save-model-pre-index-fsync-fail");
let mut w: Weights = HashMap::new();
w.insert("w.weight".to_string(), f32_weight(2));
let _guard = arm_fsync_dir_fault(0);
let r = save_model(&dir, &w, &PerLayerQuantization::default());
drop(_guard);
let Err(Error::FileIo(p)) = r else {
panic!("a pre-index fsync_dir failure must be a hard Error::FileIo, got {r:?}");
};
assert_eq!(p.op(), FileOp::Fsync);
assert_eq!(p.context(), "save_model: fsync parent directory");
assert!(
p.inner().to_string().contains("injected fsync_dir failure"),
"the wrapped io::Error must carry the injected message, got: {}",
p.inner()
);
assert!(
!dir.join("model.safetensors.index.json").is_file(),
"no index may be committed on a pre-index fsync failure"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn write_json_pretty_to_path_create_failure() {
let dir = fresh_dir("write-json-create-fail");
let blocker = dir.join("iam_a_file");
std::fs::write(&blocker, b"x").unwrap();
let target = blocker.join("index.json"); let r = write_json_pretty_to_path(
&target,
&serde_json::json!({ "weight_map": {} }),
"test: create failure",
);
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo(Create) for an ENOTDIR create, got {r:?}");
};
assert_eq!(p.op(), FileOp::Create);
assert_eq!(p.context(), "test: create failure");
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn write_json_pretty_write_failure_on_readonly_fd() {
let dir = fresh_dir("write-json-write-fail");
let path = dir.join("ro.json");
std::fs::write(&path, b"placeholder").unwrap();
let mut f = std::fs::OpenOptions::new().read(true).open(&path).unwrap();
let r = write_json_pretty(
&mut f,
&path,
&serde_json::json!({ "k": "v" }),
"test: write failure",
);
let Err(Error::FileIo(p)) = r else {
panic!("expected Error::FileIo(Write) on a read-only fd, got {r:?}");
};
assert_eq!(p.op(), FileOp::Write);
assert_eq!(p.context(), "test: write failure");
assert_eq!(p.path(), path.as_path());
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn path_is_file_self_symlink_loop_is_stat_error() {
let dir = fresh_dir("path-is-file-eloop");
assert!(!dir.join("model.safetensors.index.json").exists());
let loop_path = dir.join("model.safetensors");
std::os::unix::fs::symlink(&loop_path, &loop_path).unwrap();
assert!(
std::fs::symlink_metadata(&loop_path)
.unwrap()
.file_type()
.is_symlink(),
"the planted entry must be a symlink (the loop is in target resolution)"
);
let r = load_weights(&dir);
let Err(Error::FileIo(p)) = r else {
panic!("a self-symlink-loop model.safetensors must be an Error::FileIo, got {r:?}");
};
assert_eq!(p.context(), "path_is_file");
assert_eq!(p.op(), FileOp::Stat);
assert_eq!(p.path(), loop_path.as_path());
assert_ne!(
p.inner().kind(),
std::io::ErrorKind::NotFound,
"a stat ELOOP must NOT be classified as NotFound (that path returns Ok(false))"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn collect_sorted_dangling_symlink_entry_stat_error() {
let dir = fresh_dir("collect-sorted-dangling-symlink");
let link = dir.join("dangling.safetensors");
let missing_target = dir.join("nonexistent-target-file");
std::os::unix::fs::symlink(&missing_target, &link).unwrap();
assert!(
std::fs::symlink_metadata(&link)
.unwrap()
.file_type()
.is_symlink(),
"the planted entry must be a (dangling) symlink"
);
assert!(
std::fs::metadata(&link).is_err(),
"stat-through the dangling symlink must fail (target is missing)"
);
let r = collect_sorted(&dir, |n| n.ends_with(".safetensors"));
let Err(Error::FileIo(p)) = r else {
panic!("a dangling matched symlink must be an Error::FileIo, got {r:?}");
};
assert_eq!(p.context(), "collect_sorted: cannot stat entry");
assert_eq!(p.op(), FileOp::Stat);
assert_eq!(
p.path(),
link.as_path(),
"the error must name the offending entry path"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_model_staged_index_fsync_failure_cleans_shard_and_index_tempfiles() {
let dir = fresh_dir("save-staged-index-fsync-cleanup");
let mut w: Weights = HashMap::new();
w.insert(
"only.weight".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3usize,)).unwrap(),
);
let _guard = arm_fsync_path_fault(1);
let r = save_model(&dir, &w, &PerLayerQuantization::default());
drop(_guard);
let Err(Error::FileIo(p)) = r else {
panic!("the staged index-fsync failure must propagate as Error::FileIo, got {r:?}");
};
assert_eq!(p.op(), FileOp::Fsync);
assert!(
p.inner()
.to_string()
.contains("injected fsync_path failure"),
"the propagated error must carry the injected fsync message, got: {}",
p.inner()
);
let leftover_tmp = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.file_name()
.to_string_lossy()
.ends_with(".tmp.safetensors")
});
assert!(
!leftover_tmp,
"every staged tempfile (shard + index) must be removed on a staging failure"
);
let published_shard = std::fs::read_dir(&dir)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
let n = e.file_name().to_string_lossy().into_owned();
n.starts_with("model-gen-") && n.ends_with(".safetensors")
});
assert!(
!published_shard,
"no final shard may be published when staging fails"
);
assert!(
!dir.join("model.safetensors.index.json").exists(),
"no index may be committed when staging fails"
);
let _ = std::fs::remove_dir_all(&dir);
}