Skip to main content

ferrotorch_diffusion/
vae.rs

1//! Stable-Diffusion VAE decoder composition.
2//!
3//! The forward path mirrors `diffusers.AutoencoderKL.decode(z).sample`
4//! for `runwayml/stable-diffusion-v1-5`:
5//!
6//! ```text
7//! z (pre-divided by scaling_factor)
8//!   -> post_quant_conv
9//!   -> Decoder.conv_in
10//!   -> Decoder.mid_block
11//!   -> Decoder.up_blocks[0..N]
12//!   -> Decoder.conv_norm_out -> SiLU -> Decoder.conv_out
13//! ```
14//!
15//! The `decode_with_scaling` helper applies `z / scaling_factor` first,
16//! matching `AutoencoderKL.decode(z).sample`. `forward(z)` is the
17//! post-scaling path and accepts an already-divided latent.
18//!
19//! ## REQ status (per `.design/ferrotorch-diffusion/vae.md`)
20//!
21//! | REQ | Status | Evidence |
22//! |---|---|---|
23//! | REQ-1 | SHIPPED | `Decoder<T>` at `vae.rs:30..50` and `Decoder::new` at `vae.rs:51..120`; consumer: `VaeDecoder::new` at `vae.rs:281` builds it; itself consumed by `safetensors_loader.rs:330` `load_vae_decoder` |
24//! | REQ-2 | SHIPPED | `VaeDecoder<T>` at `vae.rs:254..263` and `VaeDecoder::new` at `vae.rs:265..288`; consumer: `safetensors_loader.rs:330` `load_vae_decoder` instantiates it; `pipeline.rs:75` carries it as a pipeline field |
25//! | REQ-3 | SHIPPED | `Module<T>::forward` at `vae.rs:314..327`; consumer: `pipeline.rs:227` `vae.decode_with_scaling(...)` (which calls `forward` internally) |
26//! | REQ-4 | SHIPPED | `decode_with_scaling` at `vae.rs:297..308`; consumer: `pipeline.rs:227` and `examples/vae_decode_dump.rs` invoke it for SD-1.5 decoding |
27//! | REQ-5 | SHIPPED | `Module<T>::load_state_dict` at `vae.rs:366..389`; consumer: `safetensors_loader.rs:89` `VaeDecoder::load_hf_state_dict` calls `self.load_state_dict(&remapped, strict)` after stripping the `vae.` prefix |
28
29use std::collections::HashMap;
30
31use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
32use ferrotorch_nn::module::{Module, StateDict};
33use ferrotorch_nn::parameter::Parameter;
34use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
35
36use crate::blocks::{UNetMidBlock2D, UpDecoderBlock2D};
37use crate::config::VaeDecoderConfig;
38
39/// The bare `Decoder` half — matches `diffusers.models.autoencoders.vae.Decoder`.
40#[derive(Debug)]
41pub struct Decoder<T: Float> {
42    /// First conv: `latent_channels -> block_out_channels[-1]` (k=3, pad=1).
43    pub conv_in: Conv2d<T>,
44    /// VAE mid-block at `block_out_channels[-1]` channels.
45    pub mid_block: UNetMidBlock2D<T>,
46    /// Up-blocks in *decoder order* — block 0 operates at the highest
47    /// channel count and lowest spatial resolution.
48    pub up_blocks: Vec<UpDecoderBlock2D<T>>,
49    /// Final GroupNorm before the output conv (operates on
50    /// `block_out_channels[0]` channels).
51    pub conv_norm_out: GroupNorm<T>,
52    /// Output activation (SiLU).
53    pub conv_act: SiLU,
54    /// Output conv: `block_out_channels[0] -> out_channels` (k=3, pad=1).
55    pub conv_out: Conv2d<T>,
56    /// Frozen copy of the config.
57    pub config: VaeDecoderConfig,
58    training: bool,
59}
60
61impl<T: Float> Decoder<T> {
62    /// Build a randomly-initialized `Decoder`.
63    ///
64    /// # Errors
65    ///
66    /// Returns [`FerrotorchError::InvalidArgument`] for any invalid
67    /// config field (forwarded from [`VaeDecoderConfig::validate`]). In
68    /// particular `block_out_channels` must be non-empty — the `unwrap`
69    /// on `.last()` below is preceded by `cfg.validate()?` which checks
70    /// exactly that.
71    pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
72        cfg.validate()?;
73        let groups = cfg.norm_num_groups;
74        let resnet_eps = 1e-6_f64;
75        let top_channels =
76            *cfg.block_out_channels
77                .last()
78                .ok_or_else(|| FerrotorchError::InvalidArgument {
79                    message: "Decoder::new: block_out_channels is empty (should be unreachable \
80                              after validate)"
81                        .into(),
82                })?;
83
84        let conv_in = Conv2d::<T>::new(
85            cfg.latent_channels,
86            top_channels,
87            (3, 3),
88            (1, 1),
89            (1, 1),
90            true,
91        )?;
92
93        let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
94
95        let reversed: Vec<usize> = cfg.block_out_channels.iter().rev().copied().collect();
96        let mut up_blocks = Vec::with_capacity(reversed.len());
97        let mut prev_out = reversed[0];
98        let num_blocks = reversed.len();
99        let resnets = cfg.resnets_per_up_block();
100        for (i, &c) in reversed.iter().enumerate() {
101            let is_final = i == num_blocks - 1;
102            up_blocks.push(UpDecoderBlock2D::<T>::new(
103                prev_out, c, resnets, groups, resnet_eps, !is_final,
104            )?);
105            prev_out = c;
106        }
107
108        let bottom_channels = cfg.block_out_channels[0];
109        let conv_norm_out = GroupNorm::<T>::new(groups, bottom_channels, resnet_eps, true)?;
110        let conv_out = Conv2d::<T>::new(
111            bottom_channels,
112            cfg.out_channels,
113            (3, 3),
114            (1, 1),
115            (1, 1),
116            true,
117        )?;
118
119        Ok(Self {
120            conv_in,
121            mid_block,
122            up_blocks,
123            conv_norm_out,
124            conv_act: SiLU::new(),
125            conv_out,
126            config: cfg,
127            training: false,
128        })
129    }
130}
131
132impl<T: Float> Module<T> for Decoder<T> {
133    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
134        // Sanity check: [B, latent_channels, H_lat, W_lat].
135        let cfg = &self.config;
136        if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
137            return Err(FerrotorchError::ShapeMismatch {
138                message: format!(
139                    "Decoder::forward: expected [B, {}, H, W], got {:?}",
140                    cfg.latent_channels,
141                    input.shape()
142                ),
143            });
144        }
145        let mut h = self.conv_in.forward(input)?;
146        h = self.mid_block.forward(&h)?;
147        for up in &self.up_blocks {
148            h = up.forward(&h)?;
149        }
150        h = self.conv_norm_out.forward(&h)?;
151        h = self.conv_act.forward(&h)?;
152        self.conv_out.forward(&h)
153    }
154
155    fn parameters(&self) -> Vec<&Parameter<T>> {
156        let mut out = Vec::new();
157        out.extend(self.conv_in.parameters());
158        out.extend(self.mid_block.parameters());
159        for b in &self.up_blocks {
160            out.extend(b.parameters());
161        }
162        out.extend(self.conv_norm_out.parameters());
163        out.extend(self.conv_out.parameters());
164        out
165    }
166
167    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
168        let mut out = Vec::new();
169        out.extend(self.conv_in.parameters_mut());
170        out.extend(self.mid_block.parameters_mut());
171        for b in &mut self.up_blocks {
172            out.extend(b.parameters_mut());
173        }
174        out.extend(self.conv_norm_out.parameters_mut());
175        out.extend(self.conv_out.parameters_mut());
176        out
177    }
178
179    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
180        let mut out = Vec::new();
181        for (n, p) in self.conv_in.named_parameters() {
182            out.push((format!("conv_in.{n}"), p));
183        }
184        for (n, p) in self.mid_block.named_parameters() {
185            out.push((format!("mid_block.{n}"), p));
186        }
187        for (i, b) in self.up_blocks.iter().enumerate() {
188            for (n, p) in b.named_parameters() {
189                out.push((format!("up_blocks.{i}.{n}"), p));
190            }
191        }
192        for (n, p) in self.conv_norm_out.named_parameters() {
193            out.push((format!("conv_norm_out.{n}"), p));
194        }
195        for (n, p) in self.conv_out.named_parameters() {
196            out.push((format!("conv_out.{n}"), p));
197        }
198        out
199    }
200
201    fn train(&mut self) {
202        self.training = true;
203        for b in &mut self.up_blocks {
204            b.train();
205        }
206        self.mid_block.train();
207    }
208    fn eval(&mut self) {
209        self.training = false;
210        for b in &mut self.up_blocks {
211            b.eval();
212        }
213        self.mid_block.eval();
214    }
215    fn is_training(&self) -> bool {
216        self.training
217    }
218
219    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
220        let extract = |prefix: &str| -> StateDict<T> {
221            let p = format!("{prefix}.");
222            state
223                .iter()
224                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
225                .collect()
226        };
227
228        if strict {
229            for k in state.keys() {
230                let ok = k.starts_with("conv_in.")
231                    || k.starts_with("mid_block.")
232                    || k.starts_with("up_blocks.")
233                    || k.starts_with("conv_norm_out.")
234                    || k.starts_with("conv_out.");
235                if !ok {
236                    return Err(FerrotorchError::InvalidArgument {
237                        message: format!("unexpected key in Decoder state_dict: \"{k}\""),
238                    });
239                }
240            }
241        }
242
243        self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
244        self.mid_block
245            .load_state_dict(&extract("mid_block"), strict)?;
246        for (i, b) in self.up_blocks.iter_mut().enumerate() {
247            b.load_state_dict(&extract(&format!("up_blocks.{i}")), strict)?;
248        }
249        self.conv_norm_out
250            .load_state_dict(&extract("conv_norm_out"), strict)?;
251        self.conv_out
252            .load_state_dict(&extract("conv_out"), strict)?;
253        Ok(())
254    }
255}
256
257/// `AutoencoderKL`-style VAE decoder = `post_quant_conv` + [`Decoder`].
258///
259/// The decoder pre-divides the latent by `config.scaling_factor` when
260/// using [`Self::decode_with_scaling`], matching
261/// `AutoencoderKL.decode(z).sample`. [`Module::forward`] expects the
262/// latent already pre-divided (this matches the order of operations the
263/// SD pipeline performs externally).
264#[derive(Debug)]
265pub struct VaeDecoder<T: Float> {
266    /// 1x1 post-quant projection over the 4 latent channels.
267    pub post_quant_conv: Conv2d<T>,
268    /// The actual `Decoder` stack.
269    pub decoder: Decoder<T>,
270    /// Frozen config copy.
271    pub config: VaeDecoderConfig,
272    training: bool,
273}
274
275impl<T: Float> VaeDecoder<T> {
276    /// Build a randomly-initialized `VaeDecoder`.
277    ///
278    /// # Errors
279    ///
280    /// Returns the underlying [`FerrotorchError`] on bad config dims.
281    pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
282        cfg.validate()?;
283        let post_quant_conv = Conv2d::<T>::new(
284            cfg.latent_channels,
285            cfg.latent_channels,
286            (1, 1),
287            (1, 1),
288            (0, 0),
289            true,
290        )?;
291        let decoder = Decoder::<T>::new(cfg.clone())?;
292        Ok(Self {
293            post_quant_conv,
294            decoder,
295            config: cfg,
296            training: false,
297        })
298    }
299
300    /// Decode a latent with the SD scaling convention:
301    /// `image = decoder(post_quant_conv(z / scaling_factor))`.
302    ///
303    /// # Errors
304    ///
305    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
306    /// `[B, latent_channels, H, W]`. Propagates downstream op errors.
307    pub fn decode_with_scaling(&self, latent: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
308        let inv = self.config.scaling_factor.recip();
309        let inv_t = T::from(inv).ok_or_else(|| FerrotorchError::InvalidArgument {
310            message: format!(
311                "VaeDecoder::decode_with_scaling: cannot cast 1/{} into Float",
312                self.config.scaling_factor
313            ),
314        })?;
315        let inv_tensor = ferrotorch_core::scalar::<T>(inv_t)?;
316        let scaled = ferrotorch_core::grad_fns::arithmetic::mul(latent, &inv_tensor)?;
317        self.forward(&scaled)
318    }
319}
320
321impl<T: Float> Module<T> for VaeDecoder<T> {
322    /// Forward expects the post-scaled latent (the caller has already
323    /// divided by `scaling_factor`).
324    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
325        let cfg = &self.config;
326        if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
327            return Err(FerrotorchError::ShapeMismatch {
328                message: format!(
329                    "VaeDecoder::forward: expected [B, {}, H, W], got {:?}",
330                    cfg.latent_channels,
331                    input.shape()
332                ),
333            });
334        }
335        let post = self.post_quant_conv.forward(input)?;
336        self.decoder.forward(&post)
337    }
338
339    fn parameters(&self) -> Vec<&Parameter<T>> {
340        let mut out = Vec::new();
341        out.extend(self.post_quant_conv.parameters());
342        out.extend(self.decoder.parameters());
343        out
344    }
345
346    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
347        let mut out = Vec::new();
348        out.extend(self.post_quant_conv.parameters_mut());
349        out.extend(self.decoder.parameters_mut());
350        out
351    }
352
353    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
354        let mut out = Vec::new();
355        for (n, p) in self.post_quant_conv.named_parameters() {
356            out.push((format!("post_quant_conv.{n}"), p));
357        }
358        for (n, p) in self.decoder.named_parameters() {
359            out.push((format!("decoder.{n}"), p));
360        }
361        out
362    }
363
364    fn train(&mut self) {
365        self.training = true;
366        self.decoder.train();
367    }
368    fn eval(&mut self) {
369        self.training = false;
370        self.decoder.eval();
371    }
372    fn is_training(&self) -> bool {
373        self.training
374    }
375
376    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
377        let extract = |prefix: &str| -> StateDict<T> {
378            let p = format!("{prefix}.");
379            state
380                .iter()
381                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
382                .collect()
383        };
384        if strict {
385            for k in state.keys() {
386                let ok = k.starts_with("post_quant_conv.") || k.starts_with("decoder.");
387                if !ok {
388                    return Err(FerrotorchError::InvalidArgument {
389                        message: format!("unexpected key in VaeDecoder state_dict: \"{k}\""),
390                    });
391                }
392            }
393        }
394        self.post_quant_conv
395            .load_state_dict(&extract("post_quant_conv"), strict)?;
396        self.decoder.load_state_dict(&extract("decoder"), strict)?;
397        let _: HashMap<String, Tensor<T>> = HashMap::new(); // keep HashMap import alive
398        Ok(())
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use ferrotorch_core::TensorStorage;
406
407    /// Tiny config that still exercises every architectural feature
408    /// (mid-block attn, 4 up-blocks, channel-changing resnet shortcut)
409    /// without making the test slow.
410    fn tiny_cfg() -> VaeDecoderConfig {
411        VaeDecoderConfig {
412            out_channels: 3,
413            latent_channels: 4,
414            // 4 blocks; channels grow with depth so the decoder's
415            // *reversed* sequence is [16, 16, 8, 4] — and the
416            // resnet shortcut path is exercised on each transition.
417            block_out_channels: vec![4, 8, 16, 16],
418            layers_per_block: 1, // => 2 resnets per up-block (faster)
419            norm_num_groups: 4,
420            sample_size: 8,
421            scaling_factor: 0.18215,
422        }
423    }
424
425    #[test]
426    fn decoder_forward_shape() {
427        let cfg = tiny_cfg();
428        let d = Decoder::<f32>::new(cfg.clone()).unwrap();
429        // latent: [1, 4, 1, 1] -> after 3 upsamples => [1, 4, 8, 8].
430        let x = Tensor::from_storage(
431            TensorStorage::cpu(vec![0.01f32; 4]),
432            vec![1, 4, 1, 1],
433            false,
434        )
435        .unwrap();
436        let y = d.forward(&x).unwrap();
437        // 1 -> 2 -> 4 -> 8 (3 upsamples, last block has no upsample).
438        assert_eq!(y.shape(), &[1, 3, 8, 8]);
439        for &v in y.data().unwrap() {
440            assert!(v.is_finite(), "decoder output non-finite: {v}");
441        }
442    }
443
444    #[test]
445    fn vae_decoder_named_parameters_include_post_quant_conv() {
446        let cfg = tiny_cfg();
447        let v = VaeDecoder::<f32>::new(cfg).unwrap();
448        let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
449        for k in [
450            "post_quant_conv.weight",
451            "post_quant_conv.bias",
452            "decoder.conv_in.weight",
453            "decoder.mid_block.attentions.0.to_q.weight",
454            "decoder.up_blocks.0.resnets.0.norm1.weight",
455            "decoder.conv_norm_out.weight",
456            "decoder.conv_out.bias",
457        ] {
458            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
459        }
460    }
461
462    #[test]
463    fn vae_decoder_forward_shape() {
464        let cfg = tiny_cfg();
465        let v = VaeDecoder::<f32>::new(cfg).unwrap();
466        let x = Tensor::from_storage(
467            TensorStorage::cpu(vec![0.01f32; 4]),
468            vec![1, 4, 1, 1],
469            false,
470        )
471        .unwrap();
472        let y = v.forward(&x).unwrap();
473        assert_eq!(y.shape(), &[1, 3, 8, 8]);
474    }
475
476    #[test]
477    fn vae_decoder_decode_with_scaling_matches_manual_div() {
478        let cfg = tiny_cfg();
479        let v = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
480        let x = Tensor::from_storage(
481            TensorStorage::cpu(vec![0.05f32; 4]),
482            vec![1, 4, 1, 1],
483            false,
484        )
485        .unwrap();
486        let inv = (1.0 / cfg.scaling_factor) as f32;
487        let scaled_data: Vec<f32> = x.data().unwrap().iter().map(|&v| v * inv).collect();
488        let scaled =
489            Tensor::from_storage(TensorStorage::cpu(scaled_data), vec![1, 4, 1, 1], false).unwrap();
490        let a = v.decode_with_scaling(&x).unwrap();
491        let b = v.forward(&scaled).unwrap();
492        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
493            assert!(
494                (x - y).abs() < 1e-4,
495                "decode_with_scaling vs manual div differ: {x} vs {y}"
496            );
497        }
498    }
499
500    #[test]
501    fn round_trip_state_dict() {
502        let cfg = tiny_cfg();
503        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
504        let sd = src.state_dict();
505        let mut dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
506        dst.load_state_dict(&sd, true).unwrap();
507        let x = Tensor::from_storage(
508            TensorStorage::cpu(vec![0.01f32; 4]),
509            vec![1, 4, 1, 1],
510            false,
511        )
512        .unwrap();
513        let a = src.forward(&x).unwrap();
514        let b = dst.forward(&x).unwrap();
515        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
516            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
517        }
518    }
519}