use std::path::{Path, PathBuf};
use crate::{
error::{
ConvertPostSavePartialPayload, DurabilityWarningPayload, Error, FileIoPayload, FileOp,
InvariantViolationPayload, LayerKeyedPayload, ParsePayload, Result,
},
lm::{
convert::{self, TOKENIZER_EXTRA_FILES},
load::{self, Weights},
lora::{self, BaseEmbedding, BaseLinear, LoraLayer},
quant::{self, PerLayerQuantization},
},
};
pub fn fuse(
model_path: &Path,
adapter_path: &Path,
save_path: &Path,
dequantize: bool,
) -> Result<()> {
reject_hub_url("model_path", model_path)?;
reject_hub_url("adapter_path", adapter_path)?;
let (cfg_typed, config_json_text) = load::load_config(model_path)?;
let weights = load::load_weights(model_path)?;
std::fs::create_dir_all(save_path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"fuse: cannot create save_path",
FileOp::Create,
save_path.to_path_buf(),
e,
))
})?;
let staging = StagingDir::create(save_path)?;
match convert::copy_tokenizer_and_extras(model_path, staging.path()) {
Ok(_outcome) => {
}
Err(snapshot_err) => {
return Err(snapshot_err);
}
}
if let Err(e) = load::load_tokenizer(staging.path(), &cfg_typed) {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
model_path.display().to_string(),
e,
)));
}
let parsed_quant = quant::parse_quantization(&config_json_text)?;
let lora_cfg = lora::read_adapter_config(adapter_path)?;
let fan_in_fan_out = lora_cfg.fan_in_fan_out();
let layers = lora::load_adapters_with_config(
&weights,
adapter_path,
&lora_cfg,
parsed_quant.as_ref(),
cfg_typed.num_hidden_layers,
)?;
let mut weights = weights;
for (path, layer) in &layers {
apply_fuse_to_weights(&mut weights, path, layer, dequantize, fan_in_fan_out)?;
}
let (out_weights, out_config_json, save_quant) = if dequantize {
let stripped_config = strip_quantization_blocks(&config_json_text)?;
let walk_quant = parsed_quant.unwrap_or_default();
let dense_weights = quant::dequantize_weights(weights, &walk_quant)?;
(
dense_weights,
stripped_config,
PerLayerQuantization::default(),
)
} else {
let save_quant = parsed_quant.unwrap_or_default();
(weights, config_json_text, save_quant)
};
let save_warning: Option<std::io::Error> =
match load::save(save_path, &out_weights, &out_config_json, &save_quant) {
Ok(()) => None,
Err(Error::DurabilityWarning(p)) if p.committed() => Some(p.into_source()),
Err(e) => return Err(e),
};
let promote_outcome = promote_staging_into_save_path(staging, save_path);
match promote_outcome {
Ok(post_promote) => {
let aggregate = crate::error::ConvertDurabilityWarnings {
committed: true,
save: save_warning,
post_copy_file: post_promote.post_promote_file,
post_copy_dir: post_promote.post_promote_dir,
};
match aggregate.count() {
0 => Ok(()),
1 => {
let (_, save, post_copy_file, post_copy_dir) = aggregate.into_parts();
let source = save
.or(post_copy_file)
.or(post_copy_dir)
.expect("count() == 1 guarantees exactly one Some field");
Err(Error::DurabilityWarning(DurabilityWarningPayload::new(
true, source,
)))
}
_ => Err(Error::ConvertDurabilityWarnings(aggregate)),
}
}
Err(promote_err) => Err(Error::ConvertPostSavePartial(
ConvertPostSavePartialPayload::new(true, save_warning, promote_err),
)),
}
}
#[derive(Debug)]
struct StagingDir {
path: Option<PathBuf>,
}
impl StagingDir {
fn create(parent: &Path) -> Result<Self> {
use std::{
fs::create_dir,
io::ErrorKind,
sync::atomic::{AtomicU64, Ordering},
time::{SystemTime, UNIX_EPOCH},
};
static COUNTER: AtomicU64 = AtomicU64::new(0);
const MAX_RETRIES: u32 = 16;
let pid = std::process::id();
let mut last_err: Option<std::io::Error> = None;
for _ in 0..MAX_RETRIES {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
let rand = nanos ^ counter.rotate_left(17);
let candidate = parent.join(format!(".staging-fuse-{pid}-{rand:016x}"));
match create_dir(&candidate) {
Ok(()) => {
return Ok(StagingDir {
path: Some(candidate),
});
}
Err(e) if e.kind() == ErrorKind::AlreadyExists => {
last_err = Some(e);
continue;
}
Err(e) => {
return Err(Error::FileIo(FileIoPayload::new(
"fuse: cannot create staging dir",
FileOp::Create,
parent.to_path_buf(),
e,
)));
}
}
}
Err(Error::FileIo(FileIoPayload::new(
"fuse: exhausted staging-dir create_dir retries (MAX_RETRIES collisions — likely a \
hostile staging-dir race or a filesystem refusing mkdir)",
FileOp::Create,
parent.to_path_buf(),
last_err.unwrap_or_else(|| std::io::Error::from(std::io::ErrorKind::AlreadyExists)),
)))
}
fn path(&self) -> &Path {
self
.path
.as_deref()
.expect("StagingDir::path called after consume — should be unreachable")
}
fn consume(mut self) -> PathBuf {
self
.path
.take()
.expect("StagingDir::consume called twice — should be unreachable")
}
}
impl Drop for StagingDir {
fn drop(&mut self) {
if let Some(path) = self.path.take()
&& let Err(e) = std::fs::remove_dir_all(&path)
{
eprintln!(
"fuse: warning — could not remove staging dir {}: {e}",
path.display()
);
}
}
}
struct PostPromoteWarnings {
post_promote_file: Option<std::io::Error>,
post_promote_dir: Option<std::io::Error>,
}
fn promote_staging_into_save_path(
staging: StagingDir,
save_path: &Path,
) -> Result<PostPromoteWarnings> {
let post_promote_file = promote_staging_inner(staging.path(), save_path)?;
let staging_path = staging.consume();
if let Err(e) = std::fs::remove_dir(&staging_path) {
eprintln!(
"fuse: warning — could not remove empty staging dir {}: {e}",
staging_path.display()
);
}
let post_promote_dir = crate::lm::load::fsync_dir(save_path).err();
Ok(PostPromoteWarnings {
post_promote_file,
post_promote_dir,
})
}
fn promote_staging_inner(staging: &Path, save_path: &Path) -> Result<Option<std::io::Error>> {
use std::collections::HashSet;
let mut staged_names: HashSet<String> = HashSet::new();
let entries = std::fs::read_dir(staging).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"fuse: cannot read staging dir",
FileOp::Read,
staging.to_path_buf(),
e,
))
})?;
let mut staged_paths: Vec<PathBuf> = Vec::new();
for entry in entries {
let entry = entry.map_err(|e| {
Error::FileIo(FileIoPayload::new(
"fuse: cannot read entry in staging dir",
FileOp::Read,
staging.to_path_buf(),
e,
))
})?;
let path = entry.path();
if !path.is_file() {
continue;
}
let Some(name) = path.file_name().and_then(|n| n.to_str()) else {
continue;
};
staged_names.insert(name.to_string());
staged_paths.push(path);
}
let mut post_promote_file: Option<std::io::Error> = None;
for staged_path in &staged_paths {
let Some(name) = staged_path.file_name() else {
continue;
};
let dst = save_path.join(name);
std::fs::rename(staged_path, &dst).map_err(|e| {
crate::Error::FileIo(FileIoPayload::new(
"fuse: cannot promote staged file to",
crate::error::FileOp::Rename,
dst.clone(),
e,
))
})?;
if let Err(e) = crate::lm::load::fsync_path_io(&dst) {
if post_promote_file.is_none() {
post_promote_file = Some(std::io::Error::new(
e.kind(),
format!("fuse: fsync {} failed: {e}", dst.display()),
));
}
}
}
for name in TOKENIZER_EXTRA_FILES {
if staged_names.contains(*name) {
continue;
}
let candidate = save_path.join(name);
remove_stale_reserved_path(&candidate, name)?;
}
let entries = std::fs::read_dir(save_path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"fuse: cannot read save_path",
FileOp::Read,
save_path.to_path_buf(),
e,
))
})?;
for entry in entries {
let entry = entry.map_err(|e| {
Error::FileIo(FileIoPayload::new(
"fuse: cannot read entry in save_path",
FileOp::Read,
save_path.to_path_buf(),
e,
))
})?;
let path = entry.path();
let Some(name) = path.file_name().and_then(|n| n.to_str()) else {
continue;
};
if !name.ends_with(".py") {
continue;
}
if staged_names.contains(name) {
continue;
}
remove_stale_reserved_path(&path, name)?;
}
Ok(post_promote_file)
}
fn remove_stale_reserved_path(path: &Path, name: &str) -> Result<()> {
let meta = match std::fs::symlink_metadata(path) {
Ok(m) => m,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
Err(e) => {
return Err(Error::FileIo(FileIoPayload::new(
"fuse: cannot stat stale destination path",
FileOp::Stat,
path.to_path_buf(),
e,
)));
}
};
let ft = meta.file_type();
if ft.is_file() {
std::fs::remove_file(path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"fuse: cannot remove stale destination file",
FileOp::Remove,
path.to_path_buf(),
e,
))
})?;
return Ok(());
}
let kind: &'static str = if ft.is_symlink() {
"symlink"
} else if ft.is_dir() {
"directory"
} else {
"non-regular file (FIFO, socket, or device)"
};
Err(Error::LayerKeyed(LayerKeyedPayload::new(
format!("{} ({})", path.display(), name),
Error::InvariantViolation(InvariantViolationPayload::new(
"fuse: stale destination path",
match kind {
"symlink" => {
"must not be a symlink (non-regular reserved path; remove manually or use a fresh save destination)"
}
"directory" => {
"must not be a directory (non-regular reserved path; remove manually or use a fresh save destination)"
}
_ => {
"must not be a non-regular file (FIFO, socket, or device) (non-regular reserved path; remove manually or use a fresh save destination)"
}
},
)),
)))
}
fn apply_fuse_to_weights(
weights: &mut Weights,
path: &str,
layer: &LoraLayer,
dequantize: bool,
fan_in_fan_out: bool,
) -> Result<()> {
let weight_key = format!("{path}.weight");
let scales_key = format!("{path}.scales");
let biases_key = format!("{path}.biases");
let bias_key = format!("{path}.bias");
weights.remove(&weight_key);
weights.remove(&scales_key);
weights.remove(&biases_key);
weights.remove(&bias_key);
match layer {
LoraLayer::Lora(_) | LoraLayer::Dora(_) => {
let fused = layer.fuse(dequantize)?;
insert_base_linear(weights, path, fused, fan_in_fan_out)?;
}
LoraLayer::DoraEmbedding(_) => {
let fused = layer.fuse_embedding()?;
insert_base_embedding(weights, path, fused);
}
}
Ok(())
}
fn insert_base_linear(
weights: &mut Weights,
path: &str,
fused: BaseLinear,
fan_in_fan_out: bool,
) -> Result<()> {
match fused {
BaseLinear::Dense { weight, bias } => {
let persisted = if fan_in_fan_out {
weight.transpose()?
} else {
weight
};
weights.insert(format!("{path}.weight"), persisted);
if let Some(b) = bias {
weights.insert(format!("{path}.bias"), b);
}
}
BaseLinear::Quantized {
weight,
scales,
quant_biases,
bias,
..
} => {
debug_assert!(
!fan_in_fan_out,
"insert_base_linear: fan_in_fan_out=true reached a Quantized fused output for \
{path:?}; the load side rejects this combination (lora.rs::build_base_linear \
3212-3221) — a packed quantized weight cannot be transposed without corrupting \
the bit-packing"
);
weights.insert(format!("{path}.weight"), weight);
weights.insert(format!("{path}.scales"), scales);
if let Some(qb) = quant_biases {
weights.insert(format!("{path}.biases"), qb);
}
if let Some(b) = bias {
weights.insert(format!("{path}.bias"), b);
}
}
}
Ok(())
}
fn insert_base_embedding(weights: &mut Weights, path: &str, fused: BaseEmbedding) {
match fused {
BaseEmbedding::Dense { weight } => {
weights.insert(format!("{path}.weight"), weight);
}
}
}
fn strip_quantization_blocks(config_json: &str) -> Result<String> {
let value: serde_json::Value = serde_json::from_str(config_json)
.map_err(|e| Error::Parse(ParsePayload::new("fuse: source config", "JSON", e)))?;
let serde_json::Value::Object(mut map) = value else {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"fuse: source config JSON",
"must be an object",
)));
};
map.remove("quantization");
map.remove("quantization_config");
let stripped = serde_json::Value::Object(map);
serde_json::to_string(&stripped).map_err(|e| {
Error::Parse(ParsePayload::new(
"fuse: cannot re-serialize stripped config",
"JSON",
e,
))
})
}
fn reject_hub_url(arg_name: &'static str, path: &Path) -> Result<()> {
let Some(s) = path.to_str() else {
return Ok(());
};
let repo_id = s
.strip_prefix("hf://")
.or_else(|| s.strip_prefix("https://huggingface.co/"))
.or_else(|| s.strip_prefix("http://huggingface.co/"));
if let Some(_repo_id) = repo_id {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
s.to_string(),
Error::InvariantViolation(InvariantViolationPayload::new(
arg_name,
"must be a LOCAL path, not a HuggingFace Hub URL (mlxrs is local-only and does \
not download from the Hub; fetch the model directory out of process — e.g. \
`huggingface-cli download <repo_id>` or `hf download <repo_id>` — and pass the \
resulting local path)",
)),
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reject_hub_url_strips_hf_prefix_in_hint() {
let err =
reject_hub_url("model_path", Path::new("hf://mlx-community/Qwen3-4B-bf16")).unwrap_err();
let Error::LayerKeyed(payload) = err else {
panic!("expected Error::LayerKeyed");
};
assert_eq!(
payload.layer(),
"hf://mlx-community/Qwen3-4B-bf16",
"carrier layer must be the rejected URL",
);
let Error::InvariantViolation(inner) = payload.inner() else {
panic!("expected inner Error::InvariantViolation");
};
assert_eq!(
inner.context(),
"model_path",
"inner context must name the rejected arg"
);
}
#[test]
fn reject_hub_url_strips_https_prefix_in_hint() {
let err = reject_hub_url(
"adapter_path",
Path::new("https://huggingface.co/owner/repo"),
)
.unwrap_err();
let Error::LayerKeyed(payload) = err else {
panic!("expected Error::LayerKeyed");
};
assert_eq!(
payload.layer(),
"https://huggingface.co/owner/repo",
"carrier layer must be the rejected URL"
);
}
#[test]
fn reject_hub_url_passes_through_local_paths() {
assert!(reject_hub_url("model_path", Path::new("/tmp/model")).is_ok());
assert!(reject_hub_url("model_path", Path::new("./relative/path")).is_ok());
assert!(reject_hub_url("model_path", Path::new("~/local/path")).is_ok());
assert!(reject_hub_url("model_path", Path::new("local")).is_ok());
}
#[test]
fn strip_quantization_blocks_removes_both_keys() {
let src = r#"{
"model_type": "qwen3",
"quantization": { "group_size": 64, "bits": 4 },
"quantization_config": { "group_size": 64, "bits": 4 }
}"#;
let stripped = strip_quantization_blocks(src).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
assert!(
parsed.get("quantization").is_none(),
"`quantization` removed"
);
assert!(
parsed.get("quantization_config").is_none(),
"`quantization_config` removed"
);
assert_eq!(
parsed.get("model_type").and_then(|v| v.as_str()),
Some("qwen3")
);
}
#[test]
fn strip_quantization_blocks_passes_through_without_keys() {
let src = r#"{ "model_type": "qwen3", "hidden_size": 16 }"#;
let stripped = strip_quantization_blocks(src).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
assert_eq!(
parsed.get("model_type").and_then(|v| v.as_str()),
Some("qwen3")
);
assert_eq!(parsed.get("hidden_size").and_then(|v| v.as_i64()), Some(16));
}
#[test]
fn strip_quantization_blocks_rejects_non_object_root() {
let src = "[1, 2, 3]";
let err = strip_quantization_blocks(src).unwrap_err();
let Error::InvariantViolation(p) = err else {
panic!("expected Error::InvariantViolation, got {err:?}");
};
assert_eq!(p.context(), "fuse: source config JSON");
assert_eq!(p.requirement(), "must be an object");
}
fn fresh_dir(tag: &str) -> PathBuf {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!("mlxrs-fuse-ut-{tag}-{}-{n}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn staging_dir_create_path_consume_keeps_dir() {
let parent = fresh_dir("staging_create");
let staging = StagingDir::create(&parent).unwrap();
let p = staging.path().to_path_buf();
assert!(p.is_dir(), "staging dir must exist after create");
assert_eq!(p.parent(), Some(parent.as_path()), "staged under parent");
assert!(
p.file_name()
.and_then(|n| n.to_str())
.is_some_and(|n| n.starts_with(".staging-fuse-")),
"staging dir basename must carry the `.staging-fuse-` prefix",
);
let returned = staging.consume();
assert_eq!(returned, p, "consume returns the staged path");
assert!(returned.is_dir(), "consume must NOT remove the dir");
std::fs::remove_dir_all(&parent).unwrap();
}
#[test]
fn staging_dir_drop_removes_dir() {
let parent = fresh_dir("staging_drop");
let staged_path;
{
let staging = StagingDir::create(&parent).unwrap();
staged_path = staging.path().to_path_buf();
assert!(staged_path.is_dir());
}
assert!(
!staged_path.exists(),
"Drop must remove the un-consumed staging dir",
);
std::fs::remove_dir_all(&parent).unwrap();
}
#[test]
fn staging_dir_drop_tolerates_missing_dir() {
let parent = fresh_dir("staging_drop_missing");
{
let staging = StagingDir::create(&parent).unwrap();
std::fs::remove_dir_all(staging.path()).unwrap();
}
std::fs::remove_dir_all(&parent).unwrap();
}
#[test]
fn staging_dir_create_errors_when_parent_missing() {
let parent = fresh_dir("staging_no_parent");
let missing = parent.join("does-not-exist");
let err = StagingDir::create(&missing).unwrap_err();
let Error::FileIo(p) = err else {
panic!("expected Error::FileIo, got {err:?}");
};
assert_eq!(p.op(), FileOp::Create);
assert_eq!(p.path(), missing.as_path(), "payload names the parent");
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
std::fs::remove_dir_all(&parent).unwrap();
}
#[test]
fn remove_stale_reserved_path_absent_is_ok() {
let dir = fresh_dir("rsrp_absent");
let missing = dir.join("generation_config.json");
remove_stale_reserved_path(&missing, "generation_config.json").unwrap();
std::fs::remove_dir_all(&dir).unwrap();
}
#[test]
fn remove_stale_reserved_path_removes_regular_file() {
let dir = fresh_dir("rsrp_file");
let path = dir.join("chat_template.jinja");
std::fs::write(&path, b"stale").unwrap();
assert!(path.is_file());
remove_stale_reserved_path(&path, "chat_template.jinja").unwrap();
assert!(!path.exists(), "stale regular file must be unlinked");
std::fs::remove_dir_all(&dir).unwrap();
}
#[test]
fn remove_stale_reserved_path_rejects_directory() {
let dir = fresh_dir("rsrp_dir");
let path = dir.join("tokenizer_config.json");
std::fs::create_dir_all(&path).unwrap();
let err = remove_stale_reserved_path(&path, "tokenizer_config.json").unwrap_err();
let Error::LayerKeyed(payload) = err else {
panic!("expected Error::LayerKeyed, got {err:?}");
};
assert!(
payload.layer().contains("tokenizer_config.json"),
"carrier layer must name the offending basename; got {}",
payload.layer(),
);
let Error::InvariantViolation(inner) = payload.inner() else {
panic!("expected inner Error::InvariantViolation");
};
assert_eq!(inner.context(), "fuse: stale destination path");
assert!(
inner.requirement().contains("directory"),
"requirement must mark the kind as directory; got {}",
inner.requirement(),
);
assert!(path.is_dir(), "directory must NOT be removed by the sweep");
std::fs::remove_dir_all(&dir).unwrap();
}
#[cfg(unix)]
#[test]
fn remove_stale_reserved_path_rejects_symlink() {
let dir = fresh_dir("rsrp_symlink");
let target = dir.join("real_target_dir");
std::fs::create_dir_all(&target).unwrap();
let link = dir.join("special_tokens_map.json");
std::os::unix::fs::symlink(&target, &link).unwrap();
let err = remove_stale_reserved_path(&link, "special_tokens_map.json").unwrap_err();
let Error::LayerKeyed(payload) = err else {
panic!("expected Error::LayerKeyed, got {err:?}");
};
let Error::InvariantViolation(inner) = payload.inner() else {
panic!("expected inner Error::InvariantViolation");
};
assert!(
inner.requirement().contains("symlink"),
"requirement must mark the kind as symlink; got {}",
inner.requirement(),
);
assert!(
link.symlink_metadata().is_ok(),
"symlink must NOT be removed by the sweep",
);
std::fs::remove_dir_all(&dir).unwrap();
}
#[cfg(unix)]
#[test]
fn remove_stale_reserved_path_rejects_fifo() {
use std::ffi::CString;
let dir = fresh_dir("rsrp_fifo");
let path = dir.join("vocab.json");
let c = CString::new(path.as_os_str().to_str().unwrap()).unwrap();
let rc = unsafe { libc::mkfifo(c.as_ptr(), 0o644) };
assert_eq!(rc, 0, "mkfifo must succeed for the fixture");
let err = remove_stale_reserved_path(&path, "vocab.json").unwrap_err();
let Error::LayerKeyed(payload) = err else {
panic!("expected Error::LayerKeyed, got {err:?}");
};
let Error::InvariantViolation(inner) = payload.inner() else {
panic!("expected inner Error::InvariantViolation");
};
assert!(
inner.requirement().contains("non-regular"),
"requirement must mark the FIFO as non-regular; got {}",
inner.requirement(),
);
let _ = std::fs::remove_file(&path);
std::fs::remove_dir_all(&dir).unwrap();
}
#[test]
fn promote_staging_inner_promotes_and_sweeps() {
let save = fresh_dir("promote_inner");
let staging = save.join(".staging-fuse-test");
std::fs::create_dir_all(&staging).unwrap();
std::fs::write(staging.join("tokenizer.json"), b"{}\n").unwrap();
std::fs::write(staging.join("keep_me.py"), b"# keep\n").unwrap();
std::fs::create_dir_all(staging.join("a_subdir")).unwrap();
std::fs::write(save.join("generation_config.json"), b"stale").unwrap();
std::fs::write(save.join("stale_mod.py"), b"# stale").unwrap();
std::fs::write(save.join("config.json"), b"{}\n").unwrap();
let warn = promote_staging_inner(&staging, &save).unwrap();
assert!(warn.is_none(), "no fsync fault injected → no warning");
assert!(save.join("tokenizer.json").is_file(), "tokenizer promoted");
assert!(
save.join("keep_me.py").is_file(),
"snapshot *.py promoted + survives the sweep",
);
assert!(
!save.join("generation_config.json").exists(),
"stale TOKENIZER_EXTRA_FILES member must be swept",
);
assert!(
!save.join("stale_mod.py").exists(),
"stale *.py must be swept",
);
assert!(save.join("config.json").is_file(), "config.json untouched");
assert!(
staging.join("a_subdir").is_dir(),
"non-regular staged entry must be skipped (left in staging)",
);
std::fs::remove_dir_all(&save).unwrap();
}
#[test]
fn promote_staging_inner_read_dir_error() {
let save = fresh_dir("promote_inner_readdir");
let missing_staging = save.join(".staging-fuse-absent");
let err = promote_staging_inner(&missing_staging, &save).unwrap_err();
let Error::FileIo(p) = err else {
panic!("expected Error::FileIo, got {err:?}");
};
assert_eq!(p.op(), FileOp::Read);
assert_eq!(p.path(), missing_staging.as_path());
std::fs::remove_dir_all(&save).unwrap();
}
#[test]
fn promote_staging_inner_rejects_stale_directory_member() {
let save = fresh_dir("promote_inner_dir_member");
let staging = save.join(".staging-fuse-test");
std::fs::create_dir_all(&staging).unwrap();
std::fs::write(staging.join("tokenizer.json"), b"{}\n").unwrap();
std::fs::create_dir_all(save.join("added_tokens.json")).unwrap();
let err = promote_staging_inner(&staging, &save).unwrap_err();
let Error::LayerKeyed(payload) = err else {
panic!("expected Error::LayerKeyed, got {err:?}");
};
assert!(payload.layer().contains("added_tokens.json"));
std::fs::remove_dir_all(&save).unwrap();
}
#[test]
fn promote_staging_into_save_path_success() {
let save = fresh_dir("promote_outer_ok");
let staging = StagingDir::create(&save).unwrap();
let staging_path = staging.path().to_path_buf();
std::fs::write(staging_path.join("tokenizer.json"), b"{}\n").unwrap();
let out = promote_staging_into_save_path(staging, &save).unwrap();
assert!(out.post_promote_file.is_none());
assert!(
save.join("tokenizer.json").is_file(),
"staged file promoted to save_path",
);
assert!(
!staging_path.exists(),
"empty staging dir removed on success",
);
std::fs::remove_dir_all(&save).unwrap();
}
#[test]
fn promote_staging_into_save_path_warns_on_nonempty_staging() {
let save = fresh_dir("promote_outer_nonempty");
let staging = StagingDir::create(&save).unwrap();
let staging_path = staging.path().to_path_buf();
std::fs::write(staging_path.join("tokenizer.json"), b"{}\n").unwrap();
std::fs::create_dir_all(staging_path.join("stray_subdir")).unwrap();
let out = promote_staging_into_save_path(staging, &save).unwrap();
assert!(out.post_promote_file.is_none());
assert!(save.join("tokenizer.json").is_file(), "file still promoted");
assert!(
staging_path.join("stray_subdir").is_dir(),
"remove_dir leaves the non-empty staging dir in place (warn-only)",
);
std::fs::remove_dir_all(&save).unwrap();
}
#[cfg(unix)]
#[test]
fn reject_hub_url_passes_non_utf8_path() {
use std::{ffi::OsStr, os::unix::ffi::OsStrExt};
let bytes = b"/tmp/\xff\xfe-not-utf8";
let p = Path::new(OsStr::from_bytes(bytes));
assert!(
reject_hub_url("model_path", p).is_ok(),
"non-utf8 path is not a hub URL → Ok",
);
}
#[test]
fn insert_base_linear_dense_with_bias() {
let mut weights: Weights = std::collections::HashMap::new();
let w = crate::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
let b = crate::Array::from_slice::<f32>(&[10.0, 20.0], &(2usize,)).unwrap();
let fused = BaseLinear::dense(w, Some(b)).unwrap();
insert_base_linear(&mut weights, "model.layer", fused, false).unwrap();
let mut wt = weights
.remove("model.layer.weight")
.expect(".weight written");
assert_eq!(
wt.shape(),
vec![2, 3],
"weight shape preserved (no transpose)"
);
assert_eq!(
wt.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
"weight bytes inserted verbatim",
);
let mut bias = weights.remove("model.layer.bias").expect(".bias written");
assert_eq!(bias.to_vec::<f32>().unwrap(), vec![10.0, 20.0]);
assert!(
!weights.contains_key("model.layer.scales"),
"dense output writes no .scales",
);
assert!(
!weights.contains_key("model.layer.biases"),
"dense output writes no .biases",
);
}
#[test]
fn insert_base_linear_dense_fan_in_fan_out_transposes() {
let mut weights: Weights = std::collections::HashMap::new();
let w = crate::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
let fused = BaseLinear::dense(w, None).unwrap();
insert_base_linear(&mut weights, "ffn", fused, true).unwrap();
let wt = weights.remove("ffn.weight").expect(".weight written");
assert_eq!(
wt.shape(),
vec![3, 2],
"fan_in_fan_out transposes [out, in] → [in, out]",
);
let mut wt_c = crate::ops::shape::contiguous(&wt, false).unwrap();
assert_eq!(
wt_c.to_vec::<f32>().unwrap(),
vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0],
"transposed weight bytes match the hand-computed Wᵀ",
);
assert!(!weights.contains_key("ffn.bias"), "no bias inserted");
}
#[test]
fn insert_base_linear_quantized_full_triple() {
let mut weights: Weights = std::collections::HashMap::new();
let w = crate::Array::from_slice::<u32>(&[0u32, 1, 2, 3], &(2, 2)).unwrap();
let scales = crate::Array::from_slice::<f32>(&[0.5, 0.25], &(2usize,)).unwrap();
let qb = crate::Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap();
let bias = crate::Array::from_slice::<f32>(&[7.0, 8.0], &(2usize,)).unwrap();
let fused =
BaseLinear::quantized(w, scales, Some(qb), Some(bias), 32, 8, "affine".to_string()).unwrap();
insert_base_linear(&mut weights, "attn", fused, false).unwrap();
assert!(weights.contains_key("attn.weight"), ".weight written");
assert!(weights.contains_key("attn.scales"), ".scales written");
assert!(
weights.contains_key("attn.biases"),
".biases (quant_biases) written",
);
let mut bias = weights.remove("attn.bias").expect(".bias written");
assert_eq!(bias.to_vec::<f32>().unwrap(), vec![7.0, 8.0]);
}
#[test]
fn insert_base_embedding_writes_weight_only() {
let mut weights: Weights = std::collections::HashMap::new();
let w = crate::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let fused = BaseEmbedding::dense(w).unwrap();
insert_base_embedding(&mut weights, "tok_emb", fused);
let mut wt = weights.remove("tok_emb.weight").expect(".weight written");
assert_eq!(wt.shape(), vec![2, 2]);
assert_eq!(wt.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(weights.len(), 0, "embedding writes ONLY the .weight key");
}
#[test]
fn apply_fuse_to_weights_embedding_rewrites_weight_and_drops_siblings() {
let num_embeddings = 3usize;
let dims = 3usize;
let r = 2usize;
let weight = crate::Array::from_slice::<f32>(
&[1.0, 0.5, 0.0, 0.0, 1.0, 0.5, 0.5, 0.0, 1.0],
&(num_embeddings, dims),
)
.unwrap();
let base = BaseEmbedding::dense(weight).unwrap();
let lora_a =
crate::Array::from_slice::<f32>(&[0.1, 0.0, 0.0, 0.1, 0.1, 0.1], &(num_embeddings, r))
.unwrap();
let lora_b =
crate::Array::from_slice::<f32>(&[0.2, 0.0, 0.1, 0.0, 0.1, 0.2], &(r, dims)).unwrap();
let m = crate::Array::from_slice::<f32>(&[1.5, 2.0, 1.2], &(num_embeddings,)).unwrap();
let params = lora::AdapterParams {
lora_a,
lora_b,
magnitude: Some(m),
};
let layer = LoraLayer::DoraEmbedding(lora::DoRAEmbedding::new(base, params, 2.0).unwrap());
let mut weights: Weights = std::collections::HashMap::new();
let path = "model.embed_tokens";
weights.insert(
format!("{path}.weight"),
crate::Array::from_slice::<f32>(&[0.0], &(1usize,)).unwrap(),
);
weights.insert(
format!("{path}.scales"),
crate::Array::from_slice::<f32>(&[0.0], &(1usize,)).unwrap(),
);
weights.insert(
format!("{path}.biases"),
crate::Array::from_slice::<f32>(&[0.0], &(1usize,)).unwrap(),
);
weights.insert(
format!("{path}.bias"),
crate::Array::from_slice::<f32>(&[0.0], &(1usize,)).unwrap(),
);
apply_fuse_to_weights(&mut weights, path, &layer, false, true).unwrap();
assert!(
weights.contains_key(&format!("{path}.weight")),
"fused embedding weight written",
);
let wt = weights.remove(&format!("{path}.weight")).unwrap();
assert_eq!(
wt.shape(),
vec![num_embeddings, dims],
"fused embedding weight is [num_embeddings, dims]",
);
assert!(
!weights.contains_key(&format!("{path}.scales")),
"stale .scales dropped on the dense embedding output",
);
assert!(
!weights.contains_key(&format!("{path}.biases")),
"stale .biases dropped",
);
assert!(
!weights.contains_key(&format!("{path}.bias")),
"stale .bias dropped (embedding has no bias)",
);
}
}