Skip to main content

ferrotorch_diffusion/
safetensors_loader.rs

1//! Helpers that turn a path-to-safetensors into a loaded
2//! [`VaeDecoder`].
3//!
4//! The pinned SD-1.5 VAE mirror carries the full VAE state-dict
5//! (encoder + post_quant_conv + decoder + quant_conv). Inference needs
6//! only the decoder slice (`post_quant_conv.*` + `decoder.*`). This
7//! loader drops everything else and returns a [`DropReport`] so the pin
8//! script can audit the drop set.
9//!
10//! ## REQ status (per `.design/ferrotorch-diffusion/safetensors_loader.md`)
11//!
12//! | REQ | Status | Evidence |
13//! |---|---|---|
14//! | REQ-1 | SHIPPED | `DropReport` at `safetensors_loader.rs:31..36`; consumer: `safetensors_loader.rs:88..90` populates and returns it; `lib.rs:138` re-exports for pin-script auditing |
15//! | REQ-2 | SHIPPED | `VaeDecoder::load_hf_state_dict` at `safetensors_loader.rs:58..91`; consumer: `load_vae_decoder` at `safetensors_loader.rs:331` invokes it |
16//! | REQ-3 | SHIPPED | `VaeEncoder::load_hf_state_dict` at `safetensors_loader.rs:363..396`; consumer: `load_vae_encoder` at `safetensors_loader.rs:426` invokes it |
17//! | REQ-4 | SHIPPED | `UNet2DConditionModel::load_hf_state_dict` at `safetensors_loader.rs:113..148`; consumer: `load_unet` at `safetensors_loader.rs:176` invokes it |
18//! | REQ-5 | SHIPPED | `load_safetensors_clip_filtered` at `safetensors_loader.rs:191..246` and `load_clip_text_encoder` at `safetensors_loader.rs:270..302`; consumer: `examples/clip_text_encode_dump.rs:265` invokes `load_clip_text_encoder` on the SD-1.5 mirror |
19//! | REQ-6 | SHIPPED | `load_unet` at `safetensors_loader.rs:163..178`, `load_vae_decoder` at `safetensors_loader.rs:318..333`, `load_vae_encoder` at `safetensors_loader.rs:413..428`, `load_clip_text_encoder` at `safetensors_loader.rs:270..302`; consumer: all four examples (`unet_predict_dump.rs`, `vae_decode_dump.rs`, `clip_text_encode_dump.rs`, `sd_pipeline_dump.rs`) import and call them |
20
21use std::collections::HashMap;
22use std::path::Path;
23
24use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
25use ferrotorch_nn::module::{Module, StateDict};
26use ferrotorch_serialize::load_safetensors;
27
28use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
29use crate::config::VaeDecoderConfig;
30use crate::unet::UNet2DConditionModel;
31use crate::unet_config::UNet2DConditionConfig;
32use crate::vae::VaeDecoder;
33use crate::vae_encoder::{VaeEncoder, VaeEncoderConfig};
34
35/// Audit trail returned by [`load_vae_decoder`] / [`VaeDecoder::load_hf_state_dict`].
36///
37/// Records HF keys that were dropped because they do not belong to the
38/// decoder (typically the encoder + `quant_conv` weights of a full
39/// `AutoencoderKL` checkpoint). The pin script asserts the dropped set
40/// equals the documented encoder / quant_conv key surface so a silent
41/// parameter drop cannot recur.
42#[derive(Debug, Default, Clone)]
43pub struct DropReport {
44    /// Keys present in the upstream state dict that did not belong to
45    /// the VAE decoder. Sorted for deterministic equality.
46    pub dropped: Vec<String>,
47}
48
49impl<T: Float> VaeDecoder<T> {
50    /// Load a HuggingFace AutoencoderKL state dict into this module.
51    ///
52    /// Accepts both:
53    ///   - `post_quant_conv.*` / `decoder.*` (bare-VAE layout, the
54    ///     normalised form the pin script produces)
55    ///   - `vae.post_quant_conv.*` / `vae.decoder.*` (when bundled
56    ///     inside a full SD pipeline checkpoint)
57    ///
58    /// Any other key (encoder, `quant_conv`, etc.) is recorded in the
59    /// returned [`DropReport`] (or, in strict mode, surfaces as
60    /// [`FerrotorchError::InvalidArgument`]).
61    ///
62    /// # Errors
63    ///
64    /// Forwards whatever each sub-module's `load_state_dict` returns
65    /// (`ShapeMismatch` on a wrong-shape tensor, `InvalidArgument` in
66    /// strict mode when a required tensor is missing). Strict mode will
67    /// surface `encoder.*` / `quant_conv.*` / etc. as errors; callers
68    /// with a full VAE checkpoint must pass `strict=false`.
69    pub fn load_hf_state_dict(
70        &mut self,
71        hf_state: &StateDict<T>,
72        strict: bool,
73    ) -> FerrotorchResult<DropReport> {
74        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
75        let mut dropped: Vec<String> = Vec::new();
76
77        for (k, v) in hf_state {
78            // Try (a) bare-VAE prefix → as-is; (b) full-pipeline
79            // `vae.<rest>` prefix → strip the `vae.` and accept.
80            let after_vae = k
81                .strip_prefix("vae.")
82                .map_or_else(|| k.clone(), str::to_owned);
83            if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
84                remapped.insert(after_vae, v.clone());
85                continue;
86            }
87            if strict {
88                return Err(FerrotorchError::InvalidArgument {
89                    message: format!(
90                        "VaeDecoder::load_hf_state_dict: key {k:?} is not under \
91                         `post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
92                         and strict mode is on. Pass strict=false to drop encoder / \
93                         quant_conv keys."
94                    ),
95                });
96            }
97            dropped.push(k.clone());
98        }
99        dropped.sort();
100        self.load_state_dict(&remapped, strict)?;
101        Ok(DropReport { dropped })
102    }
103}
104
105// ---------------------------------------------------------------------------
106// UNet2DConditionModel loader
107// ---------------------------------------------------------------------------
108
109impl<T: Float> UNet2DConditionModel<T> {
110    /// Load a HuggingFace UNet state dict into this module.
111    ///
112    /// Accepts both:
113    ///   - bare-UNet layout (the pin script normalises to this form)
114    ///   - `unet.<rest>` prefix (full SD pipeline checkpoint)
115    ///
116    /// Any unrecognised key is recorded in the returned [`DropReport`]
117    /// (or surfaces as [`FerrotorchError::InvalidArgument`] in strict
118    /// mode).
119    ///
120    /// # Errors
121    ///
122    /// Forwards whatever each sub-module's `load_state_dict` returns
123    /// (shape mismatch / strict-mode missing key).
124    pub fn load_hf_state_dict(
125        &mut self,
126        hf_state: &StateDict<T>,
127        strict: bool,
128    ) -> FerrotorchResult<DropReport> {
129        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
130        let mut dropped: Vec<String> = Vec::new();
131        for (k, v) in hf_state {
132            let after_unet = k
133                .strip_prefix("unet.")
134                .map_or_else(|| k.clone(), str::to_owned);
135            let is_unet_key = after_unet.starts_with("time_embedding.")
136                || after_unet.starts_with("conv_in.")
137                || after_unet.starts_with("down_blocks.")
138                || after_unet.starts_with("mid_block.")
139                || after_unet.starts_with("up_blocks.")
140                || after_unet.starts_with("conv_norm_out.")
141                || after_unet.starts_with("conv_out.");
142            if is_unet_key {
143                remapped.insert(after_unet, v.clone());
144                continue;
145            }
146            if strict {
147                return Err(FerrotorchError::InvalidArgument {
148                    message: format!(
149                        "UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
150                         a UNet prefix (with optional `unet.`) and strict mode is on."
151                    ),
152                });
153            }
154            dropped.push(k.clone());
155        }
156        dropped.sort();
157        self.load_state_dict(&remapped, strict)?;
158        Ok(DropReport { dropped })
159    }
160}
161
162/// Load a [`UNet2DConditionModel`] from a UNet
163/// `diffusion_pytorch_model.safetensors` file plus a parsed config.
164///
165/// `strict=false` is required when loading a full SD pipeline
166/// checkpoint (which carries `vae.*` / `text_encoder.*` keys); for a
167/// bare-UNet mirror (the form `pin_pretrained_diffusion_weights.py`
168/// uploads) `strict=true` is fine.
169///
170/// # Errors
171///
172/// Propagates safetensors parse errors, [`UNet2DConditionModel`]
173/// construction errors, and any per-key shape / strict-mode mismatch.
174pub fn load_unet<T: Float>(
175    weights_path: &Path,
176    cfg: UNet2DConditionConfig,
177    strict: bool,
178) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
179    let state =
180        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
181            message: format!(
182                "load_unet: failed to decode safetensors {}: {e}",
183                weights_path.display()
184            ),
185        })?;
186    let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
187    let report = unet.load_hf_state_dict(&state, strict)?;
188    Ok((unet, report))
189}
190
191// ---------------------------------------------------------------------------
192// ClipTextEncoder loader
193// ---------------------------------------------------------------------------
194
195/// Read a single non-sharded safetensors file into a typed `StateDict`,
196/// dropping any int64 `position_ids` buffer BEFORE the generic-`T`
197/// decode. The CLIP-text checkpoint ships an
198/// `embeddings.position_ids` (or `text_model.embeddings.position_ids`)
199/// `[1, 77]` int64 buffer that would poison a `load_safetensors::<f32>`
200/// pass because i64 is not representable as f32 in the underlying
201/// dispatch. Mirrors the trick `ferrotorch-bert`'s loader uses.
202fn load_safetensors_clip_filtered<T: Float>(
203    weights_path: &Path,
204) -> FerrotorchResult<(StateDict<T>, bool)> {
205    use safetensors::SafeTensors;
206
207    let bytes = std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
208        message: format!(
209            "load_safetensors_clip_filtered: failed to read {}: {e}",
210            weights_path.display()
211        ),
212    })?;
213    let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
214        message: format!(
215            "load_safetensors_clip_filtered: failed to parse {}: {e}",
216            weights_path.display()
217        ),
218    })?;
219    let mut keep: Vec<String> = Vec::new();
220    let mut had_position_ids = false;
221    for k in st.names() {
222        let s: &str = k.as_str();
223        // The position_ids buffer is the only int64 surface in
224        // CLIPTextModel and it has no parameter slot on our side.
225        if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
226            had_position_ids = true;
227            continue;
228        }
229        keep.push(String::from(s));
230    }
231
232    // Re-serialize only the kept tensors into an in-memory safetensors
233    // blob and feed that to `load_safetensors::<T>`. Reuses the audited
234    // generic decoder instead of re-implementing dtype dispatch here.
235    let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
236        Vec::with_capacity(keep.len());
237    for k in &keep {
238        let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
239            message: format!(
240                "load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
241            ),
242        })?;
243        subset.push((k.clone(), v));
244    }
245    let serialized =
246        safetensors::serialize(subset, &None).map_err(|e| FerrotorchError::InvalidArgument {
247            message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
248        })?;
249    let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
250        message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
251    })?;
252    std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
253        message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
254    })?;
255    let state = load_safetensors::<T>(tmp.path())?;
256    Ok((state, had_position_ids))
257}
258
259/// Load a [`ClipTextEncoder`] from a CLIP text-tower
260/// `model.safetensors` file plus a parsed [`ClipTextConfig`].
261///
262/// Accepts both upstream layouts:
263///   - bare `embeddings.* / encoder.* / final_layer_norm.*` (what the
264///     pin script normalises to).
265///   - `text_model.<rest>` prefix (what the upstream HF checkpoint
266///     ships).
267///
268/// The int64 `embeddings.position_ids` buffer (a `[1, max_pos]`
269/// `arange(max_pos)` constant regenerated each forward pass) is
270/// dropped at decode time and surfaced via the returned
271/// [`DropReport`].
272///
273/// `strict=false` is required when the upstream checkpoint carries the
274/// position_ids buffer (the default for `runwayml/stable-diffusion-v1-5`'s
275/// `text_encoder/model.safetensors`).
276///
277/// # Errors
278///
279/// Propagates safetensors parse errors, [`ClipTextEncoder`] construction
280/// errors, and any per-key shape / strict-mode mismatch.
281pub fn load_clip_text_encoder<T: Float>(
282    weights_path: &Path,
283    cfg: ClipTextConfig,
284    strict: bool,
285) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
286    let (mut state, had_position_ids) =
287        load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
288            FerrotorchError::InvalidArgument {
289                message: format!(
290                    "load_clip_text_encoder: failed to decode safetensors {}: {e}",
291                    weights_path.display()
292                ),
293            }
294        })?;
295
296    // Re-insert a placeholder entry for the position_ids buffer (with
297    // the upstream key it actually used) so the model's DropReport
298    // captures it as an intentionally-dropped upstream key. The
299    // placeholder tensor is never consumed — `load_hf_state_dict`
300    // drops the entry before any parameter slot sees it.
301    if had_position_ids {
302        let key = if state.keys().any(|k| k.starts_with("text_model.")) {
303            "text_model.embeddings.position_ids".to_string()
304        } else {
305            "embeddings.position_ids".to_string()
306        };
307        state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
308    }
309
310    let mut enc = ClipTextEncoder::<T>::new(cfg)?;
311    let report = enc.load_hf_state_dict(&state, strict)?;
312    Ok((enc, report))
313}
314
315/// Load a [`VaeDecoder`] from a VAE `diffusion_pytorch_model.safetensors`
316/// file plus a parsed config.
317///
318/// `strict=false` is required for a full `AutoencoderKL` checkpoint
319/// (which ships encoder + quant_conv weights this decoder-only loader
320/// has no slot for). The returned [`DropReport`] captures every
321/// dropped key so the pin script can confirm the drop set is exactly
322/// the documented encoder/quant_conv surface.
323///
324/// # Errors
325///
326/// Propagates safetensors parse errors, [`VaeDecoder`] construction
327/// errors, and any per-key shape / strict-mode mismatch from the
328/// underlying load.
329pub fn load_vae_decoder<T: Float>(
330    weights_path: &Path,
331    cfg: VaeDecoderConfig,
332    strict: bool,
333) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
334    let state =
335        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
336            message: format!(
337                "load_vae_decoder: failed to decode safetensors {}: {e}",
338                weights_path.display()
339            ),
340        })?;
341    let mut decoder = VaeDecoder::<T>::new(cfg)?;
342    let report = decoder.load_hf_state_dict(&state, strict)?;
343    Ok((decoder, report))
344}
345
346// ---------------------------------------------------------------------------
347// VaeEncoder loader
348// ---------------------------------------------------------------------------
349
350impl<T: Float> VaeEncoder<T> {
351    /// Load a HuggingFace AutoencoderKL state dict into this module.
352    ///
353    /// Accepts both:
354    ///   - `encoder.*` / `quant_conv.*` (bare-VAE layout, the normalised
355    ///     form produced by the pin script for an encoder-only artifact)
356    ///   - `vae.encoder.*` / `vae.quant_conv.*` (full SD pipeline
357    ///     checkpoint, e.g. an upstream `runwayml/stable-diffusion-v1-5`
358    ///     `model.safetensors`)
359    ///
360    /// Any other key (decoder, `post_quant_conv`, UNet, text encoder, …)
361    /// is recorded in the returned [`DropReport`] (or, in strict mode,
362    /// surfaces as [`FerrotorchError::InvalidArgument`]). This mirrors
363    /// [`VaeDecoder::load_hf_state_dict`] so the same full-checkpoint
364    /// file can be loaded twice — once for the encoder and once for the
365    /// decoder — each time dropping the other half.
366    ///
367    /// # Errors
368    ///
369    /// Forwards whatever each sub-module's `load_state_dict` returns
370    /// (`ShapeMismatch` on a wrong-shape tensor, `InvalidArgument` in
371    /// strict mode when a required tensor is missing). Strict mode will
372    /// surface `decoder.*` / `post_quant_conv.*` / etc. as errors;
373    /// callers with a full VAE checkpoint must pass `strict=false`.
374    pub fn load_hf_state_dict(
375        &mut self,
376        hf_state: &StateDict<T>,
377        strict: bool,
378    ) -> FerrotorchResult<DropReport> {
379        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
380        let mut dropped: Vec<String> = Vec::new();
381
382        for (k, v) in hf_state {
383            // (a) bare-VAE prefix → as-is; (b) full-pipeline `vae.<rest>`
384            // prefix → strip the `vae.` and accept.
385            let after_vae = k
386                .strip_prefix("vae.")
387                .map_or_else(|| k.clone(), str::to_owned);
388            if after_vae.starts_with("encoder.") || after_vae.starts_with("quant_conv.") {
389                remapped.insert(after_vae, v.clone());
390                continue;
391            }
392            if strict {
393                return Err(FerrotorchError::InvalidArgument {
394                    message: format!(
395                        "VaeEncoder::load_hf_state_dict: key {k:?} is not under \
396                         `encoder.*` / `quant_conv.*` (with optional `vae.` prefix) \
397                         and strict mode is on. Pass strict=false to drop decoder / \
398                         post_quant_conv keys."
399                    ),
400                });
401            }
402            dropped.push(k.clone());
403        }
404        dropped.sort();
405        self.load_state_dict(&remapped, strict)?;
406        Ok(DropReport { dropped })
407    }
408}
409
410/// Load a [`VaeEncoder`] from a VAE `diffusion_pytorch_model.safetensors`
411/// file plus a parsed config.
412///
413/// Symmetric counterpart of [`load_vae_decoder`]. `strict=false` is
414/// required for a full `AutoencoderKL` checkpoint (which ships
415/// `decoder` + `post_quant_conv` weights this encoder-only loader has
416/// no slot for). The returned [`DropReport`] captures every dropped
417/// key so a pin script can audit the drop set.
418///
419/// # Errors
420///
421/// Propagates safetensors parse errors, [`VaeEncoder`] construction
422/// errors, and any per-key shape / strict-mode mismatch from the
423/// underlying load.
424pub fn load_vae_encoder<T: Float>(
425    weights_path: &Path,
426    cfg: VaeEncoderConfig,
427    strict: bool,
428) -> FerrotorchResult<(VaeEncoder<T>, DropReport)> {
429    let state =
430        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
431            message: format!(
432                "load_vae_encoder: failed to decode safetensors {}: {e}",
433                weights_path.display()
434            ),
435        })?;
436    let mut encoder = VaeEncoder::<T>::new(cfg)?;
437    let report = encoder.load_hf_state_dict(&state, strict)?;
438    Ok((encoder, report))
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use ferrotorch_core::{Tensor, TensorStorage};
445    use ferrotorch_serialize::save_safetensors;
446    use std::path::PathBuf;
447
448    fn tiny_cfg() -> VaeDecoderConfig {
449        VaeDecoderConfig {
450            out_channels: 3,
451            latent_channels: 4,
452            block_out_channels: vec![4, 8, 16, 16],
453            layers_per_block: 1,
454            norm_num_groups: 4,
455            sample_size: 8,
456            scaling_factor: 0.18215,
457        }
458    }
459
460    fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
461        let dir = tempfile::tempdir().unwrap();
462        let path = dir.path().join("model.safetensors");
463        // The on-disk file uses the bare-VAE prefix (no `vae.`); the
464        // loader's strip-vae path is exercised by the dedicated test
465        // below.
466        let sd = v.state_dict();
467        save_safetensors(&sd, &path).unwrap();
468        (dir, path)
469    }
470
471    #[test]
472    fn round_trip_safetensors_into_decoder() {
473        let cfg = tiny_cfg();
474        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
475        let (_d, p) = tmp_safetensors_from(&src);
476        let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
477        assert!(
478            report.dropped.is_empty(),
479            "round-trip should have empty drop list, got {:?}",
480            report.dropped
481        );
482        let x = Tensor::from_storage(
483            TensorStorage::cpu(vec![0.01f32; 4]),
484            vec![1, 4, 1, 1],
485            false,
486        )
487        .unwrap();
488        let a = src.forward(&x).unwrap();
489        let b = dst.forward(&x).unwrap();
490        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
491            assert!((x - y).abs() < 1e-5);
492        }
493    }
494
495    #[test]
496    fn load_hf_drops_encoder_keys_nonstrict() {
497        let cfg = tiny_cfg();
498        let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
499        let mut hf_sd: StateDict<f32> = v.state_dict();
500        // Add an encoder key — this should be dropped.
501        hf_sd.insert(
502            "encoder.conv_in.weight".into(),
503            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
504        );
505        // Add a quant_conv key — also dropped.
506        hf_sd.insert(
507            "quant_conv.weight".into(),
508            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
509        );
510        let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
511        assert_eq!(
512            rep.dropped,
513            vec![
514                "encoder.conv_in.weight".to_string(),
515                "quant_conv.weight".to_string(),
516            ]
517        );
518    }
519
520    #[test]
521    fn load_hf_strict_rejects_encoder_keys() {
522        let cfg = tiny_cfg();
523        let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
524        let mut hf_sd: StateDict<f32> = HashMap::new();
525        hf_sd.insert(
526            "encoder.conv_in.weight".into(),
527            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
528        );
529        assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
530    }
531
532    #[test]
533    fn load_hf_strips_vae_prefix() {
534        let cfg = tiny_cfg();
535        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
536        let bare = src.state_dict();
537        // Re-prefix with `vae.` (the layout SD pipeline checkpoints use).
538        let mut prefixed: StateDict<f32> = HashMap::new();
539        for (k, v) in bare {
540            prefixed.insert(format!("vae.{k}"), v);
541        }
542        let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
543        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
544        assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
545        let x = Tensor::from_storage(
546            TensorStorage::cpu(vec![0.01f32; 4]),
547            vec![1, 4, 1, 1],
548            false,
549        )
550        .unwrap();
551        let a = src.forward(&x).unwrap();
552        let b = dst.forward(&x).unwrap();
553        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
554            assert!((x - y).abs() < 1e-5);
555        }
556    }
557
558    fn tmp_encoder_safetensors_from(v: &VaeEncoder<f32>) -> (tempfile::TempDir, PathBuf) {
559        let dir = tempfile::tempdir().unwrap();
560        let path = dir.path().join("model.safetensors");
561        let sd = v.state_dict();
562        save_safetensors(&sd, &path).unwrap();
563        (dir, path)
564    }
565
566    #[test]
567    fn round_trip_safetensors_into_encoder() {
568        let cfg = tiny_cfg();
569        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
570        let (_d, p) = tmp_encoder_safetensors_from(&src);
571        let (dst, report) = load_vae_encoder::<f32>(&p, cfg.clone(), false).unwrap();
572        assert!(
573            report.dropped.is_empty(),
574            "encoder round-trip should have empty drop list, got {:?}",
575            report.dropped
576        );
577        let x = Tensor::from_storage(
578            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
579            vec![1, 3, 8, 8],
580            false,
581        )
582        .unwrap();
583        let a = src.forward(&x).unwrap();
584        let b = dst.forward(&x).unwrap();
585        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
586            assert!((x - y).abs() < 1e-5);
587        }
588    }
589
590    #[test]
591    fn encoder_load_hf_drops_decoder_keys_nonstrict() {
592        let cfg = tiny_cfg();
593        let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
594        let mut hf_sd: StateDict<f32> = v.state_dict();
595        // Add decoder keys that the encoder should drop.
596        hf_sd.insert(
597            "decoder.conv_in.weight".into(),
598            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
599        );
600        hf_sd.insert(
601            "post_quant_conv.weight".into(),
602            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
603        );
604        let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
605        assert_eq!(
606            rep.dropped,
607            vec![
608                "decoder.conv_in.weight".to_string(),
609                "post_quant_conv.weight".to_string(),
610            ]
611        );
612    }
613
614    #[test]
615    fn encoder_load_hf_strict_rejects_decoder_keys() {
616        let cfg = tiny_cfg();
617        let mut v = VaeEncoder::<f32>::new(cfg).unwrap();
618        let mut hf_sd: StateDict<f32> = HashMap::new();
619        hf_sd.insert(
620            "decoder.conv_in.weight".into(),
621            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
622        );
623        assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
624    }
625
626    #[test]
627    fn encoder_load_hf_strips_vae_prefix() {
628        let cfg = tiny_cfg();
629        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
630        let bare = src.state_dict();
631        let mut prefixed: StateDict<f32> = HashMap::new();
632        for (k, v) in bare {
633            prefixed.insert(format!("vae.{k}"), v);
634        }
635        let mut dst = VaeEncoder::<f32>::new(cfg).unwrap();
636        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
637        assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
638        let x = Tensor::from_storage(
639            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
640            vec![1, 3, 8, 8],
641            false,
642        )
643        .unwrap();
644        let a = src.forward(&x).unwrap();
645        let b = dst.forward(&x).unwrap();
646        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
647            assert!((x - y).abs() < 1e-5);
648        }
649    }
650
651    #[test]
652    fn full_vae_checkpoint_loadable_by_both_halves() {
653        // The interesting property: a single full-VAE state dict
654        // (encoder + post_quant_conv + decoder + quant_conv) is loadable
655        // by both VaeDecoder and VaeEncoder, each cleanly dropping the
656        // half it doesn't own.
657        let cfg = tiny_cfg();
658        let dec_src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
659        let enc_src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
660
661        let mut combined: StateDict<f32> = HashMap::new();
662        for (k, v) in dec_src.state_dict() {
663            combined.insert(k, v);
664        }
665        for (k, v) in enc_src.state_dict() {
666            combined.insert(k, v);
667        }
668
669        let mut dec_dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
670        let dec_rep = dec_dst.load_hf_state_dict(&combined, false).unwrap();
671        let mut enc_dst = VaeEncoder::<f32>::new(cfg).unwrap();
672        let enc_rep = enc_dst.load_hf_state_dict(&combined, false).unwrap();
673
674        // Decoder should have dropped exactly the encoder + quant_conv keys.
675        for k in &dec_rep.dropped {
676            assert!(
677                k.starts_with("encoder.") || k.starts_with("quant_conv."),
678                "decoder dropped unexpected key: {k}"
679            );
680        }
681        // Encoder should have dropped exactly the decoder + post_quant_conv keys.
682        for k in &enc_rep.dropped {
683            assert!(
684                k.starts_with("decoder.") || k.starts_with("post_quant_conv."),
685                "encoder dropped unexpected key: {k}"
686            );
687        }
688        assert!(
689            !dec_rep.dropped.is_empty(),
690            "decoder should have dropped some keys"
691        );
692        assert!(
693            !enc_rep.dropped.is_empty(),
694            "encoder should have dropped some keys"
695        );
696    }
697}