Skip to main content

ferrotorch_diffusion/
blocks.rs

1//! Building blocks of the Stable-Diffusion VAE decoder.
2//!
3//! All blocks here match the diffusers `models/{resnet,upsampling,
4//! attention_processor}.py` / `models/unets/unet_2d_blocks.py` reference
5//! layout 1:1 in parameter naming and forward semantics so the upstream
6//! state dict (`runwayml/stable-diffusion-v1-5/vae/diffusion_pytorch_model.safetensors`)
7//! loads byte-for-byte.
8
9use std::collections::HashMap;
10
11use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
12use ferrotorch_nn::module::{Module, StateDict};
13use ferrotorch_nn::parameter::Parameter;
14use ferrotorch_nn::{Conv2d, GroupNorm, InterpolateMode, Linear, SiLU, Upsample};
15
16// ---------------------------------------------------------------------------
17// ResnetBlock2D
18// ---------------------------------------------------------------------------
19
20/// `ResnetBlock2D` — the building block of every UNet/VAE up/down/mid stack.
21///
22/// The VAE flavour has no time embedding (`temb_channels=None` in
23/// diffusers), so this implementation does not carry a
24/// `time_emb_proj`. The forward pass is:
25///
26/// ```text
27/// h = norm1(x); h = silu(h); h = conv1(h)
28/// h = norm2(h); h = silu(h); h = conv2(h)
29/// out = h + (x if in==out else conv_shortcut(x))
30/// ```
31///
32/// The `output_scale_factor` from the diffusers reference is always 1.0
33/// in the SD VAE so we hard-code it; if a future config requires
34/// rescaling, add an explicit field.
35#[derive(Debug)]
36pub struct ResnetBlock2D<T: Float> {
37    /// First GroupNorm (over `in_channels`).
38    pub norm1: GroupNorm<T>,
39    /// First Conv2d (`in_channels -> out_channels`, k=3, pad=1).
40    pub conv1: Conv2d<T>,
41    /// Second GroupNorm (over `out_channels`).
42    pub norm2: GroupNorm<T>,
43    /// Second Conv2d (`out_channels -> out_channels`, k=3, pad=1).
44    pub conv2: Conv2d<T>,
45    /// Optional 1x1 shortcut conv (present iff `in_channels !=
46    /// out_channels`).
47    pub conv_shortcut: Option<Conv2d<T>>,
48    activation: SiLU,
49    in_channels: usize,
50    #[allow(dead_code)]
51    out_channels: usize,
52    training: bool,
53}
54
55impl<T: Float> ResnetBlock2D<T> {
56    /// Build a randomly-initialized `ResnetBlock2D`.
57    ///
58    /// # Errors
59    ///
60    /// Returns [`FerrotorchError::InvalidArgument`] on bad channel /
61    /// group config.
62    pub fn new(
63        in_channels: usize,
64        out_channels: usize,
65        norm_num_groups: usize,
66        eps: f64,
67    ) -> FerrotorchResult<Self> {
68        let norm1 = GroupNorm::<T>::new(norm_num_groups, in_channels, eps, true)?;
69        let conv1 = Conv2d::<T>::new(in_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
70        let norm2 = GroupNorm::<T>::new(norm_num_groups, out_channels, eps, true)?;
71        let conv2 = Conv2d::<T>::new(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true)?;
72        let conv_shortcut = if in_channels == out_channels {
73            None
74        } else {
75            Some(Conv2d::<T>::new(
76                in_channels,
77                out_channels,
78                (1, 1),
79                (1, 1),
80                (0, 0),
81                true,
82            )?)
83        };
84        Ok(Self {
85            norm1,
86            conv1,
87            norm2,
88            conv2,
89            conv_shortcut,
90            activation: SiLU::new(),
91            in_channels,
92            out_channels,
93            training: false,
94        })
95    }
96}
97
98impl<T: Float> Module<T> for ResnetBlock2D<T> {
99    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
100        if input.ndim() != 4 || input.shape()[1] != self.in_channels {
101            return Err(FerrotorchError::ShapeMismatch {
102                message: format!(
103                    "ResnetBlock2D::forward: expected [B, {}, H, W], got {:?}",
104                    self.in_channels,
105                    input.shape()
106                ),
107            });
108        }
109        // h = norm1(x); silu; conv1.
110        let mut h = self.norm1.forward(input)?;
111        h = self.activation.forward(&h)?;
112        h = self.conv1.forward(&h)?;
113        // h = norm2(h); silu; conv2 (no dropout at eval / inference time).
114        h = self.norm2.forward(&h)?;
115        h = self.activation.forward(&h)?;
116        h = self.conv2.forward(&h)?;
117        // Residual (optionally projected via 1x1 conv).
118        let res = if let Some(sc) = &self.conv_shortcut {
119            sc.forward(input)?
120        } else {
121            input.clone()
122        };
123        // out = h + res. `output_scale_factor=1.0` in the SD VAE so no
124        // division on the way out.
125        ferrotorch_core::grad_fns::arithmetic::add(&h, &res)
126    }
127
128    fn parameters(&self) -> Vec<&Parameter<T>> {
129        let mut out = Vec::new();
130        out.extend(self.norm1.parameters());
131        out.extend(self.conv1.parameters());
132        out.extend(self.norm2.parameters());
133        out.extend(self.conv2.parameters());
134        if let Some(sc) = &self.conv_shortcut {
135            out.extend(sc.parameters());
136        }
137        out
138    }
139
140    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
141        let mut out = Vec::new();
142        out.extend(self.norm1.parameters_mut());
143        out.extend(self.conv1.parameters_mut());
144        out.extend(self.norm2.parameters_mut());
145        out.extend(self.conv2.parameters_mut());
146        if let Some(sc) = &mut self.conv_shortcut {
147            out.extend(sc.parameters_mut());
148        }
149        out
150    }
151
152    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
153        let mut out = Vec::new();
154        for (n, p) in self.norm1.named_parameters() {
155            out.push((format!("norm1.{n}"), p));
156        }
157        for (n, p) in self.conv1.named_parameters() {
158            out.push((format!("conv1.{n}"), p));
159        }
160        for (n, p) in self.norm2.named_parameters() {
161            out.push((format!("norm2.{n}"), p));
162        }
163        for (n, p) in self.conv2.named_parameters() {
164            out.push((format!("conv2.{n}"), p));
165        }
166        if let Some(sc) = &self.conv_shortcut {
167            for (n, p) in sc.named_parameters() {
168                out.push((format!("conv_shortcut.{n}"), p));
169            }
170        }
171        out
172    }
173
174    fn train(&mut self) {
175        self.training = true;
176    }
177    fn eval(&mut self) {
178        self.training = false;
179    }
180    fn is_training(&self) -> bool {
181        self.training
182    }
183
184    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
185        let extract = |prefix: &str| -> StateDict<T> {
186            let p = format!("{prefix}.");
187            state
188                .iter()
189                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
190                .collect()
191        };
192        if strict {
193            for k in state.keys() {
194                let ok = k.starts_with("norm1.")
195                    || k.starts_with("conv1.")
196                    || k.starts_with("norm2.")
197                    || k.starts_with("conv2.")
198                    || k.starts_with("conv_shortcut.");
199                if !ok {
200                    return Err(FerrotorchError::InvalidArgument {
201                        message: format!("unexpected key in ResnetBlock2D state_dict: \"{k}\""),
202                    });
203                }
204            }
205        }
206        self.norm1.load_state_dict(&extract("norm1"), strict)?;
207        self.conv1.load_state_dict(&extract("conv1"), strict)?;
208        self.norm2.load_state_dict(&extract("norm2"), strict)?;
209        self.conv2.load_state_dict(&extract("conv2"), strict)?;
210        if let Some(sc) = self.conv_shortcut.as_mut() {
211            sc.load_state_dict(&extract("conv_shortcut"), strict)?;
212        }
213        Ok(())
214    }
215}
216
217// ---------------------------------------------------------------------------
218// AttnBlock2D
219// ---------------------------------------------------------------------------
220
221/// Single-head spatial self-attention with residual + GroupNorm — the
222/// VAE mid-block attention.
223///
224/// Matches `diffusers.models.attention_processor.Attention` configured
225/// with:
226///
227///   - `heads = in_channels / attention_head_dim = 512 / 512 = 1`
228///   - `norm_num_groups = 32` (so `group_norm` is enabled)
229///   - `residual_connection = True`
230///   - `bias = True`, `out_bias = True`
231///   - `eps = 1e-6`
232///
233/// State-dict layout (HF diffusers):
234///
235/// ```text
236/// group_norm.{weight,bias}    [C], [C]
237/// to_q.{weight,bias}          [C, C], [C]
238/// to_k.{weight,bias}          [C, C], [C]
239/// to_v.{weight,bias}          [C, C], [C]
240/// to_out.0.{weight,bias}      [C, C], [C]      // to_out[1] is Dropout (no params)
241/// ```
242#[derive(Debug)]
243pub struct AttnBlock2D<T: Float> {
244    /// GroupNorm over the channel axis.
245    pub group_norm: GroupNorm<T>,
246    /// Query projection `Linear(C -> C, bias)`.
247    pub to_q: Linear<T>,
248    /// Key projection `Linear(C -> C, bias)`.
249    pub to_k: Linear<T>,
250    /// Value projection `Linear(C -> C, bias)`.
251    pub to_v: Linear<T>,
252    /// Output projection `to_out[0] = Linear(C -> C, bias)`.
253    pub to_out_0: Linear<T>,
254    channels: usize,
255    training: bool,
256}
257
258impl<T: Float> AttnBlock2D<T> {
259    /// Build a randomly-initialized `AttnBlock2D`.
260    ///
261    /// # Errors
262    ///
263    /// Returns [`FerrotorchError::InvalidArgument`] when `channels`
264    /// is not divisible by `norm_num_groups` (the GroupNorm constructor
265    /// surfaces this).
266    pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
267        let group_norm = GroupNorm::<T>::new(norm_num_groups, channels, eps, true)?;
268        let to_q = Linear::<T>::new(channels, channels, true)?;
269        let to_k = Linear::<T>::new(channels, channels, true)?;
270        let to_v = Linear::<T>::new(channels, channels, true)?;
271        let to_out_0 = Linear::<T>::new(channels, channels, true)?;
272        Ok(Self {
273            group_norm,
274            to_q,
275            to_k,
276            to_v,
277            to_out_0,
278            channels,
279            training: false,
280        })
281    }
282}
283
284impl<T: Float> Module<T> for AttnBlock2D<T> {
285    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
286        if input.ndim() != 4 || input.shape()[1] != self.channels {
287            return Err(FerrotorchError::ShapeMismatch {
288                message: format!(
289                    "AttnBlock2D::forward: expected [B, {}, H, W], got {:?}",
290                    self.channels,
291                    input.shape()
292                ),
293            });
294        }
295        let b = input.shape()[0];
296        let c = input.shape()[1];
297        let h = input.shape()[2];
298        let w = input.shape()[3];
299        let hw = h * w;
300
301        // Residual (kept in the spatial layout).
302        let residual = input.clone();
303
304        // -- 1. Reshape to [B, HW, C]:
305        //       diffusers does `view(B, C, HW).transpose(1, 2)`.
306        //       Then it group-norms on the channel axis by re-transposing
307        //       back to [B, C, HW], norming, and re-transposing.
308        let hidden = input
309            .reshape_t(&[b as isize, c as isize, hw as isize])?
310            .transpose(1, 2)?
311            .contiguous()?;
312        // GroupNorm: input [B, C, HW] -> [B, HW, C] after the final transpose.
313        let normed_hwc = self
314            .group_norm
315            .forward(&hidden.transpose(1, 2)?.contiguous()?)?
316            .transpose(1, 2)?
317            .contiguous()?;
318
319        // -- 2. q / k / v projections (Linear over the trailing C dim).
320        let q = self.to_q.forward(&normed_hwc)?; // [B, HW, C]
321        let k = self.to_k.forward(&normed_hwc)?; // [B, HW, C]
322        let v = self.to_v.forward(&normed_hwc)?; // [B, HW, C]
323
324        // -- 3. Single-head attention:
325        //          scores = q @ k^T * (1/sqrt(C))
326        //          probs  = softmax(scores, dim=-1)
327        //          out    = probs @ v
328        //       For single-head the head-split is a no-op, so this is
329        //       just a per-batch BMM. `bmm` handles the [B, HW, HW]
330        //       intermediate without materialising a head axis.
331        let scale = (c as f64).sqrt().recip();
332        let scale_t = T::from(scale).ok_or_else(|| FerrotorchError::InvalidArgument {
333            message: "AttnBlock2D::forward: failed to cast attention scale into Float".into(),
334        })?;
335        let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
336        let k_t = k.transpose(1, 2)?.contiguous()?; // [B, C, HW]
337        let scores = q.bmm(&k_t)?; // [B, HW, HW]
338        let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
339        let probs = scores_scaled.softmax()?; // dim = -1 by default
340        let attended = probs.bmm(&v)?; // [B, HW, C]
341
342        // -- 4. Output projection.
343        let projected = self.to_out_0.forward(&attended)?; // [B, HW, C]
344
345        // -- 5. Back to [B, C, H, W] and add residual.
346        let back = projected
347            .transpose(1, 2)? // [B, C, HW]
348            .reshape_t(&[b as isize, c as isize, h as isize, w as isize])?
349            .contiguous()?;
350        ferrotorch_core::grad_fns::arithmetic::add(&back, &residual)
351    }
352
353    fn parameters(&self) -> Vec<&Parameter<T>> {
354        let mut out = Vec::new();
355        out.extend(self.group_norm.parameters());
356        out.extend(self.to_q.parameters());
357        out.extend(self.to_k.parameters());
358        out.extend(self.to_v.parameters());
359        out.extend(self.to_out_0.parameters());
360        out
361    }
362
363    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
364        let mut out = Vec::new();
365        out.extend(self.group_norm.parameters_mut());
366        out.extend(self.to_q.parameters_mut());
367        out.extend(self.to_k.parameters_mut());
368        out.extend(self.to_v.parameters_mut());
369        out.extend(self.to_out_0.parameters_mut());
370        out
371    }
372
373    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
374        let mut out = Vec::new();
375        for (n, p) in self.group_norm.named_parameters() {
376            out.push((format!("group_norm.{n}"), p));
377        }
378        for (n, p) in self.to_q.named_parameters() {
379            out.push((format!("to_q.{n}"), p));
380        }
381        for (n, p) in self.to_k.named_parameters() {
382            out.push((format!("to_k.{n}"), p));
383        }
384        for (n, p) in self.to_v.named_parameters() {
385            out.push((format!("to_v.{n}"), p));
386        }
387        for (n, p) in self.to_out_0.named_parameters() {
388            out.push((format!("to_out.0.{n}"), p));
389        }
390        out
391    }
392
393    fn train(&mut self) {
394        self.training = true;
395    }
396    fn eval(&mut self) {
397        self.training = false;
398    }
399    fn is_training(&self) -> bool {
400        self.training
401    }
402
403    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
404        let extract = |prefix: &str| -> StateDict<T> {
405            let p = format!("{prefix}.");
406            state
407                .iter()
408                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
409                .collect()
410        };
411        if strict {
412            for k in state.keys() {
413                let ok = k.starts_with("group_norm.")
414                    || k.starts_with("to_q.")
415                    || k.starts_with("to_k.")
416                    || k.starts_with("to_v.")
417                    || k.starts_with("to_out.0.");
418                if !ok {
419                    return Err(FerrotorchError::InvalidArgument {
420                        message: format!("unexpected key in AttnBlock2D state_dict: \"{k}\""),
421                    });
422                }
423            }
424        }
425        self.group_norm
426            .load_state_dict(&extract("group_norm"), strict)?;
427        self.to_q.load_state_dict(&extract("to_q"), strict)?;
428        self.to_k.load_state_dict(&extract("to_k"), strict)?;
429        self.to_v.load_state_dict(&extract("to_v"), strict)?;
430        self.to_out_0
431            .load_state_dict(&extract("to_out.0"), strict)?;
432        Ok(())
433    }
434}
435
436// ---------------------------------------------------------------------------
437// Upsample2D
438// ---------------------------------------------------------------------------
439
440/// Diffusers-style `Upsample2D` — nearest-neighbor 2x interpolation
441/// followed by a `Conv2d(C, C, k=3, pad=1, bias=True)`.
442///
443/// State-dict key: `conv.{weight,bias}` (matching `name="conv"` in the
444/// diffusers reference).
445#[derive(Debug)]
446pub struct Upsample2D<T: Float> {
447    /// Output conv (also serves as the "post-upsample smoothing" filter).
448    pub conv: Conv2d<T>,
449    channels: usize,
450    training: bool,
451}
452
453impl<T: Float> Upsample2D<T> {
454    /// Build a randomly-initialized `Upsample2D` (`C -> C`, k=3, pad=1).
455    ///
456    /// # Errors
457    ///
458    /// Returns the underlying [`FerrotorchError`] on bad conv config.
459    pub fn new(channels: usize) -> FerrotorchResult<Self> {
460        let conv = Conv2d::<T>::new(channels, channels, (3, 3), (1, 1), (1, 1), true)?;
461        Ok(Self {
462            conv,
463            channels,
464            training: false,
465        })
466    }
467}
468
469impl<T: Float> Module<T> for Upsample2D<T> {
470    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
471        if input.ndim() != 4 || input.shape()[1] != self.channels {
472            return Err(FerrotorchError::ShapeMismatch {
473                message: format!(
474                    "Upsample2D::forward: expected [B, {}, H, W], got {:?}",
475                    self.channels,
476                    input.shape()
477                ),
478            });
479        }
480        let h = input.shape()[2];
481        let w = input.shape()[3];
482        let up = Upsample::new([h * 2, w * 2], InterpolateMode::Nearest);
483        let upsampled = up.forward(input)?;
484        self.conv.forward(&upsampled)
485    }
486
487    fn parameters(&self) -> Vec<&Parameter<T>> {
488        self.conv.parameters()
489    }
490    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
491        self.conv.parameters_mut()
492    }
493    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
494        self.conv
495            .named_parameters()
496            .into_iter()
497            .map(|(n, p)| (format!("conv.{n}"), p))
498            .collect()
499    }
500    fn train(&mut self) {
501        self.training = true;
502    }
503    fn eval(&mut self) {
504        self.training = false;
505    }
506    fn is_training(&self) -> bool {
507        self.training
508    }
509    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
510        let conv_sd: StateDict<T> = state
511            .iter()
512            .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
513            .collect();
514        if strict {
515            for k in state.keys() {
516                if !k.starts_with("conv.") {
517                    return Err(FerrotorchError::InvalidArgument {
518                        message: format!("unexpected key in Upsample2D state_dict: \"{k}\""),
519                    });
520                }
521            }
522        }
523        self.conv.load_state_dict(&conv_sd, strict)
524    }
525}
526
527// ---------------------------------------------------------------------------
528// Downsample2D
529// ---------------------------------------------------------------------------
530
531/// Diffusers-style `Downsample2D` — a single `Conv2d(C, C, k=3, stride=2,
532/// pad=1, bias=True)`.
533///
534/// SD-1.5 uses `use_conv=True` and `padding=1`, so there is no separate
535/// pre-conv padding step. State-dict key: `conv.{weight,bias}`.
536#[derive(Debug)]
537pub struct Downsample2D<T: Float> {
538    /// Output conv (k=3, stride=2, pad=1).
539    pub conv: Conv2d<T>,
540    channels: usize,
541    training: bool,
542}
543
544impl<T: Float> Downsample2D<T> {
545    /// Build a randomly-initialized `Downsample2D` (`C -> C`).
546    ///
547    /// # Errors
548    ///
549    /// Returns the underlying [`FerrotorchError`] on bad conv config.
550    pub fn new(channels: usize) -> FerrotorchResult<Self> {
551        let conv = Conv2d::<T>::new(channels, channels, (3, 3), (2, 2), (1, 1), true)?;
552        Ok(Self {
553            conv,
554            channels,
555            training: false,
556        })
557    }
558}
559
560impl<T: Float> Module<T> for Downsample2D<T> {
561    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
562        if input.ndim() != 4 || input.shape()[1] != self.channels {
563            return Err(FerrotorchError::ShapeMismatch {
564                message: format!(
565                    "Downsample2D::forward: expected [B, {}, H, W], got {:?}",
566                    self.channels,
567                    input.shape()
568                ),
569            });
570        }
571        self.conv.forward(input)
572    }
573    fn parameters(&self) -> Vec<&Parameter<T>> {
574        self.conv.parameters()
575    }
576    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
577        self.conv.parameters_mut()
578    }
579    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
580        self.conv
581            .named_parameters()
582            .into_iter()
583            .map(|(n, p)| (format!("conv.{n}"), p))
584            .collect()
585    }
586    fn train(&mut self) {
587        self.training = true;
588    }
589    fn eval(&mut self) {
590        self.training = false;
591    }
592    fn is_training(&self) -> bool {
593        self.training
594    }
595    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
596        let conv_sd: StateDict<T> = state
597            .iter()
598            .filter_map(|(k, v)| k.strip_prefix("conv.").map(|r| (r.to_string(), v.clone())))
599            .collect();
600        if strict {
601            for k in state.keys() {
602                if !k.starts_with("conv.") {
603                    return Err(FerrotorchError::InvalidArgument {
604                        message: format!("unexpected key in Downsample2D state_dict: \"{k}\""),
605                    });
606                }
607            }
608        }
609        self.conv.load_state_dict(&conv_sd, strict)
610    }
611}
612
613// ---------------------------------------------------------------------------
614// UpDecoderBlock2D
615// ---------------------------------------------------------------------------
616
617/// `UpDecoderBlock2D` — a stack of `layers_per_block + 1` resnets at
618/// `out_channels`, optionally followed by an `Upsample2D`.
619///
620/// State-dict key layout:
621///
622/// ```text
623/// resnets.{i}.<resnet keys>    i in 0..(layers_per_block + 1)
624/// upsamplers.0.<upsample keys>  (present iff add_upsample)
625/// ```
626#[derive(Debug)]
627pub struct UpDecoderBlock2D<T: Float> {
628    /// `layers_per_block + 1` resnets at `out_channels`.
629    pub resnets: Vec<ResnetBlock2D<T>>,
630    /// Optional upsample (absent on the last decoder block).
631    pub upsamplers_0: Option<Upsample2D<T>>,
632    training: bool,
633}
634
635impl<T: Float> UpDecoderBlock2D<T> {
636    /// Build a randomly-initialized `UpDecoderBlock2D`.
637    ///
638    /// # Errors
639    ///
640    /// Returns [`FerrotorchError::InvalidArgument`] on bad channel /
641    /// group config.
642    pub fn new(
643        in_channels: usize,
644        out_channels: usize,
645        num_resnets: usize,
646        norm_num_groups: usize,
647        eps: f64,
648        add_upsample: bool,
649    ) -> FerrotorchResult<Self> {
650        if num_resnets == 0 {
651            return Err(FerrotorchError::InvalidArgument {
652                message: "UpDecoderBlock2D: num_resnets must be > 0".into(),
653            });
654        }
655        let mut resnets = Vec::with_capacity(num_resnets);
656        for i in 0..num_resnets {
657            let in_c = if i == 0 { in_channels } else { out_channels };
658            resnets.push(ResnetBlock2D::<T>::new(
659                in_c,
660                out_channels,
661                norm_num_groups,
662                eps,
663            )?);
664        }
665        let upsamplers_0 = if add_upsample {
666            Some(Upsample2D::<T>::new(out_channels)?)
667        } else {
668            None
669        };
670        Ok(Self {
671            resnets,
672            upsamplers_0,
673            training: false,
674        })
675    }
676}
677
678impl<T: Float> Module<T> for UpDecoderBlock2D<T> {
679    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
680        let mut h = input.clone();
681        for r in &self.resnets {
682            h = r.forward(&h)?;
683        }
684        if let Some(u) = &self.upsamplers_0 {
685            h = u.forward(&h)?;
686        }
687        Ok(h)
688    }
689
690    fn parameters(&self) -> Vec<&Parameter<T>> {
691        let mut out = Vec::new();
692        for r in &self.resnets {
693            out.extend(r.parameters());
694        }
695        if let Some(u) = &self.upsamplers_0 {
696            out.extend(u.parameters());
697        }
698        out
699    }
700
701    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
702        let mut out = Vec::new();
703        for r in &mut self.resnets {
704            out.extend(r.parameters_mut());
705        }
706        if let Some(u) = self.upsamplers_0.as_mut() {
707            out.extend(u.parameters_mut());
708        }
709        out
710    }
711
712    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
713        let mut out = Vec::new();
714        for (i, r) in self.resnets.iter().enumerate() {
715            for (n, p) in r.named_parameters() {
716                out.push((format!("resnets.{i}.{n}"), p));
717            }
718        }
719        if let Some(u) = &self.upsamplers_0 {
720            for (n, p) in u.named_parameters() {
721                out.push((format!("upsamplers.0.{n}"), p));
722            }
723        }
724        out
725    }
726
727    fn train(&mut self) {
728        self.training = true;
729        for r in &mut self.resnets {
730            r.train();
731        }
732        if let Some(u) = self.upsamplers_0.as_mut() {
733            u.train();
734        }
735    }
736    fn eval(&mut self) {
737        self.training = false;
738        for r in &mut self.resnets {
739            r.eval();
740        }
741        if let Some(u) = self.upsamplers_0.as_mut() {
742            u.eval();
743        }
744    }
745    fn is_training(&self) -> bool {
746        self.training
747    }
748
749    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
750        let extract = |prefix: &str| -> StateDict<T> {
751            let p = format!("{prefix}.");
752            state
753                .iter()
754                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
755                .collect()
756        };
757        if strict {
758            for k in state.keys() {
759                let ok = k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
760                if !ok {
761                    return Err(FerrotorchError::InvalidArgument {
762                        message: format!("unexpected key in UpDecoderBlock2D state_dict: \"{k}\""),
763                    });
764                }
765            }
766        }
767        for (i, r) in self.resnets.iter_mut().enumerate() {
768            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
769        }
770        if let Some(u) = self.upsamplers_0.as_mut() {
771            u.load_state_dict(&extract("upsamplers.0"), strict)?;
772        } else if strict {
773            // Verify no upsamplers keys leaked through if we have no upsample.
774            for k in state.keys() {
775                if k.starts_with("upsamplers.") {
776                    return Err(FerrotorchError::InvalidArgument {
777                        message: format!(
778                            "UpDecoderBlock2D has no upsampler but state_dict contains \"{k}\""
779                        ),
780                    });
781                }
782            }
783        }
784        Ok(())
785    }
786}
787
788// ---------------------------------------------------------------------------
789// DownEncoderBlock2D
790// ---------------------------------------------------------------------------
791
792/// `DownEncoderBlock2D` — a stack of `layers_per_block` resnets at
793/// `out_channels`, optionally followed by a `Downsample2D`.
794///
795/// This is the encoder-side mirror of [`UpDecoderBlock2D`]. The
796/// asymmetry in resnet count (encoder uses `layers_per_block` resnets,
797/// decoder uses `layers_per_block + 1`) matches the diffusers
798/// `AutoencoderKL` convention exactly.
799///
800/// State-dict key layout:
801///
802/// ```text
803/// resnets.{i}.<resnet keys>      i in 0..layers_per_block
804/// downsamplers.0.<downsample keys>   (present iff add_downsample)
805/// ```
806#[derive(Debug)]
807pub struct DownEncoderBlock2D<T: Float> {
808    /// `layers_per_block` resnets at `out_channels`. The first resnet
809    /// projects `in_channels -> out_channels` (via its 1x1 shortcut
810    /// when the channels differ); subsequent resnets are `out_channels`
811    /// throughout.
812    pub resnets: Vec<ResnetBlock2D<T>>,
813    /// Optional downsample (absent on the deepest encoder block, which
814    /// holds spatial resolution before the mid-block).
815    pub downsamplers_0: Option<Downsample2D<T>>,
816    training: bool,
817}
818
819impl<T: Float> DownEncoderBlock2D<T> {
820    /// Build a randomly-initialized `DownEncoderBlock2D`.
821    ///
822    /// # Errors
823    ///
824    /// Returns [`FerrotorchError::InvalidArgument`] when `num_resnets`
825    /// is zero, or any underlying [`FerrotorchError`] on bad channel /
826    /// group config.
827    pub fn new(
828        in_channels: usize,
829        out_channels: usize,
830        num_resnets: usize,
831        norm_num_groups: usize,
832        eps: f64,
833        add_downsample: bool,
834    ) -> FerrotorchResult<Self> {
835        if num_resnets == 0 {
836            return Err(FerrotorchError::InvalidArgument {
837                message: "DownEncoderBlock2D: num_resnets must be > 0".into(),
838            });
839        }
840        let mut resnets = Vec::with_capacity(num_resnets);
841        for i in 0..num_resnets {
842            let in_c = if i == 0 { in_channels } else { out_channels };
843            resnets.push(ResnetBlock2D::<T>::new(
844                in_c,
845                out_channels,
846                norm_num_groups,
847                eps,
848            )?);
849        }
850        let downsamplers_0 = if add_downsample {
851            Some(Downsample2D::<T>::new(out_channels)?)
852        } else {
853            None
854        };
855        Ok(Self {
856            resnets,
857            downsamplers_0,
858            training: false,
859        })
860    }
861}
862
863impl<T: Float> Module<T> for DownEncoderBlock2D<T> {
864    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
865        let mut h = input.clone();
866        for r in &self.resnets {
867            h = r.forward(&h)?;
868        }
869        if let Some(d) = &self.downsamplers_0 {
870            h = d.forward(&h)?;
871        }
872        Ok(h)
873    }
874
875    fn parameters(&self) -> Vec<&Parameter<T>> {
876        let mut out = Vec::new();
877        for r in &self.resnets {
878            out.extend(r.parameters());
879        }
880        if let Some(d) = &self.downsamplers_0 {
881            out.extend(d.parameters());
882        }
883        out
884    }
885
886    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
887        let mut out = Vec::new();
888        for r in &mut self.resnets {
889            out.extend(r.parameters_mut());
890        }
891        if let Some(d) = self.downsamplers_0.as_mut() {
892            out.extend(d.parameters_mut());
893        }
894        out
895    }
896
897    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
898        let mut out = Vec::new();
899        for (i, r) in self.resnets.iter().enumerate() {
900            for (n, p) in r.named_parameters() {
901                out.push((format!("resnets.{i}.{n}"), p));
902            }
903        }
904        if let Some(d) = &self.downsamplers_0 {
905            for (n, p) in d.named_parameters() {
906                out.push((format!("downsamplers.0.{n}"), p));
907            }
908        }
909        out
910    }
911
912    fn train(&mut self) {
913        self.training = true;
914        for r in &mut self.resnets {
915            r.train();
916        }
917        if let Some(d) = self.downsamplers_0.as_mut() {
918            d.train();
919        }
920    }
921    fn eval(&mut self) {
922        self.training = false;
923        for r in &mut self.resnets {
924            r.eval();
925        }
926        if let Some(d) = self.downsamplers_0.as_mut() {
927            d.eval();
928        }
929    }
930    fn is_training(&self) -> bool {
931        self.training
932    }
933
934    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
935        let extract = |prefix: &str| -> StateDict<T> {
936            let p = format!("{prefix}.");
937            state
938                .iter()
939                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
940                .collect()
941        };
942        if strict {
943            for k in state.keys() {
944                let ok = k.starts_with("resnets.") || k.starts_with("downsamplers.0.");
945                if !ok {
946                    return Err(FerrotorchError::InvalidArgument {
947                        message: format!(
948                            "unexpected key in DownEncoderBlock2D state_dict: \"{k}\""
949                        ),
950                    });
951                }
952            }
953        }
954        for (i, r) in self.resnets.iter_mut().enumerate() {
955            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
956        }
957        if let Some(d) = self.downsamplers_0.as_mut() {
958            d.load_state_dict(&extract("downsamplers.0"), strict)?;
959        } else if strict {
960            for k in state.keys() {
961                if k.starts_with("downsamplers.") {
962                    return Err(FerrotorchError::InvalidArgument {
963                        message: format!(
964                            "DownEncoderBlock2D has no downsampler but state_dict contains \"{k}\""
965                        ),
966                    });
967                }
968            }
969        }
970        Ok(())
971    }
972}
973
974// ---------------------------------------------------------------------------
975// UNetMidBlock2D (VAE flavour)
976// ---------------------------------------------------------------------------
977
978/// `UNetMidBlock2D` configured the way the SD VAE uses it:
979///
980/// ```text
981/// resnets[0] -> attentions[0] -> resnets[1]
982/// ```
983///
984/// 2 resnets, 1 attention; all at `in_channels = out_channels = 512` for
985/// SD 1.5. Diffusers always has `len(resnets) == num_layers + 1` and
986/// `len(attentions) == num_layers`; in the VAE `num_layers = 1`, so two
987/// resnets and one attention.
988#[derive(Debug)]
989pub struct UNetMidBlock2D<T: Float> {
990    /// Pre-attention + post-attention resnets (size 2 for SD VAE).
991    pub resnets: Vec<ResnetBlock2D<T>>,
992    /// Single attention block.
993    pub attentions: Vec<AttnBlock2D<T>>,
994    channels: usize,
995    training: bool,
996}
997
998impl<T: Float> UNetMidBlock2D<T> {
999    /// Build a randomly-initialized VAE-flavoured `UNetMidBlock2D`.
1000    ///
1001    /// # Errors
1002    ///
1003    /// Returns the underlying [`FerrotorchError`] on bad config.
1004    pub fn new(channels: usize, norm_num_groups: usize, eps: f64) -> FerrotorchResult<Self> {
1005        // SD VAE: num_layers=1 => 2 resnets + 1 attention.
1006        let resnets = vec![
1007            ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1008            ResnetBlock2D::<T>::new(channels, channels, norm_num_groups, eps)?,
1009        ];
1010        let attentions = vec![AttnBlock2D::<T>::new(channels, norm_num_groups, eps)?];
1011        Ok(Self {
1012            resnets,
1013            attentions,
1014            channels,
1015            training: false,
1016        })
1017    }
1018}
1019
1020impl<T: Float> Module<T> for UNetMidBlock2D<T> {
1021    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1022        // diffusers UNetMidBlock2D.forward (no temb):
1023        //   h = resnets[0](x)
1024        //   for (attn, resnet) in zip(attentions, resnets[1:]):
1025        //       h = attn(h); h = resnet(h)
1026        let mut h = self.resnets[0].forward(input)?;
1027        for (attn, resnet) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
1028            h = attn.forward(&h)?;
1029            h = resnet.forward(&h)?;
1030        }
1031        Ok(h)
1032    }
1033
1034    fn parameters(&self) -> Vec<&Parameter<T>> {
1035        let mut out = Vec::new();
1036        // The HF diffusers state-dict key order is `attentions` before
1037        // `resnets`. We match the ordering used for state-dict
1038        // round-trip — see `named_parameters()` for the canonical
1039        // sequence.
1040        for a in &self.attentions {
1041            out.extend(a.parameters());
1042        }
1043        for r in &self.resnets {
1044            out.extend(r.parameters());
1045        }
1046        out
1047    }
1048
1049    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1050        let mut out = Vec::new();
1051        for a in &mut self.attentions {
1052            out.extend(a.parameters_mut());
1053        }
1054        for r in &mut self.resnets {
1055            out.extend(r.parameters_mut());
1056        }
1057        out
1058    }
1059
1060    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1061        let mut out = Vec::new();
1062        // Match the order used by `safetensors` listings of the HF VAE:
1063        // `attentions.*` first, then `resnets.*`.
1064        for (i, a) in self.attentions.iter().enumerate() {
1065            for (n, p) in a.named_parameters() {
1066                out.push((format!("attentions.{i}.{n}"), p));
1067            }
1068        }
1069        for (i, r) in self.resnets.iter().enumerate() {
1070            for (n, p) in r.named_parameters() {
1071                out.push((format!("resnets.{i}.{n}"), p));
1072            }
1073        }
1074        out
1075    }
1076
1077    fn train(&mut self) {
1078        self.training = true;
1079        for r in &mut self.resnets {
1080            r.train();
1081        }
1082        for a in &mut self.attentions {
1083            a.train();
1084        }
1085    }
1086    fn eval(&mut self) {
1087        self.training = false;
1088        for r in &mut self.resnets {
1089            r.eval();
1090        }
1091        for a in &mut self.attentions {
1092            a.eval();
1093        }
1094    }
1095    fn is_training(&self) -> bool {
1096        self.training
1097    }
1098
1099    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1100        let extract = |prefix: &str| -> StateDict<T> {
1101            let p = format!("{prefix}.");
1102            state
1103                .iter()
1104                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1105                .collect()
1106        };
1107        if strict {
1108            for k in state.keys() {
1109                let ok = k.starts_with("attentions.") || k.starts_with("resnets.");
1110                if !ok {
1111                    return Err(FerrotorchError::InvalidArgument {
1112                        message: format!("unexpected key in UNetMidBlock2D state_dict: \"{k}\""),
1113                    });
1114                }
1115            }
1116        }
1117        for (i, a) in self.attentions.iter_mut().enumerate() {
1118            a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
1119        }
1120        for (i, r) in self.resnets.iter_mut().enumerate() {
1121            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
1122        }
1123        let _ = self.channels; // silence dead_code warning under some lint configs
1124        let _: HashMap<String, Tensor<T>> = HashMap::new(); // pull HashMap into scope
1125        Ok(())
1126    }
1127}
1128
1129#[cfg(test)]
1130mod tests {
1131    use super::*;
1132    use ferrotorch_core::TensorStorage;
1133
1134    #[test]
1135    fn resnet_same_channels_no_shortcut() {
1136        let r = ResnetBlock2D::<f32>::new(32, 32, 32, 1e-6).unwrap();
1137        assert!(r.conv_shortcut.is_none());
1138        let x = Tensor::from_storage(
1139            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1140            vec![1, 32, 4, 4],
1141            false,
1142        )
1143        .unwrap();
1144        let y = r.forward(&x).unwrap();
1145        assert_eq!(y.shape(), &[1, 32, 4, 4]);
1146        for &v in y.data().unwrap() {
1147            assert!(v.is_finite());
1148        }
1149    }
1150
1151    #[test]
1152    fn resnet_different_channels_has_shortcut() {
1153        let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1154        assert!(r.conv_shortcut.is_some());
1155        let x = Tensor::from_storage(
1156            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1157            vec![1, 32, 4, 4],
1158            false,
1159        )
1160        .unwrap();
1161        let y = r.forward(&x).unwrap();
1162        assert_eq!(y.shape(), &[1, 64, 4, 4]);
1163    }
1164
1165    #[test]
1166    fn resnet_named_parameters_layout() {
1167        let r = ResnetBlock2D::<f32>::new(32, 64, 32, 1e-6).unwrap();
1168        let names: Vec<String> = r.named_parameters().into_iter().map(|(n, _)| n).collect();
1169        for k in [
1170            "norm1.weight",
1171            "norm1.bias",
1172            "conv1.weight",
1173            "conv1.bias",
1174            "norm2.weight",
1175            "norm2.bias",
1176            "conv2.weight",
1177            "conv2.bias",
1178            "conv_shortcut.weight",
1179            "conv_shortcut.bias",
1180        ] {
1181            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1182        }
1183    }
1184
1185    #[test]
1186    fn attn_shape_and_residual() {
1187        // 32 channels, 4 groups => 8 channels per group; small spatial.
1188        let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1189        let x = Tensor::from_storage(
1190            TensorStorage::cpu(vec![0.01f32; 32 * 4 * 4]),
1191            vec![1, 32, 4, 4],
1192            false,
1193        )
1194        .unwrap();
1195        let y = a.forward(&x).unwrap();
1196        assert_eq!(y.shape(), &[1, 32, 4, 4]);
1197        for &v in y.data().unwrap() {
1198            assert!(v.is_finite(), "attn output non-finite: {v}");
1199        }
1200    }
1201
1202    #[test]
1203    fn attn_named_parameters_layout() {
1204        let a = AttnBlock2D::<f32>::new(32, 4, 1e-6).unwrap();
1205        let names: Vec<String> = a.named_parameters().into_iter().map(|(n, _)| n).collect();
1206        for k in [
1207            "group_norm.weight",
1208            "group_norm.bias",
1209            "to_q.weight",
1210            "to_q.bias",
1211            "to_k.weight",
1212            "to_k.bias",
1213            "to_v.weight",
1214            "to_v.bias",
1215            "to_out.0.weight",
1216            "to_out.0.bias",
1217        ] {
1218            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1219        }
1220    }
1221
1222    #[test]
1223    fn upsample2d_doubles_spatial() {
1224        let u = Upsample2D::<f32>::new(8).unwrap();
1225        let x = Tensor::from_storage(
1226            TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1227            vec![1, 8, 3, 3],
1228            false,
1229        )
1230        .unwrap();
1231        let y = u.forward(&x).unwrap();
1232        assert_eq!(y.shape(), &[1, 8, 6, 6]);
1233    }
1234
1235    #[test]
1236    fn up_decoder_block_shape_with_upsample() {
1237        let b = UpDecoderBlock2D::<f32>::new(16, 8, 2, 4, 1e-6, true).unwrap();
1238        let x = Tensor::from_storage(
1239            TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1240            vec![1, 16, 3, 3],
1241            false,
1242        )
1243        .unwrap();
1244        let y = b.forward(&x).unwrap();
1245        assert_eq!(y.shape(), &[1, 8, 6, 6]);
1246    }
1247
1248    #[test]
1249    fn up_decoder_block_shape_no_upsample() {
1250        let b = UpDecoderBlock2D::<f32>::new(8, 8, 2, 4, 1e-6, false).unwrap();
1251        let x = Tensor::from_storage(
1252            TensorStorage::cpu(vec![0.05f32; 8 * 3 * 3]),
1253            vec![1, 8, 3, 3],
1254            false,
1255        )
1256        .unwrap();
1257        let y = b.forward(&x).unwrap();
1258        assert_eq!(y.shape(), &[1, 8, 3, 3]);
1259    }
1260
1261    #[test]
1262    fn down_encoder_block_shape_with_downsample() {
1263        // [1, 8, 4, 4] -> resnets project 8 -> 16 -> downsample halves spatial
1264        let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1265        let x = Tensor::from_storage(
1266            TensorStorage::cpu(vec![0.05f32; 8 * 4 * 4]),
1267            vec![1, 8, 4, 4],
1268            false,
1269        )
1270        .unwrap();
1271        let y = b.forward(&x).unwrap();
1272        assert_eq!(y.shape(), &[1, 16, 2, 2]);
1273    }
1274
1275    #[test]
1276    fn down_encoder_block_shape_no_downsample() {
1277        // Last encoder block: same channels, no downsample (spatial preserved)
1278        let b = DownEncoderBlock2D::<f32>::new(16, 16, 2, 4, 1e-6, false).unwrap();
1279        let x = Tensor::from_storage(
1280            TensorStorage::cpu(vec![0.05f32; 16 * 2 * 2]),
1281            vec![1, 16, 2, 2],
1282            false,
1283        )
1284        .unwrap();
1285        let y = b.forward(&x).unwrap();
1286        assert_eq!(y.shape(), &[1, 16, 2, 2]);
1287    }
1288
1289    #[test]
1290    fn down_encoder_block_named_parameters_layout() {
1291        let b = DownEncoderBlock2D::<f32>::new(8, 16, 2, 4, 1e-6, true).unwrap();
1292        let names: Vec<String> = b.named_parameters().into_iter().map(|(n, _)| n).collect();
1293        for k in [
1294            "resnets.0.norm1.weight",
1295            "resnets.0.conv1.weight",
1296            "resnets.0.conv_shortcut.weight",
1297            "resnets.1.norm1.weight",
1298            "downsamplers.0.conv.weight",
1299            "downsamplers.0.conv.bias",
1300        ] {
1301            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1302        }
1303    }
1304
1305    #[test]
1306    fn mid_block_shape() {
1307        let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1308        let x = Tensor::from_storage(
1309            TensorStorage::cpu(vec![0.05f32; 16 * 3 * 3]),
1310            vec![1, 16, 3, 3],
1311            false,
1312        )
1313        .unwrap();
1314        let y = m.forward(&x).unwrap();
1315        assert_eq!(y.shape(), &[1, 16, 3, 3]);
1316    }
1317
1318    #[test]
1319    fn mid_block_named_parameters_layout() {
1320        let m = UNetMidBlock2D::<f32>::new(16, 4, 1e-6).unwrap();
1321        let names: Vec<String> = m.named_parameters().into_iter().map(|(n, _)| n).collect();
1322        for k in [
1323            "attentions.0.group_norm.weight",
1324            "attentions.0.to_q.weight",
1325            "resnets.0.norm1.weight",
1326            "resnets.0.conv1.weight",
1327            "resnets.1.conv2.weight",
1328        ] {
1329            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1330        }
1331    }
1332}