Skip to main content

ferrotorch_diffusion/
blocks.rs

1//! Building blocks of the Stable-Diffusion VAE decoder.
2//!
3//! All blocks here match the diffusers `models/{resnet,upsampling,
4//! attention_processor}.py` / `models/unets/unet_2d_blocks.py` reference
5//! layout 1:1 in parameter naming and forward semantics so the upstream
6//! state dict (`runwayml/stable-diffusion-v1-5/vae/diffusion_pytorch_model.safetensors`)
7//! loads byte-for-byte.
8
9use std::collections::HashMap;
10
11use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
12use ferrotorch_nn::module::{Module, StateDict};
13use ferrotorch_nn::parameter::Parameter;
14use ferrotorch_nn::{
15    Conv2d, GroupNorm, InterpolateMode, Linear, SiLU, Upsample,
16};
17
18// ---------------------------------------------------------------------------
19// ResnetBlock2D
20// ---------------------------------------------------------------------------
21
22/// `ResnetBlock2D` — the building block of every UNet/VAE up/down/mid stack.
23///
24/// The VAE flavour has no time embedding (`temb_channels=None` in
25/// diffusers), so this implementation does not carry a
26/// `time_emb_proj`. The forward pass is:
27///
28/// ```text
29/// h = norm1(x); h = silu(h); h = conv1(h)
30/// h = norm2(h); h = silu(h); h = conv2(h)
31/// out = h + (x if in==out else conv_shortcut(x))
32/// ```
33///
34/// The `output_scale_factor` from the diffusers reference is always 1.0
35/// in the SD VAE so we hard-code it; if a future config requires
36/// rescaling, add an explicit field.
37#[derive(Debug)]
38pub struct ResnetBlock2D<T: Float> {
39    /// First GroupNorm (over `in_channels`).
40    pub norm1: GroupNorm<T>,
41    /// First Conv2d (`in_channels -> out_channels`, k=3, pad=1).
42    pub conv1: Conv2d<T>,
43    /// Second GroupNorm (over `out_channels`).
44    pub norm2: GroupNorm<T>,
45    /// Second Conv2d (`out_channels -> out_channels`, k=3, pad=1).
46    pub conv2: Conv2d<T>,
47    /// Optional 1x1 shortcut conv (present iff `in_channels !=
48    /// out_channels`).
49    pub conv_shortcut: Option<Conv2d<T>>,
50    activation: SiLU,
51    in_channels: usize,
52    #[allow(dead_code)]
53    out_channels: usize,
54    training: bool,
55}
56
57impl<T: Float> ResnetBlock2D<T> {
58    /// Build a randomly-initialized `ResnetBlock2D`.
59    ///
60    /// # Errors
61    ///
62    /// Returns [`FerrotorchError::InvalidArgument`] on bad channel /
63    /// group config.
64    pub fn new(
65        in_channels: usize,
66        out_channels: usize,
67        norm_num_groups: usize,
68        eps: f64,
69    ) -> FerrotorchResult<Self> {
70        let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
71        let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
72        let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
73        let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
74        let conv_shortcut = if in_channels == out_channels {
75            None
76        } else {
77            Some(Conv2d::<T>::new(
78                in_channels,
79                out_channels,
80                (1, 1),
81                (1, 1),
82                (0, 0),
83                true,
84            )?)
85        };
86        Ok(Self {
87            norm1,
88            conv1,
89            norm2,
90            conv2,
91            conv_shortcut,
92            activation: SiLU::new(),
93            in_channels,
94            out_channels,
95            training: false,
96        })
97    }
98}
99
100impl<T: Float> Module<T> for ResnetBlock2D<T> {
101    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
102        if input.ndim() != 4 || input.shape()[1] != self.in_channels {
103            return Err(FerrotorchError::ShapeMismatch {
104                message: format!(
105                    "ResnetBlock2D::forward: expected [B, {}, H, W], got {:?}",
106                    self.in_channels,
107                    input.shape()
108                ),
109            });
110        }
111        // h = norm1(x); silu; conv1.
112        let mut h = self.norm1.forward(input)?;
113        h = self.activation.forward(&h)?;
114        h = self.conv1.forward(&h)?;
115        // h = norm2(h); silu; conv2 (no dropout at eval / inference time).
116        h = self.norm2.forward(&h)?;
117        h = self.activation.forward(&h)?;
118        h = self.conv2.forward(&h)?;
119        // Residual (optionally projected via 1x1 conv).
120        let res = if let Some(sc) = &self.conv_shortcut {
121            sc.forward(input)?
122        } else {
123            input.clone()
124        };
125        // out = h + res. `output_scale_factor=1.0` in the SD VAE so no
126        // division on the way out.
127        ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
128    }
129
130    fn parameters(&self) -> Vec<&Parameter<T>> {
131        let mut out = Vec::new();
132        out.extend(self.norm1.parameters());
133        out.extend(self.conv1.parameters());
134        out.extend(self.norm2.parameters());
135        out.extend(self.conv2.parameters());
136        if let Some(sc) = &self.conv_shortcut {
137            out.extend(sc.parameters());
138        }
139        out
140    }
141
142    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
143        let mut out = Vec::new();
144        out.extend(self.norm1.parameters_mut());
145        out.extend(self.conv1.parameters_mut());
146        out.extend(self.norm2.parameters_mut());
147        out.extend(self.conv2.parameters_mut());
148        if let Some(sc) = &mut self.conv_shortcut {
149            out.extend(sc.parameters_mut());
150        }
151        out
152    }
153
154    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
155        let mut out = Vec::new();
156        for (n, p) in self.norm1.named_parameters() {
157            out.push((format!("norm1.{n}"), p));
158        }
159        for (n, p) in self.conv1.named_parameters() {
160            out.push((format!("conv1.{n}"), p));
161        }
162        for (n, p) in self.norm2.named_parameters() {
163            out.push((format!("norm2.{n}"), p));
164        }
165        for (n, p) in self.conv2.named_parameters() {
166            out.push((format!("conv2.{n}"), p));
167        }
168        if let Some(sc) = &self.conv_shortcut {
169            for (n, p) in sc.named_parameters() {
170                out.push((format!("conv_shortcut.{n}"), p));
171            }
172        }
173        out
174    }
175
176    fn train(&mut self) {
177        self.training = true;
178    }
179    fn eval(&mut self) {
180        self.training = false;
181    }
182    fn is_training(&self) -> bool {
183        self.training
184    }
185
186    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
187        let extract = |prefix: &str| -> StateDict<T> {
188            let p = format!("{prefix}.");
189            state
190                .iter()
191                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
192                .collect()
193        };
194        if strict {
195            for k in state.keys() {
196                let ok = k.starts_with("norm1.")
197                    || k.starts_with("conv1.")
198                    || k.starts_with("norm2.")
199                    || k.starts_with("conv2.")
200                    || k.starts_with("conv_shortcut.");
201                if !ok {
202                    return Err(FerrotorchError::InvalidArgument {
203                        message: format!("unexpected key in ResnetBlock2D state_dict: \"{k}\""),
204                    });
205                }
206            }
207        }
208        self.norm1.load_state_dict(&extract("norm1"), strict)?;
209        self.conv1.load_state_dict(&extract("conv1"), strict)?;
210        self.norm2.load_state_dict(&extract("norm2"), strict)?;
211        self.conv2.load_state_dict(&extract("conv2"), strict)?;
212        if let Some(sc) = self.conv_shortcut.as_mut() {
213            sc.load_state_dict(&extract("conv_shortcut"), strict)?;
214        }
215        Ok(())
216    }
217}
218
219// ---------------------------------------------------------------------------
220// AttnBlock2D
221// ---------------------------------------------------------------------------
222
223/// Single-head spatial self-attention with residual + GroupNorm — the
224/// VAE mid-block attention.
225///
226/// Matches `diffusers.models.attention_processor.Attention` configured
227/// with:
228///
229///   - `heads = in_channels / attention_head_dim = 512 / 512 = 1`
230///   - `norm_num_groups = 32` (so `group_norm` is enabled)
231///   - `residual_connection = True`
232///   - `bias = True`, `out_bias = True`
233///   - `eps = 1e-6`
234///
235/// State-dict layout (HF diffusers):
236///
237/// ```text
238/// group_norm.{weight,bias}    [C], [C]
239/// to_q.{weight,bias}          [C, C], [C]
240/// to_k.{weight,bias}          [C, C], [C]
241/// to_v.{weight,bias}          [C, C], [C]
242/// to_out.0.{weight,bias}      [C, C], [C]      // to_out[1] is Dropout (no params)
243/// ```
244#[derive(Debug)]
245pub struct AttnBlock2D<T: Float> {
246    /// GroupNorm over the channel axis.
247    pub group_norm: GroupNorm<T>,
248    /// Query projection `Linear(C -> C, bias)`.
249    pub to_q: Linear<T>,
250    /// Key projection `Linear(C -> C, bias)`.
251    pub to_k: Linear<T>,
252    /// Value projection `Linear(C -> C, bias)`.
253    pub to_v: Linear<T>,
254    /// Output projection `to_out[0] = Linear(C -> C, bias)`.
255    pub to_out_0: Linear<T>,
256    channels: usize,
257    training: bool,
258}
259
260impl<T: Float> AttnBlock2D<T> {
261    /// Build a randomly-initialized `AttnBlock2D`.
262    ///
263    /// # Errors
264    ///
265    /// Returns [`FerrotorchError::InvalidArgument`] when `channels`
266    /// is not divisible by `norm_num_groups` (the GroupNorm constructor
267    /// surfaces this).
268    pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
269        let group_norm = GroupNorm::<T>::new(norm_num_groups, channels, eps, true)?;
270        let to_q = Linear::<T>::new(channels, channels, true)?;
271        let to_k = Linear::<T>::new(channels, channels, true)?;
272        let to_v = Linear::<T>::new(channels, channels, true)?;
273        let to_out_0 = Linear::<T>::new(channels, channels, true)?;
274        Ok(Self {
275            group_norm,
276            to_q,
277            to_k,
278            to_v,
279            to_out_0,
280            channels,
281            training: false,
282        })
283    }
284}
285
286impl<T: Float> Module<T> for AttnBlock2D<T> {
287    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
288        if input.ndim() != 4 || input.shape()[1] != self.channels {
289            return Err(FerrotorchError::ShapeMismatch {
290                message: format!(
291                    "AttnBlock2D::forward: expected [B, {}, H, W], got {:?}",
292                    self.channels,
293                    input.shape()
294                ),
295            });
296        }
297        let b = input.shape()[0];
298        let c = input.shape()[1];
299        let h = input.shape()[2];
300        let w = input.shape()[3];
301        let hw = h * w;
302
303        // Residual (kept in the spatial layout).
304        let residual = input.clone();
305
306        // -- 1. Reshape to [B, HW, C]:
307        //       diffusers does `view(B, C, HW).transpose(1, 2)`.
308        //       Then it group-norms on the channel axis by re-transposing
309        //       back to [B, C, HW], norming, and re-transposing.
310        let hidden = input
311            .reshape_t(&[b as isize, c as isize, hw as isize])?
312            .transpose(1, 2)?
313            .contiguous()?;
314        // GroupNorm: input [B, C, HW] -> [B, HW, C] after the final transpose.
315        let normed_hwc = self
316            .group_norm
317            .forward(&hidden.transpose(1, 2)?.contiguous()?)?
318            .transpose(1, 2)?
319            .contiguous()?;
320
321        // -- 2. q / k / v projections (Linear over the trailing C dim).
322        let q = self.to_q.forward(&normed_hwc)?; // [B, HW, C]
323        let k = self.to_k.forward(&normed_hwc)?; // [B, HW, C]
324        let v = self.to_v.forward(&normed_hwc)?; // [B, HW, C]
325
326        // -- 3. Single-head attention:
327        //          scores = q @ k^T * (1/sqrt(C))
328        //          probs  = softmax(scores, dim=-1)
329        //          out    = probs @ v
330        //       For single-head the head-split is a no-op, so this is
331        //       just a per-batch BMM. `bmm` handles the [B, HW, HW]
332        //       intermediate without materialising a head axis.
333        let scale = (c as f64).sqrt().recip();
334        let scale_t = T::from(scale).ok_or_else(|| FerrotorchError::InvalidArgument {
335            message: "AttnBlock2D::forward: failed to cast attention scale into Float".into(),
336        })?;
337        let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
338        let k_t = k.transpose(1, 2)?.contiguous()?; // [B, C, HW]
339        let scores = q.bmm(&k_t)?; // [B, HW, HW]
340        let scores_scaled =
341            ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
342        let probs = scores_scaled.softmax()?; // dim = -1 by default
343        let attended = probs.bmm(&v)?; // [B, HW, C]
344
345        // -- 4. Output projection.
346        let projected = self.to_out_0.forward(&attended)?; // [B, HW, C]
347
348        // -- 5. Back to [B, C, H, W] and add residual.
349        let back = projected
350            .transpose(1, 2)? // [B, C, HW]
351            .reshape_t(&[b as isize, c as isize, h as isize, w as isize])?
352            .contiguous()?;
353        ferrotorch_core::grad_fns::arithmetic::add(&back, &residual)
354    }
355
356    fn parameters(&self) -> Vec<&Parameter<T>> {
357        let mut out = Vec::new();
358        out.extend(self.group_norm.parameters());
359        out.extend(self.to_q.parameters());
360        out.extend(self.to_k.parameters());
361        out.extend(self.to_v.parameters());
362        out.extend(self.to_out_0.parameters());
363        out
364    }
365
366    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
367        let mut out = Vec::new();
368        out.extend(self.group_norm.parameters_mut());
369        out.extend(self.to_q.parameters_mut());
370        out.extend(self.to_k.parameters_mut());
371        out.extend(self.to_v.parameters_mut());
372        out.extend(self.to_out_0.parameters_mut());
373        out
374    }
375
376    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
377        let mut out = Vec::new();
378        for (n, p) in self.group_norm.named_parameters() {
379            out.push((format!("group_norm.{n}"), p));
380        }
381        for (n, p) in self.to_q.named_parameters() {
382            out.push((format!("to_q.{n}"), p));
383        }
384        for (n, p) in self.to_k.named_parameters() {
385            out.push((format!("to_k.{n}"), p));
386        }
387        for (n, p) in self.to_v.named_parameters() {
388            out.push((format!("to_v.{n}"), p));
389        }
390        for (n, p) in self.to_out_0.named_parameters() {
391            out.push((format!("to_out.0.{n}"), p));
392        }
393        out
394    }
395
396    fn train(&mut self) {
397        self.training = true;
398    }
399    fn eval(&mut self) {
400        self.training = false;
401    }
402    fn is_training(&self) -> bool {
403        self.training
404    }
405
406    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
407        let extract = |prefix: &str| -> StateDict<T> {
408            let p = format!("{prefix}.");
409            state
410                .iter()
411                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
412                .collect()
413        };
414        if strict {
415            for k in state.keys() {
416                let ok = k.starts_with("group_norm.")
417                    || k.starts_with("to_q.")
418                    || k.starts_with("to_k.")
419                    || k.starts_with("to_v.")
420                    || k.starts_with("to_out.0.");
421                if !ok {
422                    return Err(FerrotorchError::InvalidArgument {
423                        message: format!("unexpected key in AttnBlock2D state_dict: \"{k}\""),
424                    });
425                }
426            }
427        }
428        self.group_norm
429            .load_state_dict(&extract("group_norm"), strict)?;
430        self.to_q.load_state_dict(&extract("to_q"), strict)?;
431        self.to_k.load_state_dict(&extract("to_k"), strict)?;
432        self.to_v.load_state_dict(&extract("to_v"), strict)?;
433        self.to_out_0
434            .load_state_dict(&extract("to_out.0"), strict)?;
435        Ok(())
436    }
437}
438
439// ---------------------------------------------------------------------------
440// Upsample2D
441// ---------------------------------------------------------------------------
442
443/// Diffusers-style `Upsample2D` — nearest-neighbor 2x interpolation
444/// followed by a `Conv2d(C, C, k=3, pad=1, bias=True)`.
445///
446/// State-dict key: `conv.{weight,bias}` (matching `name="conv"` in the
447/// diffusers reference).
448#[derive(Debug)]
449pub struct Upsample2D<T: Float> {
450    /// Output conv (also serves as the "post-upsample smoothing" filter).
451    pub conv: Conv2d<T>,
452    channels: usize,
453    training: bool,
454}
455
456impl<T: Float> Upsample2D<T> {
457    /// Build a randomly-initialized `Upsample2D` (`C -> C`, k=3, pad=1).
458    ///
459    /// # Errors
460    ///
461    /// Returns the underlying [`FerrotorchError`] on bad conv config.
462    pub fn new(channels: usize) -> FerrotorchResult<Self> {
463        let conv = Conv2d::<T>::new(channels, channels, (3, 3), (1, 1), (1, 1), true)?;
464        Ok(Self {
465            conv,
466            channels,
467            training: false,
468        })
469    }
470}
471
472impl<T: Float> Module<T> for Upsample2D<T> {
473    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
474        if input.ndim() != 4 || input.shape()[1] != self.channels {
475            return Err(FerrotorchError::ShapeMismatch {
476                message: format!(
477                    "Upsample2D::forward: expected [B, {}, H, W], got {:?}",
478                    self.channels,
479                    input.shape()
480                ),
481            });
482        }
483        let h = input.shape()[2];
484        let w = input.shape()[3];
485        let up = Upsample::new([h * 2, w * 2], InterpolateMode::Nearest);
486        let upsampled = up.forward(input)?;
487        self.conv.forward(&upsampled)
488    }
489
490    fn parameters(&self) -> Vec<&Parameter<T>> {
491        self.conv.parameters()
492    }
493    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
494        self.conv.parameters_mut()
495    }
496    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
497        self.conv
498            .named_parameters()
499            .into_iter()
500            .map(|(n, p)| (format!("conv.{n}"), p))
501            .collect()
502    }
503    fn train(&mut self) {
504        self.training = true;
505    }
506    fn eval(&mut self) {
507        self.training = false;
508    }
509    fn is_training(&self) -> bool {
510        self.training
511    }
512    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
513        let conv_sd: StateDict<T> = state
514            .iter()
515            .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
516            .collect();
517        if strict {
518            for k in state.keys() {
519                if !k.starts_with("conv.") {
520                    return Err(FerrotorchError::InvalidArgument {
521                        message: format!("unexpected key in Upsample2D state_dict: \"{k}\""),
522                    });
523                }
524            }
525        }
526        self.conv.load_state_dict(&conv_sd, strict)
527    }
528}
529
530// ---------------------------------------------------------------------------
531// Downsample2D
532// ---------------------------------------------------------------------------
533
534/// Diffusers-style `Downsample2D` — a single `Conv2d(C, C, k=3, stride=2,
535/// pad=1, bias=True)`.
536///
537/// SD-1.5 uses `use_conv=True` and `padding=1`, so there is no separate
538/// pre-conv padding step. State-dict key: `conv.{weight,bias}`.
539#[derive(Debug)]
540pub struct Downsample2D<T: Float> {
541    /// Output conv (k=3, stride=2, pad=1).
542    pub conv: Conv2d<T>,
543    channels: usize,
544    training: bool,
545}
546
547impl<T: Float> Downsample2D<T> {
548    /// Build a randomly-initialized `Downsample2D` (`C -> C`).
549    ///
550    /// # Errors
551    ///
552    /// Returns the underlying [`FerrotorchError`] on bad conv config.
553    pub fn new(channels: usize) -> FerrotorchResult<Self> {
554        let conv = Conv2d::<T>::new(channels, channels, (3, 3), (2, 2), (1, 1), true)?;
555        Ok(Self {
556            conv,
557            channels,
558            training: false,
559        })
560    }
561}
562
563impl<T: Float> Module<T> for Downsample2D<T> {
564    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
565        if input.ndim() != 4 || input.shape()[1] != self.channels {
566            return Err(FerrotorchError::ShapeMismatch {
567                message: format!(
568                    "Downsample2D::forward: expected [B, {}, H, W], got {:?}",
569                    self.channels,
570                    input.shape()
571                ),
572            });
573        }
574        self.conv.forward(input)
575    }
576    fn parameters(&self) -> Vec<&Parameter<T>> {
577        self.conv.parameters()
578    }
579    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
580        self.conv.parameters_mut()
581    }
582    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
583        self.conv
584            .named_parameters()
585            .into_iter()
586            .map(|(n, p)| (format!("conv.{n}"), p))
587            .collect()
588    }
589    fn train(&mut self) {
590        self.training = true;
591    }
592    fn eval(&mut self) {
593        self.training = false;
594    }
595    fn is_training(&self) -> bool {
596        self.training
597    }
598    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
599        let conv_sd: StateDict<T> = state
600            .iter()
601            .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
602            .collect();
603        if strict {
604            for k in state.keys() {
605                if !k.starts_with("conv.") {
606                    return Err(FerrotorchError::InvalidArgument {
607                        message: format!("unexpected key in Downsample2D state_dict: \"{k}\""),
608                    });
609                }
610            }
611        }
612        self.conv.load_state_dict(&conv_sd, strict)
613    }
614}
615
616// ---------------------------------------------------------------------------
617// UpDecoderBlock2D
618// ---------------------------------------------------------------------------
619
620/// `UpDecoderBlock2D` — a stack of `layers_per_block + 1` resnets at
621/// `out_channels`, optionally followed by an `Upsample2D`.
622///
623/// State-dict key layout:
624///
625/// ```text
626/// resnets.{i}.<resnet keys>    i in 0..(layers_per_block + 1)
627/// upsamplers.0.<upsample keys>  (present iff add_upsample)
628/// ```
629#[derive(Debug)]
630pub struct UpDecoderBlock2D<T: Float> {
631    /// `layers_per_block + 1` resnets at `out_channels`.
632    pub resnets: Vec<ResnetBlock2D<T>>,
633    /// Optional upsample (absent on the last decoder block).
634    pub upsamplers_0: Option<Upsample2D<T>>,
635    training: bool,
636}
637
638impl<T: Float> UpDecoderBlock2D<T> {
639    /// Build a randomly-initialized `UpDecoderBlock2D`.
640    ///
641    /// # Errors
642    ///
643    /// Returns [`FerrotorchError::InvalidArgument`] on bad channel /
644    /// group config.
645    pub fn new(
646        in_channels: usize,
647        out_channels: usize,
648        num_resnets: usize,
649        norm_num_groups: usize,
650        eps: f64,
651        add_upsample: bool,
652    ) -> FerrotorchResult<Self> {
653        if num_resnets == 0 {
654            return Err(FerrotorchError::InvalidArgument {
655                message: "UpDecoderBlock2D: num_resnets must be > 0".into(),
656            });
657        }
658        let mut resnets = Vec::with_capacity(num_resnets);
659        for i in 0..num_resnets {
660            let in_c = if i == 0 { in_channels } else { out_channels };
661            resnets.push(ResnetBlock2D::<T>::new(
662                in_c,
663                out_channels,
664                norm_num_groups,
665                eps,
666            )?);
667        }
668        let upsamplers_0 = if add_upsample {
669            Some(Upsample2D::<T>::new(out_channels)?)
670        } else {
671            None
672        };
673        Ok(Self {
674            resnets,
675            upsamplers_0,
676            training: false,
677        })
678    }
679}
680
681impl<T: Float> Module<T> for UpDecoderBlock2D<T> {
682    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
683        let mut h = input.clone();
684        for r in &self.resnets {
685            h = r.forward(&h)?;
686        }
687        if let Some(u) = &self.upsamplers_0 {
688            h = u.forward(&h)?;
689        }
690        Ok(h)
691    }
692
693    fn parameters(&self) -> Vec<&Parameter<T>> {
694        let mut out = Vec::new();
695        for r in &self.resnets {
696            out.extend(r.parameters());
697        }
698        if let Some(u) = &self.upsamplers_0 {
699            out.extend(u.parameters());
700        }
701        out
702    }
703
704    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
705        let mut out = Vec::new();
706        for r in &mut self.resnets {
707            out.extend(r.parameters_mut());
708        }
709        if let Some(u) = self.upsamplers_0.as_mut() {
710            out.extend(u.parameters_mut());
711        }
712        out
713    }
714
715    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
716        let mut out = Vec::new();
717        for (i, r) in self.resnets.iter().enumerate() {
718            for (n, p) in r.named_parameters() {
719                out.push((format!("resnets.{i}.{n}"), p));
720            }
721        }
722        if let Some(u) = &self.upsamplers_0 {
723            for (n, p) in u.named_parameters() {
724                out.push((format!("upsamplers.0.{n}"), p));
725            }
726        }
727        out
728    }
729
730    fn train(&mut self) {
731        self.training = true;
732        for r in &mut self.resnets {
733            r.train();
734        }
735        if let Some(u) = self.upsamplers_0.as_mut() {
736            u.train();
737        }
738    }
739    fn eval(&mut self) {
740        self.training = false;
741        for r in &mut self.resnets {
742            r.eval();
743        }
744        if let Some(u) = self.upsamplers_0.as_mut() {
745            u.eval();
746        }
747    }
748    fn is_training(&self) -> bool {
749        self.training
750    }
751
752    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
753        let extract = |prefix: &str| -> StateDict<T> {
754            let p = format!("{prefix}.");
755            state
756                .iter()
757                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
758                .collect()
759        };
760        if strict {
761            for k in state.keys() {
762                let ok =
763                    k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
764                if !ok {
765                    return Err(FerrotorchError::InvalidArgument {
766                        message: format!(
767                            "unexpected key in UpDecoderBlock2D state_dict: \"{k}\""
768                        ),
769                    });
770                }
771            }
772        }
773        for (i, r) in self.resnets.iter_mut().enumerate() {
774            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
775        }
776        if let Some(u) = self.upsamplers_0.as_mut() {
777            u.load_state_dict(&extract("upsamplers.0"), strict)?;
778        } else if strict {
779            // Verify no upsamplers keys leaked through if we have no upsample.
780            for k in state.keys() {
781                if k.starts_with("upsamplers.") {
782                    return Err(FerrotorchError::InvalidArgument {
783                        message: format!(
784                            "UpDecoderBlock2D has no upsampler but state_dict contains \"{k}\""
785                        ),
786                    });
787                }
788            }
789        }
790        Ok(())
791    }
792}
793
794// ---------------------------------------------------------------------------
795// UNetMidBlock2D (VAE flavour)
796// ---------------------------------------------------------------------------
797
798/// `UNetMidBlock2D` configured the way the SD VAE uses it:
799///
800/// ```text
801/// resnets[0] -> attentions[0] -> resnets[1]
802/// ```
803///
804/// 2 resnets, 1 attention; all at `in_channels = out_channels = 512` for
805/// SD 1.5. Diffusers always has `len(resnets) == num_layers + 1` and
806/// `len(attentions) == num_layers`; in the VAE `num_layers = 1`, so two
807/// resnets and one attention.
808#[derive(Debug)]
809pub struct UNetMidBlock2D<T: Float> {
810    /// Pre-attention + post-attention resnets (size 2 for SD VAE).
811    pub resnets: Vec<ResnetBlock2D<T>>,
812    /// Single attention block.
813    pub attentions: Vec<AttnBlock2D<T>>,
814    channels: usize,
815    training: bool,
816}
817
818impl<T: Float> UNetMidBlock2D<T> {
819    /// Build a randomly-initialized VAE-flavoured `UNetMidBlock2D`.
820    ///
821    /// # Errors
822    ///
823    /// Returns the underlying [`FerrotorchError`] on bad config.
824    pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
825        // SD VAE: num_layers=1 => 2 resnets + 1 attention.
826        let resnets = vec![
827            ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
828            ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
829        ];
830        let attentions = vec![AttnBlock2D::<T>::new(channels, norm_num_groups, eps)?];
831        Ok(Self {
832            resnets,
833            attentions,
834            channels,
835            training: false,
836        })
837    }
838}
839
840impl<T: Float> Module<T> for UNetMidBlock2D<T> {
841    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
842        // diffusers UNetMidBlock2D.forward (no temb):
843        //   h = resnets[0](x)
844        //   for (attn, resnet) in zip(attentions, resnets[1:]):
845        //       h = attn(h); h = resnet(h)
846        let mut h = self.resnets[0].forward(input)?;
847        for (attn, resnet) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
848            h = attn.forward(&h)?;
849            h = resnet.forward(&h)?;
850        }
851        Ok(h)
852    }
853
854    fn parameters(&self) -> Vec<&Parameter<T>> {
855        let mut out = Vec::new();
856        // The HF diffusers state-dict key order is `attentions` before
857        // `resnets`. We match the ordering used for state-dict
858        // round-trip — see `named_parameters()` for the canonical
859        // sequence.
860        for a in &self.attentions {
861            out.extend(a.parameters());
862        }
863        for r in &self.resnets {
864            out.extend(r.parameters());
865        }
866        out
867    }
868
869    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
870        let mut out = Vec::new();
871        for a in &mut self.attentions {
872            out.extend(a.parameters_mut());
873        }
874        for r in &mut self.resnets {
875            out.extend(r.parameters_mut());
876        }
877        out
878    }
879
880    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
881        let mut out = Vec::new();
882        // Match the order used by `safetensors` listings of the HF VAE:
883        // `attentions.*` first, then `resnets.*`.
884        for (i, a) in self.attentions.iter().enumerate() {
885            for (n, p) in a.named_parameters() {
886                out.push((format!("attentions.{i}.{n}"), p));
887            }
888        }
889        for (i, r) in self.resnets.iter().enumerate() {
890            for (n, p) in r.named_parameters() {
891                out.push((format!("resnets.{i}.{n}"), p));
892            }
893        }
894        out
895    }
896
897    fn train(&mut self) {
898        self.training = true;
899        for r in &mut self.resnets {
900            r.train();
901        }
902        for a in &mut self.attentions {
903            a.train();
904        }
905    }
906    fn eval(&mut self) {
907        self.training = false;
908        for r in &mut self.resnets {
909            r.eval();
910        }
911        for a in &mut self.attentions {
912            a.eval();
913        }
914    }
915    fn is_training(&self) -> bool {
916        self.training
917    }
918
919    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
920        let extract = |prefix: &str| -> StateDict<T> {
921            let p = format!("{prefix}.");
922            state
923                .iter()
924                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
925                .collect()
926        };
927        if strict {
928            for k in state.keys() {
929                let ok = k.starts_with("attentions.") || k.starts_with("resnets.");
930                if !ok {
931                    return Err(FerrotorchError::InvalidArgument {
932                        message: format!(
933                            "unexpected key in UNetMidBlock2D state_dict: \"{k}\""
934                        ),
935                    });
936                }
937            }
938        }
939        for (i, a) in self.attentions.iter_mut().enumerate() {
940            a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
941        }
942        for (i, r) in self.resnets.iter_mut().enumerate() {
943            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
944        }
945        let _ = self.channels; // silence dead_code warning under some lint configs
946        let _: HashMap<String, Tensor<T>> = HashMap::new(); // pull HashMap into scope
947        Ok(())
948    }
949}
950
951#[cfg(test)]
952mod tests {
953    use super::*;
954    use ferrotorch_core::TensorStorage;
955
956    #[test]
957    fn resnet_same_channels_no_shortcut() {
958        let r = ResnetBlock2D::<f32>::new(32, 32, 32, 1e-6).unwrap();
959        assert!(r.conv_shortcut.is_none());
960        let x = Tensor::from_storage(
961            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
962            vec![1, 32, 4, 4],
963            false,
964        )
965        .unwrap();
966        let y = r.forward(&x).unwrap();
967        assert_eq!(y.shape(), &[1, 32, 4, 4]);
968        for &v in y.data().unwrap() {
969            assert!(v.is_finite());
970        }
971    }
972
973    #[test]
974    fn resnet_different_channels_has_shortcut() {
975        let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
976        assert!(r.conv_shortcut.is_some());
977        let x = Tensor::from_storage(
978            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
979            vec![1, 32, 4, 4],
980            false,
981        )
982        .unwrap();
983        let y = r.forward(&x).unwrap();
984        assert_eq!(y.shape(), &[1, 64, 4, 4]);
985    }
986
987    #[test]
988    fn resnet_named_parameters_layout() {
989        let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
990        let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
991        for k in [
992            "norm1.weight",
993            "norm1.bias",
994            "conv1.weight",
995            "conv1.bias",
996            "norm2.weight",
997            "norm2.bias",
998            "conv2.weight",
999            "conv2.bias",
1000            "conv_shortcut.weight",
1001            "conv_shortcut.bias",
1002        ] {
1003            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1004        }
1005    }
1006
1007    #[test]
1008    fn attn_shape_and_residual() {
1009        // 32 channels, 4 groups => 8 channels per group; small spatial.
1010        let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1011        let x = Tensor::from_storage(
1012            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1013            vec![1, 32, 4, 4],
1014            false,
1015        )
1016        .unwrap();
1017        let y = a.forward(&x).unwrap();
1018        assert_eq!(y.shape(), &[1, 32, 4, 4]);
1019        for &v in y.data().unwrap() {
1020            assert!(v.is_finite(), "attn output non-finite: {v}");
1021        }
1022    }
1023
1024    #[test]
1025    fn attn_named_parameters_layout() {
1026        let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1027        let names: Vec<String> = a.named_parameters().into_iter().map(|(n, _)| n).collect();
1028        for k in [
1029            "group_norm.weight",
1030            "group_norm.bias",
1031            "to_q.weight",
1032            "to_q.bias",
1033            "to_k.weight",
1034            "to_k.bias",
1035            "to_v.weight",
1036            "to_v.bias",
1037            "to_out.0.weight",
1038            "to_out.0.bias",
1039        ] {
1040            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1041        }
1042    }
1043
1044    #[test]
1045    fn upsample2d_doubles_spatial() {
1046        let u = Upsample2D::<f32>::new(8).unwrap();
1047        let x = Tensor::from_storage(
1048            TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1049            vec![1, 8, 3, 3],
1050            false,
1051        )
1052        .unwrap();
1053        let y = u.forward(&x).unwrap();
1054        assert_eq!(y.shape(), &[1, 8, 6, 6]);
1055    }
1056
1057    #[test]
1058    fn up_decoder_block_shape_with_upsample() {
1059        let b = UpDecoderBlock2D::<f32>::new(16, 8, 2, 4, 1e-6, true).unwrap();
1060        let x = Tensor::from_storage(
1061            TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1062            vec![1, 16, 3, 3],
1063            false,
1064        )
1065        .unwrap();
1066        let y = b.forward(&x).unwrap();
1067        assert_eq!(y.shape(), &[1, 8, 6, 6]);
1068    }
1069
1070    #[test]
1071    fn up_decoder_block_shape_no_upsample() {
1072        let b = UpDecoderBlock2D::<f32>::new(8, 8, 2, 4, 1e-6, false).unwrap();
1073        let x = Tensor::from_storage(
1074            TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1075            vec![1, 8, 3, 3],
1076            false,
1077        )
1078        .unwrap();
1079        let y = b.forward(&x).unwrap();
1080        assert_eq!(y.shape(), &[1, 8, 3, 3]);
1081    }
1082
1083    #[test]
1084    fn mid_block_shape() {
1085        let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1086        let x = Tensor::from_storage(
1087            TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1088            vec![1, 16, 3, 3],
1089            false,
1090        )
1091        .unwrap();
1092        let y = m.forward(&x).unwrap();
1093        assert_eq!(y.shape(), &[1, 16, 3, 3]);
1094    }
1095
1096    #[test]
1097    fn mid_block_named_parameters_layout() {
1098        let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1099        let names: Vec<String> = m.named_parameters().into_iter().map(|(n, _)| n).collect();
1100        for k in [
1101            "attentions.0.group_norm.weight",
1102            "attentions.0.to_q.weight",
1103            "resnets.0.norm1.weight",
1104            "resnets.0.conv1.weight",
1105            "resnets.1.conv2.weight",
1106        ] {
1107            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1108        }
1109    }
1110}