1use std::collections::HashMap;
11use std::path::Path;
12
13use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
14use ferrotorch_nn::module::{Module, StateDict};
15use ferrotorch_serialize::load_safetensors;
16
17use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
18use crate::config::VaeDecoderConfig;
19use crate::unet::UNet2DConditionModel;
20use crate::unet_config::UNet2DConditionConfig;
21use crate::vae::VaeDecoder;
22use crate::vae_encoder::{VaeEncoder, VaeEncoderConfig};
23
24#[derive(Debug, Default, Clone)]
32pub struct DropReport {
33 pub dropped: Vec<String>,
36}
37
38impl<T: Float> VaeDecoder<T> {
39 pub fn load_hf_state_dict(
59 &mut self,
60 hf_state: &StateDict<T>,
61 strict: bool,
62 ) -> FerrotorchResult<DropReport> {
63 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
64 let mut dropped: Vec<String> = Vec::new();
65
66 for (k, v) in hf_state {
67 let after_vae = k.strip_prefix("vae.").map_or_else(|| k.clone(), str::to_owned);
70 if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
71 remapped.insert(after_vae, v.clone());
72 continue;
73 }
74 if strict {
75 return Err(FerrotorchError::InvalidArgument {
76 message: format!(
77 "VaeDecoder::load_hf_state_dict: key {k:?} is not under \
78 `post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
79 and strict mode is on. Pass strict=false to drop encoder / \
80 quant_conv keys."
81 ),
82 });
83 }
84 dropped.push(k.clone());
85 }
86 dropped.sort();
87 self.load_state_dict(&remapped, strict)?;
88 Ok(DropReport { dropped })
89 }
90}
91
92impl<T: Float> UNet2DConditionModel<T> {
97 pub fn load_hf_state_dict(
112 &mut self,
113 hf_state: &StateDict<T>,
114 strict: bool,
115 ) -> FerrotorchResult<DropReport> {
116 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
117 let mut dropped: Vec<String> = Vec::new();
118 for (k, v) in hf_state {
119 let after_unet = k.strip_prefix("unet.").map_or_else(|| k.clone(), str::to_owned);
120 let is_unet_key = after_unet.starts_with("time_embedding.")
121 || after_unet.starts_with("conv_in.")
122 || after_unet.starts_with("down_blocks.")
123 || after_unet.starts_with("mid_block.")
124 || after_unet.starts_with("up_blocks.")
125 || after_unet.starts_with("conv_norm_out.")
126 || after_unet.starts_with("conv_out.");
127 if is_unet_key {
128 remapped.insert(after_unet, v.clone());
129 continue;
130 }
131 if strict {
132 return Err(FerrotorchError::InvalidArgument {
133 message: format!(
134 "UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
135 a UNet prefix (with optional `unet.`) and strict mode is on."
136 ),
137 });
138 }
139 dropped.push(k.clone());
140 }
141 dropped.sort();
142 self.load_state_dict(&remapped, strict)?;
143 Ok(DropReport { dropped })
144 }
145}
146
147pub fn load_unet<T: Float>(
160 weights_path: &Path,
161 cfg: UNet2DConditionConfig,
162 strict: bool,
163) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
164 let state =
165 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
166 message: format!(
167 "load_unet: failed to decode safetensors {}: {e}",
168 weights_path.display()
169 ),
170 })?;
171 let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
172 let report = unet.load_hf_state_dict(&state, strict)?;
173 Ok((unet, report))
174}
175
176fn load_safetensors_clip_filtered<T: Float>(
188 weights_path: &Path,
189) -> FerrotorchResult<(StateDict<T>, bool)> {
190 use safetensors::SafeTensors;
191
192 let bytes =
193 std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
194 message: format!(
195 "load_safetensors_clip_filtered: failed to read {}: {e}",
196 weights_path.display()
197 ),
198 })?;
199 let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
200 message: format!(
201 "load_safetensors_clip_filtered: failed to parse {}: {e}",
202 weights_path.display()
203 ),
204 })?;
205 let mut keep: Vec<String> = Vec::new();
206 let mut had_position_ids = false;
207 for k in st.names() {
208 let s: &str = k.as_str();
209 if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
212 had_position_ids = true;
213 continue;
214 }
215 keep.push(String::from(s));
216 }
217
218 let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
222 Vec::with_capacity(keep.len());
223 for k in &keep {
224 let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
225 message: format!(
226 "load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
227 ),
228 })?;
229 subset.push((k.clone(), v));
230 }
231 let serialized = safetensors::serialize(subset, &None).map_err(|e| {
232 FerrotorchError::InvalidArgument {
233 message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
234 }
235 })?;
236 let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
237 message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
238 })?;
239 std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
240 message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
241 })?;
242 let state = load_safetensors::<T>(tmp.path())?;
243 Ok((state, had_position_ids))
244}
245
246pub fn load_clip_text_encoder<T: Float>(
269 weights_path: &Path,
270 cfg: ClipTextConfig,
271 strict: bool,
272) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
273 let (mut state, had_position_ids) =
274 load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
275 FerrotorchError::InvalidArgument {
276 message: format!(
277 "load_clip_text_encoder: failed to decode safetensors {}: {e}",
278 weights_path.display()
279 ),
280 }
281 })?;
282
283 if had_position_ids {
289 let key = if state
290 .keys()
291 .any(|k| k.starts_with("text_model."))
292 {
293 "text_model.embeddings.position_ids".to_string()
294 } else {
295 "embeddings.position_ids".to_string()
296 };
297 state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
298 }
299
300 let mut enc = ClipTextEncoder::<T>::new(cfg)?;
301 let report = enc.load_hf_state_dict(&state, strict)?;
302 Ok((enc, report))
303}
304
305pub fn load_vae_decoder<T: Float>(
320 weights_path: &Path,
321 cfg: VaeDecoderConfig,
322 strict: bool,
323) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
324 let state =
325 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
326 message: format!(
327 "load_vae_decoder: failed to decode safetensors {}: {e}",
328 weights_path.display()
329 ),
330 })?;
331 let mut decoder = VaeDecoder::<T>::new(cfg)?;
332 let report = decoder.load_hf_state_dict(&state, strict)?;
333 Ok((decoder, report))
334}
335
336impl<T: Float> VaeEncoder<T> {
341 pub fn load_hf_state_dict(
365 &mut self,
366 hf_state: &StateDict<T>,
367 strict: bool,
368 ) -> FerrotorchResult<DropReport> {
369 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
370 let mut dropped: Vec<String> = Vec::new();
371
372 for (k, v) in hf_state {
373 let after_vae = k.strip_prefix("vae.").map_or_else(|| k.clone(), str::to_owned);
376 if after_vae.starts_with("encoder.") || after_vae.starts_with("quant_conv.") {
377 remapped.insert(after_vae, v.clone());
378 continue;
379 }
380 if strict {
381 return Err(FerrotorchError::InvalidArgument {
382 message: format!(
383 "VaeEncoder::load_hf_state_dict: key {k:?} is not under \
384 `encoder.*` / `quant_conv.*` (with optional `vae.` prefix) \
385 and strict mode is on. Pass strict=false to drop decoder / \
386 post_quant_conv keys."
387 ),
388 });
389 }
390 dropped.push(k.clone());
391 }
392 dropped.sort();
393 self.load_state_dict(&remapped, strict)?;
394 Ok(DropReport { dropped })
395 }
396}
397
398pub fn load_vae_encoder<T: Float>(
413 weights_path: &Path,
414 cfg: VaeEncoderConfig,
415 strict: bool,
416) -> FerrotorchResult<(VaeEncoder<T>, DropReport)> {
417 let state =
418 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
419 message: format!(
420 "load_vae_encoder: failed to decode safetensors {}: {e}",
421 weights_path.display()
422 ),
423 })?;
424 let mut encoder = VaeEncoder::<T>::new(cfg)?;
425 let report = encoder.load_hf_state_dict(&state, strict)?;
426 Ok((encoder, report))
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use ferrotorch_core::{Tensor, TensorStorage};
433 use ferrotorch_serialize::save_safetensors;
434 use std::path::PathBuf;
435
436 fn tiny_cfg() -> VaeDecoderConfig {
437 VaeDecoderConfig {
438 out_channels: 3,
439 latent_channels: 4,
440 block_out_channels: vec![4, 8, 16, 16],
441 layers_per_block: 1,
442 norm_num_groups: 4,
443 sample_size: 8,
444 scaling_factor: 0.18215,
445 }
446 }
447
448 fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
449 let dir = tempfile::tempdir().unwrap();
450 let path = dir.path().join("model.safetensors");
451 let sd = v.state_dict();
455 save_safetensors(&sd, &path).unwrap();
456 (dir, path)
457 }
458
459 #[test]
460 fn round_trip_safetensors_into_decoder() {
461 let cfg = tiny_cfg();
462 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
463 let (_d, p) = tmp_safetensors_from(&src);
464 let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
465 assert!(
466 report.dropped.is_empty(),
467 "round-trip should have empty drop list, got {:?}",
468 report.dropped
469 );
470 let x = Tensor::from_storage(
471 TensorStorage::cpu(vec![0.01f32; 4]),
472 vec![1, 4, 1, 1],
473 false,
474 )
475 .unwrap();
476 let a = src.forward(&x).unwrap();
477 let b = dst.forward(&x).unwrap();
478 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
479 assert!((x - y).abs() < 1e-5);
480 }
481 }
482
483 #[test]
484 fn load_hf_drops_encoder_keys_nonstrict() {
485 let cfg = tiny_cfg();
486 let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
487 let mut hf_sd: StateDict<f32> = v.state_dict();
488 hf_sd.insert(
490 "encoder.conv_in.weight".into(),
491 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
492 );
493 hf_sd.insert(
495 "quant_conv.weight".into(),
496 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
497 );
498 let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
499 assert_eq!(
500 rep.dropped,
501 vec![
502 "encoder.conv_in.weight".to_string(),
503 "quant_conv.weight".to_string(),
504 ]
505 );
506 }
507
508 #[test]
509 fn load_hf_strict_rejects_encoder_keys() {
510 let cfg = tiny_cfg();
511 let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
512 let mut hf_sd: StateDict<f32> = HashMap::new();
513 hf_sd.insert(
514 "encoder.conv_in.weight".into(),
515 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
516 );
517 assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
518 }
519
520 #[test]
521 fn load_hf_strips_vae_prefix() {
522 let cfg = tiny_cfg();
523 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
524 let bare = src.state_dict();
525 let mut prefixed: StateDict<f32> = HashMap::new();
527 for (k, v) in bare {
528 prefixed.insert(format!("vae.{k}"), v);
529 }
530 let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
531 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
532 assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
533 let x = Tensor::from_storage(
534 TensorStorage::cpu(vec![0.01f32; 4]),
535 vec![1, 4, 1, 1],
536 false,
537 )
538 .unwrap();
539 let a = src.forward(&x).unwrap();
540 let b = dst.forward(&x).unwrap();
541 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
542 assert!((x - y).abs() < 1e-5);
543 }
544 }
545
546 fn tmp_encoder_safetensors_from(v: &VaeEncoder<f32>) -> (tempfile::TempDir, PathBuf) {
547 let dir = tempfile::tempdir().unwrap();
548 let path = dir.path().join("model.safetensors");
549 let sd = v.state_dict();
550 save_safetensors(&sd, &path).unwrap();
551 (dir, path)
552 }
553
554 #[test]
555 fn round_trip_safetensors_into_encoder() {
556 let cfg = tiny_cfg();
557 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
558 let (_d, p) = tmp_encoder_safetensors_from(&src);
559 let (dst, report) = load_vae_encoder::<f32>(&p, cfg.clone(), false).unwrap();
560 assert!(
561 report.dropped.is_empty(),
562 "encoder round-trip should have empty drop list, got {:?}",
563 report.dropped
564 );
565 let x = Tensor::from_storage(
566 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
567 vec![1, 3, 8, 8],
568 false,
569 )
570 .unwrap();
571 let a = src.forward(&x).unwrap();
572 let b = dst.forward(&x).unwrap();
573 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
574 assert!((x - y).abs() < 1e-5);
575 }
576 }
577
578 #[test]
579 fn encoder_load_hf_drops_decoder_keys_nonstrict() {
580 let cfg = tiny_cfg();
581 let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
582 let mut hf_sd: StateDict<f32> = v.state_dict();
583 hf_sd.insert(
585 "decoder.conv_in.weight".into(),
586 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
587 );
588 hf_sd.insert(
589 "post_quant_conv.weight".into(),
590 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
591 );
592 let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
593 assert_eq!(
594 rep.dropped,
595 vec![
596 "decoder.conv_in.weight".to_string(),
597 "post_quant_conv.weight".to_string(),
598 ]
599 );
600 }
601
602 #[test]
603 fn encoder_load_hf_strict_rejects_decoder_keys() {
604 let cfg = tiny_cfg();
605 let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
606 let mut hf_sd: StateDict<f32> = HashMap::new();
607 hf_sd.insert(
608 "decoder.conv_in.weight".into(),
609 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
610 );
611 assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
612 }
613
614 #[test]
615 fn encoder_load_hf_strips_vae_prefix() {
616 let cfg = tiny_cfg();
617 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
618 let bare = src.state_dict();
619 let mut prefixed: StateDict<f32> = HashMap::new();
620 for (k, v) in bare {
621 prefixed.insert(format!("vae.{k}"), v);
622 }
623 let mut dst = VaeEncoder::<f32>::new(cfg).unwrap();
624 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
625 assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
626 let x = Tensor::from_storage(
627 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
628 vec![1, 3, 8, 8],
629 false,
630 )
631 .unwrap();
632 let a = src.forward(&x).unwrap();
633 let b = dst.forward(&x).unwrap();
634 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
635 assert!((x - y).abs() < 1e-5);
636 }
637 }
638
639 #[test]
640 fn full_vae_checkpoint_loadable_by_both_halves() {
641 let cfg = tiny_cfg();
646 let dec_src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
647 let enc_src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
648
649 let mut combined: StateDict<f32> = HashMap::new();
650 for (k, v) in dec_src.state_dict() {
651 combined.insert(k, v);
652 }
653 for (k, v) in enc_src.state_dict() {
654 combined.insert(k, v);
655 }
656
657 let mut dec_dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
658 let dec_rep = dec_dst.load_hf_state_dict(&combined, false).unwrap();
659 let mut enc_dst = VaeEncoder::<f32>::new(cfg).unwrap();
660 let enc_rep = enc_dst.load_hf_state_dict(&combined, false).unwrap();
661
662 for k in &dec_rep.dropped {
664 assert!(
665 k.starts_with("encoder.") || k.starts_with("quant_conv."),
666 "decoder dropped unexpected key: {k}"
667 );
668 }
669 for k in &enc_rep.dropped {
671 assert!(
672 k.starts_with("decoder.") || k.starts_with("post_quant_conv."),
673 "encoder dropped unexpected key: {k}"
674 );
675 }
676 assert!(!dec_rep.dropped.is_empty(), "decoder should have dropped some keys");
677 assert!(!enc_rep.dropped.is_empty(), "encoder should have dropped some keys");
678 }
679}