1use std::collections::HashMap;
11use std::path::Path;
12
13use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
14use ferrotorch_nn::module::{Module, StateDict};
15use ferrotorch_serialize::load_safetensors;
16
17use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
18use crate::config::VaeDecoderConfig;
19use crate::unet::UNet2DConditionModel;
20use crate::unet_config::UNet2DConditionConfig;
21use crate::vae::VaeDecoder;
22use crate::vae_encoder::{VaeEncoder, VaeEncoderConfig};
23
24#[derive(Debug, Default, Clone)]
32pub struct DropReport {
33 pub dropped: Vec<String>,
36}
37
38impl<T: Float> VaeDecoder<T> {
39 pub fn load_hf_state_dict(
59 &mut self,
60 hf_state: &StateDict<T>,
61 strict: bool,
62 ) -> FerrotorchResult<DropReport> {
63 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
64 let mut dropped: Vec<String> = Vec::new();
65
66 for (k, v) in hf_state {
67 let after_vae = k
70 .strip_prefix("vae.")
71 .map_or_else(|| k.clone(), str::to_owned);
72 if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
73 remapped.insert(after_vae, v.clone());
74 continue;
75 }
76 if strict {
77 return Err(FerrotorchError::InvalidArgument {
78 message: format!(
79 "VaeDecoder::load_hf_state_dict: key {k:?} is not under \
80 `post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
81 and strict mode is on. Pass strict=false to drop encoder / \
82 quant_conv keys."
83 ),
84 });
85 }
86 dropped.push(k.clone());
87 }
88 dropped.sort();
89 self.load_state_dict(&remapped, strict)?;
90 Ok(DropReport { dropped })
91 }
92}
93
94impl<T: Float> UNet2DConditionModel<T> {
99 pub fn load_hf_state_dict(
114 &mut self,
115 hf_state: &StateDict<T>,
116 strict: bool,
117 ) -> FerrotorchResult<DropReport> {
118 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
119 let mut dropped: Vec<String> = Vec::new();
120 for (k, v) in hf_state {
121 let after_unet = k
122 .strip_prefix("unet.")
123 .map_or_else(|| k.clone(), str::to_owned);
124 let is_unet_key = after_unet.starts_with("time_embedding.")
125 || after_unet.starts_with("conv_in.")
126 || after_unet.starts_with("down_blocks.")
127 || after_unet.starts_with("mid_block.")
128 || after_unet.starts_with("up_blocks.")
129 || after_unet.starts_with("conv_norm_out.")
130 || after_unet.starts_with("conv_out.");
131 if is_unet_key {
132 remapped.insert(after_unet, v.clone());
133 continue;
134 }
135 if strict {
136 return Err(FerrotorchError::InvalidArgument {
137 message: format!(
138 "UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
139 a UNet prefix (with optional `unet.`) and strict mode is on."
140 ),
141 });
142 }
143 dropped.push(k.clone());
144 }
145 dropped.sort();
146 self.load_state_dict(&remapped, strict)?;
147 Ok(DropReport { dropped })
148 }
149}
150
151pub fn load_unet<T: Float>(
164 weights_path: &Path,
165 cfg: UNet2DConditionConfig,
166 strict: bool,
167) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
168 let state =
169 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
170 message: format!(
171 "load_unet: failed to decode safetensors {}: {e}",
172 weights_path.display()
173 ),
174 })?;
175 let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
176 let report = unet.load_hf_state_dict(&state, strict)?;
177 Ok((unet, report))
178}
179
180fn load_safetensors_clip_filtered<T: Float>(
192 weights_path: &Path,
193) -> FerrotorchResult<(StateDict<T>, bool)> {
194 use safetensors::SafeTensors;
195
196 let bytes = std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
197 message: format!(
198 "load_safetensors_clip_filtered: failed to read {}: {e}",
199 weights_path.display()
200 ),
201 })?;
202 let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
203 message: format!(
204 "load_safetensors_clip_filtered: failed to parse {}: {e}",
205 weights_path.display()
206 ),
207 })?;
208 let mut keep: Vec<String> = Vec::new();
209 let mut had_position_ids = false;
210 for k in st.names() {
211 let s: &str = k.as_str();
212 if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
215 had_position_ids = true;
216 continue;
217 }
218 keep.push(String::from(s));
219 }
220
221 let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
225 Vec::with_capacity(keep.len());
226 for k in &keep {
227 let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
228 message: format!(
229 "load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
230 ),
231 })?;
232 subset.push((k.clone(), v));
233 }
234 let serialized =
235 safetensors::serialize(subset, &None).map_err(|e| FerrotorchError::InvalidArgument {
236 message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
237 })?;
238 let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
239 message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
240 })?;
241 std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
242 message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
243 })?;
244 let state = load_safetensors::<T>(tmp.path())?;
245 Ok((state, had_position_ids))
246}
247
248pub fn load_clip_text_encoder<T: Float>(
271 weights_path: &Path,
272 cfg: ClipTextConfig,
273 strict: bool,
274) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
275 let (mut state, had_position_ids) =
276 load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
277 FerrotorchError::InvalidArgument {
278 message: format!(
279 "load_clip_text_encoder: failed to decode safetensors {}: {e}",
280 weights_path.display()
281 ),
282 }
283 })?;
284
285 if had_position_ids {
291 let key = if state.keys().any(|k| k.starts_with("text_model.")) {
292 "text_model.embeddings.position_ids".to_string()
293 } else {
294 "embeddings.position_ids".to_string()
295 };
296 state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
297 }
298
299 let mut enc = ClipTextEncoder::<T>::new(cfg)?;
300 let report = enc.load_hf_state_dict(&state, strict)?;
301 Ok((enc, report))
302}
303
304pub fn load_vae_decoder<T: Float>(
319 weights_path: &Path,
320 cfg: VaeDecoderConfig,
321 strict: bool,
322) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
323 let state =
324 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
325 message: format!(
326 "load_vae_decoder: failed to decode safetensors {}: {e}",
327 weights_path.display()
328 ),
329 })?;
330 let mut decoder = VaeDecoder::<T>::new(cfg)?;
331 let report = decoder.load_hf_state_dict(&state, strict)?;
332 Ok((decoder, report))
333}
334
335impl<T: Float> VaeEncoder<T> {
340 pub fn load_hf_state_dict(
364 &mut self,
365 hf_state: &StateDict<T>,
366 strict: bool,
367 ) -> FerrotorchResult<DropReport> {
368 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
369 let mut dropped: Vec<String> = Vec::new();
370
371 for (k, v) in hf_state {
372 let after_vae = k
375 .strip_prefix("vae.")
376 .map_or_else(|| k.clone(), str::to_owned);
377 if after_vae.starts_with("encoder.") || after_vae.starts_with("quant_conv.") {
378 remapped.insert(after_vae, v.clone());
379 continue;
380 }
381 if strict {
382 return Err(FerrotorchError::InvalidArgument {
383 message: format!(
384 "VaeEncoder::load_hf_state_dict: key {k:?} is not under \
385 `encoder.*` / `quant_conv.*` (with optional `vae.` prefix) \
386 and strict mode is on. Pass strict=false to drop decoder / \
387 post_quant_conv keys."
388 ),
389 });
390 }
391 dropped.push(k.clone());
392 }
393 dropped.sort();
394 self.load_state_dict(&remapped, strict)?;
395 Ok(DropReport { dropped })
396 }
397}
398
399pub fn load_vae_encoder<T: Float>(
414 weights_path: &Path,
415 cfg: VaeEncoderConfig,
416 strict: bool,
417) -> FerrotorchResult<(VaeEncoder<T>, DropReport)> {
418 let state =
419 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
420 message: format!(
421 "load_vae_encoder: failed to decode safetensors {}: {e}",
422 weights_path.display()
423 ),
424 })?;
425 let mut encoder = VaeEncoder::<T>::new(cfg)?;
426 let report = encoder.load_hf_state_dict(&state, strict)?;
427 Ok((encoder, report))
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433 use ferrotorch_core::{Tensor, TensorStorage};
434 use ferrotorch_serialize::save_safetensors;
435 use std::path::PathBuf;
436
437 fn tiny_cfg() -> VaeDecoderConfig {
438 VaeDecoderConfig {
439 out_channels: 3,
440 latent_channels: 4,
441 block_out_channels: vec![4, 8, 16, 16],
442 layers_per_block: 1,
443 norm_num_groups: 4,
444 sample_size: 8,
445 scaling_factor: 0.18215,
446 }
447 }
448
449 fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
450 let dir = tempfile::tempdir().unwrap();
451 let path = dir.path().join("model.safetensors");
452 let sd = v.state_dict();
456 save_safetensors(&sd, &path).unwrap();
457 (dir, path)
458 }
459
460 #[test]
461 fn round_trip_safetensors_into_decoder() {
462 let cfg = tiny_cfg();
463 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
464 let (_d, p) = tmp_safetensors_from(&src);
465 let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
466 assert!(
467 report.dropped.is_empty(),
468 "round-trip should have empty drop list, got {:?}",
469 report.dropped
470 );
471 let x = Tensor::from_storage(
472 TensorStorage::cpu(vec![0.01f32; 4]),
473 vec![1, 4, 1, 1],
474 false,
475 )
476 .unwrap();
477 let a = src.forward(&x).unwrap();
478 let b = dst.forward(&x).unwrap();
479 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
480 assert!((x - y).abs() < 1e-5);
481 }
482 }
483
484 #[test]
485 fn load_hf_drops_encoder_keys_nonstrict() {
486 let cfg = tiny_cfg();
487 let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
488 let mut hf_sd: StateDict<f32> = v.state_dict();
489 hf_sd.insert(
491 "encoder.conv_in.weight".into(),
492 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
493 );
494 hf_sd.insert(
496 "quant_conv.weight".into(),
497 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
498 );
499 let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
500 assert_eq!(
501 rep.dropped,
502 vec![
503 "encoder.conv_in.weight".to_string(),
504 "quant_conv.weight".to_string(),
505 ]
506 );
507 }
508
509 #[test]
510 fn load_hf_strict_rejects_encoder_keys() {
511 let cfg = tiny_cfg();
512 let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
513 let mut hf_sd: StateDict<f32> = HashMap::new();
514 hf_sd.insert(
515 "encoder.conv_in.weight".into(),
516 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
517 );
518 assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
519 }
520
521 #[test]
522 fn load_hf_strips_vae_prefix() {
523 let cfg = tiny_cfg();
524 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
525 let bare = src.state_dict();
526 let mut prefixed: StateDict<f32> = HashMap::new();
528 for (k, v) in bare {
529 prefixed.insert(format!("vae.{k}"), v);
530 }
531 let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
532 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
533 assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
534 let x = Tensor::from_storage(
535 TensorStorage::cpu(vec![0.01f32; 4]),
536 vec![1, 4, 1, 1],
537 false,
538 )
539 .unwrap();
540 let a = src.forward(&x).unwrap();
541 let b = dst.forward(&x).unwrap();
542 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
543 assert!((x - y).abs() < 1e-5);
544 }
545 }
546
547 fn tmp_encoder_safetensors_from(v: &VaeEncoder<f32>) -> (tempfile::TempDir, PathBuf) {
548 let dir = tempfile::tempdir().unwrap();
549 let path = dir.path().join("model.safetensors");
550 let sd = v.state_dict();
551 save_safetensors(&sd, &path).unwrap();
552 (dir, path)
553 }
554
555 #[test]
556 fn round_trip_safetensors_into_encoder() {
557 let cfg = tiny_cfg();
558 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
559 let (_d, p) = tmp_encoder_safetensors_from(&src);
560 let (dst, report) = load_vae_encoder::<f32>(&p, cfg.clone(), false).unwrap();
561 assert!(
562 report.dropped.is_empty(),
563 "encoder round-trip should have empty drop list, got {:?}",
564 report.dropped
565 );
566 let x = Tensor::from_storage(
567 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
568 vec![1, 3, 8, 8],
569 false,
570 )
571 .unwrap();
572 let a = src.forward(&x).unwrap();
573 let b = dst.forward(&x).unwrap();
574 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
575 assert!((x - y).abs() < 1e-5);
576 }
577 }
578
579 #[test]
580 fn encoder_load_hf_drops_decoder_keys_nonstrict() {
581 let cfg = tiny_cfg();
582 let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
583 let mut hf_sd: StateDict<f32> = v.state_dict();
584 hf_sd.insert(
586 "decoder.conv_in.weight".into(),
587 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
588 );
589 hf_sd.insert(
590 "post_quant_conv.weight".into(),
591 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
592 );
593 let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
594 assert_eq!(
595 rep.dropped,
596 vec![
597 "decoder.conv_in.weight".to_string(),
598 "post_quant_conv.weight".to_string(),
599 ]
600 );
601 }
602
603 #[test]
604 fn encoder_load_hf_strict_rejects_decoder_keys() {
605 let cfg = tiny_cfg();
606 let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
607 let mut hf_sd: StateDict<f32> = HashMap::new();
608 hf_sd.insert(
609 "decoder.conv_in.weight".into(),
610 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
611 );
612 assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
613 }
614
615 #[test]
616 fn encoder_load_hf_strips_vae_prefix() {
617 let cfg = tiny_cfg();
618 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
619 let bare = src.state_dict();
620 let mut prefixed: StateDict<f32> = HashMap::new();
621 for (k, v) in bare {
622 prefixed.insert(format!("vae.{k}"), v);
623 }
624 let mut dst = VaeEncoder::<f32>::new(cfg).unwrap();
625 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
626 assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
627 let x = Tensor::from_storage(
628 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
629 vec![1, 3, 8, 8],
630 false,
631 )
632 .unwrap();
633 let a = src.forward(&x).unwrap();
634 let b = dst.forward(&x).unwrap();
635 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
636 assert!((x - y).abs() < 1e-5);
637 }
638 }
639
640 #[test]
641 fn full_vae_checkpoint_loadable_by_both_halves() {
642 let cfg = tiny_cfg();
647 let dec_src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
648 let enc_src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
649
650 let mut combined: StateDict<f32> = HashMap::new();
651 for (k, v) in dec_src.state_dict() {
652 combined.insert(k, v);
653 }
654 for (k, v) in enc_src.state_dict() {
655 combined.insert(k, v);
656 }
657
658 let mut dec_dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
659 let dec_rep = dec_dst.load_hf_state_dict(&combined, false).unwrap();
660 let mut enc_dst = VaeEncoder::<f32>::new(cfg).unwrap();
661 let enc_rep = enc_dst.load_hf_state_dict(&combined, false).unwrap();
662
663 for k in &dec_rep.dropped {
665 assert!(
666 k.starts_with("encoder.") || k.starts_with("quant_conv."),
667 "decoder dropped unexpected key: {k}"
668 );
669 }
670 for k in &enc_rep.dropped {
672 assert!(
673 k.starts_with("decoder.") || k.starts_with("post_quant_conv."),
674 "encoder dropped unexpected key: {k}"
675 );
676 }
677 assert!(
678 !dec_rep.dropped.is_empty(),
679 "decoder should have dropped some keys"
680 );
681 assert!(
682 !enc_rep.dropped.is_empty(),
683 "encoder should have dropped some keys"
684 );
685 }
686}