use super::*;
use crate::Dtype;
fn base_weight() -> Array {
Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &(2, 3)).unwrap()
}
fn lora_a() -> Array {
Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0, 0.0, 0.0], &(3, 2)).unwrap()
}
fn lora_b() -> Array {
Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap()
}
fn plain_params() -> AdapterParams {
AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: None,
}
}
fn mlxlm_config(num_layers: i32, lora_parameters: LoraParameters) -> LoraConfig {
LoraConfig {
fine_tune_type: FineTuneType::Lora,
lora_parameters,
use_dora: false,
selection: AdapterSelection::MlxLm { num_layers },
}
}
fn mlxlm_num_layers(cfg: &LoraConfig) -> i32 {
match &cfg.selection {
AdapterSelection::MlxLm { num_layers } => *num_layers,
AdapterSelection::Peft(_) => panic!("expected an mlx-lm-native config, got PEFT"),
}
}
fn keyed_params(keys: Vec<String>) -> LoraParameters {
LoraParameters {
rank: 2,
scale: Some(2.0),
alpha: None,
keys,
dropout: None,
}
}
fn approx_eq(a: &[f32], b: &[f32], tol: f32) {
assert_eq!(a.len(), b.len(), "length mismatch: {a:?} vs {b:?}");
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() <= tol, "‖{x} - {y}‖ > {tol} ({a:?} vs {b:?})");
}
}
#[test]
fn lora_linear_forward_hand_traced() {
let base = BaseLinear::dense(base_weight(), None).unwrap();
let layer = LoRALinear::new(base, plain_params(), 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = layer.forward(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[3.0, 6.0], 1e-5);
}
#[test]
fn lora_linear_forward_with_bias() {
let bias = Array::from_slice::<f32>(&[10.0, 20.0], &(2usize,)).unwrap();
let base = BaseLinear::dense(base_weight(), Some(bias)).unwrap();
let layer = LoRALinear::new(base, plain_params(), 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = layer.forward(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[13.0, 26.0], 1e-5);
}
#[test]
fn lora_linear_zero_b_is_identity() {
let zero_b = Array::zeros::<f32>(&(2, 2)).unwrap();
let params = AdapterParams {
lora_a: lora_a(),
lora_b: zero_b,
magnitude: None,
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let layer = LoRALinear::new(base, params, 20.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = layer.forward(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[1.0, 2.0], 1e-5);
}
#[test]
fn lora_fuse_matches_forward() {
let base = BaseLinear::dense(base_weight(), None).unwrap();
let layer = LoRALinear::new(base, plain_params(), 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut via_forward = layer.forward(&x).unwrap();
let fused = layer.fuse(false).unwrap();
let mut via_fused = fused.base_output(&x).unwrap();
approx_eq(
&via_fused.to_vec::<f32>().unwrap(),
&via_forward.to_vec::<f32>().unwrap(),
1e-5,
);
}
#[test]
fn lora_fuse_with_bias_matches_forward() {
let bias = Array::from_slice::<f32>(&[10.0, 20.0], &(2usize,)).unwrap();
let base = BaseLinear::dense(base_weight(), Some(bias)).unwrap();
let layer = LoRALinear::new(base, plain_params(), 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut via_forward = layer.forward(&x).unwrap();
let fused = layer.fuse(false).unwrap();
let mut via_fused = fused.base_output(&x).unwrap();
approx_eq(
&via_fused.to_vec::<f32>().unwrap(),
&via_forward.to_vec::<f32>().unwrap(),
1e-5,
);
}
#[test]
fn dora_linear_forward_hand_traced() {
let m = Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap();
let params = AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(m),
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let layer = DoRALinear::new(base, params, 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = layer.forward(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[3.0, 6.0], 1e-5);
}
#[test]
fn dora_linear_forward_renorm_halves() {
let m = Array::from_slice::<f32>(&[1.5, 1.5], &(2usize,)).unwrap();
let params = AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(m),
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let layer = DoRALinear::new(base, params, 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = layer.forward(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[1.5, 3.0], 1e-5);
}
#[test]
fn dora_fuse_matches_forward() {
let m = Array::from_slice::<f32>(&[1.5, 2.5], &(2usize,)).unwrap();
let params = AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(m),
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let layer = DoRALinear::new(base, params, 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut via_forward = layer.forward(&x).unwrap();
let fused = layer.fuse(false).unwrap();
let mut via_fused = fused.base_output(&x).unwrap();
approx_eq(
&via_fused.to_vec::<f32>().unwrap(),
&via_forward.to_vec::<f32>().unwrap(),
1e-4,
);
}
#[test]
fn dora_requires_magnitude() {
let base = BaseLinear::dense(base_weight(), None).unwrap();
let err = DoRALinear::new(base, plain_params(), 2.0).unwrap_err();
assert!(matches!(err, Error::MissingField(_)));
}
#[test]
fn qlora_forward_matches_dense_within_quant_error() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let params = AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: None,
};
let x = Array::full::<f32>(&(1usize, input_dims), 1.0).unwrap();
let dense_base = BaseLinear::dense(dense_w.try_clone().unwrap(), None).unwrap();
let dense_layer = LoRALinear::new(dense_base, params.try_clone().unwrap(), 2.0).unwrap();
let mut dense_out = dense_layer.forward(&x).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
let q_layer = LoRALinear::new(q_base, params, 2.0).unwrap();
let mut q_out = q_layer.forward(&x).unwrap();
approx_eq(
&q_out.to_vec::<f32>().unwrap(),
&dense_out.to_vec::<f32>().unwrap(),
1e-2,
);
}
#[test]
fn qlora_fuse_dequantize_matches_forward() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let params = AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: None,
};
let x = Array::full::<f32>(&(1usize, input_dims), 1.0).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
let q_layer = LoRALinear::new(q_base, params, 2.0).unwrap();
let mut via_forward = q_layer.forward(&x).unwrap();
let fused = q_layer.fuse(true).unwrap();
assert!(matches!(fused, BaseLinear::Dense { .. }));
let mut via_fused = fused.base_output(&x).unwrap();
approx_eq(
&via_fused.to_vec::<f32>().unwrap(),
&via_forward.to_vec::<f32>().unwrap(),
1e-2,
);
}
#[test]
fn config_parse_lora_basic() {
let json = r#"{
"fine_tune_type": "lora",
"num_layers": 4,
"lora_parameters": { "rank": 16, "scale": 20.0 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.fine_tune_type, FineTuneType::Lora);
assert_eq!(mlxlm_num_layers(&cfg), 4);
assert_eq!(cfg.rank(), 16);
assert_eq!(cfg.scale(), 20.0);
assert!(!cfg.is_dora());
}
#[test]
fn config_parse_peft_flat_shape() {
let json = r#"{
"peft_type": "LORA",
"r": 16,
"lora_alpha": 32.0,
"target_modules": ["q_proj", "v_proj"],
"lora_dropout": 0.05,
"bias": "none"
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.rank(), 16, "PEFT top-level `r` must populate `rank`");
assert_eq!(cfg.scale(), 2.0);
assert!(cfg.lora_parameters.keys_slice().is_empty());
let peft = cfg.peft().expect("PEFT config must carry a PeftSelection");
match &peft.target_modules {
Some(ModuleMatcher::List(names)) => {
assert_eq!(names, &["q_proj".to_string(), "v_proj".to_string()]);
}
other => panic!("expected a target_modules List, got {other:?}"),
}
assert_eq!(cfg.lora_parameters.dropout, Some(0.05));
assert_eq!(cfg.fine_tune_type, FineTuneType::Lora);
assert!(
matches!(cfg.selection, AdapterSelection::Peft(_)),
"a PEFT config must select via PeftSelection, never the mlx-lm num_layers window"
);
assert!(!cfg.is_dora());
}
#[test]
fn config_parse_peft_use_dora() {
let json = r#"{
"peft_type": "LORA",
"r": 8,
"lora_alpha": 16.0,
"target_modules": ["q_proj"],
"use_dora": true
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert!(cfg.is_dora(), "PEFT `use_dora` must select DoRA");
assert_eq!(cfg.scale(), 2.0);
}
#[test]
fn config_parse_peft_no_peft_type_still_detected() {
let json = r#"{
"r": 4,
"lora_alpha": 8.0,
"target_modules": ["o_proj"]
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.rank(), 4);
assert_eq!(cfg.scale(), 2.0);
let peft = cfg
.peft()
.expect("PEFT shape detected ⇒ PeftSelection present");
assert!(matches!(
&peft.target_modules,
Some(ModuleMatcher::List(n)) if n == &["o_proj".to_string()]
));
}
#[test]
fn config_parse_peft_default_rank_when_r_absent() {
let json = r#"{ "peft_type": "LORA", "lora_alpha": 16.0, "target_modules": ["q_proj"] }"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.rank(), DEFAULT_LORA_RANK);
}
#[test]
fn config_parse_peft_non_lora_peft_type_is_err() {
for kind in ["LOHA", "LOKR", "IA3", "PROMPT_TUNING"] {
let json = format!(r#"{{ "peft_type": "{kind}", "r": 8, "target_modules": ["q_proj"] }}"#);
assert!(
LoraConfig::from_json(&json).is_err(),
"peft_type {kind:?} must be rejected"
);
}
}
#[test]
fn config_parse_peft_type_case_insensitive() {
let json = r#"{ "peft_type": "Lora", "r": 8, "lora_alpha": 16.0, "target_modules": ["q_proj"] }"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.rank(), 8);
}
#[test]
fn config_parse_peft_target_modules_regex() {
let json = r#"{ "peft_type": "LORA", "r": 8, "target_modules": ".*\\.(q|v)_proj" }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let peft = cfg.peft().unwrap();
let target = match &peft.target_modules {
Some(ModuleMatcher::Regex(re)) => re,
other => panic!("expected a target_modules Regex, got {other:?}"),
};
assert!(target.is_match("model.layers.0.self_attn.q_proj"));
assert!(target.is_match("model.layers.7.self_attn.v_proj"));
assert!(!target.is_match("model.layers.0.self_attn.k_proj"));
}
#[test]
fn config_parse_peft_invalid_regex_target_modules_is_err() {
let json = r#"{ "peft_type": "LORA", "r": 8, "target_modules": "(unclosed" }"#;
assert!(
LoraConfig::from_json(json).is_err(),
"an uncompilable `target_modules` regex must be rejected"
);
}
#[test]
fn config_lora_parameters_nesting_wins_over_flat_keys() {
let json = r#"{
"fine_tune_type": "lora",
"num_layers": 3,
"lora_parameters": { "rank": 64, "scale": 8.0 },
"r": 1, "lora_alpha": 999.0
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.rank(), 64, "nested `lora_parameters.rank` wins");
assert_eq!(
cfg.scale(),
8.0,
"nested literal `scale` wins, flat keys ignored"
);
assert_eq!(mlxlm_num_layers(&cfg), 3);
}
#[test]
fn config_parse_dora_and_alpha_scale() {
let json = r#"{
"fine_tune_type": "dora",
"num_layers": 2,
"lora_parameters": { "rank": 8, "alpha": 32.0 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert!(cfg.is_dora());
assert_eq!(cfg.scale(), 4.0);
}
#[test]
fn config_use_dora_flag() {
let json = r#"{
"fine_tune_type": "lora",
"use_dora": true,
"lora_parameters": { "rank": 8, "scale": 10.0 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert!(cfg.is_dora());
}
#[test]
fn config_defaults_and_unknown_keys_ignored() {
let json = r#"{ "optimizer": "adam", "learning_rate": 1e-4 }"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.fine_tune_type, FineTuneType::Lora);
assert_eq!(mlxlm_num_layers(&cfg), DEFAULT_NUM_LAYERS);
assert_eq!(cfg.rank(), DEFAULT_LORA_RANK);
assert_eq!(cfg.scale(), DEFAULT_LORA_SCALE);
}
#[test]
fn config_unknown_fine_tune_type_is_err() {
let json = r#"{ "fine_tune_type": "bogus" }"#;
assert!(LoraConfig::from_json(json).is_err());
}
#[test]
fn path_key_matching() {
assert!(path_matches_key(
"model.layers.27.self_attn.q_proj",
"self_attn.q_proj"
));
assert!(path_matches_key("self_attn.q_proj", "self_attn.q_proj"));
assert!(!path_matches_key(
"model.layers.27.self_attn.k_proj",
"q_proj"
));
assert!(!path_matches_key("model.xq_proj", "q_proj"));
}
#[test]
fn block_index_parsing() {
assert_eq!(
parse_block_index("model.layers.27.self_attn.q_proj"),
Some(27)
);
assert_eq!(parse_block_index("model.layers.0.mlp.down_proj"), Some(0));
assert_eq!(parse_block_index("model.embed_tokens"), None);
assert_eq!(parse_block_index("lm_head"), None);
}
fn toy_weights() -> Weights {
let mut w = Weights::new();
for b in 0..4 {
w.insert(
format!("model.layers.{b}.self_attn.q_proj.weight"),
base_weight(),
);
}
w.insert(
"model.layers.0.self_attn.k_proj.weight".to_string(),
base_weight(),
);
w.insert("lm_head.weight".to_string(), base_weight());
w
}
fn toy_adapter_params() -> HashMap<String, AdapterParams> {
toy_adapter_params_for(&[0, 1, 2, 3])
}
fn toy_adapter_params_for(blocks: &[i32]) -> HashMap<String, AdapterParams> {
let mut m = HashMap::new();
for &b in blocks {
m.insert(format!("model.layers.{b}.self_attn.q_proj"), plain_params());
}
m
}
#[test]
fn lora_layers_keys_and_num_layers_window() {
let weights = toy_weights();
let params = toy_adapter_params_for(&[2, 3]);
let cfg = mlxlm_config(2, keyed_params(vec!["self_attn.q_proj".to_string()]));
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap();
assert!(layers.contains_key("model.layers.2.self_attn.q_proj"));
assert!(layers.contains_key("model.layers.3.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.0.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.1.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.0.self_attn.k_proj"));
assert!(!layers.contains_key("lm_head"));
assert_eq!(layers.len(), 2);
}
#[test]
fn lora_layers_covers_all_blocks_when_num_layers_large() {
let weights = toy_weights();
let params = toy_adapter_params();
let cfg = mlxlm_config(16, keyed_params(vec!["self_attn.q_proj".to_string()]));
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap();
assert_eq!(layers.len(), 4);
}
fn write_mock_adapter(dir: &Path, fine_tune_type: &str, with_m: bool) {
let config = format!(
r#"{{
"fine_tune_type": "{fine_tune_type}",
"num_layers": 16,
"lora_parameters": {{ "rank": 2, "scale": 2.0, "keys": ["self_attn.q_proj"] }}
}}"#
);
std::fs::write(dir.join("adapter_config.json"), config).unwrap();
let mut arrays: HashMap<String, Array> = HashMap::new();
for b in 0..4 {
let path = format!("model.layers.{b}.self_attn.q_proj");
arrays.insert(format!("{path}.lora_a"), lora_a());
arrays.insert(format!("{path}.lora_b"), lora_b());
if with_m {
arrays.insert(
format!("{path}.m"),
Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap(),
);
}
}
crate::io::save_safetensors(&dir.join("adapters.safetensors"), &arrays).unwrap();
}
#[test]
fn load_adapters_lora_end_to_end() {
let tmp = std::env::temp_dir().join(format!("mlxrs_lora_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
write_mock_adapter(&tmp, "lora", false);
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
assert!(matches!(
layers.get("model.layers.0.self_attn.q_proj"),
Some(LoraLayer::Lora(_))
));
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = layers
.get("model.layers.0.self_attn.q_proj")
.unwrap()
.forward(&x)
.unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[3.0, 6.0], 1e-5);
std::fs::remove_dir_all(&tmp).ok();
}
fn write_mock_adapter_rank(dir: &Path, config_json: &str, r: usize) {
std::fs::write(dir.join("adapter_config.json"), config_json).unwrap();
let la = Array::full::<f32>(&(3usize, r), 0.01).unwrap();
let lb = Array::full::<f32>(&(r, 2usize), 0.01).unwrap();
let mut arrays: HashMap<String, Array> = HashMap::new();
for b in 0..4 {
let path = format!("model.layers.{b}.self_attn.q_proj");
arrays.insert(format!("{path}.lora_a"), la.try_clone().unwrap());
arrays.insert(format!("{path}.lora_b"), lb.try_clone().unwrap());
}
crate::io::save_safetensors(&dir.join("adapters.safetensors"), &arrays).unwrap();
}
fn write_mock_peft_adapter(
dir: &Path,
config_json: &str,
paths: &[&str],
r: usize,
with_dora: bool,
value: f32,
) {
std::fs::write(dir.join("adapter_config.json"), config_json).unwrap();
let lora_a_peft = Array::full::<f32>(&(r, 3usize), value).unwrap();
let lora_b_peft = Array::full::<f32>(&(2usize, r), value).unwrap();
let mut arrays: HashMap<String, Array> = HashMap::new();
for path in paths {
arrays.insert(
format!("base_model.model.{path}.lora_A.weight"),
lora_a_peft.try_clone().unwrap(),
);
arrays.insert(
format!("base_model.model.{path}.lora_B.weight"),
lora_b_peft.try_clone().unwrap(),
);
if with_dora {
arrays.insert(
format!("base_model.model.{path}.lora_magnitude_vector"),
Array::from_slice::<f32>(&[1.0, 1.0], &(2usize,)).unwrap(),
);
}
}
crate::io::save_safetensors(&dir.join("adapter_model.safetensors"), &arrays).unwrap();
}
#[test]
fn load_adapters_peft_flat_shape_rank16_end_to_end() {
let tmp = std::env::temp_dir().join(format!("mlxrs_peft_flat16_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{
"peft_type": "LORA",
"r": 16,
"lora_alpha": 32.0,
"target_modules": ["self_attn.q_proj"],
"lora_dropout": 0.0,
"bias": "none"
}"#;
let q_paths: Vec<String> = (0..4)
.map(|b| format!("model.layers.{b}.self_attn.q_proj"))
.collect();
let q_refs: Vec<&str> = q_paths.iter().map(String::as_str).collect();
write_mock_peft_adapter(&tmp, cfg, &q_refs, 16, false, 0.01);
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
assert!(layers.contains_key("model.layers.0.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.0.self_attn.k_proj"));
if let Some(LoraLayer::Lora(l)) = layers.get("model.layers.0.self_attn.q_proj") {
assert_eq!(l.scale(), 2.0, "PEFT scale must be lora_alpha/r");
} else {
panic!("expected a LoRA layer");
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_peft_flat_shape_rank8_scale_not_default() {
let tmp = std::env::temp_dir().join(format!("mlxrs_peft_flat8_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{
"peft_type": "LORA",
"r": 8,
"lora_alpha": 32.0,
"target_modules": ["self_attn.q_proj"]
}"#;
let q_paths: Vec<String> = (0..4)
.map(|b| format!("model.layers.{b}.self_attn.q_proj"))
.collect();
let q_refs: Vec<&str> = q_paths.iter().map(String::as_str).collect();
write_mock_peft_adapter(&tmp, cfg, &q_refs, 8, false, 0.01);
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
if let Some(LoraLayer::Lora(l)) = layers.get("model.layers.0.self_attn.q_proj") {
assert_eq!(l.scale(), 4.0);
assert_ne!(l.scale(), DEFAULT_LORA_SCALE);
} else {
panic!("expected a LoRA layer");
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_rank_drift_is_shape_mismatch() {
let tmp = std::env::temp_dir().join(format!("mlxrs_rankdrift_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{
"fine_tune_type": "lora",
"num_layers": 16,
"lora_parameters": { "rank": 8, "alpha": 32.0, "keys": ["self_attn.q_proj"] }
}"#;
write_mock_adapter_rank(&tmp, cfg, 16);
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::LengthMismatch(_)),
"rank drift must be a LengthMismatch, got {err:?}"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_peft_rank_drift_is_shape_mismatch() {
let tmp = std::env::temp_dir().join(format!("mlxrs_peft_rankdrift_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{
"peft_type": "LORA",
"r": 8,
"lora_alpha": 32.0,
"target_modules": ["self_attn.q_proj"]
}"#;
let q_paths: Vec<String> = (0..4)
.map(|b| format!("model.layers.{b}.self_attn.q_proj"))
.collect();
let q_refs: Vec<&str> = q_paths.iter().map(String::as_str).collect();
write_mock_peft_adapter(&tmp, cfg, &q_refs, 16, false, 0.01);
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::LengthMismatch(_)),
"PEFT rank drift must be a LengthMismatch, got {err:?}"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_dora_end_to_end() {
let tmp = std::env::temp_dir().join(format!("mlxrs_dora_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
write_mock_adapter(&tmp, "dora", true);
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
assert!(matches!(
layers.get("model.layers.0.self_attn.q_proj"),
Some(LoraLayer::Dora(_))
));
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_dora_missing_magnitude_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_dora_nom_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
write_mock_adapter(&tmp, "dora", false);
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::MissingKey(_)),
"missing DoRA magnitude must be MissingKey, got {err:?}"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_full_is_unsupported_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_full_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
write_mock_adapter(&tmp, "full", false);
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::UnknownEnumValue(_)),
"fine_tune_type=full rejection must be UnknownEnumValue, got {err:?}"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_unknown_fine_tune_type_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_bogus_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
write_mock_adapter(&tmp, "bogus", false);
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::Parse(_)),
"unknown fine_tune_type must be a serde Parse error, got {err:?}"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_missing_config_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_nocfg_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let arrays: HashMap<String, Array> = HashMap::new();
crate::io::save_safetensors(&tmp.join("adapters.safetensors"), &arrays).unwrap();
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::FileIo(_)),
"missing adapter_config.json must be a FileIo error, got {err:?}"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_missing_dir_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_nodir_test_{}", std::process::id()));
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::FileIo(_)),
"missing adapter dir must be a FileIo error, got {err:?}"
);
}
#[test]
fn lora_rejects_mismatched_output_dims() {
let bad_b = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &(2, 3)).unwrap();
let params = AdapterParams {
lora_a: lora_a(),
lora_b: bad_b,
magnitude: None,
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let err = LoRALinear::new(base, params, 2.0).unwrap_err();
assert!(matches!(err, Error::LengthMismatch(_)));
}
#[test]
fn lora_rejects_rank_mismatch() {
let bad_b = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0, 0.0, 0.0], &(3, 2)).unwrap();
let params = AdapterParams {
lora_a: lora_a(),
lora_b: bad_b,
magnitude: None,
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let err = LoRALinear::new(base, params, 2.0).unwrap_err();
assert!(matches!(err, Error::LengthMismatch(_)));
}
#[test]
fn lora_rejects_wrong_lora_a_input_dim_dense() {
let bad_a = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let params = AdapterParams {
lora_a: bad_a,
lora_b: lora_b(),
magnitude: None,
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let err = LoRALinear::new(base, params, 2.0).unwrap_err();
assert!(matches!(err, Error::LengthMismatch(_)));
}
#[test]
fn lora_rejects_wrong_lora_a_input_dim_quantized() {
let input_dims = 64usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(2, input_dims)).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
let bad_a = Array::full::<f32>(&(32usize, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let params = AdapterParams {
lora_a: bad_a,
lora_b: lb,
magnitude: None,
};
let err = LoRALinear::new(q_base, params, 2.0).unwrap_err();
assert!(matches!(err, Error::LengthMismatch(_)));
}
#[test]
fn lora_a_correct_input_dim_quantized_ok() {
let input_dims = 64usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(2, input_dims)).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let params = AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: None,
};
assert!(LoRALinear::new(q_base, params, 2.0).is_ok());
}
#[test]
fn resolved_scale_alpha_only() {
let p = LoraParameters {
rank: 8,
scale: None,
alpha: Some(32.0),
keys: Vec::new(),
dropout: None,
};
assert_eq!(p.resolved_scale(), 4.0);
}
#[test]
fn resolved_scale_scale_only() {
let p = LoraParameters {
rank: 8,
scale: Some(7.5),
alpha: None,
keys: Vec::new(),
dropout: None,
};
assert_eq!(p.resolved_scale(), 7.5);
}
#[test]
fn resolved_scale_alpha_wins_over_scale() {
let p = LoraParameters {
rank: 16,
scale: Some(99.0),
alpha: Some(64.0),
keys: Vec::new(),
dropout: None,
};
assert_eq!(p.resolved_scale(), 4.0);
}
#[test]
fn resolved_scale_neither_is_default() {
let p = LoraParameters {
rank: 8,
scale: None,
alpha: None,
keys: Vec::new(),
dropout: None,
};
assert_eq!(p.resolved_scale(), DEFAULT_LORA_SCALE);
}
#[test]
fn resolved_scale_alpha_with_nonpositive_rank_falls_back() {
let p = LoraParameters {
rank: 0,
scale: Some(5.0),
alpha: Some(32.0),
keys: Vec::new(),
dropout: None,
};
assert_eq!(p.resolved_scale(), 5.0);
let p_no_scale = LoraParameters {
rank: -1,
scale: None,
alpha: Some(32.0),
keys: Vec::new(),
dropout: None,
};
assert_eq!(p_no_scale.resolved_scale(), DEFAULT_LORA_SCALE);
}
#[test]
fn config_both_scale_and_alpha_alpha_wins() {
let json = r#"{
"fine_tune_type": "lora",
"lora_parameters": { "rank": 8, "scale": 50.0, "alpha": 16.0 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.scale(), 2.0); }
#[test]
fn lora_layers_num_layers_negative_one_selects_all_blocks() {
let weights = toy_weights();
let params = toy_adapter_params(); let cfg = mlxlm_config(-1, keyed_params(vec!["self_attn.q_proj".to_string()]));
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap();
assert_eq!(layers.len(), 4, "num_layers=-1 must adapt all 4 blocks");
for b in 0..4 {
assert!(layers.contains_key(&format!("model.layers.{b}.self_attn.q_proj")));
}
}
#[test]
fn lora_layers_num_layers_zero_selects_all_blocks() {
let weights = toy_weights();
let params = toy_adapter_params();
let cfg = mlxlm_config(0, keyed_params(vec!["self_attn.q_proj".to_string()]));
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap();
assert_eq!(layers.len(), 4, "num_layers=0 must adapt all 4 blocks");
}
#[test]
fn lora_layers_explicit_key_missing_factors_is_err() {
let weights = toy_weights();
let params = toy_adapter_params_for(&[0, 1]);
let cfg = mlxlm_config(16, keyed_params(vec!["self_attn.q_proj".to_string()]));
let err = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap_err();
match err {
Error::MissingKey(p) => {
assert!(
p.context().contains("explicitly-selected adapter target"),
"context names the explicit-selection rule: {}",
p.context()
);
assert_eq!(p.key(), "model.layers.2.self_attn.q_proj");
}
other => panic!("expected Error::MissingKey, got {other:?}"),
}
}
#[test]
fn lora_layers_unused_adapter_factor_is_err() {
let weights = toy_weights();
let mut params = toy_adapter_params(); params.insert(
"model.layers.99.self_attn.q_proj".to_string(),
plain_params(),
);
let cfg = mlxlm_config(16, keyed_params(vec!["self_attn.q_proj".to_string()]));
let err = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap_err();
match err {
Error::LayerKeyed(p) => {
assert_eq!(p.layer(), "model.layers.99.self_attn.q_proj");
let Error::InvariantViolation(iv) = p.inner() else {
panic!(
"expected inner Error::InvariantViolation, got {:?}",
p.inner()
);
};
assert!(
iv.context().contains("adapter factor group")
&& iv.requirement().contains("must match a base layer"),
"inner violation should call out base-layer matching: {iv:?}"
);
}
other => panic!("expected Error::LayerKeyed, got {other:?}"),
}
}
#[test]
fn lora_layers_empty_result_is_err() {
let weights = toy_weights();
let params: HashMap<String, AdapterParams> = HashMap::new();
let cfg = mlxlm_config(
16,
keyed_params(vec!["self_attn.nonexistent_proj".to_string()]),
);
let err = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap_err();
match err {
Error::InvariantViolation(p) => {
assert_eq!(p.context(), "load_adapters: adapted-layer count");
assert!(p.requirement().contains("must be >= 1"));
}
other => panic!("expected Error::InvariantViolation, got {other:?}"),
}
}
#[test]
fn lora_layers_autodiscovery_partial_factors_is_ok() {
let weights = toy_weights();
let params = toy_adapter_params_for(&[2, 3]);
let cfg = mlxlm_config(16, keyed_params(Vec::new()));
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap();
assert_eq!(layers.len(), 2);
}
#[test]
fn load_adapters_unused_factor_end_to_end_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_unused_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let config = r#"{
"fine_tune_type": "lora",
"num_layers": 16,
"lora_parameters": { "rank": 2, "scale": 2.0, "keys": ["self_attn.q_proj"] }
}"#;
std::fs::write(tmp.join("adapter_config.json"), config).unwrap();
let mut arrays: HashMap<String, Array> = HashMap::new();
for b in 0..4 {
let path = format!("model.layers.{b}.self_attn.q_proj");
arrays.insert(format!("{path}.lora_a"), lora_a());
arrays.insert(format!("{path}.lora_b"), lora_b());
}
arrays.insert(
"model.layers.42.self_attn.q_proj.lora_a".to_string(),
lora_a(),
);
arrays.insert(
"model.layers.42.self_attn.q_proj.lora_b".to_string(),
lora_b(),
);
crate::io::save_safetensors(&tmp.join("adapters.safetensors"), &arrays).unwrap();
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::LayerKeyed(ref p) if matches!(p.inner(), Error::InvariantViolation(_))),
"unused factor group must be LayerKeyed(InvariantViolation), got {err:?}"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_empty_safetensors_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_emptyst_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let config = r#"{
"fine_tune_type": "lora",
"num_layers": 16,
"lora_parameters": { "rank": 2, "scale": 2.0, "keys": ["self_attn.q_proj"] }
}"#;
std::fs::write(tmp.join("adapter_config.json"), config).unwrap();
let arrays: HashMap<String, Array> = HashMap::new();
crate::io::save_safetensors(&tmp.join("adapters.safetensors"), &arrays).unwrap();
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
assert!(
matches!(err, Error::MissingKey(_)),
"explicit-selection w/o factors must be MissingKey, got {err:?}"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn qdora_forward_matches_dense_within_quant_error() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let bias = Array::from_slice::<f32>(&[3.0, -1.0], &(output_dims,)).unwrap();
let dense_params = AdapterParams {
lora_a: la.try_clone().unwrap(),
lora_b: lb.try_clone().unwrap(),
magnitude: None,
};
let scale = 2.0f32;
let delta = lora_delta(&dense_params, scale).unwrap();
let adapted = dense_w.add(&delta).unwrap();
let m = ops::linalg_full::norm(&adapted, 2.0, &[1], false).unwrap();
let dense_base = BaseLinear::dense(
dense_w.try_clone().unwrap(),
Some(bias.try_clone().unwrap()),
)
.unwrap();
let dense_layer = DoRALinear::new(
dense_base,
AdapterParams {
lora_a: la.try_clone().unwrap(),
lora_b: lb.try_clone().unwrap(),
magnitude: Some(m.try_clone().unwrap()),
},
scale,
)
.unwrap();
let x = Array::full::<f32>(&(1usize, input_dims), 1.0).unwrap();
let mut dense_out = dense_layer.forward(&x).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base = BaseLinear::quantized(
w_q,
scales,
biases,
Some(bias.try_clone().unwrap()),
32,
8,
"affine".to_string(),
)
.unwrap();
let q_layer = DoRALinear::new(
q_base,
AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: Some(m),
},
scale,
)
.unwrap();
let mut q_out = q_layer.forward(&x).unwrap();
approx_eq(
&q_out.to_vec::<f32>().unwrap(),
&dense_out.to_vec::<f32>().unwrap(),
2e-2,
);
}
#[test]
fn qdora_forward_matches_fuse() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let m = Array::from_slice::<f32>(&[1.5, 2.5], &(output_dims,)).unwrap();
let x = Array::full::<f32>(&(1usize, input_dims), 1.0).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
let q_layer = DoRALinear::new(
q_base,
AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: Some(m),
},
2.0,
)
.unwrap();
let mut via_forward = q_layer.forward(&x).unwrap();
let fused = q_layer.fuse(true).unwrap();
let mut via_fused = fused.base_output(&x).unwrap();
approx_eq(
&via_fused.to_vec::<f32>().unwrap(),
&via_forward.to_vec::<f32>().unwrap(),
2e-2,
);
}
#[test]
fn load_adapters_non_regular_safetensors_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_nonreg_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let config = r#"{
"fine_tune_type": "lora",
"num_layers": 16,
"lora_parameters": { "rank": 2, "scale": 2.0, "keys": ["self_attn.q_proj"] }
}"#;
std::fs::write(tmp.join("adapter_config.json"), config).unwrap();
std::fs::create_dir_all(tmp.join("adapters.safetensors")).unwrap();
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
match err {
Error::FileIo(p) => {
assert_eq!(p.path(), tmp.join("adapters.safetensors").as_path());
assert_eq!(p.op(), FileOp::Stat);
assert_eq!(p.inner().kind(), std::io::ErrorKind::InvalidInput);
}
other => panic!("expected Error::FileIo(InvalidInput, Stat), got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_oversized_safetensors_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_oversize_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let config = r#"{
"fine_tune_type": "lora",
"num_layers": 16,
"lora_parameters": { "rank": 2, "scale": 2.0, "keys": ["self_attn.q_proj"] }
}"#;
std::fs::write(tmp.join("adapter_config.json"), config).unwrap();
let f = std::fs::File::create(tmp.join("adapters.safetensors")).unwrap();
f.set_len(MAX_ADAPTER_SAFETENSORS_BYTES + 1).unwrap();
drop(f);
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
match err {
Error::CapExceeded(p) => {
assert_eq!(p.cap_name(), "MAX_ADAPTER_SAFETENSORS_BYTES");
assert_eq!(p.cap(), MAX_ADAPTER_SAFETENSORS_BYTES);
assert_eq!(p.observed(), MAX_ADAPTER_SAFETENSORS_BYTES + 1);
}
other => panic!("expected Error::CapExceeded, got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
fn peft_toy_weights(n: usize) -> Weights {
let mut w = Weights::new();
for b in 0..n {
w.insert(
format!("model.layers.{b}.self_attn.q_proj.weight"),
base_weight(),
);
w.insert(
format!("model.layers.{b}.self_attn.v_proj.weight"),
base_weight(),
);
}
w.insert("lm_head.weight".to_string(), base_weight());
w
}
#[test]
fn peft_config_lora_alpha_defaults_to_8() {
let json = r#"{ "peft_type": "LORA", "r": 16, "target_modules": ["q_proj"] }"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.lora_parameters.alpha, Some(DEFAULT_PEFT_LORA_ALPHA));
assert_eq!(cfg.scale_for("model.layers.0.self_attn.q_proj"), 0.5);
}
#[test]
fn peft_config_accepts_and_ignores_training_only_fields() {
let json = r#"{
"peft_type": "LORA",
"r": 8,
"lora_alpha": 16.0,
"target_modules": ["q_proj"],
"init_lora_weights": "gaussian",
"loftq_config": {},
"eva_config": null,
"corda_config": null,
"task_type": "CAUSAL_LM",
"megatron_config": null,
"megatron_core": "megatron.core",
"revision": null,
"base_model_name_or_path": "meta-llama/Llama-3-8B"
}"#;
let cfg = LoraConfig::from_json(json).expect("training-only fields must not error");
assert_eq!(cfg.rank(), 8);
assert_eq!(cfg.scale_for("q_proj"), 2.0);
}
#[test]
fn peft_config_lora_bias_true_is_err() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "lora_bias": true
}"#;
assert!(
LoraConfig::from_json(json).is_err(),
"`lora_bias: true` must be rejected (no lora_B-bias term in LoRALinear)"
);
let ok = r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "lora_bias": false }"#;
assert!(LoraConfig::from_json(ok).is_ok());
}
#[test]
fn peft_config_bias_all_or_lora_only_is_err() {
for bias in ["all", "lora_only"] {
let json = format!(
r#"{{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "bias": {bias:?} }}"#
);
let err =
LoraConfig::from_json(&json).expect_err(&format!("PEFT `bias: {bias:?}` must be rejected"));
assert!(
matches!(err, Error::Parse(_)),
"expected Error::Parse for `bias: {bias:?}`, got {err:?}"
);
}
let none = r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "bias": "none" }"#;
assert!(LoraConfig::from_json(none).is_ok());
let absent = r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"] }"#;
assert!(LoraConfig::from_json(absent).is_ok());
}
#[test]
fn peft_config_nonempty_modules_to_save_is_err() {
let json = r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "modules_to_save": ["embed_tokens", "lm_head"] }"#;
let err = LoraConfig::from_json(json).expect_err("non-empty `modules_to_save` must be rejected");
assert!(
matches!(err, Error::Parse(_)),
"expected Error::Parse, got {err:?}"
);
let empty = r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "modules_to_save": [] }"#;
assert!(LoraConfig::from_json(empty).is_ok());
}
#[test]
fn peft_key_translation_rejects_sidecar_bias_and_modules_to_save_tensors() {
let bias_key = "base_model.model.model.layers.0.self_attn.q_proj.bias";
let mut with_bias: HashMap<String, Array> = HashMap::new();
with_bias.insert(
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight".to_string(),
Array::zeros::<f32>(&(2, 3)).unwrap(),
);
with_bias.insert(
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight".to_string(),
Array::zeros::<f32>(&(4, 2)).unwrap(),
);
with_bias.insert(
bias_key.to_string(),
Array::zeros::<f32>(&(4usize,)).unwrap(),
);
let err = translate_peft_keys(with_bias)
.expect_err("a PEFT-prefixed `.bias` tensor must be rejected, not silently dropped");
match err {
Error::LayerKeyed(ref payload) => {
assert_eq!(
payload.layer(),
bias_key,
"the rejection must name the dropped key"
);
assert!(matches!(payload.inner(), Error::InvariantViolation(_)));
}
other => panic!("expected Error::LayerKeyed, got {other:?}"),
}
let saved_key = "base_model.model.lm_head.weight";
let mut with_saved: HashMap<String, Array> = HashMap::new();
with_saved.insert(saved_key.to_string(), base_weight());
let err = translate_peft_keys(with_saved)
.expect_err("a PEFT-prefixed `modules_to_save` weight must be rejected");
let Error::LayerKeyed(payload) = err else {
panic!("expected LayerKeyed");
};
assert_eq!(payload.layer(), saved_key);
assert!(matches!(payload.inner(), Error::InvariantViolation(_)));
}
#[test]
fn peft_config_exotic_variants_are_rejected() {
let base = r#""peft_type": "LORA", "r": 8, "lora_alpha": 16.0, "target_modules": ["q_proj"]"#;
for (field, value) in [
("use_qalora", "true"),
("alora_invocation_tokens", "[1, 2, 3]"),
("velora_config", r#"{"rank": 4}"#),
("monteclora_config", r#"{"num_samples": 8}"#),
] {
let json = format!("{{ {base}, {field:?}: {value} }}");
let err = LoraConfig::from_json(&json).expect_err(&format!(
"a PEFT adapter setting `{field}` must be rejected (it changes inference)"
));
let Error::Parse(p) = &err else {
panic!("expected Error::Parse for `{field}`, got {err:?}");
};
assert_eq!(p.context(), "LoraConfig::from_json");
let msg = p.inner().to_string();
assert!(
msg.contains(field),
"the rejection error for `{field}` should name the field; got: {msg}"
);
}
}
#[test]
fn peft_config_exotic_variant_rejection_is_shape_independent() {
for (label, json) in [
("no-marker use_qalora", r#"{ "use_qalora": true }"#),
(
"no-marker alora",
r#"{ "alora_invocation_tokens": [7, 8] }"#,
),
("no-marker velora", r#"{ "velora_config": {"rank": 2} }"#),
(
"no-marker monteclora",
r#"{ "monteclora_config": {"k": 1} }"#,
),
(
"mlx-lm-shape use_qalora",
r#"{ "lora_parameters": { "rank": 8 }, "use_qalora": true }"#,
),
(
"mlx-lm-shape velora",
r#"{ "fine_tune_type": "lora", "num_layers": 4,
"lora_parameters": { "rank": 8 }, "velora_config": {"x": 1} }"#,
),
] {
assert!(
LoraConfig::from_json(json).is_err(),
"exotic-field config {label:?} must be rejected regardless of on-disk shape"
);
}
}
#[test]
fn peft_config_exotic_variant_defaults_are_accepted() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0, "target_modules": ["q_proj"],
"use_qalora": false, "qalora_group_size": 16,
"alora_invocation_tokens": null, "velora_config": null, "monteclora_config": null
}"#;
let cfg = LoraConfig::from_json(json)
.expect("exotic fields at their defaults must not trip the rejection");
assert_eq!(cfg.rank(), 8);
assert!(!cfg.is_dora());
let mlx = r#"{ "lora_parameters": { "rank": 4 },
"use_qalora": false, "velora_config": null, "monteclora_config": null }"#;
assert!(
LoraConfig::from_json(mlx).is_ok(),
"exotic defaults on an mlx-lm-shaped config must not trip the guard"
);
}
#[test]
fn peft_config_arrow_config_is_err() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0, "target_modules": ["q_proj"],
"arrow_config": { "top_k": 3 }
}"#;
let err =
LoraConfig::from_json(json).expect_err("`arrow_config` set must be rejected (forward variant)");
let Error::Parse(p) = &err else {
panic!("expected Error::Parse, got {err:?}");
};
let msg = p.inner().to_string();
assert!(
msg.contains("arrow_config"),
"the rejection should name `arrow_config`; got: {msg}"
);
}
#[test]
fn peft_config_use_bdlora_is_err() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0, "target_modules": ["q_proj"],
"use_bdlora": { "nblocks": 2 }
}"#;
let err = LoraConfig::from_json(json).expect_err("`use_bdlora` set must be rejected");
let Error::Parse(p) = &err else {
panic!("expected Error::Parse, got {err:?}");
};
let msg = p.inner().to_string();
assert!(
msg.contains("use_bdlora"),
"the rejection should name `use_bdlora`; got: {msg}"
);
}
#[test]
fn peft_config_invented_unknown_active_field_is_err() {
for (field, value) in [
("some_future_variant", r#"{ "k": 1 }"#),
("another_future_knob", "7"),
("yet_another_variant", "true"),
("a_future_string_variant", r#""enabled""#),
] {
let json = format!(
r#"{{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], {field:?}: {value} }}"#
);
let err = LoraConfig::from_json(&json).expect_err(&format!(
"an active unknown field `{field}` must be rejected by the structural backstop"
));
let Error::Parse(p) = &err else {
panic!("expected Error::Parse for `{field}`, got {err:?}");
};
let msg = p.inner().to_string();
assert!(
msg.contains(field),
"the rejection for `{field}` should name the field; got: {msg}"
);
}
}
#[test]
fn peft_config_unknown_field_inactive_value_is_accepted() {
for value in ["null", "false"] {
let json = format!(
r#"{{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "some_future_variant": {value} }}"#
);
let cfg = LoraConfig::from_json(&json).unwrap_or_else(|e| {
panic!("an inactive (`{value}`) unknown field must be ignored, got: {e:?}")
});
assert_eq!(cfg.rank(), 8);
}
let json = r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"],
"future_a": null, "future_b": false, "future_c": null }"#;
assert!(
LoraConfig::from_json(json).is_ok(),
"multiple inactive unknown fields must all be ignored"
);
}
#[test]
fn peft_config_benign_fields_with_real_values_are_accepted() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0, "target_modules": ["q_proj"],
"task_type": "CAUSAL_LM",
"revision": "main",
"base_model_name_or_path": "meta-llama/Llama-3-8B",
"auto_mapping": { "base_model_class": "LlamaForCausalLM" },
"inference_mode": true,
"peft_version": "0.19.2.dev0",
"megatron_core": "megatron.core",
"megatron_config": { "tensor_model_parallel_size": 1 },
"runtime_config": { "ephemeral_gpu_offload": true },
"eva_config": { "rho": 2.0 },
"corda_config": { "corda_method": "ipm" },
"lora_ga_config": { "scale": "stable" },
"loftq_config": { "loftq_bits": 4 },
"qalora_group_size": 16,
"ensure_weight_tying": true
}"#;
let cfg = LoraConfig::from_json(json)
.expect("benign metadata / training-only fields must load even when set");
assert_eq!(cfg.rank(), 8);
assert_eq!(cfg.lora_parameters.alpha, Some(16.0));
}
#[test]
fn peft_config_init_lora_weights_allowlist_rejects_non_factor_modes() {
for mode in [
"pissa",
"pissa_niter_4",
"PISSA_NITER_16",
"olora",
"corda",
"corda_v1",
"lora_ga",
"loftq",
"some_future_init_mode",
] {
let json = format!(
r#"{{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "init_lora_weights": "{mode}" }}"#
);
match LoraConfig::from_json(&json) {
Err(Error::Parse(p)) => {
let msg = p.inner().to_string();
assert!(
msg.contains(mode),
"the rejection should name the mode `{mode}`; got: {msg}"
);
}
Ok(_) => {
panic!("`init_lora_weights: \"{mode}\"` must be rejected (mutates base weight at init)")
}
Err(other) => panic!("expected Error::Parse for `{mode}`, got {other:?}"),
}
}
for init in ["\"gaussian\"", "\"eva\"", "\"orthogonal\"", "true", "false"] {
let json = format!(
r#"{{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "init_lora_weights": {init} }}"#
);
assert!(
LoraConfig::from_json(&json).is_ok(),
"`init_lora_weights: {init}` is a pure factor seed and must load"
);
}
}
#[test]
fn peft_config_structural_reject_examples_layer_replication_and_token_indices() {
let cases = [
(
"layer_replication",
r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "layer_replication": [[0, 4], [2, 5]] }"#,
),
(
"trainable_token_indices",
r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "trainable_token_indices": [0, 1, 2] }"#,
),
(
"target_parameters",
r#"{ "peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": [], "target_parameters": ["feed_forward.experts.gate_up_proj"] }"#,
),
];
for (field, json) in cases {
match LoraConfig::from_json(json) {
Ok(_) => panic!("`{field}` set must be rejected by the structural backstop"),
Err(Error::Parse(p)) => {
let msg = p.inner().to_string();
assert!(
msg.contains(field),
"the rejection should name `{field}`; got: {msg}"
);
}
Err(other) => panic!("expected Error::Parse for `{field}`, got {other:?}"),
}
}
}
#[test]
fn peft_config_valid_flat_fixture_still_loads() {
let json = r#"{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"auto_mapping": null,
"peft_version": "0.19.2.dev0",
"base_model_name_or_path": "meta-llama/Llama-3-8B",
"revision": null,
"inference_mode": true,
"r": 16,
"lora_alpha": 32.0,
"lora_dropout": 0.05,
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
"exclude_modules": null,
"bias": "none",
"use_rslora": false,
"use_dora": false,
"fan_in_fan_out": false,
"lora_bias": false,
"modules_to_save": null,
"init_lora_weights": true,
"layers_to_transform": null,
"layers_pattern": null,
"rank_pattern": {},
"alpha_pattern": {},
"megatron_config": null,
"megatron_core": "megatron.core",
"use_qalora": false,
"qalora_group_size": 16,
"alora_invocation_tokens": null,
"loftq_config": {},
"eva_config": null,
"corda_config": null,
"lora_ga_config": null,
"velora_config": null,
"monteclora_config": null,
"layer_replication": null,
"trainable_token_indices": null,
"target_parameters": null,
"use_bdlora": null,
"arrow_config": null,
"ensure_weight_tying": false,
"runtime_config": {"ephemeral_gpu_offload": false}
}"#;
let cfg = LoraConfig::from_json(json).expect("a realistic PEFT-flat config must still load");
assert_eq!(cfg.rank(), 16);
assert_eq!(cfg.lora_parameters.alpha, Some(32.0));
assert_eq!(cfg.scale_for("model.layers.0.self_attn.q_proj"), 2.0); let peft = cfg.peft().expect("PEFT selection");
assert!(matches!(&peft.target_modules, Some(ModuleMatcher::List(_))));
}
#[test]
fn mlx_lm_native_fixture_still_loads_with_unknown_keys() {
let json = r#"{
"fine_tune_type": "lora",
"num_layers": 8,
"lora_parameters": { "rank": 8, "scale": 20.0, "dropout": 0.0, "keys": ["q_proj"] },
"some_native_extra_key": { "whatever": 1 }
}"#;
let cfg = LoraConfig::from_json(json)
.expect("mlx-lm-native shape must keep accept-and-ignore for unknown keys");
assert_eq!(cfg.rank(), 8);
assert!(matches!(
cfg.selection,
AdapterSelection::MlxLm { num_layers: 8 }
));
}
#[test]
fn peft_key_translation_embedding_lora_precise_reject() {
for suffix in [
".lora_embedding_A",
".lora_embedding_B",
".lora_embedding_A.weight",
".lora_embedding_B.weight",
] {
let key = format!("base_model.model.model.embed_tokens{suffix}");
let mut arrays: HashMap<String, Array> = HashMap::new();
arrays.insert(key.clone(), Array::zeros::<f32>(&(2, 3)).unwrap());
match translate_peft_keys(arrays) {
Ok(_) => panic!("embedding-LoRA key {key:?} must be rejected, not accepted"),
Err(Error::LayerKeyed(p)) => {
assert_eq!(p.layer(), key, "LayerKeyed must name the offending key");
let Error::InvariantViolation(iv) = p.inner() else {
panic!(
"expected inner Error::InvariantViolation, got {:?}",
p.inner()
);
};
assert!(
iv.requirement().to_lowercase().contains("embedding"),
"the rejection requirement must mention embedding; got: {}",
iv.requirement()
);
assert!(
!iv.requirement().contains("bias") && !iv.requirement().contains("modules_to_save"),
"embedding-LoRA must NOT be misclassified as bias/modules_to_save; got: {}",
iv.requirement()
);
}
Err(other) => panic!("expected Error::LayerKeyed, got {other:?}"),
}
}
}
#[test]
fn peft_config_all_selection_fields_parse() {
let json = r#"{
"peft_type": "LORA",
"r": 8,
"lora_alpha": 16.0,
"target_modules": ["q_proj", "v_proj"],
"exclude_modules": ["lm_head"],
"use_rslora": true,
"use_dora": false,
"fan_in_fan_out": true,
"layers_to_transform": [0, 2, 4],
"layers_pattern": "layers",
"rank_pattern": { "q_proj": 16 },
"alpha_pattern": { "v_proj": 64 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
let peft = cfg.peft().unwrap();
assert!(matches!(&peft.target_modules, Some(ModuleMatcher::List(_))));
assert!(matches!(
&peft.exclude_modules,
Some(ModuleMatcher::List(_))
));
assert!(peft.use_rslora);
assert!(peft.fan_in_fan_out);
assert_eq!(peft.layers_to_transform.as_deref(), Some(&[0, 2, 4][..]));
assert_eq!(peft.layers_pattern, vec!["layers".to_string()]);
assert!(cfg.fan_in_fan_out());
}
#[test]
fn peft_rslora_scale_is_alpha_over_sqrt_r() {
let json = r#"{
"peft_type": "LORA", "r": 16, "lora_alpha": 32.0,
"target_modules": ["q_proj"], "use_rslora": true
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.scale_for("model.layers.0.self_attn.q_proj"), 8.0);
}
#[test]
fn peft_non_rslora_scale_is_alpha_over_r() {
let json = r#"{
"peft_type": "LORA", "r": 16, "lora_alpha": 32.0,
"target_modules": ["q_proj"]
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert!(!cfg.peft().unwrap().use_rslora);
assert_eq!(cfg.scale_for("model.layers.0.self_attn.q_proj"), 2.0);
}
#[test]
fn peft_rank_pattern_overrides_rank_per_module() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj", "v_proj"],
"rank_pattern": { "q_proj": 32 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.rank_for("model.layers.0.self_attn.q_proj"), 32);
assert_eq!(cfg.rank_for("model.layers.0.self_attn.v_proj"), 8);
assert_eq!(cfg.scale_for("model.layers.0.self_attn.q_proj"), 0.5);
assert_eq!(cfg.scale_for("model.layers.0.self_attn.v_proj"), 2.0);
}
#[test]
fn peft_alpha_pattern_overrides_alpha_per_module() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj", "v_proj"],
"alpha_pattern": { "q_proj": 64 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.scale_for("model.layers.0.self_attn.q_proj"), 8.0);
assert_eq!(cfg.scale_for("model.layers.0.self_attn.v_proj"), 2.0);
}
#[test]
fn peft_rank_and_alpha_pattern_with_rslora() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "use_rslora": true,
"rank_pattern": { "q_proj": 16 }, "alpha_pattern": { "q_proj": 64 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.scale_for("model.layers.0.self_attn.q_proj"), 16.0);
}
#[test]
fn peft_pattern_lookup_anchors_at_segment_boundary() {
let patterns = vec![("q_proj".to_string(), 99i32)];
assert_eq!(
pattern_lookup(&patterns, "model.layers.0.self_attn.q_proj"),
Some(99)
);
assert_eq!(pattern_lookup(&patterns, "q_proj"), Some(99));
assert_eq!(pattern_lookup(&patterns, "model.xq_proj"), None);
assert_eq!(pattern_lookup(&patterns, "model.layers.0.mlp.down"), None);
}
#[test]
fn peft_pattern_lookup_regex_key() {
let patterns = vec![("layers\\.0\\..*q_proj".to_string(), 64i32)];
assert_eq!(
pattern_lookup(&patterns, "model.layers.0.self_attn.q_proj"),
Some(64)
);
assert_eq!(
pattern_lookup(&patterns, "model.layers.1.self_attn.q_proj"),
None
);
}
#[test]
fn peft_rank_pattern_resolves_in_json_insertion_order_not_sorted() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"],
"rank_pattern": { "self_attn.q_proj": 11, ".*\\.q_proj": 22 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(
cfg.rank_for("model.layers.0.self_attn.q_proj"),
11,
"first-in-JSON-order key must win (a lexicographic sort would pick 22)"
);
let reversed = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"],
"rank_pattern": { ".*\\.q_proj": 22, "self_attn.q_proj": 11 }
}"#;
let cfg2 = LoraConfig::from_json(reversed).unwrap();
assert_eq!(
cfg2.rank_for("model.layers.0.self_attn.q_proj"),
22,
"with the order reversed the other key wins — pure insertion-order tie-break"
);
}
#[test]
fn peft_alpha_pattern_resolves_in_json_insertion_order_not_sorted() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"],
"alpha_pattern": { "q_proj": 40, ".*\\.q_proj": 80 }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert_eq!(cfg.scale_for("model.layers.0.self_attn.q_proj"), 5.0);
}
#[test]
fn peft_select_target_modules_list() {
let weights = peft_toy_weights(4);
let mut params = HashMap::new();
for b in 0..4 {
params.insert(format!("model.layers.{b}.self_attn.q_proj"), plain_params());
}
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["q_proj"] }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap();
assert_eq!(layers.len(), 4);
assert!(layers.contains_key("model.layers.3.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.0.self_attn.v_proj"));
}
#[test]
fn peft_select_target_modules_regex() {
let weights = peft_toy_weights(3);
let mut params = HashMap::new();
for b in 0..3 {
params.insert(format!("model.layers.{b}.self_attn.q_proj"), plain_params());
}
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ".*self_attn\\.q_proj" }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 3).unwrap();
assert_eq!(layers.len(), 3);
assert!(layers.contains_key("model.layers.2.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.0.self_attn.v_proj"));
}
#[test]
fn peft_select_exclude_modules_list() {
let weights = peft_toy_weights(2);
let mut params = HashMap::new();
for b in 0..2 {
params.insert(format!("model.layers.{b}.self_attn.q_proj"), plain_params());
}
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ".*_proj", "exclude_modules": ["v_proj"] }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 2).unwrap();
assert_eq!(layers.len(), 2);
assert!(layers.contains_key("model.layers.0.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.0.self_attn.v_proj"));
}
#[test]
fn peft_select_exclude_modules_regex() {
let weights = peft_toy_weights(2);
let mut params = HashMap::new();
for b in 0..2 {
params.insert(format!("model.layers.{b}.self_attn.q_proj"), plain_params());
}
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ".*_proj", "exclude_modules": ".*\\.v_proj" }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 2).unwrap();
assert_eq!(layers.len(), 2);
assert!(!layers.contains_key("model.layers.1.self_attn.v_proj"));
}
#[test]
fn peft_select_layers_to_transform_int() {
let weights = peft_toy_weights(4);
let mut params = HashMap::new();
params.insert(
"model.layers.1.self_attn.q_proj".to_string(),
plain_params(),
);
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["q_proj"], "layers_to_transform": 1 }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap();
assert_eq!(layers.len(), 1);
assert!(layers.contains_key("model.layers.1.self_attn.q_proj"));
}
#[test]
fn peft_select_layers_to_transform_list() {
let weights = peft_toy_weights(5);
let mut params = HashMap::new();
for b in [0, 3] {
params.insert(format!("model.layers.{b}.self_attn.q_proj"), plain_params());
}
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["q_proj"], "layers_to_transform": [0, 3] }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 5).unwrap();
assert_eq!(layers.len(), 2);
assert!(layers.contains_key("model.layers.0.self_attn.q_proj"));
assert!(layers.contains_key("model.layers.3.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.1.self_attn.q_proj"));
}
#[test]
fn peft_select_layers_pattern_custom_attr() {
let mut weights = Weights::new();
for b in 0..3 {
weights.insert(
format!("transformer.h.{b}.attn.c_attn.weight"),
base_weight(),
);
}
let mut params = HashMap::new();
params.insert("transformer.h.2.attn.c_attn".to_string(), plain_params());
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["c_attn"], "layers_to_transform": [2],
"layers_pattern": "h" }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 3).unwrap();
assert_eq!(layers.len(), 1);
assert!(layers.contains_key("transformer.h.2.attn.c_attn"));
}
#[test]
fn peft_select_no_restriction_adapts_all_blocks_over_16() {
let weights = peft_toy_weights(20);
let mut params = HashMap::new();
for b in 0..20 {
params.insert(format!("model.layers.{b}.self_attn.q_proj"), plain_params());
}
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["q_proj"] }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 20).unwrap();
assert_eq!(layers.len(), 20, "PEFT must adapt ALL 20 blocks, no window");
assert!(layers.contains_key("model.layers.0.self_attn.q_proj"));
assert!(layers.contains_key("model.layers.19.self_attn.q_proj"));
}
#[test]
fn peft_target_modules_all_linear_string_is_sentinel_not_regex() {
let json = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": "all-linear" }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let peft = match &cfg.selection {
AdapterSelection::Peft(p) => p,
other => panic!("expected a PEFT selection, got {other:?}"),
};
assert!(
matches!(peft.target_modules, Some(ModuleMatcher::AllLinear)),
"the `all-linear` string must parse to the AllLinear sentinel, not a regex"
);
let weights = peft_toy_weights(3);
let mut params = HashMap::new();
for b in 0..3 {
params.insert(format!("model.layers.{b}.self_attn.q_proj"), plain_params());
params.insert(format!("model.layers.{b}.self_attn.v_proj"), plain_params());
}
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 3).unwrap();
assert_eq!(
layers.len(),
6,
"all-linear adapts every linear minus the head"
);
for b in 0..3 {
assert!(layers.contains_key(&format!("model.layers.{b}.self_attn.q_proj")));
assert!(layers.contains_key(&format!("model.layers.{b}.self_attn.v_proj")));
}
assert!(
!layers.contains_key("lm_head"),
"all-linear must EXCLUDE the output head (lm_head)"
);
}
#[test]
fn peft_target_modules_all_linear_is_case_insensitive() {
for s in ["All-Linear", "ALL-LINEAR"] {
let json = format!(r#"{{ "peft_type": "LORA", "r": 2, "target_modules": {s:?} }}"#);
let cfg = LoraConfig::from_json(&json).unwrap();
assert!(
matches!(
&cfg.selection,
AdapterSelection::Peft(p) if matches!(p.target_modules, Some(ModuleMatcher::AllLinear))
),
"`{s}` must be recognized as the all-linear sentinel (case-insensitive)"
);
}
}
#[test]
fn peft_target_modules_all_linear_excludes_head_and_non_rank2() {
let q_w = base_weight(); let norm_w = Array::zeros::<f32>(&(8usize,)).unwrap(); let head_w = base_weight(); let peft = PeftSelection {
target_modules: Some(ModuleMatcher::AllLinear),
exclude_modules: None,
layers_to_transform: None,
layers_pattern: Vec::new(),
rank_pattern: Vec::new(),
alpha_pattern: Vec::new(),
use_rslora: false,
fan_in_fan_out: false,
};
assert!(peft_module_is_selected(
"model.layers.0.self_attn.q_proj",
&q_w,
&peft
));
assert!(
!peft_module_is_selected("model.layers.0.input_layernorm", &norm_w, &peft),
"a rank-1 weight is not a linear — all-linear must skip it"
);
assert!(
!peft_module_is_selected("lm_head", &head_w, &peft),
"the output head is excluded by all-linear even though it is rank-2"
);
assert!(!peft_module_is_selected("model.lm_head", &head_w, &peft));
}
#[test]
fn module_matcher_list_is_exact_or_dotted_suffix() {
let m = ModuleMatcher::List(vec!["q_proj".to_string()]);
assert!(m.matches("model.layers.0.self_attn.q_proj"));
assert!(m.matches("q_proj"));
assert!(!m.matches("model.xq_proj"));
assert!(!m.matches("q_proj_extra"));
}
#[test]
fn module_matcher_regex_is_full_match() {
let m = ModuleMatcher::Regex(Box::new(Regex::new(r".*\.q_proj").unwrap()));
assert!(m.matches("model.layers.0.self_attn.q_proj"));
assert!(!m.matches("model.layers.0.self_attn.q_proj.bias"));
let suffix = ModuleMatcher::Regex(Box::new(Regex::new(r"q_proj").unwrap()));
assert!(suffix.matches("q_proj"));
assert!(!suffix.matches("model.layers.0.self_attn.q_proj"));
}
#[test]
fn peft_layer_index_default_and_custom_pattern() {
assert_eq!(
peft_layer_index("model.layers.7.self_attn.q_proj", &[]),
Some(7)
);
assert_eq!(
peft_layer_index("transformer.h.3.attn.c_attn", &["h".to_string()]),
Some(3)
);
assert_eq!(peft_layer_index("lm_head", &[]), None);
}
#[test]
fn peft_key_translation_strips_prefix_maps_suffix_transposes() {
let mut raw: HashMap<String, Array> = HashMap::new();
raw.insert(
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight".to_string(),
Array::zeros::<f32>(&(2, 3)).unwrap(),
);
raw.insert(
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight".to_string(),
Array::zeros::<f32>(&(4, 2)).unwrap(),
);
raw.insert(
"base_model.model.model.layers.0.self_attn.q_proj.lora_magnitude_vector".to_string(),
Array::zeros::<f32>(&(4usize,)).unwrap(),
);
raw.insert("some.stray.weight".to_string(), base_weight());
let out = translate_peft_keys(raw).unwrap();
assert_eq!(out.len(), 3, "3 LoRA tensors, the stray key dropped");
let path = "model.layers.0.self_attn.q_proj";
assert_eq!(out[&format!("{path}.lora_a")].shape(), &[3, 2]);
assert_eq!(out[&format!("{path}.lora_b")].shape(), &[2, 4]);
assert_eq!(out[&format!("{path}.m")].shape(), &[4]);
}
#[test]
fn peft_key_translation_magnitude_vector_dot_weight_variant() {
let mut raw: HashMap<String, Array> = HashMap::new();
raw.insert(
"base_model.model.q_proj.lora_magnitude_vector.weight".to_string(),
Array::zeros::<f32>(&(2usize,)).unwrap(),
);
let out = translate_peft_keys(raw).unwrap();
assert!(out.contains_key("q_proj.m"));
}
#[test]
fn peft_end_to_end_rslora_scale_and_all_blocks() {
let tmp = std::env::temp_dir().join(format!("mlxrs_peft_e2e_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{
"peft_type": "LORA", "r": 16, "lora_alpha": 32.0,
"target_modules": ["self_attn.q_proj"], "use_rslora": true
}"#;
let q_paths: Vec<String> = (0..4)
.map(|b| format!("model.layers.{b}.self_attn.q_proj"))
.collect();
let q_refs: Vec<&str> = q_paths.iter().map(String::as_str).collect();
write_mock_peft_adapter(&tmp, cfg, &q_refs, 16, false, 0.01);
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
if let Some(LoraLayer::Lora(l)) = layers.get("model.layers.0.self_attn.q_proj") {
assert_eq!(l.scale(), 8.0, "rsLoRA scale must be alpha/sqrt(r)");
} else {
panic!("expected a LoRA layer");
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn peft_end_to_end_dora_with_magnitude_vector() {
let tmp = std::env::temp_dir().join(format!("mlxrs_peft_dora_e2e_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{
"peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["self_attn.q_proj"], "use_dora": true
}"#;
let q_paths: Vec<String> = (0..4)
.map(|b| format!("model.layers.{b}.self_attn.q_proj"))
.collect();
let q_refs: Vec<&str> = q_paths.iter().map(String::as_str).collect();
write_mock_peft_adapter(&tmp, cfg, &q_refs, 2, true, 0.01);
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
assert!(matches!(
layers.get("model.layers.0.self_attn.q_proj"),
Some(LoraLayer::Dora(_))
));
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn peft_end_to_end_rank_pattern_per_module_scale() {
let tmp = std::env::temp_dir().join(format!("mlxrs_peft_rankpat_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{
"peft_type": "LORA", "r": 2, "lora_alpha": 8.0,
"target_modules": ["self_attn.q_proj"],
"rank_pattern": { "layers\\.0\\..*q_proj": 4 }
}"#;
std::fs::write(tmp.join("adapter_config.json"), cfg).unwrap();
let mut arrays: HashMap<String, Array> = HashMap::new();
for b in 0..4 {
let r = if b == 0 { 4 } else { 2 };
let path = format!("model.layers.{b}.self_attn.q_proj");
arrays.insert(
format!("base_model.model.{path}.lora_A.weight"),
Array::full::<f32>(&(r, 3usize), 0.01).unwrap(),
);
arrays.insert(
format!("base_model.model.{path}.lora_B.weight"),
Array::full::<f32>(&(2usize, r), 0.01).unwrap(),
);
}
crate::io::save_safetensors(&tmp.join("adapter_model.safetensors"), &arrays).unwrap();
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
if let Some(LoraLayer::Lora(l)) = layers.get("model.layers.0.self_attn.q_proj") {
assert_eq!(l.scale(), 2.0, "rank_pattern block-0 scale = alpha/4");
} else {
panic!("expected a LoRA layer at block 0");
}
if let Some(LoraLayer::Lora(l)) = layers.get("model.layers.1.self_attn.q_proj") {
assert_eq!(l.scale(), 4.0, "default-rank block-1 scale = alpha/2");
} else {
panic!("expected a LoRA layer at block 1");
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn peft_end_to_end_exclude_modules() {
let tmp = std::env::temp_dir().join(format!("mlxrs_peft_excl_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{
"peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ".*_proj", "exclude_modules": ".*\\.v_proj"
}"#;
let weights = peft_toy_weights(3);
let q_paths: Vec<String> = (0..3)
.map(|b| format!("model.layers.{b}.self_attn.q_proj"))
.collect();
let q_refs: Vec<&str> = q_paths.iter().map(String::as_str).collect();
write_mock_peft_adapter(&tmp, cfg, &q_refs, 2, false, 0.01);
let layers = load_adapters(&weights, &tmp, None, 3).unwrap();
assert_eq!(layers.len(), 3);
for b in 0..3 {
assert!(layers.contains_key(&format!("model.layers.{b}.self_attn.q_proj")));
assert!(!layers.contains_key(&format!("model.layers.{b}.self_attn.v_proj")));
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn peft_fan_in_fan_out_transposes_base_weight() {
let standard_w = base_weight();
let fifo_w = standard_w.transpose().unwrap(); let mut std_weights = Weights::new();
std_weights.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
standard_w,
);
let mut fifo_weights = Weights::new();
fifo_weights.insert("model.layers.0.self_attn.q_proj.weight".to_string(), fifo_w);
let mut params = HashMap::new();
params.insert(
"model.layers.0.self_attn.q_proj".to_string(),
plain_params(),
);
let std_cfg = LoraConfig::from_json(
r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["q_proj"], "fan_in_fan_out": false }"#,
)
.unwrap();
let fifo_cfg = LoraConfig::from_json(
r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["q_proj"], "fan_in_fan_out": true }"#,
)
.unwrap();
let std_layers = linear_to_lora_layers(&std_weights, &std_cfg, ¶ms, None, 1).unwrap();
let fifo_layers = linear_to_lora_layers(&fifo_weights, &fifo_cfg, ¶ms, None, 1).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut std_out = std_layers["model.layers.0.self_attn.q_proj"]
.forward(&x)
.unwrap();
let mut fifo_out = fifo_layers["model.layers.0.self_attn.q_proj"]
.forward(&x)
.unwrap();
approx_eq(
&fifo_out.to_vec::<f32>().unwrap(),
&std_out.to_vec::<f32>().unwrap(),
1e-5,
);
}
#[test]
fn peft_fan_in_fan_out_quantized_is_err() {
let weight = Array::zeros::<u32>(&(8, 4)).unwrap();
let scales = Array::zeros::<f32>(&(8, 4)).unwrap();
let qbiases = Array::zeros::<f32>(&(8, 4)).unwrap();
let mut weights = Weights::new();
weights.insert("model.layers.0.self_attn.q_proj.weight".to_string(), weight);
weights.insert("model.layers.0.self_attn.q_proj.scales".to_string(), scales);
weights.insert(
"model.layers.0.self_attn.q_proj.biases".to_string(),
qbiases,
);
let quant = crate::lm::quant::PerLayerQuantization::from_global(crate::lm::quant::Quantization {
group_size: 32,
bits: 4,
mode: crate::lm::quant::QuantMode::Affine,
});
let err = build_base_linear(
&weights,
"model.layers.0.self_attn.q_proj",
&weights["model.layers.0.self_attn.q_proj.weight"],
Some(&quant),
true, )
.unwrap_err();
assert!(matches!(
err,
Error::LayerKeyed(ref payload)
if matches!(payload.inner(), Error::InvariantViolation(_))
));
}
#[test]
fn peft_load_uses_adapter_model_safetensors_filename() {
let tmp = std::env::temp_dir().join(format!("mlxrs_peft_fname_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = r#"{ "peft_type": "LORA", "r": 2, "lora_alpha": 4.0,
"target_modules": ["self_attn.q_proj"] }"#;
let q_paths: Vec<String> = (0..4)
.map(|b| format!("model.layers.{b}.self_attn.q_proj"))
.collect();
let q_refs: Vec<&str> = q_paths.iter().map(String::as_str).collect();
write_mock_peft_adapter(&tmp, cfg, &q_refs, 2, false, 0.01);
assert!(!tmp.join("adapters.safetensors").exists());
assert!(tmp.join("adapter_model.safetensors").exists());
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn mlxlm_native_path_unchanged_by_peft_work() {
let json = r#"{
"fine_tune_type": "lora", "num_layers": 8,
"lora_parameters": { "rank": 4, "scale": 16.0, "keys": ["q_proj"] }
}"#;
let cfg = LoraConfig::from_json(json).unwrap();
assert!(matches!(
cfg.selection,
AdapterSelection::MlxLm { num_layers: 8 }
));
assert!(cfg.peft().is_none());
assert_eq!(cfg.scale_for("anything"), 16.0);
assert_eq!(cfg.rank_for("anything"), 4);
assert!(!cfg.fan_in_fan_out());
}
#[test]
fn dora_linear_forward_matches_python_reference() {
let m = Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap();
let params = AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(m),
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let layer = DoRALinear::new(base, params, 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = layer.forward(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[3.0, 6.0], 1e-5);
}
#[test]
fn dora_linear_fuse_into_base_round_trip() {
let m = Array::from_slice::<f32>(&[1.5, 2.5], &(2usize,)).unwrap();
let params = AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(m),
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let layer = DoRALinear::new(base, params, 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut via_forward = layer.forward(&x).unwrap();
let fused = layer.fuse(false).unwrap();
let mut via_fused = fused.base_output(&x).unwrap();
approx_eq(
&via_fused.to_vec::<f32>().unwrap(),
&via_forward.to_vec::<f32>().unwrap(),
1e-4,
);
}
#[test]
fn dora_embedding_forward_matches_python_reference() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
#[rustfmt::skip]
let weight = Array::from_slice::<f32>(
&[
1.0, 0.0, 0.0,
0.0, 1.0, 0.0,
0.0, 0.0, 1.0,
],
&(num_embeddings, dims),
).unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a = Array::zeros::<f32>(&(num_embeddings, r)).unwrap();
let lora_b = Array::zeros::<f32>(&(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, 2.0).unwrap();
let ids = Array::from_slice::<i32>(&[0, 2], &(2usize,)).unwrap();
let mut out = layer.forward(&ids).unwrap();
approx_eq(
&out.to_vec::<f32>().unwrap(),
&[1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
1e-5,
);
}
#[test]
fn dora_embedding_forward_per_token_renorm_halves() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
#[rustfmt::skip]
let weight = Array::from_slice::<f32>(
&[
1.0, 0.0, 0.0,
0.0, 1.0, 0.0,
0.0, 0.0, 1.0,
],
&(num_embeddings, dims),
).unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a = Array::zeros::<f32>(&(num_embeddings, r)).unwrap();
let lora_b = Array::zeros::<f32>(&(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&[0.5, 0.5, 0.5], &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, 1.0).unwrap();
let ids = Array::from_slice::<i32>(&[1], &(1usize,)).unwrap();
let mut out = layer.forward(&ids).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[0.0, 0.5, 0.0], 1e-5);
}
#[test]
fn dora_embedding_as_linear_one_hot_identity() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
#[rustfmt::skip]
let weight = Array::from_slice::<f32>(
&[
1.0, 0.0, 0.0,
0.0, 1.0, 0.0,
0.0, 0.0, 1.0,
],
&(num_embeddings, dims),
).unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a = Array::zeros::<f32>(&(num_embeddings, r)).unwrap();
let lora_b = Array::zeros::<f32>(&(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = layer.as_linear(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[1.0, 2.0, 3.0], 1e-5);
}
#[test]
fn dora_embedding_fuse_round_trip() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
#[rustfmt::skip]
let weight = Array::from_slice::<f32>(
&[
1.0, 0.5, 0.0,
0.0, 1.0, 0.5,
0.5, 0.0, 1.0,
],
&(num_embeddings, dims),
).unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a =
Array::from_slice::<f32>(&[0.1, 0.0, 0.0, 0.1, 0.1, 0.1], &(num_embeddings, r)).unwrap();
let lora_b = Array::from_slice::<f32>(&[0.2, 0.0, 0.1, 0.0, 0.1, 0.2], &(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&[1.5, 2.0, 1.2], &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 0.5], &(1, dims)).unwrap();
let mut via_aslinear = layer.as_linear(&x).unwrap();
let fused = layer.fuse().unwrap();
let mut via_fused_aslinear = fused.as_linear(&x).unwrap();
approx_eq(
&via_fused_aslinear.to_vec::<f32>().unwrap(),
&via_aslinear.to_vec::<f32>().unwrap(),
1e-4,
);
}
#[test]
fn dora_embedding_requires_magnitude() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
let weight = Array::zeros::<f32>(&(num_embeddings, dims)).unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a = Array::zeros::<f32>(&(num_embeddings, r)).unwrap();
let lora_b = Array::zeros::<f32>(&(r, dims)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: None,
};
let err = DoRAEmbedding::new(base, params, 1.0).unwrap_err();
assert!(
matches!(&err, Error::MissingField(p)
if p.type_name() == "DoRAEmbedding::new" && p.field().contains("magnitude")),
"expected Error::MissingField naming `magnitude`, got {err:?}"
);
}
#[test]
fn dora_embedding_rejects_wrong_factor_shape() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
let weight = Array::zeros::<f32>(&(num_embeddings, dims)).unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let bad_a = Array::zeros::<f32>(&(2usize, r)).unwrap();
let lora_b = Array::zeros::<f32>(&(r, dims)).unwrap();
let m = Array::zeros::<f32>(&(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a: bad_a,
lora_b,
magnitude: Some(m),
};
let err = DoRAEmbedding::new(base, params, 1.0).unwrap_err();
assert!(
matches!(&err, Error::LengthMismatch(p)
if p.expected() == num_embeddings && p.actual() == 2
&& p.context().contains("lora_a")),
"expected Error::LengthMismatch for wrong leading axis, got {err:?}"
);
}
#[test]
fn qdora_linear_forward_matches_python_reference() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let dense_params_no_m = AdapterParams {
lora_a: la.try_clone().unwrap(),
lora_b: lb.try_clone().unwrap(),
magnitude: None,
};
let scale = 2.0f32;
let delta = lora_delta(&dense_params_no_m, scale).unwrap();
let adapted = dense_w.add(&delta).unwrap();
let m = ops::linalg_full::norm(&adapted, 2.0, &[1], false).unwrap();
let dense_base = BaseLinear::dense(dense_w.try_clone().unwrap(), None).unwrap();
let dense_layer = DoRALinear::new(
dense_base,
AdapterParams {
lora_a: la.try_clone().unwrap(),
lora_b: lb.try_clone().unwrap(),
magnitude: Some(m.try_clone().unwrap()),
},
scale,
)
.unwrap();
let x = Array::full::<f32>(&(1usize, input_dims), 1.0).unwrap();
let mut dense_out = dense_layer.forward(&x).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
let q_layer = DoRALinear::new(
q_base,
AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: Some(m),
},
scale,
)
.unwrap();
let mut q_out = q_layer.forward(&x).unwrap();
approx_eq(
&q_out.to_vec::<f32>().unwrap(),
&dense_out.to_vec::<f32>().unwrap(),
2e-2,
);
}
#[test]
fn qdora_linear_fuse_round_trip() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let m = Array::from_slice::<f32>(&[1.5, 2.5], &(output_dims,)).unwrap();
let x = Array::full::<f32>(&(1usize, input_dims), 1.0).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
let q_layer = DoRALinear::new(
q_base,
AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: Some(m),
},
2.0,
)
.unwrap();
let mut via_forward = q_layer.forward(&x).unwrap();
let fused = q_layer.fuse(true).unwrap();
assert!(matches!(fused, BaseLinear::Dense { .. }));
let mut via_fused = fused.base_output(&x).unwrap();
approx_eq(
&via_fused.to_vec::<f32>().unwrap(),
&via_forward.to_vec::<f32>().unwrap(),
2e-2,
);
}
#[test]
fn load_dora_adapter_from_safetensors() {
let tmp = std::env::temp_dir().join(format!("mlxrs_a2_dora_load_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
write_mock_adapter(&tmp, "dora", true);
let weights = toy_weights();
let layers = load_adapters(&weights, &tmp, None, 4).unwrap();
assert_eq!(layers.len(), 4);
for b in 0..4 {
let key = format!("model.layers.{b}.self_attn.q_proj");
match layers.get(&key) {
Some(LoraLayer::Dora(d)) => {
assert_eq!(d.magnitude().shape(), &[2]);
}
other => panic!("expected DoRA layer at {key}, got {other:?}"),
}
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn linear_to_dora_layers_grafts_correctly() {
let weights = toy_weights();
let cfg = LoraConfig {
fine_tune_type: FineTuneType::Dora,
lora_parameters: LoraParameters {
rank: 2,
scale: Some(2.0),
alpha: None,
keys: vec!["self_attn.q_proj".to_string()],
dropout: None,
},
use_dora: false,
selection: AdapterSelection::MlxLm { num_layers: 16 },
};
let mut params = HashMap::new();
for b in 0..4 {
let path = format!("model.layers.{b}.self_attn.q_proj");
params.insert(
path,
AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap()),
},
);
}
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 4).unwrap();
assert_eq!(layers.len(), 4);
for b in 0..4 {
let key = format!("model.layers.{b}.self_attn.q_proj");
assert!(
matches!(layers.get(&key), Some(LoraLayer::Dora(_))),
"expected DoRA layer at {key}"
);
}
assert!(!layers.contains_key("model.layers.0.self_attn.k_proj"));
assert!(!layers.contains_key("lm_head"));
}
fn f16_rt(x: f32) -> f64 {
half::f16::from_f32(x).to_f64()
}
fn bf16_rt(x: f32) -> f64 {
half::bf16::from_f32(x).to_f64()
}
#[allow(clippy::type_complexity)] fn mp_fixture() -> (
Vec<Vec<f32>>, // weight_f32 [4][4]
Vec<Vec<f32>>, // lora_a_f32 [4][2]
Vec<Vec<f32>>, // lora_b_f32 [2][4]
Vec<f32>, // m_f32 [4]
f32, // scale
) {
let weight_f32 = vec![
vec![1.0, 1.0, 1.0, 1.0],
vec![0.5, 0.5, 0.5, 0.5],
vec![0.25, 0.25, 0.25, 0.25],
vec![-0.75, -0.75, -0.75, -0.75],
];
let lora_a_f32 = vec![
vec![1.0, 0.0],
vec![0.5, 0.0],
vec![0.25, 0.0],
vec![-0.75, 0.0],
];
let lora_b_f32 = vec![
vec![-0.99853, -0.99853, -0.99853, -0.99853],
vec![0.0, 0.0, 0.0, 0.0],
];
let m_f32 = vec![1.0, 0.5, 0.25, 0.75];
let scale = 1.0f32;
(weight_f32, lora_a_f32, lora_b_f32, m_f32, scale)
}
#[allow(clippy::too_many_arguments)]
fn forward_scalar_reference(
weight_f32: &[Vec<f32>],
lora_a_f32: &[Vec<f32>],
lora_b_f32: &[Vec<f32>],
m_f32: &[f32],
scale: f32,
ids: &[usize],
rt: fn(f32) -> f64,
cast_z_upfront: bool,
) -> Vec<f32> {
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let scale_f64 = scale as f64;
let mut out = Vec::with_capacity(ids.len() * dims);
for &tid in ids {
let y_rt: Vec<f64> = weight_f32[tid].iter().map(|&w| rt(w)).collect();
let mut z_uncast = vec![0.0f64; dims];
for d in 0..dims {
let mut acc = 0.0f64;
for k in 0..r {
acc += (lora_a_f32[tid][k] as f64) * (lora_b_f32[k][d] as f64);
}
z_uncast[d] = scale_f64 * acc;
}
let z_cast: Vec<f64> = z_uncast.iter().map(|&v| rt(v as f32)).collect();
let out_pre: Vec<f64> = (0..dims)
.map(|d| rt((y_rt[d] + z_cast[d]) as f32))
.collect();
let z_for_norm = if cast_z_upfront { &z_cast } else { &z_uncast };
let adapted: Vec<f64> = (0..dims).map(|d| y_rt[d] + z_for_norm[d]).collect();
let denom = adapted.iter().map(|v| v * v).sum::<f64>().sqrt();
let norm_scale = (m_f32[tid] as f64) / denom;
for &op in &out_pre {
let scaled = norm_scale * op;
out.push(scaled as f32);
}
}
out
}
#[allow(clippy::too_many_arguments)]
fn as_linear_scalar_reference(
weight_f32: &[Vec<f32>],
lora_a_f32: &[Vec<f32>],
lora_b_f32: &[Vec<f32>],
m_f32: &[f32],
scale: f32,
x_f32: &[Vec<f32>],
rt: fn(f32) -> f64,
cast_delta_upfront: bool,
) -> Vec<f32> {
let num_embeddings = weight_f32.len();
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let batch = x_f32.len();
let scale_f64 = scale as f64;
let mut delta_uncast = vec![vec![0.0f64; dims]; num_embeddings];
for e in 0..num_embeddings {
for d in 0..dims {
let mut acc = 0.0f64;
for k in 0..r {
acc += (lora_a_f32[e][k] as f64) * (lora_b_f32[k][d] as f64);
}
delta_uncast[e][d] = scale_f64 * acc;
}
}
let delta_cast: Vec<Vec<f64>> = delta_uncast
.iter()
.map(|row| row.iter().map(|&v| rt(v as f32)).collect())
.collect();
let delta_for_norm = if cast_delta_upfront {
&delta_cast
} else {
&delta_uncast
};
let mut adapted = vec![vec![0.0f64; dims]; num_embeddings];
for e in 0..num_embeddings {
for d in 0..dims {
adapted[e][d] = rt(weight_f32[e][d]) + delta_for_norm[e][d];
}
}
let denom: Vec<f64> = adapted
.iter()
.map(|row| row.iter().map(|v| v * v).sum::<f64>().sqrt())
.collect();
let norm_scale: Vec<f64> = (0..num_embeddings)
.map(|e| (m_f32[e] as f64) / denom[e])
.collect();
let mut out = Vec::with_capacity(batch * num_embeddings);
for x_row in x_f32 {
let x_rt: Vec<f64> = x_row.iter().map(|&v| rt(v)).collect();
for e in 0..num_embeddings {
let mut y_be = 0.0f64;
for d in 0..dims {
y_be += x_rt[d] * rt(weight_f32[e][d]);
}
let xb: Vec<f64> = (0..r)
.map(|k| {
(0..dims)
.map(|d| x_rt[d] * (lora_b_f32[k][d] as f64))
.sum::<f64>()
})
.collect();
let z_be: f64 = (0..r).map(|k| xb[k] * (lora_a_f32[e][k] as f64)).sum();
let scaled_z_be = scale_f64 * z_be;
let scaled_z_cast = rt(scaled_z_be as f32);
let out_pre = y_be + scaled_z_cast;
out.push((norm_scale[e] * out_pre) as f32);
}
}
out
}
#[test]
fn dora_embedding_forward_mixed_precision_matches_reference_f16_base_f32_adapter() {
let (weight_f32, lora_a_f32, lora_b_f32, m_f32, scale) = mp_fixture();
let num_embeddings = weight_f32.len();
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let flat_w: Vec<f32> = weight_f32.iter().flatten().copied().collect();
let flat_a: Vec<f32> = lora_a_f32.iter().flatten().copied().collect();
let flat_b: Vec<f32> = lora_b_f32.iter().flatten().copied().collect();
let weight_f16 = Array::from_slice::<f32>(&flat_w, &(num_embeddings, dims))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let base = BaseEmbedding::dense(weight_f16).unwrap();
let lora_a = Array::from_slice::<f32>(&flat_a, &(num_embeddings, r)).unwrap();
let lora_b = Array::from_slice::<f32>(&flat_b, &(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&m_f32, &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, scale).unwrap();
let ids_vec: Vec<i32> = (0..num_embeddings as i32).collect();
let ids = Array::from_slice::<i32>(&ids_vec, &(num_embeddings,)).unwrap();
let out = layer.forward(&ids).unwrap();
assert_eq!(
out.dtype().unwrap(),
Dtype::F32,
"forward must return the promoted dtype = f32 (f16 base × f32 adapter)"
);
let mut out_f32 = out.astype(Dtype::F32).unwrap();
let got = out_f32.to_vec::<f32>().unwrap();
let ids_usize: Vec<usize> = (0..num_embeddings).collect();
let want = forward_scalar_reference(
&weight_f32,
&lora_a_f32,
&lora_b_f32,
&m_f32,
scale,
&ids_usize,
f16_rt,
false, );
approx_eq(&got, &want, 5e-3);
}
#[test]
fn dora_embedding_forward_mixed_precision_matches_reference_bf16_base_f32_adapter() {
let (weight_f32, lora_a_f32, lora_b_f32, m_f32, scale) = mp_fixture();
let num_embeddings = weight_f32.len();
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let flat_w: Vec<f32> = weight_f32.iter().flatten().copied().collect();
let flat_a: Vec<f32> = lora_a_f32.iter().flatten().copied().collect();
let flat_b: Vec<f32> = lora_b_f32.iter().flatten().copied().collect();
let weight_bf16 = Array::from_slice::<f32>(&flat_w, &(num_embeddings, dims))
.unwrap()
.astype(Dtype::BF16)
.unwrap();
let base = BaseEmbedding::dense(weight_bf16).unwrap();
let lora_a = Array::from_slice::<f32>(&flat_a, &(num_embeddings, r)).unwrap();
let lora_b = Array::from_slice::<f32>(&flat_b, &(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&m_f32, &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, scale).unwrap();
let ids_vec: Vec<i32> = (0..num_embeddings as i32).collect();
let ids = Array::from_slice::<i32>(&ids_vec, &(num_embeddings,)).unwrap();
let out = layer.forward(&ids).unwrap();
assert_eq!(
out.dtype().unwrap(),
Dtype::F32,
"forward must return the promoted dtype = f32 (bf16 base × f32 adapter)"
);
let mut out_f32 = out.astype(Dtype::F32).unwrap();
let got = out_f32.to_vec::<f32>().unwrap();
let ids_usize: Vec<usize> = (0..num_embeddings).collect();
let want = forward_scalar_reference(
&weight_f32,
&lora_a_f32,
&lora_b_f32,
&m_f32,
scale,
&ids_usize,
bf16_rt,
false,
);
approx_eq(&got, &want, 5e-2);
}
#[test]
fn dora_embedding_as_linear_mixed_precision_matches_reference_f16_base_f32_adapter() {
let (weight_f32, lora_a_f32, lora_b_f32, m_f32, scale) = mp_fixture();
let num_embeddings = weight_f32.len();
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let flat_w: Vec<f32> = weight_f32.iter().flatten().copied().collect();
let flat_a: Vec<f32> = lora_a_f32.iter().flatten().copied().collect();
let flat_b: Vec<f32> = lora_b_f32.iter().flatten().copied().collect();
let weight_f16 = Array::from_slice::<f32>(&flat_w, &(num_embeddings, dims))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let base = BaseEmbedding::dense(weight_f16).unwrap();
let lora_a = Array::from_slice::<f32>(&flat_a, &(num_embeddings, r)).unwrap();
let lora_b = Array::from_slice::<f32>(&flat_b, &(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&m_f32, &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, scale).unwrap();
let x_f32 = vec![vec![1.0, 1.0, 1.0, 1.0], vec![0.5, -0.25, 0.75, -0.125]];
let flat_x: Vec<f32> = x_f32.iter().flatten().copied().collect();
let x_arr = Array::from_slice::<f32>(&flat_x, &(x_f32.len(), dims))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let out = layer.as_linear(&x_arr).unwrap();
let mut out_f32 = out.astype(Dtype::F32).unwrap();
let got = out_f32.to_vec::<f32>().unwrap();
let want = as_linear_scalar_reference(
&weight_f32,
&lora_a_f32,
&lora_b_f32,
&m_f32,
scale,
&x_f32,
f16_rt,
false,
);
approx_eq(&got, &want, 5e-3);
}
#[test]
fn dora_embedding_as_linear_mixed_precision_matches_reference_bf16_base_f32_adapter() {
let (weight_f32, lora_a_f32, lora_b_f32, m_f32, scale) = mp_fixture();
let num_embeddings = weight_f32.len();
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let flat_w: Vec<f32> = weight_f32.iter().flatten().copied().collect();
let flat_a: Vec<f32> = lora_a_f32.iter().flatten().copied().collect();
let flat_b: Vec<f32> = lora_b_f32.iter().flatten().copied().collect();
let weight_bf16 = Array::from_slice::<f32>(&flat_w, &(num_embeddings, dims))
.unwrap()
.astype(Dtype::BF16)
.unwrap();
let base = BaseEmbedding::dense(weight_bf16).unwrap();
let lora_a = Array::from_slice::<f32>(&flat_a, &(num_embeddings, r)).unwrap();
let lora_b = Array::from_slice::<f32>(&flat_b, &(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&m_f32, &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, scale).unwrap();
let x_f32 = vec![vec![1.0, 1.0, 1.0, 1.0], vec![0.5, -0.25, 0.75, -0.125]];
let flat_x: Vec<f32> = x_f32.iter().flatten().copied().collect();
let x_arr = Array::from_slice::<f32>(&flat_x, &(x_f32.len(), dims))
.unwrap()
.astype(Dtype::BF16)
.unwrap();
let out = layer.as_linear(&x_arr).unwrap();
let mut out_f32 = out.astype(Dtype::F32).unwrap();
let got = out_f32.to_vec::<f32>().unwrap();
let want = as_linear_scalar_reference(
&weight_f32,
&lora_a_f32,
&lora_b_f32,
&m_f32,
scale,
&x_f32,
bf16_rt,
false,
);
approx_eq(&got, &want, 5e-2);
}
#[test]
fn dora_embedding_forward_loses_precision_with_upfront_cast_regression_oracle() {
let (weight_f32, lora_a_f32, lora_b_f32, m_f32, scale) = mp_fixture();
let num_embeddings = weight_f32.len();
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let flat_w: Vec<f32> = weight_f32.iter().flatten().copied().collect();
let flat_a: Vec<f32> = lora_a_f32.iter().flatten().copied().collect();
let flat_b: Vec<f32> = lora_b_f32.iter().flatten().copied().collect();
let weight_f16 = Array::from_slice::<f32>(&flat_w, &(num_embeddings, dims))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let base = BaseEmbedding::dense(weight_f16).unwrap();
let lora_a = Array::from_slice::<f32>(&flat_a, &(num_embeddings, r)).unwrap();
let lora_b = Array::from_slice::<f32>(&flat_b, &(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&m_f32, &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, scale).unwrap();
let ids_vec: Vec<i32> = (0..num_embeddings as i32).collect();
let ids = Array::from_slice::<i32>(&ids_vec, &(num_embeddings,)).unwrap();
let out = layer.forward(&ids).unwrap();
assert_eq!(
out.dtype().unwrap(),
Dtype::F32,
"regression-oracle: forward must return the promoted dtype = f32"
);
let mut out_f32 = out.astype(Dtype::F32).unwrap();
let got = out_f32.to_vec::<f32>().unwrap();
let ids_usize: Vec<usize> = (0..num_embeddings).collect();
let want_new = forward_scalar_reference(
&weight_f32,
&lora_a_f32,
&lora_b_f32,
&m_f32,
scale,
&ids_usize,
f16_rt,
false, );
let want_old = forward_scalar_reference(
&weight_f32,
&lora_a_f32,
&lora_b_f32,
&m_f32,
scale,
&ids_usize,
f16_rt,
true, );
let new_max_err = got
.iter()
.zip(want_new.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let old_max_err = got
.iter()
.zip(want_old.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
new_max_err <= 5e-3,
"uncast pipeline must match scalar reference at fp16 tol; got max err {new_max_err}",
);
assert!(
old_max_err >= 1e-2,
"upfront-cast pipeline must mismatch the scalar reference noticeably; got max err {old_max_err} (cancellation fixture may need re-tuning)",
);
assert!(
new_max_err * 5.0 <= old_max_err,
"regression-oracle expected ≥5× tighter uncast-vs-upfront-cast fit; got uncast={new_max_err}, upfront-cast={old_max_err}",
);
}
#[test]
fn dora_embedding_forward_returns_promoted_dtype_for_mixed_precision() {
for (narrow, label) in [(Dtype::F16, "f16"), (Dtype::BF16, "bf16")] {
let (weight_f32, lora_a_f32, lora_b_f32, m_f32, scale) = mp_fixture();
let num_embeddings = weight_f32.len();
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let flat_w: Vec<f32> = weight_f32.iter().flatten().copied().collect();
let flat_a: Vec<f32> = lora_a_f32.iter().flatten().copied().collect();
let flat_b: Vec<f32> = lora_b_f32.iter().flatten().copied().collect();
let weight_narrow = Array::from_slice::<f32>(&flat_w, &(num_embeddings, dims))
.unwrap()
.astype(narrow)
.unwrap();
let base = BaseEmbedding::dense(weight_narrow).unwrap();
let lora_a = Array::from_slice::<f32>(&flat_a, &(num_embeddings, r)).unwrap();
let lora_b = Array::from_slice::<f32>(&flat_b, &(r, dims)).unwrap();
let m = Array::from_slice::<f32>(&m_f32, &(num_embeddings,)).unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, scale).unwrap();
let ids_vec: Vec<i32> = (0..num_embeddings as i32).collect();
let ids = Array::from_slice::<i32>(&ids_vec, &(num_embeddings,)).unwrap();
let out = layer.forward(&ids).unwrap();
assert_eq!(
out.dtype().unwrap(),
Dtype::F32,
"forward must return promoted dtype f32 for {label} base × f32 adapter (no final narrowing astype)",
);
}
}
#[test]
fn dora_embedding_forward_preserves_base_dtype_for_uniform_precision() {
for (uniform, label) in [
(Dtype::F32, "f32"),
(Dtype::F16, "f16"),
(Dtype::BF16, "bf16"),
] {
let (weight_f32, lora_a_f32, lora_b_f32, m_f32, scale) = mp_fixture();
let num_embeddings = weight_f32.len();
let dims = weight_f32[0].len();
let r = lora_a_f32[0].len();
let flat_w: Vec<f32> = weight_f32.iter().flatten().copied().collect();
let flat_a: Vec<f32> = lora_a_f32.iter().flatten().copied().collect();
let flat_b: Vec<f32> = lora_b_f32.iter().flatten().copied().collect();
let weight = Array::from_slice::<f32>(&flat_w, &(num_embeddings, dims))
.unwrap()
.astype(uniform)
.unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a = Array::from_slice::<f32>(&flat_a, &(num_embeddings, r))
.unwrap()
.astype(uniform)
.unwrap();
let lora_b = Array::from_slice::<f32>(&flat_b, &(r, dims))
.unwrap()
.astype(uniform)
.unwrap();
let m = Array::from_slice::<f32>(&m_f32, &(num_embeddings,))
.unwrap()
.astype(uniform)
.unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, scale).unwrap();
let ids_vec: Vec<i32> = (0..num_embeddings as i32).collect();
let ids = Array::from_slice::<i32>(&ids_vec, &(num_embeddings,)).unwrap();
let out = layer.forward(&ids).unwrap();
assert_eq!(
out.dtype().unwrap(),
uniform,
"forward must return {label} when base AND adapter are uniform {label} (no promotion)",
);
}
}
#[test]
fn scaled_helper_coerces_scalar_to_array_dtype() {
for (dt, label) in [
(Dtype::F16, "f16"),
(Dtype::BF16, "bf16"),
(Dtype::F32, "f32"),
] {
let arr = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3usize,))
.unwrap()
.astype(dt)
.unwrap();
let out = scaled(&arr, 0.5).unwrap();
assert_eq!(
out.dtype().unwrap(),
dt,
"scaled must coerce the scalar to the array's dtype and keep the {label} result in {label}",
);
}
}
#[test]
fn dora_embedding_forward_uniform_f16_adapter_returns_f16() {
dora_embedding_forward_uniform_dtype_case(Dtype::F16, "f16");
}
#[test]
fn dora_embedding_forward_uniform_bf16_adapter_returns_bf16() {
dora_embedding_forward_uniform_dtype_case(Dtype::BF16, "bf16");
}
fn dora_embedding_forward_uniform_dtype_case(uniform: Dtype, label: &str) {
let num_embeddings = 2usize;
let dims = 2usize;
let r = 1usize;
let weight = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(num_embeddings, dims))
.unwrap()
.astype(uniform)
.unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a = Array::from_slice::<f32>(&[1.0, 0.0], &(num_embeddings, r))
.unwrap()
.astype(uniform)
.unwrap();
let lora_b = Array::from_slice::<f32>(&[1.0, 0.0], &(r, dims))
.unwrap()
.astype(uniform)
.unwrap();
let m = Array::from_slice::<f32>(&[1.0, 1.0], &(num_embeddings,))
.unwrap()
.astype(uniform)
.unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, 1.0f32).unwrap();
let ids = Array::from_slice::<i32>(&[0, 1], &(2usize,)).unwrap();
let out = layer.forward(&ids).unwrap();
assert_eq!(
out.dtype().unwrap(),
uniform,
"forward must return {label} for uniform-{label} base + adapter (scaled() coerces scalar to arr.dtype)",
);
let mut out_f32 = out.astype(Dtype::F32).unwrap();
let got = out_f32.to_vec::<f32>().unwrap();
approx_eq(&got, &[1.0, 0.0, 0.0, 1.0], 1e-3);
}
#[test]
fn dora_embedding_as_linear_uniform_f16_adapter_returns_f16() {
dora_embedding_as_linear_uniform_dtype_case(Dtype::F16, "f16");
}
#[test]
fn dora_embedding_as_linear_uniform_bf16_adapter_returns_bf16() {
dora_embedding_as_linear_uniform_dtype_case(Dtype::BF16, "bf16");
}
fn dora_embedding_as_linear_uniform_dtype_case(uniform: Dtype, label: &str) {
let num_embeddings = 2usize;
let dims = 2usize;
let r = 1usize;
let weight = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(num_embeddings, dims))
.unwrap()
.astype(uniform)
.unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a = Array::from_slice::<f32>(&[1.0, 0.0], &(num_embeddings, r))
.unwrap()
.astype(uniform)
.unwrap();
let lora_b = Array::from_slice::<f32>(&[1.0, 0.0], &(r, dims))
.unwrap()
.astype(uniform)
.unwrap();
let m = Array::from_slice::<f32>(&[1.0, 1.0], &(num_embeddings,))
.unwrap()
.astype(uniform)
.unwrap();
let params = AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = DoRAEmbedding::new(base, params, 1.0f32).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 1.0], &(1usize, dims))
.unwrap()
.astype(uniform)
.unwrap();
let out = layer.as_linear(&x).unwrap();
assert_eq!(
out.dtype().unwrap(),
uniform,
"as_linear must return {label} for uniform-{label} base + adapter (scaled() coerces scalar to arr.dtype)",
);
let mut out_f32 = out.astype(Dtype::F32).unwrap();
let got = out_f32.to_vec::<f32>().unwrap();
approx_eq(&got, &[1.0, 1.0], 1e-3);
}
#[test]
fn dora_linear_forward_uniform_f16_adapter_returns_f16() {
let weight = base_weight().astype(Dtype::F16).unwrap();
let la = lora_a().astype(Dtype::F16).unwrap();
let lb = lora_b().astype(Dtype::F16).unwrap();
let m = Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let params = AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: Some(m),
};
let base = BaseLinear::dense(weight, None).unwrap();
let layer = DoRALinear::new(base, params, 2.0).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let out = layer.forward(&x).unwrap();
assert_eq!(
out.dtype().unwrap(),
Dtype::F16,
"DoRALinear::forward must return f16 for uniform-f16 base + adapter (trailing astype + scaled() coercion both contribute)",
);
let mut out_f32 = out.astype(Dtype::F32).unwrap();
approx_eq(&out_f32.to_vec::<f32>().unwrap(), &[3.0, 6.0], 1e-3);
}
#[cfg(unix)]
#[test]
fn locate_adapter_safetensors_falls_back_when_preferred_is_broken_symlink() {
use std::os::unix::fs::symlink;
let tmp = std::env::temp_dir().join(format!(
"mlxrs_lora_broken_symlink_{}_{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
symlink(tmp.join("does_not_exist"), tmp.join(MLX_LM_ADAPTER_FILE)).unwrap();
std::fs::write(tmp.join(PEFT_ADAPTER_FILE), b"valid bytes").unwrap();
let cfg = mlxlm_config(2, keyed_params(vec!["self_attn.q_proj".to_string()]));
let found = locate_adapter_safetensors(&tmp, &cfg)
.expect("expected fallback to be located when preferred is a broken symlink");
assert_eq!(found, tmp.join(PEFT_ADAPTER_FILE));
let _ = std::fs::remove_dir_all(&tmp);
}
#[cfg(unix)]
#[test]
fn locate_adapter_safetensors_rejects_non_regular_preferred_path_even_with_valid_fallback() {
let tmp = std::env::temp_dir().join(format!(
"mlxrs_lora_nonreg_preferred_{}_{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
std::fs::create_dir(tmp.join(MLX_LM_ADAPTER_FILE)).unwrap();
std::fs::write(tmp.join(PEFT_ADAPTER_FILE), b"valid bytes").unwrap();
let cfg = mlxlm_config(2, keyed_params(vec!["self_attn.q_proj".to_string()]));
let err = locate_adapter_safetensors(&tmp, &cfg)
.expect_err("expected fail-fast Error::FileIo for non-regular preferred path");
match err {
Error::FileIo(p) => {
assert_eq!(
p.path(),
tmp.join(MLX_LM_ADAPTER_FILE).as_path(),
"path round-trips through FileIoPayload"
);
assert_eq!(
p.op(),
FileOp::Stat,
"non-regular surfaces from the stat probe"
);
assert_eq!(
p.inner().kind(),
std::io::ErrorKind::InvalidInput,
"non-regular candidates surface with InvalidInput",
);
}
other => panic!("expected Error::FileIo for non-regular preferred path, got {other:?}"),
}
let _ = std::fs::remove_dir_all(&tmp);
}
#[cfg(unix)]
#[test]
fn locate_adapter_safetensors_rejects_non_regular_fallback_path() {
let tmp = std::env::temp_dir().join(format!(
"mlxrs_lora_nonreg_fallback_{}_{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
std::fs::create_dir(tmp.join(PEFT_ADAPTER_FILE)).unwrap();
let cfg = mlxlm_config(2, keyed_params(vec!["self_attn.q_proj".to_string()]));
let err = locate_adapter_safetensors(&tmp, &cfg)
.expect_err("expected fail-fast Error::FileIo for non-regular fallback path");
match err {
Error::FileIo(p) => {
assert_eq!(
p.path(),
tmp.join(PEFT_ADAPTER_FILE).as_path(),
"path round-trips through FileIoPayload"
);
assert_eq!(
p.op(),
FileOp::Stat,
"non-regular surfaces from the stat probe"
);
assert_eq!(
p.inner().kind(),
std::io::ErrorKind::InvalidInput,
"non-regular candidates surface with InvalidInput",
);
}
other => panic!("expected Error::FileIo for non-regular fallback path, got {other:?}"),
}
let _ = std::fs::remove_dir_all(&tmp);
}
#[cfg(unix)]
#[test]
fn locate_adapter_safetensors_surfaces_symlink_loop_as_typed_file_io() {
use std::os::unix::fs::symlink;
let tmp = std::env::temp_dir().join(format!(
"mlxrs_lora_symlink_loop_{}_{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
let preferred = tmp.join(MLX_LM_ADAPTER_FILE);
symlink(&preferred, &preferred).unwrap();
let cfg = mlxlm_config(2, keyed_params(vec!["self_attn.q_proj".to_string()]));
let err = locate_adapter_safetensors(&tmp, &cfg)
.expect_err("expected typed FileIo error for symlink loop");
match err {
Error::FileIo(p) => {
assert_eq!(
p.path(),
preferred.as_path(),
"path round-trips through FileIoPayload"
);
assert_eq!(
p.op(),
FileOp::Stat,
"loop surfaces from the stat probe, not open"
);
}
other => panic!("expected Error::FileIo for symlink loop, got {other:?}"),
}
let _ = std::fs::remove_dir_all(&tmp);
}
#[test]
fn fine_tune_type_as_str_and_display_and_variant_tags() {
assert_eq!(FineTuneType::Lora.as_str(), "lora");
assert_eq!(FineTuneType::Dora.as_str(), "dora");
assert_eq!(FineTuneType::Full.as_str(), "full");
assert_eq!(FineTuneType::Lora.to_string(), "lora");
assert_eq!(FineTuneType::Dora.to_string(), "dora");
assert_eq!(FineTuneType::Full.to_string(), "full");
assert!(FineTuneType::Lora.is_lora());
assert!(FineTuneType::Dora.is_dora());
assert!(FineTuneType::Full.is_full());
assert!(!FineTuneType::Lora.is_dora());
assert_eq!(FineTuneType::default(), FineTuneType::Lora);
}
#[test]
fn peft_scale_for_nonpositive_resolved_rank_is_zero() {
let peft = PeftSelection {
target_modules: Some(ModuleMatcher::List(vec!["q_proj".to_string()])),
exclude_modules: None,
layers_to_transform: None,
layers_pattern: Vec::new(),
rank_pattern: vec![("q_proj".to_string(), 0i32)],
alpha_pattern: Vec::new(),
use_rslora: false,
fan_in_fan_out: false,
};
assert_eq!(
peft.scale_for("model.layers.0.self_attn.q_proj", 8, 16.0),
0.0
);
assert_eq!(
peft.scale_for("model.layers.0.self_attn.v_proj", 8, 16.0),
2.0
);
}
#[test]
fn pattern_lookup_skips_uncompilable_pattern_key() {
let patterns = vec![
("(unclosed".to_string(), 11i32),
("q_proj".to_string(), 22i32),
];
assert_eq!(
pattern_lookup(&patterns, "model.layers.0.self_attn.q_proj"),
Some(22)
);
assert_eq!(
pattern_lookup::<i32>(&[("(bad".to_string(), 5)], "anything"),
None
);
}
#[test]
fn ordered_pattern_non_map_value_is_parse_error() {
let json = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "rank_pattern": "not-a-map"
}"#;
let err =
LoraConfig::from_json(json).expect_err("a non-object `rank_pattern` must be a parse error");
assert!(
matches!(err, Error::Parse(_)),
"expected Error::Parse for non-map rank_pattern, got {err:?}"
);
let json2 = r#"{
"peft_type": "LORA", "r": 8, "lora_alpha": 16.0,
"target_modules": ["q_proj"], "alpha_pattern": 7
}"#;
assert!(matches!(LoraConfig::from_json(json2), Err(Error::Parse(_))));
}
#[test]
fn config_parse_peft_invalid_regex_exclude_modules_is_err() {
let json = r#"{ "peft_type": "LORA", "r": 8, "target_modules": ["q_proj"],
"exclude_modules": "(unclosed" }"#;
let err = LoraConfig::from_json(json)
.expect_err("an uncompilable `exclude_modules` regex must be rejected");
assert!(matches!(err, Error::Parse(_)), "got {err:?}");
}
#[test]
fn config_parse_peft_exclude_modules_all_linear_is_literal_regex() {
let json = r#"{ "peft_type": "LORA", "r": 8, "target_modules": ["q_proj"],
"exclude_modules": "all-linear" }"#;
let cfg = LoraConfig::from_json(json).unwrap();
let peft = cfg.peft().unwrap();
assert!(
matches!(&peft.exclude_modules, Some(ModuleMatcher::Regex(_))),
"all-linear in exclude_modules must be a literal Regex, not the AllLinear sentinel"
);
}
#[test]
fn base_linear_dense_rejects_non_rank2_weight() {
let w = Array::zeros::<f32>(&(6usize,)).unwrap();
let err = BaseLinear::dense(w, None).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert_eq!(p.actual(), 1);
assert_eq!(p.actual_shape(), &[6]);
assert!(p.context().contains("BaseLinear::dense"));
}
other => panic!("expected Error::RankMismatch, got {other:?}"),
}
}
#[test]
fn base_linear_dense_rejects_mismatched_bias_shape() {
let w = base_weight(); let bad_bias = Array::zeros::<f32>(&(3usize,)).unwrap(); let err = BaseLinear::dense(w, Some(bad_bias)).unwrap_err();
match err {
Error::ShapePairMismatch(p) => {
assert_eq!(p.expected(), &[2]);
assert_eq!(p.actual(), &[3]);
assert!(p.context().contains("bias"));
}
other => panic!("expected Error::ShapePairMismatch, got {other:?}"),
}
let w2 = base_weight();
let rank2_bias = Array::zeros::<f32>(&(2, 1)).unwrap();
assert!(matches!(
BaseLinear::dense(w2, Some(rank2_bias)),
Err(Error::ShapePairMismatch(_))
));
}
#[test]
fn base_linear_quantized_affine_requires_quant_biases() {
let w = Array::zeros::<u32>(&(2, 1)).unwrap();
let scales = Array::zeros::<f32>(&(2, 1)).unwrap();
let err = BaseLinear::quantized(w, scales, None, None, 32, 8, "affine".to_string()).unwrap_err();
match err {
Error::MissingField(p) => {
assert_eq!(p.type_name(), "BaseLinear::quantized");
assert!(p.field().contains("quant_biases"));
}
other => panic!("expected Error::MissingField, got {other:?}"),
}
}
#[test]
fn base_linear_quantized_float_modes_forbid_quant_biases() {
for mode in ["mxfp4", "mxfp8", "nvfp4"] {
let w = Array::zeros::<u32>(&(2, 1)).unwrap();
let scales = Array::zeros::<f32>(&(2, 1)).unwrap();
let qb = Array::zeros::<f32>(&(2, 1)).unwrap();
let err =
BaseLinear::quantized(w, scales, Some(qb), None, 32, 8, mode.to_string()).unwrap_err();
assert!(
matches!(err, Error::InvariantViolation(ref p) if p.context().contains("quant_biases")),
"mode {mode:?}: expected InvariantViolation on quant_biases, got {err:?}"
);
}
}
#[test]
fn base_linear_quantized_unknown_mode_is_err() {
let w = Array::zeros::<u32>(&(2, 1)).unwrap();
let scales = Array::zeros::<f32>(&(2, 1)).unwrap();
let err = BaseLinear::quantized(w, scales, None, None, 32, 8, "int3".to_string()).unwrap_err();
match err {
Error::UnknownEnumValue(p) => {
assert_eq!(p.value(), "int3");
assert_eq!(p.supported(), &["affine", "mxfp4", "mxfp8", "nvfp4"]);
}
other => panic!("expected Error::UnknownEnumValue, got {other:?}"),
}
}
#[test]
fn base_linear_quantized_rejects_nonpositive_bits_and_group_size() {
let w = Array::zeros::<u32>(&(2, 1)).unwrap();
let scales = Array::zeros::<f32>(&(2, 1)).unwrap();
let qb = Array::zeros::<f32>(&(2, 1)).unwrap();
let err = BaseLinear::quantized(
w,
scales,
Some(qb.try_clone().unwrap()),
None,
32,
0,
"affine".to_string(),
)
.unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context().contains("bits") && p.value() == "0"),
"expected OutOfRange on bits, got {err:?}"
);
let w2 = Array::zeros::<u32>(&(2, 1)).unwrap();
let scales2 = Array::zeros::<f32>(&(2, 1)).unwrap();
let err2 =
BaseLinear::quantized(w2, scales2, Some(qb), None, 0, 8, "affine".to_string()).unwrap_err();
assert!(
matches!(err2, Error::OutOfRange(ref p) if p.context().contains("group_size") && p.value() == "0"),
"expected OutOfRange on group_size, got {err2:?}"
);
}
#[test]
fn lora_fuse_requantize_quantized_base_stays_quantized() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let params = AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: None,
};
let x = Array::full::<f32>(&(1usize, input_dims), 1.0).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
let q_layer = LoRALinear::new(q_base, params, 2.0).unwrap();
let mut via_forward = q_layer.forward(&x).unwrap();
let fused = q_layer.fuse(false).unwrap();
assert!(
matches!(fused, BaseLinear::Quantized { .. }),
"fuse(false) over a quantized base must re-quantize"
);
let mut via_fused = fused.base_output(&x).unwrap();
approx_eq(
&via_fused.to_vec::<f32>().unwrap(),
&via_forward.to_vec::<f32>().unwrap(),
2e-2,
);
}
#[test]
fn dora_fuse_requantize_quantized_base_stays_quantized() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let la = Array::full::<f32>(&(input_dims, 2usize), 0.01).unwrap();
let lb = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let m = Array::from_slice::<f32>(&[1.5, 2.5], &(output_dims,)).unwrap();
let bias = Array::from_slice::<f32>(&[0.25, -0.5], &(output_dims,)).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, Some(bias), 32, 8, "affine".to_string()).unwrap();
let q_layer = DoRALinear::new(
q_base,
AdapterParams {
lora_a: la,
lora_b: lb,
magnitude: Some(m),
},
2.0,
)
.unwrap();
let fused = q_layer.fuse(false).unwrap();
assert!(
matches!(fused, BaseLinear::Quantized { .. }),
"QDoRA fuse(false) must re-quantize the fused weight"
);
assert!(
fused.bias().is_some(),
"the original output bias is re-attached"
);
}
#[test]
fn linear_layer_accessors_expose_base_and_scale() {
let bias = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap();
let base = BaseLinear::dense(base_weight(), Some(bias)).unwrap();
let lora = LoRALinear::new(base, plain_params(), 7.0).unwrap();
assert_eq!(lora.scale(), 7.0);
assert!(matches!(lora.base(), BaseLinear::Dense { .. }));
assert!(lora.base().bias().is_some());
let m = Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap();
let dora = DoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(m),
},
9.0,
)
.unwrap();
assert_eq!(dora.scale(), 9.0);
assert!(matches!(dora.base(), BaseLinear::Dense { .. }));
assert_eq!(dora.magnitude().shape(), &[2]);
}
#[test]
fn dora_linear_rejects_wrong_magnitude_shape() {
let bad_m = Array::from_slice::<f32>(&[3.0, 3.0, 3.0], &(3usize,)).unwrap(); let params = AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(bad_m),
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let err = DoRALinear::new(base, params, 2.0).unwrap_err();
match err {
Error::ShapePairMismatch(p) => {
assert_eq!(p.expected(), &[2]);
assert_eq!(p.actual(), &[3]);
assert!(p.context().contains("magnitude"));
}
other => panic!("expected Error::ShapePairMismatch, got {other:?}"),
}
}
#[test]
fn dora_linear_rank_mismatch_uses_dora_validation_context() {
let bad_a = Array::zeros::<f32>(&(3usize,)).unwrap(); let m = Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap();
let params = AdapterParams {
lora_a: bad_a,
lora_b: lora_b(),
magnitude: Some(m),
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
let err = DoRALinear::new(base, params, 2.0).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert!(
p.context().contains("DoRALinear") && p.context().contains("lora_a"),
"context must name the DoRALinear lora_a check: {}",
p.context()
);
}
other => panic!("expected Error::RankMismatch, got {other:?}"),
}
}
#[test]
fn dora_linear_lora_b_rank_and_output_dim_mismatch() {
let bad_b = Array::zeros::<f32>(&(2usize,)).unwrap(); let m = Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap();
let base = BaseLinear::dense(base_weight(), None).unwrap();
let err = DoRALinear::new(
base,
AdapterParams {
lora_a: lora_a(),
lora_b: bad_b,
magnitude: Some(m.try_clone().unwrap()),
},
2.0,
)
.unwrap_err();
assert!(
matches!(err, Error::RankMismatch(ref p) if p.context().contains("DoRALinear") && p.context().contains("lora_b")),
"got {err:?}"
);
let wrong_out_b = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &(2, 3)).unwrap(); let base2 = BaseLinear::dense(base_weight(), None).unwrap();
let err2 = DoRALinear::new(
base2,
AdapterParams {
lora_a: lora_a(),
lora_b: wrong_out_b,
magnitude: Some(m),
},
2.0,
)
.unwrap_err();
assert!(
matches!(err2, Error::LengthMismatch(ref p)
if p.context().contains("DoRALinear") && p.context().contains("output_dims")
&& p.expected() == 2 && p.actual() == 3),
"got {err2:?}"
);
}
#[test]
fn validate_config_rank_catches_lora_b_only_drift() {
let a = Array::zeros::<f32>(&(3usize, 2usize)).unwrap(); let b = Array::zeros::<f32>(&(3usize, 2usize)).unwrap(); let params = AdapterParams {
lora_a: a,
lora_b: b,
magnitude: None,
};
let err = validate_config_rank(¶ms, 2, LinearValidationContext::LoraLinear).unwrap_err();
match err {
Error::LengthMismatch(p) => {
assert!(p.context().contains("lora_b"));
assert_eq!(p.expected(), 2);
assert_eq!(p.actual(), 3);
}
other => panic!("expected Error::LengthMismatch on lora_b, got {other:?}"),
}
let good = AdapterParams {
lora_a: Array::zeros::<f32>(&(3usize, 2usize)).unwrap(),
lora_b: Array::zeros::<f32>(&(2usize, 2usize)).unwrap(),
magnitude: None,
};
assert!(validate_config_rank(&good, 2, LinearValidationContext::DoraLinear).is_ok());
}
#[test]
fn validate_config_rank_dora_context_lora_a_drift() {
let params = AdapterParams {
lora_a: Array::zeros::<f32>(&(3usize, 2usize)).unwrap(), lora_b: Array::zeros::<f32>(&(2usize, 2usize)).unwrap(),
magnitude: None,
};
let err = validate_config_rank(¶ms, 4, LinearValidationContext::DoraLinear).unwrap_err();
match err {
Error::LengthMismatch(p) => {
assert!(
p.context().contains("DoRALinear") && p.context().contains("lora_a"),
"DoRALinear config-rank lora_a context: {}",
p.context()
);
assert_eq!(p.expected(), 4);
assert_eq!(p.actual(), 2);
}
other => panic!("expected Error::LengthMismatch, got {other:?}"),
}
}
#[test]
fn lora_linear_rejects_rank1_lora_a_with_lora_context() {
let bad_a = Array::zeros::<f32>(&(3usize,)).unwrap(); let params = AdapterParams {
lora_a: bad_a,
lora_b: lora_b(),
magnitude: None,
};
let base = BaseLinear::dense(base_weight(), None).unwrap();
match LoRALinear::new(base, params, 2.0).unwrap_err() {
Error::RankMismatch(p) => {
assert!(
p.context().contains("LoRALinear") && p.context().contains("lora_a"),
"context: {}",
p.context()
);
}
other => panic!("expected Error::RankMismatch, got {other:?}"),
}
}
#[test]
fn dora_linear_shared_rank_and_input_dim_mismatch_use_dora_context() {
let m = Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap();
let base = BaseLinear::dense(base_weight(), None).unwrap();
let err = DoRALinear::new(
base,
AdapterParams {
lora_a: Array::zeros::<f32>(&(3usize, 2usize)).unwrap(), lora_b: Array::zeros::<f32>(&(3usize, 2usize)).unwrap(), magnitude: Some(m.try_clone().unwrap()),
},
2.0,
)
.unwrap_err();
assert!(
matches!(err, Error::LengthMismatch(ref p)
if p.context().contains("DoRALinear") && p.context().contains("shared rank")),
"got {err:?}"
);
let base2 = BaseLinear::dense(base_weight(), None).unwrap();
let err2 = DoRALinear::new(
base2,
AdapterParams {
lora_a: Array::zeros::<f32>(&(2usize, 2usize)).unwrap(), lora_b: lora_b(),
magnitude: Some(m),
},
2.0,
)
.unwrap_err();
assert!(
matches!(err2, Error::LengthMismatch(ref p)
if p.context().contains("DoRALinear") && p.context().contains("input_dims")
&& p.expected() == 3 && p.actual() == 2),
"got {err2:?}"
);
}
#[test]
fn base_embedding_dense_rejects_non_rank2_and_exposes_weight() {
let bad = Array::zeros::<f32>(&(6usize,)).unwrap();
match BaseEmbedding::dense(bad).unwrap_err() {
Error::RankMismatch(p) => {
assert_eq!(p.actual(), 1);
assert!(p.context().contains("BaseEmbedding::dense"));
}
other => panic!("expected Error::RankMismatch, got {other:?}"),
}
let w = Array::zeros::<f32>(&(4usize, 3usize)).unwrap();
let base = BaseEmbedding::dense(w).unwrap();
assert_eq!(base.weight().shape(), &[4, 3]);
}
#[test]
fn dora_embedding_rejects_wrong_magnitude_shape() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
let base = BaseEmbedding::dense(Array::zeros::<f32>(&(num_embeddings, dims)).unwrap()).unwrap();
let bad_m = Array::zeros::<f32>(&(num_embeddings + 1,)).unwrap(); let params = AdapterParams {
lora_a: Array::zeros::<f32>(&(num_embeddings, r)).unwrap(),
lora_b: Array::zeros::<f32>(&(r, dims)).unwrap(),
magnitude: Some(bad_m),
};
let err = DoRAEmbedding::new(base, params, 1.0).unwrap_err();
match err {
Error::ShapePairMismatch(p) => {
assert_eq!(p.expected(), &[3]);
assert_eq!(p.actual(), &[4]);
assert!(p.context().contains("num_embeddings"));
}
other => panic!("expected Error::ShapePairMismatch, got {other:?}"),
}
}
#[test]
fn dora_embedding_accessors_expose_base_scale_magnitude() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
let base = BaseEmbedding::dense(Array::zeros::<f32>(&(num_embeddings, dims)).unwrap()).unwrap();
let m = Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(num_embeddings,)).unwrap();
let layer = DoRAEmbedding::new(
base,
AdapterParams {
lora_a: Array::zeros::<f32>(&(num_embeddings, r)).unwrap(),
lora_b: Array::zeros::<f32>(&(r, dims)).unwrap(),
magnitude: Some(m),
},
5.0,
)
.unwrap();
assert_eq!(layer.scale(), 5.0);
assert_eq!(layer.base().weight().shape(), &[3, 3]);
assert_eq!(layer.magnitude().shape(), &[3]);
}
#[test]
fn dora_embedding_rejects_non_rank2_lora_a_and_lora_b() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
let m = Array::zeros::<f32>(&(num_embeddings,)).unwrap();
let base = BaseEmbedding::dense(Array::zeros::<f32>(&(num_embeddings, dims)).unwrap()).unwrap();
let err = DoRAEmbedding::new(
base,
AdapterParams {
lora_a: Array::zeros::<f32>(&(num_embeddings,)).unwrap(), lora_b: Array::zeros::<f32>(&(r, dims)).unwrap(),
magnitude: Some(m.try_clone().unwrap()),
},
1.0,
)
.unwrap_err();
assert!(
matches!(err, Error::RankMismatch(ref p) if p.context().contains("DoRAEmbedding") && p.context().contains("lora_a")),
"got {err:?}"
);
let base2 = BaseEmbedding::dense(Array::zeros::<f32>(&(num_embeddings, dims)).unwrap()).unwrap();
let err2 = DoRAEmbedding::new(
base2,
AdapterParams {
lora_a: Array::zeros::<f32>(&(num_embeddings, r)).unwrap(),
lora_b: Array::zeros::<f32>(&(r,)).unwrap(), magnitude: Some(m),
},
1.0,
)
.unwrap_err();
assert!(
matches!(err2, Error::RankMismatch(ref p) if p.context().contains("DoRAEmbedding") && p.context().contains("lora_b")),
"got {err2:?}"
);
}
#[test]
fn dora_embedding_rejects_shared_rank_and_dims_mismatch() {
let num_embeddings = 3usize;
let dims = 3usize;
let m = Array::zeros::<f32>(&(num_embeddings,)).unwrap();
let base = BaseEmbedding::dense(Array::zeros::<f32>(&(num_embeddings, dims)).unwrap()).unwrap();
let err = DoRAEmbedding::new(
base,
AdapterParams {
lora_a: Array::zeros::<f32>(&(num_embeddings, 2usize)).unwrap(), lora_b: Array::zeros::<f32>(&(3usize, dims)).unwrap(), magnitude: Some(m.try_clone().unwrap()),
},
1.0,
)
.unwrap_err();
assert!(
matches!(err, Error::LengthMismatch(ref p) if p.context().contains("shared rank")),
"got {err:?}"
);
let base2 = BaseEmbedding::dense(Array::zeros::<f32>(&(num_embeddings, dims)).unwrap()).unwrap();
let err2 = DoRAEmbedding::new(
base2,
AdapterParams {
lora_a: Array::zeros::<f32>(&(num_embeddings, 2usize)).unwrap(), lora_b: Array::zeros::<f32>(&(2usize, dims + 1)).unwrap(), magnitude: Some(m),
},
1.0,
)
.unwrap_err();
assert!(
matches!(err2, Error::LengthMismatch(ref p)
if p.context().contains("dims") && p.expected() == dims && p.actual() == dims + 1),
"got {err2:?}"
);
}
fn one_hot_dora_embedding_layer() -> LoraLayer {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
#[rustfmt::skip]
let weight = Array::from_slice::<f32>(
&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
&(num_embeddings, dims),
).unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let params = AdapterParams {
lora_a: Array::zeros::<f32>(&(num_embeddings, r)).unwrap(),
lora_b: Array::zeros::<f32>(&(r, dims)).unwrap(),
magnitude: Some(Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(num_embeddings,)).unwrap()),
};
LoraLayer::DoraEmbedding(DoRAEmbedding::new(base, params, 2.0).unwrap())
}
#[test]
fn lora_layer_forward_dispatches_per_variant() {
let lora = LoraLayer::Lora(
LoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
plain_params(),
2.0,
)
.unwrap(),
);
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut o1 = lora.forward(&x).unwrap();
approx_eq(&o1.to_vec::<f32>().unwrap(), &[3.0, 6.0], 1e-5);
let dora = LoraLayer::Dora(
DoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap()),
},
2.0,
)
.unwrap(),
);
let mut o2 = dora.forward(&x).unwrap();
approx_eq(&o2.to_vec::<f32>().unwrap(), &[3.0, 6.0], 1e-5);
let emb = one_hot_dora_embedding_layer();
let ids = Array::from_slice::<i32>(&[0, 2], &(2usize,)).unwrap();
let mut o3 = emb.forward(&ids).unwrap();
approx_eq(
&o3.to_vec::<f32>().unwrap(),
&[1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
1e-5,
);
}
#[test]
fn lora_layer_base_and_base_embedding_accessors() {
let lora = LoraLayer::Lora(
LoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
plain_params(),
2.0,
)
.unwrap(),
);
assert!(lora.base().is_some());
assert!(lora.base_embedding().is_none());
let emb = one_hot_dora_embedding_layer();
assert!(emb.base().is_none());
assert!(emb.base_embedding().is_some());
assert_eq!(emb.base_embedding().unwrap().weight().shape(), &[3, 3]);
}
#[test]
fn lora_layer_fuse_on_embedding_variant_is_err() {
let emb = one_hot_dora_embedding_layer();
match emb.fuse(false).unwrap_err() {
Error::InvariantViolation(p) => {
assert!(p.context().contains("LoraLayer::fuse"));
assert!(p.requirement().contains("fuse_embedding"));
}
other => panic!("expected Error::InvariantViolation, got {other:?}"),
}
}
#[test]
fn lora_layer_fuse_embedding_on_linear_variant_is_err() {
let lora = LoraLayer::Lora(
LoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
plain_params(),
2.0,
)
.unwrap(),
);
match lora.fuse_embedding().unwrap_err() {
Error::InvariantViolation(p) => {
assert!(p.context().contains("LoraLayer::fuse_embedding"));
assert!(p.requirement().contains("call `fuse`"));
}
other => panic!("expected Error::InvariantViolation, got {other:?}"),
}
let dora = LoraLayer::Dora(
DoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap()),
},
2.0,
)
.unwrap(),
);
assert!(matches!(
dora.fuse_embedding(),
Err(Error::InvariantViolation(_))
));
}
#[test]
fn lora_layer_fuse_embedding_round_trips_via_as_linear() {
let emb = one_hot_dora_embedding_layer();
let fused = emb.fuse_embedding().unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut out = fused.as_linear(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[1.0, 2.0, 3.0], 1e-5);
}
#[test]
fn lora_layer_fuse_on_linear_variants_ok() {
let lora = LoraLayer::Lora(
LoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
plain_params(),
2.0,
)
.unwrap(),
);
assert!(lora.fuse(false).is_ok());
let dora = LoraLayer::Dora(
DoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap()),
},
2.0,
)
.unwrap(),
);
assert!(dora.fuse(false).is_ok());
}
#[test]
fn base_input_dims_recovers_logical_width_for_quantized_base() {
let input_dims = 64usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(2, input_dims)).unwrap();
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let q_base =
BaseLinear::quantized(w_q, scales, biases, None, 32, 8, "affine".to_string()).unwrap();
assert_eq!(base_input_dims(&q_base).unwrap(), 64);
assert_eq!(base_output_dims(&q_base).unwrap(), 2);
let dense_base = BaseLinear::dense(base_weight(), None).unwrap();
assert_eq!(base_input_dims(&dense_base).unwrap(), 3);
assert_eq!(base_output_dims(&dense_base).unwrap(), 2);
}
#[test]
fn mlxlm_autodiscovery_skips_non_rank2_weights() {
let mut weights = Weights::new();
weights.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
base_weight(),
); weights.insert(
"model.layers.0.input_layernorm.weight".to_string(),
Array::zeros::<f32>(&(2usize,)).unwrap(),
); let mut params = HashMap::new();
params.insert(
"model.layers.0.self_attn.q_proj".to_string(),
plain_params(),
);
let cfg = mlxlm_config(16, keyed_params(Vec::new()));
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 1).unwrap();
assert_eq!(layers.len(), 1, "only the rank-2 q_proj is auto-discovered");
assert!(layers.contains_key("model.layers.0.self_attn.q_proj"));
assert!(!layers.contains_key("model.layers.0.input_layernorm"));
}
#[test]
fn peft_autodiscovery_none_target_skips_non_rank2() {
let peft = PeftSelection {
target_modules: None,
exclude_modules: None,
layers_to_transform: None,
layers_pattern: Vec::new(),
rank_pattern: Vec::new(),
alpha_pattern: Vec::new(),
use_rslora: false,
fan_in_fan_out: false,
};
let rank2 = base_weight();
let rank1 = Array::zeros::<f32>(&(4usize,)).unwrap();
assert!(peft_module_is_selected(
"model.layers.0.self_attn.q_proj",
&rank2,
&peft
));
assert!(!peft_module_is_selected(
"model.layers.0.input_layernorm",
&rank1,
&peft
));
}
#[test]
fn peft_layers_to_transform_no_extractable_index_deselects() {
let peft = PeftSelection {
target_modules: Some(ModuleMatcher::List(vec!["lm_head".to_string()])),
exclude_modules: None,
layers_to_transform: Some(vec![0]),
layers_pattern: Vec::new(),
rank_pattern: Vec::new(),
alpha_pattern: Vec::new(),
use_rslora: false,
fan_in_fan_out: false,
};
assert!(!peft_module_is_selected("lm_head", &base_weight(), &peft));
assert!(peft_module_is_selected(
"model.layers.0.lm_head",
&base_weight(),
&peft
));
}
#[test]
fn peft_layers_to_transform_index_not_in_list_deselects() {
let peft = PeftSelection {
target_modules: Some(ModuleMatcher::List(vec!["q_proj".to_string()])),
exclude_modules: None,
layers_to_transform: Some(vec![5]),
layers_pattern: Vec::new(),
rank_pattern: Vec::new(),
alpha_pattern: Vec::new(),
use_rslora: false,
fan_in_fan_out: false,
};
assert!(!peft_module_is_selected(
"model.layers.0.self_attn.q_proj",
&base_weight(),
&peft
));
assert!(peft_module_is_selected(
"model.layers.5.self_attn.q_proj",
&base_weight(),
&peft
));
}
#[test]
fn linear_to_lora_layers_skips_non_weight_keys() {
let mut weights = Weights::new();
weights.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
base_weight(),
);
weights.insert(
"model.layers.0.self_attn.q_proj.scales".to_string(),
Array::zeros::<f32>(&(2, 1)).unwrap(),
);
weights.insert("metadata_only_key".to_string(), base_weight());
let mut params = HashMap::new();
params.insert(
"model.layers.0.self_attn.q_proj".to_string(),
plain_params(),
);
let cfg = mlxlm_config(16, keyed_params(vec!["self_attn.q_proj".to_string()]));
let layers = linear_to_lora_layers(&weights, &cfg, ¶ms, None, 1).unwrap();
assert_eq!(layers.len(), 1);
assert!(layers.contains_key("model.layers.0.self_attn.q_proj"));
}
#[test]
fn load_adapters_nonpositive_rank_is_out_of_range() {
let tmp = std::env::temp_dir().join(format!("mlxrs_zerorank_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cfg = LoraConfig {
fine_tune_type: FineTuneType::Lora,
lora_parameters: LoraParameters {
rank: 0,
scale: Some(2.0),
alpha: None,
keys: vec!["self_attn.q_proj".to_string()],
dropout: None,
},
use_dora: false,
selection: AdapterSelection::MlxLm { num_layers: 16 },
};
let weights = toy_weights();
let err = load_adapters_with_config(&weights, &tmp, &cfg, None, 4).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert!(p.context().contains("rank"));
assert_eq!(p.value(), "0");
}
other => panic!("expected Error::OutOfRange, got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn load_adapters_no_weights_file_at_all_is_not_found() {
let tmp = std::env::temp_dir().join(format!("mlxrs_nost_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let config = r#"{
"fine_tune_type": "lora", "num_layers": 16,
"lora_parameters": { "rank": 2, "scale": 2.0, "keys": ["self_attn.q_proj"] }
}"#;
std::fs::write(tmp.join("adapter_config.json"), config).unwrap();
let weights = toy_weights();
let err = load_adapters(&weights, &tmp, None, 4).unwrap_err();
match err {
Error::FileIo(p) => {
assert_eq!(p.path(), tmp.join("adapters.safetensors").as_path());
assert_eq!(p.op(), FileOp::Open);
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
}
other => panic!("expected Error::FileIo(NotFound), got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn split_adapter_params_lora_a_without_lora_b_is_err() {
let mut arrays: HashMap<String, Array> = HashMap::new();
arrays.insert("p.lora_a".to_string(), lora_a());
let err = split_adapter_params(arrays, false).unwrap_err();
match err {
Error::MissingKey(p) => {
assert_eq!(p.key(), "p.lora_b");
assert!(p.context().contains("no matching `lora_b`"));
}
other => panic!("expected Error::MissingKey, got {other:?}"),
}
}
#[test]
fn split_adapter_params_lora_b_without_lora_a_is_err() {
let mut arrays: HashMap<String, Array> = HashMap::new();
arrays.insert("p.lora_b".to_string(), lora_b());
let err = split_adapter_params(arrays, false).unwrap_err();
match err {
Error::MissingKey(p) => {
assert_eq!(p.key(), "p.lora_a");
assert!(p.context().contains("no matching `lora_a`"));
}
other => panic!("expected Error::MissingKey, got {other:?}"),
}
}
#[test]
fn split_adapter_params_dora_missing_m_is_err_but_lora_ignores_stray_m() {
let mut arrays: HashMap<String, Array> = HashMap::new();
arrays.insert("p.lora_a".to_string(), lora_a());
arrays.insert("p.lora_b".to_string(), lora_b());
let err = split_adapter_params(arrays, true).unwrap_err();
assert!(
matches!(err, Error::MissingKey(ref p) if p.key() == "p.m"),
"expected MissingKey naming p.m, got {err:?}"
);
let mut arrays2: HashMap<String, Array> = HashMap::new();
arrays2.insert("p.lora_a".to_string(), lora_a());
arrays2.insert("p.lora_b".to_string(), lora_b());
arrays2.insert(
"p.m".to_string(),
Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap(),
);
arrays2.insert("unrelated.weight".to_string(), base_weight());
let out = split_adapter_params(arrays2, false).unwrap();
assert_eq!(out.len(), 1);
let p = out.get("p").unwrap();
assert_eq!(p.lora_a.shape(), &[3, 2]);
assert_eq!(p.lora_b.shape(), &[2, 2]);
}
#[test]
fn read_adapter_config_parses_valid_config() {
let tmp = std::env::temp_dir().join(format!("mlxrs_readcfg_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let config = r#"{
"fine_tune_type": "dora", "num_layers": 5,
"lora_parameters": { "rank": 4, "alpha": 16.0 }
}"#;
std::fs::write(tmp.join("adapter_config.json"), config).unwrap();
let cfg = read_adapter_config(&tmp).unwrap();
assert!(cfg.is_dora());
assert_eq!(cfg.rank(), 4);
assert_eq!(cfg.scale(), 4.0); assert!(matches!(
cfg.selection,
AdapterSelection::MlxLm { num_layers: 5 }
));
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn read_adapter_config_missing_file_is_file_io() {
let tmp = std::env::temp_dir().join(format!("mlxrs_readcfg_missing_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
match read_adapter_config(&tmp).unwrap_err() {
Error::FileIo(p) => {
assert_eq!(p.path(), tmp.join("adapter_config.json").as_path());
assert_eq!(p.op(), FileOp::Open);
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
}
other => panic!("expected Error::FileIo(NotFound, Open), got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn read_adapter_config_non_regular_file_is_err() {
let tmp = std::env::temp_dir().join(format!("mlxrs_readcfg_dir_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
std::fs::create_dir_all(tmp.join("adapter_config.json")).unwrap();
match read_adapter_config(&tmp).unwrap_err() {
Error::FileIo(p) => {
assert_eq!(p.path(), tmp.join("adapter_config.json").as_path());
assert_eq!(p.op(), FileOp::Stat);
assert_eq!(p.inner().kind(), std::io::ErrorKind::InvalidInput);
}
other => panic!("expected Error::FileIo(InvalidInput, Stat), got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn read_adapter_config_oversized_body_is_cap_exceeded() {
let tmp = std::env::temp_dir().join(format!("mlxrs_readcfg_big_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let cap = crate::lm::load::MAX_CONFIG_BYTES;
let mut body = String::from("{}");
body.push_str(&" ".repeat((cap as usize) + 1));
std::fs::write(tmp.join("adapter_config.json"), &body).unwrap();
match read_adapter_config(&tmp).unwrap_err() {
Error::CapExceeded(p) => {
assert_eq!(p.cap_name(), "MAX_CONFIG_BYTES");
assert_eq!(p.cap(), cap);
assert!(p.observed() > cap);
}
other => panic!("expected Error::CapExceeded, got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn read_adapter_config_non_utf8_body_is_parse_error() {
let tmp = std::env::temp_dir().join(format!("mlxrs_readcfg_badutf8_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
std::fs::write(tmp.join("adapter_config.json"), [0xFFu8, 0xFE, 0x00]).unwrap();
match read_adapter_config(&tmp).unwrap_err() {
Error::Parse(p) => {
assert!(p.context().contains("adapter_config.json"));
}
other => panic!("expected Error::Parse for non-UTF-8 body, got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn base_input_dims_quantized_non_rank2_weight_is_rank_mismatch() {
let bad_w = Array::zeros::<u32>(&(8usize,)).unwrap(); let scales = Array::zeros::<f32>(&(8usize,)).unwrap();
let qbiases = Array::zeros::<f32>(&(8usize,)).unwrap();
let base = BaseLinear::quantized(
bad_w,
scales,
Some(qbiases),
None,
32,
8,
"affine".to_string(),
)
.unwrap();
match base_input_dims(&base).unwrap_err() {
Error::RankMismatch(p) => {
assert_eq!(p.actual(), 1);
assert_eq!(p.actual_shape(), &[8]);
assert!(p.context().contains("quantized base weight"));
}
other => panic!("expected Error::RankMismatch, got {other:?}"),
}
}
#[test]
fn base_output_dims_rank0_quantized_weight_is_rank_mismatch() {
let scalar_w = Array::from_slice::<u32>(&[0u32], &[0i32; 0]).unwrap(); let scales = Array::zeros::<f32>(&(1usize,)).unwrap();
let qbiases = Array::zeros::<f32>(&(1usize,)).unwrap();
let base = BaseLinear::quantized(
scalar_w,
scales,
Some(qbiases),
None,
32,
8,
"affine".to_string(),
)
.unwrap();
match base_output_dims(&base).unwrap_err() {
Error::RankMismatch(p) => {
assert_eq!(p.actual(), 0);
assert!(p.actual_shape().is_empty());
assert!(p.context().contains("output_dims"));
}
other => panic!("expected Error::RankMismatch, got {other:?}"),
}
}
#[test]
fn build_base_linear_quantized_success_builds_quantized_base() {
let input_dims = 64usize;
let output_dims = 2usize;
let mut wdata = vec![1.0f32; input_dims];
wdata.extend(vec![0.5f32; input_dims]);
let dense_w = Array::from_slice::<f32>(&wdata, &(output_dims, input_dims)).unwrap();
let x = Array::full::<f32>(&(1usize, input_dims), 1.0).unwrap();
let path = "model.layers.0.self_attn.q_proj";
let weight_key = format!("{path}.weight");
let (w_q, scales, biases) = ops::quantized::quantize(&dense_w, 32, 8, "affine", None).unwrap();
let out_bias = Array::from_slice::<f32>(&[10.0, -20.0], &(output_dims,)).unwrap();
let mut weights = Weights::new();
weights.insert(weight_key.clone(), w_q);
weights.insert(format!("{path}.scales"), scales);
weights.insert(format!("{path}.biases"), biases.unwrap());
weights.insert(format!("{path}.bias"), out_bias);
let quant = crate::lm::quant::PerLayerQuantization::from_global(crate::lm::quant::Quantization {
group_size: 32,
bits: 8,
mode: crate::lm::quant::QuantMode::Affine,
});
let weight_ref = &weights[&weight_key];
let base = build_base_linear(
&weights,
path,
weight_ref,
Some(&quant),
false, )
.unwrap();
assert!(
matches!(base, BaseLinear::Quantized { .. }),
"a resolvable quant + .scales sibling must build a quantized base"
);
assert!(base.bias().is_some());
let mut out = base.base_output(&x).unwrap();
approx_eq(&out.to_vec::<f32>().unwrap(), &[74.0, 12.0], 2e-1);
}
#[test]
fn requantize_fused_on_dense_base_returns_dense_unchanged() {
let w = base_weight(); let bias = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap();
let dense = BaseLinear::dense(base_weight(), None).unwrap();
let out = dense
.requantize_fused(w, Some(bias))
.expect("dense requantize_fused must return a dense base");
assert!(
matches!(out, BaseLinear::Dense { .. }),
"requantize_fused over a dense base must stay dense"
);
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let mut y = out.base_output(&x).unwrap();
approx_eq(&y.to_vec::<f32>().unwrap(), &[2.0, 4.0], 1e-5);
}
#[test]
fn check_adapter_safetensors_missing_file_is_file_io_open() {
let tmp = std::env::temp_dir().join(format!("mlxrs_ckst_missing_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let missing = tmp.join("does_not_exist.safetensors");
match check_adapter_safetensors(&missing).unwrap_err() {
Error::FileIo(p) => {
assert_eq!(p.path(), missing.as_path());
assert_eq!(p.op(), FileOp::Open);
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
}
other => panic!("expected Error::FileIo(NotFound, Open), got {other:?}"),
}
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn check_adapter_safetensors_regular_file_within_cap_is_ok() {
let tmp = std::env::temp_dir().join(format!("mlxrs_ckst_ok_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let f = tmp.join("adapters.safetensors");
std::fs::write(
&f,
b"not a real safetensors blob, but a regular file under cap",
)
.unwrap();
assert!(
check_adapter_safetensors(&f).is_ok(),
"a small regular file must pass the pre-mmap stat gate"
);
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn lora_layer_base_on_dora_variant_returns_some() {
let dora = LoraLayer::Dora(
DoRALinear::new(
BaseLinear::dense(base_weight(), None).unwrap(),
AdapterParams {
lora_a: lora_a(),
lora_b: lora_b(),
magnitude: Some(Array::from_slice::<f32>(&[3.0, 3.0], &(2usize,)).unwrap()),
},
2.0,
)
.unwrap(),
);
let base = dora.base().expect("Dora variant exposes a BaseLinear");
assert!(base.bias().is_none());
assert!(dora.base_embedding().is_none());
}
#[test]
fn validate_config_rank_dora_context_lora_b_only_drift() {
let params = AdapterParams {
lora_a: Array::zeros::<f32>(&(3usize, 2usize)).unwrap(), lora_b: Array::zeros::<f32>(&(3usize, 2usize)).unwrap(), magnitude: None,
};
let err = validate_config_rank(¶ms, 2, LinearValidationContext::DoraLinear).unwrap_err();
match err {
Error::LengthMismatch(p) => {
assert!(p.context().contains("DoRALinear"));
assert!(p.context().contains("lora_b"));
assert_eq!(p.expected(), 2);
assert_eq!(p.actual(), 3);
}
other => panic!("expected Error::LengthMismatch on lora_b (DoRA context), got {other:?}"),
}
}
#[test]
fn base_embedding_as_linear_tied_head_hand_traced() {
let weight = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 0.0, 1.0], &(1, 3)).unwrap();
let mut out = base.as_linear(&x).unwrap();
assert_eq!(out.shape(), &[1, 2]);
approx_eq(&out.to_vec::<f32>().unwrap(), &[4.0, 10.0], 1e-5);
}