Skip to main content

ferrotorch_diffusion/
blocks.rs

1//! Building blocks of the Stable-Diffusion VAE decoder.
2//!
3//! All blocks here match the diffusers `models/{resnet,upsampling,
4//! attention_processor}.py` / `models/unets/unet_2d_blocks.py` reference
5//! layout 1:1 in parameter naming and forward semantics so the upstream
6//! state dict (`runwayml/stable-diffusion-v1-5/vae/diffusion_pytorch_model.safetensors`)
7//! loads byte-for-byte.
8//!
9//! ## REQ status (per `.design/ferrotorch-diffusion/blocks.md`)
10//!
11//! | REQ | Status | Evidence |
12//! |---|---|---|
13//! | REQ-1 | SHIPPED | `ResnetBlock2D` at `ferrotorch-diffusion/src/blocks.rs:35..215`; consumer: `ferrotorch-diffusion/src/vae.rs` and `vae_encoder.rs` consume it transitively via `UpDecoderBlock2D`/`DownEncoderBlock2D`/`UNetMidBlock2D` |
14//! | REQ-2 | SHIPPED | `AttnBlock2D` at `ferrotorch-diffusion/src/blocks.rs:242..434`; consumer: `UNetMidBlock2D::new` at `blocks.rs:1010` calls `AttnBlock2D::<T>::new` and `ferrotorch-diffusion/src/vae.rs:83` builds it |
15//! | REQ-3 | SHIPPED | `Upsample2D` at `ferrotorch-diffusion/src/blocks.rs:446..525`; consumer: `ferrotorch-diffusion/src/unet.rs:48` imports it for UNet up blocks |
16//! | REQ-4 | SHIPPED | `Downsample2D` at `ferrotorch-diffusion/src/blocks.rs:536..611`; consumer: `ferrotorch-diffusion/src/unet.rs:48` imports it for UNet down blocks |
17//! | REQ-5 | SHIPPED | `UpDecoderBlock2D` at `ferrotorch-diffusion/src/blocks.rs:626..786`; consumer: `ferrotorch-diffusion/src/vae.rs:92` constructs each decoder up-block |
18//! | REQ-6 | SHIPPED | `DownEncoderBlock2D` at `ferrotorch-diffusion/src/blocks.rs:806..972`; consumer: `ferrotorch-diffusion/src/vae_encoder.rs:123` constructs each encoder down-block |
19//! | REQ-7 | SHIPPED | `UNetMidBlock2D` at `ferrotorch-diffusion/src/blocks.rs:988..1127`; consumer: `ferrotorch-diffusion/src/vae.rs:83` and `vae_encoder.rs:130` both invoke `UNetMidBlock2D::<T>::new` |
20
21use std::collections::HashMap;
22
23use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
24use ferrotorch_nn::module::{Module, StateDict};
25use ferrotorch_nn::parameter::Parameter;
26use ferrotorch_nn::{Conv2d, GroupNorm, InterpolateMode, Linear, SiLU, Upsample};
27
28// ---------------------------------------------------------------------------
29// ResnetBlock2D
30// ---------------------------------------------------------------------------
31
32/// `ResnetBlock2D` — the building block of every UNet/VAE up/down/mid stack.
33///
34/// The VAE flavour has no time embedding (`temb_channels=None` in
35/// diffusers), so this implementation does not carry a
36/// `time_emb_proj`. The forward pass is:
37///
38/// ```text
39/// h = norm1(x); h = silu(h); h = conv1(h)
40/// h = norm2(h); h = silu(h); h = conv2(h)
41/// out = h + (x if in==out else conv_shortcut(x))
42/// ```
43///
44/// The `output_scale_factor` from the diffusers reference is always 1.0
45/// in the SD VAE so we hard-code it; if a future config requires
46/// rescaling, add an explicit field.
47#[derive(Debug)]
48pub struct ResnetBlock2D<T: Float> {
49    /// First GroupNorm (over `in_channels`).
50    pub norm1: GroupNorm<T>,
51    /// First Conv2d (`in_channels -> out_channels`, k=3, pad=1).
52    pub conv1: Conv2d<T>,
53    /// Second GroupNorm (over `out_channels`).
54    pub norm2: GroupNorm<T>,
55    /// Second Conv2d (`out_channels -> out_channels`, k=3, pad=1).
56    pub conv2: Conv2d<T>,
57    /// Optional 1x1 shortcut conv (present iff `in_channels !=
58    /// out_channels`).
59    pub conv_shortcut: Option<Conv2d<T>>,
60    activation: SiLU,
61    in_channels: usize,
62    #[allow(dead_code)]
63    out_channels: usize,
64    training: bool,
65}
66
67impl<T: Float> ResnetBlock2D<T> {
68    /// Build a randomly-initialized `ResnetBlock2D`.
69    ///
70    /// # Errors
71    ///
72    /// Returns [`FerrotorchError::InvalidArgument`] on bad channel /
73    /// group config.
74    pub fn new(
75        in_channels: usize,
76        out_channels: usize,
77        norm_num_groups: usize,
78        eps: f64,
79    ) -> FerrotorchResult<Self> {
80        let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
81        let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
82        let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
83        let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
84        let conv_shortcut = if in_channels == out_channels {
85            None
86        } else {
87            Some(Conv2d::<T>::new(
88                in_channels,
89                out_channels,
90                (1, 1),
91                (1, 1),
92                (0, 0),
93                true,
94            )?)
95        };
96        Ok(Self {
97            norm1,
98            conv1,
99            norm2,
100            conv2,
101            conv_shortcut,
102            activation: SiLU::new(),
103            in_channels,
104            out_channels,
105            training: false,
106        })
107    }
108}
109
110impl<T: Float> Module<T> for ResnetBlock2D<T> {
111    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
112        if input.ndim() != 4 || input.shape()[1] != self.in_channels {
113            return Err(FerrotorchError::ShapeMismatch {
114                message: format!(
115                    "ResnetBlock2D::forward: expected [B, {}, H, W], got {:?}",
116                    self.in_channels,
117                    input.shape()
118                ),
119            });
120        }
121        // h = norm1(x); silu; conv1.
122        let mut h = self.norm1.forward(input)?;
123        h = self.activation.forward(&h)?;
124        h = self.conv1.forward(&h)?;
125        // h = norm2(h); silu; conv2 (no dropout at eval / inference time).
126        h = self.norm2.forward(&h)?;
127        h = self.activation.forward(&h)?;
128        h = self.conv2.forward(&h)?;
129        // Residual (optionally projected via 1x1 conv).
130        let res = if let Some(sc) = &self.conv_shortcut {
131            sc.forward(input)?
132        } else {
133            input.clone()
134        };
135        // out = h + res. `output_scale_factor=1.0` in the SD VAE so no
136        // division on the way out.
137        ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
138    }
139
140    fn parameters(&self) -> Vec<&Parameter<T>> {
141        let mut out = Vec::new();
142        out.extend(self.norm1.parameters());
143        out.extend(self.conv1.parameters());
144        out.extend(self.norm2.parameters());
145        out.extend(self.conv2.parameters());
146        if let Some(sc) = &self.conv_shortcut {
147            out.extend(sc.parameters());
148        }
149        out
150    }
151
152    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
153        let mut out = Vec::new();
154        out.extend(self.norm1.parameters_mut());
155        out.extend(self.conv1.parameters_mut());
156        out.extend(self.norm2.parameters_mut());
157        out.extend(self.conv2.parameters_mut());
158        if let Some(sc) = &mut self.conv_shortcut {
159            out.extend(sc.parameters_mut());
160        }
161        out
162    }
163
164    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
165        let mut out = Vec::new();
166        for (n, p) in self.norm1.named_parameters() {
167            out.push((format!("norm1.{n}"), p));
168        }
169        for (n, p) in self.conv1.named_parameters() {
170            out.push((format!("conv1.{n}"), p));
171        }
172        for (n, p) in self.norm2.named_parameters() {
173            out.push((format!("norm2.{n}"), p));
174        }
175        for (n, p) in self.conv2.named_parameters() {
176            out.push((format!("conv2.{n}"), p));
177        }
178        if let Some(sc) = &self.conv_shortcut {
179            for (n, p) in sc.named_parameters() {
180                out.push((format!("conv_shortcut.{n}"), p));
181            }
182        }
183        out
184    }
185
186    fn train(&mut self) {
187        self.training = true;
188    }
189    fn eval(&mut self) {
190        self.training = false;
191    }
192    fn is_training(&self) -> bool {
193        self.training
194    }
195
196    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
197        let extract = |prefix: &str| -> StateDict<T> {
198            let p = format!("{prefix}.");
199            state
200                .iter()
201                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
202                .collect()
203        };
204        if strict {
205            for k in state.keys() {
206                let ok = k.starts_with("norm1.")
207                    || k.starts_with("conv1.")
208                    || k.starts_with("norm2.")
209                    || k.starts_with("conv2.")
210                    || k.starts_with("conv_shortcut.");
211                if !ok {
212                    return Err(FerrotorchError::InvalidArgument {
213                        message: format!("unexpected key in ResnetBlock2D state_dict: \"{k}\""),
214                    });
215                }
216            }
217        }
218        self.norm1.load_state_dict(&extract("norm1"), strict)?;
219        self.conv1.load_state_dict(&extract("conv1"), strict)?;
220        self.norm2.load_state_dict(&extract("norm2"), strict)?;
221        self.conv2.load_state_dict(&extract("conv2"), strict)?;
222        if let Some(sc) = self.conv_shortcut.as_mut() {
223            sc.load_state_dict(&extract("conv_shortcut"), strict)?;
224        }
225        Ok(())
226    }
227}
228
229// ---------------------------------------------------------------------------
230// AttnBlock2D
231// ---------------------------------------------------------------------------
232
233/// Single-head spatial self-attention with residual + GroupNorm — the
234/// VAE mid-block attention.
235///
236/// Matches `diffusers.models.attention_processor.Attention` configured
237/// with:
238///
239///   - `heads = in_channels / attention_head_dim = 512 / 512 = 1`
240///   - `norm_num_groups = 32` (so `group_norm` is enabled)
241///   - `residual_connection = True`
242///   - `bias = True`, `out_bias = True`
243///   - `eps = 1e-6`
244///
245/// State-dict layout (HF diffusers):
246///
247/// ```text
248/// group_norm.{weight,bias}    [C], [C]
249/// to_q.{weight,bias}          [C, C], [C]
250/// to_k.{weight,bias}          [C, C], [C]
251/// to_v.{weight,bias}          [C, C], [C]
252/// to_out.0.{weight,bias}      [C, C], [C]      // to_out[1] is Dropout (no params)
253/// ```
254#[derive(Debug)]
255pub struct AttnBlock2D<T: Float> {
256    /// GroupNorm over the channel axis.
257    pub group_norm: GroupNorm<T>,
258    /// Query projection `Linear(C -> C, bias)`.
259    pub to_q: Linear<T>,
260    /// Key projection `Linear(C -> C, bias)`.
261    pub to_k: Linear<T>,
262    /// Value projection `Linear(C -> C, bias)`.
263    pub to_v: Linear<T>,
264    /// Output projection `to_out[0] = Linear(C -> C, bias)`.
265    pub to_out_0: Linear<T>,
266    channels: usize,
267    training: bool,
268}
269
270impl<T: Float> AttnBlock2D<T> {
271    /// Build a randomly-initialized `AttnBlock2D`.
272    ///
273    /// # Errors
274    ///
275    /// Returns [`FerrotorchError::InvalidArgument`] when `channels`
276    /// is not divisible by `norm_num_groups` (the GroupNorm constructor
277    /// surfaces this).
278    pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
279        let group_norm = GroupNorm::<T>::new(norm_num_groups, channels, eps, true)?;
280        let to_q = Linear::<T>::new(channels, channels, true)?;
281        let to_k = Linear::<T>::new(channels, channels, true)?;
282        let to_v = Linear::<T>::new(channels, channels, true)?;
283        let to_out_0 = Linear::<T>::new(channels, channels, true)?;
284        Ok(Self {
285            group_norm,
286            to_q,
287            to_k,
288            to_v,
289            to_out_0,
290            channels,
291            training: false,
292        })
293    }
294}
295
296impl<T: Float> Module<T> for AttnBlock2D<T> {
297    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
298        if input.ndim() != 4 || input.shape()[1] != self.channels {
299            return Err(FerrotorchError::ShapeMismatch {
300                message: format!(
301                    "AttnBlock2D::forward: expected [B, {}, H, W], got {:?}",
302                    self.channels,
303                    input.shape()
304                ),
305            });
306        }
307        let b = input.shape()[0];
308        let c = input.shape()[1];
309        let h = input.shape()[2];
310        let w = input.shape()[3];
311        let hw = h * w;
312
313        // Residual (kept in the spatial layout).
314        let residual = input.clone();
315
316        // -- 1. Reshape to [B, HW, C]:
317        //       diffusers does `view(B, C, HW).transpose(1, 2)`.
318        //       Then it group-norms on the channel axis by re-transposing
319        //       back to [B, C, HW], norming, and re-transposing.
320        let hidden = input
321            .reshape_t(&[b as isize, c as isize, hw as isize])?
322            .transpose(1, 2)?
323            .contiguous()?;
324        // GroupNorm: input [B, C, HW] -> [B, HW, C] after the final transpose.
325        let normed_hwc = self
326            .group_norm
327            .forward(&hidden.transpose(1, 2)?.contiguous()?)?
328            .transpose(1, 2)?
329            .contiguous()?;
330
331        // -- 2. q / k / v projections (Linear over the trailing C dim).
332        let q = self.to_q.forward(&normed_hwc)?; // [B, HW, C]
333        let k = self.to_k.forward(&normed_hwc)?; // [B, HW, C]
334        let v = self.to_v.forward(&normed_hwc)?; // [B, HW, C]
335
336        // -- 3. Single-head attention:
337        //          scores = q @ k^T * (1/sqrt(C))
338        //          probs  = softmax(scores, dim=-1)
339        //          out    = probs @ v
340        //       For single-head the head-split is a no-op, so this is
341        //       just a per-batch BMM. `bmm` handles the [B, HW, HW]
342        //       intermediate without materialising a head axis.
343        let scale = (c as f64).sqrt().recip();
344        let scale_t = T::from(scale).ok_or_else(|| FerrotorchError::InvalidArgument {
345            message: "AttnBlock2D::forward: failed to cast attention scale into Float".into(),
346        })?;
347        let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
348        let k_t = k.transpose(1, 2)?.contiguous()?; // [B, C, HW]
349        let scores = q.bmm(&k_t)?; // [B, HW, HW]
350        let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
351        let probs = scores_scaled.softmax()?; // dim = -1 by default
352        let attended = probs.bmm(&v)?; // [B, HW, C]
353
354        // -- 4. Output projection.
355        let projected = self.to_out_0.forward(&attended)?; // [B, HW, C]
356
357        // -- 5. Back to [B, C, H, W] and add residual.
358        let back = projected
359            .transpose(1, 2)? // [B, C, HW]
360            .reshape_t(&[b as isize, c as isize, h as isize, w as isize])?
361            .contiguous()?;
362        ferrotorch_core::grad_fns::arithmetic::add(&back, &residual)
363    }
364
365    fn parameters(&self) -> Vec<&Parameter<T>> {
366        let mut out = Vec::new();
367        out.extend(self.group_norm.parameters());
368        out.extend(self.to_q.parameters());
369        out.extend(self.to_k.parameters());
370        out.extend(self.to_v.parameters());
371        out.extend(self.to_out_0.parameters());
372        out
373    }
374
375    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
376        let mut out = Vec::new();
377        out.extend(self.group_norm.parameters_mut());
378        out.extend(self.to_q.parameters_mut());
379        out.extend(self.to_k.parameters_mut());
380        out.extend(self.to_v.parameters_mut());
381        out.extend(self.to_out_0.parameters_mut());
382        out
383    }
384
385    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
386        let mut out = Vec::new();
387        for (n, p) in self.group_norm.named_parameters() {
388            out.push((format!("group_norm.{n}"), p));
389        }
390        for (n, p) in self.to_q.named_parameters() {
391            out.push((format!("to_q.{n}"), p));
392        }
393        for (n, p) in self.to_k.named_parameters() {
394            out.push((format!("to_k.{n}"), p));
395        }
396        for (n, p) in self.to_v.named_parameters() {
397            out.push((format!("to_v.{n}"), p));
398        }
399        for (n, p) in self.to_out_0.named_parameters() {
400            out.push((format!("to_out.0.{n}"), p));
401        }
402        out
403    }
404
405    fn train(&mut self) {
406        self.training = true;
407    }
408    fn eval(&mut self) {
409        self.training = false;
410    }
411    fn is_training(&self) -> bool {
412        self.training
413    }
414
415    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
416        let extract = |prefix: &str| -> StateDict<T> {
417            let p = format!("{prefix}.");
418            state
419                .iter()
420                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
421                .collect()
422        };
423        if strict {
424            for k in state.keys() {
425                let ok = k.starts_with("group_norm.")
426                    || k.starts_with("to_q.")
427                    || k.starts_with("to_k.")
428                    || k.starts_with("to_v.")
429                    || k.starts_with("to_out.0.");
430                if !ok {
431                    return Err(FerrotorchError::InvalidArgument {
432                        message: format!("unexpected key in AttnBlock2D state_dict: \"{k}\""),
433                    });
434                }
435            }
436        }
437        self.group_norm
438            .load_state_dict(&extract("group_norm"), strict)?;
439        self.to_q.load_state_dict(&extract("to_q"), strict)?;
440        self.to_k.load_state_dict(&extract("to_k"), strict)?;
441        self.to_v.load_state_dict(&extract("to_v"), strict)?;
442        self.to_out_0
443            .load_state_dict(&extract("to_out.0"), strict)?;
444        Ok(())
445    }
446}
447
448// ---------------------------------------------------------------------------
449// Upsample2D
450// ---------------------------------------------------------------------------
451
452/// Diffusers-style `Upsample2D` — nearest-neighbor 2x interpolation
453/// followed by a `Conv2d(C, C, k=3, pad=1, bias=True)`.
454///
455/// State-dict key: `conv.{weight,bias}` (matching `name="conv"` in the
456/// diffusers reference).
457#[derive(Debug)]
458pub struct Upsample2D<T: Float> {
459    /// Output conv (also serves as the "post-upsample smoothing" filter).
460    pub conv: Conv2d<T>,
461    channels: usize,
462    training: bool,
463}
464
465impl<T: Float> Upsample2D<T> {
466    /// Build a randomly-initialized `Upsample2D` (`C -> C`, k=3, pad=1).
467    ///
468    /// # Errors
469    ///
470    /// Returns the underlying [`FerrotorchError`] on bad conv config.
471    pub fn new(channels: usize) -> FerrotorchResult<Self> {
472        let conv = Conv2d::<T>::new(channels, channels, (3, 3), (1, 1), (1, 1), true)?;
473        Ok(Self {
474            conv,
475            channels,
476            training: false,
477        })
478    }
479}
480
481impl<T: Float> Module<T> for Upsample2D<T> {
482    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
483        if input.ndim() != 4 || input.shape()[1] != self.channels {
484            return Err(FerrotorchError::ShapeMismatch {
485                message: format!(
486                    "Upsample2D::forward: expected [B, {}, H, W], got {:?}",
487                    self.channels,
488                    input.shape()
489                ),
490            });
491        }
492        let h = input.shape()[2];
493        let w = input.shape()[3];
494        let up = Upsample::new([h * 2, w * 2], InterpolateMode::Nearest);
495        let upsampled = up.forward(input)?;
496        self.conv.forward(&upsampled)
497    }
498
499    fn parameters(&self) -> Vec<&Parameter<T>> {
500        self.conv.parameters()
501    }
502    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
503        self.conv.parameters_mut()
504    }
505    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
506        self.conv
507            .named_parameters()
508            .into_iter()
509            .map(|(n, p)| (format!("conv.{n}"), p))
510            .collect()
511    }
512    fn train(&mut self) {
513        self.training = true;
514    }
515    fn eval(&mut self) {
516        self.training = false;
517    }
518    fn is_training(&self) -> bool {
519        self.training
520    }
521    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
522        let conv_sd: StateDict<T> = state
523            .iter()
524            .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
525            .collect();
526        if strict {
527            for k in state.keys() {
528                if !k.starts_with("conv.") {
529                    return Err(FerrotorchError::InvalidArgument {
530                        message: format!("unexpected key in Upsample2D state_dict: \"{k}\""),
531                    });
532                }
533            }
534        }
535        self.conv.load_state_dict(&conv_sd, strict)
536    }
537}
538
539// ---------------------------------------------------------------------------
540// Downsample2D
541// ---------------------------------------------------------------------------
542
543/// Diffusers-style `Downsample2D` — a single `Conv2d(C, C, k=3, stride=2,
544/// pad=1, bias=True)`.
545///
546/// SD-1.5 uses `use_conv=True` and `padding=1`, so there is no separate
547/// pre-conv padding step. State-dict key: `conv.{weight,bias}`.
548#[derive(Debug)]
549pub struct Downsample2D<T: Float> {
550    /// Output conv (k=3, stride=2, pad=1).
551    pub conv: Conv2d<T>,
552    channels: usize,
553    training: bool,
554}
555
556impl<T: Float> Downsample2D<T> {
557    /// Build a randomly-initialized `Downsample2D` (`C -> C`).
558    ///
559    /// # Errors
560    ///
561    /// Returns the underlying [`FerrotorchError`] on bad conv config.
562    pub fn new(channels: usize) -> FerrotorchResult<Self> {
563        let conv = Conv2d::<T>::new(channels, channels, (3, 3), (2, 2), (1, 1), true)?;
564        Ok(Self {
565            conv,
566            channels,
567            training: false,
568        })
569    }
570}
571
572impl<T: Float> Module<T> for Downsample2D<T> {
573    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
574        if input.ndim() != 4 || input.shape()[1] != self.channels {
575            return Err(FerrotorchError::ShapeMismatch {
576                message: format!(
577                    "Downsample2D::forward: expected [B, {}, H, W], got {:?}",
578                    self.channels,
579                    input.shape()
580                ),
581            });
582        }
583        self.conv.forward(input)
584    }
585    fn parameters(&self) -> Vec<&Parameter<T>> {
586        self.conv.parameters()
587    }
588    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
589        self.conv.parameters_mut()
590    }
591    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
592        self.conv
593            .named_parameters()
594            .into_iter()
595            .map(|(n, p)| (format!("conv.{n}"), p))
596            .collect()
597    }
598    fn train(&mut self) {
599        self.training = true;
600    }
601    fn eval(&mut self) {
602        self.training = false;
603    }
604    fn is_training(&self) -> bool {
605        self.training
606    }
607    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
608        let conv_sd: StateDict<T> = state
609            .iter()
610            .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
611            .collect();
612        if strict {
613            for k in state.keys() {
614                if !k.starts_with("conv.") {
615                    return Err(FerrotorchError::InvalidArgument {
616                        message: format!("unexpected key in Downsample2D state_dict: \"{k}\""),
617                    });
618                }
619            }
620        }
621        self.conv.load_state_dict(&conv_sd, strict)
622    }
623}
624
625// ---------------------------------------------------------------------------
626// UpDecoderBlock2D
627// ---------------------------------------------------------------------------
628
629/// `UpDecoderBlock2D` — a stack of `layers_per_block + 1` resnets at
630/// `out_channels`, optionally followed by an `Upsample2D`.
631///
632/// State-dict key layout:
633///
634/// ```text
635/// resnets.{i}.<resnet keys>    i in 0..(layers_per_block + 1)
636/// upsamplers.0.<upsample keys>  (present iff add_upsample)
637/// ```
638#[derive(Debug)]
639pub struct UpDecoderBlock2D<T: Float> {
640    /// `layers_per_block + 1` resnets at `out_channels`.
641    pub resnets: Vec<ResnetBlock2D<T>>,
642    /// Optional upsample (absent on the last decoder block).
643    pub upsamplers_0: Option<Upsample2D<T>>,
644    training: bool,
645}
646
647impl<T: Float> UpDecoderBlock2D<T> {
648    /// Build a randomly-initialized `UpDecoderBlock2D`.
649    ///
650    /// # Errors
651    ///
652    /// Returns [`FerrotorchError::InvalidArgument`] on bad channel /
653    /// group config.
654    pub fn new(
655        in_channels: usize,
656        out_channels: usize,
657        num_resnets: usize,
658        norm_num_groups: usize,
659        eps: f64,
660        add_upsample: bool,
661    ) -> FerrotorchResult<Self> {
662        if num_resnets == 0 {
663            return Err(FerrotorchError::InvalidArgument {
664                message: "UpDecoderBlock2D: num_resnets must be > 0".into(),
665            });
666        }
667        let mut resnets = Vec::with_capacity(num_resnets);
668        for i in 0..num_resnets {
669            let in_c = if i == 0 { in_channels } else { out_channels };
670            resnets.push(ResnetBlock2D::<T>::new(
671                in_c,
672                out_channels,
673                norm_num_groups,
674                eps,
675            )?);
676        }
677        let upsamplers_0 = if add_upsample {
678            Some(Upsample2D::<T>::new(out_channels)?)
679        } else {
680            None
681        };
682        Ok(Self {
683            resnets,
684            upsamplers_0,
685            training: false,
686        })
687    }
688}
689
690impl<T: Float> Module<T> for UpDecoderBlock2D<T> {
691    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
692        let mut h = input.clone();
693        for r in &self.resnets {
694            h = r.forward(&h)?;
695        }
696        if let Some(u) = &self.upsamplers_0 {
697            h = u.forward(&h)?;
698        }
699        Ok(h)
700    }
701
702    fn parameters(&self) -> Vec<&Parameter<T>> {
703        let mut out = Vec::new();
704        for r in &self.resnets {
705            out.extend(r.parameters());
706        }
707        if let Some(u) = &self.upsamplers_0 {
708            out.extend(u.parameters());
709        }
710        out
711    }
712
713    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
714        let mut out = Vec::new();
715        for r in &mut self.resnets {
716            out.extend(r.parameters_mut());
717        }
718        if let Some(u) = self.upsamplers_0.as_mut() {
719            out.extend(u.parameters_mut());
720        }
721        out
722    }
723
724    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
725        let mut out = Vec::new();
726        for (i, r) in self.resnets.iter().enumerate() {
727            for (n, p) in r.named_parameters() {
728                out.push((format!("resnets.{i}.{n}"), p));
729            }
730        }
731        if let Some(u) = &self.upsamplers_0 {
732            for (n, p) in u.named_parameters() {
733                out.push((format!("upsamplers.0.{n}"), p));
734            }
735        }
736        out
737    }
738
739    fn train(&mut self) {
740        self.training = true;
741        for r in &mut self.resnets {
742            r.train();
743        }
744        if let Some(u) = self.upsamplers_0.as_mut() {
745            u.train();
746        }
747    }
748    fn eval(&mut self) {
749        self.training = false;
750        for r in &mut self.resnets {
751            r.eval();
752        }
753        if let Some(u) = self.upsamplers_0.as_mut() {
754            u.eval();
755        }
756    }
757    fn is_training(&self) -> bool {
758        self.training
759    }
760
761    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
762        let extract = |prefix: &str| -> StateDict<T> {
763            let p = format!("{prefix}.");
764            state
765                .iter()
766                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
767                .collect()
768        };
769        if strict {
770            for k in state.keys() {
771                let ok = k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
772                if !ok {
773                    return Err(FerrotorchError::InvalidArgument {
774                        message: format!("unexpected key in UpDecoderBlock2D state_dict: \"{k}\""),
775                    });
776                }
777            }
778        }
779        for (i, r) in self.resnets.iter_mut().enumerate() {
780            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
781        }
782        if let Some(u) = self.upsamplers_0.as_mut() {
783            u.load_state_dict(&extract("upsamplers.0"), strict)?;
784        } else if strict {
785            // Verify no upsamplers keys leaked through if we have no upsample.
786            for k in state.keys() {
787                if k.starts_with("upsamplers.") {
788                    return Err(FerrotorchError::InvalidArgument {
789                        message: format!(
790                            "UpDecoderBlock2D has no upsampler but state_dict contains \"{k}\""
791                        ),
792                    });
793                }
794            }
795        }
796        Ok(())
797    }
798}
799
800// ---------------------------------------------------------------------------
801// DownEncoderBlock2D
802// ---------------------------------------------------------------------------
803
804/// `DownEncoderBlock2D` — a stack of `layers_per_block` resnets at
805/// `out_channels`, optionally followed by a `Downsample2D`.
806///
807/// This is the encoder-side mirror of [`UpDecoderBlock2D`]. The
808/// asymmetry in resnet count (encoder uses `layers_per_block` resnets,
809/// decoder uses `layers_per_block + 1`) matches the diffusers
810/// `AutoencoderKL` convention exactly.
811///
812/// State-dict key layout:
813///
814/// ```text
815/// resnets.{i}.<resnet keys>      i in 0..layers_per_block
816/// downsamplers.0.<downsample keys>   (present iff add_downsample)
817/// ```
818#[derive(Debug)]
819pub struct DownEncoderBlock2D<T: Float> {
820    /// `layers_per_block` resnets at `out_channels`. The first resnet
821    /// projects `in_channels -> out_channels` (via its 1x1 shortcut
822    /// when the channels differ); subsequent resnets are `out_channels`
823    /// throughout.
824    pub resnets: Vec<ResnetBlock2D<T>>,
825    /// Optional downsample (absent on the deepest encoder block, which
826    /// holds spatial resolution before the mid-block).
827    pub downsamplers_0: Option<Downsample2D<T>>,
828    training: bool,
829}
830
831impl<T: Float> DownEncoderBlock2D<T> {
832    /// Build a randomly-initialized `DownEncoderBlock2D`.
833    ///
834    /// # Errors
835    ///
836    /// Returns [`FerrotorchError::InvalidArgument`] when `num_resnets`
837    /// is zero, or any underlying [`FerrotorchError`] on bad channel /
838    /// group config.
839    pub fn new(
840        in_channels: usize,
841        out_channels: usize,
842        num_resnets: usize,
843        norm_num_groups: usize,
844        eps: f64,
845        add_downsample: bool,
846    ) -> FerrotorchResult<Self> {
847        if num_resnets == 0 {
848            return Err(FerrotorchError::InvalidArgument {
849                message: "DownEncoderBlock2D: num_resnets must be > 0".into(),
850            });
851        }
852        let mut resnets = Vec::with_capacity(num_resnets);
853        for i in 0..num_resnets {
854            let in_c = if i == 0 { in_channels } else { out_channels };
855            resnets.push(ResnetBlock2D::<T>::new(
856                in_c,
857                out_channels,
858                norm_num_groups,
859                eps,
860            )?);
861        }
862        let downsamplers_0 = if add_downsample {
863            Some(Downsample2D::<T>::new(out_channels)?)
864        } else {
865            None
866        };
867        Ok(Self {
868            resnets,
869            downsamplers_0,
870            training: false,
871        })
872    }
873}
874
875impl<T: Float> Module<T> for DownEncoderBlock2D<T> {
876    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
877        let mut h = input.clone();
878        for r in &self.resnets {
879            h = r.forward(&h)?;
880        }
881        if let Some(d) = &self.downsamplers_0 {
882            h = d.forward(&h)?;
883        }
884        Ok(h)
885    }
886
887    fn parameters(&self) -> Vec<&Parameter<T>> {
888        let mut out = Vec::new();
889        for r in &self.resnets {
890            out.extend(r.parameters());
891        }
892        if let Some(d) = &self.downsamplers_0 {
893            out.extend(d.parameters());
894        }
895        out
896    }
897
898    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
899        let mut out = Vec::new();
900        for r in &mut self.resnets {
901            out.extend(r.parameters_mut());
902        }
903        if let Some(d) = self.downsamplers_0.as_mut() {
904            out.extend(d.parameters_mut());
905        }
906        out
907    }
908
909    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
910        let mut out = Vec::new();
911        for (i, r) in self.resnets.iter().enumerate() {
912            for (n, p) in r.named_parameters() {
913                out.push((format!("resnets.{i}.{n}"), p));
914            }
915        }
916        if let Some(d) = &self.downsamplers_0 {
917            for (n, p) in d.named_parameters() {
918                out.push((format!("downsamplers.0.{n}"), p));
919            }
920        }
921        out
922    }
923
924    fn train(&mut self) {
925        self.training = true;
926        for r in &mut self.resnets {
927            r.train();
928        }
929        if let Some(d) = self.downsamplers_0.as_mut() {
930            d.train();
931        }
932    }
933    fn eval(&mut self) {
934        self.training = false;
935        for r in &mut self.resnets {
936            r.eval();
937        }
938        if let Some(d) = self.downsamplers_0.as_mut() {
939            d.eval();
940        }
941    }
942    fn is_training(&self) -> bool {
943        self.training
944    }
945
946    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
947        let extract = |prefix: &str| -> StateDict<T> {
948            let p = format!("{prefix}.");
949            state
950                .iter()
951                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
952                .collect()
953        };
954        if strict {
955            for k in state.keys() {
956                let ok = k.starts_with("resnets.") || k.starts_with("downsamplers.0.");
957                if !ok {
958                    return Err(FerrotorchError::InvalidArgument {
959                        message: format!(
960                            "unexpected key in DownEncoderBlock2D state_dict: \"{k}\""
961                        ),
962                    });
963                }
964            }
965        }
966        for (i, r) in self.resnets.iter_mut().enumerate() {
967            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
968        }
969        if let Some(d) = self.downsamplers_0.as_mut() {
970            d.load_state_dict(&extract("downsamplers.0"), strict)?;
971        } else if strict {
972            for k in state.keys() {
973                if k.starts_with("downsamplers.") {
974                    return Err(FerrotorchError::InvalidArgument {
975                        message: format!(
976                            "DownEncoderBlock2D has no downsampler but state_dict contains \"{k}\""
977                        ),
978                    });
979                }
980            }
981        }
982        Ok(())
983    }
984}
985
986// ---------------------------------------------------------------------------
987// UNetMidBlock2D (VAE flavour)
988// ---------------------------------------------------------------------------
989
990/// `UNetMidBlock2D` configured the way the SD VAE uses it:
991///
992/// ```text
993/// resnets[0] -> attentions[0] -> resnets[1]
994/// ```
995///
996/// 2 resnets, 1 attention; all at `in_channels = out_channels = 512` for
997/// SD 1.5. Diffusers always has `len(resnets) == num_layers + 1` and
998/// `len(attentions) == num_layers`; in the VAE `num_layers = 1`, so two
999/// resnets and one attention.
1000#[derive(Debug)]
1001pub struct UNetMidBlock2D<T: Float> {
1002    /// Pre-attention + post-attention resnets (size 2 for SD VAE).
1003    pub resnets: Vec<ResnetBlock2D<T>>,
1004    /// Single attention block.
1005    pub attentions: Vec<AttnBlock2D<T>>,
1006    channels: usize,
1007    training: bool,
1008}
1009
1010impl<T: Float> UNetMidBlock2D<T> {
1011    /// Build a randomly-initialized VAE-flavoured `UNetMidBlock2D`.
1012    ///
1013    /// # Errors
1014    ///
1015    /// Returns the underlying [`FerrotorchError`] on bad config.
1016    pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
1017        // SD VAE: num_layers=1 => 2 resnets + 1 attention.
1018        let resnets = vec![
1019            ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1020            ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1021        ];
1022        let attentions = vec![AttnBlock2D::<T>::new(channels, norm_num_groups, eps)?];
1023        Ok(Self {
1024            resnets,
1025            attentions,
1026            channels,
1027            training: false,
1028        })
1029    }
1030}
1031
1032impl<T: Float> Module<T> for UNetMidBlock2D<T> {
1033    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1034        // diffusers UNetMidBlock2D.forward (no temb):
1035        //   h = resnets[0](x)
1036        //   for (attn, resnet) in zip(attentions, resnets[1:]):
1037        //       h = attn(h); h = resnet(h)
1038        let mut h = self.resnets[0].forward(input)?;
1039        for (attn, resnet) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
1040            h = attn.forward(&h)?;
1041            h = resnet.forward(&h)?;
1042        }
1043        Ok(h)
1044    }
1045
1046    fn parameters(&self) -> Vec<&Parameter<T>> {
1047        let mut out = Vec::new();
1048        // The HF diffusers state-dict key order is `attentions` before
1049        // `resnets`. We match the ordering used for state-dict
1050        // round-trip — see `named_parameters()` for the canonical
1051        // sequence.
1052        for a in &self.attentions {
1053            out.extend(a.parameters());
1054        }
1055        for r in &self.resnets {
1056            out.extend(r.parameters());
1057        }
1058        out
1059    }
1060
1061    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1062        let mut out = Vec::new();
1063        for a in &mut self.attentions {
1064            out.extend(a.parameters_mut());
1065        }
1066        for r in &mut self.resnets {
1067            out.extend(r.parameters_mut());
1068        }
1069        out
1070    }
1071
1072    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1073        let mut out = Vec::new();
1074        // Match the order used by `safetensors` listings of the HF VAE:
1075        // `attentions.*` first, then `resnets.*`.
1076        for (i, a) in self.attentions.iter().enumerate() {
1077            for (n, p) in a.named_parameters() {
1078                out.push((format!("attentions.{i}.{n}"), p));
1079            }
1080        }
1081        for (i, r) in self.resnets.iter().enumerate() {
1082            for (n, p) in r.named_parameters() {
1083                out.push((format!("resnets.{i}.{n}"), p));
1084            }
1085        }
1086        out
1087    }
1088
1089    fn train(&mut self) {
1090        self.training = true;
1091        for r in &mut self.resnets {
1092            r.train();
1093        }
1094        for a in &mut self.attentions {
1095            a.train();
1096        }
1097    }
1098    fn eval(&mut self) {
1099        self.training = false;
1100        for r in &mut self.resnets {
1101            r.eval();
1102        }
1103        for a in &mut self.attentions {
1104            a.eval();
1105        }
1106    }
1107    fn is_training(&self) -> bool {
1108        self.training
1109    }
1110
1111    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1112        let extract = |prefix: &str| -> StateDict<T> {
1113            let p = format!("{prefix}.");
1114            state
1115                .iter()
1116                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1117                .collect()
1118        };
1119        if strict {
1120            for k in state.keys() {
1121                let ok = k.starts_with("attentions.") || k.starts_with("resnets.");
1122                if !ok {
1123                    return Err(FerrotorchError::InvalidArgument {
1124                        message: format!("unexpected key in UNetMidBlock2D state_dict: \"{k}\""),
1125                    });
1126                }
1127            }
1128        }
1129        for (i, a) in self.attentions.iter_mut().enumerate() {
1130            a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
1131        }
1132        for (i, r) in self.resnets.iter_mut().enumerate() {
1133            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
1134        }
1135        let _ = self.channels; // silence dead_code warning under some lint configs
1136        let _: HashMap<String, Tensor<T>> = HashMap::new(); // pull HashMap into scope
1137        Ok(())
1138    }
1139}
1140
1141#[cfg(test)]
1142mod tests {
1143    use super::*;
1144    use ferrotorch_core::TensorStorage;
1145
1146    #[test]
1147    fn resnet_same_channels_no_shortcut() {
1148        let r = ResnetBlock2D::<f32>::new(32, 32, 32, 1e-6).unwrap();
1149        assert!(r.conv_shortcut.is_none());
1150        let x = Tensor::from_storage(
1151            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1152            vec![1, 32, 4, 4],
1153            false,
1154        )
1155        .unwrap();
1156        let y = r.forward(&x).unwrap();
1157        assert_eq!(y.shape(), &[1, 32, 4, 4]);
1158        for &v in y.data().unwrap() {
1159            assert!(v.is_finite());
1160        }
1161    }
1162
1163    #[test]
1164    fn resnet_different_channels_has_shortcut() {
1165        let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1166        assert!(r.conv_shortcut.is_some());
1167        let x = Tensor::from_storage(
1168            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1169            vec![1, 32, 4, 4],
1170            false,
1171        )
1172        .unwrap();
1173        let y = r.forward(&x).unwrap();
1174        assert_eq!(y.shape(), &[1, 64, 4, 4]);
1175    }
1176
1177    #[test]
1178    fn resnet_named_parameters_layout() {
1179        let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1180        let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
1181        for k in [
1182            "norm1.weight",
1183            "norm1.bias",
1184            "conv1.weight",
1185            "conv1.bias",
1186            "norm2.weight",
1187            "norm2.bias",
1188            "conv2.weight",
1189            "conv2.bias",
1190            "conv_shortcut.weight",
1191            "conv_shortcut.bias",
1192        ] {
1193            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1194        }
1195    }
1196
1197    #[test]
1198    fn attn_shape_and_residual() {
1199        // 32 channels, 4 groups => 8 channels per group; small spatial.
1200        let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1201        let x = Tensor::from_storage(
1202            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1203            vec![1, 32, 4, 4],
1204            false,
1205        )
1206        .unwrap();
1207        let y = a.forward(&x).unwrap();
1208        assert_eq!(y.shape(), &[1, 32, 4, 4]);
1209        for &v in y.data().unwrap() {
1210            assert!(v.is_finite(), "attn output non-finite: {v}");
1211        }
1212    }
1213
1214    #[test]
1215    fn attn_named_parameters_layout() {
1216        let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1217        let names: Vec<String> = a.named_parameters().into_iter().map(|(n, _)| n).collect();
1218        for k in [
1219            "group_norm.weight",
1220            "group_norm.bias",
1221            "to_q.weight",
1222            "to_q.bias",
1223            "to_k.weight",
1224            "to_k.bias",
1225            "to_v.weight",
1226            "to_v.bias",
1227            "to_out.0.weight",
1228            "to_out.0.bias",
1229        ] {
1230            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1231        }
1232    }
1233
1234    #[test]
1235    fn upsample2d_doubles_spatial() {
1236        let u = Upsample2D::<f32>::new(8).unwrap();
1237        let x = Tensor::from_storage(
1238            TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1239            vec![1, 8, 3, 3],
1240            false,
1241        )
1242        .unwrap();
1243        let y = u.forward(&x).unwrap();
1244        assert_eq!(y.shape(), &[1, 8, 6, 6]);
1245    }
1246
1247    #[test]
1248    fn up_decoder_block_shape_with_upsample() {
1249        let b = UpDecoderBlock2D::<f32>::new(16, 8, 2, 4, 1e-6, true).unwrap();
1250        let x = Tensor::from_storage(
1251            TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1252            vec![1, 16, 3, 3],
1253            false,
1254        )
1255        .unwrap();
1256        let y = b.forward(&x).unwrap();
1257        assert_eq!(y.shape(), &[1, 8, 6, 6]);
1258    }
1259
1260    #[test]
1261    fn up_decoder_block_shape_no_upsample() {
1262        let b = UpDecoderBlock2D::<f32>::new(8, 8, 2, 4, 1e-6, false).unwrap();
1263        let x = Tensor::from_storage(
1264            TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1265            vec![1, 8, 3, 3],
1266            false,
1267        )
1268        .unwrap();
1269        let y = b.forward(&x).unwrap();
1270        assert_eq!(y.shape(), &[1, 8, 3, 3]);
1271    }
1272
1273    #[test]
1274    fn down_encoder_block_shape_with_downsample() {
1275        // [1, 8, 4, 4] -> resnets project 8 -> 16 -> downsample halves spatial
1276        let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1277        let x = Tensor::from_storage(
1278            TensorStorage::cpu(vec![0.05f32; 8 * 4 * 4]),
1279            vec![1, 8, 4, 4],
1280            false,
1281        )
1282        .unwrap();
1283        let y = b.forward(&x).unwrap();
1284        assert_eq!(y.shape(), &[1, 16, 2, 2]);
1285    }
1286
1287    #[test]
1288    fn down_encoder_block_shape_no_downsample() {
1289        // Last encoder block: same channels, no downsample (spatial preserved)
1290        let b = DownEncoderBlock2D::<f32>::new(16, 16, 2, 4, 1e-6, false).unwrap();
1291        let x = Tensor::from_storage(
1292            TensorStorage::cpu(vec![0.05f32; 16 * 2 * 2]),
1293            vec![1, 16, 2, 2],
1294            false,
1295        )
1296        .unwrap();
1297        let y = b.forward(&x).unwrap();
1298        assert_eq!(y.shape(), &[1, 16, 2, 2]);
1299    }
1300
1301    #[test]
1302    fn down_encoder_block_named_parameters_layout() {
1303        let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1304        let names: Vec<String> = b.named_parameters().into_iter().map(|(n, _)| n).collect();
1305        for k in [
1306            "resnets.0.norm1.weight",
1307            "resnets.0.conv1.weight",
1308            "resnets.0.conv_shortcut.weight",
1309            "resnets.1.norm1.weight",
1310            "downsamplers.0.conv.weight",
1311            "downsamplers.0.conv.bias",
1312        ] {
1313            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1314        }
1315    }
1316
1317    #[test]
1318    fn mid_block_shape() {
1319        let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1320        let x = Tensor::from_storage(
1321            TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1322            vec![1, 16, 3, 3],
1323            false,
1324        )
1325        .unwrap();
1326        let y = m.forward(&x).unwrap();
1327        assert_eq!(y.shape(), &[1, 16, 3, 3]);
1328    }
1329
1330    #[test]
1331    fn mid_block_named_parameters_layout() {
1332        let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1333        let names: Vec<String> = m.named_parameters().into_iter().map(|(n, _)| n).collect();
1334        for k in [
1335            "attentions.0.group_norm.weight",
1336            "attentions.0.to_q.weight",
1337            "resnets.0.norm1.weight",
1338            "resnets.0.conv1.weight",
1339            "resnets.1.conv2.weight",
1340        ] {
1341            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1342        }
1343    }
1344}