use std::collections::HashMap;
use std::path::Path;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
use ferrotorch_nn::module::{Module, StateDict};
use ferrotorch_serialize::load_safetensors;
use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
use crate::config::VaeDecoderConfig;
use crate::unet::UNet2DConditionModel;
use crate::unet_config::UNet2DConditionConfig;
use crate::vae::VaeDecoder;
use crate::vae_encoder::{VaeEncoder, VaeEncoderConfig};
#[derive(Debug, Default, Clone)]
pub struct DropReport {
pub dropped: Vec<String>,
}
impl<T: Float> VaeDecoder<T> {
pub fn load_hf_state_dict(
&mut self,
hf_state: &StateDict<T>,
strict: bool,
) -> FerrotorchResult<DropReport> {
let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
let mut dropped: Vec<String> = Vec::new();
for (k, v) in hf_state {
let after_vae = k.strip_prefix("vae.").map_or_else(|| k.clone(), str::to_owned);
if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
remapped.insert(after_vae, v.clone());
continue;
}
if strict {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"VaeDecoder::load_hf_state_dict: key {k:?} is not under \
`post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
and strict mode is on. Pass strict=false to drop encoder / \
quant_conv keys."
),
});
}
dropped.push(k.clone());
}
dropped.sort();
self.load_state_dict(&remapped, strict)?;
Ok(DropReport { dropped })
}
}
impl<T: Float> UNet2DConditionModel<T> {
pub fn load_hf_state_dict(
&mut self,
hf_state: &StateDict<T>,
strict: bool,
) -> FerrotorchResult<DropReport> {
let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
let mut dropped: Vec<String> = Vec::new();
for (k, v) in hf_state {
let after_unet = k.strip_prefix("unet.").map_or_else(|| k.clone(), str::to_owned);
let is_unet_key = after_unet.starts_with("time_embedding.")
|| after_unet.starts_with("conv_in.")
|| after_unet.starts_with("down_blocks.")
|| after_unet.starts_with("mid_block.")
|| after_unet.starts_with("up_blocks.")
|| after_unet.starts_with("conv_norm_out.")
|| after_unet.starts_with("conv_out.");
if is_unet_key {
remapped.insert(after_unet, v.clone());
continue;
}
if strict {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
a UNet prefix (with optional `unet.`) and strict mode is on."
),
});
}
dropped.push(k.clone());
}
dropped.sort();
self.load_state_dict(&remapped, strict)?;
Ok(DropReport { dropped })
}
}
pub fn load_unet<T: Float>(
weights_path: &Path,
cfg: UNet2DConditionConfig,
strict: bool,
) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
let state =
load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
message: format!(
"load_unet: failed to decode safetensors {}: {e}",
weights_path.display()
),
})?;
let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
let report = unet.load_hf_state_dict(&state, strict)?;
Ok((unet, report))
}
fn load_safetensors_clip_filtered<T: Float>(
weights_path: &Path,
) -> FerrotorchResult<(StateDict<T>, bool)> {
use safetensors::SafeTensors;
let bytes =
std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
message: format!(
"load_safetensors_clip_filtered: failed to read {}: {e}",
weights_path.display()
),
})?;
let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
message: format!(
"load_safetensors_clip_filtered: failed to parse {}: {e}",
weights_path.display()
),
})?;
let mut keep: Vec<String> = Vec::new();
let mut had_position_ids = false;
for k in st.names() {
let s: &str = k.as_str();
if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
had_position_ids = true;
continue;
}
keep.push(String::from(s));
}
let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
Vec::with_capacity(keep.len());
for k in &keep {
let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
message: format!(
"load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
),
})?;
subset.push((k.clone(), v));
}
let serialized = safetensors::serialize(subset, &None).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
}
})?;
let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
})?;
std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
})?;
let state = load_safetensors::<T>(tmp.path())?;
Ok((state, had_position_ids))
}
pub fn load_clip_text_encoder<T: Float>(
weights_path: &Path,
cfg: ClipTextConfig,
strict: bool,
) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
let (mut state, had_position_ids) =
load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!(
"load_clip_text_encoder: failed to decode safetensors {}: {e}",
weights_path.display()
),
}
})?;
if had_position_ids {
let key = if state
.keys()
.any(|k| k.starts_with("text_model."))
{
"text_model.embeddings.position_ids".to_string()
} else {
"embeddings.position_ids".to_string()
};
state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
}
let mut enc = ClipTextEncoder::<T>::new(cfg)?;
let report = enc.load_hf_state_dict(&state, strict)?;
Ok((enc, report))
}
pub fn load_vae_decoder<T: Float>(
weights_path: &Path,
cfg: VaeDecoderConfig,
strict: bool,
) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
let state =
load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
message: format!(
"load_vae_decoder: failed to decode safetensors {}: {e}",
weights_path.display()
),
})?;
let mut decoder = VaeDecoder::<T>::new(cfg)?;
let report = decoder.load_hf_state_dict(&state, strict)?;
Ok((decoder, report))
}
impl<T: Float> VaeEncoder<T> {
pub fn load_hf_state_dict(
&mut self,
hf_state: &StateDict<T>,
strict: bool,
) -> FerrotorchResult<DropReport> {
let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
let mut dropped: Vec<String> = Vec::new();
for (k, v) in hf_state {
let after_vae = k.strip_prefix("vae.").map_or_else(|| k.clone(), str::to_owned);
if after_vae.starts_with("encoder.") || after_vae.starts_with("quant_conv.") {
remapped.insert(after_vae, v.clone());
continue;
}
if strict {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"VaeEncoder::load_hf_state_dict: key {k:?} is not under \
`encoder.*` / `quant_conv.*` (with optional `vae.` prefix) \
and strict mode is on. Pass strict=false to drop decoder / \
post_quant_conv keys."
),
});
}
dropped.push(k.clone());
}
dropped.sort();
self.load_state_dict(&remapped, strict)?;
Ok(DropReport { dropped })
}
}
pub fn load_vae_encoder<T: Float>(
weights_path: &Path,
cfg: VaeEncoderConfig,
strict: bool,
) -> FerrotorchResult<(VaeEncoder<T>, DropReport)> {
let state =
load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
message: format!(
"load_vae_encoder: failed to decode safetensors {}: {e}",
weights_path.display()
),
})?;
let mut encoder = VaeEncoder::<T>::new(cfg)?;
let report = encoder.load_hf_state_dict(&state, strict)?;
Ok((encoder, report))
}
#[cfg(test)]
mod tests {
use super::*;
use ferrotorch_core::{Tensor, TensorStorage};
use ferrotorch_serialize::save_safetensors;
use std::path::PathBuf;
fn tiny_cfg() -> VaeDecoderConfig {
VaeDecoderConfig {
out_channels: 3,
latent_channels: 4,
block_out_channels: vec![4, 8, 16, 16],
layers_per_block: 1,
norm_num_groups: 4,
sample_size: 8,
scaling_factor: 0.18215,
}
}
fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("model.safetensors");
let sd = v.state_dict();
save_safetensors(&sd, &path).unwrap();
(dir, path)
}
#[test]
fn round_trip_safetensors_into_decoder() {
let cfg = tiny_cfg();
let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
let (_d, p) = tmp_safetensors_from(&src);
let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
assert!(
report.dropped.is_empty(),
"round-trip should have empty drop list, got {:?}",
report.dropped
);
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 4]),
vec![1, 4, 1, 1],
false,
)
.unwrap();
let a = src.forward(&x).unwrap();
let b = dst.forward(&x).unwrap();
for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
assert!((x - y).abs() < 1e-5);
}
}
#[test]
fn load_hf_drops_encoder_keys_nonstrict() {
let cfg = tiny_cfg();
let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
let mut hf_sd: StateDict<f32> = v.state_dict();
hf_sd.insert(
"encoder.conv_in.weight".into(),
ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
);
hf_sd.insert(
"quant_conv.weight".into(),
ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
);
let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
assert_eq!(
rep.dropped,
vec![
"encoder.conv_in.weight".to_string(),
"quant_conv.weight".to_string(),
]
);
}
#[test]
fn load_hf_strict_rejects_encoder_keys() {
let cfg = tiny_cfg();
let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
let mut hf_sd: StateDict<f32> = HashMap::new();
hf_sd.insert(
"encoder.conv_in.weight".into(),
ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
);
assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
}
#[test]
fn load_hf_strips_vae_prefix() {
let cfg = tiny_cfg();
let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
let bare = src.state_dict();
let mut prefixed: StateDict<f32> = HashMap::new();
for (k, v) in bare {
prefixed.insert(format!("vae.{k}"), v);
}
let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 4]),
vec![1, 4, 1, 1],
false,
)
.unwrap();
let a = src.forward(&x).unwrap();
let b = dst.forward(&x).unwrap();
for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
assert!((x - y).abs() < 1e-5);
}
}
fn tmp_encoder_safetensors_from(v: &VaeEncoder<f32>) -> (tempfile::TempDir, PathBuf) {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("model.safetensors");
let sd = v.state_dict();
save_safetensors(&sd, &path).unwrap();
(dir, path)
}
#[test]
fn round_trip_safetensors_into_encoder() {
let cfg = tiny_cfg();
let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
let (_d, p) = tmp_encoder_safetensors_from(&src);
let (dst, report) = load_vae_encoder::<f32>(&p, cfg.clone(), false).unwrap();
assert!(
report.dropped.is_empty(),
"encoder round-trip should have empty drop list, got {:?}",
report.dropped
);
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
vec![1, 3, 8, 8],
false,
)
.unwrap();
let a = src.forward(&x).unwrap();
let b = dst.forward(&x).unwrap();
for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
assert!((x - y).abs() < 1e-5);
}
}
#[test]
fn encoder_load_hf_drops_decoder_keys_nonstrict() {
let cfg = tiny_cfg();
let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
let mut hf_sd: StateDict<f32> = v.state_dict();
hf_sd.insert(
"decoder.conv_in.weight".into(),
ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
);
hf_sd.insert(
"post_quant_conv.weight".into(),
ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
);
let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
assert_eq!(
rep.dropped,
vec![
"decoder.conv_in.weight".to_string(),
"post_quant_conv.weight".to_string(),
]
);
}
#[test]
fn encoder_load_hf_strict_rejects_decoder_keys() {
let cfg = tiny_cfg();
let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
let mut hf_sd: StateDict<f32> = HashMap::new();
hf_sd.insert(
"decoder.conv_in.weight".into(),
ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
);
assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
}
#[test]
fn encoder_load_hf_strips_vae_prefix() {
let cfg = tiny_cfg();
let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
let bare = src.state_dict();
let mut prefixed: StateDict<f32> = HashMap::new();
for (k, v) in bare {
prefixed.insert(format!("vae.{k}"), v);
}
let mut dst = VaeEncoder::<f32>::new(cfg).unwrap();
let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
vec![1, 3, 8, 8],
false,
)
.unwrap();
let a = src.forward(&x).unwrap();
let b = dst.forward(&x).unwrap();
for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
assert!((x - y).abs() < 1e-5);
}
}
#[test]
fn full_vae_checkpoint_loadable_by_both_halves() {
let cfg = tiny_cfg();
let dec_src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
let enc_src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
let mut combined: StateDict<f32> = HashMap::new();
for (k, v) in dec_src.state_dict() {
combined.insert(k, v);
}
for (k, v) in enc_src.state_dict() {
combined.insert(k, v);
}
let mut dec_dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
let dec_rep = dec_dst.load_hf_state_dict(&combined, false).unwrap();
let mut enc_dst = VaeEncoder::<f32>::new(cfg).unwrap();
let enc_rep = enc_dst.load_hf_state_dict(&combined, false).unwrap();
for k in &dec_rep.dropped {
assert!(
k.starts_with("encoder.") || k.starts_with("quant_conv."),
"decoder dropped unexpected key: {k}"
);
}
for k in &enc_rep.dropped {
assert!(
k.starts_with("decoder.") || k.starts_with("post_quant_conv."),
"encoder dropped unexpected key: {k}"
);
}
assert!(!dec_rep.dropped.is_empty(), "decoder should have dropped some keys");
assert!(!enc_rep.dropped.is_empty(), "encoder should have dropped some keys");
}
}