Skip to main content

ferrotorch_diffusion/
safetensors_loader.rs

1//! Helpers that turn a path-to-safetensors into a loaded
2//! [`VaeDecoder`].
3//!
4//! The pinned SD-1.5 VAE mirror carries the full VAE state-dict
5//! (encoder + post_quant_conv + decoder + quant_conv). Inference needs
6//! only the decoder slice (`post_quant_conv.*` + `decoder.*`). This
7//! loader drops everything else and returns a [`DropReport`] so the pin
8//! script can audit the drop set.
9
10use std::collections::HashMap;
11use std::path::Path;
12
13use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
14use ferrotorch_nn::module::{Module, StateDict};
15use ferrotorch_serialize::load_safetensors;
16
17use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
18use crate::config::VaeDecoderConfig;
19use crate::unet::UNet2DConditionModel;
20use crate::unet_config::UNet2DConditionConfig;
21use crate::vae::VaeDecoder;
22use crate::vae_encoder::{VaeEncoder, VaeEncoderConfig};
23
24/// Audit trail returned by [`load_vae_decoder`] / [`VaeDecoder::load_hf_state_dict`].
25///
26/// Records HF keys that were dropped because they do not belong to the
27/// decoder (typically the encoder + `quant_conv` weights of a full
28/// `AutoencoderKL` checkpoint). The pin script asserts the dropped set
29/// equals the documented encoder / quant_conv key surface so a silent
30/// parameter drop cannot recur.
31#[derive(Debug, Default, Clone)]
32pub struct DropReport {
33    /// Keys present in the upstream state dict that did not belong to
34    /// the VAE decoder. Sorted for deterministic equality.
35    pub dropped: Vec<String>,
36}
37
38impl<T: Float> VaeDecoder<T> {
39    /// Load a HuggingFace AutoencoderKL state dict into this module.
40    ///
41    /// Accepts both:
42    ///   - `post_quant_conv.*` / `decoder.*` (bare-VAE layout, the
43    ///     normalised form the pin script produces)
44    ///   - `vae.post_quant_conv.*` / `vae.decoder.*` (when bundled
45    ///     inside a full SD pipeline checkpoint)
46    ///
47    /// Any other key (encoder, `quant_conv`, etc.) is recorded in the
48    /// returned [`DropReport`] (or, in strict mode, surfaces as
49    /// [`FerrotorchError::InvalidArgument`]).
50    ///
51    /// # Errors
52    ///
53    /// Forwards whatever each sub-module's `load_state_dict` returns
54    /// (`ShapeMismatch` on a wrong-shape tensor, `InvalidArgument` in
55    /// strict mode when a required tensor is missing). Strict mode will
56    /// surface `encoder.*` / `quant_conv.*` / etc. as errors; callers
57    /// with a full VAE checkpoint must pass `strict=false`.
58    pub fn load_hf_state_dict(
59        &mut self,
60        hf_state: &StateDict<T>,
61        strict: bool,
62    ) -> FerrotorchResult<DropReport> {
63        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
64        let mut dropped: Vec<String> = Vec::new();
65
66        for (k, v) in hf_state {
67            // Try (a) bare-VAE prefix → as-is; (b) full-pipeline
68            // `vae.<rest>` prefix → strip the `vae.` and accept.
69            let after_vae = k.strip_prefix("vae.").map_or_else(|| k.clone(), str::to_owned);
70            if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
71                remapped.insert(after_vae, v.clone());
72                continue;
73            }
74            if strict {
75                return Err(FerrotorchError::InvalidArgument {
76                    message: format!(
77                        "VaeDecoder::load_hf_state_dict: key {k:?} is not under \
78                         `post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
79                         and strict mode is on. Pass strict=false to drop encoder / \
80                         quant_conv keys."
81                    ),
82                });
83            }
84            dropped.push(k.clone());
85        }
86        dropped.sort();
87        self.load_state_dict(&remapped, strict)?;
88        Ok(DropReport { dropped })
89    }
90}
91
92// ---------------------------------------------------------------------------
93// UNet2DConditionModel loader
94// ---------------------------------------------------------------------------
95
96impl<T: Float> UNet2DConditionModel<T> {
97    /// Load a HuggingFace UNet state dict into this module.
98    ///
99    /// Accepts both:
100    ///   - bare-UNet layout (the pin script normalises to this form)
101    ///   - `unet.<rest>` prefix (full SD pipeline checkpoint)
102    ///
103    /// Any unrecognised key is recorded in the returned [`DropReport`]
104    /// (or surfaces as [`FerrotorchError::InvalidArgument`] in strict
105    /// mode).
106    ///
107    /// # Errors
108    ///
109    /// Forwards whatever each sub-module's `load_state_dict` returns
110    /// (shape mismatch / strict-mode missing key).
111    pub fn load_hf_state_dict(
112        &mut self,
113        hf_state: &StateDict<T>,
114        strict: bool,
115    ) -> FerrotorchResult<DropReport> {
116        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
117        let mut dropped: Vec<String> = Vec::new();
118        for (k, v) in hf_state {
119            let after_unet = k.strip_prefix("unet.").map_or_else(|| k.clone(), str::to_owned);
120            let is_unet_key = after_unet.starts_with("time_embedding.")
121                || after_unet.starts_with("conv_in.")
122                || after_unet.starts_with("down_blocks.")
123                || after_unet.starts_with("mid_block.")
124                || after_unet.starts_with("up_blocks.")
125                || after_unet.starts_with("conv_norm_out.")
126                || after_unet.starts_with("conv_out.");
127            if is_unet_key {
128                remapped.insert(after_unet, v.clone());
129                continue;
130            }
131            if strict {
132                return Err(FerrotorchError::InvalidArgument {
133                    message: format!(
134                        "UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
135                         a UNet prefix (with optional `unet.`) and strict mode is on."
136                    ),
137                });
138            }
139            dropped.push(k.clone());
140        }
141        dropped.sort();
142        self.load_state_dict(&remapped, strict)?;
143        Ok(DropReport { dropped })
144    }
145}
146
147/// Load a [`UNet2DConditionModel`] from a UNet
148/// `diffusion_pytorch_model.safetensors` file plus a parsed config.
149///
150/// `strict=false` is required when loading a full SD pipeline
151/// checkpoint (which carries `vae.*` / `text_encoder.*` keys); for a
152/// bare-UNet mirror (the form `pin_pretrained_diffusion_weights.py`
153/// uploads) `strict=true` is fine.
154///
155/// # Errors
156///
157/// Propagates safetensors parse errors, [`UNet2DConditionModel`]
158/// construction errors, and any per-key shape / strict-mode mismatch.
159pub fn load_unet<T: Float>(
160    weights_path: &Path,
161    cfg: UNet2DConditionConfig,
162    strict: bool,
163) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
164    let state =
165        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
166            message: format!(
167                "load_unet: failed to decode safetensors {}: {e}",
168                weights_path.display()
169            ),
170        })?;
171    let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
172    let report = unet.load_hf_state_dict(&state, strict)?;
173    Ok((unet, report))
174}
175
176// ---------------------------------------------------------------------------
177// ClipTextEncoder loader
178// ---------------------------------------------------------------------------
179
180/// Read a single non-sharded safetensors file into a typed `StateDict`,
181/// dropping any int64 `position_ids` buffer BEFORE the generic-`T`
182/// decode. The CLIP-text checkpoint ships an
183/// `embeddings.position_ids` (or `text_model.embeddings.position_ids`)
184/// `[1, 77]` int64 buffer that would poison a `load_safetensors::<f32>`
185/// pass because i64 is not representable as f32 in the underlying
186/// dispatch. Mirrors the trick `ferrotorch-bert`'s loader uses.
187fn load_safetensors_clip_filtered<T: Float>(
188    weights_path: &Path,
189) -> FerrotorchResult<(StateDict<T>, bool)> {
190    use safetensors::SafeTensors;
191
192    let bytes =
193        std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
194            message: format!(
195                "load_safetensors_clip_filtered: failed to read {}: {e}",
196                weights_path.display()
197            ),
198        })?;
199    let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
200        message: format!(
201            "load_safetensors_clip_filtered: failed to parse {}: {e}",
202            weights_path.display()
203        ),
204    })?;
205    let mut keep: Vec<String> = Vec::new();
206    let mut had_position_ids = false;
207    for k in st.names() {
208        let s: &str = k.as_str();
209        // The position_ids buffer is the only int64 surface in
210        // CLIPTextModel and it has no parameter slot on our side.
211        if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
212            had_position_ids = true;
213            continue;
214        }
215        keep.push(String::from(s));
216    }
217
218    // Re-serialize only the kept tensors into an in-memory safetensors
219    // blob and feed that to `load_safetensors::<T>`. Reuses the audited
220    // generic decoder instead of re-implementing dtype dispatch here.
221    let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
222        Vec::with_capacity(keep.len());
223    for k in &keep {
224        let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
225            message: format!(
226                "load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
227            ),
228        })?;
229        subset.push((k.clone(), v));
230    }
231    let serialized = safetensors::serialize(subset, &None).map_err(|e| {
232        FerrotorchError::InvalidArgument {
233            message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
234        }
235    })?;
236    let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
237        message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
238    })?;
239    std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
240        message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
241    })?;
242    let state = load_safetensors::<T>(tmp.path())?;
243    Ok((state, had_position_ids))
244}
245
246/// Load a [`ClipTextEncoder`] from a CLIP text-tower
247/// `model.safetensors` file plus a parsed [`ClipTextConfig`].
248///
249/// Accepts both upstream layouts:
250///   - bare `embeddings.* / encoder.* / final_layer_norm.*` (what the
251///     pin script normalises to).
252///   - `text_model.<rest>` prefix (what the upstream HF checkpoint
253///     ships).
254///
255/// The int64 `embeddings.position_ids` buffer (a `[1, max_pos]`
256/// `arange(max_pos)` constant regenerated each forward pass) is
257/// dropped at decode time and surfaced via the returned
258/// [`DropReport`].
259///
260/// `strict=false` is required when the upstream checkpoint carries the
261/// position_ids buffer (the default for `runwayml/stable-diffusion-v1-5`'s
262/// `text_encoder/model.safetensors`).
263///
264/// # Errors
265///
266/// Propagates safetensors parse errors, [`ClipTextEncoder`] construction
267/// errors, and any per-key shape / strict-mode mismatch.
268pub fn load_clip_text_encoder<T: Float>(
269    weights_path: &Path,
270    cfg: ClipTextConfig,
271    strict: bool,
272) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
273    let (mut state, had_position_ids) =
274        load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
275            FerrotorchError::InvalidArgument {
276                message: format!(
277                    "load_clip_text_encoder: failed to decode safetensors {}: {e}",
278                    weights_path.display()
279                ),
280            }
281        })?;
282
283    // Re-insert a placeholder entry for the position_ids buffer (with
284    // the upstream key it actually used) so the model's DropReport
285    // captures it as an intentionally-dropped upstream key. The
286    // placeholder tensor is never consumed — `load_hf_state_dict`
287    // drops the entry before any parameter slot sees it.
288    if had_position_ids {
289        let key = if state
290            .keys()
291            .any(|k| k.starts_with("text_model."))
292        {
293            "text_model.embeddings.position_ids".to_string()
294        } else {
295            "embeddings.position_ids".to_string()
296        };
297        state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
298    }
299
300    let mut enc = ClipTextEncoder::<T>::new(cfg)?;
301    let report = enc.load_hf_state_dict(&state, strict)?;
302    Ok((enc, report))
303}
304
305/// Load a [`VaeDecoder`] from a VAE `diffusion_pytorch_model.safetensors`
306/// file plus a parsed config.
307///
308/// `strict=false` is required for a full `AutoencoderKL` checkpoint
309/// (which ships encoder + quant_conv weights this decoder-only loader
310/// has no slot for). The returned [`DropReport`] captures every
311/// dropped key so the pin script can confirm the drop set is exactly
312/// the documented encoder/quant_conv surface.
313///
314/// # Errors
315///
316/// Propagates safetensors parse errors, [`VaeDecoder`] construction
317/// errors, and any per-key shape / strict-mode mismatch from the
318/// underlying load.
319pub fn load_vae_decoder<T: Float>(
320    weights_path: &Path,
321    cfg: VaeDecoderConfig,
322    strict: bool,
323) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
324    let state =
325        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
326            message: format!(
327                "load_vae_decoder: failed to decode safetensors {}: {e}",
328                weights_path.display()
329            ),
330        })?;
331    let mut decoder = VaeDecoder::<T>::new(cfg)?;
332    let report = decoder.load_hf_state_dict(&state, strict)?;
333    Ok((decoder, report))
334}
335
336// ---------------------------------------------------------------------------
337// VaeEncoder loader
338// ---------------------------------------------------------------------------
339
340impl<T: Float> VaeEncoder<T> {
341    /// Load a HuggingFace AutoencoderKL state dict into this module.
342    ///
343    /// Accepts both:
344    ///   - `encoder.*` / `quant_conv.*` (bare-VAE layout, the normalised
345    ///     form produced by the pin script for an encoder-only artifact)
346    ///   - `vae.encoder.*` / `vae.quant_conv.*` (full SD pipeline
347    ///     checkpoint, e.g. an upstream `runwayml/stable-diffusion-v1-5`
348    ///     `model.safetensors`)
349    ///
350    /// Any other key (decoder, `post_quant_conv`, UNet, text encoder, …)
351    /// is recorded in the returned [`DropReport`] (or, in strict mode,
352    /// surfaces as [`FerrotorchError::InvalidArgument`]). This mirrors
353    /// [`VaeDecoder::load_hf_state_dict`] so the same full-checkpoint
354    /// file can be loaded twice — once for the encoder and once for the
355    /// decoder — each time dropping the other half.
356    ///
357    /// # Errors
358    ///
359    /// Forwards whatever each sub-module's `load_state_dict` returns
360    /// (`ShapeMismatch` on a wrong-shape tensor, `InvalidArgument` in
361    /// strict mode when a required tensor is missing). Strict mode will
362    /// surface `decoder.*` / `post_quant_conv.*` / etc. as errors;
363    /// callers with a full VAE checkpoint must pass `strict=false`.
364    pub fn load_hf_state_dict(
365        &mut self,
366        hf_state: &StateDict<T>,
367        strict: bool,
368    ) -> FerrotorchResult<DropReport> {
369        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
370        let mut dropped: Vec<String> = Vec::new();
371
372        for (k, v) in hf_state {
373            // (a) bare-VAE prefix → as-is; (b) full-pipeline `vae.<rest>`
374            // prefix → strip the `vae.` and accept.
375            let after_vae = k.strip_prefix("vae.").map_or_else(|| k.clone(), str::to_owned);
376            if after_vae.starts_with("encoder.") || after_vae.starts_with("quant_conv.") {
377                remapped.insert(after_vae, v.clone());
378                continue;
379            }
380            if strict {
381                return Err(FerrotorchError::InvalidArgument {
382                    message: format!(
383                        "VaeEncoder::load_hf_state_dict: key {k:?} is not under \
384                         `encoder.*` / `quant_conv.*` (with optional `vae.` prefix) \
385                         and strict mode is on. Pass strict=false to drop decoder / \
386                         post_quant_conv keys."
387                    ),
388                });
389            }
390            dropped.push(k.clone());
391        }
392        dropped.sort();
393        self.load_state_dict(&remapped, strict)?;
394        Ok(DropReport { dropped })
395    }
396}
397
398/// Load a [`VaeEncoder`] from a VAE `diffusion_pytorch_model.safetensors`
399/// file plus a parsed config.
400///
401/// Symmetric counterpart of [`load_vae_decoder`]. `strict=false` is
402/// required for a full `AutoencoderKL` checkpoint (which ships
403/// `decoder` + `post_quant_conv` weights this encoder-only loader has
404/// no slot for). The returned [`DropReport`] captures every dropped
405/// key so a pin script can audit the drop set.
406///
407/// # Errors
408///
409/// Propagates safetensors parse errors, [`VaeEncoder`] construction
410/// errors, and any per-key shape / strict-mode mismatch from the
411/// underlying load.
412pub fn load_vae_encoder<T: Float>(
413    weights_path: &Path,
414    cfg: VaeEncoderConfig,
415    strict: bool,
416) -> FerrotorchResult<(VaeEncoder<T>, DropReport)> {
417    let state =
418        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
419            message: format!(
420                "load_vae_encoder: failed to decode safetensors {}: {e}",
421                weights_path.display()
422            ),
423        })?;
424    let mut encoder = VaeEncoder::<T>::new(cfg)?;
425    let report = encoder.load_hf_state_dict(&state, strict)?;
426    Ok((encoder, report))
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use ferrotorch_core::{Tensor, TensorStorage};
433    use ferrotorch_serialize::save_safetensors;
434    use std::path::PathBuf;
435
436    fn tiny_cfg() -> VaeDecoderConfig {
437        VaeDecoderConfig {
438            out_channels: 3,
439            latent_channels: 4,
440            block_out_channels: vec![4, 8, 16, 16],
441            layers_per_block: 1,
442            norm_num_groups: 4,
443            sample_size: 8,
444            scaling_factor: 0.18215,
445        }
446    }
447
448    fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
449        let dir = tempfile::tempdir().unwrap();
450        let path = dir.path().join("model.safetensors");
451        // The on-disk file uses the bare-VAE prefix (no `vae.`); the
452        // loader's strip-vae path is exercised by the dedicated test
453        // below.
454        let sd = v.state_dict();
455        save_safetensors(&sd, &path).unwrap();
456        (dir, path)
457    }
458
459    #[test]
460    fn round_trip_safetensors_into_decoder() {
461        let cfg = tiny_cfg();
462        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
463        let (_d, p) = tmp_safetensors_from(&src);
464        let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
465        assert!(
466            report.dropped.is_empty(),
467            "round-trip should have empty drop list, got {:?}",
468            report.dropped
469        );
470        let x = Tensor::from_storage(
471            TensorStorage::cpu(vec![0.01f32; 4]),
472            vec![1, 4, 1, 1],
473            false,
474        )
475        .unwrap();
476        let a = src.forward(&x).unwrap();
477        let b = dst.forward(&x).unwrap();
478        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
479            assert!((x - y).abs() < 1e-5);
480        }
481    }
482
483    #[test]
484    fn load_hf_drops_encoder_keys_nonstrict() {
485        let cfg = tiny_cfg();
486        let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
487        let mut hf_sd: StateDict<f32> = v.state_dict();
488        // Add an encoder key — this should be dropped.
489        hf_sd.insert(
490            "encoder.conv_in.weight".into(),
491            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
492        );
493        // Add a quant_conv key — also dropped.
494        hf_sd.insert(
495            "quant_conv.weight".into(),
496            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
497        );
498        let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
499        assert_eq!(
500            rep.dropped,
501            vec![
502                "encoder.conv_in.weight".to_string(),
503                "quant_conv.weight".to_string(),
504            ]
505        );
506    }
507
508    #[test]
509    fn load_hf_strict_rejects_encoder_keys() {
510        let cfg = tiny_cfg();
511        let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
512        let mut hf_sd: StateDict<f32> = HashMap::new();
513        hf_sd.insert(
514            "encoder.conv_in.weight".into(),
515            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
516        );
517        assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
518    }
519
520    #[test]
521    fn load_hf_strips_vae_prefix() {
522        let cfg = tiny_cfg();
523        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
524        let bare = src.state_dict();
525        // Re-prefix with `vae.` (the layout SD pipeline checkpoints use).
526        let mut prefixed: StateDict<f32> = HashMap::new();
527        for (k, v) in bare {
528            prefixed.insert(format!("vae.{k}"), v);
529        }
530        let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
531        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
532        assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
533        let x = Tensor::from_storage(
534            TensorStorage::cpu(vec![0.01f32; 4]),
535            vec![1, 4, 1, 1],
536            false,
537        )
538        .unwrap();
539        let a = src.forward(&x).unwrap();
540        let b = dst.forward(&x).unwrap();
541        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
542            assert!((x - y).abs() < 1e-5);
543        }
544    }
545
546    fn tmp_encoder_safetensors_from(v: &VaeEncoder<f32>) -> (tempfile::TempDir, PathBuf) {
547        let dir = tempfile::tempdir().unwrap();
548        let path = dir.path().join("model.safetensors");
549        let sd = v.state_dict();
550        save_safetensors(&sd, &path).unwrap();
551        (dir, path)
552    }
553
554    #[test]
555    fn round_trip_safetensors_into_encoder() {
556        let cfg = tiny_cfg();
557        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
558        let (_d, p) = tmp_encoder_safetensors_from(&src);
559        let (dst, report) = load_vae_encoder::<f32>(&p, cfg.clone(), false).unwrap();
560        assert!(
561            report.dropped.is_empty(),
562            "encoder round-trip should have empty drop list, got {:?}",
563            report.dropped
564        );
565        let x = Tensor::from_storage(
566            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
567            vec![1, 3, 8, 8],
568            false,
569        )
570        .unwrap();
571        let a = src.forward(&x).unwrap();
572        let b = dst.forward(&x).unwrap();
573        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
574            assert!((x - y).abs() < 1e-5);
575        }
576    }
577
578    #[test]
579    fn encoder_load_hf_drops_decoder_keys_nonstrict() {
580        let cfg = tiny_cfg();
581        let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
582        let mut hf_sd: StateDict<f32> = v.state_dict();
583        // Add decoder keys that the encoder should drop.
584        hf_sd.insert(
585            "decoder.conv_in.weight".into(),
586            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
587        );
588        hf_sd.insert(
589            "post_quant_conv.weight".into(),
590            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
591        );
592        let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
593        assert_eq!(
594            rep.dropped,
595            vec![
596                "decoder.conv_in.weight".to_string(),
597                "post_quant_conv.weight".to_string(),
598            ]
599        );
600    }
601
602    #[test]
603    fn encoder_load_hf_strict_rejects_decoder_keys() {
604        let cfg = tiny_cfg();
605        let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
606        let mut hf_sd: StateDict<f32> = HashMap::new();
607        hf_sd.insert(
608            "decoder.conv_in.weight".into(),
609            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
610        );
611        assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
612    }
613
614    #[test]
615    fn encoder_load_hf_strips_vae_prefix() {
616        let cfg = tiny_cfg();
617        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
618        let bare = src.state_dict();
619        let mut prefixed: StateDict<f32> = HashMap::new();
620        for (k, v) in bare {
621            prefixed.insert(format!("vae.{k}"), v);
622        }
623        let mut dst = VaeEncoder::<f32>::new(cfg).unwrap();
624        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
625        assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
626        let x = Tensor::from_storage(
627            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
628            vec![1, 3, 8, 8],
629            false,
630        )
631        .unwrap();
632        let a = src.forward(&x).unwrap();
633        let b = dst.forward(&x).unwrap();
634        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
635            assert!((x - y).abs() < 1e-5);
636        }
637    }
638
639    #[test]
640    fn full_vae_checkpoint_loadable_by_both_halves() {
641        // The interesting property: a single full-VAE state dict
642        // (encoder + post_quant_conv + decoder + quant_conv) is loadable
643        // by both VaeDecoder and VaeEncoder, each cleanly dropping the
644        // half it doesn't own.
645        let cfg = tiny_cfg();
646        let dec_src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
647        let enc_src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
648
649        let mut combined: StateDict<f32> = HashMap::new();
650        for (k, v) in dec_src.state_dict() {
651            combined.insert(k, v);
652        }
653        for (k, v) in enc_src.state_dict() {
654            combined.insert(k, v);
655        }
656
657        let mut dec_dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
658        let dec_rep = dec_dst.load_hf_state_dict(&combined, false).unwrap();
659        let mut enc_dst = VaeEncoder::<f32>::new(cfg).unwrap();
660        let enc_rep = enc_dst.load_hf_state_dict(&combined, false).unwrap();
661
662        // Decoder should have dropped exactly the encoder + quant_conv keys.
663        for k in &dec_rep.dropped {
664            assert!(
665                k.starts_with("encoder.") || k.starts_with("quant_conv."),
666                "decoder dropped unexpected key: {k}"
667            );
668        }
669        // Encoder should have dropped exactly the decoder + post_quant_conv keys.
670        for k in &enc_rep.dropped {
671            assert!(
672                k.starts_with("decoder.") || k.starts_with("post_quant_conv."),
673                "encoder dropped unexpected key: {k}"
674            );
675        }
676        assert!(!dec_rep.dropped.is_empty(), "decoder should have dropped some keys");
677        assert!(!enc_rep.dropped.is_empty(), "encoder should have dropped some keys");
678    }
679}