1use std::collections::HashMap;
22use std::path::Path;
23
24use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
25use ferrotorch_nn::module::{Module, StateDict};
26use ferrotorch_serialize::load_safetensors;
27
28use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
29use crate::config::VaeDecoderConfig;
30use crate::unet::UNet2DConditionModel;
31use crate::unet_config::UNet2DConditionConfig;
32use crate::vae::VaeDecoder;
33use crate::vae_encoder::{VaeEncoder, VaeEncoderConfig};
34
35#[derive(Debug, Default, Clone)]
43pub struct DropReport {
44 pub dropped: Vec<String>,
47}
48
49impl<T: Float> VaeDecoder<T> {
50 pub fn load_hf_state_dict(
70 &mut self,
71 hf_state: &StateDict<T>,
72 strict: bool,
73 ) -> FerrotorchResult<DropReport> {
74 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
75 let mut dropped: Vec<String> = Vec::new();
76
77 for (k, v) in hf_state {
78 let after_vae = k
81 .strip_prefix("vae.")
82 .map_or_else(|| k.clone(), str::to_owned);
83 if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
84 remapped.insert(after_vae, v.clone());
85 continue;
86 }
87 if strict {
88 return Err(FerrotorchError::InvalidArgument {
89 message: format!(
90 "VaeDecoder::load_hf_state_dict: key {k:?} is not under \
91 `post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
92 and strict mode is on. Pass strict=false to drop encoder / \
93 quant_conv keys."
94 ),
95 });
96 }
97 dropped.push(k.clone());
98 }
99 dropped.sort();
100 self.load_state_dict(&remapped, strict)?;
101 Ok(DropReport { dropped })
102 }
103}
104
105impl<T: Float> UNet2DConditionModel<T> {
110 pub fn load_hf_state_dict(
125 &mut self,
126 hf_state: &StateDict<T>,
127 strict: bool,
128 ) -> FerrotorchResult<DropReport> {
129 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
130 let mut dropped: Vec<String> = Vec::new();
131 for (k, v) in hf_state {
132 let after_unet = k
133 .strip_prefix("unet.")
134 .map_or_else(|| k.clone(), str::to_owned);
135 let is_unet_key = after_unet.starts_with("time_embedding.")
136 || after_unet.starts_with("conv_in.")
137 || after_unet.starts_with("down_blocks.")
138 || after_unet.starts_with("mid_block.")
139 || after_unet.starts_with("up_blocks.")
140 || after_unet.starts_with("conv_norm_out.")
141 || after_unet.starts_with("conv_out.");
142 if is_unet_key {
143 remapped.insert(after_unet, v.clone());
144 continue;
145 }
146 if strict {
147 return Err(FerrotorchError::InvalidArgument {
148 message: format!(
149 "UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
150 a UNet prefix (with optional `unet.`) and strict mode is on."
151 ),
152 });
153 }
154 dropped.push(k.clone());
155 }
156 dropped.sort();
157 self.load_state_dict(&remapped, strict)?;
158 Ok(DropReport { dropped })
159 }
160}
161
162pub fn load_unet<T: Float>(
175 weights_path: &Path,
176 cfg: UNet2DConditionConfig,
177 strict: bool,
178) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
179 let state =
180 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
181 message: format!(
182 "load_unet: failed to decode safetensors {}: {e}",
183 weights_path.display()
184 ),
185 })?;
186 let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
187 let report = unet.load_hf_state_dict(&state, strict)?;
188 Ok((unet, report))
189}
190
191fn load_safetensors_clip_filtered<T: Float>(
203 weights_path: &Path,
204) -> FerrotorchResult<(StateDict<T>, bool)> {
205 use safetensors::SafeTensors;
206
207 let bytes = std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
208 message: format!(
209 "load_safetensors_clip_filtered: failed to read {}: {e}",
210 weights_path.display()
211 ),
212 })?;
213 let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
214 message: format!(
215 "load_safetensors_clip_filtered: failed to parse {}: {e}",
216 weights_path.display()
217 ),
218 })?;
219 let mut keep: Vec<String> = Vec::new();
220 let mut had_position_ids = false;
221 for k in st.names() {
222 let s: &str = k.as_str();
223 if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
226 had_position_ids = true;
227 continue;
228 }
229 keep.push(String::from(s));
230 }
231
232 let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
236 Vec::with_capacity(keep.len());
237 for k in &keep {
238 let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
239 message: format!(
240 "load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
241 ),
242 })?;
243 subset.push((k.clone(), v));
244 }
245 let serialized =
246 safetensors::serialize(subset, &None).map_err(|e| FerrotorchError::InvalidArgument {
247 message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
248 })?;
249 let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
250 message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
251 })?;
252 std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
253 message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
254 })?;
255 let state = load_safetensors::<T>(tmp.path())?;
256 Ok((state, had_position_ids))
257}
258
259pub fn load_clip_text_encoder<T: Float>(
282 weights_path: &Path,
283 cfg: ClipTextConfig,
284 strict: bool,
285) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
286 let (mut state, had_position_ids) =
287 load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
288 FerrotorchError::InvalidArgument {
289 message: format!(
290 "load_clip_text_encoder: failed to decode safetensors {}: {e}",
291 weights_path.display()
292 ),
293 }
294 })?;
295
296 if had_position_ids {
302 let key = if state.keys().any(|k| k.starts_with("text_model.")) {
303 "text_model.embeddings.position_ids".to_string()
304 } else {
305 "embeddings.position_ids".to_string()
306 };
307 state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
308 }
309
310 let mut enc = ClipTextEncoder::<T>::new(cfg)?;
311 let report = enc.load_hf_state_dict(&state, strict)?;
312 Ok((enc, report))
313}
314
315pub fn load_vae_decoder<T: Float>(
330 weights_path: &Path,
331 cfg: VaeDecoderConfig,
332 strict: bool,
333) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
334 let state =
335 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
336 message: format!(
337 "load_vae_decoder: failed to decode safetensors {}: {e}",
338 weights_path.display()
339 ),
340 })?;
341 let mut decoder = VaeDecoder::<T>::new(cfg)?;
342 let report = decoder.load_hf_state_dict(&state, strict)?;
343 Ok((decoder, report))
344}
345
346impl<T: Float> VaeEncoder<T> {
351 pub fn load_hf_state_dict(
375 &mut self,
376 hf_state: &StateDict<T>,
377 strict: bool,
378 ) -> FerrotorchResult<DropReport> {
379 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
380 let mut dropped: Vec<String> = Vec::new();
381
382 for (k, v) in hf_state {
383 let after_vae = k
386 .strip_prefix("vae.")
387 .map_or_else(|| k.clone(), str::to_owned);
388 if after_vae.starts_with("encoder.") || after_vae.starts_with("quant_conv.") {
389 remapped.insert(after_vae, v.clone());
390 continue;
391 }
392 if strict {
393 return Err(FerrotorchError::InvalidArgument {
394 message: format!(
395 "VaeEncoder::load_hf_state_dict: key {k:?} is not under \
396 `encoder.*` / `quant_conv.*` (with optional `vae.` prefix) \
397 and strict mode is on. Pass strict=false to drop decoder / \
398 post_quant_conv keys."
399 ),
400 });
401 }
402 dropped.push(k.clone());
403 }
404 dropped.sort();
405 self.load_state_dict(&remapped, strict)?;
406 Ok(DropReport { dropped })
407 }
408}
409
410pub fn load_vae_encoder<T: Float>(
425 weights_path: &Path,
426 cfg: VaeEncoderConfig,
427 strict: bool,
428) -> FerrotorchResult<(VaeEncoder<T>, DropReport)> {
429 let state =
430 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
431 message: format!(
432 "load_vae_encoder: failed to decode safetensors {}: {e}",
433 weights_path.display()
434 ),
435 })?;
436 let mut encoder = VaeEncoder::<T>::new(cfg)?;
437 let report = encoder.load_hf_state_dict(&state, strict)?;
438 Ok((encoder, report))
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use ferrotorch_core::{Tensor, TensorStorage};
445 use ferrotorch_serialize::save_safetensors;
446 use std::path::PathBuf;
447
448 fn tiny_cfg() -> VaeDecoderConfig {
449 VaeDecoderConfig {
450 out_channels: 3,
451 latent_channels: 4,
452 block_out_channels: vec![4, 8, 16, 16],
453 layers_per_block: 1,
454 norm_num_groups: 4,
455 sample_size: 8,
456 scaling_factor: 0.18215,
457 }
458 }
459
460 fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
461 let dir = tempfile::tempdir().unwrap();
462 let path = dir.path().join("model.safetensors");
463 let sd = v.state_dict();
467 save_safetensors(&sd, &path).unwrap();
468 (dir, path)
469 }
470
471 #[test]
472 fn round_trip_safetensors_into_decoder() {
473 let cfg = tiny_cfg();
474 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
475 let (_d, p) = tmp_safetensors_from(&src);
476 let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
477 assert!(
478 report.dropped.is_empty(),
479 "round-trip should have empty drop list, got {:?}",
480 report.dropped
481 );
482 let x = Tensor::from_storage(
483 TensorStorage::cpu(vec![0.01f32; 4]),
484 vec![1, 4, 1, 1],
485 false,
486 )
487 .unwrap();
488 let a = src.forward(&x).unwrap();
489 let b = dst.forward(&x).unwrap();
490 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
491 assert!((x - y).abs() < 1e-5);
492 }
493 }
494
495 #[test]
496 fn load_hf_drops_encoder_keys_nonstrict() {
497 let cfg = tiny_cfg();
498 let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
499 let mut hf_sd: StateDict<f32> = v.state_dict();
500 hf_sd.insert(
502 "encoder.conv_in.weight".into(),
503 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
504 );
505 hf_sd.insert(
507 "quant_conv.weight".into(),
508 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
509 );
510 let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
511 assert_eq!(
512 rep.dropped,
513 vec![
514 "encoder.conv_in.weight".to_string(),
515 "quant_conv.weight".to_string(),
516 ]
517 );
518 }
519
520 #[test]
521 fn load_hf_strict_rejects_encoder_keys() {
522 let cfg = tiny_cfg();
523 let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
524 let mut hf_sd: StateDict<f32> = HashMap::new();
525 hf_sd.insert(
526 "encoder.conv_in.weight".into(),
527 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
528 );
529 assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
530 }
531
532 #[test]
533 fn load_hf_strips_vae_prefix() {
534 let cfg = tiny_cfg();
535 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
536 let bare = src.state_dict();
537 let mut prefixed: StateDict<f32> = HashMap::new();
539 for (k, v) in bare {
540 prefixed.insert(format!("vae.{k}"), v);
541 }
542 let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
543 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
544 assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
545 let x = Tensor::from_storage(
546 TensorStorage::cpu(vec![0.01f32; 4]),
547 vec![1, 4, 1, 1],
548 false,
549 )
550 .unwrap();
551 let a = src.forward(&x).unwrap();
552 let b = dst.forward(&x).unwrap();
553 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
554 assert!((x - y).abs() < 1e-5);
555 }
556 }
557
558 fn tmp_encoder_safetensors_from(v: &VaeEncoder<f32>) -> (tempfile::TempDir, PathBuf) {
559 let dir = tempfile::tempdir().unwrap();
560 let path = dir.path().join("model.safetensors");
561 let sd = v.state_dict();
562 save_safetensors(&sd, &path).unwrap();
563 (dir, path)
564 }
565
566 #[test]
567 fn round_trip_safetensors_into_encoder() {
568 let cfg = tiny_cfg();
569 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
570 let (_d, p) = tmp_encoder_safetensors_from(&src);
571 let (dst, report) = load_vae_encoder::<f32>(&p, cfg.clone(), false).unwrap();
572 assert!(
573 report.dropped.is_empty(),
574 "encoder round-trip should have empty drop list, got {:?}",
575 report.dropped
576 );
577 let x = Tensor::from_storage(
578 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
579 vec![1, 3, 8, 8],
580 false,
581 )
582 .unwrap();
583 let a = src.forward(&x).unwrap();
584 let b = dst.forward(&x).unwrap();
585 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
586 assert!((x - y).abs() < 1e-5);
587 }
588 }
589
590 #[test]
591 fn encoder_load_hf_drops_decoder_keys_nonstrict() {
592 let cfg = tiny_cfg();
593 let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
594 let mut hf_sd: StateDict<f32> = v.state_dict();
595 hf_sd.insert(
597 "decoder.conv_in.weight".into(),
598 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
599 );
600 hf_sd.insert(
601 "post_quant_conv.weight".into(),
602 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
603 );
604 let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
605 assert_eq!(
606 rep.dropped,
607 vec![
608 "decoder.conv_in.weight".to_string(),
609 "post_quant_conv.weight".to_string(),
610 ]
611 );
612 }
613
614 #[test]
615 fn encoder_load_hf_strict_rejects_decoder_keys() {
616 let cfg = tiny_cfg();
617 let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
618 let mut hf_sd: StateDict<f32> = HashMap::new();
619 hf_sd.insert(
620 "decoder.conv_in.weight".into(),
621 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
622 );
623 assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
624 }
625
626 #[test]
627 fn encoder_load_hf_strips_vae_prefix() {
628 let cfg = tiny_cfg();
629 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
630 let bare = src.state_dict();
631 let mut prefixed: StateDict<f32> = HashMap::new();
632 for (k, v) in bare {
633 prefixed.insert(format!("vae.{k}"), v);
634 }
635 let mut dst = VaeEncoder::<f32>::new(cfg).unwrap();
636 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
637 assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
638 let x = Tensor::from_storage(
639 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
640 vec![1, 3, 8, 8],
641 false,
642 )
643 .unwrap();
644 let a = src.forward(&x).unwrap();
645 let b = dst.forward(&x).unwrap();
646 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
647 assert!((x - y).abs() < 1e-5);
648 }
649 }
650
651 #[test]
652 fn full_vae_checkpoint_loadable_by_both_halves() {
653 let cfg = tiny_cfg();
658 let dec_src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
659 let enc_src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
660
661 let mut combined: StateDict<f32> = HashMap::new();
662 for (k, v) in dec_src.state_dict() {
663 combined.insert(k, v);
664 }
665 for (k, v) in enc_src.state_dict() {
666 combined.insert(k, v);
667 }
668
669 let mut dec_dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
670 let dec_rep = dec_dst.load_hf_state_dict(&combined, false).unwrap();
671 let mut enc_dst = VaeEncoder::<f32>::new(cfg).unwrap();
672 let enc_rep = enc_dst.load_hf_state_dict(&combined, false).unwrap();
673
674 for k in &dec_rep.dropped {
676 assert!(
677 k.starts_with("encoder.") || k.starts_with("quant_conv."),
678 "decoder dropped unexpected key: {k}"
679 );
680 }
681 for k in &enc_rep.dropped {
683 assert!(
684 k.starts_with("decoder.") || k.starts_with("post_quant_conv."),
685 "encoder dropped unexpected key: {k}"
686 );
687 }
688 assert!(
689 !dec_rep.dropped.is_empty(),
690 "decoder should have dropped some keys"
691 );
692 assert!(
693 !enc_rep.dropped.is_empty(),
694 "encoder should have dropped some keys"
695 );
696 }
697}