Skip to main content

ferrotorch_diffusion/
blocks.rs

1//! Building blocks of the Stable-Diffusion VAE decoder.
2//!
3//! All blocks here match the diffusers `models/{resnet,upsampling,
4//! attention_processor}.py` / `models/unets/unet_2d_blocks.py` reference
5//! layout 1:1 in parameter naming and forward semantics so the upstream
6//! state dict (`runwayml/stable-diffusion-v1-5/vae/diffusion_pytorch_model.safetensors`)
7//! loads byte-for-byte.
8
9use std::collections::HashMap;
10
11use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
12use ferrotorch_nn::module::{Module, StateDict};
13use ferrotorch_nn::parameter::Parameter;
14use ferrotorch_nn::{
15    Conv2d, GroupNorm, InterpolateMode, Linear, SiLU, Upsample,
16};
17
18// ---------------------------------------------------------------------------
19// ResnetBlock2D
20// ---------------------------------------------------------------------------
21
22/// `ResnetBlock2D` — the building block of every UNet/VAE up/down/mid stack.
23///
24/// The VAE flavour has no time embedding (`temb_channels=None` in
25/// diffusers), so this implementation does not carry a
26/// `time_emb_proj`. The forward pass is:
27///
28/// ```text
29/// h = norm1(x); h = silu(h); h = conv1(h)
30/// h = norm2(h); h = silu(h); h = conv2(h)
31/// out = h + (x if in==out else conv_shortcut(x))
32/// ```
33///
34/// The `output_scale_factor` from the diffusers reference is always 1.0
35/// in the SD VAE so we hard-code it; if a future config requires
36/// rescaling, add an explicit field.
37#[derive(Debug)]
38pub struct ResnetBlock2D<T: Float> {
39    /// First GroupNorm (over `in_channels`).
40    pub norm1: GroupNorm<T>,
41    /// First Conv2d (`in_channels -> out_channels`, k=3, pad=1).
42    pub conv1: Conv2d<T>,
43    /// Second GroupNorm (over `out_channels`).
44    pub norm2: GroupNorm<T>,
45    /// Second Conv2d (`out_channels -> out_channels`, k=3, pad=1).
46    pub conv2: Conv2d<T>,
47    /// Optional 1x1 shortcut conv (present iff `in_channels !=
48    /// out_channels`).
49    pub conv_shortcut: Option<Conv2d<T>>,
50    activation: SiLU,
51    in_channels: usize,
52    #[allow(dead_code)]
53    out_channels: usize,
54    training: bool,
55}
56
57impl<T: Float> ResnetBlock2D<T> {
58    /// Build a randomly-initialized `ResnetBlock2D`.
59    ///
60    /// # Errors
61    ///
62    /// Returns [`FerrotorchError::InvalidArgument`] on bad channel /
63    /// group config.
64    pub fn new(
65        in_channels: usize,
66        out_channels: usize,
67        norm_num_groups: usize,
68        eps: f64,
69    ) -> FerrotorchResult<Self> {
70        let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
71        let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
72        let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
73        let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
74        let conv_shortcut = if in_channels == out_channels {
75            None
76        } else {
77            Some(Conv2d::<T>::new(
78                in_channels,
79                out_channels,
80                (1, 1),
81                (1, 1),
82                (0, 0),
83                true,
84            )?)
85        };
86        Ok(Self {
87            norm1,
88            conv1,
89            norm2,
90            conv2,
91            conv_shortcut,
92            activation: SiLU::new(),
93            in_channels,
94            out_channels,
95            training: false,
96        })
97    }
98}
99
100impl<T: Float> Module<T> for ResnetBlock2D<T> {
101    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
102        if input.ndim() != 4 || input.shape()[1] != self.in_channels {
103            return Err(FerrotorchError::ShapeMismatch {
104                message: format!(
105                    "ResnetBlock2D::forward: expected [B, {}, H, W], got {:?}",
106                    self.in_channels,
107                    input.shape()
108                ),
109            });
110        }
111        // h = norm1(x); silu; conv1.
112        let mut h = self.norm1.forward(input)?;
113        h = self.activation.forward(&h)?;
114        h = self.conv1.forward(&h)?;
115        // h = norm2(h); silu; conv2 (no dropout at eval / inference time).
116        h = self.norm2.forward(&h)?;
117        h = self.activation.forward(&h)?;
118        h = self.conv2.forward(&h)?;
119        // Residual (optionally projected via 1x1 conv).
120        let res = if let Some(sc) = &self.conv_shortcut {
121            sc.forward(input)?
122        } else {
123            input.clone()
124        };
125        // out = h + res. `output_scale_factor=1.0` in the SD VAE so no
126        // division on the way out.
127        ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
128    }
129
130    fn parameters(&self) -> Vec<&Parameter<T>> {
131        let mut out = Vec::new();
132        out.extend(self.norm1.parameters());
133        out.extend(self.conv1.parameters());
134        out.extend(self.norm2.parameters());
135        out.extend(self.conv2.parameters());
136        if let Some(sc) = &self.conv_shortcut {
137            out.extend(sc.parameters());
138        }
139        out
140    }
141
142    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
143        let mut out = Vec::new();
144        out.extend(self.norm1.parameters_mut());
145        out.extend(self.conv1.parameters_mut());
146        out.extend(self.norm2.parameters_mut());
147        out.extend(self.conv2.parameters_mut());
148        if let Some(sc) = &mut self.conv_shortcut {
149            out.extend(sc.parameters_mut());
150        }
151        out
152    }
153
154    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
155        let mut out = Vec::new();
156        for (n, p) in self.norm1.named_parameters() {
157            out.push((format!("norm1.{n}"), p));
158        }
159        for (n, p) in self.conv1.named_parameters() {
160            out.push((format!("conv1.{n}"), p));
161        }
162        for (n, p) in self.norm2.named_parameters() {
163            out.push((format!("norm2.{n}"), p));
164        }
165        for (n, p) in self.conv2.named_parameters() {
166            out.push((format!("conv2.{n}"), p));
167        }
168        if let Some(sc) = &self.conv_shortcut {
169            for (n, p) in sc.named_parameters() {
170                out.push((format!("conv_shortcut.{n}"), p));
171            }
172        }
173        out
174    }
175
176    fn train(&mut self) {
177        self.training = true;
178    }
179    fn eval(&mut self) {
180        self.training = false;
181    }
182    fn is_training(&self) -> bool {
183        self.training
184    }
185
186    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
187        let extract = |prefix: &str| -> StateDict<T> {
188            let p = format!("{prefix}.");
189            state
190                .iter()
191                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
192                .collect()
193        };
194        if strict {
195            for k in state.keys() {
196                let ok = k.starts_with("norm1.")
197                    || k.starts_with("conv1.")
198                    || k.starts_with("norm2.")
199                    || k.starts_with("conv2.")
200                    || k.starts_with("conv_shortcut.");
201                if !ok {
202                    return Err(FerrotorchError::InvalidArgument {
203                        message: format!("unexpected key in ResnetBlock2D state_dict: \"{k}\""),
204                    });
205                }
206            }
207        }
208        self.norm1.load_state_dict(&extract("norm1"), strict)?;
209        self.conv1.load_state_dict(&extract("conv1"), strict)?;
210        self.norm2.load_state_dict(&extract("norm2"), strict)?;
211        self.conv2.load_state_dict(&extract("conv2"), strict)?;
212        if let Some(sc) = self.conv_shortcut.as_mut() {
213            sc.load_state_dict(&extract("conv_shortcut"), strict)?;
214        }
215        Ok(())
216    }
217}
218
219// ---------------------------------------------------------------------------
220// AttnBlock2D
221// ---------------------------------------------------------------------------
222
223/// Single-head spatial self-attention with residual + GroupNorm — the
224/// VAE mid-block attention.
225///
226/// Matches `diffusers.models.attention_processor.Attention` configured
227/// with:
228///
229///   - `heads = in_channels / attention_head_dim = 512 / 512 = 1`
230///   - `norm_num_groups = 32` (so `group_norm` is enabled)
231///   - `residual_connection = True`
232///   - `bias = True`, `out_bias = True`
233///   - `eps = 1e-6`
234///
235/// State-dict layout (HF diffusers):
236///
237/// ```text
238/// group_norm.{weight,bias}    [C], [C]
239/// to_q.{weight,bias}          [C, C], [C]
240/// to_k.{weight,bias}          [C, C], [C]
241/// to_v.{weight,bias}          [C, C], [C]
242/// to_out.0.{weight,bias}      [C, C], [C]      // to_out[1] is Dropout (no params)
243/// ```
244#[derive(Debug)]
245pub struct AttnBlock2D<T: Float> {
246    /// GroupNorm over the channel axis.
247    pub group_norm: GroupNorm<T>,
248    /// Query projection `Linear(C -> C, bias)`.
249    pub to_q: Linear<T>,
250    /// Key projection `Linear(C -> C, bias)`.
251    pub to_k: Linear<T>,
252    /// Value projection `Linear(C -> C, bias)`.
253    pub to_v: Linear<T>,
254    /// Output projection `to_out[0] = Linear(C -> C, bias)`.
255    pub to_out_0: Linear<T>,
256    channels: usize,
257    training: bool,
258}
259
260impl<T: Float> AttnBlock2D<T> {
261    /// Build a randomly-initialized `AttnBlock2D`.
262    ///
263    /// # Errors
264    ///
265    /// Returns [`FerrotorchError::InvalidArgument`] when `channels`
266    /// is not divisible by `norm_num_groups` (the GroupNorm constructor
267    /// surfaces this).
268    pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
269        let group_norm = GroupNorm::<T>::new(norm_num_groups, channels, eps, true)?;
270        let to_q = Linear::<T>::new(channels, channels, true)?;
271        let to_k = Linear::<T>::new(channels, channels, true)?;
272        let to_v = Linear::<T>::new(channels, channels, true)?;
273        let to_out_0 = Linear::<T>::new(channels, channels, true)?;
274        Ok(Self {
275            group_norm,
276            to_q,
277            to_k,
278            to_v,
279            to_out_0,
280            channels,
281            training: false,
282        })
283    }
284}
285
286impl<T: Float> Module<T> for AttnBlock2D<T> {
287    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
288        if input.ndim() != 4 || input.shape()[1] != self.channels {
289            return Err(FerrotorchError::ShapeMismatch {
290                message: format!(
291                    "AttnBlock2D::forward: expected [B, {}, H, W], got {:?}",
292                    self.channels,
293                    input.shape()
294                ),
295            });
296        }
297        let b = input.shape()[0];
298        let c = input.shape()[1];
299        let h = input.shape()[2];
300        let w = input.shape()[3];
301        let hw = h * w;
302
303        // Residual (kept in the spatial layout).
304        let residual = input.clone();
305
306        // -- 1. Reshape to [B, HW, C]:
307        //       diffusers does `view(B, C, HW).transpose(1, 2)`.
308        //       Then it group-norms on the channel axis by re-transposing
309        //       back to [B, C, HW], norming, and re-transposing.
310        let hidden = input
311            .reshape_t(&[b as isize, c as isize, hw as isize])?
312            .transpose(1, 2)?
313            .contiguous()?;
314        // GroupNorm: input [B, C, HW] -> [B, HW, C] after the final transpose.
315        let normed_hwc = self
316            .group_norm
317            .forward(&hidden.transpose(1, 2)?.contiguous()?)?
318            .transpose(1, 2)?
319            .contiguous()?;
320
321        // -- 2. q / k / v projections (Linear over the trailing C dim).
322        let q = self.to_q.forward(&normed_hwc)?; // [B, HW, C]
323        let k = self.to_k.forward(&normed_hwc)?; // [B, HW, C]
324        let v = self.to_v.forward(&normed_hwc)?; // [B, HW, C]
325
326        // -- 3. Single-head attention:
327        //          scores = q @ k^T * (1/sqrt(C))
328        //          probs  = softmax(scores, dim=-1)
329        //          out    = probs @ v
330        //       For single-head the head-split is a no-op, so this is
331        //       just a per-batch BMM. `bmm` handles the [B, HW, HW]
332        //       intermediate without materialising a head axis.
333        let scale = (c as f64).sqrt().recip();
334        let scale_t = T::from(scale).ok_or_else(|| FerrotorchError::InvalidArgument {
335            message: "AttnBlock2D::forward: failed to cast attention scale into Float".into(),
336        })?;
337        let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
338        let k_t = k.transpose(1, 2)?.contiguous()?; // [B, C, HW]
339        let scores = q.bmm(&k_t)?; // [B, HW, HW]
340        let scores_scaled =
341            ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
342        let probs = scores_scaled.softmax()?; // dim = -1 by default
343        let attended = probs.bmm(&v)?; // [B, HW, C]
344
345        // -- 4. Output projection.
346        let projected = self.to_out_0.forward(&attended)?; // [B, HW, C]
347
348        // -- 5. Back to [B, C, H, W] and add residual.
349        let back = projected
350            .transpose(1, 2)? // [B, C, HW]
351            .reshape_t(&[b as isize, c as isize, h as isize, w as isize])?
352            .contiguous()?;
353        ferrotorch_core::grad_fns::arithmetic::add(&back, &residual)
354    }
355
356    fn parameters(&self) -> Vec<&Parameter<T>> {
357        let mut out = Vec::new();
358        out.extend(self.group_norm.parameters());
359        out.extend(self.to_q.parameters());
360        out.extend(self.to_k.parameters());
361        out.extend(self.to_v.parameters());
362        out.extend(self.to_out_0.parameters());
363        out
364    }
365
366    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
367        let mut out = Vec::new();
368        out.extend(self.group_norm.parameters_mut());
369        out.extend(self.to_q.parameters_mut());
370        out.extend(self.to_k.parameters_mut());
371        out.extend(self.to_v.parameters_mut());
372        out.extend(self.to_out_0.parameters_mut());
373        out
374    }
375
376    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
377        let mut out = Vec::new();
378        for (n, p) in self.group_norm.named_parameters() {
379            out.push((format!("group_norm.{n}"), p));
380        }
381        for (n, p) in self.to_q.named_parameters() {
382            out.push((format!("to_q.{n}"), p));
383        }
384        for (n, p) in self.to_k.named_parameters() {
385            out.push((format!("to_k.{n}"), p));
386        }
387        for (n, p) in self.to_v.named_parameters() {
388            out.push((format!("to_v.{n}"), p));
389        }
390        for (n, p) in self.to_out_0.named_parameters() {
391            out.push((format!("to_out.0.{n}"), p));
392        }
393        out
394    }
395
396    fn train(&mut self) {
397        self.training = true;
398    }
399    fn eval(&mut self) {
400        self.training = false;
401    }
402    fn is_training(&self) -> bool {
403        self.training
404    }
405
406    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
407        let extract = |prefix: &str| -> StateDict<T> {
408            let p = format!("{prefix}.");
409            state
410                .iter()
411                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
412                .collect()
413        };
414        if strict {
415            for k in state.keys() {
416                let ok = k.starts_with("group_norm.")
417                    || k.starts_with("to_q.")
418                    || k.starts_with("to_k.")
419                    || k.starts_with("to_v.")
420                    || k.starts_with("to_out.0.");
421                if !ok {
422                    return Err(FerrotorchError::InvalidArgument {
423                        message: format!("unexpected key in AttnBlock2D state_dict: \"{k}\""),
424                    });
425                }
426            }
427        }
428        self.group_norm
429            .load_state_dict(&extract("group_norm"), strict)?;
430        self.to_q.load_state_dict(&extract("to_q"), strict)?;
431        self.to_k.load_state_dict(&extract("to_k"), strict)?;
432        self.to_v.load_state_dict(&extract("to_v"), strict)?;
433        self.to_out_0
434            .load_state_dict(&extract("to_out.0"), strict)?;
435        Ok(())
436    }
437}
438
439// ---------------------------------------------------------------------------
440// Upsample2D
441// ---------------------------------------------------------------------------
442
443/// Diffusers-style `Upsample2D` — nearest-neighbor 2x interpolation
444/// followed by a `Conv2d(C, C, k=3, pad=1, bias=True)`.
445///
446/// State-dict key: `conv.{weight,bias}` (matching `name="conv"` in the
447/// diffusers reference).
448#[derive(Debug)]
449pub struct Upsample2D<T: Float> {
450    /// Output conv (also serves as the "post-upsample smoothing" filter).
451    pub conv: Conv2d<T>,
452    channels: usize,
453    training: bool,
454}
455
456impl<T: Float> Upsample2D<T> {
457    /// Build a randomly-initialized `Upsample2D` (`C -> C`, k=3, pad=1).
458    ///
459    /// # Errors
460    ///
461    /// Returns the underlying [`FerrotorchError`] on bad conv config.
462    pub fn new(channels: usize) -> FerrotorchResult<Self> {
463        let conv = Conv2d::<T>::new(channels, channels, (3, 3), (1, 1), (1, 1), true)?;
464        Ok(Self {
465            conv,
466            channels,
467            training: false,
468        })
469    }
470}
471
472impl<T: Float> Module<T> for Upsample2D<T> {
473    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
474        if input.ndim() != 4 || input.shape()[1] != self.channels {
475            return Err(FerrotorchError::ShapeMismatch {
476                message: format!(
477                    "Upsample2D::forward: expected [B, {}, H, W], got {:?}",
478                    self.channels,
479                    input.shape()
480                ),
481            });
482        }
483        let h = input.shape()[2];
484        let w = input.shape()[3];
485        let up = Upsample::new([h * 2, w * 2], InterpolateMode::Nearest);
486        let upsampled = up.forward(input)?;
487        self.conv.forward(&upsampled)
488    }
489
490    fn parameters(&self) -> Vec<&Parameter<T>> {
491        self.conv.parameters()
492    }
493    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
494        self.conv.parameters_mut()
495    }
496    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
497        self.conv
498            .named_parameters()
499            .into_iter()
500            .map(|(n, p)| (format!("conv.{n}"), p))
501            .collect()
502    }
503    fn train(&mut self) {
504        self.training = true;
505    }
506    fn eval(&mut self) {
507        self.training = false;
508    }
509    fn is_training(&self) -> bool {
510        self.training
511    }
512    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
513        let conv_sd: StateDict<T> = state
514            .iter()
515            .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
516            .collect();
517        if strict {
518            for k in state.keys() {
519                if !k.starts_with("conv.") {
520                    return Err(FerrotorchError::InvalidArgument {
521                        message: format!("unexpected key in Upsample2D state_dict: \"{k}\""),
522                    });
523                }
524            }
525        }
526        self.conv.load_state_dict(&conv_sd, strict)
527    }
528}
529
530// ---------------------------------------------------------------------------
531// Downsample2D
532// ---------------------------------------------------------------------------
533
534/// Diffusers-style `Downsample2D` — a single `Conv2d(C, C, k=3, stride=2,
535/// pad=1, bias=True)`.
536///
537/// SD-1.5 uses `use_conv=True` and `padding=1`, so there is no separate
538/// pre-conv padding step. State-dict key: `conv.{weight,bias}`.
539#[derive(Debug)]
540pub struct Downsample2D<T: Float> {
541    /// Output conv (k=3, stride=2, pad=1).
542    pub conv: Conv2d<T>,
543    channels: usize,
544    training: bool,
545}
546
547impl<T: Float> Downsample2D<T> {
548    /// Build a randomly-initialized `Downsample2D` (`C -> C`).
549    ///
550    /// # Errors
551    ///
552    /// Returns the underlying [`FerrotorchError`] on bad conv config.
553    pub fn new(channels: usize) -> FerrotorchResult<Self> {
554        let conv = Conv2d::<T>::new(channels, channels, (3, 3), (2, 2), (1, 1), true)?;
555        Ok(Self {
556            conv,
557            channels,
558            training: false,
559        })
560    }
561}
562
563impl<T: Float> Module<T> for Downsample2D<T> {
564    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
565        if input.ndim() != 4 || input.shape()[1] != self.channels {
566            return Err(FerrotorchError::ShapeMismatch {
567                message: format!(
568                    "Downsample2D::forward: expected [B, {}, H, W], got {:?}",
569                    self.channels,
570                    input.shape()
571                ),
572            });
573        }
574        self.conv.forward(input)
575    }
576    fn parameters(&self) -> Vec<&Parameter<T>> {
577        self.conv.parameters()
578    }
579    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
580        self.conv.parameters_mut()
581    }
582    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
583        self.conv
584            .named_parameters()
585            .into_iter()
586            .map(|(n, p)| (format!("conv.{n}"), p))
587            .collect()
588    }
589    fn train(&mut self) {
590        self.training = true;
591    }
592    fn eval(&mut self) {
593        self.training = false;
594    }
595    fn is_training(&self) -> bool {
596        self.training
597    }
598    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
599        let conv_sd: StateDict<T> = state
600            .iter()
601            .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
602            .collect();
603        if strict {
604            for k in state.keys() {
605                if !k.starts_with("conv.") {
606                    return Err(FerrotorchError::InvalidArgument {
607                        message: format!("unexpected key in Downsample2D state_dict: \"{k}\""),
608                    });
609                }
610            }
611        }
612        self.conv.load_state_dict(&conv_sd, strict)
613    }
614}
615
616// ---------------------------------------------------------------------------
617// UpDecoderBlock2D
618// ---------------------------------------------------------------------------
619
620/// `UpDecoderBlock2D` — a stack of `layers_per_block + 1` resnets at
621/// `out_channels`, optionally followed by an `Upsample2D`.
622///
623/// State-dict key layout:
624///
625/// ```text
626/// resnets.{i}.<resnet keys>    i in 0..(layers_per_block + 1)
627/// upsamplers.0.<upsample keys>  (present iff add_upsample)
628/// ```
629#[derive(Debug)]
630pub struct UpDecoderBlock2D<T: Float> {
631    /// `layers_per_block + 1` resnets at `out_channels`.
632    pub resnets: Vec<ResnetBlock2D<T>>,
633    /// Optional upsample (absent on the last decoder block).
634    pub upsamplers_0: Option<Upsample2D<T>>,
635    training: bool,
636}
637
638impl<T: Float> UpDecoderBlock2D<T> {
639    /// Build a randomly-initialized `UpDecoderBlock2D`.
640    ///
641    /// # Errors
642    ///
643    /// Returns [`FerrotorchError::InvalidArgument`] on bad channel /
644    /// group config.
645    pub fn new(
646        in_channels: usize,
647        out_channels: usize,
648        num_resnets: usize,
649        norm_num_groups: usize,
650        eps: f64,
651        add_upsample: bool,
652    ) -> FerrotorchResult<Self> {
653        if num_resnets == 0 {
654            return Err(FerrotorchError::InvalidArgument {
655                message: "UpDecoderBlock2D: num_resnets must be > 0".into(),
656            });
657        }
658        let mut resnets = Vec::with_capacity(num_resnets);
659        for i in 0..num_resnets {
660            let in_c = if i == 0 { in_channels } else { out_channels };
661            resnets.push(ResnetBlock2D::<T>::new(
662                in_c,
663                out_channels,
664                norm_num_groups,
665                eps,
666            )?);
667        }
668        let upsamplers_0 = if add_upsample {
669            Some(Upsample2D::<T>::new(out_channels)?)
670        } else {
671            None
672        };
673        Ok(Self {
674            resnets,
675            upsamplers_0,
676            training: false,
677        })
678    }
679}
680
681impl<T: Float> Module<T> for UpDecoderBlock2D<T> {
682    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
683        let mut h = input.clone();
684        for r in &self.resnets {
685            h = r.forward(&h)?;
686        }
687        if let Some(u) = &self.upsamplers_0 {
688            h = u.forward(&h)?;
689        }
690        Ok(h)
691    }
692
693    fn parameters(&self) -> Vec<&Parameter<T>> {
694        let mut out = Vec::new();
695        for r in &self.resnets {
696            out.extend(r.parameters());
697        }
698        if let Some(u) = &self.upsamplers_0 {
699            out.extend(u.parameters());
700        }
701        out
702    }
703
704    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
705        let mut out = Vec::new();
706        for r in &mut self.resnets {
707            out.extend(r.parameters_mut());
708        }
709        if let Some(u) = self.upsamplers_0.as_mut() {
710            out.extend(u.parameters_mut());
711        }
712        out
713    }
714
715    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
716        let mut out = Vec::new();
717        for (i, r) in self.resnets.iter().enumerate() {
718            for (n, p) in r.named_parameters() {
719                out.push((format!("resnets.{i}.{n}"), p));
720            }
721        }
722        if let Some(u) = &self.upsamplers_0 {
723            for (n, p) in u.named_parameters() {
724                out.push((format!("upsamplers.0.{n}"), p));
725            }
726        }
727        out
728    }
729
730    fn train(&mut self) {
731        self.training = true;
732        for r in &mut self.resnets {
733            r.train();
734        }
735        if let Some(u) = self.upsamplers_0.as_mut() {
736            u.train();
737        }
738    }
739    fn eval(&mut self) {
740        self.training = false;
741        for r in &mut self.resnets {
742            r.eval();
743        }
744        if let Some(u) = self.upsamplers_0.as_mut() {
745            u.eval();
746        }
747    }
748    fn is_training(&self) -> bool {
749        self.training
750    }
751
752    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
753        let extract = |prefix: &str| -> StateDict<T> {
754            let p = format!("{prefix}.");
755            state
756                .iter()
757                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
758                .collect()
759        };
760        if strict {
761            for k in state.keys() {
762                let ok =
763                    k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
764                if !ok {
765                    return Err(FerrotorchError::InvalidArgument {
766                        message: format!(
767                            "unexpected key in UpDecoderBlock2D state_dict: \"{k}\""
768                        ),
769                    });
770                }
771            }
772        }
773        for (i, r) in self.resnets.iter_mut().enumerate() {
774            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
775        }
776        if let Some(u) = self.upsamplers_0.as_mut() {
777            u.load_state_dict(&extract("upsamplers.0"), strict)?;
778        } else if strict {
779            // Verify no upsamplers keys leaked through if we have no upsample.
780            for k in state.keys() {
781                if k.starts_with("upsamplers.") {
782                    return Err(FerrotorchError::InvalidArgument {
783                        message: format!(
784                            "UpDecoderBlock2D has no upsampler but state_dict contains \"{k}\""
785                        ),
786                    });
787                }
788            }
789        }
790        Ok(())
791    }
792}
793
794// ---------------------------------------------------------------------------
795// DownEncoderBlock2D
796// ---------------------------------------------------------------------------
797
798/// `DownEncoderBlock2D` — a stack of `layers_per_block` resnets at
799/// `out_channels`, optionally followed by a `Downsample2D`.
800///
801/// This is the encoder-side mirror of [`UpDecoderBlock2D`]. The
802/// asymmetry in resnet count (encoder uses `layers_per_block` resnets,
803/// decoder uses `layers_per_block + 1`) matches the diffusers
804/// `AutoencoderKL` convention exactly.
805///
806/// State-dict key layout:
807///
808/// ```text
809/// resnets.{i}.<resnet keys>      i in 0..layers_per_block
810/// downsamplers.0.<downsample keys>   (present iff add_downsample)
811/// ```
812#[derive(Debug)]
813pub struct DownEncoderBlock2D<T: Float> {
814    /// `layers_per_block` resnets at `out_channels`. The first resnet
815    /// projects `in_channels -> out_channels` (via its 1x1 shortcut
816    /// when the channels differ); subsequent resnets are `out_channels`
817    /// throughout.
818    pub resnets: Vec<ResnetBlock2D<T>>,
819    /// Optional downsample (absent on the deepest encoder block, which
820    /// holds spatial resolution before the mid-block).
821    pub downsamplers_0: Option<Downsample2D<T>>,
822    training: bool,
823}
824
825impl<T: Float> DownEncoderBlock2D<T> {
826    /// Build a randomly-initialized `DownEncoderBlock2D`.
827    ///
828    /// # Errors
829    ///
830    /// Returns [`FerrotorchError::InvalidArgument`] when `num_resnets`
831    /// is zero, or any underlying [`FerrotorchError`] on bad channel /
832    /// group config.
833    pub fn new(
834        in_channels: usize,
835        out_channels: usize,
836        num_resnets: usize,
837        norm_num_groups: usize,
838        eps: f64,
839        add_downsample: bool,
840    ) -> FerrotorchResult<Self> {
841        if num_resnets == 0 {
842            return Err(FerrotorchError::InvalidArgument {
843                message: "DownEncoderBlock2D: num_resnets must be > 0".into(),
844            });
845        }
846        let mut resnets = Vec::with_capacity(num_resnets);
847        for i in 0..num_resnets {
848            let in_c = if i == 0 { in_channels } else { out_channels };
849            resnets.push(ResnetBlock2D::<T>::new(
850                in_c,
851                out_channels,
852                norm_num_groups,
853                eps,
854            )?);
855        }
856        let downsamplers_0 = if add_downsample {
857            Some(Downsample2D::<T>::new(out_channels)?)
858        } else {
859            None
860        };
861        Ok(Self {
862            resnets,
863            downsamplers_0,
864            training: false,
865        })
866    }
867}
868
869impl<T: Float> Module<T> for DownEncoderBlock2D<T> {
870    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
871        let mut h = input.clone();
872        for r in &self.resnets {
873            h = r.forward(&h)?;
874        }
875        if let Some(d) = &self.downsamplers_0 {
876            h = d.forward(&h)?;
877        }
878        Ok(h)
879    }
880
881    fn parameters(&self) -> Vec<&Parameter<T>> {
882        let mut out = Vec::new();
883        for r in &self.resnets {
884            out.extend(r.parameters());
885        }
886        if let Some(d) = &self.downsamplers_0 {
887            out.extend(d.parameters());
888        }
889        out
890    }
891
892    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
893        let mut out = Vec::new();
894        for r in &mut self.resnets {
895            out.extend(r.parameters_mut());
896        }
897        if let Some(d) = self.downsamplers_0.as_mut() {
898            out.extend(d.parameters_mut());
899        }
900        out
901    }
902
903    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
904        let mut out = Vec::new();
905        for (i, r) in self.resnets.iter().enumerate() {
906            for (n, p) in r.named_parameters() {
907                out.push((format!("resnets.{i}.{n}"), p));
908            }
909        }
910        if let Some(d) = &self.downsamplers_0 {
911            for (n, p) in d.named_parameters() {
912                out.push((format!("downsamplers.0.{n}"), p));
913            }
914        }
915        out
916    }
917
918    fn train(&mut self) {
919        self.training = true;
920        for r in &mut self.resnets {
921            r.train();
922        }
923        if let Some(d) = self.downsamplers_0.as_mut() {
924            d.train();
925        }
926    }
927    fn eval(&mut self) {
928        self.training = false;
929        for r in &mut self.resnets {
930            r.eval();
931        }
932        if let Some(d) = self.downsamplers_0.as_mut() {
933            d.eval();
934        }
935    }
936    fn is_training(&self) -> bool {
937        self.training
938    }
939
940    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
941        let extract = |prefix: &str| -> StateDict<T> {
942            let p = format!("{prefix}.");
943            state
944                .iter()
945                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
946                .collect()
947        };
948        if strict {
949            for k in state.keys() {
950                let ok =
951                    k.starts_with("resnets.") || k.starts_with("downsamplers.0.");
952                if !ok {
953                    return Err(FerrotorchError::InvalidArgument {
954                        message: format!(
955                            "unexpected key in DownEncoderBlock2D state_dict: \"{k}\""
956                        ),
957                    });
958                }
959            }
960        }
961        for (i, r) in self.resnets.iter_mut().enumerate() {
962            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
963        }
964        if let Some(d) = self.downsamplers_0.as_mut() {
965            d.load_state_dict(&extract("downsamplers.0"), strict)?;
966        } else if strict {
967            for k in state.keys() {
968                if k.starts_with("downsamplers.") {
969                    return Err(FerrotorchError::InvalidArgument {
970                        message: format!(
971                            "DownEncoderBlock2D has no downsampler but state_dict contains \"{k}\""
972                        ),
973                    });
974                }
975            }
976        }
977        Ok(())
978    }
979}
980
981// ---------------------------------------------------------------------------
982// UNetMidBlock2D (VAE flavour)
983// ---------------------------------------------------------------------------
984
985/// `UNetMidBlock2D` configured the way the SD VAE uses it:
986///
987/// ```text
988/// resnets[0] -> attentions[0] -> resnets[1]
989/// ```
990///
991/// 2 resnets, 1 attention; all at `in_channels = out_channels = 512` for
992/// SD 1.5. Diffusers always has `len(resnets) == num_layers + 1` and
993/// `len(attentions) == num_layers`; in the VAE `num_layers = 1`, so two
994/// resnets and one attention.
995#[derive(Debug)]
996pub struct UNetMidBlock2D<T: Float> {
997    /// Pre-attention + post-attention resnets (size 2 for SD VAE).
998    pub resnets: Vec<ResnetBlock2D<T>>,
999    /// Single attention block.
1000    pub attentions: Vec<AttnBlock2D<T>>,
1001    channels: usize,
1002    training: bool,
1003}
1004
1005impl<T: Float> UNetMidBlock2D<T> {
1006    /// Build a randomly-initialized VAE-flavoured `UNetMidBlock2D`.
1007    ///
1008    /// # Errors
1009    ///
1010    /// Returns the underlying [`FerrotorchError`] on bad config.
1011    pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
1012        // SD VAE: num_layers=1 => 2 resnets + 1 attention.
1013        let resnets = vec![
1014            ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1015            ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1016        ];
1017        let attentions = vec![AttnBlock2D::<T>::new(channels, norm_num_groups, eps)?];
1018        Ok(Self {
1019            resnets,
1020            attentions,
1021            channels,
1022            training: false,
1023        })
1024    }
1025}
1026
1027impl<T: Float> Module<T> for UNetMidBlock2D<T> {
1028    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1029        // diffusers UNetMidBlock2D.forward (no temb):
1030        //   h = resnets[0](x)
1031        //   for (attn, resnet) in zip(attentions, resnets[1:]):
1032        //       h = attn(h); h = resnet(h)
1033        let mut h = self.resnets[0].forward(input)?;
1034        for (attn, resnet) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
1035            h = attn.forward(&h)?;
1036            h = resnet.forward(&h)?;
1037        }
1038        Ok(h)
1039    }
1040
1041    fn parameters(&self) -> Vec<&Parameter<T>> {
1042        let mut out = Vec::new();
1043        // The HF diffusers state-dict key order is `attentions` before
1044        // `resnets`. We match the ordering used for state-dict
1045        // round-trip — see `named_parameters()` for the canonical
1046        // sequence.
1047        for a in &self.attentions {
1048            out.extend(a.parameters());
1049        }
1050        for r in &self.resnets {
1051            out.extend(r.parameters());
1052        }
1053        out
1054    }
1055
1056    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1057        let mut out = Vec::new();
1058        for a in &mut self.attentions {
1059            out.extend(a.parameters_mut());
1060        }
1061        for r in &mut self.resnets {
1062            out.extend(r.parameters_mut());
1063        }
1064        out
1065    }
1066
1067    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1068        let mut out = Vec::new();
1069        // Match the order used by `safetensors` listings of the HF VAE:
1070        // `attentions.*` first, then `resnets.*`.
1071        for (i, a) in self.attentions.iter().enumerate() {
1072            for (n, p) in a.named_parameters() {
1073                out.push((format!("attentions.{i}.{n}"), p));
1074            }
1075        }
1076        for (i, r) in self.resnets.iter().enumerate() {
1077            for (n, p) in r.named_parameters() {
1078                out.push((format!("resnets.{i}.{n}"), p));
1079            }
1080        }
1081        out
1082    }
1083
1084    fn train(&mut self) {
1085        self.training = true;
1086        for r in &mut self.resnets {
1087            r.train();
1088        }
1089        for a in &mut self.attentions {
1090            a.train();
1091        }
1092    }
1093    fn eval(&mut self) {
1094        self.training = false;
1095        for r in &mut self.resnets {
1096            r.eval();
1097        }
1098        for a in &mut self.attentions {
1099            a.eval();
1100        }
1101    }
1102    fn is_training(&self) -> bool {
1103        self.training
1104    }
1105
1106    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1107        let extract = |prefix: &str| -> StateDict<T> {
1108            let p = format!("{prefix}.");
1109            state
1110                .iter()
1111                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1112                .collect()
1113        };
1114        if strict {
1115            for k in state.keys() {
1116                let ok = k.starts_with("attentions.") || k.starts_with("resnets.");
1117                if !ok {
1118                    return Err(FerrotorchError::InvalidArgument {
1119                        message: format!(
1120                            "unexpected key in UNetMidBlock2D state_dict: \"{k}\""
1121                        ),
1122                    });
1123                }
1124            }
1125        }
1126        for (i, a) in self.attentions.iter_mut().enumerate() {
1127            a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
1128        }
1129        for (i, r) in self.resnets.iter_mut().enumerate() {
1130            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
1131        }
1132        let _ = self.channels; // silence dead_code warning under some lint configs
1133        let _: HashMap<String, Tensor<T>> = HashMap::new(); // pull HashMap into scope
1134        Ok(())
1135    }
1136}
1137
1138#[cfg(test)]
1139mod tests {
1140    use super::*;
1141    use ferrotorch_core::TensorStorage;
1142
1143    #[test]
1144    fn resnet_same_channels_no_shortcut() {
1145        let r = ResnetBlock2D::<f32>::new(32, 32, 32, 1e-6).unwrap();
1146        assert!(r.conv_shortcut.is_none());
1147        let x = Tensor::from_storage(
1148            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1149            vec![1, 32, 4, 4],
1150            false,
1151        )
1152        .unwrap();
1153        let y = r.forward(&x).unwrap();
1154        assert_eq!(y.shape(), &[1, 32, 4, 4]);
1155        for &v in y.data().unwrap() {
1156            assert!(v.is_finite());
1157        }
1158    }
1159
1160    #[test]
1161    fn resnet_different_channels_has_shortcut() {
1162        let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1163        assert!(r.conv_shortcut.is_some());
1164        let x = Tensor::from_storage(
1165            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1166            vec![1, 32, 4, 4],
1167            false,
1168        )
1169        .unwrap();
1170        let y = r.forward(&x).unwrap();
1171        assert_eq!(y.shape(), &[1, 64, 4, 4]);
1172    }
1173
1174    #[test]
1175    fn resnet_named_parameters_layout() {
1176        let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1177        let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
1178        for k in [
1179            "norm1.weight",
1180            "norm1.bias",
1181            "conv1.weight",
1182            "conv1.bias",
1183            "norm2.weight",
1184            "norm2.bias",
1185            "conv2.weight",
1186            "conv2.bias",
1187            "conv_shortcut.weight",
1188            "conv_shortcut.bias",
1189        ] {
1190            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1191        }
1192    }
1193
1194    #[test]
1195    fn attn_shape_and_residual() {
1196        // 32 channels, 4 groups => 8 channels per group; small spatial.
1197        let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1198        let x = Tensor::from_storage(
1199            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1200            vec![1, 32, 4, 4],
1201            false,
1202        )
1203        .unwrap();
1204        let y = a.forward(&x).unwrap();
1205        assert_eq!(y.shape(), &[1, 32, 4, 4]);
1206        for &v in y.data().unwrap() {
1207            assert!(v.is_finite(), "attn output non-finite: {v}");
1208        }
1209    }
1210
1211    #[test]
1212    fn attn_named_parameters_layout() {
1213        let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1214        let names: Vec<String> = a.named_parameters().into_iter().map(|(n, _)| n).collect();
1215        for k in [
1216            "group_norm.weight",
1217            "group_norm.bias",
1218            "to_q.weight",
1219            "to_q.bias",
1220            "to_k.weight",
1221            "to_k.bias",
1222            "to_v.weight",
1223            "to_v.bias",
1224            "to_out.0.weight",
1225            "to_out.0.bias",
1226        ] {
1227            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1228        }
1229    }
1230
1231    #[test]
1232    fn upsample2d_doubles_spatial() {
1233        let u = Upsample2D::<f32>::new(8).unwrap();
1234        let x = Tensor::from_storage(
1235            TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1236            vec![1, 8, 3, 3],
1237            false,
1238        )
1239        .unwrap();
1240        let y = u.forward(&x).unwrap();
1241        assert_eq!(y.shape(), &[1, 8, 6, 6]);
1242    }
1243
1244    #[test]
1245    fn up_decoder_block_shape_with_upsample() {
1246        let b = UpDecoderBlock2D::<f32>::new(16, 8, 2, 4, 1e-6, true).unwrap();
1247        let x = Tensor::from_storage(
1248            TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1249            vec![1, 16, 3, 3],
1250            false,
1251        )
1252        .unwrap();
1253        let y = b.forward(&x).unwrap();
1254        assert_eq!(y.shape(), &[1, 8, 6, 6]);
1255    }
1256
1257    #[test]
1258    fn up_decoder_block_shape_no_upsample() {
1259        let b = UpDecoderBlock2D::<f32>::new(8, 8, 2, 4, 1e-6, false).unwrap();
1260        let x = Tensor::from_storage(
1261            TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1262            vec![1, 8, 3, 3],
1263            false,
1264        )
1265        .unwrap();
1266        let y = b.forward(&x).unwrap();
1267        assert_eq!(y.shape(), &[1, 8, 3, 3]);
1268    }
1269
1270    #[test]
1271    fn down_encoder_block_shape_with_downsample() {
1272        // [1, 8, 4, 4] -> resnets project 8 -> 16 -> downsample halves spatial
1273        let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1274        let x = Tensor::from_storage(
1275            TensorStorage::cpu(vec![0.05f32; 8 * 4 * 4]),
1276            vec![1, 8, 4, 4],
1277            false,
1278        )
1279        .unwrap();
1280        let y = b.forward(&x).unwrap();
1281        assert_eq!(y.shape(), &[1, 16, 2, 2]);
1282    }
1283
1284    #[test]
1285    fn down_encoder_block_shape_no_downsample() {
1286        // Last encoder block: same channels, no downsample (spatial preserved)
1287        let b = DownEncoderBlock2D::<f32>::new(16, 16, 2, 4, 1e-6, false).unwrap();
1288        let x = Tensor::from_storage(
1289            TensorStorage::cpu(vec![0.05f32; 16 * 2 * 2]),
1290            vec![1, 16, 2, 2],
1291            false,
1292        )
1293        .unwrap();
1294        let y = b.forward(&x).unwrap();
1295        assert_eq!(y.shape(), &[1, 16, 2, 2]);
1296    }
1297
1298    #[test]
1299    fn down_encoder_block_named_parameters_layout() {
1300        let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1301        let names: Vec<String> = b.named_parameters().into_iter().map(|(n, _)| n).collect();
1302        for k in [
1303            "resnets.0.norm1.weight",
1304            "resnets.0.conv1.weight",
1305            "resnets.0.conv_shortcut.weight",
1306            "resnets.1.norm1.weight",
1307            "downsamplers.0.conv.weight",
1308            "downsamplers.0.conv.bias",
1309        ] {
1310            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1311        }
1312    }
1313
1314    #[test]
1315    fn mid_block_shape() {
1316        let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1317        let x = Tensor::from_storage(
1318            TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1319            vec![1, 16, 3, 3],
1320            false,
1321        )
1322        .unwrap();
1323        let y = m.forward(&x).unwrap();
1324        assert_eq!(y.shape(), &[1, 16, 3, 3]);
1325    }
1326
1327    #[test]
1328    fn mid_block_named_parameters_layout() {
1329        let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1330        let names: Vec<String> = m.named_parameters().into_iter().map(|(n, _)| n).collect();
1331        for k in [
1332            "attentions.0.group_norm.weight",
1333            "attentions.0.to_q.weight",
1334            "resnets.0.norm1.weight",
1335            "resnets.0.conv1.weight",
1336            "resnets.1.conv2.weight",
1337        ] {
1338            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1339        }
1340    }
1341}