Skip to main content

ferrotorch_diffusion/
unet.rs

1//! Stable-Diffusion UNet2DConditionModel forward pass.
2//!
3//! Matches `diffusers.models.unets.UNet2DConditionModel.forward(sample,
4//! timestep, encoder_hidden_states).sample` for
5//! `runwayml/stable-diffusion-v1-5` (the
6//! `down_block_types=[CrossAttn × 3, DownBlock2D]` / `up_block_types=
7//! [UpBlock2D, CrossAttn × 3]` / `mid_block=UNetMidBlock2DCrossAttn`
8//! topology).
9//!
10//! ```text
11//! t  = time_proj(timesteps)                # [B, 320]
12//! t  = time_embedding(t)                   # [B, 1280]
13//! h0 = conv_in(sample)                     # [B, 320, H, W]
14//! skips = [h0]
15//! for db in down_blocks:                   # 4 down blocks
16//!     h, skip = db(h, t, ehs)
17//!     skips.extend(skip)
18//! h = mid_block(h, t, ehs)
19//! for ub in up_blocks:                     # 4 up blocks
20//!     pop the trailing N skip tensors and concat at each resnet stage
21//!     h = ub(h, t, ehs, skips_popped)
22//! h = silu(conv_norm_out(h))
23//! h = conv_out(h)                          # [B, 4, H, W]
24//! ```
25//!
26//! State-dict key layout (matches diffusers byte-for-byte):
27//!
28//! ```text
29//! time_embedding.linear_{1,2}.{weight,bias}
30//! conv_in.{weight,bias}
31//! down_blocks.{i}.{resnets,attentions,downsamplers}.<…>
32//! mid_block.{resnets,attentions}.<…>
33//! up_blocks.{i}.{resnets,attentions,upsamplers}.<…>
34//! conv_norm_out.{weight,bias}
35//! conv_out.{weight,bias}
36//! ```
37//!
38//! Note: the time embedding has NO sinusoidal parameters (it's the
39//! arithmetic-only `Timesteps` module followed by the trainable
40//! `TimestepEmbedding` MLP).
41//!
42//! ## REQ status (per `.design/ferrotorch-diffusion/unet.md`)
43//!
44//! | REQ | Status | Evidence |
45//! |---|---|---|
46//! | REQ-1 | SHIPPED | `CrossAttnDownBlock2D<T>` at `unet.rs:67..76` and `forward_t` at `unet.rs:147..165`; consumer: `AnyDownBlock::CrossAttn` at `unet.rs:1162` built by `UNet2DConditionModel::new` |
47//! | REQ-2 | SHIPPED | `DownBlock2D<T>` at `unet.rs:277..285` and `forward_t` at `unet.rs:327..343`; consumer: `AnyDownBlock::Plain` at `unet.rs:1174` built by `UNet2DConditionModel::new` |
48//! | REQ-3 | SHIPPED | `UNetMidBlock2DCrossAttn<T>` at `unet.rs:442..450` and `forward_t` at `unet.rs:491..504`; consumer: `UNet2DConditionModel::forward_t` at `unet.rs:1330` invokes the mid block |
49//! | REQ-4 | SHIPPED | `CrossAttnUpBlock2D<T>` at `unet.rs:600..610` and `forward_t` at `unet.rs:691..722`; consumer: `AnyUpBlock::CrossAttn` at `unet.rs:1217` built by `UNet2DConditionModel::new` |
50//! | REQ-5 | SHIPPED | `UpBlock2D<T>` at `unet.rs:829..836` and `forward_t` at `unet.rs:890..914`; consumer: `AnyUpBlock::Plain` at `unet.rs:1230` |
51//! | REQ-6 | SHIPPED | `AnyDownBlock<T>` at `unet.rs:1000..1046` and `AnyUpBlock<T>` at `unet.rs:1048..1099`; consumer: `UNet2DConditionModel::new` dispatches by `cfg.down_block_has_attn[i]` / `cfg.up_block_has_attn[i]` |
52//! | REQ-7 | SHIPPED | `UNet2DConditionModel::forward_t` at `unet.rs:1274..1362`; consumer: `pipeline.rs:142..143` in `cfg_eval` calls `self.unet.forward_t(...)` twice per diffusion step |
53//! | REQ-8 | SHIPPED | `Module<T>::load_state_dict` at `unet.rs:1444..1486`; consumer: `safetensors_loader.rs:146` `UNet2DConditionModel::load_hf_state_dict` calls `self.load_state_dict(&remapped, strict)` |
54
55use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
56use ferrotorch_nn::module::{Module, StateDict};
57use ferrotorch_nn::parameter::Parameter;
58use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
59
60use crate::attention::Transformer2DModel;
61use crate::blocks::{Downsample2D, Upsample2D};
62use crate::resnet_block_time::ResnetBlock2DTime;
63use crate::time_embedding::{TimestepEmbedding, Timesteps};
64use crate::unet_config::UNet2DConditionConfig;
65
66// ---------------------------------------------------------------------------
67// CrossAttnDownBlock2D
68// ---------------------------------------------------------------------------
69
70/// `CrossAttnDownBlock2D` — `layers_per_block` × (ResnetBlock2DTime +
71/// Transformer2DModel) + optional Downsample2D.
72///
73/// State-dict layout:
74///
75/// ```text
76/// resnets.{j}.<keys>          j in 0..layers_per_block
77/// attentions.{j}.<keys>       j in 0..layers_per_block
78/// downsamplers.0.<keys>        present iff add_downsample
79/// ```
80#[derive(Debug)]
81pub struct CrossAttnDownBlock2D<T: Float> {
82    /// `layers_per_block` time-conditioned resnets.
83    pub resnets: Vec<ResnetBlock2DTime<T>>,
84    /// `layers_per_block` transformer blocks (cross-attn).
85    pub attentions: Vec<Transformer2DModel<T>>,
86    /// Optional downsampler.
87    pub downsamplers_0: Option<Downsample2D<T>>,
88    training: bool,
89}
90
91impl<T: Float> CrossAttnDownBlock2D<T> {
92    /// Build a randomly-initialized cross-attention down-block.
93    ///
94    /// # Errors
95    ///
96    /// Returns the underlying [`FerrotorchError`].
97    #[allow(clippy::too_many_arguments)]
98    pub fn new(
99        in_channels: usize,
100        out_channels: usize,
101        temb_channels: usize,
102        num_layers: usize,
103        attention_head_dim: usize,
104        cross_attention_dim: usize,
105        norm_num_groups: usize,
106        add_downsample: bool,
107        transformer_layers_per_block: usize,
108    ) -> FerrotorchResult<Self> {
109        // Diffusers footgun: for SD-1.5 the `attention_head_dim` config
110        // entry is actually the *number of heads*, not the per-head
111        // dimension. `num_attention_heads` was added later and defaults
112        // to `attention_head_dim` when unset (see
113        // diffusers/models/unets/unet_2d_condition.py: `num_attention_heads
114        // = num_attention_heads or attention_head_dim`). The per-head
115        // dimension is then `out_channels // num_heads`.
116        let heads = attention_head_dim;
117        let dim_head = out_channels / heads;
118        let mut resnets = Vec::with_capacity(num_layers);
119        let mut attentions = Vec::with_capacity(num_layers);
120        for j in 0..num_layers {
121            let in_c = if j == 0 { in_channels } else { out_channels };
122            resnets.push(ResnetBlock2DTime::<T>::new(
123                in_c,
124                out_channels,
125                temb_channels,
126                norm_num_groups,
127                1e-5,
128            )?);
129            attentions.push(Transformer2DModel::<T>::new(
130                out_channels,
131                heads,
132                dim_head,
133                transformer_layers_per_block,
134                cross_attention_dim,
135                norm_num_groups,
136            )?);
137        }
138        let downsamplers_0 = if add_downsample {
139            Some(Downsample2D::<T>::new(out_channels)?)
140        } else {
141            None
142        };
143        Ok(Self {
144            resnets,
145            attentions,
146            downsamplers_0,
147            training: false,
148        })
149    }
150
151    /// Forward; returns `(output, [skip_after_each_resnet+optional_downsample])`.
152    ///
153    /// Skips are pushed in diffusers order — every resnet+attn pair
154    /// emits one skip, and the downsampler (if present) emits the post-
155    /// downsample tensor as the trailing skip.
156    ///
157    /// # Errors
158    ///
159    /// Returns the underlying [`FerrotorchError`].
160    pub fn forward_t(
161        &self,
162        x: &Tensor<T>,
163        temb: &Tensor<T>,
164        encoder_hidden_states: &Tensor<T>,
165    ) -> FerrotorchResult<(Tensor<T>, Vec<Tensor<T>>)> {
166        let mut h = x.clone();
167        let mut skips = Vec::with_capacity(self.resnets.len() + 1);
168        for (r, a) in self.resnets.iter().zip(self.attentions.iter()) {
169            h = r.forward_t(&h, temb)?;
170            h = a.forward_xattn(&h, encoder_hidden_states)?;
171            skips.push(h.clone());
172        }
173        if let Some(d) = &self.downsamplers_0 {
174            h = d.forward(&h)?;
175            skips.push(h.clone());
176        }
177        Ok((h, skips))
178    }
179
180    fn named_params_internal(&self) -> Vec<(String, &Parameter<T>)> {
181        let mut o = Vec::new();
182        for (i, r) in self.resnets.iter().enumerate() {
183            for (n, p) in r.named_parameters() {
184                o.push((format!("resnets.{i}.{n}"), p));
185            }
186        }
187        for (i, a) in self.attentions.iter().enumerate() {
188            for (n, p) in a.named_parameters() {
189                o.push((format!("attentions.{i}.{n}"), p));
190            }
191        }
192        if let Some(d) = &self.downsamplers_0 {
193            for (n, p) in d.named_parameters() {
194                o.push((format!("downsamplers.0.{n}"), p));
195            }
196        }
197        o
198    }
199}
200
201impl<T: Float> Module<T> for CrossAttnDownBlock2D<T> {
202    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
203        Err(FerrotorchError::InvalidArgument {
204            message: "CrossAttnDownBlock2D::forward requires (x, temb, ehs) — call \
205                      forward_t instead"
206                .into(),
207        })
208    }
209
210    fn parameters(&self) -> Vec<&Parameter<T>> {
211        let mut o = Vec::new();
212        for r in &self.resnets {
213            o.extend(r.parameters());
214        }
215        for a in &self.attentions {
216            o.extend(a.parameters());
217        }
218        if let Some(d) = &self.downsamplers_0 {
219            o.extend(d.parameters());
220        }
221        o
222    }
223    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
224        let mut o = Vec::new();
225        for r in &mut self.resnets {
226            o.extend(r.parameters_mut());
227        }
228        for a in &mut self.attentions {
229            o.extend(a.parameters_mut());
230        }
231        if let Some(d) = self.downsamplers_0.as_mut() {
232            o.extend(d.parameters_mut());
233        }
234        o
235    }
236    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
237        self.named_params_internal()
238    }
239    fn train(&mut self) {
240        self.training = true;
241    }
242    fn eval(&mut self) {
243        self.training = false;
244    }
245    fn is_training(&self) -> bool {
246        self.training
247    }
248    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
249        let extract = |prefix: &str| -> StateDict<T> {
250            let p = format!("{prefix}.");
251            state
252                .iter()
253                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
254                .collect()
255        };
256        if strict {
257            for k in state.keys() {
258                let ok = k.starts_with("resnets.")
259                    || k.starts_with("attentions.")
260                    || k.starts_with("downsamplers.0.");
261                if !ok {
262                    return Err(FerrotorchError::InvalidArgument {
263                        message: format!(
264                            "unexpected key in CrossAttnDownBlock2D state_dict: \"{k}\""
265                        ),
266                    });
267                }
268            }
269        }
270        for (i, r) in self.resnets.iter_mut().enumerate() {
271            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
272        }
273        for (i, a) in self.attentions.iter_mut().enumerate() {
274            a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
275        }
276        if let Some(d) = self.downsamplers_0.as_mut() {
277            d.load_state_dict(&extract("downsamplers.0"), strict)?;
278        }
279        Ok(())
280    }
281}
282
283// ---------------------------------------------------------------------------
284// DownBlock2D (no cross-attn — the final SD down-block)
285// ---------------------------------------------------------------------------
286
287/// `DownBlock2D` — `layers_per_block` ResnetBlock2DTime + optional
288/// Downsample2D. No transformer attention.
289#[derive(Debug)]
290pub struct DownBlock2D<T: Float> {
291    /// `layers_per_block` time-conditioned resnets.
292    pub resnets: Vec<ResnetBlock2DTime<T>>,
293    /// Optional downsampler.
294    pub downsamplers_0: Option<Downsample2D<T>>,
295    training: bool,
296}
297
298impl<T: Float> DownBlock2D<T> {
299    /// Build a randomly-initialized non-cross-attn down-block.
300    ///
301    /// # Errors
302    ///
303    /// Returns the underlying [`FerrotorchError`].
304    pub fn new(
305        in_channels: usize,
306        out_channels: usize,
307        temb_channels: usize,
308        num_layers: usize,
309        norm_num_groups: usize,
310        add_downsample: bool,
311    ) -> FerrotorchResult<Self> {
312        let mut resnets = Vec::with_capacity(num_layers);
313        for j in 0..num_layers {
314            let in_c = if j == 0 { in_channels } else { out_channels };
315            resnets.push(ResnetBlock2DTime::<T>::new(
316                in_c,
317                out_channels,
318                temb_channels,
319                norm_num_groups,
320                1e-5,
321            )?);
322        }
323        let downsamplers_0 = if add_downsample {
324            Some(Downsample2D::<T>::new(out_channels)?)
325        } else {
326            None
327        };
328        Ok(Self {
329            resnets,
330            downsamplers_0,
331            training: false,
332        })
333    }
334
335    /// Forward; returns `(output, skips)`.
336    ///
337    /// # Errors
338    ///
339    /// Returns the underlying [`FerrotorchError`].
340    pub fn forward_t(
341        &self,
342        x: &Tensor<T>,
343        temb: &Tensor<T>,
344    ) -> FerrotorchResult<(Tensor<T>, Vec<Tensor<T>>)> {
345        let mut h = x.clone();
346        let mut skips = Vec::with_capacity(self.resnets.len() + 1);
347        for r in &self.resnets {
348            h = r.forward_t(&h, temb)?;
349            skips.push(h.clone());
350        }
351        if let Some(d) = &self.downsamplers_0 {
352            h = d.forward(&h)?;
353            skips.push(h.clone());
354        }
355        Ok((h, skips))
356    }
357}
358
359impl<T: Float> Module<T> for DownBlock2D<T> {
360    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
361        Err(FerrotorchError::InvalidArgument {
362            message: "DownBlock2D::forward requires (x, temb) — call forward_t instead".into(),
363        })
364    }
365
366    fn parameters(&self) -> Vec<&Parameter<T>> {
367        let mut o = Vec::new();
368        for r in &self.resnets {
369            o.extend(r.parameters());
370        }
371        if let Some(d) = &self.downsamplers_0 {
372            o.extend(d.parameters());
373        }
374        o
375    }
376    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
377        let mut o = Vec::new();
378        for r in &mut self.resnets {
379            o.extend(r.parameters_mut());
380        }
381        if let Some(d) = self.downsamplers_0.as_mut() {
382            o.extend(d.parameters_mut());
383        }
384        o
385    }
386    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
387        let mut o = Vec::new();
388        for (i, r) in self.resnets.iter().enumerate() {
389            for (n, p) in r.named_parameters() {
390                o.push((format!("resnets.{i}.{n}"), p));
391            }
392        }
393        if let Some(d) = &self.downsamplers_0 {
394            for (n, p) in d.named_parameters() {
395                o.push((format!("downsamplers.0.{n}"), p));
396            }
397        }
398        o
399    }
400    fn train(&mut self) {
401        self.training = true;
402    }
403    fn eval(&mut self) {
404        self.training = false;
405    }
406    fn is_training(&self) -> bool {
407        self.training
408    }
409    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
410        let extract = |prefix: &str| -> StateDict<T> {
411            let p = format!("{prefix}.");
412            state
413                .iter()
414                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
415                .collect()
416        };
417        if strict {
418            for k in state.keys() {
419                let ok = k.starts_with("resnets.") || k.starts_with("downsamplers.0.");
420                if !ok {
421                    return Err(FerrotorchError::InvalidArgument {
422                        message: format!("unexpected key in DownBlock2D state_dict: \"{k}\""),
423                    });
424                }
425            }
426        }
427        for (i, r) in self.resnets.iter_mut().enumerate() {
428            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
429        }
430        if let Some(d) = self.downsamplers_0.as_mut() {
431            d.load_state_dict(&extract("downsamplers.0"), strict)?;
432        }
433        Ok(())
434    }
435}
436
437// ---------------------------------------------------------------------------
438// UNetMidBlock2DCrossAttn
439// ---------------------------------------------------------------------------
440
441/// `UNetMidBlock2DCrossAttn` — the mid block of the SD UNet:
442///
443/// ```text
444/// resnets[0] -> attentions[0] -> resnets[1]
445/// ```
446///
447/// State-dict layout:
448///
449/// ```text
450/// attentions.0.<keys>
451/// resnets.0.<keys>
452/// resnets.1.<keys>
453/// ```
454#[derive(Debug)]
455pub struct UNetMidBlock2DCrossAttn<T: Float> {
456    /// 2 time-conditioned resnets.
457    pub resnets: Vec<ResnetBlock2DTime<T>>,
458    /// 1 cross-attn transformer.
459    pub attentions: Vec<Transformer2DModel<T>>,
460    training: bool,
461}
462
463impl<T: Float> UNetMidBlock2DCrossAttn<T> {
464    /// Build a randomly-initialized mid block (`channels -> channels`).
465    ///
466    /// # Errors
467    ///
468    /// Returns the underlying [`FerrotorchError`].
469    pub fn new(
470        channels: usize,
471        temb_channels: usize,
472        attention_head_dim: usize,
473        cross_attention_dim: usize,
474        norm_num_groups: usize,
475        transformer_layers_per_block: usize,
476    ) -> FerrotorchResult<Self> {
477        // See CrossAttnDownBlock2D::new for the heads/dim_head footgun.
478        let heads = attention_head_dim;
479        let dim_head = channels / heads;
480        let r0 =
481            ResnetBlock2DTime::<T>::new(channels, channels, temb_channels, norm_num_groups, 1e-5)?;
482        let attn = Transformer2DModel::<T>::new(
483            channels,
484            heads,
485            dim_head,
486            transformer_layers_per_block,
487            cross_attention_dim,
488            norm_num_groups,
489        )?;
490        let r1 =
491            ResnetBlock2DTime::<T>::new(channels, channels, temb_channels, norm_num_groups, 1e-5)?;
492        Ok(Self {
493            resnets: vec![r0, r1],
494            attentions: vec![attn],
495            training: false,
496        })
497    }
498
499    /// Forward with time embedding and encoder hidden states.
500    ///
501    /// # Errors
502    ///
503    /// Returns the underlying [`FerrotorchError`].
504    pub fn forward_t(
505        &self,
506        x: &Tensor<T>,
507        temb: &Tensor<T>,
508        encoder_hidden_states: &Tensor<T>,
509    ) -> FerrotorchResult<Tensor<T>> {
510        let mut h = self.resnets[0].forward_t(x, temb)?;
511        for (a, r) in self.attentions.iter().zip(self.resnets.iter().skip(1)) {
512            h = a.forward_xattn(&h, encoder_hidden_states)?;
513            h = r.forward_t(&h, temb)?;
514        }
515        Ok(h)
516    }
517}
518
519impl<T: Float> Module<T> for UNetMidBlock2DCrossAttn<T> {
520    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
521        Err(FerrotorchError::InvalidArgument {
522            message: "UNetMidBlock2DCrossAttn::forward requires (x, temb, ehs) — call \
523                      forward_t instead"
524                .into(),
525        })
526    }
527
528    fn parameters(&self) -> Vec<&Parameter<T>> {
529        let mut o = Vec::new();
530        // Match diffusers: attentions before resnets at the state-dict
531        // level. The forward path uses resnets[0] first, but the named
532        // params are listed `attentions.*` then `resnets.*`.
533        for a in &self.attentions {
534            o.extend(a.parameters());
535        }
536        for r in &self.resnets {
537            o.extend(r.parameters());
538        }
539        o
540    }
541    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
542        let mut o = Vec::new();
543        for a in &mut self.attentions {
544            o.extend(a.parameters_mut());
545        }
546        for r in &mut self.resnets {
547            o.extend(r.parameters_mut());
548        }
549        o
550    }
551    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
552        let mut o = Vec::new();
553        for (i, a) in self.attentions.iter().enumerate() {
554            for (n, p) in a.named_parameters() {
555                o.push((format!("attentions.{i}.{n}"), p));
556            }
557        }
558        for (i, r) in self.resnets.iter().enumerate() {
559            for (n, p) in r.named_parameters() {
560                o.push((format!("resnets.{i}.{n}"), p));
561            }
562        }
563        o
564    }
565    fn train(&mut self) {
566        self.training = true;
567    }
568    fn eval(&mut self) {
569        self.training = false;
570    }
571    fn is_training(&self) -> bool {
572        self.training
573    }
574    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
575        let extract = |prefix: &str| -> StateDict<T> {
576            let p = format!("{prefix}.");
577            state
578                .iter()
579                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
580                .collect()
581        };
582        if strict {
583            for k in state.keys() {
584                let ok = k.starts_with("resnets.") || k.starts_with("attentions.");
585                if !ok {
586                    return Err(FerrotorchError::InvalidArgument {
587                        message: format!(
588                            "unexpected key in UNetMidBlock2DCrossAttn state_dict: \"{k}\""
589                        ),
590                    });
591                }
592            }
593        }
594        for (i, a) in self.attentions.iter_mut().enumerate() {
595            a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
596        }
597        for (i, r) in self.resnets.iter_mut().enumerate() {
598            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
599        }
600        Ok(())
601    }
602}
603
604// ---------------------------------------------------------------------------
605// CrossAttnUpBlock2D
606// ---------------------------------------------------------------------------
607
608/// `CrossAttnUpBlock2D` — `(layers_per_block + 1)` × (ResnetBlock2DTime
609/// + Transformer2DModel) + optional Upsample2D. Each resnet's input
610///   is `cat([h, skip], dim=channel)` where `skip` is popped from the
611///   down-side skip stack.
612#[derive(Debug)]
613pub struct CrossAttnUpBlock2D<T: Float> {
614    /// `(layers_per_block + 1)` time-conditioned resnets.
615    pub resnets: Vec<ResnetBlock2DTime<T>>,
616    /// `(layers_per_block + 1)` transformer blocks.
617    pub attentions: Vec<Transformer2DModel<T>>,
618    /// Optional upsample.
619    pub upsamplers_0: Option<Upsample2D<T>>,
620    training: bool,
621}
622
623impl<T: Float> CrossAttnUpBlock2D<T> {
624    /// Build a randomly-initialized cross-attention up-block.
625    ///
626    /// # Errors
627    ///
628    /// Returns the underlying [`FerrotorchError`].
629    #[allow(clippy::too_many_arguments)]
630    pub fn new(
631        in_channels: usize,
632        out_channels: usize,
633        prev_output_channels: usize,
634        temb_channels: usize,
635        num_layers: usize,
636        attention_head_dim: usize,
637        cross_attention_dim: usize,
638        norm_num_groups: usize,
639        add_upsample: bool,
640        transformer_layers_per_block: usize,
641    ) -> FerrotorchResult<Self> {
642        // See CrossAttnDownBlock2D::new for the heads/dim_head footgun.
643        let heads = attention_head_dim;
644        let dim_head = out_channels / heads;
645        let mut resnets = Vec::with_capacity(num_layers);
646        let mut attentions = Vec::with_capacity(num_layers);
647        // Diffusers' resnet input channel widths for an up-block of
648        // length L:
649        //   resnet[0]:  in = prev_output + skip_at_same_resolution
650        //   resnet[i in 1..L-1]: in = out + skip_at_same_resolution
651        //   resnet[L-1]: in = out + in_channels (the skip at the
652        //                deepest resolution feeds back to the *previous*
653        //                resolution's res output)
654        // For the canonical SD shape:
655        //   skip for j = 0:  res_skip = out
656        //   skip for j = 1:  res_skip = out
657        //   skip for j = L-1: res_skip = in_channels (the channel count
658        //                that came from the down-side pre-downsample)
659        for j in 0..num_layers {
660            let res_skip = if j == num_layers - 1 {
661                in_channels
662            } else {
663                out_channels
664            };
665            let resnet_in = if j == 0 {
666                prev_output_channels + res_skip
667            } else {
668                out_channels + res_skip
669            };
670            resnets.push(ResnetBlock2DTime::<T>::new(
671                resnet_in,
672                out_channels,
673                temb_channels,
674                norm_num_groups,
675                1e-5,
676            )?);
677            attentions.push(Transformer2DModel::<T>::new(
678                out_channels,
679                heads,
680                dim_head,
681                transformer_layers_per_block,
682                cross_attention_dim,
683                norm_num_groups,
684            )?);
685        }
686        let upsamplers_0 = if add_upsample {
687            Some(Upsample2D::<T>::new(out_channels)?)
688        } else {
689            None
690        };
691        Ok(Self {
692            resnets,
693            attentions,
694            upsamplers_0,
695            training: false,
696        })
697    }
698
699    /// Forward, consuming `skips.len()` skip tensors (one per resnet).
700    ///
701    /// # Errors
702    ///
703    /// Returns the underlying [`FerrotorchError`].
704    pub fn forward_t(
705        &self,
706        x: &Tensor<T>,
707        skips: &[Tensor<T>],
708        temb: &Tensor<T>,
709        encoder_hidden_states: &Tensor<T>,
710    ) -> FerrotorchResult<Tensor<T>> {
711        if skips.len() != self.resnets.len() {
712            return Err(FerrotorchError::InvalidArgument {
713                message: format!(
714                    "CrossAttnUpBlock2D::forward_t: got {} skips for {} resnets",
715                    skips.len(),
716                    self.resnets.len()
717                ),
718            });
719        }
720        let mut h = x.clone();
721        for ((r, a), skip) in self
722            .resnets
723            .iter()
724            .zip(self.attentions.iter())
725            .zip(skips.iter())
726        {
727            h = ferrotorch_core::grad_fns::shape::cat(&[h.clone(), skip.clone()], 1)?;
728            h = r.forward_t(&h, temb)?;
729            h = a.forward_xattn(&h, encoder_hidden_states)?;
730        }
731        if let Some(u) = &self.upsamplers_0 {
732            h = u.forward(&h)?;
733        }
734        Ok(h)
735    }
736}
737
738impl<T: Float> Module<T> for CrossAttnUpBlock2D<T> {
739    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
740        Err(FerrotorchError::InvalidArgument {
741            message: "CrossAttnUpBlock2D::forward requires (x, skips, temb, ehs) — call \
742                      forward_t instead"
743                .into(),
744        })
745    }
746    fn parameters(&self) -> Vec<&Parameter<T>> {
747        let mut o = Vec::new();
748        for r in &self.resnets {
749            o.extend(r.parameters());
750        }
751        for a in &self.attentions {
752            o.extend(a.parameters());
753        }
754        if let Some(u) = &self.upsamplers_0 {
755            o.extend(u.parameters());
756        }
757        o
758    }
759    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
760        let mut o = Vec::new();
761        for r in &mut self.resnets {
762            o.extend(r.parameters_mut());
763        }
764        for a in &mut self.attentions {
765            o.extend(a.parameters_mut());
766        }
767        if let Some(u) = self.upsamplers_0.as_mut() {
768            o.extend(u.parameters_mut());
769        }
770        o
771    }
772    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
773        let mut o = Vec::new();
774        for (i, r) in self.resnets.iter().enumerate() {
775            for (n, p) in r.named_parameters() {
776                o.push((format!("resnets.{i}.{n}"), p));
777            }
778        }
779        for (i, a) in self.attentions.iter().enumerate() {
780            for (n, p) in a.named_parameters() {
781                o.push((format!("attentions.{i}.{n}"), p));
782            }
783        }
784        if let Some(u) = &self.upsamplers_0 {
785            for (n, p) in u.named_parameters() {
786                o.push((format!("upsamplers.0.{n}"), p));
787            }
788        }
789        o
790    }
791    fn train(&mut self) {
792        self.training = true;
793    }
794    fn eval(&mut self) {
795        self.training = false;
796    }
797    fn is_training(&self) -> bool {
798        self.training
799    }
800    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
801        let extract = |prefix: &str| -> StateDict<T> {
802            let p = format!("{prefix}.");
803            state
804                .iter()
805                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
806                .collect()
807        };
808        if strict {
809            for k in state.keys() {
810                let ok = k.starts_with("resnets.")
811                    || k.starts_with("attentions.")
812                    || k.starts_with("upsamplers.0.");
813                if !ok {
814                    return Err(FerrotorchError::InvalidArgument {
815                        message: format!(
816                            "unexpected key in CrossAttnUpBlock2D state_dict: \"{k}\""
817                        ),
818                    });
819                }
820            }
821        }
822        for (i, r) in self.resnets.iter_mut().enumerate() {
823            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
824        }
825        for (i, a) in self.attentions.iter_mut().enumerate() {
826            a.load_state_dict(&extract(&format!("attentions.{i}")), strict)?;
827        }
828        if let Some(u) = self.upsamplers_0.as_mut() {
829            u.load_state_dict(&extract("upsamplers.0"), strict)?;
830        }
831        Ok(())
832    }
833}
834
835// ---------------------------------------------------------------------------
836// UpBlock2D (no cross-attn — the first SD up-block)
837// ---------------------------------------------------------------------------
838
839/// `UpBlock2D` — `(layers_per_block + 1)` ResnetBlock2DTime + optional
840/// Upsample2D. No transformer attention.
841#[derive(Debug)]
842pub struct UpBlock2D<T: Float> {
843    /// `(layers_per_block + 1)` time-conditioned resnets.
844    pub resnets: Vec<ResnetBlock2DTime<T>>,
845    /// Optional upsample.
846    pub upsamplers_0: Option<Upsample2D<T>>,
847    training: bool,
848}
849
850impl<T: Float> UpBlock2D<T> {
851    /// Build a randomly-initialized non-cross-attn up-block.
852    ///
853    /// # Errors
854    ///
855    /// Returns the underlying [`FerrotorchError`].
856    #[allow(clippy::too_many_arguments)]
857    pub fn new(
858        in_channels: usize,
859        out_channels: usize,
860        prev_output_channels: usize,
861        temb_channels: usize,
862        num_layers: usize,
863        norm_num_groups: usize,
864        add_upsample: bool,
865    ) -> FerrotorchResult<Self> {
866        let mut resnets = Vec::with_capacity(num_layers);
867        for j in 0..num_layers {
868            let res_skip = if j == num_layers - 1 {
869                in_channels
870            } else {
871                out_channels
872            };
873            let resnet_in = if j == 0 {
874                prev_output_channels + res_skip
875            } else {
876                out_channels + res_skip
877            };
878            resnets.push(ResnetBlock2DTime::<T>::new(
879                resnet_in,
880                out_channels,
881                temb_channels,
882                norm_num_groups,
883                1e-5,
884            )?);
885        }
886        let upsamplers_0 = if add_upsample {
887            Some(Upsample2D::<T>::new(out_channels)?)
888        } else {
889            None
890        };
891        Ok(Self {
892            resnets,
893            upsamplers_0,
894            training: false,
895        })
896    }
897
898    /// Forward.
899    ///
900    /// # Errors
901    ///
902    /// Returns the underlying [`FerrotorchError`].
903    pub fn forward_t(
904        &self,
905        x: &Tensor<T>,
906        skips: &[Tensor<T>],
907        temb: &Tensor<T>,
908    ) -> FerrotorchResult<Tensor<T>> {
909        if skips.len() != self.resnets.len() {
910            return Err(FerrotorchError::InvalidArgument {
911                message: format!(
912                    "UpBlock2D::forward_t: got {} skips for {} resnets",
913                    skips.len(),
914                    self.resnets.len()
915                ),
916            });
917        }
918        let mut h = x.clone();
919        for (r, skip) in self.resnets.iter().zip(skips.iter()) {
920            h = ferrotorch_core::grad_fns::shape::cat(&[h.clone(), skip.clone()], 1)?;
921            h = r.forward_t(&h, temb)?;
922        }
923        if let Some(u) = &self.upsamplers_0 {
924            h = u.forward(&h)?;
925        }
926        Ok(h)
927    }
928}
929
930impl<T: Float> Module<T> for UpBlock2D<T> {
931    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
932        Err(FerrotorchError::InvalidArgument {
933            message: "UpBlock2D::forward requires (x, skips, temb) — call forward_t instead".into(),
934        })
935    }
936    fn parameters(&self) -> Vec<&Parameter<T>> {
937        let mut o = Vec::new();
938        for r in &self.resnets {
939            o.extend(r.parameters());
940        }
941        if let Some(u) = &self.upsamplers_0 {
942            o.extend(u.parameters());
943        }
944        o
945    }
946    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
947        let mut o = Vec::new();
948        for r in &mut self.resnets {
949            o.extend(r.parameters_mut());
950        }
951        if let Some(u) = self.upsamplers_0.as_mut() {
952            o.extend(u.parameters_mut());
953        }
954        o
955    }
956    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
957        let mut o = Vec::new();
958        for (i, r) in self.resnets.iter().enumerate() {
959            for (n, p) in r.named_parameters() {
960                o.push((format!("resnets.{i}.{n}"), p));
961            }
962        }
963        if let Some(u) = &self.upsamplers_0 {
964            for (n, p) in u.named_parameters() {
965                o.push((format!("upsamplers.0.{n}"), p));
966            }
967        }
968        o
969    }
970    fn train(&mut self) {
971        self.training = true;
972    }
973    fn eval(&mut self) {
974        self.training = false;
975    }
976    fn is_training(&self) -> bool {
977        self.training
978    }
979    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
980        let extract = |prefix: &str| -> StateDict<T> {
981            let p = format!("{prefix}.");
982            state
983                .iter()
984                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
985                .collect()
986        };
987        if strict {
988            for k in state.keys() {
989                let ok = k.starts_with("resnets.") || k.starts_with("upsamplers.0.");
990                if !ok {
991                    return Err(FerrotorchError::InvalidArgument {
992                        message: format!("unexpected key in UpBlock2D state_dict: \"{k}\""),
993                    });
994                }
995            }
996        }
997        for (i, r) in self.resnets.iter_mut().enumerate() {
998            r.load_state_dict(&extract(&format!("resnets.{i}")), strict)?;
999        }
1000        if let Some(u) = self.upsamplers_0.as_mut() {
1001            u.load_state_dict(&extract("upsamplers.0"), strict)?;
1002        }
1003        Ok(())
1004    }
1005}
1006
1007// ---------------------------------------------------------------------------
1008// Discriminated down/up block enum
1009// ---------------------------------------------------------------------------
1010
1011/// Discriminated wrapper over the two down-block variants.
1012#[derive(Debug)]
1013pub enum AnyDownBlock<T: Float> {
1014    /// CrossAttn variant.
1015    CrossAttn(CrossAttnDownBlock2D<T>),
1016    /// Plain variant.
1017    Plain(DownBlock2D<T>),
1018}
1019
1020impl<T: Float> AnyDownBlock<T> {
1021    fn forward_t(
1022        &self,
1023        x: &Tensor<T>,
1024        temb: &Tensor<T>,
1025        encoder_hidden_states: &Tensor<T>,
1026    ) -> FerrotorchResult<(Tensor<T>, Vec<Tensor<T>>)> {
1027        match self {
1028            AnyDownBlock::CrossAttn(b) => b.forward_t(x, temb, encoder_hidden_states),
1029            AnyDownBlock::Plain(b) => b.forward_t(x, temb),
1030        }
1031    }
1032
1033    fn parameters(&self) -> Vec<&Parameter<T>> {
1034        match self {
1035            AnyDownBlock::CrossAttn(b) => b.parameters(),
1036            AnyDownBlock::Plain(b) => b.parameters(),
1037        }
1038    }
1039    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1040        match self {
1041            AnyDownBlock::CrossAttn(b) => b.parameters_mut(),
1042            AnyDownBlock::Plain(b) => b.parameters_mut(),
1043        }
1044    }
1045    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1046        match self {
1047            AnyDownBlock::CrossAttn(b) => b.named_parameters(),
1048            AnyDownBlock::Plain(b) => b.named_parameters(),
1049        }
1050    }
1051    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1052        match self {
1053            AnyDownBlock::CrossAttn(b) => b.load_state_dict(state, strict),
1054            AnyDownBlock::Plain(b) => b.load_state_dict(state, strict),
1055        }
1056    }
1057}
1058
1059/// Discriminated wrapper over the two up-block variants.
1060#[derive(Debug)]
1061pub enum AnyUpBlock<T: Float> {
1062    /// CrossAttn variant.
1063    CrossAttn(CrossAttnUpBlock2D<T>),
1064    /// Plain variant.
1065    Plain(UpBlock2D<T>),
1066}
1067
1068impl<T: Float> AnyUpBlock<T> {
1069    fn forward_t(
1070        &self,
1071        x: &Tensor<T>,
1072        skips: &[Tensor<T>],
1073        temb: &Tensor<T>,
1074        encoder_hidden_states: &Tensor<T>,
1075    ) -> FerrotorchResult<Tensor<T>> {
1076        match self {
1077            AnyUpBlock::CrossAttn(b) => b.forward_t(x, skips, temb, encoder_hidden_states),
1078            AnyUpBlock::Plain(b) => b.forward_t(x, skips, temb),
1079        }
1080    }
1081    fn num_resnets(&self) -> usize {
1082        match self {
1083            AnyUpBlock::CrossAttn(b) => b.resnets.len(),
1084            AnyUpBlock::Plain(b) => b.resnets.len(),
1085        }
1086    }
1087    fn parameters(&self) -> Vec<&Parameter<T>> {
1088        match self {
1089            AnyUpBlock::CrossAttn(b) => b.parameters(),
1090            AnyUpBlock::Plain(b) => b.parameters(),
1091        }
1092    }
1093    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1094        match self {
1095            AnyUpBlock::CrossAttn(b) => b.parameters_mut(),
1096            AnyUpBlock::Plain(b) => b.parameters_mut(),
1097        }
1098    }
1099    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1100        match self {
1101            AnyUpBlock::CrossAttn(b) => b.named_parameters(),
1102            AnyUpBlock::Plain(b) => b.named_parameters(),
1103        }
1104    }
1105    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1106        match self {
1107            AnyUpBlock::CrossAttn(b) => b.load_state_dict(state, strict),
1108            AnyUpBlock::Plain(b) => b.load_state_dict(state, strict),
1109        }
1110    }
1111}
1112
1113// ---------------------------------------------------------------------------
1114// UNet2DConditionModel
1115// ---------------------------------------------------------------------------
1116
1117/// Diffusers-style `UNet2DConditionModel` — the SD-1.5 UNet.
1118#[derive(Debug)]
1119pub struct UNet2DConditionModel<T: Float> {
1120    /// Parameter-free sinusoidal encoding.
1121    pub time_proj: Timesteps,
1122    /// Trainable MLP after the encoding.
1123    pub time_embedding: TimestepEmbedding<T>,
1124    /// `conv_in`: Conv2d(in_channels, block_out_channels[0], k=3, pad=1).
1125    pub conv_in: Conv2d<T>,
1126    /// 4 down blocks.
1127    pub down_blocks: Vec<AnyDownBlock<T>>,
1128    /// 1 mid block.
1129    pub mid_block: UNetMidBlock2DCrossAttn<T>,
1130    /// 4 up blocks (in diffusers' order — index 0 is the deepest /
1131    /// lowest resolution / no-cross-attn block).
1132    pub up_blocks: Vec<AnyUpBlock<T>>,
1133    /// `conv_norm_out`: GroupNorm(32, block_out_channels[0]).
1134    pub conv_norm_out: GroupNorm<T>,
1135    /// SiLU activation.
1136    pub conv_act: SiLU,
1137    /// `conv_out`: Conv2d(block_out_channels[0], out_channels, k=3, pad=1).
1138    pub conv_out: Conv2d<T>,
1139    /// Frozen config copy.
1140    pub config: UNet2DConditionConfig,
1141    training: bool,
1142}
1143
1144impl<T: Float> UNet2DConditionModel<T> {
1145    /// Build a randomly-initialized `UNet2DConditionModel`.
1146    ///
1147    /// # Errors
1148    ///
1149    /// Returns [`FerrotorchError::InvalidArgument`] for invalid config or
1150    /// underlying [`FerrotorchError`] from the conv constructor.
1151    pub fn new(cfg: UNet2DConditionConfig) -> FerrotorchResult<Self> {
1152        cfg.validate()?;
1153        let temb_channels = cfg.time_embed_dim();
1154        let bocs = &cfg.block_out_channels;
1155        let groups = cfg.norm_num_groups;
1156        let num_blocks = bocs.len();
1157
1158        // time projection: sinusoidal encoding into block_out_channels[0],
1159        // then Linear(block_out_channels[0] -> time_embed_dim) + SiLU +
1160        // Linear(time_embed_dim -> time_embed_dim).
1161        let time_proj = Timesteps::new(bocs[0], cfg.flip_sin_to_cos, cfg.freq_shift)?;
1162        let time_embedding = TimestepEmbedding::<T>::new(bocs[0], temb_channels)?;
1163
1164        // conv_in.
1165        let conv_in = Conv2d::<T>::new(cfg.in_channels, bocs[0], (3, 3), (1, 1), (1, 1), true)?;
1166
1167        // Down blocks.
1168        let mut down_blocks: Vec<AnyDownBlock<T>> = Vec::with_capacity(num_blocks);
1169        let mut prev = bocs[0];
1170        for i in 0..num_blocks {
1171            let out_c = bocs[i];
1172            let is_final = i == num_blocks - 1;
1173            let add_downsample = !is_final;
1174            let block: AnyDownBlock<T> = if cfg.down_block_has_attn[i] {
1175                AnyDownBlock::CrossAttn(CrossAttnDownBlock2D::<T>::new(
1176                    prev,
1177                    out_c,
1178                    temb_channels,
1179                    cfg.layers_per_block,
1180                    cfg.attention_head_dim,
1181                    cfg.cross_attention_dim,
1182                    groups,
1183                    add_downsample,
1184                    cfg.transformer_layers_per_block,
1185                )?)
1186            } else {
1187                AnyDownBlock::Plain(DownBlock2D::<T>::new(
1188                    prev,
1189                    out_c,
1190                    temb_channels,
1191                    cfg.layers_per_block,
1192                    groups,
1193                    add_downsample,
1194                )?)
1195            };
1196            down_blocks.push(block);
1197            prev = out_c;
1198        }
1199
1200        // Mid block (at the deepest channels = bocs[-1]).
1201        let mid_channels = bocs[num_blocks - 1];
1202        let mid_block = UNetMidBlock2DCrossAttn::<T>::new(
1203            mid_channels,
1204            temb_channels,
1205            cfg.attention_head_dim,
1206            cfg.cross_attention_dim,
1207            groups,
1208            cfg.transformer_layers_per_block,
1209        )?;
1210
1211        // Up blocks — iterate through reversed_block_out_channels.
1212        //   reversed = bocs.reverse() = e.g. [1280, 1280, 640, 320]
1213        //   For each up-block i in 0..num_blocks:
1214        //     in_channels = reversed[min(i+1, num_blocks-1)] (channel
1215        //                    count of the corresponding down-block's
1216        //                    pre-downsample resolution)
1217        //     out_channels = reversed[i]
1218        //     prev_output  = reversed[i-1] if i>0 else mid_channels
1219        //                    (== reversed[0])
1220        //   Diffusers gates `add_upsample` on `i < num_blocks - 1`.
1221        let mut up_blocks: Vec<AnyUpBlock<T>> = Vec::with_capacity(num_blocks);
1222        let reversed: Vec<usize> = bocs.iter().rev().copied().collect();
1223        let mut prev_up = mid_channels;
1224        for i in 0..num_blocks {
1225            let out_c = reversed[i];
1226            let in_c = reversed[(i + 1).min(num_blocks - 1)];
1227            let is_final = i == num_blocks - 1;
1228            let add_upsample = !is_final;
1229            let block: AnyUpBlock<T> = if cfg.up_block_has_attn[i] {
1230                AnyUpBlock::CrossAttn(CrossAttnUpBlock2D::<T>::new(
1231                    in_c,
1232                    out_c,
1233                    prev_up,
1234                    temb_channels,
1235                    cfg.layers_per_block + 1,
1236                    cfg.attention_head_dim,
1237                    cfg.cross_attention_dim,
1238                    groups,
1239                    add_upsample,
1240                    cfg.transformer_layers_per_block,
1241                )?)
1242            } else {
1243                AnyUpBlock::Plain(UpBlock2D::<T>::new(
1244                    in_c,
1245                    out_c,
1246                    prev_up,
1247                    temb_channels,
1248                    cfg.layers_per_block + 1,
1249                    groups,
1250                    add_upsample,
1251                )?)
1252            };
1253            up_blocks.push(block);
1254            prev_up = out_c;
1255        }
1256
1257        let conv_norm_out = GroupNorm::<T>::new(groups, bocs[0], 1e-5, true)?;
1258        let conv_out = Conv2d::<T>::new(bocs[0], cfg.out_channels, (3, 3), (1, 1), (1, 1), true)?;
1259
1260        Ok(Self {
1261            time_proj,
1262            time_embedding,
1263            conv_in,
1264            down_blocks,
1265            mid_block,
1266            up_blocks,
1267            conv_norm_out,
1268            conv_act: SiLU::new(),
1269            conv_out,
1270            config: cfg,
1271            training: false,
1272        })
1273    }
1274
1275    /// Run UNet forward.
1276    ///
1277    /// * `sample`: `[B, in_channels, H, W]` — noisy latent.
1278    /// * `timesteps`: `[B]` — diffusion timestep per batch entry.
1279    /// * `encoder_hidden_states`: `[B, S, cross_attention_dim]` — the
1280    ///   text-encoder conditioning.
1281    ///
1282    /// Returns the predicted noise `[B, out_channels, H, W]`.
1283    ///
1284    /// # Errors
1285    ///
1286    /// Returns [`FerrotorchError::ShapeMismatch`] for bad input ranks.
1287    pub fn forward_t(
1288        &self,
1289        sample: &Tensor<T>,
1290        timesteps: &Tensor<T>,
1291        encoder_hidden_states: &Tensor<T>,
1292    ) -> FerrotorchResult<Tensor<T>> {
1293        let cfg = &self.config;
1294        if sample.ndim() != 4 || sample.shape()[1] != cfg.in_channels {
1295            return Err(FerrotorchError::ShapeMismatch {
1296                message: format!(
1297                    "UNet2DConditionModel: expected sample [B, {}, H, W], got {:?}",
1298                    cfg.in_channels,
1299                    sample.shape()
1300                ),
1301            });
1302        }
1303        if timesteps.ndim() != 1 {
1304            return Err(FerrotorchError::ShapeMismatch {
1305                message: format!(
1306                    "UNet2DConditionModel: expected timesteps [B], got {:?}",
1307                    timesteps.shape()
1308                ),
1309            });
1310        }
1311        if encoder_hidden_states.ndim() != 3
1312            || encoder_hidden_states.shape()[2] != cfg.cross_attention_dim
1313        {
1314            return Err(FerrotorchError::ShapeMismatch {
1315                message: format!(
1316                    "UNet2DConditionModel: expected encoder_hidden_states [B, S, {}], got {:?}",
1317                    cfg.cross_attention_dim,
1318                    encoder_hidden_states.shape()
1319                ),
1320            });
1321        }
1322
1323        // 1. Time embedding: [B] -> [B, bocs[0]] -> [B, time_embed_dim].
1324        let t_enc = self.time_proj.forward_t(timesteps)?;
1325        let temb = self.time_embedding.forward(&t_enc)?;
1326
1327        // 2. conv_in.
1328        let mut h = self.conv_in.forward(sample)?;
1329        let mut skips: Vec<Tensor<T>> = Vec::new();
1330        // The initial conv_in output is the first skip (the down-block
1331        // loop pushes its own; diffusers prepends `sample` so this skip
1332        // is consumed by the *last* up-block's last resnet).
1333        skips.push(h.clone());
1334
1335        // 3. Down blocks.
1336        for db in &self.down_blocks {
1337            let (out, mut block_skips) = db.forward_t(&h, &temb, encoder_hidden_states)?;
1338            h = out;
1339            skips.append(&mut block_skips);
1340        }
1341
1342        // 4. Mid block.
1343        h = self.mid_block.forward_t(&h, &temb, encoder_hidden_states)?;
1344
1345        // 5. Up blocks. Each consumes the trailing N skips (where N =
1346        //    block.num_resnets()).
1347        for ub in &self.up_blocks {
1348            let n = ub.num_resnets();
1349            if skips.len() < n {
1350                return Err(FerrotorchError::InvalidArgument {
1351                    message: format!(
1352                        "UNet2DConditionModel: up-block needs {n} skips, only {} left",
1353                        skips.len()
1354                    ),
1355                });
1356            }
1357            let split_at = skips.len() - n;
1358            let popped = skips.split_off(split_at);
1359            // Diffusers iterates skips in the order they were popped
1360            // (most-recent-first per resnet inside the block). The
1361            // standard idiom is `res_samples = skips[-N:]` then the
1362            // block consumes them in reverse order (one per resnet).
1363            // We reverse `popped` so resnet[0] gets the most recently
1364            // pushed skip — the same as diffusers'
1365            // `down_block_res_samples[-1]` for `res_hidden_states_tuple
1366            // [-1]`.
1367            let popped_rev: Vec<Tensor<T>> = popped.into_iter().rev().collect();
1368            h = ub.forward_t(&h, &popped_rev, &temb, encoder_hidden_states)?;
1369        }
1370
1371        // 6. Output head.
1372        h = self.conv_norm_out.forward(&h)?;
1373        h = self.conv_act.forward(&h)?;
1374        self.conv_out.forward(&h)
1375    }
1376}
1377
1378impl<T: Float> Module<T> for UNet2DConditionModel<T> {
1379    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1380        Err(FerrotorchError::InvalidArgument {
1381            message: "UNet2DConditionModel::forward requires (sample, timesteps, ehs) — \
1382                      call forward_t instead"
1383                .into(),
1384        })
1385    }
1386
1387    fn parameters(&self) -> Vec<&Parameter<T>> {
1388        let mut o = Vec::new();
1389        o.extend(self.time_embedding.parameters());
1390        o.extend(self.conv_in.parameters());
1391        for db in &self.down_blocks {
1392            o.extend(db.parameters());
1393        }
1394        o.extend(self.mid_block.parameters());
1395        for ub in &self.up_blocks {
1396            o.extend(ub.parameters());
1397        }
1398        o.extend(self.conv_norm_out.parameters());
1399        o.extend(self.conv_out.parameters());
1400        o
1401    }
1402
1403    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1404        let mut o = Vec::new();
1405        o.extend(self.time_embedding.parameters_mut());
1406        o.extend(self.conv_in.parameters_mut());
1407        for db in &mut self.down_blocks {
1408            o.extend(db.parameters_mut());
1409        }
1410        o.extend(self.mid_block.parameters_mut());
1411        for ub in &mut self.up_blocks {
1412            o.extend(ub.parameters_mut());
1413        }
1414        o.extend(self.conv_norm_out.parameters_mut());
1415        o.extend(self.conv_out.parameters_mut());
1416        o
1417    }
1418
1419    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1420        let mut o = Vec::new();
1421        for (n, p) in self.time_embedding.named_parameters() {
1422            o.push((format!("time_embedding.{n}"), p));
1423        }
1424        for (n, p) in self.conv_in.named_parameters() {
1425            o.push((format!("conv_in.{n}"), p));
1426        }
1427        for (i, db) in self.down_blocks.iter().enumerate() {
1428            for (n, p) in db.named_parameters() {
1429                o.push((format!("down_blocks.{i}.{n}"), p));
1430            }
1431        }
1432        for (n, p) in self.mid_block.named_parameters() {
1433            o.push((format!("mid_block.{n}"), p));
1434        }
1435        for (i, ub) in self.up_blocks.iter().enumerate() {
1436            for (n, p) in ub.named_parameters() {
1437                o.push((format!("up_blocks.{i}.{n}"), p));
1438            }
1439        }
1440        for (n, p) in self.conv_norm_out.named_parameters() {
1441            o.push((format!("conv_norm_out.{n}"), p));
1442        }
1443        for (n, p) in self.conv_out.named_parameters() {
1444            o.push((format!("conv_out.{n}"), p));
1445        }
1446        o
1447    }
1448    fn train(&mut self) {
1449        self.training = true;
1450    }
1451    fn eval(&mut self) {
1452        self.training = false;
1453    }
1454    fn is_training(&self) -> bool {
1455        self.training
1456    }
1457    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1458        let extract = |prefix: &str| -> StateDict<T> {
1459            let p = format!("{prefix}.");
1460            state
1461                .iter()
1462                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1463                .collect()
1464        };
1465        if strict {
1466            for k in state.keys() {
1467                let ok = k.starts_with("time_embedding.")
1468                    || k.starts_with("conv_in.")
1469                    || k.starts_with("down_blocks.")
1470                    || k.starts_with("mid_block.")
1471                    || k.starts_with("up_blocks.")
1472                    || k.starts_with("conv_norm_out.")
1473                    || k.starts_with("conv_out.");
1474                if !ok {
1475                    return Err(FerrotorchError::InvalidArgument {
1476                        message: format!(
1477                            "unexpected key in UNet2DConditionModel state_dict: \"{k}\""
1478                        ),
1479                    });
1480                }
1481            }
1482        }
1483        self.time_embedding
1484            .load_state_dict(&extract("time_embedding"), strict)?;
1485        self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
1486        for (i, db) in self.down_blocks.iter_mut().enumerate() {
1487            db.load_state_dict(&extract(&format!("down_blocks.{i}")), strict)?;
1488        }
1489        self.mid_block
1490            .load_state_dict(&extract("mid_block"), strict)?;
1491        for (i, ub) in self.up_blocks.iter_mut().enumerate() {
1492            ub.load_state_dict(&extract(&format!("up_blocks.{i}")), strict)?;
1493        }
1494        self.conv_norm_out
1495            .load_state_dict(&extract("conv_norm_out"), strict)?;
1496        self.conv_out
1497            .load_state_dict(&extract("conv_out"), strict)?;
1498        Ok(())
1499    }
1500}
1501
1502#[cfg(test)]
1503mod tests {
1504    use super::*;
1505    use ferrotorch_core::TensorStorage;
1506
1507    fn tiny_cfg() -> UNet2DConditionConfig {
1508        UNet2DConditionConfig {
1509            in_channels: 4,
1510            out_channels: 4,
1511            block_out_channels: vec![16, 32, 64, 64],
1512            layers_per_block: 1,
1513            attention_head_dim: 8,
1514            cross_attention_dim: 24,
1515            norm_num_groups: 4,
1516            sample_size: 8,
1517            flip_sin_to_cos: true,
1518            freq_shift: 0.0,
1519            transformer_layers_per_block: 1,
1520            down_block_has_attn: vec![true, true, true, false],
1521            up_block_has_attn: vec![false, true, true, true],
1522        }
1523    }
1524
1525    #[test]
1526    fn unet_forward_shape() {
1527        let cfg = tiny_cfg();
1528        let unet = UNet2DConditionModel::<f32>::new(cfg.clone()).unwrap();
1529        let sample = Tensor::from_storage(
1530            TensorStorage::cpu(vec![0.01f32; 4 * 8 * 8]),
1531            vec![1, 4, 8, 8],
1532            false,
1533        )
1534        .unwrap();
1535        let timesteps =
1536            Tensor::from_storage(TensorStorage::cpu(vec![5.0f32]), vec![1], false).unwrap();
1537        let ehs = Tensor::from_storage(
1538            TensorStorage::cpu(vec![0.01f32; 7 * 24]),
1539            vec![1, 7, 24],
1540            false,
1541        )
1542        .unwrap();
1543        let y = unet.forward_t(&sample, &timesteps, &ehs).unwrap();
1544        assert_eq!(y.shape(), &[1, 4, 8, 8]);
1545    }
1546
1547    #[test]
1548    fn unet_named_parameters_includes_canonical_keys() {
1549        let cfg = tiny_cfg();
1550        let unet = UNet2DConditionModel::<f32>::new(cfg).unwrap();
1551        let names: Vec<String> = unet
1552            .named_parameters()
1553            .into_iter()
1554            .map(|(n, _)| n)
1555            .collect();
1556        for k in [
1557            "time_embedding.linear_1.weight",
1558            "time_embedding.linear_2.bias",
1559            "conv_in.weight",
1560            "down_blocks.0.resnets.0.norm1.weight",
1561            "down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight",
1562            "mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight",
1563            "mid_block.resnets.1.conv2.weight",
1564            "up_blocks.0.resnets.0.norm1.weight",
1565            "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
1566            "conv_norm_out.weight",
1567            "conv_out.bias",
1568        ] {
1569            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
1570        }
1571    }
1572}