use crate::loader::single_file::SingleFileBundle;
use crate::loader::vae_keys::apply_vae_rename;
use std::collections::BTreeMap;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RenameOutput {
Direct(String),
FusedSlice {
diffusers_key: String,
axis: usize,
component: usize,
num_components: usize,
},
}
#[derive(Debug, Default, Clone)]
pub struct SdxlRemap {
pub unet: BTreeMap<String, String>,
pub vae: BTreeMap<String, String>,
pub clip_l: BTreeMap<String, String>,
pub clip_g: BTreeMap<String, (String, RenameOutput)>,
pub unmapped: Vec<String>,
}
#[derive(Debug, Error)]
pub enum RemapError {
#[error("placeholder — no fatal rename failures in SDXL yet")]
#[allow(dead_code)]
Placeholder,
}
pub fn apply_sdxl_unet_rename(a1111_key: &str) -> Option<String> {
let inner = a1111_key.strip_prefix("model.diffusion_model.")?;
if let Some(out) = match inner {
"time_embed.0.weight" => Some("time_embedding.linear_1.weight"),
"time_embed.0.bias" => Some("time_embedding.linear_1.bias"),
"time_embed.2.weight" => Some("time_embedding.linear_2.weight"),
"time_embed.2.bias" => Some("time_embedding.linear_2.bias"),
"out.0.weight" => Some("conv_norm_out.weight"),
"out.0.bias" => Some("conv_norm_out.bias"),
"out.2.weight" => Some("conv_out.weight"),
"out.2.bias" => Some("conv_out.bias"),
"label_emb.0.0.weight" => Some("add_embedding.linear_1.weight"),
"label_emb.0.0.bias" => Some("add_embedding.linear_1.bias"),
"label_emb.0.2.weight" => Some("add_embedding.linear_2.weight"),
"label_emb.0.2.bias" => Some("add_embedding.linear_2.bias"),
_ => None,
} {
return Some(out.to_string());
}
if let Some(rest) = inner.strip_prefix("input_blocks.") {
let (block_idx, rest) = split_idx(rest)?;
let (sub_idx, suffix) = split_idx(rest)?;
return rename_unet_input_block(block_idx, sub_idx, suffix);
}
if let Some(rest) = inner.strip_prefix("middle_block.") {
let (sub_idx, suffix) = split_idx(rest)?;
return rename_unet_middle_block(sub_idx, suffix);
}
if let Some(rest) = inner.strip_prefix("output_blocks.") {
let (block_idx, rest) = split_idx(rest)?;
let (sub_idx, suffix) = split_idx(rest)?;
return rename_unet_output_block(block_idx, sub_idx, suffix);
}
None
}
pub fn apply_sdxl_clip_l_rename(a1111_key: &str) -> Option<String> {
let stripped = a1111_key.strip_prefix("conditioner.embedders.0.transformer.")?;
if !stripped.starts_with("text_model.") && stripped != "text_model" {
return None;
}
Some(stripped.to_string())
}
pub fn apply_sdxl_clip_g_rename(a1111_key: &str) -> Option<Vec<RenameOutput>> {
let inner = a1111_key.strip_prefix("conditioner.embedders.1.model.")?;
if let Some(direct) = match inner {
"token_embedding.weight" => Some("text_model.embeddings.token_embedding.weight"),
"positional_embedding" => Some("text_model.embeddings.position_embedding.weight"),
"ln_final.weight" => Some("text_model.final_layer_norm.weight"),
"ln_final.bias" => Some("text_model.final_layer_norm.bias"),
"text_projection" => Some("text_projection.weight"),
_ => None,
} {
return Some(vec![RenameOutput::Direct(direct.to_string())]);
}
if let Some(rest) = inner.strip_prefix("transformer.resblocks.") {
let (layer_idx, suffix) = split_idx(rest)?;
return rename_clip_g_resblock(layer_idx, suffix);
}
None
}
pub fn build_sdxl_remap(bundle: &SingleFileBundle) -> Result<SdxlRemap, RemapError> {
let mut out = SdxlRemap::default();
apply_into(
&bundle.unet_keys,
&mut out.unet,
&mut out.unmapped,
apply_sdxl_unet_rename,
);
apply_into(
&bundle.vae_keys,
&mut out.vae,
&mut out.unmapped,
apply_vae_rename,
);
apply_into(
&bundle.clip_l_keys,
&mut out.clip_l,
&mut out.unmapped,
apply_sdxl_clip_l_rename,
);
if let Some(clip_g_keys) = bundle.clip_g_keys.as_ref() {
for key in clip_g_keys {
match apply_sdxl_clip_g_rename(key) {
Some(outputs) => {
for output in outputs {
let diffusers_key = match &output {
RenameOutput::Direct(k) => k.clone(),
RenameOutput::FusedSlice { diffusers_key, .. } => diffusers_key.clone(),
};
out.clip_g.insert(diffusers_key, (key.clone(), output));
}
}
None => out.unmapped.push(key.clone()),
}
}
}
Ok(out)
}
fn apply_into<F>(
src: &[String],
dst: &mut BTreeMap<String, String>,
unmapped: &mut Vec<String>,
rename: F,
) where
F: Fn(&str) -> Option<String>,
{
for key in src {
match rename(key) {
Some(diffusers_key) => {
dst.insert(diffusers_key, key.clone());
}
None => unmapped.push(key.clone()),
}
}
}
fn split_idx(s: &str) -> Option<(usize, &str)> {
let (head, tail) = s.split_once('.')?;
let idx: usize = head.parse().ok()?;
Some((idx, tail))
}
fn rename_unet_input_block(block_idx: usize, sub_idx: usize, suffix: &str) -> Option<String> {
if block_idx == 0 && sub_idx == 0 {
return Some(format!("conv_in.{suffix}"));
}
if (1..=3).contains(&block_idx) {
let in_stage = block_idx - 1; return match (in_stage, sub_idx) {
(0, 0) | (1, 0) => Some(format!(
"down_blocks.0.resnets.{in_stage}.{}",
rename_resnet_inner(suffix)?
)),
(2, 0) => Some(format!(
"down_blocks.0.downsamplers.0.{}",
rename_downsampler_inner(suffix)?
)),
_ => None,
};
}
if (4..=8).contains(&block_idx) {
let stage_idx = if block_idx <= 6 { 1 } else { 2 };
let stage_base = if stage_idx == 1 { 4 } else { 7 };
let in_stage = block_idx - stage_base;
if stage_idx == 1 && in_stage == 2 {
return match sub_idx {
0 => Some(format!(
"down_blocks.1.downsamplers.0.{}",
rename_downsampler_inner(suffix)?
)),
_ => None,
};
}
return match (in_stage, sub_idx) {
(0, 0) | (1, 0) => Some(format!(
"down_blocks.{stage_idx}.resnets.{in_stage}.{}",
rename_resnet_inner(suffix)?
)),
(0, 1) | (1, 1) => Some(format!(
"down_blocks.{stage_idx}.attentions.{in_stage}.{suffix}",
)),
_ => None,
};
}
None
}
fn rename_unet_middle_block(sub_idx: usize, suffix: &str) -> Option<String> {
match sub_idx {
0 => Some(format!(
"mid_block.resnets.0.{}",
rename_resnet_inner(suffix)?
)),
1 => Some(format!("mid_block.attentions.0.{suffix}")),
2 => Some(format!(
"mid_block.resnets.1.{}",
rename_resnet_inner(suffix)?
)),
_ => None,
}
}
fn rename_unet_output_block(block_idx: usize, sub_idx: usize, suffix: &str) -> Option<String> {
let stage_idx = block_idx / 3;
let resnet_idx = block_idx % 3;
match stage_idx {
0 | 1 => match sub_idx {
0 => Some(format!(
"up_blocks.{stage_idx}.resnets.{resnet_idx}.{}",
rename_resnet_inner(suffix)?
)),
1 => Some(format!(
"up_blocks.{stage_idx}.attentions.{resnet_idx}.{suffix}",
)),
2 if resnet_idx == 2 => Some(format!(
"up_blocks.{stage_idx}.upsamplers.0.{}",
rename_upsampler_inner(suffix)?
)),
_ => None,
},
2 => match sub_idx {
0 => Some(format!(
"up_blocks.2.resnets.{resnet_idx}.{}",
rename_resnet_inner(suffix)?
)),
_ => None,
},
_ => None,
}
}
fn rename_resnet_inner(suffix: &str) -> Option<String> {
Some(match suffix {
"in_layers.0.weight" => "norm1.weight".to_string(),
"in_layers.0.bias" => "norm1.bias".to_string(),
"in_layers.2.weight" => "conv1.weight".to_string(),
"in_layers.2.bias" => "conv1.bias".to_string(),
"emb_layers.1.weight" => "time_emb_proj.weight".to_string(),
"emb_layers.1.bias" => "time_emb_proj.bias".to_string(),
"out_layers.0.weight" => "norm2.weight".to_string(),
"out_layers.0.bias" => "norm2.bias".to_string(),
"out_layers.3.weight" => "conv2.weight".to_string(),
"out_layers.3.bias" => "conv2.bias".to_string(),
"skip_connection.weight" => "conv_shortcut.weight".to_string(),
"skip_connection.bias" => "conv_shortcut.bias".to_string(),
_ => return None,
})
}
fn rename_downsampler_inner(suffix: &str) -> Option<String> {
suffix
.strip_prefix("op.")
.map(|tail| format!("conv.{tail}"))
}
fn rename_upsampler_inner(suffix: &str) -> Option<String> {
Some(suffix.to_string())
}
fn rename_clip_g_resblock(layer_idx: usize, suffix: &str) -> Option<Vec<RenameOutput>> {
let layer = format!("text_model.encoder.layers.{layer_idx}");
if let Some(direct) = match suffix {
"ln_1.weight" => Some(format!("{layer}.layer_norm1.weight")),
"ln_1.bias" => Some(format!("{layer}.layer_norm1.bias")),
"ln_2.weight" => Some(format!("{layer}.layer_norm2.weight")),
"ln_2.bias" => Some(format!("{layer}.layer_norm2.bias")),
_ => None,
} {
return Some(vec![RenameOutput::Direct(direct)]);
}
if let Some(direct) = match suffix {
"attn.out_proj.weight" => Some(format!("{layer}.self_attn.out_proj.weight")),
"attn.out_proj.bias" => Some(format!("{layer}.self_attn.out_proj.bias")),
_ => None,
} {
return Some(vec![RenameOutput::Direct(direct)]);
}
if let Some(direct) = match suffix {
"mlp.c_fc.weight" => Some(format!("{layer}.mlp.fc1.weight")),
"mlp.c_fc.bias" => Some(format!("{layer}.mlp.fc1.bias")),
"mlp.c_proj.weight" => Some(format!("{layer}.mlp.fc2.weight")),
"mlp.c_proj.bias" => Some(format!("{layer}.mlp.fc2.bias")),
_ => None,
} {
return Some(vec![RenameOutput::Direct(direct)]);
}
let (kind, lookup) = match suffix {
"attn.in_proj_weight" => ("weight", true),
"attn.in_proj_bias" => ("bias", true),
_ => ("", false),
};
if lookup {
return Some(vec![
RenameOutput::FusedSlice {
diffusers_key: format!("{layer}.self_attn.q_proj.{kind}"),
axis: 0,
component: 0,
num_components: 3,
},
RenameOutput::FusedSlice {
diffusers_key: format!("{layer}.self_attn.k_proj.{kind}"),
axis: 0,
component: 1,
num_components: 3,
},
RenameOutput::FusedSlice {
diffusers_key: format!("{layer}.self_attn.v_proj.{kind}"),
axis: 0,
component: 2,
num_components: 3,
},
]);
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unet_input_block_0_to_conv_in() {
assert_eq!(
apply_sdxl_unet_rename("model.diffusion_model.input_blocks.0.0.weight").as_deref(),
Some("conv_in.weight"),
);
assert_eq!(
apply_sdxl_unet_rename("model.diffusion_model.input_blocks.0.0.bias").as_deref(),
Some("conv_in.bias"),
);
}
#[test]
fn unet_input_block_stage_0_resnet_only() {
assert_eq!(
apply_sdxl_unet_rename("model.diffusion_model.input_blocks.1.0.in_layers.0.weight",)
.as_deref(),
Some("down_blocks.0.resnets.0.norm1.weight"),
);
assert_eq!(
apply_sdxl_unet_rename("model.diffusion_model.input_blocks.3.0.op.weight").as_deref(),
Some("down_blocks.0.downsamplers.0.conv.weight"),
);
assert!(apply_sdxl_unet_rename(
"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight",
)
.is_none());
}
#[test]
fn unet_input_block_stage_2_attention_with_transformer_layer_5() {
assert_eq!(
apply_sdxl_unet_rename(
"model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_q.weight",
)
.as_deref(),
Some("down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_q.weight"),
);
}
#[test]
fn unet_output_block_top_stage_no_attention() {
assert_eq!(
apply_sdxl_unet_rename("model.diffusion_model.output_blocks.6.0.in_layers.0.weight",)
.as_deref(),
Some("up_blocks.2.resnets.0.norm1.weight"),
);
assert!(apply_sdxl_unet_rename(
"model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight",
)
.is_none());
assert_eq!(
apply_sdxl_unet_rename("model.diffusion_model.output_blocks.5.2.conv.weight")
.as_deref(),
Some("up_blocks.1.upsamplers.0.conv.weight"),
);
}
#[test]
fn unet_label_emb_to_add_embedding() {
assert_eq!(
apply_sdxl_unet_rename("model.diffusion_model.label_emb.0.0.weight").as_deref(),
Some("add_embedding.linear_1.weight"),
);
assert_eq!(
apply_sdxl_unet_rename("model.diffusion_model.label_emb.0.2.bias").as_deref(),
Some("add_embedding.linear_2.bias"),
);
}
#[test]
fn clip_l_strip_new_prefix() {
assert_eq!(
apply_sdxl_clip_l_rename(
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
)
.as_deref(),
Some("text_model.encoder.layers.0.self_attn.q_proj.weight"),
);
}
#[test]
fn clip_g_resblock_layer_norm_direct() {
let outputs = apply_sdxl_clip_g_rename(
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight",
)
.expect("ln_1 must rename");
assert_eq!(outputs.len(), 1);
assert_eq!(
outputs[0],
RenameOutput::Direct("text_model.encoder.layers.0.layer_norm1.weight".to_string()),
);
}
#[test]
fn clip_g_attn_in_proj_weight_splits_q_k_v() {
let outputs = apply_sdxl_clip_g_rename(
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
)
.expect("in_proj_weight must rename");
assert_eq!(outputs.len(), 3, "Q/K/V must produce three slice outputs");
let expected: Vec<RenameOutput> = vec![
RenameOutput::FusedSlice {
diffusers_key: "text_model.encoder.layers.0.self_attn.q_proj.weight".to_string(),
axis: 0,
component: 0,
num_components: 3,
},
RenameOutput::FusedSlice {
diffusers_key: "text_model.encoder.layers.0.self_attn.k_proj.weight".to_string(),
axis: 0,
component: 1,
num_components: 3,
},
RenameOutput::FusedSlice {
diffusers_key: "text_model.encoder.layers.0.self_attn.v_proj.weight".to_string(),
axis: 0,
component: 2,
num_components: 3,
},
];
assert_eq!(outputs, expected);
}
#[test]
fn clip_g_attn_in_proj_bias_splits_q_k_v() {
let outputs = apply_sdxl_clip_g_rename(
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
)
.expect("in_proj_bias must rename");
assert_eq!(outputs.len(), 3);
for (i, comp) in outputs.iter().enumerate() {
match comp {
RenameOutput::FusedSlice {
diffusers_key,
axis,
component,
num_components,
} => {
let expected_letter = ["q", "k", "v"][i];
assert_eq!(
diffusers_key,
&format!(
"text_model.encoder.layers.7.self_attn.{expected_letter}_proj.bias"
),
);
assert_eq!(*axis, 0);
assert_eq!(*component, i);
assert_eq!(*num_components, 3);
}
_ => panic!("expected FusedSlice for in_proj_bias, got {comp:?}"),
}
}
}
#[test]
fn clip_g_mlp_c_fc_renames_to_fc1() {
let outputs = apply_sdxl_clip_g_rename(
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight",
)
.expect("mlp.c_fc must rename");
assert_eq!(
outputs,
vec![RenameOutput::Direct(
"text_model.encoder.layers.3.mlp.fc1.weight".to_string()
)],
);
}
#[test]
fn clip_g_text_projection_to_text_projection_weight() {
let outputs = apply_sdxl_clip_g_rename("conditioner.embedders.1.model.text_projection")
.expect("text_projection must rename");
assert_eq!(
outputs,
vec![RenameOutput::Direct("text_projection.weight".to_string())],
);
}
#[test]
fn unrecognized_keys_return_none() {
assert!(apply_sdxl_unet_rename("denoiser.sigmas").is_none());
assert!(apply_sdxl_unet_rename(
"model.diffusion_model.input_blocks.99.0.in_layers.0.weight"
)
.is_none());
assert!(apply_sdxl_clip_l_rename("conditioner.unrelated.thing").is_none());
assert!(apply_sdxl_clip_g_rename("conditioner.embedders.1.model.unknown.thing").is_none());
assert!(apply_sdxl_clip_l_rename(
"cond_stage_model.transformer.text_model.final_layer_norm.weight",
)
.is_none());
}
#[test]
fn build_sdxl_remap_routes_keys_per_component_with_clip_g_fused_split() {
use crate::loader::single_file::{load, SingleFileBundle};
use mold_catalog::families::Family;
use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
use std::collections::HashMap;
use std::path::PathBuf;
let path: PathBuf = std::env::temp_dir().join(format!(
"mold-loader-sdxl-remap-{}-{}.safetensors",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
let keys: &[&str] = &[
"model.diffusion_model.input_blocks.0.0.weight",
"model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_q.weight",
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight",
"model.diffusion_model.output_blocks.5.2.conv.weight",
"model.diffusion_model.label_emb.0.0.weight",
"first_stage_model.encoder.down.0.block.0.norm1.weight",
"first_stage_model.decoder.up.3.block.1.conv1.weight",
"first_stage_model.quant_conv.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
"conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight",
"conditioner.embedders.0.transformer.text_model.final_layer_norm.weight",
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight",
"conditioner.embedders.1.model.text_projection",
"conditioner.embedders.1.model.token_embedding.weight",
"denoiser.sigmas",
];
let f32_zero = 0.0f32.to_le_bytes().to_vec();
let buffers: Vec<Vec<u8>> = keys.iter().map(|_| f32_zero.clone()).collect();
let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
for (key, buf) in keys.iter().zip(buffers.iter()) {
tensors.insert(
(*key).to_string(),
TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
);
}
serialize_to_file(&tensors, &None, &path).unwrap();
let bundle: SingleFileBundle = load(&path, Family::Sdxl).expect("load partition");
let remap = build_sdxl_remap(&bundle).expect("build remap");
assert_eq!(
remap.unet.get("conv_in.weight").map(|s| s.as_str()),
Some("model.diffusion_model.input_blocks.0.0.weight"),
);
assert_eq!(
remap
.unet
.get("down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_q.weight")
.map(|s| s.as_str()),
Some("model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_q.weight",),
);
assert_eq!(
remap
.unet
.get("up_blocks.1.upsamplers.0.conv.weight")
.map(|s| s.as_str()),
Some("model.diffusion_model.output_blocks.5.2.conv.weight"),
);
assert_eq!(
remap
.unet
.get("add_embedding.linear_1.weight")
.map(|s| s.as_str()),
Some("model.diffusion_model.label_emb.0.0.weight"),
);
assert!(remap
.vae
.contains_key("encoder.down_blocks.0.resnets.0.norm1.weight"));
assert!(remap
.vae
.contains_key("decoder.up_blocks.0.resnets.1.conv1.weight"));
assert!(remap.vae.contains_key("quant_conv.weight"));
assert!(remap
.clip_l
.contains_key("text_model.encoder.layers.0.self_attn.q_proj.weight"));
assert!(remap
.clip_l
.contains_key("text_model.embeddings.token_embedding.weight"));
assert!(remap
.clip_l
.contains_key("text_model.final_layer_norm.weight"));
let q = remap
.clip_g
.get("text_model.encoder.layers.0.self_attn.q_proj.weight")
.expect("q_proj must be present after fused split");
assert_eq!(
q.0.as_str(),
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
);
match &q.1 {
RenameOutput::FusedSlice {
axis,
component,
num_components,
..
} => {
assert_eq!(*axis, 0);
assert_eq!(*component, 0);
assert_eq!(*num_components, 3);
}
_ => panic!("q_proj must be FusedSlice, got {:?}", q.1),
}
let k = remap
.clip_g
.get("text_model.encoder.layers.0.self_attn.k_proj.weight")
.expect("k_proj");
let v = remap
.clip_g
.get("text_model.encoder.layers.0.self_attn.v_proj.weight")
.expect("v_proj");
assert_eq!(k.0, q.0);
assert_eq!(v.0, q.0);
match (&k.1, &v.1) {
(
RenameOutput::FusedSlice { component: kc, .. },
RenameOutput::FusedSlice { component: vc, .. },
) => {
assert_eq!(*kc, 1);
assert_eq!(*vc, 2);
}
_ => panic!("k_proj / v_proj must both be FusedSlice"),
}
assert!(remap
.clip_g
.contains_key("text_model.encoder.layers.0.layer_norm1.weight"));
assert!(remap.clip_g.contains_key("text_projection.weight"));
assert!(remap
.clip_g
.contains_key("text_model.embeddings.token_embedding.weight"));
assert!(
remap.unmapped.is_empty(),
"expected no unmapped keys, got {:?}",
remap.unmapped
);
let _ = std::fs::remove_file(path);
}
}