use crate::flux2::Flux2Config;
use crate::loader::{RenameOutput, Sd15Remap, SdxlRemap};
use anyhow::{anyhow, Context, Result};
use candle_core::{safetensors::MmapedSafetensors, DType, Device, Tensor};
use candle_nn::var_builder::SimpleBackend;
use std::collections::BTreeMap;
use std::fs::File;
use std::io::Read;
use std::path::Path;
fn check_safetensors_not_truncated(path: &Path) -> Result<()> {
let file_size = std::fs::metadata(path)
.with_context(|| format!("stat {} for size check", path.display()))?
.len();
let mut f =
File::open(path).with_context(|| format!("open {} for size check", path.display()))?;
let mut len_buf = [0u8; 8];
f.read_exact(&mut len_buf).with_context(|| {
format!(
"read safetensors header length at {} (file is only {} bytes — likely truncated)",
path.display(),
file_size,
)
})?;
let header_len = u64::from_le_bytes(len_buf);
let header_end = 8u64.saturating_add(header_len);
if header_end > file_size {
return Err(anyhow!(
"checkpoint at {} is truncated: file is {} bytes but the safetensors header alone \
needs {} bytes (8-byte length prefix + {} declared header length). \
Re-download the model — the file is incomplete.",
path.display(),
file_size,
header_end,
header_len,
));
}
let mut header_buf = vec![0u8; header_len as usize];
f.read_exact(&mut header_buf)
.with_context(|| format!("read safetensors header at {}", path.display()))?;
let header: serde_json::Value = serde_json::from_slice(&header_buf)
.with_context(|| format!("parse safetensors header JSON at {}", path.display()))?;
let obj = header.as_object().ok_or_else(|| {
anyhow!(
"safetensors header at {} is not a JSON object",
path.display(),
)
})?;
let mut max_end: u64 = 0;
for (k, v) in obj {
if k == "__metadata__" {
continue;
}
if let Some(end) = v
.get("data_offsets")
.and_then(|x| x.as_array())
.filter(|a| a.len() == 2)
.and_then(|a| a[1].as_u64())
{
max_end = max_end.max(end);
}
}
let expected_total = header_end.saturating_add(max_end);
if expected_total > file_size {
let missing = expected_total - file_size;
return Err(anyhow!(
"checkpoint at {} is truncated: file is {} bytes but the safetensors header declares \
tensor data ending at {} bytes ({} bytes missing). \
The download is incomplete — re-fetch the model.",
path.display(),
file_size,
expected_total,
missing,
));
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Nvfp4Component {
Packed,
BlockScales,
TensorScale,
SliceMeta {
axis: u32,
component: u32,
num_components: u32,
},
}
#[derive(Debug, Clone)]
enum BackendEntry {
Direct { source_key: String },
Slice {
source_key: String,
axis: usize,
component: usize,
num_components: usize,
},
Nvfp4Component {
source_base: String,
component: Nvfp4Component,
},
SwapHalves { source_key: String, axis: usize },
}
pub struct SingleFileBackend {
st: MmapedSafetensors,
entries: BTreeMap<String, BackendEntry>,
}
impl SingleFileBackend {
fn from_entries(checkpoint: &Path, entries: BTreeMap<String, BackendEntry>) -> Result<Self> {
check_safetensors_not_truncated(checkpoint).with_context(|| {
format!(
"validate single-file checkpoint at {}",
checkpoint.display(),
)
})?;
let st = unsafe { MmapedSafetensors::new(checkpoint) }
.with_context(|| format!("mmap single-file checkpoint at {}", checkpoint.display()))?;
Ok(Self { st, entries })
}
fn direct_entries(remap_slice: &BTreeMap<String, String>) -> BTreeMap<String, BackendEntry> {
let mut entries: BTreeMap<String, BackendEntry> = BTreeMap::new();
for (diffusers, a1111) in remap_slice {
entries.insert(
diffusers.clone(),
BackendEntry::Direct {
source_key: a1111.clone(),
},
);
}
entries
}
fn clip_g_entries(
clip_g_remap: &BTreeMap<String, (String, RenameOutput)>,
) -> BTreeMap<String, BackendEntry> {
let mut entries: BTreeMap<String, BackendEntry> = BTreeMap::new();
for (diffusers, (a1111_key, output)) in clip_g_remap {
let entry = match output {
RenameOutput::Direct(_) => BackendEntry::Direct {
source_key: a1111_key.clone(),
},
RenameOutput::FusedSlice {
axis,
component,
num_components,
..
} => BackendEntry::Slice {
source_key: a1111_key.clone(),
axis: *axis,
component: *component,
num_components: *num_components,
},
};
entries.insert(diffusers.clone(), entry);
}
entries
}
pub fn from_sd15_remap(checkpoint: &Path, remap: &Sd15Remap) -> Result<Self> {
let mut entries: BTreeMap<String, BackendEntry> = BTreeMap::new();
for (diffusers, a1111) in remap
.unet
.iter()
.chain(remap.vae.iter())
.chain(remap.clip_l.iter())
{
entries.insert(
diffusers.clone(),
BackendEntry::Direct {
source_key: a1111.clone(),
},
);
}
Self::from_entries(checkpoint, entries)
}
pub fn from_sd15_unet(checkpoint: &Path, remap: &Sd15Remap) -> Result<Self> {
Self::from_entries(checkpoint, Self::direct_entries(&remap.unet))
}
pub fn from_sd15_vae(checkpoint: &Path, remap: &Sd15Remap) -> Result<Self> {
Self::from_entries(checkpoint, Self::direct_entries(&remap.vae))
}
pub fn from_sd15_clip_l(checkpoint: &Path, remap: &Sd15Remap) -> Result<Self> {
Self::from_entries(checkpoint, Self::direct_entries(&remap.clip_l))
}
pub fn from_sdxl_remap(checkpoint: &Path, remap: &SdxlRemap) -> Result<Self> {
let mut entries: BTreeMap<String, BackendEntry> = BTreeMap::new();
for (diffusers, a1111) in remap
.unet
.iter()
.chain(remap.vae.iter())
.chain(remap.clip_l.iter())
{
entries.insert(
diffusers.clone(),
BackendEntry::Direct {
source_key: a1111.clone(),
},
);
}
for (diffusers, entry) in Self::clip_g_entries(&remap.clip_g) {
entries.insert(diffusers, entry);
}
Self::from_entries(checkpoint, entries)
}
pub fn from_sdxl_unet(checkpoint: &Path, remap: &SdxlRemap) -> Result<Self> {
Self::from_entries(checkpoint, Self::direct_entries(&remap.unet))
}
pub fn from_sdxl_vae(checkpoint: &Path, remap: &SdxlRemap) -> Result<Self> {
Self::from_entries(checkpoint, Self::direct_entries(&remap.vae))
}
pub fn from_sdxl_clip_l(checkpoint: &Path, remap: &SdxlRemap) -> Result<Self> {
Self::from_entries(checkpoint, Self::direct_entries(&remap.clip_l))
}
pub fn from_sdxl_clip_g(checkpoint: &Path, remap: &SdxlRemap) -> Result<Self> {
Self::from_entries(checkpoint, Self::clip_g_entries(&remap.clip_g))
}
pub fn from_flux2_singlefile(checkpoint: &Path, cfg: &Flux2Config) -> Result<Self> {
let format = crate::flux2::detect_format(checkpoint).with_context(|| {
format!("peek single-file Flux.2 header at {}", checkpoint.display(),)
})?;
let (prefix, quant) = match format {
crate::flux2::Flux2SingleFileFormat::Nvfp4 => ("model.diffusion_model.", Quant::Nvfp4),
crate::flux2::Flux2SingleFileFormat::BflNative => {
("model.diffusion_model.", Quant::None)
}
crate::flux2::Flux2SingleFileFormat::BflNativeRoot => ("", Quant::None),
crate::flux2::Flux2SingleFileFormat::Diffusers
| crate::flux2::Flux2SingleFileFormat::Unknown => {
return Err(anyhow!(
"checkpoint does not look like a BFL-native single-file \
Flux.2 (no model.diffusion_model.* or root-level BFL keys \
found — expected Civitai/ComfyUI export)",
));
}
};
let nvfp4_bases: std::collections::BTreeSet<String> = if quant == Quant::Nvfp4 {
collect_nvfp4_bases(checkpoint).with_context(|| {
format!(
"enumerate NVFP4 bases in {} (header peek)",
checkpoint.display(),
)
})?
} else {
std::collections::BTreeSet::new()
};
let rms_suffix = detect_rms_norm_suffix(checkpoint, prefix).with_context(|| {
format!(
"probe RMSNorm tensor suffix in {} (header peek)",
checkpoint.display(),
)
})?;
let entries = build_flux2_entries(cfg, prefix, quant, &nvfp4_bases, rms_suffix);
Self::from_entries(checkpoint, entries)
}
fn lookup(&self, diffusers_key: &str, dev: &Device) -> candle_core::Result<Tensor> {
let entry = self.entries.get(diffusers_key).ok_or_else(|| {
candle_core::Error::Msg(format!(
"single-file backend: no rename rule for diffusers key '{diffusers_key}'"
))
})?;
match entry {
BackendEntry::Direct { source_key } => self.st.load(source_key, dev),
BackendEntry::Slice {
source_key,
axis,
component,
num_components,
} => {
let full = self.st.load(source_key, dev)?;
let total = full.dim(*axis)?;
if *num_components == 0 || total % num_components != 0 {
return Err(candle_core::Error::Msg(format!(
"single-file backend: source tensor '{source_key}' axis {axis} dim {total} is not divisible by num_components {num_components}",
)));
}
let stride = total / num_components;
let offset = component * stride;
full.narrow(*axis, offset, stride)
}
BackendEntry::Nvfp4Component {
source_base,
component,
} => self.load_nvfp4_component(source_base, *component, dev),
BackendEntry::SwapHalves { source_key, axis } => {
let t = self.st.load(source_key, dev)?;
let total = t.dim(*axis)?;
if total % 2 != 0 {
return Err(candle_core::Error::Msg(format!(
"single-file backend: SwapHalves source '{source_key}' axis {axis} dim {total} is odd",
)));
}
let half = total / 2;
let first = t.narrow(*axis, 0, half)?;
let second = t.narrow(*axis, half, half)?;
Tensor::cat(&[&second, &first], *axis)
}
}
}
fn load_nvfp4_component(
&self,
source_base: &str,
component: Nvfp4Component,
_requested_dev: &Device,
) -> candle_core::Result<Tensor> {
let cpu = Device::Cpu;
match component {
Nvfp4Component::Packed => {
let weight_key = format!("{source_base}.weight");
let t = self.st.load(&weight_key, &cpu)?;
if t.dtype() != DType::U8 {
return Err(candle_core::Error::Msg(format!(
"NVFP4: expected '{weight_key}' to be U8 packed FP4, got {:?}",
t.dtype()
)));
}
Ok(t)
}
Nvfp4Component::BlockScales => {
let scale_key = format!("{source_base}.weight_scale");
let t = self.st.load(&scale_key, &cpu)?;
if t.dtype() != DType::F8E4M3 {
return Err(candle_core::Error::Msg(format!(
"NVFP4: expected '{scale_key}' to be F8E4M3 block scales, got {:?}",
t.dtype()
)));
}
Ok(t)
}
Nvfp4Component::TensorScale => {
let scale2_key = format!("{source_base}.weight_scale_2");
let t = self.st.load(&scale2_key, &cpu)?;
t.to_dtype(DType::F32)
}
Nvfp4Component::SliceMeta {
axis,
component,
num_components,
} => Tensor::from_vec(vec![axis, component, num_components], 3, &cpu),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Quant {
None,
Nvfp4,
}
fn weight_base(bfl_full_key: &str) -> &str {
bfl_full_key.strip_suffix(".weight").unwrap_or(bfl_full_key)
}
fn collect_nvfp4_bases(path: &Path) -> Result<std::collections::BTreeSet<String>> {
use std::collections::BTreeSet;
use std::fs::File;
use std::io::Read;
let mut file = File::open(path).with_context(|| format!("open {}", path.display()))?;
let mut len_buf = [0u8; 8];
file.read_exact(&mut len_buf)?;
let header_len = u64::from_le_bytes(len_buf) as usize;
let mut header_buf = vec![0u8; header_len];
file.read_exact(&mut header_buf)?;
let header: serde_json::Value =
serde_json::from_slice(&header_buf).with_context(|| "parse safetensors header JSON")?;
let obj = header
.as_object()
.ok_or_else(|| anyhow!("safetensors header is not a JSON object"))?;
let mut has_scale: BTreeSet<String> = BTreeSet::new();
let mut has_scale_2: BTreeSet<String> = BTreeSet::new();
for key in obj.keys() {
if key == "__metadata__" {
continue;
}
if let Some(base) = key.strip_suffix(".weight_scale") {
has_scale.insert(base.to_string());
} else if let Some(base) = key.strip_suffix(".weight_scale_2") {
has_scale_2.insert(base.to_string());
}
}
Ok(has_scale.intersection(&has_scale_2).cloned().collect())
}
fn detect_rms_norm_suffix(path: &Path, prefix: &str) -> Result<&'static str> {
use std::fs::File;
use std::io::Read;
let mut file = File::open(path).with_context(|| format!("open {}", path.display()))?;
let mut len_buf = [0u8; 8];
file.read_exact(&mut len_buf)?;
let header_len = u64::from_le_bytes(len_buf) as usize;
let mut header_buf = vec![0u8; header_len];
file.read_exact(&mut header_buf)?;
let header: serde_json::Value =
serde_json::from_slice(&header_buf).with_context(|| "parse safetensors header JSON")?;
let obj = header
.as_object()
.ok_or_else(|| anyhow!("safetensors header is not a JSON object"))?;
let probe_scale = format!("{prefix}double_blocks.0.img_attn.norm.query_norm.scale");
let probe_weight = format!("{prefix}double_blocks.0.img_attn.norm.query_norm.weight");
if obj.contains_key(&probe_scale) {
Ok("scale")
} else if obj.contains_key(&probe_weight) {
Ok("weight")
} else {
Ok("scale")
}
}
fn nvfp4_subkeys(diffusers_key: &str, source_base: &str) -> Vec<(String, BackendEntry)> {
debug_assert!(
diffusers_key.ends_with(".weight"),
"NVFP4 routing only applies to `.weight` keys"
);
vec![
(
format!("{diffusers_key}.nvfp4_packed"),
BackendEntry::Nvfp4Component {
source_base: source_base.to_string(),
component: Nvfp4Component::Packed,
},
),
(
format!("{diffusers_key}.nvfp4_block_scales"),
BackendEntry::Nvfp4Component {
source_base: source_base.to_string(),
component: Nvfp4Component::BlockScales,
},
),
(
format!("{diffusers_key}.nvfp4_tensor_scale"),
BackendEntry::Nvfp4Component {
source_base: source_base.to_string(),
component: Nvfp4Component::TensorScale,
},
),
]
}
fn direct(
diffusers: &str,
bfl_suffix: &str,
prefix: &str,
quant: Quant,
nvfp4_bases: &std::collections::BTreeSet<String>,
) -> Vec<(String, BackendEntry)> {
let source_key = format!("{prefix}{bfl_suffix}");
let source_base = weight_base(&source_key).to_string();
let route_nvfp4 = quant == Quant::Nvfp4
&& bfl_suffix.ends_with(".weight")
&& nvfp4_bases.contains(&source_base);
if !route_nvfp4 {
return vec![(diffusers.to_string(), BackendEntry::Direct { source_key })];
}
nvfp4_subkeys(diffusers, &source_base)
}
fn slice_qkv(
diffusers: &str,
bfl_suffix: &str,
component: usize,
prefix: &str,
quant: Quant,
nvfp4_bases: &std::collections::BTreeSet<String>,
) -> Vec<(String, BackendEntry)> {
let source_key = format!("{prefix}{bfl_suffix}");
let source_base = weight_base(&source_key).to_string();
let route_nvfp4 = quant == Quant::Nvfp4 && nvfp4_bases.contains(&source_base);
if !route_nvfp4 {
return vec![(
diffusers.to_string(),
BackendEntry::Slice {
source_key,
axis: 0,
component,
num_components: 3,
},
)];
}
let mut entries = nvfp4_subkeys(diffusers, &source_base);
entries.push((
format!("{diffusers}.nvfp4_slice_meta"),
BackendEntry::Nvfp4Component {
source_base,
component: Nvfp4Component::SliceMeta {
axis: 0,
component: component as u32,
num_components: 3,
},
},
));
entries
}
fn build_flux2_entries(
cfg: &Flux2Config,
prefix: &str,
quant: Quant,
nvfp4_bases: &std::collections::BTreeSet<String>,
rms_suffix: &str,
) -> BTreeMap<String, BackendEntry> {
let mut e: BTreeMap<String, BackendEntry> = BTreeMap::new();
let top_level = [
("x_embedder.weight", "img_in.weight"),
("context_embedder.weight", "txt_in.weight"),
(
"time_guidance_embed.timestep_embedder.linear_1.weight",
"time_in.in_layer.weight",
),
(
"time_guidance_embed.timestep_embedder.linear_2.weight",
"time_in.out_layer.weight",
),
("proj_out.weight", "final_layer.linear.weight"),
(
"double_stream_modulation_img.linear.weight",
"double_stream_modulation_img.lin.weight",
),
(
"double_stream_modulation_txt.linear.weight",
"double_stream_modulation_txt.lin.weight",
),
(
"single_stream_modulation.linear.weight",
"single_stream_modulation.lin.weight",
),
];
for (d, b) in top_level {
for (k, v) in direct(d, b, prefix, quant, nvfp4_bases) {
e.insert(k, v);
}
}
let ada_ln_bfl_key = format!("{prefix}final_layer.adaLN_modulation.1.weight");
e.insert(
"norm_out.linear.weight".to_string(),
BackendEntry::SwapHalves {
source_key: ada_ln_bfl_key,
axis: 0,
},
);
if cfg.vec_in_dim > 0 {
for (d, b) in [
("vector_in.linear_1.weight", "vector_in.in_layer.weight"),
("vector_in.linear_2.weight", "vector_in.out_layer.weight"),
] {
for (k, v) in direct(d, b, prefix, quant, nvfp4_bases) {
e.insert(k, v);
}
}
}
if cfg.guidance_embed {
for (d, b) in [
(
"time_guidance_embed.guidance_embedder.linear_1.weight",
"guidance_in.in_layer.weight",
),
(
"time_guidance_embed.guidance_embedder.linear_2.weight",
"guidance_in.out_layer.weight",
),
] {
for (k, v) in direct(d, b, prefix, quant, nvfp4_bases) {
e.insert(k, v);
}
}
}
for i in 0..cfg.depth {
for (component, comp_name) in [(0usize, "to_q"), (1, "to_k"), (2, "to_v")] {
for (k, v) in slice_qkv(
&format!("transformer_blocks.{i}.attn.{comp_name}.weight"),
&format!("double_blocks.{i}.img_attn.qkv.weight"),
component,
prefix,
quant,
nvfp4_bases,
) {
e.insert(k, v);
}
}
let img_direct: [(&str, String); 5] = [
("attn.to_out.0.weight", "img_attn.proj.weight".to_string()),
(
"attn.norm_q.weight",
format!("img_attn.norm.query_norm.{rms_suffix}"),
),
(
"attn.norm_k.weight",
format!("img_attn.norm.key_norm.{rms_suffix}"),
),
("ff.linear_in.weight", "img_mlp.0.weight".to_string()),
("ff.linear_out.weight", "img_mlp.2.weight".to_string()),
];
for (d_suffix, b_suffix) in &img_direct {
for (k, v) in direct(
&format!("transformer_blocks.{i}.{d_suffix}"),
&format!("double_blocks.{i}.{b_suffix}"),
prefix,
quant,
nvfp4_bases,
) {
e.insert(k, v);
}
}
for (component, comp_name) in [(0usize, "add_q_proj"), (1, "add_k_proj"), (2, "add_v_proj")]
{
for (k, v) in slice_qkv(
&format!("transformer_blocks.{i}.attn.{comp_name}.weight"),
&format!("double_blocks.{i}.txt_attn.qkv.weight"),
component,
prefix,
quant,
nvfp4_bases,
) {
e.insert(k, v);
}
}
let txt_direct: [(&str, String); 5] = [
("attn.to_add_out.weight", "txt_attn.proj.weight".to_string()),
(
"attn.norm_added_q.weight",
format!("txt_attn.norm.query_norm.{rms_suffix}"),
),
(
"attn.norm_added_k.weight",
format!("txt_attn.norm.key_norm.{rms_suffix}"),
),
(
"ff_context.linear_in.weight",
"txt_mlp.0.weight".to_string(),
),
(
"ff_context.linear_out.weight",
"txt_mlp.2.weight".to_string(),
),
];
for (d_suffix, b_suffix) in &txt_direct {
for (k, v) in direct(
&format!("transformer_blocks.{i}.{d_suffix}"),
&format!("double_blocks.{i}.{b_suffix}"),
prefix,
quant,
nvfp4_bases,
) {
e.insert(k, v);
}
}
}
for i in 0..cfg.depth_single_blocks {
let single_direct: [(&str, String); 4] = [
("attn.to_qkv_mlp_proj.weight", "linear1.weight".to_string()),
("attn.to_out.weight", "linear2.weight".to_string()),
(
"attn.norm_q.weight",
format!("norm.query_norm.{rms_suffix}"),
),
("attn.norm_k.weight", format!("norm.key_norm.{rms_suffix}")),
];
for (d_suffix, b_suffix) in &single_direct {
for (k, v) in direct(
&format!("single_transformer_blocks.{i}.{d_suffix}"),
&format!("single_blocks.{i}.{b_suffix}"),
prefix,
quant,
nvfp4_bases,
) {
e.insert(k, v);
}
}
}
e
}
impl SimpleBackend for SingleFileBackend {
fn get(
&self,
s: candle_core::Shape,
name: &str,
_h: candle_nn::Init,
dtype: DType,
dev: &Device,
) -> candle_core::Result<Tensor> {
let t = self.lookup(name, dev)?;
if t.shape() != &s {
return Err(candle_core::Error::UnexpectedShape {
msg: format!("single-file backend: shape mismatch for {name}"),
expected: s,
got: t.shape().clone(),
}
.bt());
}
if t.dtype() != dtype {
t.to_dtype(dtype)
} else {
Ok(t)
}
}
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
let t = self.lookup(name, dev)?;
if t.dtype() != dtype {
t.to_dtype(dtype)
} else {
Ok(t)
}
}
fn contains_tensor(&self, name: &str) -> bool {
self.entries.contains_key(name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::loader::single_file::load as load_bundle;
use crate::loader::{build_sd15_remap, build_sdxl_remap};
use mold_catalog::families::Family;
use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
use std::collections::HashMap;
use std::path::PathBuf;
fn write_synthetic_with_tensors(
name: &str,
tensors: &[(String, Vec<usize>, Vec<f32>)],
) -> PathBuf {
let refs: Vec<(&str, Vec<usize>, Vec<f32>)> = tensors
.iter()
.map(|(k, s, d)| (k.as_str(), s.clone(), d.clone()))
.collect();
write_synthetic(name, &refs)
}
fn write_synthetic(name: &str, tensors: &[(&str, Vec<usize>, Vec<f32>)]) -> PathBuf {
let path = std::env::temp_dir().join(format!(
"mold-sf-backend-{}-{}-{}.safetensors",
name,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
let buffers: Vec<Vec<u8>> = tensors
.iter()
.map(|(_, _, data)| {
let mut bytes = Vec::with_capacity(data.len() * 4);
for v in data {
bytes.extend_from_slice(&v.to_le_bytes());
}
bytes
})
.collect();
let mut views: HashMap<String, TensorView<'_>> = HashMap::new();
for ((key, shape, _), buf) in tensors.iter().zip(buffers.iter()) {
views.insert(
(*key).to_string(),
TensorView::new(SafeDtype::F32, shape.clone(), buf).unwrap(),
);
}
serialize_to_file(&views, &None, &path).unwrap();
path
}
#[test]
fn sd15_backend_resolves_diffusers_key_to_a1111_tensor() {
let path = write_synthetic(
"sd15-direct",
&[(
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
vec![2, 2],
vec![1.5, 2.5, 3.5, 4.5],
)],
);
let bundle = load_bundle(&path, Family::Sd15).expect("partition sd15");
let remap = build_sd15_remap(&bundle).expect("build remap");
let backend = SingleFileBackend::from_sd15_remap(&path, &remap).expect("backend");
let dev = Device::Cpu;
let t = SimpleBackend::get_unchecked(
&backend,
"text_model.encoder.layers.0.self_attn.q_proj.weight",
DType::F32,
&dev,
)
.expect("direct lookup must hit");
assert_eq!(t.dims(), &[2, 2]);
let flat: Vec<f32> = t.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(flat, vec![1.5, 2.5, 3.5, 4.5]);
let _ = std::fs::remove_file(path);
}
#[test]
fn sdxl_backend_slices_clip_g_fused_qkv_weight() {
let d: usize = 4;
let mut data = Vec::with_capacity(3 * d * d);
for component in 1..=3 {
for _row in 0..d {
for _col in 0..d {
data.push(component as f32);
}
}
}
let path = write_synthetic(
"sdxl-fused-qkv-w",
&[
(
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
vec![2, 2],
vec![0.1, 0.2, 0.3, 0.4],
),
(
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
vec![3 * d, d],
data,
),
],
);
let bundle = load_bundle(&path, Family::Sdxl).expect("partition sdxl");
let remap = build_sdxl_remap(&bundle).expect("build remap");
let backend = SingleFileBackend::from_sdxl_remap(&path, &remap).expect("backend");
let dev = Device::Cpu;
for (component, expected_value) in [(0usize, 1.0f32), (1, 2.0), (2, 3.0)] {
let diffusers_key = match component {
0 => "text_model.encoder.layers.0.self_attn.q_proj.weight",
1 => "text_model.encoder.layers.0.self_attn.k_proj.weight",
2 => "text_model.encoder.layers.0.self_attn.v_proj.weight",
_ => unreachable!(),
};
let t = SimpleBackend::get_unchecked(&backend, diffusers_key, DType::F32, &dev)
.unwrap_or_else(|e| panic!("slice lookup for component {component}: {e}"));
assert_eq!(
t.dims(),
&[d, d],
"{diffusers_key}: slice must be [d, d], not full [3*d, d]",
);
let flat: Vec<f32> = t.flatten_all().unwrap().to_vec1().unwrap();
assert!(
flat.iter().all(|&v| v == expected_value),
"{diffusers_key}: every value must be {expected_value} (got {flat:?})",
);
}
let _ = std::fs::remove_file(path);
}
#[test]
fn sdxl_backend_slices_clip_g_fused_qkv_bias() {
let d: usize = 5;
let mut data: Vec<f32> = Vec::with_capacity(3 * d);
for component in 1..=3 {
for _ in 0..d {
data.push(component as f32 * 10.0);
}
}
let path = write_synthetic(
"sdxl-fused-qkv-b",
&[
(
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
vec![2, 2],
vec![0.1, 0.2, 0.3, 0.4],
),
(
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
vec![3 * d],
data,
),
],
);
let bundle = load_bundle(&path, Family::Sdxl).expect("partition sdxl");
let remap = build_sdxl_remap(&bundle).expect("build remap");
let backend = SingleFileBackend::from_sdxl_remap(&path, &remap).expect("backend");
let dev = Device::Cpu;
for (component, expected_value) in [(0usize, 10.0f32), (1, 20.0), (2, 30.0)] {
let diffusers_key = match component {
0 => "text_model.encoder.layers.0.self_attn.q_proj.bias",
1 => "text_model.encoder.layers.0.self_attn.k_proj.bias",
2 => "text_model.encoder.layers.0.self_attn.v_proj.bias",
_ => unreachable!(),
};
let t = SimpleBackend::get_unchecked(&backend, diffusers_key, DType::F32, &dev)
.unwrap_or_else(|e| panic!("bias slice for component {component}: {e}"));
assert_eq!(
t.dims(),
&[d],
"{diffusers_key}: 1D bias slice must be [d], not [3*d]",
);
let flat: Vec<f32> = t.to_vec1().unwrap();
assert!(
flat.iter().all(|&v| v == expected_value),
"{diffusers_key}: every value must be {expected_value} (got {flat:?})",
);
}
let _ = std::fs::remove_file(path);
}
#[test]
fn backend_unmapped_key_returns_error() {
let path = write_synthetic(
"sd15-empty-lookup",
&[(
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
vec![1],
vec![0.0],
)],
);
let bundle = load_bundle(&path, Family::Sd15).expect("partition sd15");
let remap = build_sd15_remap(&bundle).expect("build remap");
let backend = SingleFileBackend::from_sd15_remap(&path, &remap).expect("backend");
let dev = Device::Cpu;
let err = SimpleBackend::get_unchecked(
&backend,
"totally.bogus.key.no.diffusers.path",
DType::F32,
&dev,
)
.expect_err("unmapped key must error");
assert!(
err.to_string().contains("no rename rule"),
"expected legible error mentioning 'no rename rule', got: {err}",
);
let _ = std::fs::remove_file(path);
}
#[test]
fn sdxl_clip_l_scoped_backend_returns_clip_l_tensor_when_keys_collide_with_clip_g() {
let d_l: usize = 4;
let d_g: usize = 6;
let l_data: Vec<f32> = (0..d_l * d_l).map(|i| 0.5 + i as f32 * 0.1).collect();
let g_qkv_data: Vec<f32> = (0..3 * d_g * d_g).map(|i| 10.0 + i as f32).collect();
let path = write_synthetic(
"sdxl-no-collision-clip-l",
&[
(
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
vec![d_l, d_l],
l_data.clone(),
),
(
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
vec![3 * d_g, d_g],
g_qkv_data,
),
],
);
let bundle = load_bundle(&path, Family::Sdxl).expect("partition sdxl");
let remap = build_sdxl_remap(&bundle).expect("build remap");
let all_in_one = SingleFileBackend::from_sdxl_remap(&path, &remap).expect("backend");
let collided = SimpleBackend::get_unchecked(
&all_in_one,
"text_model.encoder.layers.0.self_attn.q_proj.weight",
DType::F32,
&Device::Cpu,
)
.expect("lookup");
assert_eq!(
collided.dims(),
&[d_g, d_g],
"all-in-one from_sdxl_remap must still exhibit the collision (CLIP-G wins) — \
this is the bug that motivates the scoped factories",
);
let backend_l =
SingleFileBackend::from_sdxl_clip_l(&path, &remap).expect("clip-l scoped backend");
let t_l = SimpleBackend::get_unchecked(
&backend_l,
"text_model.encoder.layers.0.self_attn.q_proj.weight",
DType::F32,
&Device::Cpu,
)
.expect("clip-l scoped lookup");
assert_eq!(
t_l.dims(),
&[d_l, d_l],
"CLIP-L scoped backend must return CLIP-L's [d_l, d_l] tensor, \
not CLIP-G's [d_g, d_g] slice — collision elimination is the \
whole point of the scoped factory",
);
let flat: Vec<f32> = t_l.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(flat, l_data, "values must match the CLIP-L source tensor");
let backend_g =
SingleFileBackend::from_sdxl_clip_g(&path, &remap).expect("clip-g scoped backend");
let t_g = SimpleBackend::get_unchecked(
&backend_g,
"text_model.encoder.layers.0.self_attn.q_proj.weight",
DType::F32,
&Device::Cpu,
)
.expect("clip-g scoped lookup");
assert_eq!(
t_g.dims(),
&[d_g, d_g],
"CLIP-G scoped backend keeps the slice semantics — q_proj is the \
0th component of the [3*d_g, d_g] in_proj_weight slab",
);
let _ = std::fs::remove_file(path);
}
#[test]
fn sdxl_clip_l_scoped_backend_excludes_clip_g_only_keys() {
let path = write_synthetic(
"sdxl-clip-l-isolation",
&[
(
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
vec![2, 2],
vec![0.1, 0.2, 0.3, 0.4],
),
(
"conditioner.embedders.1.model.text_projection",
vec![1],
vec![99.0],
),
],
);
let bundle = load_bundle(&path, Family::Sdxl).unwrap();
let remap = build_sdxl_remap(&bundle).unwrap();
let backend_l = SingleFileBackend::from_sdxl_clip_l(&path, &remap).unwrap();
assert!(
!SimpleBackend::contains_tensor(&backend_l, "text_projection.weight"),
"CLIP-L scoped backend must not advertise CLIP-G-only keys",
);
let backend_g = SingleFileBackend::from_sdxl_clip_g(&path, &remap).unwrap();
assert!(
SimpleBackend::contains_tensor(&backend_g, "text_projection.weight"),
"CLIP-G scoped backend must include CLIP-G's text_projection.weight",
);
let _ = std::fs::remove_file(path);
}
#[test]
fn backend_get_validates_shape_so_candle_attnblock_falls_through_to_conv_path() {
let c = 4usize;
let on_disk: Vec<f32> = (0..c * c).map(|i| 1.0 + i as f32 * 0.1).collect();
let path = write_synthetic(
"sd15-vae-attn-conv-shape",
&[
(
"first_stage_model.encoder.mid.attn_1.q.weight",
vec![c, c, 1, 1],
on_disk.clone(),
),
],
);
let bundle = load_bundle(&path, Family::Sd15).expect("partition");
let remap = build_sd15_remap(&bundle).expect("remap");
let backend = SingleFileBackend::from_sd15_vae(&path, &remap).expect("backend");
let dev = Device::Cpu;
let diffusers_key = "encoder.mid_block.attentions.0.to_q.weight";
let result_rank2 = SimpleBackend::get(
&backend,
candle_core::Shape::from((c, c)),
diffusers_key,
candle_nn::Init::Const(0.0),
DType::F32,
&dev,
);
let err = result_rank2.expect_err(
"rank-2 probe must error so candle's get_qkv_linear falls through to the conv path",
);
let msg = err.to_string();
assert!(
msg.contains("shape mismatch") || msg.contains("UnexpectedShape"),
"expected shape-mismatch error so candle's Err-arm fires; got: {msg}",
);
let t = SimpleBackend::get(
&backend,
candle_core::Shape::from((c, c, 1, 1)),
diffusers_key,
candle_nn::Init::Const(0.0),
DType::F32,
&dev,
)
.expect("rank-4 probe must succeed for A1111 Conv2d 1×1 weight");
assert_eq!(t.dims(), &[c, c, 1, 1]);
let flat: Vec<f32> = t.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(flat, on_disk);
let _ = std::fs::remove_file(path);
}
#[test]
fn backend_dtype_promotes_when_caller_requests_other_dtype() {
let path = write_synthetic(
"sd15-dtype",
&[(
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
vec![1],
vec![1.0],
)],
);
let bundle = load_bundle(&path, Family::Sd15).unwrap();
let remap = build_sd15_remap(&bundle).unwrap();
let backend = SingleFileBackend::from_sd15_remap(&path, &remap).unwrap();
let t = SimpleBackend::get_unchecked(
&backend,
"text_model.encoder.layers.0.self_attn.q_proj.weight",
DType::F16,
&Device::Cpu,
)
.expect("F16 lookup");
assert_eq!(t.dtype(), DType::F16);
let _ = std::fs::remove_file(path);
}
fn flux2_test_config() -> Flux2Config {
Flux2Config {
in_channels: 128,
vec_in_dim: 0,
context_in_dim: 7680,
hidden_size: 3072,
mlp_ratio: 3.0,
num_heads: 24,
depth: 1,
depth_single_blocks: 1,
axes_dim: vec![32, 32, 32, 32],
theta: 2000,
guidance_embed: false,
}
}
fn write_flux2_bfl_fixture(cfg: &Flux2Config, override_qkv: Option<Vec<f32>>) -> PathBuf {
write_flux2_bfl_fixture_with_rms(cfg, override_qkv, "scale")
}
fn write_flux2_bfl_fixture_with_rms(
cfg: &Flux2Config,
override_qkv: Option<Vec<f32>>,
rms_suffix: &str,
) -> PathBuf {
let prefix = "model.diffusion_model";
let mut tensors: Vec<(String, Vec<usize>, Vec<f32>)> = Vec::new();
let push = |t: &mut Vec<(String, Vec<usize>, Vec<f32>)>, key: &str, shape: Vec<usize>| {
let n: usize = shape.iter().product();
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
t.push((key.to_string(), shape, data));
};
for (suffix, shape) in [
("img_in.weight", vec![1, 1]),
("txt_in.weight", vec![1, 1]),
("time_in.in_layer.weight", vec![1, 1]),
("time_in.out_layer.weight", vec![1, 1]),
("final_layer.linear.weight", vec![1, 1]),
("final_layer.adaLN_modulation.1.weight", vec![2, 1]),
("double_stream_modulation_img.lin.weight", vec![1, 1]),
("double_stream_modulation_txt.lin.weight", vec![1, 1]),
("single_stream_modulation.lin.weight", vec![1, 1]),
] {
push(&mut tensors, &format!("{prefix}.{suffix}"), shape);
}
for i in 0..cfg.depth {
let img_qkv_key = format!("{prefix}.double_blocks.{i}.img_attn.qkv.weight");
let txt_qkv_key = format!("{prefix}.double_blocks.{i}.txt_attn.qkv.weight");
if let Some(data) = override_qkv.as_ref() {
let d = (data.len() / 3).isqrt();
tensors.push((img_qkv_key.clone(), vec![3 * d, d], data.clone()));
tensors.push((txt_qkv_key.clone(), vec![3 * d, d], data.clone()));
} else {
push(&mut tensors, &img_qkv_key, vec![3, 1]);
push(&mut tensors, &txt_qkv_key, vec![3, 1]);
}
let rms_q_img = format!("img_attn.norm.query_norm.{rms_suffix}");
let rms_k_img = format!("img_attn.norm.key_norm.{rms_suffix}");
let rms_q_txt = format!("txt_attn.norm.query_norm.{rms_suffix}");
let rms_k_txt = format!("txt_attn.norm.key_norm.{rms_suffix}");
for suffix in [
"img_attn.proj.weight",
rms_q_img.as_str(),
rms_k_img.as_str(),
"img_mlp.0.weight",
"img_mlp.2.weight",
"txt_attn.proj.weight",
rms_q_txt.as_str(),
rms_k_txt.as_str(),
"txt_mlp.0.weight",
"txt_mlp.2.weight",
] {
push(
&mut tensors,
&format!("{prefix}.double_blocks.{i}.{suffix}"),
vec![1, 1],
);
}
}
for i in 0..cfg.depth_single_blocks {
let rms_q = format!("norm.query_norm.{rms_suffix}");
let rms_k = format!("norm.key_norm.{rms_suffix}");
for suffix in [
"linear1.weight",
"linear2.weight",
rms_q.as_str(),
rms_k.as_str(),
] {
push(
&mut tensors,
&format!("{prefix}.single_blocks.{i}.{suffix}"),
vec![1, 1],
);
}
}
let refs: Vec<(&str, Vec<usize>, Vec<f32>)> = tensors
.iter()
.map(|(k, s, d)| (k.as_str(), s.clone(), d.clone()))
.collect();
write_synthetic("flux2-bfl", &refs)
}
#[test]
fn flux2_singlefile_backend_remaps_top_level_keys() {
let cfg = flux2_test_config();
let path = write_flux2_bfl_fixture(&cfg, None);
let backend =
SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("flux2 backend must load");
let dev = Device::Cpu;
for diffusers_key in [
"x_embedder.weight",
"context_embedder.weight",
"time_guidance_embed.timestep_embedder.linear_1.weight",
"time_guidance_embed.timestep_embedder.linear_2.weight",
"proj_out.weight",
"double_stream_modulation_img.linear.weight",
"double_stream_modulation_txt.linear.weight",
"single_stream_modulation.linear.weight",
] {
let t = SimpleBackend::get_unchecked(&backend, diffusers_key, DType::F32, &dev)
.unwrap_or_else(|e| panic!("{diffusers_key}: {e}"));
assert_eq!(t.dims(), &[1, 1], "{diffusers_key}: shape");
}
let t = SimpleBackend::get_unchecked(&backend, "norm_out.linear.weight", DType::F32, &dev)
.expect("norm_out.linear.weight must be accessible");
assert_eq!(
t.dims(),
&[2, 1],
"norm_out.linear.weight: SwapHalves preserves shape"
);
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_singlefile_backend_slices_double_block_qkv() {
let d = 4usize;
let mut data: Vec<f32> = Vec::with_capacity(3 * d * d);
for component in 1..=3 {
for _ in 0..d {
for _ in 0..d {
data.push(component as f32);
}
}
}
let cfg = flux2_test_config();
let path = write_flux2_bfl_fixture(&cfg, Some(data));
let backend = SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("backend");
let dev = Device::Cpu;
for (component, sentinel, name) in
[(0usize, 1.0f32, "to_q"), (1, 2.0, "to_k"), (2, 3.0, "to_v")]
{
let key = format!("transformer_blocks.0.attn.{name}.weight");
let t = SimpleBackend::get_unchecked(&backend, &key, DType::F32, &dev)
.unwrap_or_else(|e| panic!("{key}: {e}"));
assert_eq!(t.dims(), &[d, d], "{key}: slice shape");
let flat: Vec<f32> = t.flatten_all().unwrap().to_vec1().unwrap();
assert!(
flat.iter().all(|&v| v == sentinel),
"{key} (component {component}): values must all be {sentinel}, got {flat:?}",
);
}
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_singlefile_backend_loads_rms_norm_weight_suffix() {
let cfg = flux2_test_config();
let path = write_flux2_bfl_fixture_with_rms(&cfg, None, "weight");
let backend = SingleFileBackend::from_flux2_singlefile(&path, &cfg)
.expect("RMSNorm `.weight`-suffix checkpoint must load");
let dev = Device::Cpu;
for diffusers_key in [
"transformer_blocks.0.attn.norm_q.weight",
"transformer_blocks.0.attn.norm_k.weight",
"transformer_blocks.0.attn.norm_added_q.weight",
"transformer_blocks.0.attn.norm_added_k.weight",
"single_transformer_blocks.0.attn.norm_q.weight",
"single_transformer_blocks.0.attn.norm_k.weight",
] {
SimpleBackend::get_unchecked(&backend, diffusers_key, DType::F32, &dev)
.unwrap_or_else(|e| panic!("{diffusers_key}: {e}"));
}
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_singlefile_backend_slices_double_block_added_qkv() {
let d = 4usize;
let mut data: Vec<f32> = Vec::with_capacity(3 * d * d);
for component in 1..=3 {
for _ in 0..d {
for _ in 0..d {
data.push((component as f32) * 10.0);
}
}
}
let cfg = flux2_test_config();
let path = write_flux2_bfl_fixture(&cfg, Some(data));
let backend = SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("backend");
let dev = Device::Cpu;
for (component, sentinel, name) in [
(0usize, 10.0f32, "add_q_proj"),
(1, 20.0, "add_k_proj"),
(2, 30.0, "add_v_proj"),
] {
let key = format!("transformer_blocks.0.attn.{name}.weight");
let t = SimpleBackend::get_unchecked(&backend, &key, DType::F32, &dev)
.unwrap_or_else(|e| panic!("{key}: {e}"));
assert_eq!(t.dims(), &[d, d], "{key}: slice shape");
let flat: Vec<f32> = t.flatten_all().unwrap().to_vec1().unwrap();
assert!(
flat.iter().all(|&v| v == sentinel),
"{key} (component {component}): values must all be {sentinel}",
);
}
let _ = std::fs::remove_file(path);
}
fn write_typed_synthetic(
name: &str,
tensors: Vec<(String, SafeDtype, Vec<usize>, Vec<u8>)>,
) -> PathBuf {
let path = std::env::temp_dir().join(format!(
"mold-sf-backend-typed-{}-{}-{}.safetensors",
name,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
let owned: Vec<(String, SafeDtype, Vec<usize>, Vec<u8>)> = tensors;
let mut views: HashMap<String, TensorView<'_>> = HashMap::new();
for (key, dtype, shape, bytes) in owned.iter() {
views.insert(
key.clone(),
TensorView::new(*dtype, shape.clone(), bytes).unwrap(),
);
}
serialize_to_file(&views, &None, &path).unwrap();
path
}
fn one_layer_nvfp4_bytes(weight_scale_2: f32) -> Vec<(String, SafeDtype, Vec<usize>, Vec<u8>)> {
let weight_bytes = vec![0x22u8; 8];
let scale_bytes = vec![0x38u8];
let scale_2_bytes = weight_scale_2.to_le_bytes().to_vec();
let base = "model.diffusion_model.double_blocks.0.img_attn.proj";
vec![
(
format!("{base}.weight"),
SafeDtype::U8,
vec![1, 8],
weight_bytes,
),
(
format!("{base}.weight_scale"),
SafeDtype::F8_E4M3,
vec![1, 1],
scale_bytes,
),
(
format!("{base}.weight_scale_2"),
SafeDtype::F32,
vec![],
scale_2_bytes,
),
]
}
#[test]
fn flux2_singlefile_backend_accepts_nvfp4_format() {
let path = write_typed_synthetic("flux2-nvfp4-accept", one_layer_nvfp4_bytes(0.5));
let cfg = flux2_test_config();
let _backend = SingleFileBackend::from_flux2_singlefile(&path, &cfg)
.expect("NVFP4 checkpoint must now load (no longer rejected)");
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_nvfp4_emits_three_subkeys_per_layer() {
let path = write_typed_synthetic("flux2-nvfp4-subkeys", one_layer_nvfp4_bytes(0.5));
let cfg = flux2_test_config();
let backend = SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("nvfp4 backend");
let base = "transformer_blocks.0.attn.to_out.0.weight";
for sub in ["nvfp4_packed", "nvfp4_block_scales", "nvfp4_tensor_scale"] {
let key = format!("{base}.{sub}");
assert!(
SimpleBackend::contains_tensor(&backend, &key),
"{key}: NVFP4 sub-key must be present",
);
}
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_nvfp4_does_not_emit_bare_weight_for_nvfp4_layers() {
let path = write_typed_synthetic("flux2-nvfp4-no-bare", one_layer_nvfp4_bytes(0.5));
let cfg = flux2_test_config();
let backend = SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("nvfp4 backend");
for key in [
"transformer_blocks.0.attn.to_out.0.weight",
"transformer_blocks.0.attn.to_out.0.scale_weight",
] {
assert!(
!SimpleBackend::contains_tensor(&backend, key),
"{key}: bare weight / scale_weight must NOT exist for NVFP4 layers (sub-keys only)",
);
}
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_nvfp4_subkey_lookup_returns_cpu_tensors_with_correct_dtypes() {
let path = write_typed_synthetic("flux2-nvfp4-subkey-dtypes", one_layer_nvfp4_bytes(0.5));
let cfg = flux2_test_config();
let backend = SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("backend");
let base = "transformer_blocks.0.attn.to_out.0.weight";
let dev = Device::Cpu;
let packed = SimpleBackend::get_unchecked(
&backend,
&format!("{base}.nvfp4_packed"),
DType::U8,
&dev,
)
.expect("packed lookup");
assert_eq!(packed.dtype(), DType::U8);
assert_eq!(packed.dims(), &[1, 8]);
let scales = SimpleBackend::get_unchecked(
&backend,
&format!("{base}.nvfp4_block_scales"),
DType::F8E4M3,
&dev,
)
.expect("scales lookup");
assert_eq!(scales.dtype(), DType::F8E4M3);
let tscale = SimpleBackend::get_unchecked(
&backend,
&format!("{base}.nvfp4_tensor_scale"),
DType::F32,
&dev,
)
.expect("tensor_scale lookup");
assert_eq!(tscale.dtype(), DType::F32);
let v: f32 = tscale.to_scalar().unwrap();
assert!((v - 0.5).abs() < 1e-6, "tensor_scale must be 0.5, got {v}",);
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_nvfp4_slice_qkv_emits_meta() {
let weight_bytes: Vec<u8> = (0..3)
.flat_map(|n| {
let nibble = match n {
0 => 0x22,
1 => 0x44,
2 => 0x66,
_ => unreachable!(),
};
vec![nibble; 8]
})
.collect();
let scale_bytes = vec![0x38u8; 3]; let scale_2 = 0.5f32;
let scale_2_bytes = scale_2.to_le_bytes().to_vec();
let qkv_base = "model.diffusion_model.double_blocks.0.img_attn.qkv";
let tensors = vec![
(
format!("{qkv_base}.weight"),
SafeDtype::U8,
vec![3, 8],
weight_bytes,
),
(
format!("{qkv_base}.weight_scale"),
SafeDtype::F8_E4M3,
vec![3, 1],
scale_bytes,
),
(
format!("{qkv_base}.weight_scale_2"),
SafeDtype::F32,
vec![],
scale_2_bytes,
),
];
let path = write_typed_synthetic("flux2-nvfp4-slice-meta", tensors);
let cfg = flux2_test_config();
let backend = SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("backend");
for (component, comp_name) in [(0u32, "to_q"), (1, "to_k"), (2, "to_v")] {
let meta_key = format!("transformer_blocks.0.attn.{comp_name}.weight.nvfp4_slice_meta");
assert!(
SimpleBackend::contains_tensor(&backend, &meta_key),
"{meta_key}: slice meta sub-key must be present",
);
let meta = SimpleBackend::get_unchecked(&backend, &meta_key, DType::U32, &Device::Cpu)
.expect("slice meta lookup");
assert_eq!(meta.dtype(), DType::U32);
assert_eq!(meta.dims(), &[3]);
let v: Vec<u32> = meta.to_vec1().unwrap();
assert_eq!(
v,
vec![0u32, component, 3],
"{meta_key}: meta must encode [axis=0, component={component}, num_components=3]",
);
}
for comp_name in ["to_q", "to_k", "to_v"] {
let packed_key = format!("transformer_blocks.0.attn.{comp_name}.weight.nvfp4_packed");
assert!(
SimpleBackend::contains_tensor(&backend, &packed_key),
"{packed_key}: packed sub-key must be present for sliced QKV component",
);
}
let _ = std::fs::remove_file(path);
}
#[test]
#[ignore = "requires MOLD_NVFP4_PROBE_PATH env var pointing at a real NVFP4 .safetensors"]
fn flux2_nvfp4_real_file_loads_full_klein_9b() {
use crate::flux2::transformer::Flux2Transformer;
use std::time::Instant;
let path = match std::env::var("MOLD_NVFP4_PROBE_PATH") {
Ok(p) => std::path::PathBuf::from(p),
Err(_) => {
eprintln!("skipping: MOLD_NVFP4_PROBE_PATH not set");
return;
}
};
assert!(
path.is_file(),
"MOLD_NVFP4_PROBE_PATH must point at a real file (got {})",
path.display(),
);
let size_gb = std::fs::metadata(&path).unwrap().len() as f64 / 1e9;
eprintln!("probing NVFP4 file: {} ({:.2} GB)", path.display(), size_gb,);
let cfg = crate::flux2::Flux2Config::klein_9b();
let t0 = Instant::now();
let backend =
SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("backend construction");
eprintln!(" backend constructed in {:?}", t0.elapsed());
let dev = Device::Cpu;
let vb = candle_nn::VarBuilder::from_backend(Box::new(backend), DType::BF16, dev.clone());
let t1 = Instant::now();
let _transformer = Flux2Transformer::new(&cfg, vb)
.expect("Flux2Transformer::new must succeed end-to-end on the real NVFP4 checkpoint");
eprintln!(" transformer loaded in {:?}", t1.elapsed());
eprintln!(
" total time: {:?} (every NVFP4 layer set up streaming dequant; BF16 cache populated lazily on first forward)",
t0.elapsed(),
);
}
#[test]
fn flux2_singlefile_backend_rejects_non_bfl_native_under_nvfp4_routing_too() {
let path = write_synthetic(
"flux2-already-diffusers-no-prefix",
&[("x_embedder.weight", vec![1, 1], vec![1.0])],
);
let cfg = flux2_test_config();
let err = match SingleFileBackend::from_flux2_singlefile(&path, &cfg) {
Ok(_) => panic!("non-BFL-native must be rejected"),
Err(e) => e,
};
assert!(
err.to_string().contains("model.diffusion_model"),
"error must mention model.diffusion_model, got: {err}",
);
let _ = std::fs::remove_file(path);
}
#[test]
fn check_safetensors_not_truncated_passes_for_intact_file() {
let path = write_synthetic(
"intact",
&[(
"model.diffusion_model.img_in.weight",
vec![2, 2],
vec![1.0, 2.0, 3.0, 4.0],
)],
);
check_safetensors_not_truncated(&path).expect("intact file must validate");
let _ = std::fs::remove_file(path);
}
#[test]
fn check_safetensors_not_truncated_flags_short_file() {
let path = write_synthetic(
"truncated",
&[(
"model.diffusion_model.img_in.weight",
vec![4, 4],
(0..16).map(|i| i as f32).collect(),
)],
);
let full_size = std::fs::metadata(&path).unwrap().len();
let truncated_size = full_size - 16; let f = std::fs::OpenOptions::new().write(true).open(&path).unwrap();
f.set_len(truncated_size).unwrap();
drop(f);
let err =
check_safetensors_not_truncated(&path).expect_err("truncated file must be rejected");
let msg = err.to_string();
assert!(
msg.contains("truncated"),
"error must say 'truncated', got: {msg}",
);
assert!(
msg.contains("missing") || msg.contains("Re-download") || msg.contains("re-fetch"),
"error must hint at re-downloading, got: {msg}",
);
let _ = std::fs::remove_file(path);
}
#[test]
fn from_flux2_singlefile_surfaces_truncation_clearly() {
let cfg = flux2_test_config();
let path = write_flux2_bfl_fixture(&cfg, None);
let full_size = std::fs::metadata(&path).unwrap().len();
let f = std::fs::OpenOptions::new().write(true).open(&path).unwrap();
f.set_len(full_size - 4).unwrap();
drop(f);
let err = match SingleFileBackend::from_flux2_singlefile(&path, &cfg) {
Ok(_) => panic!("truncated Flux.2 single-file must be rejected before mmap"),
Err(e) => e,
};
let chained = format!("{err:#}");
assert!(
chained.contains("validate single-file checkpoint") && chained.contains("truncated"),
"expected outer wrapper + truncated root, got: {chained}",
);
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_singlefile_backend_rejects_non_bfl_native() {
let path = write_synthetic(
"flux2-already-diffusers",
&[("x_embedder.weight", vec![1, 1], vec![1.0])],
);
let cfg = flux2_test_config();
let err = match SingleFileBackend::from_flux2_singlefile(&path, &cfg) {
Ok(_) => panic!("non-BFL-native must be rejected"),
Err(e) => e,
};
assert!(
err.to_string().contains("model.diffusion_model"),
"error must mention model.diffusion_model, got: {err}",
);
let _ = std::fs::remove_file(path);
}
#[test]
fn flux2_singlefile_backend_swaps_ada_ln_halves_for_diffusers_ordering() {
let n = 4usize; let mut ada_data: Vec<f32> = vec![10.0f32; n / 2]; ada_data.extend(vec![20.0f32; n / 2]);
let cfg = flux2_test_config();
let prefix = "model.diffusion_model";
let mut tensors: Vec<(String, Vec<usize>, Vec<f32>)> = Vec::new();
for suffix in [
"img_in.weight",
"txt_in.weight",
"time_in.in_layer.weight",
"time_in.out_layer.weight",
"final_layer.linear.weight",
"double_stream_modulation_img.lin.weight",
"double_stream_modulation_txt.lin.weight",
"single_stream_modulation.lin.weight",
] {
tensors.push((format!("{prefix}.{suffix}"), vec![1, 1], vec![0.0f32]));
}
tensors.push((
format!("{prefix}.final_layer.adaLN_modulation.1.weight"),
vec![n, 1],
ada_data,
));
for i in 0..cfg.depth {
for suffix in [
"img_attn.qkv.weight",
"txt_attn.qkv.weight",
"img_attn.proj.weight",
"img_attn.norm.query_norm.scale",
"img_attn.norm.key_norm.scale",
"img_mlp.0.weight",
"img_mlp.2.weight",
"txt_attn.proj.weight",
"txt_attn.norm.query_norm.scale",
"txt_attn.norm.key_norm.scale",
"txt_mlp.0.weight",
"txt_mlp.2.weight",
] {
tensors.push((
format!("{prefix}.double_blocks.{i}.{suffix}"),
vec![3, 1],
vec![0.0; 3],
));
}
}
for i in 0..cfg.depth_single_blocks {
for suffix in [
"single_blocks.attn.to_qkv_mlp_proj.weight",
"single_blocks.attn.to_out.weight",
"single_blocks.attn.norm.query_norm.scale",
"single_blocks.attn.norm.key_norm.scale",
] {
let _ = i; tensors.push((
format!("{prefix}.single_blocks.{i}.{suffix}"),
vec![1, 1],
vec![0.0],
));
}
}
let path = write_synthetic_with_tensors("flux2-ada-swap-test", &tensors);
let backend =
SingleFileBackend::from_flux2_singlefile(&path, &cfg).expect("backend must load");
let dev = Device::Cpu;
let t = SimpleBackend::get_unchecked(&backend, "norm_out.linear.weight", DType::F32, &dev)
.expect("norm_out.linear.weight must be accessible");
assert_eq!(t.dims(), &[n, 1], "SwapHalves must preserve shape ({n}, 1)");
let vals: Vec<f32> = t.flatten_all().unwrap().to_vec1().unwrap();
for (i, &v) in vals[..n / 2].iter().enumerate() {
assert!(
(v - 20.0).abs() < 1e-6,
"row {i}: expected 20.0 (scale, now first) after swap, got {v}",
);
}
for (i, &v) in vals[n / 2..].iter().enumerate() {
assert!(
(v - 10.0).abs() < 1e-6,
"row {i}: expected 10.0 (shift, now second) after swap, got {v}",
);
}
let _ = std::fs::remove_file(path);
}
}