use std::{
collections::HashMap,
path::{Path, PathBuf},
};
use crate::{
array::Array,
dtype::Dtype,
error::{
ConvertPostSavePartialPayload, DurabilityWarningPayload, EmptyInputPayload, Error,
FileIoPayload, FileOp, InvariantViolationPayload, LayerKeyedPayload, ParsePayload, Result,
UnsupportedDtypePayload,
},
lm::{
load::{self, Weights},
quant::{self, PerLayerQuantization, QuantMode, Quantization, QuantizationOption},
},
};
pub struct ConvertArgs {
pub hf_path: PathBuf,
pub mlx_path: PathBuf,
pub quantize: bool,
pub q_group_size: Option<i32>,
pub q_bits: Option<i32>,
pub q_mode: QuantMode,
pub dtype: Option<Dtype>,
pub upload_repo: Option<String>,
pub revision: Option<String>,
pub dequantize: bool,
pub quant_predicate: Option<Box<dyn MixedQuantPredicate>>,
pub trust_remote_code: bool,
}
impl Default for ConvertArgs {
fn default() -> Self {
Self {
hf_path: PathBuf::new(),
mlx_path: PathBuf::new(),
quantize: false,
q_group_size: None,
q_bits: None,
q_mode: QuantMode::Affine,
dtype: None,
upload_repo: None,
revision: None,
dequantize: false,
quant_predicate: None,
trust_remote_code: false,
}
}
}
pub trait MixedQuantPredicate {
fn decide(&self, layer_name: &str, weight: &Array) -> Option<Quantization>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MixedQuantRecipe {
Mixed2_6,
Mixed3_4,
Mixed3_6,
Mixed4_6,
}
impl MixedQuantRecipe {
fn bits(self) -> (i32, i32) {
match self {
MixedQuantRecipe::Mixed2_6 => (2, 6),
MixedQuantRecipe::Mixed3_4 => (3, 4),
MixedQuantRecipe::Mixed3_6 => (3, 6),
MixedQuantRecipe::Mixed4_6 => (4, 6),
}
}
}
#[derive(Debug)]
pub struct DefaultMixedQuantPredicate {
low_bits: i32,
high_bits: i32,
group_size: i32,
layer_location: usize,
num_layers: i32,
}
impl MixedQuantPredicate for DefaultMixedQuantPredicate {
fn decide(&self, layer_name: &str, _weight: &Array) -> Option<Quantization> {
let index: i32 = layer_name
.split('.')
.nth(self.layer_location)
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let q8 = self.num_layers / 8;
let use_more_bits =
index < q8 || index >= 7 * self.num_layers / 8 || (index - q8).rem_euclid(3) == 2;
if use_more_bits
&& (layer_name.contains("v_proj")
|| layer_name.contains("v_a_proj")
|| layer_name.contains("v_b_proj"))
{
return Some(Quantization {
group_size: self.group_size,
bits: self.high_bits,
mode: QuantMode::Affine,
});
}
if use_more_bits && layer_name.contains("down_proj") {
return Some(Quantization {
group_size: self.group_size,
bits: self.high_bits,
mode: QuantMode::Affine,
});
}
if layer_name.contains("lm_head") {
return Some(Quantization {
group_size: self.group_size,
bits: self.high_bits,
mode: QuantMode::Affine,
});
}
Some(Quantization {
group_size: self.group_size,
bits: self.low_bits,
mode: QuantMode::Affine,
})
}
}
pub fn mixed_quant_predicate(
recipe: MixedQuantRecipe,
weights: &Weights,
group_size: i32,
) -> Result<DefaultMixedQuantPredicate> {
let (low_bits, high_bits) = recipe.bits();
let mut down_keys: Vec<&str> = weights
.keys()
.filter_map(|k| {
if k.contains("down_proj") {
Some(k.strip_suffix(".weight").unwrap_or(k.as_str()))
} else {
None
}
})
.collect();
if down_keys.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"mixed_quant_predicate: model down_proj keys",
)));
}
down_keys.sort();
let first = down_keys[0];
let layer_location: usize = first
.split('.')
.position(|s| !s.is_empty() && s.chars().all(|c| c.is_ascii_digit()))
.ok_or_else(|| {
Error::LayerKeyed(LayerKeyedPayload::new(
first.to_string(),
Error::InvariantViolation(InvariantViolationPayload::new(
"mixed_quant_predicate: `down_proj` path",
"must contain a numeric layer-index segment (mirroring \
convert.py:43-45's `if k.isdigit(): break`)",
)),
))
})?;
let mut max_idx: i32 = -1;
for key in &down_keys {
if let Some(seg) = key.split('.').nth(layer_location)
&& let Ok(idx) = seg.parse::<i32>()
&& idx > max_idx
{
max_idx = idx;
}
}
let num_layers = if max_idx >= 0 { max_idx + 1 } else { 1 };
Ok(DefaultMixedQuantPredicate {
low_bits,
high_bits,
group_size,
layer_location,
num_layers,
})
}
pub fn convert(args: ConvertArgs) -> Result<()> {
let ConvertArgs {
hf_path,
mlx_path,
quantize,
q_group_size,
q_bits,
q_mode,
dtype,
upload_repo,
revision,
dequantize,
quant_predicate,
trust_remote_code: _, } = args;
if mlx_path.exists() {
return Err(Error::FileIo(FileIoPayload::new(
"convert: destination must not already exist (delete the file/directory or specify \
a new path to save to)",
FileOp::Stat,
mlx_path,
std::io::Error::from(std::io::ErrorKind::AlreadyExists),
)));
}
if upload_repo.is_some() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"convert: `upload_repo`",
"must be None in mlxrs (HuggingFace Hub upload is out of scope — mlxrs is \
local-path-only; drop the kwarg or upload the result directory yourself)",
)));
}
if revision.is_some() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"convert: `revision`",
"must be None in mlxrs (HuggingFace Hub download is out of scope — mlxrs is \
local-path-only; download the checkpoint yourself and pass its local path as `hf_path`)",
)));
}
if quantize && dequantize {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"convert: `quantize` && `dequantize` flags",
"must not both be true — choose either quantize or dequantize, not both \
(convert.py:146-147)",
)));
}
let (cfg_typed, config_json_text) = load::load_config(&hf_path)?;
let weights = load::load_weights(&hf_path)?;
let _tokenizer = load::load_tokenizer(&hf_path, &cfg_typed)?;
let resolved_dtype = resolve_target_dtype(dtype, &config_json_text)?;
let weights = if let Some(d) = resolved_dtype {
cast_floats_to_dtype(weights, d)?
} else {
weights
};
let (out_weights, out_config_json, per_layer_cfg) = if quantize {
let (gs, bits) = defaults_for_mode(q_mode, q_group_size, q_bits);
let decisions = build_predicate_decisions(quant_predicate.as_deref(), &weights, gs);
let (cfg, cfg_json) =
build_quantize_config(&config_json_text, gs, bits, q_mode, &decisions, &weights)?;
let eligible = |path: &str, _weight: &Array| -> bool {
match (quant_predicate.is_some(), decisions.get(path)) {
(true, Some(Some(_))) => true,
(true, Some(None)) => false,
(true, None) => false,
(false, _) => true,
}
};
let w = quant::quantize_weights(weights, &cfg, &eligible)?;
(w, cfg_json, cfg)
} else if dequantize {
let cfg = quant::parse_quantization(&config_json_text)?.unwrap_or_default();
let stripped = strip_quantization_blocks(&config_json_text)?;
let w = quant::dequantize_weights(weights, &cfg)?;
(w, stripped, PerLayerQuantization::default())
} else {
let cfg = quant::parse_quantization(&config_json_text)?.unwrap_or_default();
(weights, config_json_text, cfg)
};
let committed_warning: Option<std::io::Error> =
match load::save(&mlx_path, &out_weights, &out_config_json, &per_layer_cfg) {
Ok(()) => None,
Err(Error::DurabilityWarning(p)) if p.committed() => Some(p.into_source()),
Err(other) => return Err(other),
};
let copy_result = copy_tokenizer_and_extras(&hf_path, &mlx_path);
match (committed_warning, copy_result) {
(save, Ok(copy_outcome)) => {
let copy_warnings = match copy_outcome {
CopyOutcome::Committed => CopyDurabilityWarnings::default(),
CopyOutcome::CommittedWithDurabilityWarning(w) => w,
};
let aggregate = crate::error::ConvertDurabilityWarnings {
committed: true,
save,
post_copy_file: copy_warnings.post_copy_file,
post_copy_dir: copy_warnings.post_copy_dir,
};
match aggregate.count() {
0 => Ok(()),
1 => {
let (_, save, post_copy_file, post_copy_dir) = aggregate.into_parts();
let source = save
.or(post_copy_file)
.or(post_copy_dir)
.expect("count() == 1 guarantees exactly one Some field");
Err(Error::DurabilityWarning(DurabilityWarningPayload::new(
true, source,
)))
}
_ => Err(Error::ConvertDurabilityWarnings(aggregate)),
}
}
(None, Err(copy_err)) => Err(Error::ConvertPostSavePartial(
ConvertPostSavePartialPayload::new(true, None, copy_err),
)),
(Some(save_source), Err(copy_err)) => Err(Error::ConvertPostSavePartial(
ConvertPostSavePartialPayload::new(true, Some(save_source), copy_err),
)),
}
}
fn defaults_for_mode(mode: QuantMode, gs: Option<i32>, bits: Option<i32>) -> (i32, i32) {
let (default_gs, default_bits) = match mode {
QuantMode::Affine => (64, 4),
QuantMode::Mxfp4 => (32, 4),
QuantMode::Nvfp4 => (16, 4),
QuantMode::Mxfp8 => (32, 8),
};
(
gs.filter(|&v| v > 0).unwrap_or(default_gs),
bits.filter(|&v| v > 0).unwrap_or(default_bits),
)
}
fn resolve_target_dtype(explicit: Option<Dtype>, config_json: &str) -> Result<Option<Dtype>> {
if let Some(d) = explicit {
return match d {
Dtype::F16 | Dtype::BF16 | Dtype::F32 => Ok(Some(d)),
other => Err(Error::UnsupportedDtype(UnsupportedDtypePayload::new(
"convert: `dtype` (matches mlx_lm/convert.py:82 supported set MODEL_CONVERSION_DTYPES)",
other,
&[Dtype::F16, Dtype::BF16, Dtype::F32],
))),
};
}
let parsed: serde_json::Value = match serde_json::from_str(config_json) {
Ok(v) => v,
Err(_) => return Ok(None), };
if let Some(s) = parsed.get("torch_dtype").and_then(|v| v.as_str())
&& let Some(d) = parse_conversion_dtype(s)
{
return Ok(Some(d));
}
if let Some(text_cfg) = parsed.get("text_config")
&& let Some(s) = text_cfg.get("dtype").and_then(|v| v.as_str())
&& let Some(d) = parse_conversion_dtype(s)
{
return Ok(Some(d));
}
Ok(None)
}
fn parse_conversion_dtype(s: &str) -> Option<Dtype> {
match s {
"float16" => Some(Dtype::F16),
"bfloat16" => Some(Dtype::BF16),
"float32" => Some(Dtype::F32),
_ => None,
}
}
fn cast_floats_to_dtype(weights: Weights, target: Dtype) -> Result<Weights> {
let mut out: Weights = HashMap::with_capacity(weights.len());
for (k, arr) in weights {
let dt = arr.dtype()?;
let is_floating = matches!(dt, Dtype::F16 | Dtype::F32 | Dtype::BF16);
if is_floating && dt != target {
out.insert(k, arr.astype(target)?);
} else {
out.insert(k, arr);
}
}
Ok(out)
}
type PredicateDecisions = HashMap<String, Option<Quantization>>;
fn build_predicate_decisions(
predicate: Option<&dyn MixedQuantPredicate>,
weights: &Weights,
group_size: i32,
) -> PredicateDecisions {
let mut decisions: PredicateDecisions = HashMap::new();
let Some(pred) = predicate else {
return decisions;
};
if group_size <= 0 {
return decisions;
}
let gs_usize = group_size as usize;
for (key, arr) in weights {
let Some(path) = key.strip_suffix(".weight") else {
continue;
};
let shape = arr.shape();
if shape.len() < 2 {
continue;
}
let last = *shape.last().expect("rank>=2");
if last % gs_usize != 0 {
continue;
}
let decision = pred.decide(path, arr);
decisions.insert(path.to_string(), decision);
}
decisions
}
fn build_quantize_config(
config_json: &str,
group_size: i32,
bits: i32,
mode: QuantMode,
decisions: &PredicateDecisions,
weights: &Weights,
) -> Result<(PerLayerQuantization, String)> {
let value: serde_json::Value = serde_json::from_str(config_json)
.map_err(|e| Error::Parse(ParsePayload::new("convert: source config", "JSON", e)))?;
let serde_json::Value::Object(mut map) = value else {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"convert: source config JSON",
"must be an object",
)));
};
let fine_grained = map.contains_key("quantization");
let global = Quantization {
group_size,
bits,
mode,
};
let mut per_layer_overrides: HashMap<String, QuantizationOption> = HashMap::new();
if !decisions.is_empty() {
for key in weights.keys() {
let Some(path) = key.strip_suffix(".weight") else {
continue;
};
let Some(decision) = decisions.get(path) else {
continue;
};
match decision {
Some(q) => {
if *q != global || fine_grained {
per_layer_overrides.insert(path.to_string(), QuantizationOption::Quantize(*q));
}
}
None => {
if fine_grained {
per_layer_overrides.insert(path.to_string(), QuantizationOption::Skip);
}
}
}
}
}
let mut quant_block = serde_json::Map::new();
quant_block.insert(
"group_size".to_string(),
serde_json::Value::Number(group_size.into()),
);
quant_block.insert("bits".to_string(), serde_json::Value::Number(bits.into()));
quant_block.insert(
"mode".to_string(),
serde_json::Value::String(mode.as_str().to_string()),
);
for (path, opt) in &per_layer_overrides {
match opt {
QuantizationOption::Skip => {
quant_block.insert(path.clone(), serde_json::Value::Bool(false));
}
QuantizationOption::Quantize(q) => {
let mut nested = serde_json::Map::new();
nested.insert(
"group_size".to_string(),
serde_json::Value::Number(q.group_size.into()),
);
nested.insert("bits".to_string(), serde_json::Value::Number(q.bits.into()));
nested.insert(
"mode".to_string(),
serde_json::Value::String(q.mode.as_str().to_string()),
);
quant_block.insert(path.clone(), serde_json::Value::Object(nested));
}
}
}
map.insert(
"quantization".to_string(),
serde_json::Value::Object(quant_block),
);
let updated_text = serde_json::to_string(&serde_json::Value::Object(map)).map_err(|e| {
Error::Parse(ParsePayload::new(
"convert: cannot re-serialize patched config",
"JSON",
e,
))
})?;
let live_cfg = PerLayerQuantization::new(Some(global), per_layer_overrides);
Ok((live_cfg, updated_text))
}
fn strip_quantization_blocks(config_json: &str) -> Result<String> {
let value: serde_json::Value = serde_json::from_str(config_json)
.map_err(|e| Error::Parse(ParsePayload::new("convert: source config", "JSON", e)))?;
let serde_json::Value::Object(mut map) = value else {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"convert: source config JSON",
"must be an object",
)));
};
map.remove("quantization");
map.remove("quantization_config");
let stripped = serde_json::Value::Object(map);
serde_json::to_string(&stripped).map_err(|e| {
Error::Parse(ParsePayload::new(
"convert: cannot re-serialize stripped config",
"JSON",
e,
))
})
}
pub(crate) const TOKENIZER_EXTRA_FILES: &[&str] = &[
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"added_tokens.json",
"spiece.model",
"tokenizer.model",
"vocab.json",
"merges.txt",
"chat_template.jinja",
"generation_config.json",
];
#[derive(Debug, Default)]
pub struct CopyDurabilityWarnings {
pub post_copy_file: Option<std::io::Error>,
pub post_copy_dir: Option<std::io::Error>,
}
#[derive(Debug)]
pub enum CopyOutcome {
Committed,
CommittedWithDurabilityWarning(CopyDurabilityWarnings),
}
pub fn copy_tokenizer_and_extras(src: &Path, dst: &Path) -> Result<CopyOutcome> {
if paths_are_same(src, dst) {
return Ok(CopyOutcome::Committed);
}
let mut warnings = CopyDurabilityWarnings::default();
let mut record_file_fsync_warning = |e: std::io::Error| {
if warnings.post_copy_file.is_none() {
warnings.post_copy_file = Some(e);
}
};
fn wrap_fsync_err(dst: &Path, e: &std::io::Error) -> std::io::Error {
std::io::Error::new(
e.kind(),
format!(
"copy_tokenizer_and_extras: fsync {} failed: {e}",
dst.display()
),
)
}
for name in TOKENIZER_EXTRA_FILES {
let s = src.join(name);
if !s.is_file() {
continue;
}
let d = dst.join(name);
std::fs::copy(&s, &d).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"copy_tokenizer_and_extras",
FileOp::Copy,
d.clone(),
e,
))
})?;
if let Err(e) = crate::lm::load::fsync_path_io(&d) {
record_file_fsync_warning(wrap_fsync_err(&d, &e));
}
}
let entries = std::fs::read_dir(src).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"copy_tokenizer_and_extras",
FileOp::Read,
src.to_path_buf(),
e,
))
})?;
for entry in entries {
let entry = entry.map_err(|e| {
Error::FileIo(FileIoPayload::new(
"copy_tokenizer_and_extras: cannot read entry in",
FileOp::Read,
src.to_path_buf(),
e,
))
})?;
let path = entry.path();
if !path.is_file() {
continue;
}
let Some(name) = path.file_name().and_then(|n| n.to_str()) else {
continue;
};
if !name.ends_with(".py") {
continue;
}
let d = dst.join(name);
std::fs::copy(&path, &d).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"copy_tokenizer_and_extras",
FileOp::Copy,
d.clone(),
e,
))
})?;
if let Err(e) = crate::lm::load::fsync_path_io(&d) {
record_file_fsync_warning(wrap_fsync_err(&d, &e));
}
}
if let Err(dir_err) = crate::lm::load::fsync_dir(dst) {
warnings.post_copy_dir = Some(dir_err);
}
Ok(
if warnings.post_copy_file.is_none() && warnings.post_copy_dir.is_none() {
CopyOutcome::Committed
} else {
CopyOutcome::CommittedWithDurabilityWarning(warnings)
},
)
}
fn paths_are_same(src: &Path, dst: &Path) -> bool {
match (std::fs::canonicalize(src), std::fs::canonicalize(dst)) {
(Ok(a), Ok(b)) => a == b,
_ => src == dst,
}
}
#[cfg(test)]
mod unit {
use super::*;
#[test]
fn defaults_for_mode_table_matches_utils_py_800_808() {
assert_eq!(defaults_for_mode(QuantMode::Affine, None, None), (64, 4));
assert_eq!(defaults_for_mode(QuantMode::Mxfp4, None, None), (32, 4));
assert_eq!(defaults_for_mode(QuantMode::Nvfp4, None, None), (16, 4));
assert_eq!(defaults_for_mode(QuantMode::Mxfp8, None, None), (32, 8));
assert_eq!(
defaults_for_mode(QuantMode::Affine, Some(128), Some(8)),
(128, 8)
);
}
#[test]
fn defaults_for_mode_zero_group_size_is_falsy() {
assert_eq!(
defaults_for_mode(QuantMode::Affine, Some(0), None),
(64, 4),
"Some(0) group_size falls back to mode default"
);
assert_eq!(
defaults_for_mode(QuantMode::Mxfp4, Some(0), None),
(32, 4),
"Some(0) group_size falls back to mxfp4 default"
);
}
#[test]
fn defaults_for_mode_zero_bits_is_falsy() {
assert_eq!(
defaults_for_mode(QuantMode::Affine, None, Some(0)),
(64, 4),
"Some(0) bits falls back to mode default"
);
assert_eq!(
defaults_for_mode(QuantMode::Mxfp8, None, Some(0)),
(32, 8),
"Some(0) bits falls back to mxfp8 default"
);
}
#[test]
fn defaults_for_mode_negative_also_falls_back() {
assert_eq!(
defaults_for_mode(QuantMode::Affine, Some(-1), Some(-2)),
(64, 4)
);
}
#[test]
fn parse_conversion_dtype_table_matches_convert_py_82() {
assert_eq!(parse_conversion_dtype("float16"), Some(Dtype::F16));
assert_eq!(parse_conversion_dtype("bfloat16"), Some(Dtype::BF16));
assert_eq!(parse_conversion_dtype("float32"), Some(Dtype::F32));
assert_eq!(parse_conversion_dtype("float64"), None);
assert_eq!(parse_conversion_dtype("int32"), None);
assert_eq!(parse_conversion_dtype(""), None);
}
#[test]
fn resolve_target_dtype_explicit_wins() {
let cfg = r#"{"torch_dtype":"float32"}"#;
assert_eq!(
resolve_target_dtype(Some(Dtype::BF16), cfg).unwrap(),
Some(Dtype::BF16)
);
}
#[test]
fn resolve_target_dtype_falls_back_to_torch_dtype() {
let cfg = r#"{"torch_dtype":"bfloat16"}"#;
assert_eq!(resolve_target_dtype(None, cfg).unwrap(), Some(Dtype::BF16));
}
#[test]
fn resolve_target_dtype_falls_back_to_text_config_dtype() {
let cfg = r#"{"text_config":{"dtype":"float16"}}"#;
assert_eq!(resolve_target_dtype(None, cfg).unwrap(), Some(Dtype::F16));
}
#[test]
fn resolve_target_dtype_unknown_is_none() {
let cfg = r#"{"torch_dtype":"float64"}"#;
assert_eq!(resolve_target_dtype(None, cfg).unwrap(), None);
}
#[test]
fn resolve_target_dtype_explicit_i32_is_error() {
let cfg = r#"{"torch_dtype":"float32"}"#;
match resolve_target_dtype(Some(Dtype::I32), cfg) {
Err(Error::UnsupportedDtype(p)) => {
assert_eq!(p.dtype(), Dtype::I32);
assert_eq!(p.supported(), &[Dtype::F16, Dtype::BF16, Dtype::F32]);
assert!(
p.context().contains("MODEL_CONVERSION_DTYPES"),
"context names the reference set: {}",
p.context()
);
}
other => panic!("expected Err(UnsupportedDtype), got {other:?}"),
}
}
#[test]
fn resolve_target_dtype_explicit_f64_is_error() {
let cfg = r#"{}"#;
match resolve_target_dtype(Some(Dtype::F64), cfg) {
Err(Error::UnsupportedDtype(p)) => {
assert_eq!(p.dtype(), Dtype::F64);
assert_eq!(p.supported(), &[Dtype::F16, Dtype::BF16, Dtype::F32]);
}
other => panic!("expected Err(UnsupportedDtype), got {other:?}"),
}
}
#[test]
fn resolve_target_dtype_explicit_complex_is_error() {
let cfg = r#"{}"#;
match resolve_target_dtype(Some(Dtype::Complex64), cfg) {
Err(Error::UnsupportedDtype(p)) => {
assert_eq!(p.dtype(), Dtype::Complex64);
assert_eq!(p.supported(), &[Dtype::F16, Dtype::BF16, Dtype::F32]);
}
other => panic!("expected Err(UnsupportedDtype), got {other:?}"),
}
}
#[test]
fn resolve_target_dtype_explicit_bool_is_error() {
let cfg = r#"{}"#;
match resolve_target_dtype(Some(Dtype::Bool), cfg) {
Err(Error::UnsupportedDtype(p)) => {
assert_eq!(p.dtype(), Dtype::Bool);
assert_eq!(p.supported(), &[Dtype::F16, Dtype::BF16, Dtype::F32]);
}
other => panic!("expected Err(UnsupportedDtype), got {other:?}"),
}
}
#[test]
fn strip_quantization_blocks_removes_both_keys() {
let cfg = r#"{
"model_type":"qwen3",
"quantization":{"group_size":64,"bits":4},
"quantization_config":{"group_size":64,"bits":4}
}"#;
let out = strip_quantization_blocks(cfg).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&out).unwrap();
assert!(parsed.get("quantization").is_none());
assert!(parsed.get("quantization_config").is_none());
assert_eq!(
parsed.get("model_type").and_then(|v| v.as_str()),
Some("qwen3")
);
}
use std::cell::RefCell;
struct CountingPredicate {
counts: RefCell<HashMap<String, u32>>,
cycle: RefCell<u32>,
}
impl CountingPredicate {
fn new() -> Self {
Self {
counts: RefCell::new(HashMap::new()),
cycle: RefCell::new(0),
}
}
fn max_count(&self) -> u32 {
self.counts.borrow().values().copied().max().unwrap_or(0)
}
fn paths_seen(&self) -> Vec<String> {
let mut paths: Vec<String> = self.counts.borrow().keys().cloned().collect();
paths.sort();
paths
}
}
impl MixedQuantPredicate for CountingPredicate {
fn decide(&self, layer_name: &str, _weight: &Array) -> Option<Quantization> {
*self
.counts
.borrow_mut()
.entry(layer_name.to_string())
.or_insert(0) += 1;
*self.cycle.borrow_mut() += 1;
Some(Quantization {
group_size: 64,
bits: 4,
mode: QuantMode::Affine,
})
}
}
#[test]
fn build_predicate_decisions_calls_predicate_once_per_eligible_layer() {
let mut weights: Weights = HashMap::new();
for path in ["layer.a", "layer.b", "layer.c"] {
weights.insert(
format!("{path}.weight"),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
}
weights.insert(
"layer.d.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 64], &(64usize,)).unwrap(),
);
let pred = CountingPredicate::new();
let decisions = build_predicate_decisions(Some(&pred), &weights, 64);
let counts = pred.counts.borrow();
assert_eq!(counts.len(), 3, "exactly the 3 eligible layers visited");
for path in ["layer.a", "layer.b", "layer.c"] {
assert_eq!(
counts.get(path).copied(),
Some(1),
"{path} called exactly once",
);
}
assert!(
!counts.contains_key("layer.d"),
"structurally-ineligible layer never reaches the predicate"
);
assert_eq!(decisions.len(), 3, "decision map has 3 eligible entries");
for path in ["layer.a", "layer.b", "layer.c"] {
assert!(
matches!(decisions.get(path), Some(Some(_))),
"{path} decision recorded as Some(Some(_))",
);
}
}
#[test]
fn build_quantize_config_does_not_reinvoke_predicate() {
let mut weights: Weights = HashMap::new();
for path in ["layer.a", "layer.b"] {
weights.insert(
format!("{path}.weight"),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
}
let pred = CountingPredicate::new();
let decisions = build_predicate_decisions(Some(&pred), &weights, 64);
let after_decisions = pred.counts.borrow().clone();
let _ = build_quantize_config("{}", 64, 4, QuantMode::Affine, &decisions, &weights).unwrap();
assert_eq!(
*pred.counts.borrow(),
after_decisions,
"build_quantize_config must not re-invoke the predicate"
);
assert_eq!(pred.max_count(), 1, "every layer's count is still 1");
let paths = pred.paths_seen();
assert_eq!(paths, vec!["layer.a".to_string(), "layer.b".to_string()]);
}
#[test]
fn convert_durability_warning_still_copies_tokenizer_and_returns_warning() {
let dir = std::env::temp_dir().join(format!("mlxrs_convert_durability_{}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
let src = dir.join("src");
let dst = dir.join("dst");
std::fs::create_dir_all(&src).unwrap();
let plain_config = r#"{
"model_type":"qwen3","hidden_size":16,"num_hidden_layers":1,
"num_attention_heads":2,"num_key_value_heads":2,"head_dim":8,
"rope_theta":10000.0,"vocab_size":128,"tie_word_embeddings":false
}"#;
std::fs::write(src.join("config.json"), plain_config).unwrap();
let blob: Vec<f32> = (0..128).map(|i| (i as f32) * 0.01).collect();
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.weight".to_string(),
Array::from_slice::<f32>(&blob, &(2usize, 64usize)).unwrap(),
);
crate::io::save_safetensors(&src.join("model.safetensors"), &weights).unwrap();
let tokenizer_json = include_str!("../../tests/fixtures/tokenizer.json");
let tokenizer_config_json = include_str!("../../tests/fixtures/tokenizer_config.json");
std::fs::write(src.join("tokenizer.json"), tokenizer_json).unwrap();
std::fs::write(src.join("tokenizer_config.json"), tokenizer_config_json).unwrap();
std::fs::write(
src.join("special_tokens_map.json"),
br#"{"eos_token":"</s>"}"#,
)
.unwrap();
std::fs::write(src.join("generation_config.json"), br#"{"max_length":32}"#).unwrap();
let _guard = crate::lm::load::arm_fsync_dir_fault(1);
let r = convert(ConvertArgs {
hf_path: src.clone(),
mlx_path: dst.clone(),
..Default::default()
});
drop(_guard);
match r {
Err(Error::DurabilityWarning(p)) => {
assert!(
p.committed(),
"convert's DurabilityWarning must carry committed=true"
);
assert!(
p.source()
.to_string()
.contains("injected fsync_dir failure"),
"underlying io::Error preserved: got {}",
p.source()
);
}
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
assert!(dst.join("config.json").is_file(), "config.json on disk");
assert!(
dst.join("model.safetensors.index.json").is_file(),
"index.json on disk"
);
for name in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"generation_config.json",
] {
assert!(
dst.join(name).is_file(),
"{name} copied despite the DurabilityWarning"
);
let a = std::fs::read(src.join(name)).unwrap();
let b = std::fs::read(dst.join(name)).unwrap();
assert_eq!(a, b, "{name} byte-equal at dst");
}
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn convert_durability_warning_then_tokenizer_copy_failure_preserves_committed_signal() {
use std::os::unix::fs::PermissionsExt;
let dir = std::env::temp_dir().join(format!(
"mlxrs_convert_durability_then_copyfail_{}",
std::process::id()
));
let _ = std::fs::remove_dir_all(&dir);
let src = dir.join("src");
let dst = dir.join("dst");
std::fs::create_dir_all(&src).unwrap();
let plain_config = r#"{
"model_type":"qwen3","hidden_size":16,"num_hidden_layers":1,
"num_attention_heads":2,"num_key_value_heads":2,"head_dim":8,
"rope_theta":10000.0,"vocab_size":128,"tie_word_embeddings":false
}"#;
std::fs::write(src.join("config.json"), plain_config).unwrap();
let blob: Vec<f32> = (0..128).map(|i| (i as f32) * 0.01).collect();
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.weight".to_string(),
Array::from_slice::<f32>(&blob, &(2usize, 64usize)).unwrap(),
);
crate::io::save_safetensors(&src.join("model.safetensors"), &weights).unwrap();
let tokenizer_json = include_str!("../../tests/fixtures/tokenizer.json");
let tokenizer_config_json = include_str!("../../tests/fixtures/tokenizer_config.json");
std::fs::write(src.join("tokenizer.json"), tokenizer_json).unwrap();
std::fs::write(src.join("tokenizer_config.json"), tokenizer_config_json).unwrap();
let chmod_target = src.join("special_tokens_map.json");
std::fs::write(&chmod_target, br#"{"eos_token":"</s>"}"#).unwrap();
std::fs::write(src.join("generation_config.json"), br#"{"max_length":32}"#).unwrap();
let mut perm = std::fs::metadata(&chmod_target).unwrap().permissions();
perm.set_mode(0o000);
std::fs::set_permissions(&chmod_target, perm).unwrap();
struct PermRestore(std::path::PathBuf);
impl Drop for PermRestore {
fn drop(&mut self) {
if let Ok(meta) = std::fs::metadata(&self.0) {
let mut p = meta.permissions();
p.set_mode(0o644);
let _ = std::fs::set_permissions(&self.0, p);
}
}
}
let _perm_guard = PermRestore(chmod_target);
let _guard = crate::lm::load::arm_fsync_dir_fault(1);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst.clone(),
..Default::default()
});
drop(_guard);
match &r {
Err(Error::ConvertPostSavePartial(p)) => {
assert!(
p.committed(),
"ConvertPostSavePartial must carry committed=true (variant is \
reachable only after the observable commit point)"
);
let save_warning = p
.save_warning()
.expect("save_warning must be Some — the fsync-dir injector fired");
let _ = save_warning.kind();
assert!(
save_warning
.to_string()
.contains("injected fsync_dir failure"),
"save_warning preserves the verbatim fsync_dir io::Error \
message: got {save_warning}"
);
let copy_error = p.copy_error();
assert!(
matches!(copy_error, Error::FileIo(_)),
"copy_error is the typed FileIo variant (machine-readable, \
not stringified); got: {copy_error:?}"
);
let copy_msg = copy_error.to_string();
assert!(
copy_msg.contains("copy_tokenizer_and_extras"),
"copy_error names copy_tokenizer_and_extras; got: {copy_msg}"
);
assert!(
copy_msg.contains("special_tokens_map.json"),
"copy_error names the failing file (special_tokens_map.json); \
got: {copy_msg}"
);
}
other => panic!(
"expected Err(ConvertPostSavePartial), got {other:?} — the post-save \
copy failure must surface the structured variant so the caller can \
machine-detect 'destination structurally incomplete'"
),
}
assert!(dst.join("config.json").is_file(), "config.json on disk");
assert!(
dst.join("model.safetensors.index.json").is_file(),
"index.json on disk"
);
let any_shard = std::fs::read_dir(&dst)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.path()
.file_name()
.and_then(|n| n.to_str())
.map(|n| n.ends_with(".safetensors"))
.unwrap_or(false)
});
assert!(any_shard, "at least one shard committed on disk");
assert!(
!dst.join("special_tokens_map.json").is_file(),
"the chmod-000 source file MUST NOT have been copied (its \
read failed before any bytes were written to dst)"
);
drop(_perm_guard);
let _ = std::fs::remove_dir_all(&dir);
}
#[cfg(unix)]
#[test]
fn convert_no_durability_warning_then_tokenizer_copy_failure_returns_partial_with_no_save_warning()
{
use std::os::unix::fs::PermissionsExt;
let dir = std::env::temp_dir().join(format!(
"mlxrs_convert_clean_save_then_copyfail_{}",
std::process::id()
));
let _ = std::fs::remove_dir_all(&dir);
let src = dir.join("src");
let dst = dir.join("dst");
std::fs::create_dir_all(&src).unwrap();
let plain_config = r#"{
"model_type":"qwen3","hidden_size":16,"num_hidden_layers":1,
"num_attention_heads":2,"num_key_value_heads":2,"head_dim":8,
"rope_theta":10000.0,"vocab_size":128,"tie_word_embeddings":false
}"#;
std::fs::write(src.join("config.json"), plain_config).unwrap();
let blob: Vec<f32> = (0..128).map(|i| (i as f32) * 0.01).collect();
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.weight".to_string(),
Array::from_slice::<f32>(&blob, &(2usize, 64usize)).unwrap(),
);
crate::io::save_safetensors(&src.join("model.safetensors"), &weights).unwrap();
let tokenizer_json = include_str!("../../tests/fixtures/tokenizer.json");
let tokenizer_config_json = include_str!("../../tests/fixtures/tokenizer_config.json");
std::fs::write(src.join("tokenizer.json"), tokenizer_json).unwrap();
std::fs::write(src.join("tokenizer_config.json"), tokenizer_config_json).unwrap();
let chmod_target = src.join("special_tokens_map.json");
std::fs::write(&chmod_target, br#"{"eos_token":"</s>"}"#).unwrap();
std::fs::write(src.join("generation_config.json"), br#"{"max_length":32}"#).unwrap();
let mut perm = std::fs::metadata(&chmod_target).unwrap().permissions();
perm.set_mode(0o000);
std::fs::set_permissions(&chmod_target, perm).unwrap();
struct PermRestore(std::path::PathBuf);
impl Drop for PermRestore {
fn drop(&mut self) {
if let Ok(meta) = std::fs::metadata(&self.0) {
let mut p = meta.permissions();
p.set_mode(0o644);
let _ = std::fs::set_permissions(&self.0, p);
}
}
}
let _perm_guard = PermRestore(chmod_target);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst.clone(),
..Default::default()
});
match &r {
Err(Error::ConvertPostSavePartial(p)) => {
assert!(
p.committed(),
"ConvertPostSavePartial must carry committed=true (variant is \
reachable only after the observable commit point)"
);
assert!(
p.save_warning().is_none(),
"save_warning must be None — the save returned plain Ok(()) \
with no fsync warning; got: {:?}",
p.save_warning()
);
let copy_error = p.copy_error();
assert!(
matches!(copy_error, Error::FileIo(_)),
"copy_error is the typed FileIo variant; got: {copy_error:?}"
);
let copy_msg = copy_error.to_string();
assert!(
copy_msg.contains("copy_tokenizer_and_extras"),
"copy_error names copy_tokenizer_and_extras; got: {copy_msg}"
);
assert!(
copy_msg.contains("special_tokens_map.json"),
"copy_error names the failing file (special_tokens_map.json); \
got: {copy_msg}"
);
}
other => panic!(
"expected Err(ConvertPostSavePartial), got {other:?} — a clean \
save + post-save copy failure MUST surface the structured \
variant (the destination IS committed, structurally \
incomplete)"
),
}
assert!(dst.join("config.json").is_file(), "config.json on disk");
assert!(
dst.join("model.safetensors.index.json").is_file(),
"index.json on disk"
);
let any_shard = std::fs::read_dir(&dst)
.unwrap()
.filter_map(|e| e.ok())
.any(|e| {
e.path()
.file_name()
.and_then(|n| n.to_str())
.map(|n| n.ends_with(".safetensors"))
.unwrap_or(false)
});
assert!(any_shard, "at least one shard committed on disk");
assert!(
!dst.join("special_tokens_map.json").is_file(),
"the chmod-000 source file MUST NOT have been copied"
);
drop(_perm_guard);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn convert_post_save_partial_error_chain_iterable() {
let copy_err = Error::FileIo(FileIoPayload::new(
"copy_tokenizer_and_extras",
crate::error::FileOp::Copy,
std::path::PathBuf::from("special_tokens_map.json"),
std::io::Error::new(std::io::ErrorKind::PermissionDenied, "EACCES on copy"),
));
let save_warning_inner = std::io::Error::other("fsync_dir warning");
let err = Error::ConvertPostSavePartial(ConvertPostSavePartialPayload::new(
true,
Some(save_warning_inner),
copy_err,
));
let e: &dyn std::error::Error = &err;
let top = e.to_string();
assert!(
top.contains("committed=true"),
"Display carries the structured committed=true tag; got: {top}"
);
assert!(
top.contains("destination directory may be incomplete"),
"Display carries the structurally-incomplete hint; got: {top}"
);
let source = e.source().expect(
"ConvertPostSavePartial has a #[source]-annotated chain — \
calling .source() must return the copy_error",
);
let source_msg = source.to_string();
assert!(
source_msg.contains("EACCES on copy"),
".source() returns the copy_error (the actionable failure); \
got: {source_msg}"
);
if let Error::ConvertPostSavePartial(p) = &err {
assert!(p.committed());
assert_eq!(
p.save_warning().map(|e| e.to_string()).as_deref(),
Some("fsync_dir warning"),
"save_warning is reachable via direct field access (typed accessor)"
);
let Error::FileIo(io_p) = p.copy_error() else {
unreachable!("copy_error is FileIo (constructed above)");
};
assert_eq!(io_p.inner().kind(), std::io::ErrorKind::PermissionDenied);
assert!(p.copy_error().to_string().contains("EACCES on copy"));
} else {
unreachable!("constructed ConvertPostSavePartial above");
}
}
fn build_save_fixture(tag: &str) -> (PathBuf, PathBuf, PathBuf) {
let workdir = std::env::temp_dir().join(format!("mlxrs_convert_{tag}_{}", std::process::id()));
let _ = std::fs::remove_dir_all(&workdir);
let src = workdir.join("src");
let dst = workdir.join("dst");
std::fs::create_dir_all(&src).unwrap();
let plain_config = r#"{
"model_type":"qwen3","hidden_size":16,"num_hidden_layers":1,
"num_attention_heads":2,"num_key_value_heads":2,"head_dim":8,
"rope_theta":10000.0,"vocab_size":128,"tie_word_embeddings":false
}"#;
std::fs::write(src.join("config.json"), plain_config).unwrap();
let blob: Vec<f32> = (0..128).map(|i| (i as f32) * 0.01).collect();
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.weight".to_string(),
Array::from_slice::<f32>(&blob, &(2usize, 64usize)).unwrap(),
);
crate::io::save_safetensors(&src.join("model.safetensors"), &weights).unwrap();
let tokenizer_json = include_str!("../../tests/fixtures/tokenizer.json");
let tokenizer_config_json = include_str!("../../tests/fixtures/tokenizer_config.json");
std::fs::write(src.join("tokenizer.json"), tokenizer_json).unwrap();
std::fs::write(src.join("tokenizer_config.json"), tokenizer_config_json).unwrap();
std::fs::write(
src.join("special_tokens_map.json"),
br#"{"eos_token":"</s>"}"#,
)
.unwrap();
std::fs::write(src.join("generation_config.json"), br#"{"max_length":32}"#).unwrap();
(workdir, src, dst)
}
#[test]
fn convert_post_copy_file_fsync_failure_returns_durability_warning() {
let (workdir, src, dst) = build_save_fixture("file_fsync_fail");
let _guard = crate::lm::load::arm_fsync_path_fault(3);
let r = convert(ConvertArgs {
hf_path: src.clone(),
mlx_path: dst.clone(),
..Default::default()
});
drop(_guard);
match &r {
Err(Error::DurabilityWarning(p)) => {
assert!(
p.committed(),
"post-copy fsync warning carries committed=true (data IS on disk)"
);
let msg = p.source().to_string();
assert!(
msg.contains("injected fsync_path failure") || msg.contains("post-copy"),
"source message references the post-copy fsync; got: {msg}"
);
}
Err(Error::ConvertPostSavePartial(_)) => panic!(
"post-copy FSYNC warning must NOT surface as ConvertPostSavePartial — \
that variant is reserved for `std::fs::copy` itself failing (file \
did NOT reach disk). A post-copy fsync warning means data IS on \
disk; only durability is uncertain (DurabilityWarning contract)."
),
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
for name in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"generation_config.json",
] {
assert!(
dst.join(name).is_file(),
"{name} IS on disk despite the post-copy fsync warning"
);
let a = std::fs::read(src.join(name)).unwrap();
let b = std::fs::read(dst.join(name)).unwrap();
assert_eq!(a, b, "{name} byte-equal at dst");
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn convert_post_copy_dir_fsync_failure_returns_durability_warning() {
let (workdir, src, dst) = build_save_fixture("dir_fsync_fail");
let _guard = crate::lm::load::arm_fsync_dir_fault(3);
let r = convert(ConvertArgs {
hf_path: src.clone(),
mlx_path: dst.clone(),
..Default::default()
});
drop(_guard);
match &r {
Err(Error::DurabilityWarning(p)) => {
assert!(
p.committed(),
"post-copy dir-fsync warning carries committed=true (data IS on disk)"
);
let msg = p.source().to_string();
assert!(
msg.contains("injected fsync_dir failure") || msg.contains("post-copy fsync_dir"),
"source message references the post-copy dir fsync; got: {msg}"
);
}
Err(Error::ConvertPostSavePartial(_)) => panic!(
"post-copy DIR fsync warning must NOT surface as ConvertPostSavePartial — \
that variant is reserved for `std::fs::copy` itself failing. A \
post-copy dir-fsync warning means data IS on disk (every file's \
own fsync passed); only the dir-entry durability is uncertain."
),
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
for name in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"generation_config.json",
] {
assert!(dst.join(name).is_file(), "{name} on disk");
let a = std::fs::read(src.join(name)).unwrap();
let b = std::fs::read(dst.join(name)).unwrap();
assert_eq!(a, b, "{name} byte-equal at dst");
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn convert_post_copy_both_fsyncs_fail_combined_message() {
let (workdir, src, dst) = build_save_fixture("both_fsyncs_fail");
let _path_guard = crate::lm::load::arm_fsync_path_fault(3);
let _dir_guard = crate::lm::load::arm_fsync_dir_fault(3);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst,
..Default::default()
});
drop(_path_guard);
drop(_dir_guard);
match &r {
Err(Error::ConvertDurabilityWarnings(agg)) => {
assert!(agg.committed, "committed=true even when both fsyncs warn");
assert!(
agg.save.is_none(),
"save-side fsync passed (skip count steps past save's 3 \
fsync_dir calls); got: {:?}",
agg.save
);
let post_copy_file = agg
.post_copy_file
.as_ref()
.expect("post_copy_file fsync warned (path injector fired on the 4th call)");
let _ = post_copy_file.kind();
assert!(
post_copy_file
.to_string()
.contains("injected fsync_path failure"),
"post_copy_file preserves the verbatim file-fsync io::Error \
message (no string fold); got: {post_copy_file}"
);
let post_copy_dir = agg
.post_copy_dir
.as_ref()
.expect("post_copy_dir fsync warned (dir injector fired on the 4th call)");
let _ = post_copy_dir.kind();
assert!(
post_copy_dir
.to_string()
.contains("injected fsync_dir failure"),
"post_copy_dir preserves the verbatim dir-fsync io::Error \
message (no string fold); got: {post_copy_dir}"
);
assert_eq!(
agg.count(),
2,
"two non-None warning fields (post_copy_file + post_copy_dir)"
);
let e: &dyn std::error::Error = r.as_ref().err().unwrap();
let source = e.source().expect(
"ConvertDurabilityWarnings has a source chain via the \
inner aggregate's std::error::Error impl",
);
assert!(
source.to_string().contains("injected fsync_path failure"),
".source() returns the FIRST non-None warning \
(post_copy_file when save is None); got: {source}"
);
}
Err(Error::DurabilityWarning(_)) => panic!(
"both-fsyncs-warn surfaces the multi-warning aggregate \
ConvertDurabilityWarnings, NOT the single-warning \
DurabilityWarning shape (typed access to each \
boundary)"
),
Err(Error::ConvertPostSavePartial(_)) => panic!(
"both-fsyncs-warn surfaces ConvertDurabilityWarnings, NOT \
ConvertPostSavePartial (data IS on disk; only durability \
uncertain on two boundaries)"
),
other => panic!("expected Err(ConvertDurabilityWarnings), got {other:?}"),
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn convert_ok_implies_post_copy_fsyncs_called() {
{
let (workdir, src, dst) = build_save_fixture("happy_path");
let r = convert(ConvertArgs {
hf_path: src.clone(),
mlx_path: dst.clone(),
..Default::default()
});
assert!(
matches!(r, Ok(())),
"happy path returns Ok(()) — every fsync passes; got: {r:?}"
);
for name in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"generation_config.json",
] {
assert!(
dst.join(name).is_file(),
"{name} on disk after happy-path convert"
);
let a = std::fs::read(src.join(name)).unwrap();
let b = std::fs::read(dst.join(name)).unwrap();
assert_eq!(a, b, "{name} byte-equal");
}
let _ = std::fs::remove_dir_all(&workdir);
}
{
let (workdir, src, dst) = build_save_fixture("happy_path_spy_file");
let _guard = crate::lm::load::arm_fsync_path_fault(3);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst,
..Default::default()
});
drop(_guard);
assert!(
matches!(r, Err(Error::DurabilityWarning(_))),
"fsync_path injector armed past every save-side call must be \
observed by the post-copy file fsync loop — a silent removal \
of that loop would leave the injector unfired and the result \
Ok(()); got: {r:?}"
);
let _ = std::fs::remove_dir_all(&workdir);
}
{
let (workdir, src, dst) = build_save_fixture("happy_path_spy_dir");
let _guard = crate::lm::load::arm_fsync_dir_fault(3);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst,
..Default::default()
});
drop(_guard);
assert!(
matches!(r, Err(Error::DurabilityWarning(_))),
"fsync_dir injector armed past every save-side call must be \
observed by the post-copy `fsync_dir(dst)` call — a silent \
removal of that call would leave the injector unfired and \
the result Ok(()); got: {r:?}"
);
let _ = std::fs::remove_dir_all(&workdir);
}
}
#[test]
fn convert_save_and_post_copy_dir_warn_returns_aggregate() {
let save = std::io::Error::other("save-side fsync_dir warning");
let post_copy_dir = std::io::Error::other("post-copy fsync_dir warning");
let agg = crate::error::ConvertDurabilityWarnings {
committed: true,
save: Some(save),
post_copy_file: None,
post_copy_dir: Some(post_copy_dir),
};
assert_eq!(agg.count(), 2, "two non-None fields");
let err: Error = agg.into();
match &err {
Error::ConvertDurabilityWarnings(agg) => {
assert!(agg.committed);
let save = agg
.save
.as_ref()
.expect("save warning present (direct destructure)");
assert!(save.to_string().contains("save-side fsync_dir warning"));
assert!(
agg.post_copy_file.is_none(),
"post_copy_file is None: {:?}",
agg.post_copy_file
);
let post_copy_dir = agg
.post_copy_dir
.as_ref()
.expect("post_copy_dir warning present (direct destructure)");
assert!(
post_copy_dir
.to_string()
.contains("post-copy fsync_dir warning")
);
assert!(
agg
.first_warning()
.unwrap()
.to_string()
.contains("save-side fsync_dir warning"),
"first_warning() returns save when save is Some"
);
}
other => panic!(
"the aggregate-count routing must produce \
ConvertDurabilityWarnings, NOT {other:?}"
),
}
}
#[test]
fn convert_save_and_post_copy_file_warn_returns_aggregate() {
let (workdir, src, dst) = build_save_fixture("save_and_postcopy_file_warn");
let _dir_guard = crate::lm::load::arm_fsync_dir_fault(2);
let _path_guard = crate::lm::load::arm_fsync_path_fault(3);
let r = convert(ConvertArgs {
hf_path: src.clone(),
mlx_path: dst.clone(),
..Default::default()
});
drop(_dir_guard);
drop(_path_guard);
match &r {
Err(Error::ConvertDurabilityWarnings(agg)) => {
assert!(agg.committed, "committed=true");
let save = agg
.save
.as_ref()
.expect("save warning is Some (fsync_dir injector skip=2 fired on save's config-commit dir fsync)");
let _ = save.kind();
assert!(
save.to_string().contains("injected fsync_dir failure"),
"save preserves the verbatim save-side fsync_dir io::Error \
message; got: {save}"
);
let post_copy_file = agg.post_copy_file.as_ref().expect(
"post_copy_file warning is Some (fsync_path injector skip=3 \
fired on the first post-copy per-file fsync)",
);
let _ = post_copy_file.kind();
assert!(
post_copy_file
.to_string()
.contains("injected fsync_path failure"),
"post_copy_file preserves the verbatim post-copy fsync_path \
io::Error message; got: {post_copy_file}"
);
assert!(
agg.post_copy_dir.is_none(),
"post_copy_dir is None (single-shot fsync_dir guard fired \
during save and is not re-armed for the post-copy dir fsync); \
got: {:?}",
agg.post_copy_dir
);
assert_eq!(agg.count(), 2, "two non-None warnings");
assert!(
agg
.first_warning()
.unwrap()
.to_string()
.contains("injected fsync_dir failure"),
"first_warning() returns save when save is Some"
);
}
Err(Error::DurabilityWarning(_)) => panic!(
"save warned + post-copy file fsync warned MUST surface the \
typed aggregate ConvertDurabilityWarnings, NOT the \
single-source DurabilityWarning (which would fold the two via \
io::Error::other(format!(...)))"
),
other => panic!("expected Err(ConvertDurabilityWarnings), got {other:?}"),
}
for name in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"generation_config.json",
] {
assert!(dst.join(name).is_file(), "{name} on disk");
let a = std::fs::read(src.join(name)).unwrap();
let b = std::fs::read(dst.join(name)).unwrap();
assert_eq!(a, b, "{name} byte-equal at dst");
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn convert_durability_aggregate_error_chain_walkable() {
{
let save = std::io::Error::other("SAVE warning");
let pcf = std::io::Error::other("PCF warning");
let pcd = std::io::Error::other("PCD warning");
let agg = crate::error::ConvertDurabilityWarnings {
committed: true,
save: Some(save),
post_copy_file: Some(pcf),
post_copy_dir: Some(pcd),
};
assert_eq!(agg.count(), 3);
let err: Error = agg.into();
let e: &dyn std::error::Error = &err;
let source = e
.source()
.expect("source chain non-empty (any non-None warning)");
assert!(
source.to_string().contains("SAVE warning"),
".source() returns the save warning (highest priority when \
present); got: {source}"
);
if let Error::ConvertDurabilityWarnings(agg) = &err {
assert_eq!(agg.save.as_ref().unwrap().kind(), std::io::ErrorKind::Other);
assert_eq!(
agg.post_copy_file.as_ref().unwrap().kind(),
std::io::ErrorKind::Other
);
assert_eq!(
agg.post_copy_dir.as_ref().unwrap().kind(),
std::io::ErrorKind::Other
);
assert!(agg.save.as_ref().unwrap().to_string().contains("SAVE"));
assert!(
agg
.post_copy_file
.as_ref()
.unwrap()
.to_string()
.contains("PCF")
);
assert!(
agg
.post_copy_dir
.as_ref()
.unwrap()
.to_string()
.contains("PCD")
);
} else {
unreachable!("constructed ConvertDurabilityWarnings");
}
}
{
let pcf = std::io::Error::other("PCF only warning");
let pcd = std::io::Error::other("PCD only warning");
let agg = crate::error::ConvertDurabilityWarnings {
committed: true,
save: None,
post_copy_file: Some(pcf),
post_copy_dir: Some(pcd),
};
let err: Error = agg.into();
let e: &dyn std::error::Error = &err;
let source = e.source().expect("source chain non-empty");
assert!(
source.to_string().contains("PCF only warning"),
".source() returns post_copy_file when save is None (next \
priority); got: {source}"
);
}
{
let pcd = std::io::Error::other("PCD lone warning");
let agg = crate::error::ConvertDurabilityWarnings {
committed: true,
save: None,
post_copy_file: None,
post_copy_dir: Some(pcd),
};
let err: Error = agg.into();
let e: &dyn std::error::Error = &err;
let source = e.source().expect("source chain non-empty");
assert!(
source.to_string().contains("PCD lone warning"),
".source() returns post_copy_dir when both higher-priority \
fields are None; got: {source}"
);
}
{
let agg = crate::error::ConvertDurabilityWarnings {
committed: true,
save: None,
post_copy_file: None,
post_copy_dir: None,
};
assert_eq!(agg.count(), 0);
let err: Error = agg.into();
let e: &dyn std::error::Error = &err;
assert!(
e.source().is_none(),
".source() is None when every field is None"
);
assert!(
err.to_string().contains("committed=true"),
"Display carries the committed=true tag; got: {err}"
);
}
}
#[test]
fn convert_post_copy_file_warning_preserves_io_error_kind() {
let (workdir, src, dst) = build_save_fixture("post_copy_file_kind");
let _guard =
crate::lm::load::arm_fsync_path_fault_with_kind(3, std::io::ErrorKind::PermissionDenied);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst,
..Default::default()
});
drop(_guard);
match &r {
Err(Error::DurabilityWarning(p)) => {
assert!(p.committed(), "committed=true (data IS on disk)");
assert_eq!(
p.source().kind(),
std::io::ErrorKind::PermissionDenied,
"post_copy_file warning preserves the injected ErrorKind \
(PermissionDenied) end-to-end — an \
`io::Error::other(message)` re-wrap would collapse this to \
ErrorKind::Other; got: {:?} ({})",
p.source().kind(),
p.source()
);
assert!(
p.source()
.to_string()
.contains("injected fsync_path failure"),
"source preserves the verbatim injector message: {}",
p.source()
);
}
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn convert_post_copy_dir_warning_preserves_io_error_kind() {
let (workdir, src, dst) = build_save_fixture("post_copy_dir_kind");
let _guard = crate::lm::load::arm_fsync_dir_fault_with_kind(3, std::io::ErrorKind::StorageFull);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst,
..Default::default()
});
drop(_guard);
match &r {
Err(Error::DurabilityWarning(p)) => {
assert!(p.committed(), "committed=true (data IS on disk)");
assert_eq!(
p.source().kind(),
std::io::ErrorKind::StorageFull,
"post_copy_dir warning preserves the injected ErrorKind \
(StorageFull / ENOSPC) end-to-end; got: {:?} ({})",
p.source().kind(),
p.source()
);
assert!(
p.source()
.to_string()
.contains("injected fsync_dir failure"),
"source preserves the verbatim injector message: {}",
p.source()
);
}
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn convert_save_warning_preserves_io_error_kind() {
let (workdir, src, dst) = build_save_fixture("save_kind");
let _guard =
crate::lm::load::arm_fsync_dir_fault_with_kind(1, std::io::ErrorKind::ConnectionReset);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst,
..Default::default()
});
drop(_guard);
match &r {
Err(Error::DurabilityWarning(p)) => {
assert!(p.committed(), "committed=true (save committed)");
assert_eq!(
p.source().kind(),
std::io::ErrorKind::ConnectionReset,
"save-side warning preserves the injected ErrorKind \
(ConnectionReset) end-to-end through \
CommitOutcome::CommittedWithDurabilityWarning → \
Error::DurabilityWarning → convert's committed_warning \
stash → ConvertDurabilityWarnings.save; got: {:?} ({})",
p.source().kind(),
p.source()
);
assert!(
p.source()
.to_string()
.contains("injected fsync_dir failure"),
"source preserves the verbatim injector message: {}",
p.source()
);
}
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn convert_post_copy_file_warning_includes_destination_path() {
let (workdir, src, dst) = build_save_fixture("post_copy_file_path_ctx");
let _guard =
crate::lm::load::arm_fsync_path_fault_with_kind(3, std::io::ErrorKind::PermissionDenied);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst.clone(),
..Default::default()
});
drop(_guard);
match &r {
Err(Error::DurabilityWarning(p)) => {
assert!(p.committed(), "committed=true (data IS on disk)");
assert_eq!(
p.source().kind(),
std::io::ErrorKind::PermissionDenied,
"the wrap preserves the underlying ErrorKind; \
got: {:?} ({})",
p.source().kind(),
p.source()
);
let msg = p.source().to_string();
assert!(
msg.contains("copy_tokenizer_and_extras: fsync"),
"the wrap adds the operation tag `copy_tokenizer_and_extras: \
fsync ...`; got: {msg}"
);
assert!(
msg.contains("tokenizer.json"),
"wrap names the destination filename (tokenizer.json); got: \
{msg}"
);
let expected_dst = dst.join("tokenizer.json");
assert!(
msg.contains(&expected_dst.display().to_string()),
"wrap names the full destination path ({}); got: {msg}",
expected_dst.display()
);
assert!(
msg.contains("injected fsync_path failure"),
"wrap preserves the verbatim inner io::Error message; got: \
{msg}"
);
}
other => panic!("expected Err(DurabilityWarning), got {other:?}"),
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn convert_post_copy_file_real_failure_includes_path_and_kind() {
let (workdir, src, dst) = build_save_fixture("post_copy_file_real_fail");
let _guard = crate::lm::load::arm_fsync_path_fault_remove_then_fail(3);
let r = convert(ConvertArgs {
hf_path: src,
mlx_path: dst.clone(),
..Default::default()
});
drop(_guard);
match &r {
Err(Error::DurabilityWarning(p)) => {
assert!(
p.committed(),
"post-copy fsync warning carries committed=true (the file's \
bytes reached disk via std::fs::copy BEFORE the fsync ran; \
the durability-uncertain window is between copy returning \
Ok and the fsync completing)"
);
assert_eq!(
p.source().kind(),
std::io::ErrorKind::NotFound,
"real OS failure path produces the natural File::open kind \
(NotFound) — observed {:?} ({})",
p.source().kind(),
p.source()
);
let msg = p.source().to_string();
assert!(
msg.contains("tokenizer.json"),
"wrap names the destination filename (tokenizer.json) so the \
caller can pinpoint WHICH copied file warned; got: {msg}"
);
let expected_dst = dst.join("tokenizer.json");
assert!(
msg.contains(&expected_dst.display().to_string()),
"wrap names the full destination path ({}) so the caller can \
navigate to the failing file directly; got: {msg}",
expected_dst.display()
);
assert!(
msg.contains("copy_tokenizer_and_extras: fsync"),
"wrap adds the operation tag so a real OS error (which has no \
path embedded — the message is OS-level text like `No such \
file or directory (os error 2)`) can be traced back to the \
post-copy fsync step in copy_tokenizer_and_extras; got: {msg}"
);
assert!(
!msg.contains("injected fsync_path failure"),
"this test drives the REAL OS failure path (File::open on a \
removed file); the standard injector's synthesized marker \
`injected fsync_path failure` must NOT appear — its presence \
would mean the remove_then_fail injector regressed to using \
the synthesized error path; got: {msg}"
);
}
other => panic!(
"expected Err(DurabilityWarning) carrying the real OS NotFound \
(kind from File::open on a removed file) wrapped with the operation path + \
operation context, got {other:?}"
),
}
let _ = std::fs::remove_dir_all(&workdir);
}
#[test]
fn mixed_quant_predicate_no_numeric_segment_is_layer_keyed_error() {
let mut weights: Weights = HashMap::new();
weights.insert(
"model.decoder.down_proj.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
match mixed_quant_predicate(MixedQuantRecipe::Mixed3_6, &weights, 64) {
Err(Error::LayerKeyed(p)) => {
assert_eq!(
p.layer(),
"model.decoder.down_proj",
"LayerKeyed names the offending down_proj path"
);
match p.inner() {
Error::InvariantViolation(iv) => {
assert!(
iv.context().contains("down_proj"),
"inner InvariantViolation names the down_proj path context: {}",
iv.context()
);
assert!(
iv.requirement().contains("numeric layer-index segment"),
"inner requirement quotes the numeric-segment rule: {}",
iv.requirement()
);
}
other => panic!("expected inner InvariantViolation, got {other:?}"),
}
}
other => panic!("expected Err(LayerKeyed), got {other:?}"),
}
}
#[test]
fn resolve_target_dtype_unparseable_config_is_none() {
assert_eq!(resolve_target_dtype(None, "{ not json").unwrap(), None);
assert_eq!(resolve_target_dtype(None, "@@@").unwrap(), None);
}
#[test]
fn resolve_target_dtype_known_key_unknown_value_falls_through() {
let cfg = r#"{"torch_dtype":"int8","text_config":{"dtype":"qint4"}}"#;
assert_eq!(resolve_target_dtype(None, cfg).unwrap(), None);
let cfg2 = r#"{"text_config":{"hidden_size":16}}"#;
assert_eq!(resolve_target_dtype(None, cfg2).unwrap(), None);
}
#[test]
fn cast_floats_to_dtype_casts_floats_passes_through_rest() {
let mut weights: Weights = HashMap::new();
weights.insert(
"a.weight".to_string(),
Array::from_slice::<f32>(&[1.0_f32, 2.0, 3.0, 4.0], &(2usize, 2usize)).unwrap(),
);
weights.insert(
"b.weight".to_string(),
Array::from_slice::<f32>(&[5.0_f32, 6.0], &(1usize, 2usize))
.unwrap()
.astype(Dtype::F16)
.unwrap(),
);
weights.insert(
"c.weight".to_string(),
Array::from_slice::<f32>(&[7.0_f32, 8.0], &(1usize, 2usize))
.unwrap()
.astype(Dtype::U32)
.unwrap(),
);
let out = cast_floats_to_dtype(weights, Dtype::F16).unwrap();
assert_eq!(out.len(), 3, "every key is preserved");
assert_eq!(
out.get("a.weight").unwrap().dtype().unwrap(),
Dtype::F16,
"F32 weight cast to the F16 target"
);
assert_eq!(
out.get("b.weight").unwrap().dtype().unwrap(),
Dtype::F16,
"already-F16 weight left at F16 (no redundant cast)"
);
assert_eq!(
out.get("c.weight").unwrap().dtype().unwrap(),
Dtype::U32,
"non-floating U32 weight passed through untouched"
);
}
#[test]
fn build_predicate_decisions_nonpositive_group_size_returns_empty() {
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.a.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
let pred = CountingPredicate::new();
let d0 = build_predicate_decisions(Some(&pred), &weights, 0);
assert!(d0.is_empty(), "group_size==0 yields an empty decision map");
let dn = build_predicate_decisions(Some(&pred), &weights, -8);
assert!(dn.is_empty(), "negative group_size yields an empty map");
assert_eq!(
pred.max_count(),
0,
"the predicate is not invoked when group_size <= 0"
);
}
#[test]
fn build_predicate_decisions_skips_non_weight_and_indivisible_keys() {
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.a.bias".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 64], &(1usize, 64usize)).unwrap(),
);
weights.insert(
"layer.b.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 100], &(2usize, 50usize)).unwrap(),
);
weights.insert(
"layer.c.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 64], &(64usize,)).unwrap(),
);
weights.insert(
"layer.d.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
let pred = CountingPredicate::new();
let decisions = build_predicate_decisions(Some(&pred), &weights, 64);
assert_eq!(decisions.len(), 1, "exactly the one eligible layer");
assert!(
matches!(decisions.get("layer.d"), Some(Some(_))),
"layer.d recorded with the predicate's decision"
);
assert_eq!(pred.paths_seen(), vec!["layer.d".to_string()]);
for skipped in ["layer.a", "layer.b", "layer.c"] {
assert!(
!decisions.contains_key(skipped),
"{skipped} skipped before reaching the predicate"
);
}
}
#[test]
fn build_quantize_config_non_object_config_is_error() {
let weights: Weights = HashMap::new();
let decisions: PredicateDecisions = HashMap::new();
match build_quantize_config("[1,2,3]", 64, 4, QuantMode::Affine, &decisions, &weights) {
Err(Error::InvariantViolation(p)) => {
assert!(
p.context().contains("source config JSON"),
"context names the source config: {}",
p.context()
);
assert_eq!(p.requirement(), "must be an object");
}
other => panic!("expected Err(InvariantViolation), got {other:?}"),
}
}
#[test]
fn build_quantize_config_quantize_override_emits_nested_object() {
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.a.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
let q_diff = Quantization {
group_size: 64,
bits: 8, mode: QuantMode::Affine,
};
let mut decisions: PredicateDecisions = HashMap::new();
decisions.insert("layer.a".to_string(), Some(q_diff));
let (live, text) =
build_quantize_config("{}", 64, 4, QuantMode::Affine, &decisions, &weights).unwrap();
assert_eq!(
live.quantization,
Some(Quantization {
group_size: 64,
bits: 4,
mode: QuantMode::Affine
}),
"global default is the call's (group_size, bits, mode)"
);
assert_eq!(
live.per_layer_ref().get("layer.a"),
Some(&QuantizationOption::Quantize(q_diff)),
"layer.a carries the differing-params Quantize override"
);
let parsed: serde_json::Value = serde_json::from_str(&text).unwrap();
let block = parsed.get("quantization").unwrap();
assert_eq!(block.get("group_size").and_then(|v| v.as_i64()), Some(64));
assert_eq!(block.get("bits").and_then(|v| v.as_i64()), Some(4));
assert_eq!(block.get("mode").and_then(|v| v.as_str()), Some("affine"));
let nested = block.get("layer.a").expect("per-layer override emitted");
assert_eq!(
nested.get("bits").and_then(|v| v.as_i64()),
Some(8),
"nested override carries the differing bits=8"
);
assert_eq!(nested.get("group_size").and_then(|v| v.as_i64()), Some(64));
assert_eq!(nested.get("mode").and_then(|v| v.as_str()), Some("affine"));
}
#[test]
fn build_quantize_config_equal_default_no_override_when_not_fine_grained() {
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.a.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
let global = Quantization {
group_size: 64,
bits: 4,
mode: QuantMode::Affine,
};
let mut decisions: PredicateDecisions = HashMap::new();
decisions.insert("layer.a".to_string(), Some(global));
let (live, text) =
build_quantize_config("{}", 64, 4, QuantMode::Affine, &decisions, &weights).unwrap();
assert!(
live.per_layer_ref().is_empty(),
"no override when the decision equals the global and config is not fine-grained"
);
let parsed: serde_json::Value = serde_json::from_str(&text).unwrap();
let block = parsed.get("quantization").unwrap().as_object().unwrap();
assert_eq!(
block.len(),
3,
"block carries only the global keys: {block:?}"
);
assert!(block.contains_key("group_size"));
assert!(block.contains_key("bits"));
assert!(block.contains_key("mode"));
}
#[test]
fn build_quantize_config_fine_grained_writes_skip_and_default_overrides() {
let mut weights: Weights = HashMap::new();
for path in ["keep", "drop"] {
weights.insert(
format!("{path}.weight"),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
}
let global = Quantization {
group_size: 64,
bits: 4,
mode: QuantMode::Affine,
};
let mut decisions: PredicateDecisions = HashMap::new();
decisions.insert("keep".to_string(), Some(global)); decisions.insert("drop".to_string(), None);
let src = r#"{"quantization":{"group_size":64,"bits":4}}"#;
let (live, text) =
build_quantize_config(src, 64, 4, QuantMode::Affine, &decisions, &weights).unwrap();
assert_eq!(
live.per_layer_ref().get("keep"),
Some(&QuantizationOption::Quantize(global)),
"equal-to-global decision still written under fine_grained"
);
assert_eq!(
live.per_layer_ref().get("drop"),
Some(&QuantizationOption::Skip),
"None decision written as Skip under fine_grained"
);
let parsed: serde_json::Value = serde_json::from_str(&text).unwrap();
let block = parsed.get("quantization").unwrap();
assert!(
block.get("keep").map(|v| v.is_object()).unwrap_or(false),
"keep emitted as a nested object: {block:?}"
);
assert_eq!(
block.get("drop"),
Some(&serde_json::Value::Bool(false)),
"Skip emitted as a literal false (BaseConfiguration shape)"
);
}
#[test]
fn build_quantize_config_skips_non_weight_and_undecided_keys() {
let mut weights: Weights = HashMap::new();
weights.insert(
"layer.bias".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 64], &(1usize, 64usize)).unwrap(),
);
weights.insert(
"layer.undecided.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
weights.insert(
"layer.decided.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 128], &(2usize, 64usize)).unwrap(),
);
let q = Quantization {
group_size: 64,
bits: 8,
mode: QuantMode::Affine,
};
let mut decisions: PredicateDecisions = HashMap::new();
decisions.insert("layer.decided".to_string(), Some(q));
let (live, _text) =
build_quantize_config("{}", 64, 4, QuantMode::Affine, &decisions, &weights).unwrap();
assert_eq!(
live.per_layer_ref().len(),
1,
"only the decided key produced an override"
);
assert_eq!(
live.per_layer_ref().get("layer.decided"),
Some(&QuantizationOption::Quantize(q))
);
assert!(!live.per_layer_ref().contains_key("layer.bias"));
assert!(!live.per_layer_ref().contains_key("layer.undecided"));
}
#[test]
fn strip_quantization_blocks_non_object_config_is_error() {
match strip_quantization_blocks("[\"not\",\"an\",\"object\"]") {
Err(Error::InvariantViolation(p)) => {
assert!(
p.context().contains("source config JSON"),
"context names the source config: {}",
p.context()
);
assert_eq!(p.requirement(), "must be an object");
}
other => panic!("expected Err(InvariantViolation), got {other:?}"),
}
}
#[test]
fn copy_tokenizer_and_extras_copies_py_skips_non_py_and_dirs() {
let dir = std::env::temp_dir().join(format!("mlxrs_convert_pyglob_{}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
let src = dir.join("src");
let dst = dir.join("dst");
std::fs::create_dir_all(&src).unwrap();
std::fs::create_dir_all(&dst).unwrap();
std::fs::write(src.join("modeling_custom.py"), b"# hf model code\n").unwrap();
std::fs::write(src.join("configuration_custom.py"), b"# hf config code\n").unwrap();
std::fs::write(src.join("README.md"), b"readme\n").unwrap();
std::fs::create_dir_all(src.join("package.py")).unwrap();
let outcome = copy_tokenizer_and_extras(&src, &dst).unwrap();
assert!(
matches!(outcome, CopyOutcome::Committed),
"no fsync faults → fully-durable Committed; got {outcome:?}"
);
for name in ["modeling_custom.py", "configuration_custom.py"] {
assert!(dst.join(name).is_file(), "{name} copied via the *.py glob");
assert_eq!(
std::fs::read(src.join(name)).unwrap(),
std::fs::read(dst.join(name)).unwrap(),
"{name} byte-equal at dst"
);
}
assert!(
!dst.join("README.md").is_file(),
"non-.py regular file is not copied by the glob"
);
assert!(
!dst.join("package.py").exists(),
".py-named sub-directory is skipped (not a file)"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn copy_tokenizer_and_extras_missing_src_dir_is_read_error() {
let dir =
std::env::temp_dir().join(format!("mlxrs_convert_missing_src_{}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
let src = dir.join("does_not_exist_src");
let dst = dir.join("dst");
std::fs::create_dir_all(&dst).unwrap();
match copy_tokenizer_and_extras(&src, &dst) {
Err(Error::FileIo(p)) => {
assert_eq!(p.op(), FileOp::Read, "read_dir failure is a Read op");
assert_eq!(p.path(), src.as_path(), "the failing path is src");
assert!(
p.context().contains("copy_tokenizer_and_extras"),
"context names the helper: {}",
p.context()
);
}
other => panic!("expected Err(FileIo{{op:Read}}), got {other:?}"),
}
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn paths_are_same_fallback_textual_compare_when_uncanonicalizable() {
let base =
std::env::temp_dir().join(format!("mlxrs_convert_paths_same_{}", std::process::id()));
let a = base.join("nonexistent_a");
let b = base.join("nonexistent_b");
assert!(
!paths_are_same(&a, &b),
"distinct uncanonicalizable paths compare unequal via the fallback"
);
assert!(
paths_are_same(&a, &a),
"identical uncanonicalizable path compares equal via the fallback"
);
}
#[test]
fn convert_predicate_with_ineligible_weight_skips_via_true_none_arm() {
let dir = std::env::temp_dir().join(format!(
"mlxrs_convert_true_none_arm_{}",
std::process::id()
));
let _ = std::fs::remove_dir_all(&dir);
let src = dir.join("src");
let dst = dir.join("dst");
std::fs::create_dir_all(&src).unwrap();
let plain_config = r#"{
"model_type":"qwen3","hidden_size":16,"num_hidden_layers":1,
"num_attention_heads":2,"num_key_value_heads":2,"head_dim":8,
"rope_theta":10000.0,"vocab_size":128,"tie_word_embeddings":false
}"#;
std::fs::write(src.join("config.json"), plain_config).unwrap();
let blob: Vec<f32> = (0..128).map(|i| (i as f32) * 0.01).collect();
let mut weights: Weights = HashMap::new();
weights.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
Array::from_slice::<f32>(&blob, &(2usize, 64usize)).unwrap(),
);
weights.insert(
"model.norm.weight".to_string(),
Array::from_slice::<f32>(&[0.0_f32; 64], &(64usize,)).unwrap(),
);
crate::io::save_safetensors(&src.join("model.safetensors"), &weights).unwrap();
let tokenizer_json = include_str!("../../tests/fixtures/tokenizer.json");
let tokenizer_config_json = include_str!("../../tests/fixtures/tokenizer_config.json");
std::fs::write(src.join("tokenizer.json"), tokenizer_json).unwrap();
std::fs::write(src.join("tokenizer_config.json"), tokenizer_config_json).unwrap();
struct QuantizeAll;
impl MixedQuantPredicate for QuantizeAll {
fn decide(&self, _layer_name: &str, _weight: &Array) -> Option<Quantization> {
Some(Quantization::affine(64, 4))
}
}
convert(ConvertArgs {
hf_path: src,
mlx_path: dst.clone(),
quantize: true,
q_bits: Some(4),
q_group_size: Some(64),
q_mode: QuantMode::Affine,
quant_predicate: Some(Box::new(QuantizeAll)),
..Default::default()
})
.unwrap();
let reloaded = load::load_weights(&dst).unwrap();
assert!(
reloaded.contains_key("model.layers.0.self_attn.q_proj.scales"),
"eligible rank-2 layer quantized"
);
assert!(
reloaded.contains_key("model.norm.weight"),
"ineligible rank-1 weight still present"
);
assert!(
!reloaded.contains_key("model.norm.scales"),
"ineligible rank-1 weight NOT quantized (skipped via the (true, None) arm)"
);
let _ = std::fs::remove_dir_all(&dir);
}
}