use crate::loader::single_file::SingleFileBundle;
use std::collections::BTreeMap;
use thiserror::Error;
#[derive(Debug, Default, Clone)]
pub struct Sd15Remap {
pub unet: BTreeMap<String, String>,
pub vae: BTreeMap<String, String>,
pub clip_l: BTreeMap<String, String>,
pub unmapped: Vec<String>,
}
#[derive(Debug, Error)]
pub enum RemapError {
#[error("placeholder — no fatal rename failures in SD1.5 yet")]
#[allow(dead_code)]
Placeholder,
}
pub fn apply_sd15_unet_rename(a1111_key: &str) -> Option<String> {
let inner = a1111_key.strip_prefix("model.diffusion_model.")?;
if let Some(out) = match inner {
"time_embed.0.weight" => Some("time_embedding.linear_1.weight"),
"time_embed.0.bias" => Some("time_embedding.linear_1.bias"),
"time_embed.2.weight" => Some("time_embedding.linear_2.weight"),
"time_embed.2.bias" => Some("time_embedding.linear_2.bias"),
"out.0.weight" => Some("conv_norm_out.weight"),
"out.0.bias" => Some("conv_norm_out.bias"),
"out.2.weight" => Some("conv_out.weight"),
"out.2.bias" => Some("conv_out.bias"),
_ => None,
} {
return Some(out.to_string());
}
if let Some(rest) = inner.strip_prefix("input_blocks.") {
let (block_idx, rest) = split_idx(rest)?;
let (sub_idx, suffix) = split_idx(rest)?;
return rename_unet_input_block(block_idx, sub_idx, suffix);
}
if let Some(rest) = inner.strip_prefix("middle_block.") {
let (sub_idx, suffix) = split_idx(rest)?;
return rename_unet_middle_block(sub_idx, suffix);
}
if let Some(rest) = inner.strip_prefix("output_blocks.") {
let (block_idx, rest) = split_idx(rest)?;
let (sub_idx, suffix) = split_idx(rest)?;
return rename_unet_output_block(block_idx, sub_idx, suffix);
}
None
}
pub fn apply_sd15_vae_rename(a1111_key: &str) -> Option<String> {
crate::loader::vae_keys::apply_vae_rename(a1111_key)
}
pub fn apply_sd15_clip_l_rename(a1111_key: &str) -> Option<String> {
let stripped = a1111_key.strip_prefix("cond_stage_model.transformer.")?;
if !stripped.starts_with("text_model.") && stripped != "text_model" {
return None;
}
Some(stripped.to_string())
}
pub fn build_sd15_remap(bundle: &SingleFileBundle) -> Result<Sd15Remap, RemapError> {
let mut out = Sd15Remap::default();
apply_into(
&bundle.unet_keys,
&mut out.unet,
&mut out.unmapped,
apply_sd15_unet_rename,
);
apply_into(
&bundle.vae_keys,
&mut out.vae,
&mut out.unmapped,
apply_sd15_vae_rename,
);
apply_into(
&bundle.clip_l_keys,
&mut out.clip_l,
&mut out.unmapped,
apply_sd15_clip_l_rename,
);
Ok(out)
}
fn apply_into<F>(
src: &[String],
dst: &mut BTreeMap<String, String>,
unmapped: &mut Vec<String>,
rename: F,
) where
F: Fn(&str) -> Option<String>,
{
for key in src {
match rename(key) {
Some(diffusers_key) => {
dst.insert(diffusers_key, key.clone());
}
None => unmapped.push(key.clone()),
}
}
}
fn split_idx(s: &str) -> Option<(usize, &str)> {
let (head, tail) = s.split_once('.')?;
let idx: usize = head.parse().ok()?;
Some((idx, tail))
}
fn rename_unet_input_block(block_idx: usize, sub_idx: usize, suffix: &str) -> Option<String> {
if block_idx == 0 && sub_idx == 0 {
return Some(format!("conv_in.{suffix}"));
}
let stage_idx = (block_idx - 1) / 3; let in_stage = (block_idx - 1) % 3;
if stage_idx == 3 {
let resnet_idx = block_idx - 10;
if sub_idx == 0 {
return Some(format!(
"down_blocks.3.resnets.{resnet_idx}.{}",
rename_resnet_inner(suffix)?
));
}
return None;
}
match (in_stage, sub_idx) {
(0, 0) | (1, 0) => Some(format!(
"down_blocks.{stage_idx}.resnets.{in_stage}.{}",
rename_resnet_inner(suffix)?
)),
(0, 1) | (1, 1) => Some(format!(
"down_blocks.{stage_idx}.attentions.{in_stage}.{suffix}",
)),
(2, 0) => Some(format!(
"down_blocks.{stage_idx}.downsamplers.0.{}",
rename_downsampler_inner(suffix)?
)),
_ => None,
}
}
fn rename_unet_middle_block(sub_idx: usize, suffix: &str) -> Option<String> {
match sub_idx {
0 => Some(format!(
"mid_block.resnets.0.{}",
rename_resnet_inner(suffix)?
)),
1 => Some(format!("mid_block.attentions.0.{suffix}")),
2 => Some(format!(
"mid_block.resnets.1.{}",
rename_resnet_inner(suffix)?
)),
_ => None,
}
}
fn rename_unet_output_block(block_idx: usize, sub_idx: usize, suffix: &str) -> Option<String> {
let stage_idx = block_idx / 3;
let resnet_idx = block_idx % 3;
if stage_idx == 0 {
match sub_idx {
0 => Some(format!(
"up_blocks.0.resnets.{resnet_idx}.{}",
rename_resnet_inner(suffix)?
)),
1 if block_idx == 2 => Some(format!(
"up_blocks.0.upsamplers.0.{}",
rename_upsampler_inner(suffix)?
)),
_ => None,
}
} else {
match sub_idx {
0 => Some(format!(
"up_blocks.{stage_idx}.resnets.{resnet_idx}.{}",
rename_resnet_inner(suffix)?
)),
1 => Some(format!(
"up_blocks.{stage_idx}.attentions.{resnet_idx}.{suffix}",
)),
2 if resnet_idx == 2 && stage_idx != 3 => Some(format!(
"up_blocks.{stage_idx}.upsamplers.0.{}",
rename_upsampler_inner(suffix)?
)),
_ => None,
}
}
}
fn rename_resnet_inner(suffix: &str) -> Option<String> {
Some(match suffix {
"in_layers.0.weight" => "norm1.weight".to_string(),
"in_layers.0.bias" => "norm1.bias".to_string(),
"in_layers.2.weight" => "conv1.weight".to_string(),
"in_layers.2.bias" => "conv1.bias".to_string(),
"emb_layers.1.weight" => "time_emb_proj.weight".to_string(),
"emb_layers.1.bias" => "time_emb_proj.bias".to_string(),
"out_layers.0.weight" => "norm2.weight".to_string(),
"out_layers.0.bias" => "norm2.bias".to_string(),
"out_layers.3.weight" => "conv2.weight".to_string(),
"out_layers.3.bias" => "conv2.bias".to_string(),
"skip_connection.weight" => "conv_shortcut.weight".to_string(),
"skip_connection.bias" => "conv_shortcut.bias".to_string(),
_ => return None,
})
}
fn rename_downsampler_inner(suffix: &str) -> Option<String> {
suffix
.strip_prefix("op.")
.map(|tail| format!("conv.{tail}"))
}
fn rename_upsampler_inner(suffix: &str) -> Option<String> {
Some(suffix.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unet_input_block_0_to_conv_in() {
assert_eq!(
apply_sd15_unet_rename("model.diffusion_model.input_blocks.0.0.weight").as_deref(),
Some("conv_in.weight"),
);
assert_eq!(
apply_sd15_unet_rename("model.diffusion_model.input_blocks.0.0.bias").as_deref(),
Some("conv_in.bias"),
);
}
#[test]
fn unet_middle_block_attention_transformer_block() {
assert_eq!(
apply_sd15_unet_rename(
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight",
)
.as_deref(),
Some("mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight"),
);
}
#[test]
fn unet_output_block_with_upsampler() {
assert_eq!(
apply_sd15_unet_rename("model.diffusion_model.output_blocks.2.0.in_layers.0.weight",)
.as_deref(),
Some("up_blocks.0.resnets.2.norm1.weight"),
);
assert_eq!(
apply_sd15_unet_rename("model.diffusion_model.output_blocks.2.1.conv.weight")
.as_deref(),
Some("up_blocks.0.upsamplers.0.conv.weight"),
);
assert_eq!(
apply_sd15_unet_rename("model.diffusion_model.output_blocks.5.2.conv.weight")
.as_deref(),
Some("up_blocks.1.upsamplers.0.conv.weight"),
);
}
#[test]
fn vae_encoder_down_block_resnet_norm1() {
assert_eq!(
apply_sd15_vae_rename("first_stage_model.encoder.down.0.block.0.norm1.weight",)
.as_deref(),
Some("encoder.down_blocks.0.resnets.0.norm1.weight"),
);
}
#[test]
fn clip_l_text_model_self_attn_q_proj() {
assert_eq!(
apply_sd15_clip_l_rename(
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
)
.as_deref(),
Some("text_model.encoder.layers.0.self_attn.q_proj.weight"),
);
}
#[test]
fn unrecognized_keys_return_none() {
assert!(apply_sd15_unet_rename("denoiser.sigmas").is_none());
assert!(apply_sd15_unet_rename("model.diffusion_model.input_blocks.0.0.unknown").is_some());
assert!(apply_sd15_unet_rename(
"model.diffusion_model.input_blocks.99.0.in_layers.0.weight"
)
.is_none());
assert!(apply_sd15_vae_rename("first_stage_model.unknown.thing").is_none());
assert!(apply_sd15_clip_l_rename("cond_stage_model.unrelated.thing").is_none());
}
#[test]
fn build_sd15_remap_routes_keys_per_component() {
use crate::loader::single_file::{load, SingleFileBundle};
use mold_catalog::families::Family;
use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
use std::collections::HashMap;
use std::path::PathBuf;
let path: PathBuf = std::env::temp_dir().join(format!(
"mold-loader-sd15-remap-{}-{}.safetensors",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
let keys: &[&str] = &[
"model.diffusion_model.input_blocks.0.0.weight",
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight",
"model.diffusion_model.output_blocks.5.2.conv.weight",
"first_stage_model.encoder.down.0.block.0.norm1.weight",
"first_stage_model.decoder.up.3.block.1.conv1.weight",
"first_stage_model.quant_conv.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
"cond_stage_model.transformer.text_model.final_layer_norm.weight",
"denoiser.sigmas",
];
let f32_zero = 0.0f32.to_le_bytes().to_vec();
let buffers: Vec<Vec<u8>> = keys.iter().map(|_| f32_zero.clone()).collect();
let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
for (key, buf) in keys.iter().zip(buffers.iter()) {
tensors.insert(
(*key).to_string(),
TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
);
}
serialize_to_file(&tensors, &None, &path).unwrap();
let bundle: SingleFileBundle = load(&path, Family::Sd15).expect("load partition");
let remap = build_sd15_remap(&bundle).expect("build remap");
assert_eq!(
remap.unet.get("conv_in.weight").map(|s| s.as_str()),
Some("model.diffusion_model.input_blocks.0.0.weight"),
);
assert_eq!(
remap
.unet
.get("mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight")
.map(|s| s.as_str()),
Some("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight"),
);
assert_eq!(
remap
.unet
.get("up_blocks.1.upsamplers.0.conv.weight")
.map(|s| s.as_str()),
Some("model.diffusion_model.output_blocks.5.2.conv.weight"),
);
assert!(remap
.vae
.contains_key("encoder.down_blocks.0.resnets.0.norm1.weight"));
assert!(remap
.vae
.contains_key("decoder.up_blocks.0.resnets.1.conv1.weight"));
assert!(remap.vae.contains_key("quant_conv.weight"));
assert!(remap
.clip_l
.contains_key("text_model.encoder.layers.0.self_attn.q_proj.weight"));
assert!(remap
.clip_l
.contains_key("text_model.embeddings.token_embedding.weight"));
assert!(remap
.clip_l
.contains_key("text_model.final_layer_norm.weight"));
assert!(
remap.unmapped.is_empty(),
"expected no unmapped keys, got {:?}",
remap.unmapped
);
let _ = std::fs::remove_file(path);
}
}