Skip to main content

ferrotorch_diffusion/
safetensors_loader.rs

1//! Helpers that turn a path-to-safetensors into a loaded
2//! [`VaeDecoder`].
3//!
4//! The pinned SD-1.5 VAE mirror carries the full VAE state-dict
5//! (encoder + post_quant_conv + decoder + quant_conv). Inference needs
6//! only the decoder slice (`post_quant_conv.*` + `decoder.*`). This
7//! loader drops everything else and returns a [`DropReport`] so the pin
8//! script can audit the drop set.
9
10use std::collections::HashMap;
11use std::path::Path;
12
13use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
14use ferrotorch_nn::module::{Module, StateDict};
15use ferrotorch_serialize::load_safetensors;
16
17use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
18use crate::config::VaeDecoderConfig;
19use crate::unet::UNet2DConditionModel;
20use crate::unet_config::UNet2DConditionConfig;
21use crate::vae::VaeDecoder;
22use crate::vae_encoder::{VaeEncoder, VaeEncoderConfig};
23
24/// Audit trail returned by [`load_vae_decoder`] / [`VaeDecoder::load_hf_state_dict`].
25///
26/// Records HF keys that were dropped because they do not belong to the
27/// decoder (typically the encoder + `quant_conv` weights of a full
28/// `AutoencoderKL` checkpoint). The pin script asserts the dropped set
29/// equals the documented encoder / quant_conv key surface so a silent
30/// parameter drop cannot recur.
31#[derive(Debug, Default, Clone)]
32pub struct DropReport {
33    /// Keys present in the upstream state dict that did not belong to
34    /// the VAE decoder. Sorted for deterministic equality.
35    pub dropped: Vec<String>,
36}
37
38impl<T: Float> VaeDecoder<T> {
39    /// Load a HuggingFace AutoencoderKL state dict into this module.
40    ///
41    /// Accepts both:
42    ///   - `post_quant_conv.*` / `decoder.*` (bare-VAE layout, the
43    ///     normalised form the pin script produces)
44    ///   - `vae.post_quant_conv.*` / `vae.decoder.*` (when bundled
45    ///     inside a full SD pipeline checkpoint)
46    ///
47    /// Any other key (encoder, `quant_conv`, etc.) is recorded in the
48    /// returned [`DropReport`] (or, in strict mode, surfaces as
49    /// [`FerrotorchError::InvalidArgument`]).
50    ///
51    /// # Errors
52    ///
53    /// Forwards whatever each sub-module's `load_state_dict` returns
54    /// (`ShapeMismatch` on a wrong-shape tensor, `InvalidArgument` in
55    /// strict mode when a required tensor is missing). Strict mode will
56    /// surface `encoder.*` / `quant_conv.*` / etc. as errors; callers
57    /// with a full VAE checkpoint must pass `strict=false`.
58    pub fn load_hf_state_dict(
59        &mut self,
60        hf_state: &StateDict<T>,
61        strict: bool,
62    ) -> FerrotorchResult<DropReport> {
63        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
64        let mut dropped: Vec<String> = Vec::new();
65
66        for (k, v) in hf_state {
67            // Try (a) bare-VAE prefix → as-is; (b) full-pipeline
68            // `vae.<rest>` prefix → strip the `vae.` and accept.
69            let after_vae = k
70                .strip_prefix("vae.")
71                .map_or_else(|| k.clone(), str::to_owned);
72            if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
73                remapped.insert(after_vae, v.clone());
74                continue;
75            }
76            if strict {
77                return Err(FerrotorchError::InvalidArgument {
78                    message: format!(
79                        "VaeDecoder::load_hf_state_dict: key {k:?} is not under \
80                         `post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
81                         and strict mode is on. Pass strict=false to drop encoder / \
82                         quant_conv keys."
83                    ),
84                });
85            }
86            dropped.push(k.clone());
87        }
88        dropped.sort();
89        self.load_state_dict(&remapped, strict)?;
90        Ok(DropReport { dropped })
91    }
92}
93
94// ---------------------------------------------------------------------------
95// UNet2DConditionModel loader
96// ---------------------------------------------------------------------------
97
98impl<T: Float> UNet2DConditionModel<T> {
99    /// Load a HuggingFace UNet state dict into this module.
100    ///
101    /// Accepts both:
102    ///   - bare-UNet layout (the pin script normalises to this form)
103    ///   - `unet.<rest>` prefix (full SD pipeline checkpoint)
104    ///
105    /// Any unrecognised key is recorded in the returned [`DropReport`]
106    /// (or surfaces as [`FerrotorchError::InvalidArgument`] in strict
107    /// mode).
108    ///
109    /// # Errors
110    ///
111    /// Forwards whatever each sub-module's `load_state_dict` returns
112    /// (shape mismatch / strict-mode missing key).
113    pub fn load_hf_state_dict(
114        &mut self,
115        hf_state: &StateDict<T>,
116        strict: bool,
117    ) -> FerrotorchResult<DropReport> {
118        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
119        let mut dropped: Vec<String> = Vec::new();
120        for (k, v) in hf_state {
121            let after_unet = k
122                .strip_prefix("unet.")
123                .map_or_else(|| k.clone(), str::to_owned);
124            let is_unet_key = after_unet.starts_with("time_embedding.")
125                || after_unet.starts_with("conv_in.")
126                || after_unet.starts_with("down_blocks.")
127                || after_unet.starts_with("mid_block.")
128                || after_unet.starts_with("up_blocks.")
129                || after_unet.starts_with("conv_norm_out.")
130                || after_unet.starts_with("conv_out.");
131            if is_unet_key {
132                remapped.insert(after_unet, v.clone());
133                continue;
134            }
135            if strict {
136                return Err(FerrotorchError::InvalidArgument {
137                    message: format!(
138                        "UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
139                         a UNet prefix (with optional `unet.`) and strict mode is on."
140                    ),
141                });
142            }
143            dropped.push(k.clone());
144        }
145        dropped.sort();
146        self.load_state_dict(&remapped, strict)?;
147        Ok(DropReport { dropped })
148    }
149}
150
151/// Load a [`UNet2DConditionModel`] from a UNet
152/// `diffusion_pytorch_model.safetensors` file plus a parsed config.
153///
154/// `strict=false` is required when loading a full SD pipeline
155/// checkpoint (which carries `vae.*` / `text_encoder.*` keys); for a
156/// bare-UNet mirror (the form `pin_pretrained_diffusion_weights.py`
157/// uploads) `strict=true` is fine.
158///
159/// # Errors
160///
161/// Propagates safetensors parse errors, [`UNet2DConditionModel`]
162/// construction errors, and any per-key shape / strict-mode mismatch.
163pub fn load_unet<T: Float>(
164    weights_path: &Path,
165    cfg: UNet2DConditionConfig,
166    strict: bool,
167) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
168    let state =
169        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
170            message: format!(
171                "load_unet: failed to decode safetensors {}: {e}",
172                weights_path.display()
173            ),
174        })?;
175    let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
176    let report = unet.load_hf_state_dict(&state, strict)?;
177    Ok((unet, report))
178}
179
180// ---------------------------------------------------------------------------
181// ClipTextEncoder loader
182// ---------------------------------------------------------------------------
183
184/// Read a single non-sharded safetensors file into a typed `StateDict`,
185/// dropping any int64 `position_ids` buffer BEFORE the generic-`T`
186/// decode. The CLIP-text checkpoint ships an
187/// `embeddings.position_ids` (or `text_model.embeddings.position_ids`)
188/// `[1, 77]` int64 buffer that would poison a `load_safetensors::<f32>`
189/// pass because i64 is not representable as f32 in the underlying
190/// dispatch. Mirrors the trick `ferrotorch-bert`'s loader uses.
191fn load_safetensors_clip_filtered<T: Float>(
192    weights_path: &Path,
193) -> FerrotorchResult<(StateDict<T>, bool)> {
194    use safetensors::SafeTensors;
195
196    let bytes = std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
197        message: format!(
198            "load_safetensors_clip_filtered: failed to read {}: {e}",
199            weights_path.display()
200        ),
201    })?;
202    let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
203        message: format!(
204            "load_safetensors_clip_filtered: failed to parse {}: {e}",
205            weights_path.display()
206        ),
207    })?;
208    let mut keep: Vec<String> = Vec::new();
209    let mut had_position_ids = false;
210    for k in st.names() {
211        let s: &str = k.as_str();
212        // The position_ids buffer is the only int64 surface in
213        // CLIPTextModel and it has no parameter slot on our side.
214        if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
215            had_position_ids = true;
216            continue;
217        }
218        keep.push(String::from(s));
219    }
220
221    // Re-serialize only the kept tensors into an in-memory safetensors
222    // blob and feed that to `load_safetensors::<T>`. Reuses the audited
223    // generic decoder instead of re-implementing dtype dispatch here.
224    let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
225        Vec::with_capacity(keep.len());
226    for k in &keep {
227        let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
228            message: format!(
229                "load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
230            ),
231        })?;
232        subset.push((k.clone(), v));
233    }
234    let serialized =
235        safetensors::serialize(subset, &None).map_err(|e| FerrotorchError::InvalidArgument {
236            message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
237        })?;
238    let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
239        message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
240    })?;
241    std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
242        message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
243    })?;
244    let state = load_safetensors::<T>(tmp.path())?;
245    Ok((state, had_position_ids))
246}
247
248/// Load a [`ClipTextEncoder`] from a CLIP text-tower
249/// `model.safetensors` file plus a parsed [`ClipTextConfig`].
250///
251/// Accepts both upstream layouts:
252///   - bare `embeddings.* / encoder.* / final_layer_norm.*` (what the
253///     pin script normalises to).
254///   - `text_model.<rest>` prefix (what the upstream HF checkpoint
255///     ships).
256///
257/// The int64 `embeddings.position_ids` buffer (a `[1, max_pos]`
258/// `arange(max_pos)` constant regenerated each forward pass) is
259/// dropped at decode time and surfaced via the returned
260/// [`DropReport`].
261///
262/// `strict=false` is required when the upstream checkpoint carries the
263/// position_ids buffer (the default for `runwayml/stable-diffusion-v1-5`'s
264/// `text_encoder/model.safetensors`).
265///
266/// # Errors
267///
268/// Propagates safetensors parse errors, [`ClipTextEncoder`] construction
269/// errors, and any per-key shape / strict-mode mismatch.
270pub fn load_clip_text_encoder<T: Float>(
271    weights_path: &Path,
272    cfg: ClipTextConfig,
273    strict: bool,
274) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
275    let (mut state, had_position_ids) =
276        load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
277            FerrotorchError::InvalidArgument {
278                message: format!(
279                    "load_clip_text_encoder: failed to decode safetensors {}: {e}",
280                    weights_path.display()
281                ),
282            }
283        })?;
284
285    // Re-insert a placeholder entry for the position_ids buffer (with
286    // the upstream key it actually used) so the model's DropReport
287    // captures it as an intentionally-dropped upstream key. The
288    // placeholder tensor is never consumed — `load_hf_state_dict`
289    // drops the entry before any parameter slot sees it.
290    if had_position_ids {
291        let key = if state.keys().any(|k| k.starts_with("text_model.")) {
292            "text_model.embeddings.position_ids".to_string()
293        } else {
294            "embeddings.position_ids".to_string()
295        };
296        state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
297    }
298
299    let mut enc = ClipTextEncoder::<T>::new(cfg)?;
300    let report = enc.load_hf_state_dict(&state, strict)?;
301    Ok((enc, report))
302}
303
304/// Load a [`VaeDecoder`] from a VAE `diffusion_pytorch_model.safetensors`
305/// file plus a parsed config.
306///
307/// `strict=false` is required for a full `AutoencoderKL` checkpoint
308/// (which ships encoder + quant_conv weights this decoder-only loader
309/// has no slot for). The returned [`DropReport`] captures every
310/// dropped key so the pin script can confirm the drop set is exactly
311/// the documented encoder/quant_conv surface.
312///
313/// # Errors
314///
315/// Propagates safetensors parse errors, [`VaeDecoder`] construction
316/// errors, and any per-key shape / strict-mode mismatch from the
317/// underlying load.
318pub fn load_vae_decoder<T: Float>(
319    weights_path: &Path,
320    cfg: VaeDecoderConfig,
321    strict: bool,
322) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
323    let state =
324        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
325            message: format!(
326                "load_vae_decoder: failed to decode safetensors {}: {e}",
327                weights_path.display()
328            ),
329        })?;
330    let mut decoder = VaeDecoder::<T>::new(cfg)?;
331    let report = decoder.load_hf_state_dict(&state, strict)?;
332    Ok((decoder, report))
333}
334
335// ---------------------------------------------------------------------------
336// VaeEncoder loader
337// ---------------------------------------------------------------------------
338
339impl<T: Float> VaeEncoder<T> {
340    /// Load a HuggingFace AutoencoderKL state dict into this module.
341    ///
342    /// Accepts both:
343    ///   - `encoder.*` / `quant_conv.*` (bare-VAE layout, the normalised
344    ///     form produced by the pin script for an encoder-only artifact)
345    ///   - `vae.encoder.*` / `vae.quant_conv.*` (full SD pipeline
346    ///     checkpoint, e.g. an upstream `runwayml/stable-diffusion-v1-5`
347    ///     `model.safetensors`)
348    ///
349    /// Any other key (decoder, `post_quant_conv`, UNet, text encoder, …)
350    /// is recorded in the returned [`DropReport`] (or, in strict mode,
351    /// surfaces as [`FerrotorchError::InvalidArgument`]). This mirrors
352    /// [`VaeDecoder::load_hf_state_dict`] so the same full-checkpoint
353    /// file can be loaded twice — once for the encoder and once for the
354    /// decoder — each time dropping the other half.
355    ///
356    /// # Errors
357    ///
358    /// Forwards whatever each sub-module's `load_state_dict` returns
359    /// (`ShapeMismatch` on a wrong-shape tensor, `InvalidArgument` in
360    /// strict mode when a required tensor is missing). Strict mode will
361    /// surface `decoder.*` / `post_quant_conv.*` / etc. as errors;
362    /// callers with a full VAE checkpoint must pass `strict=false`.
363    pub fn load_hf_state_dict(
364        &mut self,
365        hf_state: &StateDict<T>,
366        strict: bool,
367    ) -> FerrotorchResult<DropReport> {
368        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
369        let mut dropped: Vec<String> = Vec::new();
370
371        for (k, v) in hf_state {
372            // (a) bare-VAE prefix → as-is; (b) full-pipeline `vae.<rest>`
373            // prefix → strip the `vae.` and accept.
374            let after_vae = k
375                .strip_prefix("vae.")
376                .map_or_else(|| k.clone(), str::to_owned);
377            if after_vae.starts_with("encoder.") || after_vae.starts_with("quant_conv.") {
378                remapped.insert(after_vae, v.clone());
379                continue;
380            }
381            if strict {
382                return Err(FerrotorchError::InvalidArgument {
383                    message: format!(
384                        "VaeEncoder::load_hf_state_dict: key {k:?} is not under \
385                         `encoder.*` / `quant_conv.*` (with optional `vae.` prefix) \
386                         and strict mode is on. Pass strict=false to drop decoder / \
387                         post_quant_conv keys."
388                    ),
389                });
390            }
391            dropped.push(k.clone());
392        }
393        dropped.sort();
394        self.load_state_dict(&remapped, strict)?;
395        Ok(DropReport { dropped })
396    }
397}
398
399/// Load a [`VaeEncoder`] from a VAE `diffusion_pytorch_model.safetensors`
400/// file plus a parsed config.
401///
402/// Symmetric counterpart of [`load_vae_decoder`]. `strict=false` is
403/// required for a full `AutoencoderKL` checkpoint (which ships
404/// `decoder` + `post_quant_conv` weights this encoder-only loader has
405/// no slot for). The returned [`DropReport`] captures every dropped
406/// key so a pin script can audit the drop set.
407///
408/// # Errors
409///
410/// Propagates safetensors parse errors, [`VaeEncoder`] construction
411/// errors, and any per-key shape / strict-mode mismatch from the
412/// underlying load.
413pub fn load_vae_encoder<T: Float>(
414    weights_path: &Path,
415    cfg: VaeEncoderConfig,
416    strict: bool,
417) -> FerrotorchResult<(VaeEncoder<T>, DropReport)> {
418    let state =
419        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
420            message: format!(
421                "load_vae_encoder: failed to decode safetensors {}: {e}",
422                weights_path.display()
423            ),
424        })?;
425    let mut encoder = VaeEncoder::<T>::new(cfg)?;
426    let report = encoder.load_hf_state_dict(&state, strict)?;
427    Ok((encoder, report))
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433    use ferrotorch_core::{Tensor, TensorStorage};
434    use ferrotorch_serialize::save_safetensors;
435    use std::path::PathBuf;
436
437    fn tiny_cfg() -> VaeDecoderConfig {
438        VaeDecoderConfig {
439            out_channels: 3,
440            latent_channels: 4,
441            block_out_channels: vec![4, 8, 16, 16],
442            layers_per_block: 1,
443            norm_num_groups: 4,
444            sample_size: 8,
445            scaling_factor: 0.18215,
446        }
447    }
448
449    fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
450        let dir = tempfile::tempdir().unwrap();
451        let path = dir.path().join("model.safetensors");
452        // The on-disk file uses the bare-VAE prefix (no `vae.`); the
453        // loader's strip-vae path is exercised by the dedicated test
454        // below.
455        let sd = v.state_dict();
456        save_safetensors(&sd, &path).unwrap();
457        (dir, path)
458    }
459
460    #[test]
461    fn round_trip_safetensors_into_decoder() {
462        let cfg = tiny_cfg();
463        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
464        let (_d, p) = tmp_safetensors_from(&src);
465        let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
466        assert!(
467            report.dropped.is_empty(),
468            "round-trip should have empty drop list, got {:?}",
469            report.dropped
470        );
471        let x = Tensor::from_storage(
472            TensorStorage::cpu(vec![0.01f32; 4]),
473            vec![1, 4, 1, 1],
474            false,
475        )
476        .unwrap();
477        let a = src.forward(&x).unwrap();
478        let b = dst.forward(&x).unwrap();
479        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
480            assert!((x - y).abs() < 1e-5);
481        }
482    }
483
484    #[test]
485    fn load_hf_drops_encoder_keys_nonstrict() {
486        let cfg = tiny_cfg();
487        let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
488        let mut hf_sd: StateDict<f32> = v.state_dict();
489        // Add an encoder key — this should be dropped.
490        hf_sd.insert(
491            "encoder.conv_in.weight".into(),
492            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
493        );
494        // Add a quant_conv key — also dropped.
495        hf_sd.insert(
496            "quant_conv.weight".into(),
497            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
498        );
499        let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
500        assert_eq!(
501            rep.dropped,
502            vec![
503                "encoder.conv_in.weight".to_string(),
504                "quant_conv.weight".to_string(),
505            ]
506        );
507    }
508
509    #[test]
510    fn load_hf_strict_rejects_encoder_keys() {
511        let cfg = tiny_cfg();
512        let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
513        let mut hf_sd: StateDict<f32> = HashMap::new();
514        hf_sd.insert(
515            "encoder.conv_in.weight".into(),
516            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
517        );
518        assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
519    }
520
521    #[test]
522    fn load_hf_strips_vae_prefix() {
523        let cfg = tiny_cfg();
524        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
525        let bare = src.state_dict();
526        // Re-prefix with `vae.` (the layout SD pipeline checkpoints use).
527        let mut prefixed: StateDict<f32> = HashMap::new();
528        for (k, v) in bare {
529            prefixed.insert(format!("vae.{k}"), v);
530        }
531        let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
532        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
533        assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
534        let x = Tensor::from_storage(
535            TensorStorage::cpu(vec![0.01f32; 4]),
536            vec![1, 4, 1, 1],
537            false,
538        )
539        .unwrap();
540        let a = src.forward(&x).unwrap();
541        let b = dst.forward(&x).unwrap();
542        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
543            assert!((x - y).abs() < 1e-5);
544        }
545    }
546
547    fn tmp_encoder_safetensors_from(v: &VaeEncoder<f32>) -> (tempfile::TempDir, PathBuf) {
548        let dir = tempfile::tempdir().unwrap();
549        let path = dir.path().join("model.safetensors");
550        let sd = v.state_dict();
551        save_safetensors(&sd, &path).unwrap();
552        (dir, path)
553    }
554
555    #[test]
556    fn round_trip_safetensors_into_encoder() {
557        let cfg = tiny_cfg();
558        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
559        let (_d, p) = tmp_encoder_safetensors_from(&src);
560        let (dst, report) = load_vae_encoder::<f32>(&p, cfg.clone(), false).unwrap();
561        assert!(
562            report.dropped.is_empty(),
563            "encoder round-trip should have empty drop list, got {:?}",
564            report.dropped
565        );
566        let x = Tensor::from_storage(
567            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
568            vec![1, 3, 8, 8],
569            false,
570        )
571        .unwrap();
572        let a = src.forward(&x).unwrap();
573        let b = dst.forward(&x).unwrap();
574        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
575            assert!((x - y).abs() < 1e-5);
576        }
577    }
578
579    #[test]
580    fn encoder_load_hf_drops_decoder_keys_nonstrict() {
581        let cfg = tiny_cfg();
582        let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
583        let mut hf_sd: StateDict<f32> = v.state_dict();
584        // Add decoder keys that the encoder should drop.
585        hf_sd.insert(
586            "decoder.conv_in.weight".into(),
587            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
588        );
589        hf_sd.insert(
590            "post_quant_conv.weight".into(),
591            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
592        );
593        let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
594        assert_eq!(
595            rep.dropped,
596            vec![
597                "decoder.conv_in.weight".to_string(),
598                "post_quant_conv.weight".to_string(),
599            ]
600        );
601    }
602
603    #[test]
604    fn encoder_load_hf_strict_rejects_decoder_keys() {
605        let cfg = tiny_cfg();
606        let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
607        let mut hf_sd: StateDict<f32> = HashMap::new();
608        hf_sd.insert(
609            "decoder.conv_in.weight".into(),
610            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
611        );
612        assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
613    }
614
615    #[test]
616    fn encoder_load_hf_strips_vae_prefix() {
617        let cfg = tiny_cfg();
618        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
619        let bare = src.state_dict();
620        let mut prefixed: StateDict<f32> = HashMap::new();
621        for (k, v) in bare {
622            prefixed.insert(format!("vae.{k}"), v);
623        }
624        let mut dst = VaeEncoder::<f32>::new(cfg).unwrap();
625        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
626        assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
627        let x = Tensor::from_storage(
628            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
629            vec![1, 3, 8, 8],
630            false,
631        )
632        .unwrap();
633        let a = src.forward(&x).unwrap();
634        let b = dst.forward(&x).unwrap();
635        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
636            assert!((x - y).abs() < 1e-5);
637        }
638    }
639
640    #[test]
641    fn full_vae_checkpoint_loadable_by_both_halves() {
642        // The interesting property: a single full-VAE state dict
643        // (encoder + post_quant_conv + decoder + quant_conv) is loadable
644        // by both VaeDecoder and VaeEncoder, each cleanly dropping the
645        // half it doesn't own.
646        let cfg = tiny_cfg();
647        let dec_src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
648        let enc_src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
649
650        let mut combined: StateDict<f32> = HashMap::new();
651        for (k, v) in dec_src.state_dict() {
652            combined.insert(k, v);
653        }
654        for (k, v) in enc_src.state_dict() {
655            combined.insert(k, v);
656        }
657
658        let mut dec_dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
659        let dec_rep = dec_dst.load_hf_state_dict(&combined, false).unwrap();
660        let mut enc_dst = VaeEncoder::<f32>::new(cfg).unwrap();
661        let enc_rep = enc_dst.load_hf_state_dict(&combined, false).unwrap();
662
663        // Decoder should have dropped exactly the encoder + quant_conv keys.
664        for k in &dec_rep.dropped {
665            assert!(
666                k.starts_with("encoder.") || k.starts_with("quant_conv."),
667                "decoder dropped unexpected key: {k}"
668            );
669        }
670        // Encoder should have dropped exactly the decoder + post_quant_conv keys.
671        for k in &enc_rep.dropped {
672            assert!(
673                k.starts_with("decoder.") || k.starts_with("post_quant_conv."),
674                "encoder dropped unexpected key: {k}"
675            );
676        }
677        assert!(
678            !dec_rep.dropped.is_empty(),
679            "decoder should have dropped some keys"
680        );
681        assert!(
682            !enc_rep.dropped.is_empty(),
683            "encoder should have dropped some keys"
684        );
685    }
686}