Skip to main content

ferrotorch_diffusion/
vae.rs

1//! Stable-Diffusion VAE decoder composition.
2//!
3//! The forward path mirrors `diffusers.AutoencoderKL.decode(z).sample`
4//! for `runwayml/stable-diffusion-v1-5`:
5//!
6//! ```text
7//! z (pre-divided by scaling_factor)
8//!   -> post_quant_conv
9//!   -> Decoder.conv_in
10//!   -> Decoder.mid_block
11//!   -> Decoder.up_blocks[0..N]
12//!   -> Decoder.conv_norm_out -> SiLU -> Decoder.conv_out
13//! ```
14//!
15//! The `decode_with_scaling` helper applies `z / scaling_factor` first,
16//! matching `AutoencoderKL.decode(z).sample`. `forward(z)` is the
17//! post-scaling path and accepts an already-divided latent.
18
19use std::collections::HashMap;
20
21use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
22use ferrotorch_nn::module::{Module, StateDict};
23use ferrotorch_nn::parameter::Parameter;
24use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
25
26use crate::blocks::{UNetMidBlock2D, UpDecoderBlock2D};
27use crate::config::VaeDecoderConfig;
28
29/// The bare `Decoder` half — matches `diffusers.models.autoencoders.vae.Decoder`.
30#[derive(Debug)]
31pub struct Decoder<T: Float> {
32    /// First conv: `latent_channels -> block_out_channels[-1]` (k=3, pad=1).
33    pub conv_in: Conv2d<T>,
34    /// VAE mid-block at `block_out_channels[-1]` channels.
35    pub mid_block: UNetMidBlock2D<T>,
36    /// Up-blocks in *decoder order* — block 0 operates at the highest
37    /// channel count and lowest spatial resolution.
38    pub up_blocks: Vec<UpDecoderBlock2D<T>>,
39    /// Final GroupNorm before the output conv (operates on
40    /// `block_out_channels[0]` channels).
41    pub conv_norm_out: GroupNorm<T>,
42    /// Output activation (SiLU).
43    pub conv_act: SiLU,
44    /// Output conv: `block_out_channels[0] -> out_channels` (k=3, pad=1).
45    pub conv_out: Conv2d<T>,
46    /// Frozen copy of the config.
47    pub config: VaeDecoderConfig,
48    training: bool,
49}
50
51impl<T: Float> Decoder<T> {
52    /// Build a randomly-initialized `Decoder`.
53    ///
54    /// # Errors
55    ///
56    /// Returns [`FerrotorchError::InvalidArgument`] for any invalid
57    /// config field (forwarded from [`VaeDecoderConfig::validate`]). In
58    /// particular `block_out_channels` must be non-empty — the `unwrap`
59    /// on `.last()` below is preceded by `cfg.validate()?` which checks
60    /// exactly that.
61    pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
62        cfg.validate()?;
63        let groups = cfg.norm_num_groups;
64        let resnet_eps = 1e-6_f64;
65        let top_channels =
66            *cfg.block_out_channels
67                .last()
68                .ok_or_else(|| FerrotorchError::InvalidArgument {
69                    message: "Decoder::new: block_out_channels is empty (should be unreachable \
70                              after validate)"
71                        .into(),
72                })?;
73
74        let conv_in = Conv2d::<T>::new(
75            cfg.latent_channels,
76            top_channels,
77            (3, 3),
78            (1, 1),
79            (1, 1),
80            true,
81        )?;
82
83        let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
84
85        let reversed: Vec<usize> = cfg.block_out_channels.iter().rev().copied().collect();
86        let mut up_blocks = Vec::with_capacity(reversed.len());
87        let mut prev_out = reversed[0];
88        let num_blocks = reversed.len();
89        let resnets = cfg.resnets_per_up_block();
90        for (i, &c) in reversed.iter().enumerate() {
91            let is_final = i == num_blocks - 1;
92            up_blocks.push(UpDecoderBlock2D::<T>::new(
93                prev_out,
94                c,
95                resnets,
96                groups,
97                resnet_eps,
98                !is_final,
99            )?);
100            prev_out = c;
101        }
102
103        let bottom_channels = cfg.block_out_channels[0];
104        let conv_norm_out =
105            GroupNorm::<T>::new(groups, bottom_channels, resnet_eps, true)?;
106        let conv_out = Conv2d::<T>::new(
107            bottom_channels,
108            cfg.out_channels,
109            (3, 3),
110            (1, 1),
111            (1, 1),
112            true,
113        )?;
114
115        Ok(Self {
116            conv_in,
117            mid_block,
118            up_blocks,
119            conv_norm_out,
120            conv_act: SiLU::new(),
121            conv_out,
122            config: cfg,
123            training: false,
124        })
125    }
126}
127
128impl<T: Float> Module<T> for Decoder<T> {
129    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
130        // Sanity check: [B, latent_channels, H_lat, W_lat].
131        let cfg = &self.config;
132        if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
133            return Err(FerrotorchError::ShapeMismatch {
134                message: format!(
135                    "Decoder::forward: expected [B, {}, H, W], got {:?}",
136                    cfg.latent_channels,
137                    input.shape()
138                ),
139            });
140        }
141        let mut h = self.conv_in.forward(input)?;
142        h = self.mid_block.forward(&h)?;
143        for up in &self.up_blocks {
144            h = up.forward(&h)?;
145        }
146        h = self.conv_norm_out.forward(&h)?;
147        h = self.conv_act.forward(&h)?;
148        self.conv_out.forward(&h)
149    }
150
151    fn parameters(&self) -> Vec<&Parameter<T>> {
152        let mut out = Vec::new();
153        out.extend(self.conv_in.parameters());
154        out.extend(self.mid_block.parameters());
155        for b in &self.up_blocks {
156            out.extend(b.parameters());
157        }
158        out.extend(self.conv_norm_out.parameters());
159        out.extend(self.conv_out.parameters());
160        out
161    }
162
163    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
164        let mut out = Vec::new();
165        out.extend(self.conv_in.parameters_mut());
166        out.extend(self.mid_block.parameters_mut());
167        for b in &mut self.up_blocks {
168            out.extend(b.parameters_mut());
169        }
170        out.extend(self.conv_norm_out.parameters_mut());
171        out.extend(self.conv_out.parameters_mut());
172        out
173    }
174
175    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
176        let mut out = Vec::new();
177        for (n, p) in self.conv_in.named_parameters() {
178            out.push((format!("conv_in.{n}"), p));
179        }
180        for (n, p) in self.mid_block.named_parameters() {
181            out.push((format!("mid_block.{n}"), p));
182        }
183        for (i, b) in self.up_blocks.iter().enumerate() {
184            for (n, p) in b.named_parameters() {
185                out.push((format!("up_blocks.{i}.{n}"), p));
186            }
187        }
188        for (n, p) in self.conv_norm_out.named_parameters() {
189            out.push((format!("conv_norm_out.{n}"), p));
190        }
191        for (n, p) in self.conv_out.named_parameters() {
192            out.push((format!("conv_out.{n}"), p));
193        }
194        out
195    }
196
197    fn train(&mut self) {
198        self.training = true;
199        for b in &mut self.up_blocks {
200            b.train();
201        }
202        self.mid_block.train();
203    }
204    fn eval(&mut self) {
205        self.training = false;
206        for b in &mut self.up_blocks {
207            b.eval();
208        }
209        self.mid_block.eval();
210    }
211    fn is_training(&self) -> bool {
212        self.training
213    }
214
215    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
216        let extract = |prefix: &str| -> StateDict<T> {
217            let p = format!("{prefix}.");
218            state
219                .iter()
220                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
221                .collect()
222        };
223
224        if strict {
225            for k in state.keys() {
226                let ok = k.starts_with("conv_in.")
227                    || k.starts_with("mid_block.")
228                    || k.starts_with("up_blocks.")
229                    || k.starts_with("conv_norm_out.")
230                    || k.starts_with("conv_out.");
231                if !ok {
232                    return Err(FerrotorchError::InvalidArgument {
233                        message: format!("unexpected key in Decoder state_dict: \"{k}\""),
234                    });
235                }
236            }
237        }
238
239        self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
240        self.mid_block
241            .load_state_dict(&extract("mid_block"), strict)?;
242        for (i, b) in self.up_blocks.iter_mut().enumerate() {
243            b.load_state_dict(&extract(&format!("up_blocks.{i}")), strict)?;
244        }
245        self.conv_norm_out
246            .load_state_dict(&extract("conv_norm_out"), strict)?;
247        self.conv_out
248            .load_state_dict(&extract("conv_out"), strict)?;
249        Ok(())
250    }
251}
252
253/// `AutoencoderKL`-style VAE decoder = `post_quant_conv` + [`Decoder`].
254///
255/// The decoder pre-divides the latent by `config.scaling_factor` when
256/// using [`Self::decode_with_scaling`], matching
257/// `AutoencoderKL.decode(z).sample`. [`Module::forward`] expects the
258/// latent already pre-divided (this matches the order of operations the
259/// SD pipeline performs externally).
260#[derive(Debug)]
261pub struct VaeDecoder<T: Float> {
262    /// 1x1 post-quant projection over the 4 latent channels.
263    pub post_quant_conv: Conv2d<T>,
264    /// The actual `Decoder` stack.
265    pub decoder: Decoder<T>,
266    /// Frozen config copy.
267    pub config: VaeDecoderConfig,
268    training: bool,
269}
270
271impl<T: Float> VaeDecoder<T> {
272    /// Build a randomly-initialized `VaeDecoder`.
273    ///
274    /// # Errors
275    ///
276    /// Returns the underlying [`FerrotorchError`] on bad config dims.
277    pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
278        cfg.validate()?;
279        let post_quant_conv = Conv2d::<T>::new(
280            cfg.latent_channels,
281            cfg.latent_channels,
282            (1, 1),
283            (1, 1),
284            (0, 0),
285            true,
286        )?;
287        let decoder = Decoder::<T>::new(cfg.clone())?;
288        Ok(Self {
289            post_quant_conv,
290            decoder,
291            config: cfg,
292            training: false,
293        })
294    }
295
296    /// Decode a latent with the SD scaling convention:
297    /// `image = decoder(post_quant_conv(z / scaling_factor))`.
298    ///
299    /// # Errors
300    ///
301    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
302    /// `[B, latent_channels, H, W]`. Propagates downstream op errors.
303    pub fn decode_with_scaling(&self, latent: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
304        let inv = self.config.scaling_factor.recip();
305        let inv_t = T::from(inv).ok_or_else(|| FerrotorchError::InvalidArgument {
306            message: format!(
307                "VaeDecoder::decode_with_scaling: cannot cast 1/{} into Float",
308                self.config.scaling_factor
309            ),
310        })?;
311        let inv_tensor = ferrotorch_core::scalar::<T>(inv_t)?;
312        let scaled = ferrotorch_core::grad_fns::arithmetic::mul(latent, &inv_tensor)?;
313        self.forward(&scaled)
314    }
315}
316
317impl<T: Float> Module<T> for VaeDecoder<T> {
318    /// Forward expects the post-scaled latent (the caller has already
319    /// divided by `scaling_factor`).
320    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
321        let cfg = &self.config;
322        if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
323            return Err(FerrotorchError::ShapeMismatch {
324                message: format!(
325                    "VaeDecoder::forward: expected [B, {}, H, W], got {:?}",
326                    cfg.latent_channels,
327                    input.shape()
328                ),
329            });
330        }
331        let post = self.post_quant_conv.forward(input)?;
332        self.decoder.forward(&post)
333    }
334
335    fn parameters(&self) -> Vec<&Parameter<T>> {
336        let mut out = Vec::new();
337        out.extend(self.post_quant_conv.parameters());
338        out.extend(self.decoder.parameters());
339        out
340    }
341
342    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
343        let mut out = Vec::new();
344        out.extend(self.post_quant_conv.parameters_mut());
345        out.extend(self.decoder.parameters_mut());
346        out
347    }
348
349    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
350        let mut out = Vec::new();
351        for (n, p) in self.post_quant_conv.named_parameters() {
352            out.push((format!("post_quant_conv.{n}"), p));
353        }
354        for (n, p) in self.decoder.named_parameters() {
355            out.push((format!("decoder.{n}"), p));
356        }
357        out
358    }
359
360    fn train(&mut self) {
361        self.training = true;
362        self.decoder.train();
363    }
364    fn eval(&mut self) {
365        self.training = false;
366        self.decoder.eval();
367    }
368    fn is_training(&self) -> bool {
369        self.training
370    }
371
372    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
373        let extract = |prefix: &str| -> StateDict<T> {
374            let p = format!("{prefix}.");
375            state
376                .iter()
377                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
378                .collect()
379        };
380        if strict {
381            for k in state.keys() {
382                let ok = k.starts_with("post_quant_conv.") || k.starts_with("decoder.");
383                if !ok {
384                    return Err(FerrotorchError::InvalidArgument {
385                        message: format!("unexpected key in VaeDecoder state_dict: \"{k}\""),
386                    });
387                }
388            }
389        }
390        self.post_quant_conv
391            .load_state_dict(&extract("post_quant_conv"), strict)?;
392        self.decoder.load_state_dict(&extract("decoder"), strict)?;
393        let _: HashMap<String, Tensor<T>> = HashMap::new(); // keep HashMap import alive
394        Ok(())
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use ferrotorch_core::TensorStorage;
402
403    /// Tiny config that still exercises every architectural feature
404    /// (mid-block attn, 4 up-blocks, channel-changing resnet shortcut)
405    /// without making the test slow.
406    fn tiny_cfg() -> VaeDecoderConfig {
407        VaeDecoderConfig {
408            out_channels: 3,
409            latent_channels: 4,
410            // 4 blocks; channels grow with depth so the decoder's
411            // *reversed* sequence is [16, 16, 8, 4] — and the
412            // resnet shortcut path is exercised on each transition.
413            block_out_channels: vec![4, 8, 16, 16],
414            layers_per_block: 1, // => 2 resnets per up-block (faster)
415            norm_num_groups: 4,
416            sample_size: 8,
417            scaling_factor: 0.18215,
418        }
419    }
420
421    #[test]
422    fn decoder_forward_shape() {
423        let cfg = tiny_cfg();
424        let d = Decoder::<f32>::new(cfg.clone()).unwrap();
425        // latent: [1, 4, 1, 1] -> after 3 upsamples => [1, 4, 8, 8].
426        let x = Tensor::from_storage(
427            TensorStorage::cpu(vec![0.01f32; 4]),
428            vec![1, 4, 1, 1],
429            false,
430        )
431        .unwrap();
432        let y = d.forward(&x).unwrap();
433        // 1 -> 2 -> 4 -> 8 (3 upsamples, last block has no upsample).
434        assert_eq!(y.shape(), &[1, 3, 8, 8]);
435        for &v in y.data().unwrap() {
436            assert!(v.is_finite(), "decoder output non-finite: {v}");
437        }
438    }
439
440    #[test]
441    fn vae_decoder_named_parameters_include_post_quant_conv() {
442        let cfg = tiny_cfg();
443        let v = VaeDecoder::<f32>::new(cfg).unwrap();
444        let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
445        for k in [
446            "post_quant_conv.weight",
447            "post_quant_conv.bias",
448            "decoder.conv_in.weight",
449            "decoder.mid_block.attentions.0.to_q.weight",
450            "decoder.up_blocks.0.resnets.0.norm1.weight",
451            "decoder.conv_norm_out.weight",
452            "decoder.conv_out.bias",
453        ] {
454            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
455        }
456    }
457
458    #[test]
459    fn vae_decoder_forward_shape() {
460        let cfg = tiny_cfg();
461        let v = VaeDecoder::<f32>::new(cfg).unwrap();
462        let x = Tensor::from_storage(
463            TensorStorage::cpu(vec![0.01f32; 4]),
464            vec![1, 4, 1, 1],
465            false,
466        )
467        .unwrap();
468        let y = v.forward(&x).unwrap();
469        assert_eq!(y.shape(), &[1, 3, 8, 8]);
470    }
471
472    #[test]
473    fn vae_decoder_decode_with_scaling_matches_manual_div() {
474        let cfg = tiny_cfg();
475        let v = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
476        let x = Tensor::from_storage(
477            TensorStorage::cpu(vec![0.05f32; 4]),
478            vec![1, 4, 1, 1],
479            false,
480        )
481        .unwrap();
482        let inv = (1.0 / cfg.scaling_factor) as f32;
483        let scaled_data: Vec<f32> =
484            x.data().unwrap().iter().map(|&v| v * inv).collect();
485        let scaled = Tensor::from_storage(
486            TensorStorage::cpu(scaled_data),
487            vec![1, 4, 1, 1],
488            false,
489        )
490        .unwrap();
491        let a = v.decode_with_scaling(&x).unwrap();
492        let b = v.forward(&scaled).unwrap();
493        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
494            assert!(
495                (x - y).abs() < 1e-4,
496                "decode_with_scaling vs manual div differ: {x} vs {y}"
497            );
498        }
499    }
500
501    #[test]
502    fn round_trip_state_dict() {
503        let cfg = tiny_cfg();
504        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
505        let sd = src.state_dict();
506        let mut dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
507        dst.load_state_dict(&sd, true).unwrap();
508        let x = Tensor::from_storage(
509            TensorStorage::cpu(vec![0.01f32; 4]),
510            vec![1, 4, 1, 1],
511            false,
512        )
513        .unwrap();
514        let a = src.forward(&x).unwrap();
515        let b = dst.forward(&x).unwrap();
516        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
517            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
518        }
519    }
520}