Skip to main content

ferrotorch_diffusion/
vae.rs

1//! Stable-Diffusion VAE decoder composition.
2//!
3//! The forward path mirrors `diffusers.AutoencoderKL.decode(z).sample`
4//! for `runwayml/stable-diffusion-v1-5`:
5//!
6//! ```text
7//! z (pre-divided by scaling_factor)
8//!   -> post_quant_conv
9//!   -> Decoder.conv_in
10//!   -> Decoder.mid_block
11//!   -> Decoder.up_blocks[0..N]
12//!   -> Decoder.conv_norm_out -> SiLU -> Decoder.conv_out
13//! ```
14//!
15//! The `decode_with_scaling` helper applies `z / scaling_factor` first,
16//! matching `AutoencoderKL.decode(z).sample`. `forward(z)` is the
17//! post-scaling path and accepts an already-divided latent.
18
19use std::collections::HashMap;
20
21use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
22use ferrotorch_nn::module::{Module, StateDict};
23use ferrotorch_nn::parameter::Parameter;
24use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
25
26use crate::blocks::{UNetMidBlock2D, UpDecoderBlock2D};
27use crate::config::VaeDecoderConfig;
28
29/// The bare `Decoder` half — matches `diffusers.models.autoencoders.vae.Decoder`.
30#[derive(Debug)]
31pub struct Decoder<T: Float> {
32    /// First conv: `latent_channels -> block_out_channels[-1]` (k=3, pad=1).
33    pub conv_in: Conv2d<T>,
34    /// VAE mid-block at `block_out_channels[-1]` channels.
35    pub mid_block: UNetMidBlock2D<T>,
36    /// Up-blocks in *decoder order* — block 0 operates at the highest
37    /// channel count and lowest spatial resolution.
38    pub up_blocks: Vec<UpDecoderBlock2D<T>>,
39    /// Final GroupNorm before the output conv (operates on
40    /// `block_out_channels[0]` channels).
41    pub conv_norm_out: GroupNorm<T>,
42    /// Output activation (SiLU).
43    pub conv_act: SiLU,
44    /// Output conv: `block_out_channels[0] -> out_channels` (k=3, pad=1).
45    pub conv_out: Conv2d<T>,
46    /// Frozen copy of the config.
47    pub config: VaeDecoderConfig,
48    training: bool,
49}
50
51impl<T: Float> Decoder<T> {
52    /// Build a randomly-initialized `Decoder`.
53    ///
54    /// # Errors
55    ///
56    /// Returns [`FerrotorchError::InvalidArgument`] for any invalid
57    /// config field (forwarded from [`VaeDecoderConfig::validate`]). In
58    /// particular `block_out_channels` must be non-empty — the `unwrap`
59    /// on `.last()` below is preceded by `cfg.validate()?` which checks
60    /// exactly that.
61    pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
62        cfg.validate()?;
63        let groups = cfg.norm_num_groups;
64        let resnet_eps = 1e-6_f64;
65        let top_channels =
66            *cfg.block_out_channels
67                .last()
68                .ok_or_else(|| FerrotorchError::InvalidArgument {
69                    message: "Decoder::new: block_out_channels is empty (should be unreachable \
70                              after validate)"
71                        .into(),
72                })?;
73
74        let conv_in = Conv2d::<T>::new(
75            cfg.latent_channels,
76            top_channels,
77            (3, 3),
78            (1, 1),
79            (1, 1),
80            true,
81        )?;
82
83        let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
84
85        let reversed: Vec<usize> = cfg.block_out_channels.iter().rev().copied().collect();
86        let mut up_blocks = Vec::with_capacity(reversed.len());
87        let mut prev_out = reversed[0];
88        let num_blocks = reversed.len();
89        let resnets = cfg.resnets_per_up_block();
90        for (i, &c) in reversed.iter().enumerate() {
91            let is_final = i == num_blocks - 1;
92            up_blocks.push(UpDecoderBlock2D::<T>::new(
93                prev_out, c, resnets, groups, resnet_eps, !is_final,
94            )?);
95            prev_out = c;
96        }
97
98        let bottom_channels = cfg.block_out_channels[0];
99        let conv_norm_out = GroupNorm::<T>::new(groups, bottom_channels, resnet_eps, true)?;
100        let conv_out = Conv2d::<T>::new(
101            bottom_channels,
102            cfg.out_channels,
103            (3, 3),
104            (1, 1),
105            (1, 1),
106            true,
107        )?;
108
109        Ok(Self {
110            conv_in,
111            mid_block,
112            up_blocks,
113            conv_norm_out,
114            conv_act: SiLU::new(),
115            conv_out,
116            config: cfg,
117            training: false,
118        })
119    }
120}
121
122impl<T: Float> Module<T> for Decoder<T> {
123    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
124        // Sanity check: [B, latent_channels, H_lat, W_lat].
125        let cfg = &self.config;
126        if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
127            return Err(FerrotorchError::ShapeMismatch {
128                message: format!(
129                    "Decoder::forward: expected [B, {}, H, W], got {:?}",
130                    cfg.latent_channels,
131                    input.shape()
132                ),
133            });
134        }
135        let mut h = self.conv_in.forward(input)?;
136        h = self.mid_block.forward(&h)?;
137        for up in &self.up_blocks {
138            h = up.forward(&h)?;
139        }
140        h = self.conv_norm_out.forward(&h)?;
141        h = self.conv_act.forward(&h)?;
142        self.conv_out.forward(&h)
143    }
144
145    fn parameters(&self) -> Vec<&Parameter<T>> {
146        let mut out = Vec::new();
147        out.extend(self.conv_in.parameters());
148        out.extend(self.mid_block.parameters());
149        for b in &self.up_blocks {
150            out.extend(b.parameters());
151        }
152        out.extend(self.conv_norm_out.parameters());
153        out.extend(self.conv_out.parameters());
154        out
155    }
156
157    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
158        let mut out = Vec::new();
159        out.extend(self.conv_in.parameters_mut());
160        out.extend(self.mid_block.parameters_mut());
161        for b in &mut self.up_blocks {
162            out.extend(b.parameters_mut());
163        }
164        out.extend(self.conv_norm_out.parameters_mut());
165        out.extend(self.conv_out.parameters_mut());
166        out
167    }
168
169    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
170        let mut out = Vec::new();
171        for (n, p) in self.conv_in.named_parameters() {
172            out.push((format!("conv_in.{n}"), p));
173        }
174        for (n, p) in self.mid_block.named_parameters() {
175            out.push((format!("mid_block.{n}"), p));
176        }
177        for (i, b) in self.up_blocks.iter().enumerate() {
178            for (n, p) in b.named_parameters() {
179                out.push((format!("up_blocks.{i}.{n}"), p));
180            }
181        }
182        for (n, p) in self.conv_norm_out.named_parameters() {
183            out.push((format!("conv_norm_out.{n}"), p));
184        }
185        for (n, p) in self.conv_out.named_parameters() {
186            out.push((format!("conv_out.{n}"), p));
187        }
188        out
189    }
190
191    fn train(&mut self) {
192        self.training = true;
193        for b in &mut self.up_blocks {
194            b.train();
195        }
196        self.mid_block.train();
197    }
198    fn eval(&mut self) {
199        self.training = false;
200        for b in &mut self.up_blocks {
201            b.eval();
202        }
203        self.mid_block.eval();
204    }
205    fn is_training(&self) -> bool {
206        self.training
207    }
208
209    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
210        let extract = |prefix: &str| -> StateDict<T> {
211            let p = format!("{prefix}.");
212            state
213                .iter()
214                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
215                .collect()
216        };
217
218        if strict {
219            for k in state.keys() {
220                let ok = k.starts_with("conv_in.")
221                    || k.starts_with("mid_block.")
222                    || k.starts_with("up_blocks.")
223                    || k.starts_with("conv_norm_out.")
224                    || k.starts_with("conv_out.");
225                if !ok {
226                    return Err(FerrotorchError::InvalidArgument {
227                        message: format!("unexpected key in Decoder state_dict: \"{k}\""),
228                    });
229                }
230            }
231        }
232
233        self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
234        self.mid_block
235            .load_state_dict(&extract("mid_block"), strict)?;
236        for (i, b) in self.up_blocks.iter_mut().enumerate() {
237            b.load_state_dict(&extract(&format!("up_blocks.{i}")), strict)?;
238        }
239        self.conv_norm_out
240            .load_state_dict(&extract("conv_norm_out"), strict)?;
241        self.conv_out
242            .load_state_dict(&extract("conv_out"), strict)?;
243        Ok(())
244    }
245}
246
247/// `AutoencoderKL`-style VAE decoder = `post_quant_conv` + [`Decoder`].
248///
249/// The decoder pre-divides the latent by `config.scaling_factor` when
250/// using [`Self::decode_with_scaling`], matching
251/// `AutoencoderKL.decode(z).sample`. [`Module::forward`] expects the
252/// latent already pre-divided (this matches the order of operations the
253/// SD pipeline performs externally).
254#[derive(Debug)]
255pub struct VaeDecoder<T: Float> {
256    /// 1x1 post-quant projection over the 4 latent channels.
257    pub post_quant_conv: Conv2d<T>,
258    /// The actual `Decoder` stack.
259    pub decoder: Decoder<T>,
260    /// Frozen config copy.
261    pub config: VaeDecoderConfig,
262    training: bool,
263}
264
265impl<T: Float> VaeDecoder<T> {
266    /// Build a randomly-initialized `VaeDecoder`.
267    ///
268    /// # Errors
269    ///
270    /// Returns the underlying [`FerrotorchError`] on bad config dims.
271    pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
272        cfg.validate()?;
273        let post_quant_conv = Conv2d::<T>::new(
274            cfg.latent_channels,
275            cfg.latent_channels,
276            (1, 1),
277            (1, 1),
278            (0, 0),
279            true,
280        )?;
281        let decoder = Decoder::<T>::new(cfg.clone())?;
282        Ok(Self {
283            post_quant_conv,
284            decoder,
285            config: cfg,
286            training: false,
287        })
288    }
289
290    /// Decode a latent with the SD scaling convention:
291    /// `image = decoder(post_quant_conv(z / scaling_factor))`.
292    ///
293    /// # Errors
294    ///
295    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
296    /// `[B, latent_channels, H, W]`. Propagates downstream op errors.
297    pub fn decode_with_scaling(&self, latent: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
298        let inv = self.config.scaling_factor.recip();
299        let inv_t = T::from(inv).ok_or_else(|| FerrotorchError::InvalidArgument {
300            message: format!(
301                "VaeDecoder::decode_with_scaling: cannot cast 1/{} into Float",
302                self.config.scaling_factor
303            ),
304        })?;
305        let inv_tensor = ferrotorch_core::scalar::<T>(inv_t)?;
306        let scaled = ferrotorch_core::grad_fns::arithmetic::mul(latent, &inv_tensor)?;
307        self.forward(&scaled)
308    }
309}
310
311impl<T: Float> Module<T> for VaeDecoder<T> {
312    /// Forward expects the post-scaled latent (the caller has already
313    /// divided by `scaling_factor`).
314    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
315        let cfg = &self.config;
316        if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
317            return Err(FerrotorchError::ShapeMismatch {
318                message: format!(
319                    "VaeDecoder::forward: expected [B, {}, H, W], got {:?}",
320                    cfg.latent_channels,
321                    input.shape()
322                ),
323            });
324        }
325        let post = self.post_quant_conv.forward(input)?;
326        self.decoder.forward(&post)
327    }
328
329    fn parameters(&self) -> Vec<&Parameter<T>> {
330        let mut out = Vec::new();
331        out.extend(self.post_quant_conv.parameters());
332        out.extend(self.decoder.parameters());
333        out
334    }
335
336    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
337        let mut out = Vec::new();
338        out.extend(self.post_quant_conv.parameters_mut());
339        out.extend(self.decoder.parameters_mut());
340        out
341    }
342
343    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
344        let mut out = Vec::new();
345        for (n, p) in self.post_quant_conv.named_parameters() {
346            out.push((format!("post_quant_conv.{n}"), p));
347        }
348        for (n, p) in self.decoder.named_parameters() {
349            out.push((format!("decoder.{n}"), p));
350        }
351        out
352    }
353
354    fn train(&mut self) {
355        self.training = true;
356        self.decoder.train();
357    }
358    fn eval(&mut self) {
359        self.training = false;
360        self.decoder.eval();
361    }
362    fn is_training(&self) -> bool {
363        self.training
364    }
365
366    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
367        let extract = |prefix: &str| -> StateDict<T> {
368            let p = format!("{prefix}.");
369            state
370                .iter()
371                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
372                .collect()
373        };
374        if strict {
375            for k in state.keys() {
376                let ok = k.starts_with("post_quant_conv.") || k.starts_with("decoder.");
377                if !ok {
378                    return Err(FerrotorchError::InvalidArgument {
379                        message: format!("unexpected key in VaeDecoder state_dict: \"{k}\""),
380                    });
381                }
382            }
383        }
384        self.post_quant_conv
385            .load_state_dict(&extract("post_quant_conv"), strict)?;
386        self.decoder.load_state_dict(&extract("decoder"), strict)?;
387        let _: HashMap<String, Tensor<T>> = HashMap::new(); // keep HashMap import alive
388        Ok(())
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use ferrotorch_core::TensorStorage;
396
397    /// Tiny config that still exercises every architectural feature
398    /// (mid-block attn, 4 up-blocks, channel-changing resnet shortcut)
399    /// without making the test slow.
400    fn tiny_cfg() -> VaeDecoderConfig {
401        VaeDecoderConfig {
402            out_channels: 3,
403            latent_channels: 4,
404            // 4 blocks; channels grow with depth so the decoder's
405            // *reversed* sequence is [16, 16, 8, 4] — and the
406            // resnet shortcut path is exercised on each transition.
407            block_out_channels: vec![4, 8, 16, 16],
408            layers_per_block: 1, // => 2 resnets per up-block (faster)
409            norm_num_groups: 4,
410            sample_size: 8,
411            scaling_factor: 0.18215,
412        }
413    }
414
415    #[test]
416    fn decoder_forward_shape() {
417        let cfg = tiny_cfg();
418        let d = Decoder::<f32>::new(cfg.clone()).unwrap();
419        // latent: [1, 4, 1, 1] -> after 3 upsamples => [1, 4, 8, 8].
420        let x = Tensor::from_storage(
421            TensorStorage::cpu(vec![0.01f32; 4]),
422            vec![1, 4, 1, 1],
423            false,
424        )
425        .unwrap();
426        let y = d.forward(&x).unwrap();
427        // 1 -> 2 -> 4 -> 8 (3 upsamples, last block has no upsample).
428        assert_eq!(y.shape(), &[1, 3, 8, 8]);
429        for &v in y.data().unwrap() {
430            assert!(v.is_finite(), "decoder output non-finite: {v}");
431        }
432    }
433
434    #[test]
435    fn vae_decoder_named_parameters_include_post_quant_conv() {
436        let cfg = tiny_cfg();
437        let v = VaeDecoder::<f32>::new(cfg).unwrap();
438        let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
439        for k in [
440            "post_quant_conv.weight",
441            "post_quant_conv.bias",
442            "decoder.conv_in.weight",
443            "decoder.mid_block.attentions.0.to_q.weight",
444            "decoder.up_blocks.0.resnets.0.norm1.weight",
445            "decoder.conv_norm_out.weight",
446            "decoder.conv_out.bias",
447        ] {
448            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
449        }
450    }
451
452    #[test]
453    fn vae_decoder_forward_shape() {
454        let cfg = tiny_cfg();
455        let v = VaeDecoder::<f32>::new(cfg).unwrap();
456        let x = Tensor::from_storage(
457            TensorStorage::cpu(vec![0.01f32; 4]),
458            vec![1, 4, 1, 1],
459            false,
460        )
461        .unwrap();
462        let y = v.forward(&x).unwrap();
463        assert_eq!(y.shape(), &[1, 3, 8, 8]);
464    }
465
466    #[test]
467    fn vae_decoder_decode_with_scaling_matches_manual_div() {
468        let cfg = tiny_cfg();
469        let v = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
470        let x = Tensor::from_storage(
471            TensorStorage::cpu(vec![0.05f32; 4]),
472            vec![1, 4, 1, 1],
473            false,
474        )
475        .unwrap();
476        let inv = (1.0 / cfg.scaling_factor) as f32;
477        let scaled_data: Vec<f32> = x.data().unwrap().iter().map(|&v| v * inv).collect();
478        let scaled =
479            Tensor::from_storage(TensorStorage::cpu(scaled_data), vec![1, 4, 1, 1], false).unwrap();
480        let a = v.decode_with_scaling(&x).unwrap();
481        let b = v.forward(&scaled).unwrap();
482        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
483            assert!(
484                (x - y).abs() < 1e-4,
485                "decode_with_scaling vs manual div differ: {x} vs {y}"
486            );
487        }
488    }
489
490    #[test]
491    fn round_trip_state_dict() {
492        let cfg = tiny_cfg();
493        let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
494        let sd = src.state_dict();
495        let mut dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
496        dst.load_state_dict(&sd, true).unwrap();
497        let x = Tensor::from_storage(
498            TensorStorage::cpu(vec![0.01f32; 4]),
499            vec![1, 4, 1, 1],
500            false,
501        )
502        .unwrap();
503        let a = src.forward(&x).unwrap();
504        let b = dst.forward(&x).unwrap();
505        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
506            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
507        }
508    }
509}