Skip to main content

ferrotorch_diffusion/
attention.rs

1//! Multi-head attention + the `Transformer2DModel` wrapper used by the
2//! SD UNet's CrossAttn blocks.
3//!
4//! Diffusers' layout 1:1:
5//!
6//! ```text
7//! Attention(query_dim, cross_attention_dim?, heads, dim_head)
8//!   to_q.{weight,bias}    [inner, query_dim]   (inner = heads * dim_head)
9//!   to_k.{weight,bias}    [inner, kv_dim]       (kv_dim = cross_attention_dim
10//!                                                or query_dim for self-attn)
11//!   to_v.{weight,bias}    [inner, kv_dim]
12//!   to_out.0.{weight,bias}[query_dim, inner]
13//!   to_out.1              Dropout (no params)
14//! ```
15//!
16//! For SD-1.5 UNet, biases on `to_q/to_k/to_v` are disabled
17//! (`bias=False`); the output projection `to_out.0` keeps its bias.
18//!
19//! `BasicTransformerBlock`:
20//!
21//! ```text
22//! h0 = LayerNorm1(x)
23//! h1 = Attention1(h0, h0, h0)             # self-attn
24//! x  = x + h1
25//! h0 = LayerNorm2(x)
26//! h2 = Attention2(h0, encoder_hidden, …)  # cross-attn
27//! x  = x + h2
28//! h0 = LayerNorm3(x)
29//! h3 = FeedForward(h0)                    # GEGLU + Linear
30//! x  = x + h3
31//! ```
32//!
33//! `Transformer2DModel`:
34//!
35//! ```text
36//! GroupNorm(32) -> proj_in (Conv2d k=1) -> flatten [B, HW, C]
37//!   -> N × BasicTransformerBlock -> reshape back -> proj_out + residual
38//! ```
39
40use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
41use ferrotorch_nn::module::{Module, StateDict};
42use ferrotorch_nn::parameter::Parameter;
43use ferrotorch_nn::{Conv2d, GELU, GroupNorm, LayerNorm, Linear};
44
45// ---------------------------------------------------------------------------
46// Attention (multi-head, optional cross-attention)
47// ---------------------------------------------------------------------------
48
49/// Multi-head attention block. Supports self-attention (when `key` and
50/// `value` are derived from the same tensor as `query`) and
51/// cross-attention (when they come from `encoder_hidden_states`).
52///
53/// This is the `Attention` class in
54/// `diffusers.models.attention_processor` configured as it appears in
55/// SD-1.5's UNet — `bias = False` on q/k/v and `out_bias = True` on
56/// `to_out.0`, no `group_norm`, no `spatial_norm`, no
57/// `added_kv_proj_dim`.
58#[derive(Debug)]
59pub struct Attention<T: Float> {
60    /// Per-head dimension.
61    pub dim_head: usize,
62    /// Number of heads.
63    pub heads: usize,
64    /// `inner_dim = heads * dim_head`.
65    pub inner_dim: usize,
66    /// Query projection: `[inner_dim, query_dim]`.
67    pub to_q: Linear<T>,
68    /// Key projection: `[inner_dim, kv_dim]`.
69    pub to_k: Linear<T>,
70    /// Value projection: `[inner_dim, kv_dim]`.
71    pub to_v: Linear<T>,
72    /// Output projection: `[query_dim, inner_dim]` (with bias).
73    pub to_out_0: Linear<T>,
74    query_dim: usize,
75    kv_dim: usize,
76    scale: f64,
77    training: bool,
78}
79
80impl<T: Float> Attention<T> {
81    /// Build a randomly-initialized `Attention`.
82    ///
83    /// `cross_attention_dim = None` means self-attention
84    /// (`kv_dim = query_dim`); a `Some(_)` value enables cross-attention.
85    ///
86    /// `bias` controls `to_q/to_k/to_v` bias (SD-1.5 sets this to false).
87    /// `to_out.0` always has bias (matches diffusers default
88    /// `out_bias=True`).
89    ///
90    /// # Errors
91    ///
92    /// Returns the underlying [`FerrotorchError`] when any `Linear` size
93    /// is invalid.
94    pub fn new(
95        query_dim: usize,
96        cross_attention_dim: Option<usize>,
97        heads: usize,
98        dim_head: usize,
99        bias: bool,
100    ) -> FerrotorchResult<Self> {
101        let inner_dim = heads * dim_head;
102        let kv_dim = cross_attention_dim.unwrap_or(query_dim);
103        let to_q = Linear::<T>::new(query_dim, inner_dim, bias)?;
104        let to_k = Linear::<T>::new(kv_dim, inner_dim, bias)?;
105        let to_v = Linear::<T>::new(kv_dim, inner_dim, bias)?;
106        let to_out_0 = Linear::<T>::new(inner_dim, query_dim, true)?;
107        let scale = (dim_head as f64).sqrt().recip();
108        Ok(Self {
109            dim_head,
110            heads,
111            inner_dim,
112            to_q,
113            to_k,
114            to_v,
115            to_out_0,
116            query_dim,
117            kv_dim,
118            scale,
119            training: false,
120        })
121    }
122
123    /// Forward with optional encoder hidden states.
124    ///
125    /// `hidden_states` has shape `[B, N, query_dim]`. When
126    /// `encoder_hidden_states` is `None` this is self-attention; when
127    /// `Some([B, S, kv_dim])` it's cross-attention.
128    ///
129    /// # Errors
130    ///
131    /// Returns [`FerrotorchError::ShapeMismatch`] on rank / dim
132    /// disagreement.
133    pub fn forward_xattn(
134        &self,
135        hidden_states: &Tensor<T>,
136        encoder_hidden_states: Option<&Tensor<T>>,
137    ) -> FerrotorchResult<Tensor<T>> {
138        if hidden_states.ndim() != 3 || hidden_states.shape()[2] != self.query_dim {
139            return Err(FerrotorchError::ShapeMismatch {
140                message: format!(
141                    "Attention::forward_xattn: expected hidden_states [B, N, {}], got {:?}",
142                    self.query_dim,
143                    hidden_states.shape()
144                ),
145            });
146        }
147        let b = hidden_states.shape()[0];
148        let n = hidden_states.shape()[1];
149        // kv source.
150        let kv = encoder_hidden_states.unwrap_or(hidden_states);
151        if kv.ndim() != 3 || kv.shape()[0] != b || kv.shape()[2] != self.kv_dim {
152            return Err(FerrotorchError::ShapeMismatch {
153                message: format!(
154                    "Attention::forward_xattn: expected kv [B={b}, S, {}], got {:?}",
155                    self.kv_dim,
156                    kv.shape()
157                ),
158            });
159        }
160        let s = kv.shape()[1];
161
162        // -- Linear projections. q: [B, N, inner], k/v: [B, S, inner].
163        let q = self.to_q.forward(hidden_states)?;
164        let k = self.to_k.forward(kv)?;
165        let v = self.to_v.forward(kv)?;
166
167        // -- Reshape to per-head: [B, N, H, D] then transpose to
168        //    [B, H, N, D], collapse to [B*H, N, D] for the BMM kernel.
169        //    Same trick for k, v over S.
170        let h = self.heads;
171        let d = self.dim_head;
172        let q = q
173            .reshape_t(&[b as isize, n as isize, h as isize, d as isize])?
174            .transpose(1, 2)? // [B, H, N, D]
175            .contiguous()?
176            .reshape_t(&[(b * h) as isize, n as isize, d as isize])?;
177        let k = k
178            .reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
179            .transpose(1, 2)? // [B, H, S, D]
180            .contiguous()?
181            .reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
182        let v = v
183            .reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
184            .transpose(1, 2)? // [B, H, S, D]
185            .contiguous()?
186            .reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
187
188        // -- scores = (q @ k^T) * scale.
189        let k_t = k.transpose(1, 2)?.contiguous()?; // [B*H, D, S]
190        let scores = q.bmm(&k_t)?; // [B*H, N, S]
191        let scale_t = T::from(self.scale).ok_or_else(|| FerrotorchError::InvalidArgument {
192            message: "Attention::forward_xattn: failed to cast attention scale into Float".into(),
193        })?;
194        let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
195        let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
196        let probs = scores_scaled.softmax()?; // [B*H, N, S]
197        let attended = probs.bmm(&v)?; // [B*H, N, D]
198
199        // -- Merge heads: [B*H, N, D] -> [B, H, N, D] -> [B, N, H, D]
200        //    -> [B, N, inner_dim].
201        let attended = attended
202            .reshape_t(&[b as isize, h as isize, n as isize, d as isize])?
203            .transpose(1, 2)? // [B, N, H, D]
204            .contiguous()?
205            .reshape_t(&[b as isize, n as isize, self.inner_dim as isize])?;
206
207        // -- Output projection.
208        self.to_out_0.forward(&attended)
209    }
210}
211
212impl<T: Float> Module<T> for Attention<T> {
213    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
214        // Self-attention path.
215        self.forward_xattn(input, None)
216    }
217
218    fn parameters(&self) -> Vec<&Parameter<T>> {
219        let mut o = Vec::new();
220        o.extend(self.to_q.parameters());
221        o.extend(self.to_k.parameters());
222        o.extend(self.to_v.parameters());
223        o.extend(self.to_out_0.parameters());
224        o
225    }
226    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
227        let mut o = Vec::new();
228        o.extend(self.to_q.parameters_mut());
229        o.extend(self.to_k.parameters_mut());
230        o.extend(self.to_v.parameters_mut());
231        o.extend(self.to_out_0.parameters_mut());
232        o
233    }
234    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
235        let mut o = Vec::new();
236        for (n, p) in self.to_q.named_parameters() {
237            o.push((format!("to_q.{n}"), p));
238        }
239        for (n, p) in self.to_k.named_parameters() {
240            o.push((format!("to_k.{n}"), p));
241        }
242        for (n, p) in self.to_v.named_parameters() {
243            o.push((format!("to_v.{n}"), p));
244        }
245        for (n, p) in self.to_out_0.named_parameters() {
246            o.push((format!("to_out.0.{n}"), p));
247        }
248        o
249    }
250    fn train(&mut self) {
251        self.training = true;
252    }
253    fn eval(&mut self) {
254        self.training = false;
255    }
256    fn is_training(&self) -> bool {
257        self.training
258    }
259    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
260        let extract = |prefix: &str| -> StateDict<T> {
261            let p = format!("{prefix}.");
262            state
263                .iter()
264                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
265                .collect()
266        };
267        if strict {
268            for k in state.keys() {
269                let ok = k.starts_with("to_q.")
270                    || k.starts_with("to_k.")
271                    || k.starts_with("to_v.")
272                    || k.starts_with("to_out.0.");
273                if !ok {
274                    return Err(FerrotorchError::InvalidArgument {
275                        message: format!("unexpected key in Attention state_dict: \"{k}\""),
276                    });
277                }
278            }
279        }
280        self.to_q.load_state_dict(&extract("to_q"), strict)?;
281        self.to_k.load_state_dict(&extract("to_k"), strict)?;
282        self.to_v.load_state_dict(&extract("to_v"), strict)?;
283        self.to_out_0
284            .load_state_dict(&extract("to_out.0"), strict)?;
285        Ok(())
286    }
287}
288
289// ---------------------------------------------------------------------------
290// FeedForward (GEGLU + Linear)
291// ---------------------------------------------------------------------------
292
293/// GEGLU-style feed-forward (matches the SD UNet's
294/// `BasicTransformerBlock.ff` exactly):
295///
296/// ```text
297/// net.0  = GEGLU(dim, dim * mult)
298///          = Linear(dim, 2 * dim * mult)         # proj
299///          x, gate = chunk(2, dim=-1)
300///          return x * gelu(gate)
301/// net.1  = Dropout (no params)
302/// net.2  = Linear(dim * mult, dim)
303/// ```
304///
305/// Diffusers' `FeedForward` defaults: `mult = 4`,
306/// `activation_fn = "geglu"`. SD-1.5 UNet uses both defaults.
307///
308/// State-dict layout:
309///
310/// ```text
311/// net.0.proj.{weight,bias}    [2 * dim_ff, dim], [2 * dim_ff]
312/// net.2.{weight,bias}         [dim, dim_ff],     [dim]
313/// ```
314#[derive(Debug)]
315pub struct FeedForward<T: Float> {
316    /// GEGLU's expansion projection (`dim -> 2 * dim_ff`).
317    pub net_0_proj: Linear<T>,
318    /// Output projection (`dim_ff -> dim`).
319    pub net_2: Linear<T>,
320    activation: GELU,
321    dim_ff: usize,
322    training: bool,
323}
324
325impl<T: Float> FeedForward<T> {
326    /// Build a GEGLU `FeedForward` (`dim_ff = dim * mult`).
327    ///
328    /// # Errors
329    ///
330    /// Returns the underlying [`FerrotorchError`] for invalid dims.
331    pub fn new(dim: usize, mult: usize) -> FerrotorchResult<Self> {
332        let dim_ff = dim * mult;
333        let net_0_proj = Linear::<T>::new(dim, 2 * dim_ff, true)?;
334        let net_2 = Linear::<T>::new(dim_ff, dim, true)?;
335        Ok(Self {
336            net_0_proj,
337            net_2,
338            activation: GELU::new(),
339            dim_ff,
340            training: false,
341        })
342    }
343}
344
345impl<T: Float> Module<T> for FeedForward<T> {
346    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
347        // proj -> chunk(2, -1) -> x * gelu(gate)
348        let proj = self.net_0_proj.forward(input)?;
349        // `chunk` operates on a positive dim — last axis here.
350        let last = proj.ndim() - 1;
351        let parts = proj.chunk(2, last)?;
352        if parts.len() != 2 {
353            return Err(FerrotorchError::ShapeMismatch {
354                message: format!(
355                    "FeedForward: chunk(2) returned {} parts (expected 2)",
356                    parts.len()
357                ),
358            });
359        }
360        let x = parts[0].contiguous()?;
361        let gate = parts[1].contiguous()?;
362        let gated = self.activation.forward(&gate)?;
363        let activated = ferrotorch_core::grad_fns::arithmetic::mul(&x, &gated)?;
364        self.net_2.forward(&activated)
365    }
366    fn parameters(&self) -> Vec<&Parameter<T>> {
367        let mut o = Vec::new();
368        o.extend(self.net_0_proj.parameters());
369        o.extend(self.net_2.parameters());
370        o
371    }
372    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
373        let mut o = Vec::new();
374        o.extend(self.net_0_proj.parameters_mut());
375        o.extend(self.net_2.parameters_mut());
376        o
377    }
378    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
379        let mut o = Vec::new();
380        for (n, p) in self.net_0_proj.named_parameters() {
381            o.push((format!("net.0.proj.{n}"), p));
382        }
383        for (n, p) in self.net_2.named_parameters() {
384            o.push((format!("net.2.{n}"), p));
385        }
386        o
387    }
388    fn train(&mut self) {
389        self.training = true;
390    }
391    fn eval(&mut self) {
392        self.training = false;
393    }
394    fn is_training(&self) -> bool {
395        self.training
396    }
397    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
398        let extract = |prefix: &str| -> StateDict<T> {
399            let p = format!("{prefix}.");
400            state
401                .iter()
402                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
403                .collect()
404        };
405        if strict {
406            for k in state.keys() {
407                let ok = k.starts_with("net.0.proj.") || k.starts_with("net.2.");
408                if !ok {
409                    return Err(FerrotorchError::InvalidArgument {
410                        message: format!("unexpected key in FeedForward state_dict: \"{k}\""),
411                    });
412                }
413            }
414        }
415        self.net_0_proj
416            .load_state_dict(&extract("net.0.proj"), strict)?;
417        self.net_2.load_state_dict(&extract("net.2"), strict)?;
418        let _ = self.dim_ff;
419        Ok(())
420    }
421}
422
423// ---------------------------------------------------------------------------
424// BasicTransformerBlock
425// ---------------------------------------------------------------------------
426
427/// Diffusers' `BasicTransformerBlock` configured the way SD-1.5's UNet
428/// uses it: pre-LayerNorm on every sub-layer, self-attn followed by
429/// cross-attn followed by GEGLU FeedForward, all with residuals.
430///
431/// State-dict layout:
432///
433/// ```text
434/// norm1.{weight,bias}   [dim], [dim]
435/// attn1.<keys>          # self-attn (Attention with cross_attention_dim=None)
436/// norm2.{weight,bias}
437/// attn2.<keys>          # cross-attn
438/// norm3.{weight,bias}
439/// ff.<keys>             # FeedForward (GEGLU)
440/// ```
441#[derive(Debug)]
442pub struct BasicTransformerBlock<T: Float> {
443    /// LayerNorm before self-attn.
444    pub norm1: LayerNorm<T>,
445    /// Self-attention.
446    pub attn1: Attention<T>,
447    /// LayerNorm before cross-attn.
448    pub norm2: LayerNorm<T>,
449    /// Cross-attention.
450    pub attn2: Attention<T>,
451    /// LayerNorm before FF.
452    pub norm3: LayerNorm<T>,
453    /// GEGLU FeedForward.
454    pub ff: FeedForward<T>,
455    dim: usize,
456    training: bool,
457}
458
459impl<T: Float> BasicTransformerBlock<T> {
460    /// Build a randomly-initialized `BasicTransformerBlock`.
461    ///
462    /// # Errors
463    ///
464    /// Returns the underlying [`FerrotorchError`] for invalid dims.
465    pub fn new(
466        dim: usize,
467        heads: usize,
468        dim_head: usize,
469        cross_attention_dim: usize,
470    ) -> FerrotorchResult<Self> {
471        // SD-1.5 sets `attention_bias = False` for the UNet (no bias on
472        // q/k/v); the output projection (`to_out.0`) keeps its bias
473        // unconditionally.
474        let norm1 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
475        let attn1 = Attention::<T>::new(dim, None, heads, dim_head, false)?;
476        let norm2 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
477        let attn2 = Attention::<T>::new(dim, Some(cross_attention_dim), heads, dim_head, false)?;
478        let norm3 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
479        let ff = FeedForward::<T>::new(dim, 4)?;
480        Ok(Self {
481            norm1,
482            attn1,
483            norm2,
484            attn2,
485            norm3,
486            ff,
487            dim,
488            training: false,
489        })
490    }
491
492    /// Forward with optional encoder hidden states. Self-attn ignores
493    /// `encoder_hidden_states`, cross-attn uses it.
494    ///
495    /// # Errors
496    ///
497    /// Returns [`FerrotorchError::ShapeMismatch`] on rank disagreement,
498    /// underlying op errors otherwise.
499    pub fn forward_xattn(
500        &self,
501        x: &Tensor<T>,
502        encoder_hidden_states: &Tensor<T>,
503    ) -> FerrotorchResult<Tensor<T>> {
504        if x.ndim() != 3 || x.shape()[2] != self.dim {
505            return Err(FerrotorchError::ShapeMismatch {
506                message: format!(
507                    "BasicTransformerBlock::forward: expected x [B, N, {}], got {:?}",
508                    self.dim,
509                    x.shape()
510                ),
511            });
512        }
513        // Sub-block 1: self-attn.
514        let h1 = self.norm1.forward(x)?;
515        let h1 = self.attn1.forward_xattn(&h1, None)?;
516        let x = ferrotorch_core::grad_fns::arithmetic::add(&h1, x)?;
517        // Sub-block 2: cross-attn.
518        let h2 = self.norm2.forward(&x)?;
519        let h2 = self.attn2.forward_xattn(&h2, Some(encoder_hidden_states))?;
520        let x = ferrotorch_core::grad_fns::arithmetic::add(&h2, &x)?;
521        // Sub-block 3: FF.
522        let h3 = self.norm3.forward(&x)?;
523        let h3 = self.ff.forward(&h3)?;
524        ferrotorch_core::grad_fns::arithmetic::add(&h3, &x)
525    }
526}
527
528impl<T: Float> Module<T> for BasicTransformerBlock<T> {
529    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
530        Err(FerrotorchError::InvalidArgument {
531            message: "BasicTransformerBlock::forward: cross-attn requires \
532                      encoder_hidden_states — call forward_xattn instead"
533                .into(),
534        })
535    }
536
537    fn parameters(&self) -> Vec<&Parameter<T>> {
538        let mut o = Vec::new();
539        o.extend(self.norm1.parameters());
540        o.extend(self.attn1.parameters());
541        o.extend(self.norm2.parameters());
542        o.extend(self.attn2.parameters());
543        o.extend(self.norm3.parameters());
544        o.extend(self.ff.parameters());
545        o
546    }
547    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
548        let mut o = Vec::new();
549        o.extend(self.norm1.parameters_mut());
550        o.extend(self.attn1.parameters_mut());
551        o.extend(self.norm2.parameters_mut());
552        o.extend(self.attn2.parameters_mut());
553        o.extend(self.norm3.parameters_mut());
554        o.extend(self.ff.parameters_mut());
555        o
556    }
557    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
558        let mut o = Vec::new();
559        for (n, p) in self.norm1.named_parameters() {
560            o.push((format!("norm1.{n}"), p));
561        }
562        for (n, p) in self.attn1.named_parameters() {
563            o.push((format!("attn1.{n}"), p));
564        }
565        for (n, p) in self.norm2.named_parameters() {
566            o.push((format!("norm2.{n}"), p));
567        }
568        for (n, p) in self.attn2.named_parameters() {
569            o.push((format!("attn2.{n}"), p));
570        }
571        for (n, p) in self.norm3.named_parameters() {
572            o.push((format!("norm3.{n}"), p));
573        }
574        for (n, p) in self.ff.named_parameters() {
575            o.push((format!("ff.{n}"), p));
576        }
577        o
578    }
579    fn train(&mut self) {
580        self.training = true;
581    }
582    fn eval(&mut self) {
583        self.training = false;
584    }
585    fn is_training(&self) -> bool {
586        self.training
587    }
588    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
589        let extract = |prefix: &str| -> StateDict<T> {
590            let p = format!("{prefix}.");
591            state
592                .iter()
593                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
594                .collect()
595        };
596        if strict {
597            for k in state.keys() {
598                let ok = k.starts_with("norm1.")
599                    || k.starts_with("attn1.")
600                    || k.starts_with("norm2.")
601                    || k.starts_with("attn2.")
602                    || k.starts_with("norm3.")
603                    || k.starts_with("ff.");
604                if !ok {
605                    return Err(FerrotorchError::InvalidArgument {
606                        message: format!(
607                            "unexpected key in BasicTransformerBlock state_dict: \"{k}\""
608                        ),
609                    });
610                }
611            }
612        }
613        self.norm1.load_state_dict(&extract("norm1"), strict)?;
614        self.attn1.load_state_dict(&extract("attn1"), strict)?;
615        self.norm2.load_state_dict(&extract("norm2"), strict)?;
616        self.attn2.load_state_dict(&extract("attn2"), strict)?;
617        self.norm3.load_state_dict(&extract("norm3"), strict)?;
618        self.ff.load_state_dict(&extract("ff"), strict)?;
619        Ok(())
620    }
621}
622
623// ---------------------------------------------------------------------------
624// Transformer2DModel
625// ---------------------------------------------------------------------------
626
627/// Diffusers' `Transformer2DModel` configured the way SD-1.5's UNet
628/// uses it:
629///
630/// ```text
631/// h = norm(x)                              # GroupNorm(32, in_channels)
632/// h = proj_in(h)                           # Conv2d(C, inner, k=1) [use_linear_projection=False]
633/// h = h.permute(0, 2, 3, 1).reshape(B, H*W, inner)
634/// for block in transformer_blocks:
635///     h = block(h, encoder_hidden_states)
636/// h = h.reshape(B, H, W, inner).permute(0, 3, 1, 2)
637/// h = proj_out(h)                          # Conv2d(inner, C, k=1)
638/// return h + residual
639/// ```
640///
641/// SD-1.5 v1 uses `Conv2d` (not Linear) for `proj_in`/`proj_out`
642/// (`use_linear_projection=False`). `transformer_layers_per_block=1`
643/// (the diffusers default and the SD-1.5 v1 setting).
644#[derive(Debug)]
645pub struct Transformer2DModel<T: Float> {
646    /// GroupNorm before `proj_in`.
647    pub norm: GroupNorm<T>,
648    /// `proj_in`: Conv2d(C, inner, k=1).
649    pub proj_in: Conv2d<T>,
650    /// `N × BasicTransformerBlock`.
651    pub transformer_blocks: Vec<BasicTransformerBlock<T>>,
652    /// `proj_out`: Conv2d(inner, C, k=1).
653    pub proj_out: Conv2d<T>,
654    channels: usize,
655    inner_dim: usize,
656    training: bool,
657}
658
659impl<T: Float> Transformer2DModel<T> {
660    /// Build a randomly-initialized `Transformer2DModel`.
661    ///
662    /// `inner_dim = heads * dim_head` for the SD UNet (proj_in expands
663    /// only when these disagree; for SD it's always equal to
664    /// `in_channels`).
665    ///
666    /// # Errors
667    ///
668    /// Returns the underlying [`FerrotorchError`] for invalid dims.
669    pub fn new(
670        in_channels: usize,
671        heads: usize,
672        dim_head: usize,
673        num_layers: usize,
674        cross_attention_dim: usize,
675        norm_num_groups: usize,
676    ) -> FerrotorchResult<Self> {
677        let inner_dim = heads * dim_head;
678        let norm = GroupNorm::<T>::new(norm_num_groups, in_channels, 1e-6, true)?;
679        let proj_in = Conv2d::<T>::new(in_channels, inner_dim, (1, 1), (1, 1), (0, 0), true)?;
680        let proj_out = Conv2d::<T>::new(inner_dim, in_channels, (1, 1), (1, 1), (0, 0), true)?;
681        let mut transformer_blocks = Vec::with_capacity(num_layers);
682        for _ in 0..num_layers {
683            transformer_blocks.push(BasicTransformerBlock::<T>::new(
684                inner_dim,
685                heads,
686                dim_head,
687                cross_attention_dim,
688            )?);
689        }
690        Ok(Self {
691            norm,
692            proj_in,
693            transformer_blocks,
694            proj_out,
695            channels: in_channels,
696            inner_dim,
697            training: false,
698        })
699    }
700
701    /// Forward with encoder hidden states for cross-attn.
702    ///
703    /// `x` has shape `[B, C, H, W]`. The result has the same shape.
704    ///
705    /// # Errors
706    ///
707    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
708    /// `[B, channels, H, W]`.
709    pub fn forward_xattn(
710        &self,
711        x: &Tensor<T>,
712        encoder_hidden_states: &Tensor<T>,
713    ) -> FerrotorchResult<Tensor<T>> {
714        if x.ndim() != 4 || x.shape()[1] != self.channels {
715            return Err(FerrotorchError::ShapeMismatch {
716                message: format!(
717                    "Transformer2DModel::forward: expected [B, {}, H, W], got {:?}",
718                    self.channels,
719                    x.shape()
720                ),
721            });
722        }
723        let b = x.shape()[0];
724        let c = x.shape()[1];
725        let h = x.shape()[2];
726        let w = x.shape()[3];
727        let hw = h * w;
728
729        let residual = x.clone();
730        // norm + proj_in (Conv2d k=1) keeps the [B, C', H, W] layout.
731        let mut hidden = self.norm.forward(x)?;
732        hidden = self.proj_in.forward(&hidden)?;
733        // [B, inner, H, W] -> [B, inner, HW] -> [B, HW, inner]
734        let mut hidden_seq = hidden
735            .reshape_t(&[b as isize, self.inner_dim as isize, hw as isize])?
736            .transpose(1, 2)?
737            .contiguous()?;
738        // Run the transformer blocks.
739        for block in &self.transformer_blocks {
740            hidden_seq = block.forward_xattn(&hidden_seq, encoder_hidden_states)?;
741        }
742        // Back to spatial: [B, HW, inner] -> [B, inner, HW] -> [B, inner, H, W]
743        let hidden_back = hidden_seq
744            .transpose(1, 2)?
745            .reshape_t(&[b as isize, self.inner_dim as isize, h as isize, w as isize])?
746            .contiguous()?;
747        // proj_out (Conv2d k=1) + residual.
748        let out = self.proj_out.forward(&hidden_back)?;
749        let _ = c;
750        ferrotorch_core::grad_fns::arithmetic::add(&out, &residual)
751    }
752}
753
754impl<T: Float> Module<T> for Transformer2DModel<T> {
755    fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
756        Err(FerrotorchError::InvalidArgument {
757            message: "Transformer2DModel::forward: cross-attn requires \
758                      encoder_hidden_states — call forward_xattn instead"
759                .into(),
760        })
761    }
762
763    fn parameters(&self) -> Vec<&Parameter<T>> {
764        let mut o = Vec::new();
765        o.extend(self.norm.parameters());
766        o.extend(self.proj_in.parameters());
767        for b in &self.transformer_blocks {
768            o.extend(b.parameters());
769        }
770        o.extend(self.proj_out.parameters());
771        o
772    }
773    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
774        let mut o = Vec::new();
775        o.extend(self.norm.parameters_mut());
776        o.extend(self.proj_in.parameters_mut());
777        for b in &mut self.transformer_blocks {
778            o.extend(b.parameters_mut());
779        }
780        o.extend(self.proj_out.parameters_mut());
781        o
782    }
783    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
784        let mut o = Vec::new();
785        for (n, p) in self.norm.named_parameters() {
786            o.push((format!("norm.{n}"), p));
787        }
788        for (n, p) in self.proj_in.named_parameters() {
789            o.push((format!("proj_in.{n}"), p));
790        }
791        for (i, b) in self.transformer_blocks.iter().enumerate() {
792            for (n, p) in b.named_parameters() {
793                o.push((format!("transformer_blocks.{i}.{n}"), p));
794            }
795        }
796        for (n, p) in self.proj_out.named_parameters() {
797            o.push((format!("proj_out.{n}"), p));
798        }
799        o
800    }
801    fn train(&mut self) {
802        self.training = true;
803    }
804    fn eval(&mut self) {
805        self.training = false;
806    }
807    fn is_training(&self) -> bool {
808        self.training
809    }
810    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
811        let extract = |prefix: &str| -> StateDict<T> {
812            let p = format!("{prefix}.");
813            state
814                .iter()
815                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
816                .collect()
817        };
818        if strict {
819            for k in state.keys() {
820                let ok = k.starts_with("norm.")
821                    || k.starts_with("proj_in.")
822                    || k.starts_with("transformer_blocks.")
823                    || k.starts_with("proj_out.");
824                if !ok {
825                    return Err(FerrotorchError::InvalidArgument {
826                        message: format!(
827                            "unexpected key in Transformer2DModel state_dict: \"{k}\""
828                        ),
829                    });
830                }
831            }
832        }
833        self.norm.load_state_dict(&extract("norm"), strict)?;
834        self.proj_in.load_state_dict(&extract("proj_in"), strict)?;
835        for (i, b) in self.transformer_blocks.iter_mut().enumerate() {
836            b.load_state_dict(&extract(&format!("transformer_blocks.{i}")), strict)?;
837        }
838        self.proj_out
839            .load_state_dict(&extract("proj_out"), strict)?;
840        Ok(())
841    }
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847    use ferrotorch_core::TensorStorage;
848
849    #[test]
850    fn attention_self_shape() {
851        let a = Attention::<f32>::new(16, None, 4, 4, false).unwrap();
852        let x = Tensor::from_storage(
853            TensorStorage::cpu(vec![0.01f32; 5 * 16]),
854            vec![1, 5, 16],
855            false,
856        )
857        .unwrap();
858        let y = a.forward_xattn(&x, None).unwrap();
859        assert_eq!(y.shape(), &[1, 5, 16]);
860    }
861
862    #[test]
863    fn attention_cross_shape() {
864        let a = Attention::<f32>::new(16, Some(24), 4, 4, false).unwrap();
865        let x = Tensor::from_storage(
866            TensorStorage::cpu(vec![0.01f32; 5 * 16]),
867            vec![1, 5, 16],
868            false,
869        )
870        .unwrap();
871        let ehs = Tensor::from_storage(
872            TensorStorage::cpu(vec![0.01f32; 7 * 24]),
873            vec![1, 7, 24],
874            false,
875        )
876        .unwrap();
877        let y = a.forward_xattn(&x, Some(&ehs)).unwrap();
878        assert_eq!(y.shape(), &[1, 5, 16]);
879    }
880
881    #[test]
882    fn feedforward_shape_and_keys() {
883        let ff = FeedForward::<f32>::new(16, 2).unwrap();
884        let x = Tensor::from_storage(
885            TensorStorage::cpu(vec![0.01f32; 5 * 16]),
886            vec![1, 5, 16],
887            false,
888        )
889        .unwrap();
890        let y = ff.forward(&x).unwrap();
891        assert_eq!(y.shape(), &[1, 5, 16]);
892        let names: Vec<String> = ff.named_parameters().into_iter().map(|(n, _)| n).collect();
893        for k in [
894            "net.0.proj.weight",
895            "net.0.proj.bias",
896            "net.2.weight",
897            "net.2.bias",
898        ] {
899            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
900        }
901    }
902
903    #[test]
904    fn basic_transformer_block_shape() {
905        let blk = BasicTransformerBlock::<f32>::new(16, 4, 4, 24).unwrap();
906        let x = Tensor::from_storage(
907            TensorStorage::cpu(vec![0.01f32; 5 * 16]),
908            vec![1, 5, 16],
909            false,
910        )
911        .unwrap();
912        let ehs = Tensor::from_storage(
913            TensorStorage::cpu(vec![0.01f32; 7 * 24]),
914            vec![1, 7, 24],
915            false,
916        )
917        .unwrap();
918        let y = blk.forward_xattn(&x, &ehs).unwrap();
919        assert_eq!(y.shape(), &[1, 5, 16]);
920    }
921
922    #[test]
923    fn transformer_2d_shape() {
924        let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
925        let x = Tensor::from_storage(
926            TensorStorage::cpu(vec![0.01f32; 16 * 3 * 3]),
927            vec![1, 16, 3, 3],
928            false,
929        )
930        .unwrap();
931        let ehs = Tensor::from_storage(
932            TensorStorage::cpu(vec![0.01f32; 5 * 24]),
933            vec![1, 5, 24],
934            false,
935        )
936        .unwrap();
937        let y = t.forward_xattn(&x, &ehs).unwrap();
938        assert_eq!(y.shape(), &[1, 16, 3, 3]);
939    }
940
941    #[test]
942    fn transformer_2d_named_parameters() {
943        let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
944        let names: Vec<String> = t.named_parameters().into_iter().map(|(n, _)| n).collect();
945        for k in [
946            "norm.weight",
947            "proj_in.weight",
948            "proj_in.bias",
949            "transformer_blocks.0.norm1.weight",
950            "transformer_blocks.0.attn1.to_q.weight",
951            "transformer_blocks.0.attn2.to_k.weight",
952            "transformer_blocks.0.ff.net.0.proj.weight",
953            "transformer_blocks.0.ff.net.2.weight",
954            "proj_out.weight",
955        ] {
956            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
957        }
958    }
959}