pub fn apply_vae_rename(a1111_key: &str) -> Option<String> {
let inner = a1111_key.strip_prefix("first_stage_model.")?;
if matches!(
inner,
"quant_conv.weight" | "quant_conv.bias" | "post_quant_conv.weight" | "post_quant_conv.bias"
) {
return Some(inner.to_string());
}
if let Some(rest) = inner.strip_prefix("encoder.") {
return rename_vae_half(rest, VaeHalf::Encoder).map(|s| format!("encoder.{s}"));
}
if let Some(rest) = inner.strip_prefix("decoder.") {
return rename_vae_half(rest, VaeHalf::Decoder).map(|s| format!("decoder.{s}"));
}
None
}
fn split_idx(s: &str) -> Option<(usize, &str)> {
let (head, tail) = s.split_once('.')?;
let idx: usize = head.parse().ok()?;
Some((idx, tail))
}
#[derive(Copy, Clone)]
enum VaeHalf {
Encoder,
Decoder,
}
fn rename_vae_half(suffix: &str, half: VaeHalf) -> Option<String> {
if let Some(out) = match suffix {
"conv_in.weight" => Some("conv_in.weight"),
"conv_in.bias" => Some("conv_in.bias"),
"conv_out.weight" => Some("conv_out.weight"),
"conv_out.bias" => Some("conv_out.bias"),
"norm_out.weight" => Some("conv_norm_out.weight"),
"norm_out.bias" => Some("conv_norm_out.bias"),
_ => None,
} {
return Some(out.to_string());
}
if let Some(rest) = suffix.strip_prefix("mid.") {
return rename_vae_mid(rest);
}
match half {
VaeHalf::Encoder => {
if let Some(rest) = suffix.strip_prefix("down.") {
let (stage, rest) = split_idx(rest)?;
if let Some(rest) = rest.strip_prefix("block.") {
let (block, tail) = split_idx(rest)?;
return Some(format!(
"down_blocks.{stage}.resnets.{block}.{}",
rename_resnet_inner_vae(tail)?
));
}
if let Some(tail) = rest.strip_prefix("downsample.") {
return Some(format!("down_blocks.{stage}.downsamplers.0.{tail}"));
}
}
None
}
VaeHalf::Decoder => {
if let Some(rest) = suffix.strip_prefix("up.") {
let (ldm_stage, rest) = split_idx(rest)?;
let diff_stage = 3usize.checked_sub(ldm_stage)?;
if let Some(rest) = rest.strip_prefix("block.") {
let (block, tail) = split_idx(rest)?;
return Some(format!(
"up_blocks.{diff_stage}.resnets.{block}.{}",
rename_resnet_inner_vae(tail)?
));
}
if let Some(tail) = rest.strip_prefix("upsample.") {
return Some(format!("up_blocks.{diff_stage}.upsamplers.0.{tail}"));
}
}
None
}
}
}
fn rename_vae_mid(suffix: &str) -> Option<String> {
if let Some(rest) = suffix.strip_prefix("block_1.") {
return Some(format!(
"mid_block.resnets.0.{}",
rename_resnet_inner_vae(rest)?
));
}
if let Some(rest) = suffix.strip_prefix("block_2.") {
return Some(format!(
"mid_block.resnets.1.{}",
rename_resnet_inner_vae(rest)?
));
}
if let Some(rest) = suffix.strip_prefix("attn_1.") {
return Some(format!(
"mid_block.attentions.0.{}",
rename_vae_mid_attn(rest)?
));
}
None
}
fn rename_vae_mid_attn(suffix: &str) -> Option<String> {
Some(match suffix {
"q.weight" => "to_q.weight".to_string(),
"q.bias" => "to_q.bias".to_string(),
"k.weight" => "to_k.weight".to_string(),
"k.bias" => "to_k.bias".to_string(),
"v.weight" => "to_v.weight".to_string(),
"v.bias" => "to_v.bias".to_string(),
"proj_out.weight" => "to_out.0.weight".to_string(),
"proj_out.bias" => "to_out.0.bias".to_string(),
"norm.weight" => "group_norm.weight".to_string(),
"norm.bias" => "group_norm.bias".to_string(),
_ => return None,
})
}
fn rename_resnet_inner_vae(suffix: &str) -> Option<String> {
Some(match suffix {
"norm1.weight" => "norm1.weight".to_string(),
"norm1.bias" => "norm1.bias".to_string(),
"conv1.weight" => "conv1.weight".to_string(),
"conv1.bias" => "conv1.bias".to_string(),
"norm2.weight" => "norm2.weight".to_string(),
"norm2.bias" => "norm2.bias".to_string(),
"conv2.weight" => "conv2.weight".to_string(),
"conv2.bias" => "conv2.bias".to_string(),
"nin_shortcut.weight" => "conv_shortcut.weight".to_string(),
"nin_shortcut.bias" => "conv_shortcut.bias".to_string(),
_ => return None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encoder_down_block_resnet_norm1() {
assert_eq!(
apply_vae_rename("first_stage_model.encoder.down.0.block.0.norm1.weight").as_deref(),
Some("encoder.down_blocks.0.resnets.0.norm1.weight"),
);
}
#[test]
fn decoder_up_stage_reversal() {
assert_eq!(
apply_vae_rename("first_stage_model.decoder.up.3.block.1.conv1.weight").as_deref(),
Some("decoder.up_blocks.0.resnets.1.conv1.weight"),
);
}
#[test]
fn quant_conv_pass_through() {
assert_eq!(
apply_vae_rename("first_stage_model.quant_conv.weight").as_deref(),
Some("quant_conv.weight"),
);
assert_eq!(
apply_vae_rename("first_stage_model.post_quant_conv.bias").as_deref(),
Some("post_quant_conv.bias"),
);
}
#[test]
fn mid_attn_proj_out_renames_to_to_out() {
assert_eq!(
apply_vae_rename("first_stage_model.encoder.mid.attn_1.proj_out.weight").as_deref(),
Some("encoder.mid_block.attentions.0.to_out.0.weight"),
);
}
#[test]
fn unknown_returns_none() {
assert!(apply_vae_rename("first_stage_model.unknown.thing").is_none());
assert!(apply_vae_rename("not_a_vae_prefix.thing").is_none());
}
}