Skip to main content

ferrotorch_diffusion/
attention.rs

1//! Multi-head attention + the `Transformer2DModel` wrapper used by the
2//! SD UNet's CrossAttn blocks.
3//!
4//! Diffusers' layout 1:1:
5//!
6//! ```text
7//! Attention(query_dim, cross_attention_dim?, heads, dim_head)
8//!   to_q.{weight,bias}    [inner, query_dim]   (inner = heads * dim_head)
9//!   to_k.{weight,bias}    [inner, kv_dim]       (kv_dim = cross_attention_dim
10//!                                                or query_dim for self-attn)
11//!   to_v.{weight,bias}    [inner, kv_dim]
12//!   to_out.0.{weight,bias}[query_dim, inner]
13//!   to_out.1              Dropout (no params)
14//! ```
15//!
16//! For SD-1.5 UNet, biases on `to_q/to_k/to_v` are disabled
17//! (`bias=False`); the output projection `to_out.0` keeps its bias.
18//!
19//! `BasicTransformerBlock`:
20//!
21//! ```text
22//! h0 = LayerNorm1(x)
23//! h1 = Attention1(h0, h0, h0)             # self-attn
24//! x  = x + h1
25//! h0 = LayerNorm2(x)
26//! h2 = Attention2(h0, encoder_hidden, …)  # cross-attn
27//! x  = x + h2
28//! h0 = LayerNorm3(x)
29//! h3 = FeedForward(h0)                    # GEGLU + Linear
30//! x  = x + h3
31//! ```
32//!
33//! `Transformer2DModel`:
34//!
35//! ```text
36//! GroupNorm(32) -> proj_in (Conv2d k=1) -> flatten [B, HW, C]
37//!   -> N × BasicTransformerBlock -> reshape back -> proj_out + residual
38//! ```
39//!
40//! ## REQ status (per `.design/ferrotorch-diffusion/attention.md`)
41//!
42//! | REQ | Status | Evidence |
43//! |---|---|---|
44//! | REQ-1 | SHIPPED | `Attention::new` at `ferrotorch-diffusion/src/attention.rs:94..121`; consumer: `BasicTransformerBlock::new` at `attention.rs:475..477` calls `Attention::<T>::new` for both self-attn and cross-attn |
45//! | REQ-2 | SHIPPED | `Attention::forward_xattn` at `ferrotorch-diffusion/src/attention.rs:133..209`; consumer: `BasicTransformerBlock::forward_xattn` at `attention.rs:515` and `attention.rs:519` invokes it |
46//! | REQ-3 | SHIPPED | `FeedForward::new` at `ferrotorch-diffusion/src/attention.rs:331..342`; consumer: `BasicTransformerBlock::new` at `attention.rs:479` calls `FeedForward::<T>::new(dim, 4)` |
47//! | REQ-4 | SHIPPED | `BasicTransformerBlock::new` at `ferrotorch-diffusion/src/attention.rs:465..490` and `forward_xattn` at `attention.rs:499..525`; consumer: `Transformer2DModel::new` at `attention.rs:683..688` and forward at `attention.rs:740` |
48//! | REQ-5 | SHIPPED | `Transformer2DModel::new` at `ferrotorch-diffusion/src/attention.rs:669..699` and `forward_xattn` at `attention.rs:709..751`; consumer: `ferrotorch-diffusion/src/unet.rs:116`, `unet.rs:469`, `unet.rs:664` build cross-attn levels |
49//! | REQ-6 | SHIPPED | error guards at `ferrotorch-diffusion/src/attention.rs:529..535` and `attention.rs:755..761`; consumer: every cross-attn caller in `ferrotorch-diffusion/src/unet.rs` uses `forward_xattn` explicitly |
50
51use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
52use ferrotorch_nn::module::{Module, StateDict};
53use ferrotorch_nn::parameter::Parameter;
54use ferrotorch_nn::{Conv2d, GELU, GroupNorm, LayerNorm, Linear};
55
56// ---------------------------------------------------------------------------
57// Attention (multi-head, optional cross-attention)
58// ---------------------------------------------------------------------------
59
60/// Multi-head attention block. Supports self-attention (when `key` and
61/// `value` are derived from the same tensor as `query`) and
62/// cross-attention (when they come from `encoder_hidden_states`).
63///
64/// This is the `Attention` class in
65/// `diffusers.models.attention_processor` configured as it appears in
66/// SD-1.5's UNet — `bias = False` on q/k/v and `out_bias = True` on
67/// `to_out.0`, no `group_norm`, no `spatial_norm`, no
68/// `added_kv_proj_dim`.
69#[derive(Debug)]
70pub struct Attention<T: Float> {
71    /// Per-head dimension.
72    pub dim_head: usize,
73    /// Number of heads.
74    pub heads: usize,
75    /// `inner_dim = heads * dim_head`.
76    pub inner_dim: usize,
77    /// Query projection: `[inner_dim, query_dim]`.
78    pub to_q: Linear<T>,
79    /// Key projection: `[inner_dim, kv_dim]`.
80    pub to_k: Linear<T>,
81    /// Value projection: `[inner_dim, kv_dim]`.
82    pub to_v: Linear<T>,
83    /// Output projection: `[query_dim, inner_dim]` (with bias).
84    pub to_out_0: Linear<T>,
85    query_dim: usize,
86    kv_dim: usize,
87    scale: f64,
88    training: bool,
89}
90
91impl<T: Float> Attention<T> {
92    /// Build a randomly-initialized `Attention`.
93    ///
94    /// `cross_attention_dim = None` means self-attention
95    /// (`kv_dim = query_dim`); a `Some(_)` value enables cross-attention.
96    ///
97    /// `bias` controls `to_q/to_k/to_v` bias (SD-1.5 sets this to false).
98    /// `to_out.0` always has bias (matches diffusers default
99    /// `out_bias=True`).
100    ///
101    /// # Errors
102    ///
103    /// Returns the underlying [`FerrotorchError`] when any `Linear` size
104    /// is invalid.
105    pub fn new(
106        query_dim: usize,
107        cross_attention_dim: Option<usize>,
108        heads: usize,
109        dim_head: usize,
110        bias: bool,
111    ) -> FerrotorchResult<Self> {
112        let inner_dim = heads * dim_head;
113        let kv_dim = cross_attention_dim.unwrap_or(query_dim);
114        let to_q = Linear::<T>::new(query_dim, inner_dim, bias)?;
115        let to_k = Linear::<T>::new(kv_dim, inner_dim, bias)?;
116        let to_v = Linear::<T>::new(kv_dim, inner_dim, bias)?;
117        let to_out_0 = Linear::<T>::new(inner_dim, query_dim, true)?;
118        let scale = (dim_head as f64).sqrt().recip();
119        Ok(Self {
120            dim_head,
121            heads,
122            inner_dim,
123            to_q,
124            to_k,
125            to_v,
126            to_out_0,
127            query_dim,
128            kv_dim,
129            scale,
130            training: false,
131        })
132    }
133
134    /// Forward with optional encoder hidden states.
135    ///
136    /// `hidden_states` has shape `[B, N, query_dim]`. When
137    /// `encoder_hidden_states` is `None` this is self-attention; when
138    /// `Some([B, S, kv_dim])` it's cross-attention.
139    ///
140    /// # Errors
141    ///
142    /// Returns [`FerrotorchError::ShapeMismatch`] on rank / dim
143    /// disagreement.
144    pub fn forward_xattn(
145        &self,
146        hidden_states: &Tensor<T>,
147        encoder_hidden_states: Option<&Tensor<T>>,
148    ) -> FerrotorchResult<Tensor<T>> {
149        if hidden_states.ndim() != 3 || hidden_states.shape()[2] != self.query_dim {
150            return Err(FerrotorchError::ShapeMismatch {
151                message: format!(
152                    "Attention::forward_xattn: expected hidden_states [B, N, {}], got {:?}",
153                    self.query_dim,
154                    hidden_states.shape()
155                ),
156            });
157        }
158        let b = hidden_states.shape()[0];
159        let n = hidden_states.shape()[1];
160        // kv source.
161        let kv = encoder_hidden_states.unwrap_or(hidden_states);
162        if kv.ndim() != 3 || kv.shape()[0] != b || kv.shape()[2] != self.kv_dim {
163            return Err(FerrotorchError::ShapeMismatch {
164                message: format!(
165                    "Attention::forward_xattn: expected kv [B={b}, S, {}], got {:?}",
166                    self.kv_dim,
167                    kv.shape()
168                ),
169            });
170        }
171        let s = kv.shape()[1];
172
173        // -- Linear projections. q: [B, N, inner], k/v: [B, S, inner].
174        let q = self.to_q.forward(hidden_states)?;
175        let k = self.to_k.forward(kv)?;
176        let v = self.to_v.forward(kv)?;
177
178        // -- Reshape to per-head: [B, N, H, D] then transpose to
179        //    [B, H, N, D], collapse to [B*H, N, D] for the BMM kernel.
180        //    Same trick for k, v over S.
181        let h = self.heads;
182        let d = self.dim_head;
183        let q = q
184            .reshape_t(&[b as isize, n as isize, h as isize, d as isize])?
185            .transpose(1, 2)? // [B, H, N, D]
186            .contiguous()?
187            .reshape_t(&[(b * h) as isize, n as isize, d as isize])?;
188        let k = k
189            .reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
190            .transpose(1, 2)? // [B, H, S, D]
191            .contiguous()?
192            .reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
193        let v = v
194            .reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
195            .transpose(1, 2)? // [B, H, S, D]
196            .contiguous()?
197            .reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
198
199        // -- scores = (q @ k^T) * scale.
200        let k_t = k.transpose(1, 2)?.contiguous()?; // [B*H, D, S]
201        let scores = q.bmm(&k_t)?; // [B*H, N, S]
202        let scale_t = T::from(self.scale).ok_or_else(|| FerrotorchError::InvalidArgument {
203            message: "Attention::forward_xattn: failed to cast attention scale into Float".into(),
204        })?;
205        let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
206        let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
207        let probs = scores_scaled.softmax()?; // [B*H, N, S]
208        let attended = probs.bmm(&v)?; // [B*H, N, D]
209
210        // -- Merge heads: [B*H, N, D] -> [B, H, N, D] -> [B, N, H, D]
211        //    -> [B, N, inner_dim].
212        let attended = attended
213            .reshape_t(&[b as isize, h as isize, n as isize, d as isize])?
214            .transpose(1, 2)? // [B, N, H, D]
215            .contiguous()?
216            .reshape_t(&[b as isize, n as isize, self.inner_dim as isize])?;
217
218        // -- Output projection.
219        self.to_out_0.forward(&attended)
220    }
221}
222
223impl<T: Float> Module<T> for Attention<T> {
224    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
225        // Self-attention path.
226        self.forward_xattn(input, None)
227    }
228
229    fn parameters(&self) -> Vec<&Parameter<T>> {
230        let mut o = Vec::new();
231        o.extend(self.to_q.parameters());
232        o.extend(self.to_k.parameters());
233        o.extend(self.to_v.parameters());
234        o.extend(self.to_out_0.parameters());
235        o
236    }
237    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
238        let mut o = Vec::new();
239        o.extend(self.to_q.parameters_mut());
240        o.extend(self.to_k.parameters_mut());
241        o.extend(self.to_v.parameters_mut());
242        o.extend(self.to_out_0.parameters_mut());
243        o
244    }
245    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
246        let mut o = Vec::new();
247        for (n, p) in self.to_q.named_parameters() {
248            o.push((format!("to_q.{n}"), p));
249        }
250        for (n, p) in self.to_k.named_parameters() {
251            o.push((format!("to_k.{n}"), p));
252        }
253        for (n, p) in self.to_v.named_parameters() {
254            o.push((format!("to_v.{n}"), p));
255        }
256        for (n, p) in self.to_out_0.named_parameters() {
257            o.push((format!("to_out.0.{n}"), p));
258        }
259        o
260    }
261    fn train(&mut self) {
262        self.training = true;
263    }
264    fn eval(&mut self) {
265        self.training = false;
266    }
267    fn is_training(&self) -> bool {
268        self.training
269    }
270    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
271        let extract = |prefix: &str| -> StateDict<T> {
272            let p = format!("{prefix}.");
273            state
274                .iter()
275                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
276                .collect()
277        };
278        if strict {
279            for k in state.keys() {
280                let ok = k.starts_with("to_q.")
281                    || k.starts_with("to_k.")
282                    || k.starts_with("to_v.")
283                    || k.starts_with("to_out.0.");
284                if !ok {
285                    return Err(FerrotorchError::InvalidArgument {
286                        message: format!("unexpected key in Attention state_dict: \"{k}\""),
287                    });
288                }
289            }
290        }
291        self.to_q.load_state_dict(&extract("to_q"), strict)?;
292        self.to_k.load_state_dict(&extract("to_k"), strict)?;
293        self.to_v.load_state_dict(&extract("to_v"), strict)?;
294        self.to_out_0
295            .load_state_dict(&extract("to_out.0"), strict)?;
296        Ok(())
297    }
298}
299
300// ---------------------------------------------------------------------------
301// FeedForward (GEGLU + Linear)
302// ---------------------------------------------------------------------------
303
304/// GEGLU-style feed-forward (matches the SD UNet's
305/// `BasicTransformerBlock.ff` exactly):
306///
307/// ```text
308/// net.0  = GEGLU(dim, dim * mult)
309///          = Linear(dim, 2 * dim * mult)         # proj
310///          x, gate = chunk(2, dim=-1)
311///          return x * gelu(gate)
312/// net.1  = Dropout (no params)
313/// net.2  = Linear(dim * mult, dim)
314/// ```
315///
316/// Diffusers' `FeedForward` defaults: `mult = 4`,
317/// `activation_fn = "geglu"`. SD-1.5 UNet uses both defaults.
318///
319/// State-dict layout:
320///
321/// ```text
322/// net.0.proj.{weight,bias}    [2 * dim_ff, dim], [2 * dim_ff]
323/// net.2.{weight,bias}         [dim, dim_ff],     [dim]
324/// ```
325#[derive(Debug)]
326pub struct FeedForward<T: Float> {
327    /// GEGLU's expansion projection (`dim -> 2 * dim_ff`).
328    pub net_0_proj: Linear<T>,
329    /// Output projection (`dim_ff -> dim`).
330    pub net_2: Linear<T>,
331    activation: GELU,
332    dim_ff: usize,
333    training: bool,
334}
335
336impl<T: Float> FeedForward<T> {
337    /// Build a GEGLU `FeedForward` (`dim_ff = dim * mult`).
338    ///
339    /// # Errors
340    ///
341    /// Returns the underlying [`FerrotorchError`] for invalid dims.
342    pub fn new(dim: usize, mult: usize) -> FerrotorchResult<Self> {
343        let dim_ff = dim * mult;
344        let net_0_proj = Linear::<T>::new(dim, 2 * dim_ff, true)?;
345        let net_2 = Linear::<T>::new(dim_ff, dim, true)?;
346        Ok(Self {
347            net_0_proj,
348            net_2,
349            activation: GELU::new(),
350            dim_ff,
351            training: false,
352        })
353    }
354}
355
356impl<T: Float> Module<T> for FeedForward<T> {
357    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
358        // proj -> chunk(2, -1) -> x * gelu(gate)
359        let proj = self.net_0_proj.forward(input)?;
360        // `chunk` operates on a positive dim — last axis here.
361        let last = proj.ndim() - 1;
362        let parts = proj.chunk(2, last)?;
363        if parts.len() != 2 {
364            return Err(FerrotorchError::ShapeMismatch {
365                message: format!(
366                    "FeedForward: chunk(2) returned {} parts (expected 2)",
367                    parts.len()
368                ),
369            });
370        }
371        let x = parts[0].contiguous()?;
372        let gate = parts[1].contiguous()?;
373        let gated = self.activation.forward(&gate)?;
374        let activated = ferrotorch_core::grad_fns::arithmetic::mul(&x, &gated)?;
375        self.net_2.forward(&activated)
376    }
377    fn parameters(&self) -> Vec<&Parameter<T>> {
378        let mut o = Vec::new();
379        o.extend(self.net_0_proj.parameters());
380        o.extend(self.net_2.parameters());
381        o
382    }
383    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
384        let mut o = Vec::new();
385        o.extend(self.net_0_proj.parameters_mut());
386        o.extend(self.net_2.parameters_mut());
387        o
388    }
389    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
390        let mut o = Vec::new();
391        for (n, p) in self.net_0_proj.named_parameters() {
392            o.push((format!("net.0.proj.{n}"), p));
393        }
394        for (n, p) in self.net_2.named_parameters() {
395            o.push((format!("net.2.{n}"), p));
396        }
397        o
398    }
399    fn train(&mut self) {
400        self.training = true;
401    }
402    fn eval(&mut self) {
403        self.training = false;
404    }
405    fn is_training(&self) -> bool {
406        self.training
407    }
408    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
409        let extract = |prefix: &str| -> StateDict<T> {
410            let p = format!("{prefix}.");
411            state
412                .iter()
413                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
414                .collect()
415        };
416        if strict {
417            for k in state.keys() {
418                let ok = k.starts_with("net.0.proj.") || k.starts_with("net.2.");
419                if !ok {
420                    return Err(FerrotorchError::InvalidArgument {
421                        message: format!("unexpected key in FeedForward state_dict: \"{k}\""),
422                    });
423                }
424            }
425        }
426        self.net_0_proj
427            .load_state_dict(&extract("net.0.proj"), strict)?;
428        self.net_2.load_state_dict(&extract("net.2"), strict)?;
429        let _ = self.dim_ff;
430        Ok(())
431    }
432}
433
434// ---------------------------------------------------------------------------
435// BasicTransformerBlock
436// ---------------------------------------------------------------------------
437
438/// Diffusers' `BasicTransformerBlock` configured the way SD-1.5's UNet
439/// uses it: pre-LayerNorm on every sub-layer, self-attn followed by
440/// cross-attn followed by GEGLU FeedForward, all with residuals.
441///
442/// State-dict layout:
443///
444/// ```text
445/// norm1.{weight,bias}   [dim], [dim]
446/// attn1.<keys>          # self-attn (Attention with cross_attention_dim=None)
447/// norm2.{weight,bias}
448/// attn2.<keys>          # cross-attn
449/// norm3.{weight,bias}
450/// ff.<keys>             # FeedForward (GEGLU)
451/// ```
452#[derive(Debug)]
453pub struct BasicTransformerBlock<T: Float> {
454    /// LayerNorm before self-attn.
455    pub norm1: LayerNorm<T>,
456    /// Self-attention.
457    pub attn1: Attention<T>,
458    /// LayerNorm before cross-attn.
459    pub norm2: LayerNorm<T>,
460    /// Cross-attention.
461    pub attn2: Attention<T>,
462    /// LayerNorm before FF.
463    pub norm3: LayerNorm<T>,
464    /// GEGLU FeedForward.
465    pub ff: FeedForward<T>,
466    dim: usize,
467    training: bool,
468}
469
470impl<T: Float> BasicTransformerBlock<T> {
471    /// Build a randomly-initialized `BasicTransformerBlock`.
472    ///
473    /// # Errors
474    ///
475    /// Returns the underlying [`FerrotorchError`] for invalid dims.
476    pub fn new(
477        dim: usize,
478        heads: usize,
479        dim_head: usize,
480        cross_attention_dim: usize,
481    ) -> FerrotorchResult<Self> {
482        // SD-1.5 sets `attention_bias = False` for the UNet (no bias on
483        // q/k/v); the output projection (`to_out.0`) keeps its bias
484        // unconditionally.
485        let norm1 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
486        let attn1 = Attention::<T>::new(dim, None, heads, dim_head, false)?;
487        let norm2 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
488        let attn2 = Attention::<T>::new(dim, Some(cross_attention_dim), heads, dim_head, false)?;
489        let norm3 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
490        let ff = FeedForward::<T>::new(dim, 4)?;
491        Ok(Self {
492            norm1,
493            attn1,
494            norm2,
495            attn2,
496            norm3,
497            ff,
498            dim,
499            training: false,
500        })
501    }
502
503    /// Forward with optional encoder hidden states. Self-attn ignores
504    /// `encoder_hidden_states`, cross-attn uses it.
505    ///
506    /// # Errors
507    ///
508    /// Returns [`FerrotorchError::ShapeMismatch`] on rank disagreement,
509    /// underlying op errors otherwise.
510    pub fn forward_xattn(
511        &self,
512        x: &Tensor<T>,
513        encoder_hidden_states: &Tensor<T>,
514    ) -> FerrotorchResult<Tensor<T>> {
515        if x.ndim() != 3 || x.shape()[2] != self.dim {
516            return Err(FerrotorchError::ShapeMismatch {
517                message: format!(
518                    "BasicTransformerBlock::forward: expected x [B, N, {}], got {:?}",
519                    self.dim,
520                    x.shape()
521                ),
522            });
523        }
524        // Sub-block 1: self-attn.
525        let h1 = self.norm1.forward(x)?;
526        let h1 = self.attn1.forward_xattn(&h1, None)?;
527        let x = ferrotorch_core::grad_fns::arithmetic::add(&h1, x)?;
528        // Sub-block 2: cross-attn.
529        let h2 = self.norm2.forward(&x)?;
530        let h2 = self.attn2.forward_xattn(&h2, Some(encoder_hidden_states))?;
531        let x = ferrotorch_core::grad_fns::arithmetic::add(&h2, &x)?;
532        // Sub-block 3: FF.
533        let h3 = self.norm3.forward(&x)?;
534        let h3 = self.ff.forward(&h3)?;
535        ferrotorch_core::grad_fns::arithmetic::add(&h3, &x)
536    }
537}
538
539impl<T: Float> Module<T> for BasicTransformerBlock<T> {
540    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
541        Err(FerrotorchError::InvalidArgument {
542            message: "BasicTransformerBlock::forward: cross-attn requires \
543                      encoder_hidden_states — call forward_xattn instead"
544                .into(),
545        })
546    }
547
548    fn parameters(&self) -> Vec<&Parameter<T>> {
549        let mut o = Vec::new();
550        o.extend(self.norm1.parameters());
551        o.extend(self.attn1.parameters());
552        o.extend(self.norm2.parameters());
553        o.extend(self.attn2.parameters());
554        o.extend(self.norm3.parameters());
555        o.extend(self.ff.parameters());
556        o
557    }
558    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
559        let mut o = Vec::new();
560        o.extend(self.norm1.parameters_mut());
561        o.extend(self.attn1.parameters_mut());
562        o.extend(self.norm2.parameters_mut());
563        o.extend(self.attn2.parameters_mut());
564        o.extend(self.norm3.parameters_mut());
565        o.extend(self.ff.parameters_mut());
566        o
567    }
568    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
569        let mut o = Vec::new();
570        for (n, p) in self.norm1.named_parameters() {
571            o.push((format!("norm1.{n}"), p));
572        }
573        for (n, p) in self.attn1.named_parameters() {
574            o.push((format!("attn1.{n}"), p));
575        }
576        for (n, p) in self.norm2.named_parameters() {
577            o.push((format!("norm2.{n}"), p));
578        }
579        for (n, p) in self.attn2.named_parameters() {
580            o.push((format!("attn2.{n}"), p));
581        }
582        for (n, p) in self.norm3.named_parameters() {
583            o.push((format!("norm3.{n}"), p));
584        }
585        for (n, p) in self.ff.named_parameters() {
586            o.push((format!("ff.{n}"), p));
587        }
588        o
589    }
590    fn train(&mut self) {
591        self.training = true;
592    }
593    fn eval(&mut self) {
594        self.training = false;
595    }
596    fn is_training(&self) -> bool {
597        self.training
598    }
599    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
600        let extract = |prefix: &str| -> StateDict<T> {
601            let p = format!("{prefix}.");
602            state
603                .iter()
604                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
605                .collect()
606        };
607        if strict {
608            for k in state.keys() {
609                let ok = k.starts_with("norm1.")
610                    || k.starts_with("attn1.")
611                    || k.starts_with("norm2.")
612                    || k.starts_with("attn2.")
613                    || k.starts_with("norm3.")
614                    || k.starts_with("ff.");
615                if !ok {
616                    return Err(FerrotorchError::InvalidArgument {
617                        message: format!(
618                            "unexpected key in BasicTransformerBlock state_dict: \"{k}\""
619                        ),
620                    });
621                }
622            }
623        }
624        self.norm1.load_state_dict(&extract("norm1"), strict)?;
625        self.attn1.load_state_dict(&extract("attn1"), strict)?;
626        self.norm2.load_state_dict(&extract("norm2"), strict)?;
627        self.attn2.load_state_dict(&extract("attn2"), strict)?;
628        self.norm3.load_state_dict(&extract("norm3"), strict)?;
629        self.ff.load_state_dict(&extract("ff"), strict)?;
630        Ok(())
631    }
632}
633
634// ---------------------------------------------------------------------------
635// Transformer2DModel
636// ---------------------------------------------------------------------------
637
638/// Diffusers' `Transformer2DModel` configured the way SD-1.5's UNet
639/// uses it:
640///
641/// ```text
642/// h = norm(x)                              # GroupNorm(32, in_channels)
643/// h = proj_in(h)                           # Conv2d(C, inner, k=1) [use_linear_projection=False]
644/// h = h.permute(0, 2, 3, 1).reshape(B, H*W, inner)
645/// for block in transformer_blocks:
646///     h = block(h, encoder_hidden_states)
647/// h = h.reshape(B, H, W, inner).permute(0, 3, 1, 2)
648/// h = proj_out(h)                          # Conv2d(inner, C, k=1)
649/// return h + residual
650/// ```
651///
652/// SD-1.5 v1 uses `Conv2d` (not Linear) for `proj_in`/`proj_out`
653/// (`use_linear_projection=False`). `transformer_layers_per_block=1`
654/// (the diffusers default and the SD-1.5 v1 setting).
655#[derive(Debug)]
656pub struct Transformer2DModel<T: Float> {
657    /// GroupNorm before `proj_in`.
658    pub norm: GroupNorm<T>,
659    /// `proj_in`: Conv2d(C, inner, k=1).
660    pub proj_in: Conv2d<T>,
661    /// `N × BasicTransformerBlock`.
662    pub transformer_blocks: Vec<BasicTransformerBlock<T>>,
663    /// `proj_out`: Conv2d(inner, C, k=1).
664    pub proj_out: Conv2d<T>,
665    channels: usize,
666    inner_dim: usize,
667    training: bool,
668}
669
670impl<T: Float> Transformer2DModel<T> {
671    /// Build a randomly-initialized `Transformer2DModel`.
672    ///
673    /// `inner_dim = heads * dim_head` for the SD UNet (proj_in expands
674    /// only when these disagree; for SD it's always equal to
675    /// `in_channels`).
676    ///
677    /// # Errors
678    ///
679    /// Returns the underlying [`FerrotorchError`] for invalid dims.
680    pub fn new(
681        in_channels: usize,
682        heads: usize,
683        dim_head: usize,
684        num_layers: usize,
685        cross_attention_dim: usize,
686        norm_num_groups: usize,
687    ) -> FerrotorchResult<Self> {
688        let inner_dim = heads * dim_head;
689        let norm = GroupNorm::<T>::new(norm_num_groups, in_channels, 1e-6, true)?;
690        let proj_in = Conv2d::<T>::new(in_channels, inner_dim, (1, 1), (1, 1), (0, 0), true)?;
691        let proj_out = Conv2d::<T>::new(inner_dim, in_channels, (1, 1), (1, 1), (0, 0), true)?;
692        let mut transformer_blocks = Vec::with_capacity(num_layers);
693        for _ in 0..num_layers {
694            transformer_blocks.push(BasicTransformerBlock::<T>::new(
695                inner_dim,
696                heads,
697                dim_head,
698                cross_attention_dim,
699            )?);
700        }
701        Ok(Self {
702            norm,
703            proj_in,
704            transformer_blocks,
705            proj_out,
706            channels: in_channels,
707            inner_dim,
708            training: false,
709        })
710    }
711
712    /// Forward with encoder hidden states for cross-attn.
713    ///
714    /// `x` has shape `[B, C, H, W]`. The result has the same shape.
715    ///
716    /// # Errors
717    ///
718    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
719    /// `[B, channels, H, W]`.
720    pub fn forward_xattn(
721        &self,
722        x: &Tensor<T>,
723        encoder_hidden_states: &Tensor<T>,
724    ) -> FerrotorchResult<Tensor<T>> {
725        if x.ndim() != 4 || x.shape()[1] != self.channels {
726            return Err(FerrotorchError::ShapeMismatch {
727                message: format!(
728                    "Transformer2DModel::forward: expected [B, {}, H, W], got {:?}",
729                    self.channels,
730                    x.shape()
731                ),
732            });
733        }
734        let b = x.shape()[0];
735        let c = x.shape()[1];
736        let h = x.shape()[2];
737        let w = x.shape()[3];
738        let hw = h * w;
739
740        let residual = x.clone();
741        // norm + proj_in (Conv2d k=1) keeps the [B, C', H, W] layout.
742        let mut hidden = self.norm.forward(x)?;
743        hidden = self.proj_in.forward(&hidden)?;
744        // [B, inner, H, W] -> [B, inner, HW] -> [B, HW, inner]
745        let mut hidden_seq = hidden
746            .reshape_t(&[b as isize, self.inner_dim as isize, hw as isize])?
747            .transpose(1, 2)?
748            .contiguous()?;
749        // Run the transformer blocks.
750        for block in &self.transformer_blocks {
751            hidden_seq = block.forward_xattn(&hidden_seq, encoder_hidden_states)?;
752        }
753        // Back to spatial: [B, HW, inner] -> [B, inner, HW] -> [B, inner, H, W]
754        let hidden_back = hidden_seq
755            .transpose(1, 2)?
756            .reshape_t(&[b as isize, self.inner_dim as isize, h as isize, w as isize])?
757            .contiguous()?;
758        // proj_out (Conv2d k=1) + residual.
759        let out = self.proj_out.forward(&hidden_back)?;
760        let _ = c;
761        ferrotorch_core::grad_fns::arithmetic::add(&out, &residual)
762    }
763}
764
765impl<T: Float> Module<T> for Transformer2DModel<T> {
766    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
767        Err(FerrotorchError::InvalidArgument {
768            message: "Transformer2DModel::forward: cross-attn requires \
769                      encoder_hidden_states — call forward_xattn instead"
770                .into(),
771        })
772    }
773
774    fn parameters(&self) -> Vec<&Parameter<T>> {
775        let mut o = Vec::new();
776        o.extend(self.norm.parameters());
777        o.extend(self.proj_in.parameters());
778        for b in &self.transformer_blocks {
779            o.extend(b.parameters());
780        }
781        o.extend(self.proj_out.parameters());
782        o
783    }
784    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
785        let mut o = Vec::new();
786        o.extend(self.norm.parameters_mut());
787        o.extend(self.proj_in.parameters_mut());
788        for b in &mut self.transformer_blocks {
789            o.extend(b.parameters_mut());
790        }
791        o.extend(self.proj_out.parameters_mut());
792        o
793    }
794    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
795        let mut o = Vec::new();
796        for (n, p) in self.norm.named_parameters() {
797            o.push((format!("norm.{n}"), p));
798        }
799        for (n, p) in self.proj_in.named_parameters() {
800            o.push((format!("proj_in.{n}"), p));
801        }
802        for (i, b) in self.transformer_blocks.iter().enumerate() {
803            for (n, p) in b.named_parameters() {
804                o.push((format!("transformer_blocks.{i}.{n}"), p));
805            }
806        }
807        for (n, p) in self.proj_out.named_parameters() {
808            o.push((format!("proj_out.{n}"), p));
809        }
810        o
811    }
812    fn train(&mut self) {
813        self.training = true;
814    }
815    fn eval(&mut self) {
816        self.training = false;
817    }
818    fn is_training(&self) -> bool {
819        self.training
820    }
821    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
822        let extract = |prefix: &str| -> StateDict<T> {
823            let p = format!("{prefix}.");
824            state
825                .iter()
826                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
827                .collect()
828        };
829        if strict {
830            for k in state.keys() {
831                let ok = k.starts_with("norm.")
832                    || k.starts_with("proj_in.")
833                    || k.starts_with("transformer_blocks.")
834                    || k.starts_with("proj_out.");
835                if !ok {
836                    return Err(FerrotorchError::InvalidArgument {
837                        message: format!(
838                            "unexpected key in Transformer2DModel state_dict: \"{k}\""
839                        ),
840                    });
841                }
842            }
843        }
844        self.norm.load_state_dict(&extract("norm"), strict)?;
845        self.proj_in.load_state_dict(&extract("proj_in"), strict)?;
846        for (i, b) in self.transformer_blocks.iter_mut().enumerate() {
847            b.load_state_dict(&extract(&format!("transformer_blocks.{i}")), strict)?;
848        }
849        self.proj_out
850            .load_state_dict(&extract("proj_out"), strict)?;
851        Ok(())
852    }
853}
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858    use ferrotorch_core::TensorStorage;
859
860    #[test]
861    fn attention_self_shape() {
862        let a = Attention::<f32>::new(16, None, 4, 4, false).unwrap();
863        let x = Tensor::from_storage(
864            TensorStorage::cpu(vec![0.01f32; 5 * 16]),
865            vec![1, 5, 16],
866            false,
867        )
868        .unwrap();
869        let y = a.forward_xattn(&x, None).unwrap();
870        assert_eq!(y.shape(), &[1, 5, 16]);
871    }
872
873    #[test]
874    fn attention_cross_shape() {
875        let a = Attention::<f32>::new(16, Some(24), 4, 4, false).unwrap();
876        let x = Tensor::from_storage(
877            TensorStorage::cpu(vec![0.01f32; 5 * 16]),
878            vec![1, 5, 16],
879            false,
880        )
881        .unwrap();
882        let ehs = Tensor::from_storage(
883            TensorStorage::cpu(vec![0.01f32; 7 * 24]),
884            vec![1, 7, 24],
885            false,
886        )
887        .unwrap();
888        let y = a.forward_xattn(&x, Some(&ehs)).unwrap();
889        assert_eq!(y.shape(), &[1, 5, 16]);
890    }
891
892    #[test]
893    fn feedforward_shape_and_keys() {
894        let ff = FeedForward::<f32>::new(16, 2).unwrap();
895        let x = Tensor::from_storage(
896            TensorStorage::cpu(vec![0.01f32; 5 * 16]),
897            vec![1, 5, 16],
898            false,
899        )
900        .unwrap();
901        let y = ff.forward(&x).unwrap();
902        assert_eq!(y.shape(), &[1, 5, 16]);
903        let names: Vec<String> = ff.named_parameters().into_iter().map(|(n, _)| n).collect();
904        for k in [
905            "net.0.proj.weight",
906            "net.0.proj.bias",
907            "net.2.weight",
908            "net.2.bias",
909        ] {
910            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
911        }
912    }
913
914    #[test]
915    fn basic_transformer_block_shape() {
916        let blk = BasicTransformerBlock::<f32>::new(16, 4, 4, 24).unwrap();
917        let x = Tensor::from_storage(
918            TensorStorage::cpu(vec![0.01f32; 5 * 16]),
919            vec![1, 5, 16],
920            false,
921        )
922        .unwrap();
923        let ehs = Tensor::from_storage(
924            TensorStorage::cpu(vec![0.01f32; 7 * 24]),
925            vec![1, 7, 24],
926            false,
927        )
928        .unwrap();
929        let y = blk.forward_xattn(&x, &ehs).unwrap();
930        assert_eq!(y.shape(), &[1, 5, 16]);
931    }
932
933    #[test]
934    fn transformer_2d_shape() {
935        let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
936        let x = Tensor::from_storage(
937            TensorStorage::cpu(vec![0.01f32; 16 * 3 * 3]),
938            vec![1, 16, 3, 3],
939            false,
940        )
941        .unwrap();
942        let ehs = Tensor::from_storage(
943            TensorStorage::cpu(vec![0.01f32; 5 * 24]),
944            vec![1, 5, 24],
945            false,
946        )
947        .unwrap();
948        let y = t.forward_xattn(&x, &ehs).unwrap();
949        assert_eq!(y.shape(), &[1, 16, 3, 3]);
950    }
951
952    #[test]
953    fn transformer_2d_named_parameters() {
954        let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
955        let names: Vec<String> = t.named_parameters().into_iter().map(|(n, _)| n).collect();
956        for k in [
957            "norm.weight",
958            "proj_in.weight",
959            "proj_in.bias",
960            "transformer_blocks.0.norm1.weight",
961            "transformer_blocks.0.attn1.to_q.weight",
962            "transformer_blocks.0.attn2.to_k.weight",
963            "transformer_blocks.0.ff.net.0.proj.weight",
964            "transformer_blocks.0.ff.net.2.weight",
965            "proj_out.weight",
966        ] {
967            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
968        }
969    }
970}