Skip to main content

ferrotorch_diffusion/
safetensors_loader.rs

1//! Helpers that turn a path-to-safetensors into a loaded
2//! [`VaeDecoder`].
3//!
4//! The pinned SD-1.5 VAE mirror carries the full VAE state-dict
5//! (encoder + post_quant_conv + decoder + quant_conv). Inference needs
6//! only the decoder slice (`post_quant_conv.*` + `decoder.*`). This
7//! loader drops everything else and returns a [`DropReport`] so the pin
8//! script can audit the drop set.
9
10use std::collections::HashMap;
11use std::path::Path;
12
13use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
14use ferrotorch_nn::module::{Module, StateDict};
15use ferrotorch_serialize::load_safetensors;
16
17use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
18use crate::config::VaeDecoderConfig;
19use crate::unet::UNet2DConditionModel;
20use crate::unet_config::UNet2DConditionConfig;
21use crate::vae::VaeDecoder;
22
23/// Audit trail returned by [`load_vae_decoder`] / [`VaeDecoder::load_hf_state_dict`].
24///
25/// Records HF keys that were dropped because they do not belong to the
26/// decoder (typically the encoder + `quant_conv` weights of a full
27/// `AutoencoderKL` checkpoint). The pin script asserts the dropped set
28/// equals the documented encoder / quant_conv key surface so a silent
29/// parameter drop cannot recur.
30#[derive(Debug, Default, Clone)]
31pub struct DropReport {
32    /// Keys present in the upstream state dict that did not belong to
33    /// the VAE decoder. Sorted for deterministic equality.
34    pub dropped: Vec<String>,
35}
36
37impl<T: Float> VaeDecoder<T> {
38    /// Load a HuggingFace AutoencoderKL state dict into this module.
39    ///
40    /// Accepts both:
41    ///   - `post_quant_conv.*` / `decoder.*` (bare-VAE layout, the
42    ///     normalised form the pin script produces)
43    ///   - `vae.post_quant_conv.*` / `vae.decoder.*` (when bundled
44    ///     inside a full SD pipeline checkpoint)
45    ///
46    /// Any other key (encoder, `quant_conv`, etc.) is recorded in the
47    /// returned [`DropReport`] (or, in strict mode, surfaces as
48    /// [`FerrotorchError::InvalidArgument`]).
49    ///
50    /// # Errors
51    ///
52    /// Forwards whatever each sub-module's `load_state_dict` returns
53    /// (`ShapeMismatch` on a wrong-shape tensor, `InvalidArgument` in
54    /// strict mode when a required tensor is missing). Strict mode will
55    /// surface `encoder.*` / `quant_conv.*` / etc. as errors; callers
56    /// with a full VAE checkpoint must pass `strict=false`.
57    pub fn load_hf_state_dict(
58        &mut self,
59        hf_state: &StateDict<T>,
60        strict: bool,
61    ) -> FerrotorchResult<DropReport> {
62        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
63        let mut dropped: Vec<String> = Vec::new();
64
65        for (k, v) in hf_state {
66            // Try (a) bare-VAE prefix → as-is; (b) full-pipeline
67            // `vae.<rest>` prefix → strip the `vae.` and accept.
68            let after_vae = k.strip_prefix("vae.").map_or_else(|| k.clone(), str::to_owned);
69            if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
70                remapped.insert(after_vae, v.clone());
71                continue;
72            }
73            if strict {
74                return Err(FerrotorchError::InvalidArgument {
75                    message: format!(
76                        "VaeDecoder::load_hf_state_dict: key {k:?} is not under \
77                         `post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
78                         and strict mode is on. Pass strict=false to drop encoder / \
79                         quant_conv keys."
80                    ),
81                });
82            }
83            dropped.push(k.clone());
84        }
85        dropped.sort();
86        self.load_state_dict(&remapped, strict)?;
87        Ok(DropReport { dropped })
88    }
89}
90
91// ---------------------------------------------------------------------------
92// UNet2DConditionModel loader
93// ---------------------------------------------------------------------------
94
95impl<T: Float> UNet2DConditionModel<T> {
96    /// Load a HuggingFace UNet state dict into this module.
97    ///
98    /// Accepts both:
99    ///   - bare-UNet layout (the pin script normalises to this form)
100    ///   - `unet.<rest>` prefix (full SD pipeline checkpoint)
101    ///
102    /// Any unrecognised key is recorded in the returned [`DropReport`]
103    /// (or surfaces as [`FerrotorchError::InvalidArgument`] in strict
104    /// mode).
105    ///
106    /// # Errors
107    ///
108    /// Forwards whatever each sub-module's `load_state_dict` returns
109    /// (shape mismatch / strict-mode missing key).
110    pub fn load_hf_state_dict(
111        &mut self,
112        hf_state: &StateDict<T>,
113        strict: bool,
114    ) -> FerrotorchResult<DropReport> {
115        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
116        let mut dropped: Vec<String> = Vec::new();
117        for (k, v) in hf_state {
118            let after_unet = k.strip_prefix("unet.").map_or_else(|| k.clone(), str::to_owned);
119            let is_unet_key = after_unet.starts_with("time_embedding.")
120                || after_unet.starts_with("conv_in.")
121                || after_unet.starts_with("down_blocks.")
122                || after_unet.starts_with("mid_block.")
123                || after_unet.starts_with("up_blocks.")
124                || after_unet.starts_with("conv_norm_out.")
125                || after_unet.starts_with("conv_out.");
126            if is_unet_key {
127                remapped.insert(after_unet, v.clone());
128                continue;
129            }
130            if strict {
131                return Err(FerrotorchError::InvalidArgument {
132                    message: format!(
133                        "UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
134                         a UNet prefix (with optional `unet.`) and strict mode is on."
135                    ),
136                });
137            }
138            dropped.push(k.clone());
139        }
140        dropped.sort();
141        self.load_state_dict(&remapped, strict)?;
142        Ok(DropReport { dropped })
143    }
144}
145
146/// Load a [`UNet2DConditionModel`] from a UNet
147/// `diffusion_pytorch_model.safetensors` file plus a parsed config.
148///
149/// `strict=false` is required when loading a full SD pipeline
150/// checkpoint (which carries `vae.*` / `text_encoder.*` keys); for a
151/// bare-UNet mirror (the form `pin_pretrained_diffusion_weights.py`
152/// uploads) `strict=true` is fine.
153///
154/// # Errors
155///
156/// Propagates safetensors parse errors, [`UNet2DConditionModel`]
157/// construction errors, and any per-key shape / strict-mode mismatch.
158pub fn load_unet<T: Float>(
159    weights_path: &Path,
160    cfg: UNet2DConditionConfig,
161    strict: bool,
162) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
163    let state =
164        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
165            message: format!(
166                "load_unet: failed to decode safetensors {}: {e}",
167                weights_path.display()
168            ),
169        })?;
170    let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
171    let report = unet.load_hf_state_dict(&state, strict)?;
172    Ok((unet, report))
173}
174
175// ---------------------------------------------------------------------------
176// ClipTextEncoder loader
177// ---------------------------------------------------------------------------
178
179/// Read a single non-sharded safetensors file into a typed `StateDict`,
180/// dropping any int64 `position_ids` buffer BEFORE the generic-`T`
181/// decode. The CLIP-text checkpoint ships an
182/// `embeddings.position_ids` (or `text_model.embeddings.position_ids`)
183/// `[1, 77]` int64 buffer that would poison a `load_safetensors::<f32>`
184/// pass because i64 is not representable as f32 in the underlying
185/// dispatch. Mirrors the trick `ferrotorch-bert`'s loader uses.
186fn load_safetensors_clip_filtered<T: Float>(
187    weights_path: &Path,
188) -> FerrotorchResult<(StateDict<T>, bool)> {
189    use safetensors::SafeTensors;
190
191    let bytes =
192        std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
193            message: format!(
194                "load_safetensors_clip_filtered: failed to read {}: {e}",
195                weights_path.display()
196            ),
197        })?;
198    let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
199        message: format!(
200            "load_safetensors_clip_filtered: failed to parse {}: {e}",
201            weights_path.display()
202        ),
203    })?;
204    let mut keep: Vec<String> = Vec::new();
205    let mut had_position_ids = false;
206    for k in st.names() {
207        let s: &str = k.as_str();
208        // The position_ids buffer is the only int64 surface in
209        // CLIPTextModel and it has no parameter slot on our side.
210        if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
211            had_position_ids = true;
212            continue;
213        }
214        keep.push(String::from(s));
215    }
216
217    // Re-serialize only the kept tensors into an in-memory safetensors
218    // blob and feed that to `load_safetensors::<T>`. Reuses the audited
219    // generic decoder instead of re-implementing dtype dispatch here.
220    let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
221        Vec::with_capacity(keep.len());
222    for k in &keep {
223        let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
224            message: format!(
225                "load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
226            ),
227        })?;
228        subset.push((k.clone(), v));
229    }
230    let serialized = safetensors::serialize(subset, &None).map_err(|e| {
231        FerrotorchError::InvalidArgument {
232            message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
233        }
234    })?;
235    let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
236        message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
237    })?;
238    std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
239        message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
240    })?;
241    let state = load_safetensors::<T>(tmp.path())?;
242    Ok((state, had_position_ids))
243}
244
245/// Load a [`ClipTextEncoder`] from a CLIP text-tower
246/// `model.safetensors` file plus a parsed [`ClipTextConfig`].
247///
248/// Accepts both upstream layouts:
249///   - bare `embeddings.* / encoder.* / final_layer_norm.*` (what the
250///     pin script normalises to).
251///   - `text_model.<rest>` prefix (what the upstream HF checkpoint
252///     ships).
253///
254/// The int64 `embeddings.position_ids` buffer (a `[1, max_pos]`
255/// `arange(max_pos)` constant regenerated each forward pass) is
256/// dropped at decode time and surfaced via the returned
257/// [`DropReport`].
258///
259/// `strict=false` is required when the upstream checkpoint carries the
260/// position_ids buffer (the default for `runwayml/stable-diffusion-v1-5`'s
261/// `text_encoder/model.safetensors`).
262///
263/// # Errors
264///
265/// Propagates safetensors parse errors, [`ClipTextEncoder`] construction
266/// errors, and any per-key shape / strict-mode mismatch.
267pub fn load_clip_text_encoder<T: Float>(
268    weights_path: &Path,
269    cfg: ClipTextConfig,
270    strict: bool,
271) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
272    let (mut state, had_position_ids) =
273        load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
274            FerrotorchError::InvalidArgument {
275                message: format!(
276                    "load_clip_text_encoder: failed to decode safetensors {}: {e}",
277                    weights_path.display()
278                ),
279            }
280        })?;
281
282    // Re-insert a placeholder entry for the position_ids buffer (with
283    // the upstream key it actually used) so the model's DropReport
284    // captures it as an intentionally-dropped upstream key. The
285    // placeholder tensor is never consumed — `load_hf_state_dict`
286    // drops the entry before any parameter slot sees it.
287    if had_position_ids {
288        let key = if state
289            .keys()
290            .any(|k| k.starts_with("text_model."))
291        {
292            "text_model.embeddings.position_ids".to_string()
293        } else {
294            "embeddings.position_ids".to_string()
295        };
296        state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
297    }
298
299    let mut enc = ClipTextEncoder::<T>::new(cfg)?;
300    let report = enc.load_hf_state_dict(&state, strict)?;
301    Ok((enc, report))
302}
303
304/// Load a [`VaeDecoder`] from a VAE `diffusion_pytorch_model.safetensors`
305/// file plus a parsed config.
306///
307/// `strict=false` is required for a full `AutoencoderKL` checkpoint
308/// (which ships encoder + quant_conv weights this decoder-only loader
309/// has no slot for). The returned [`DropReport`] captures every
310/// dropped key so the pin script can confirm the drop set is exactly
311/// the documented encoder/quant_conv surface.
312///
313/// # Errors
314///
315/// Propagates safetensors parse errors, [`VaeDecoder`] construction
316/// errors, and any per-key shape / strict-mode mismatch from the
317/// underlying load.
318pub fn load_vae_decoder<T: Float>(
319    weights_path: &Path,
320    cfg: VaeDecoderConfig,
321    strict: bool,
322) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
323    let state =
324        load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
325            message: format!(
326                "load_vae_decoder: failed to decode safetensors {}: {e}",
327                weights_path.display()
328            ),
329        })?;
330    let mut decoder = VaeDecoder::<T>::new(cfg)?;
331    let report = decoder.load_hf_state_dict(&state, strict)?;
332    Ok((decoder, report))
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use ferrotorch_core::{Tensor, TensorStorage};
339    use ferrotorch_serialize::save_safetensors;
340    use std::path::PathBuf;
341
342    fn tiny_cfg() -> VaeDecoderConfig {
343        VaeDecoderConfig {
344            out_channels: 3,
345            latent_channels: 4,
346            block_out_channels: vec![4, 8, 16, 16],
347            layers_per_block: 1,
348            norm_num_groups: 4,
349            sample_size: 8,
350            scaling_factor: 0.18215,
351        }
352    }
353
354    fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
355        let dir = tempfile::tempdir().unwrap();
356        let path = dir.path().join("model.safetensors");
357        // The on-disk file uses the bare-VAE prefix (no `vae.`); the
358        // loader's strip-vae path is exercised by the dedicated test
359        // below.
360        let sd = v.state_dict();
361        save_safetensors(&sd, &path).unwrap();
362        (dir, path)
363    }
364
365    #[test]
366    fn round_trip_safetensors_into_decoder() {
367        let cfg = tiny_cfg();
368        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
369        let (_d, p) = tmp_safetensors_from(&src);
370        let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
371        assert!(
372            report.dropped.is_empty(),
373            "round-trip should have empty drop list, got {:?}",
374            report.dropped
375        );
376        let x = Tensor::from_storage(
377            TensorStorage::cpu(vec![0.01f32; 4]),
378            vec![1, 4, 1, 1],
379            false,
380        )
381        .unwrap();
382        let a = src.forward(&x).unwrap();
383        let b = dst.forward(&x).unwrap();
384        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
385            assert!((x - y).abs() < 1e-5);
386        }
387    }
388
389    #[test]
390    fn load_hf_drops_encoder_keys_nonstrict() {
391        let cfg = tiny_cfg();
392        let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
393        let mut hf_sd: StateDict<f32> = v.state_dict();
394        // Add an encoder key — this should be dropped.
395        hf_sd.insert(
396            "encoder.conv_in.weight".into(),
397            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
398        );
399        // Add a quant_conv key — also dropped.
400        hf_sd.insert(
401            "quant_conv.weight".into(),
402            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
403        );
404        let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
405        assert_eq!(
406            rep.dropped,
407            vec![
408                "encoder.conv_in.weight".to_string(),
409                "quant_conv.weight".to_string(),
410            ]
411        );
412    }
413
414    #[test]
415    fn load_hf_strict_rejects_encoder_keys() {
416        let cfg = tiny_cfg();
417        let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
418        let mut hf_sd: StateDict<f32> = HashMap::new();
419        hf_sd.insert(
420            "encoder.conv_in.weight".into(),
421            ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
422        );
423        assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
424    }
425
426    #[test]
427    fn load_hf_strips_vae_prefix() {
428        let cfg = tiny_cfg();
429        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
430        let bare = src.state_dict();
431        // Re-prefix with `vae.` (the layout SD pipeline checkpoints use).
432        let mut prefixed: StateDict<f32> = HashMap::new();
433        for (k, v) in bare {
434            prefixed.insert(format!("vae.{k}"), v);
435        }
436        let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
437        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
438        assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
439        let x = Tensor::from_storage(
440            TensorStorage::cpu(vec![0.01f32; 4]),
441            vec![1, 4, 1, 1],
442            false,
443        )
444        .unwrap();
445        let a = src.forward(&x).unwrap();
446        let b = dst.forward(&x).unwrap();
447        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
448            assert!((x - y).abs() < 1e-5);
449        }
450    }
451}