Skip to main content

oxigaf_diffusion/
attention.rs

1//! Attention-based building blocks for multi-view diffusion.
2//!
3//! Implements the multi-view transformer block that replaces the standard
4//! SD 2.1 `BasicTransformerBlock` with additional layers:
5//!
6//! ## Multi-View Transformer Architecture
7//!
8//! Each `MultiViewTransformerBlock` contains five sequential operations:
9//!
10//! 1. **Self-Attention** (`attn1`): Attention within each view's spatial tokens
11//! 2. **Cross-View Attention** (`attn_cv`): Attention across all N views at each
12//!    spatial position, enabling 3D consistency
13//! 3. **Text Cross-Attention** (`attn2`): Conditions on text embeddings
14//!    (always zero in GAF since we don't use text prompts)
15//! 4. **IP-Adapter Cross-Attention** (`attn_ip`): Conditions on CLIP image
16//!    embeddings from the reference photo, providing identity preservation
17//! 5. **Feed-Forward** (`ff`): GeGLU-activated MLP for feature processing
18//!
19//! ## IP-Adapter Mechanism
20//!
21//! The IP-Adapter layer enables pixel-level identity conditioning:
22//!
23//! - **Input**: CLIP ViT-H/14 encodes reference image → 257×1280 embeddings
24//! - **Projection**: Linear layer projects to cross_attention_dim (1024)
25//! - **Attention**: Each spatial position (h×w) attends to 257 image tokens
26//! - **Output**: Spatially-varying conditioning based on reference features
27//!
28//! When `ip_tokens=None` (CFG unconditional pass), the IP-Adapter layer is
29//! skipped entirely via early return, producing unconditional predictions.
30//!
31//! ## Flash Attention Support
32//!
33//! When the `flash_attention` feature is enabled, attention modules can use
34//! memory-efficient flash attention with O(N) memory complexity instead of
35//! O(N²). This is controlled via the `use_flash_attention` field in
36//! `DiffusionConfig`.
37//!
38//! Flash attention provides 2-4× memory reduction for large images without
39//! sacrificing accuracy (< 1e-3 numerical difference from standard attention).
40
41use candle_core::{DType, Result, Tensor, D};
42use candle_nn as nn;
43use candle_nn::Module;
44
45#[cfg(feature = "flash_attention")]
46use crate::flash_attention::{FlashAttention, FlashAttentionConfig};
47
48// ---------------------------------------------------------------------------
49// GeGLU activation
50// ---------------------------------------------------------------------------
51
52#[derive(Debug)]
53struct GeGlu {
54    proj: nn::Linear,
55}
56
57impl GeGlu {
58    fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
59        let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
60        Ok(Self { proj })
61    }
62}
63
64impl Module for GeGlu {
65    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
66        let hidden_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
67        &hidden_and_gate[0] * hidden_and_gate[1].gelu()?
68    }
69}
70
71// ---------------------------------------------------------------------------
72// Feed-forward network
73// ---------------------------------------------------------------------------
74
75#[derive(Debug)]
76struct FeedForward {
77    project_in: GeGlu,
78    linear_out: nn::Linear,
79}
80
81impl FeedForward {
82    fn new(vs: nn::VarBuilder, dim: usize, mult: usize) -> Result<Self> {
83        let inner_dim = dim * mult;
84        let vs = vs.pp("net");
85        let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
86        let linear_out = nn::linear(inner_dim, dim, vs.pp("2"))?;
87        Ok(Self {
88            project_in,
89            linear_out,
90        })
91    }
92}
93
94impl Module for FeedForward {
95    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
96        let xs = self.project_in.forward(xs)?;
97        self.linear_out.forward(&xs)
98    }
99}
100
101// ---------------------------------------------------------------------------
102// Cross-attention (used for self-attn, text cross-attn, cross-view, IP)
103// ---------------------------------------------------------------------------
104
105/// Cross-attention module with optional flash attention support.
106///
107/// When flash attention is enabled (via feature flag and configuration),
108/// uses memory-efficient O(N) block-wise attention computation instead
109/// of the standard O(N^2) attention matrix.
110#[derive(Debug)]
111pub struct CrossAttention {
112    to_q: nn::Linear,
113    to_k: nn::Linear,
114    to_v: nn::Linear,
115    to_out: nn::Linear,
116    heads: usize,
117    dim_head: usize,
118    scale: f64,
119    /// Flash attention module (when feature is enabled)
120    #[cfg(feature = "flash_attention")]
121    flash_attention: Option<FlashAttention>,
122    /// Whether to use flash attention for this module
123    use_flash_attention: bool,
124}
125
126impl CrossAttention {
127    /// Create a new cross-attention module with standard attention.
128    pub fn new(
129        vs: nn::VarBuilder,
130        query_dim: usize,
131        context_dim: Option<usize>,
132        heads: usize,
133        dim_head: usize,
134    ) -> Result<Self> {
135        Self::new_with_flash(vs, query_dim, context_dim, heads, dim_head, false, 64)
136    }
137
138    /// Create a new cross-attention module with optional flash attention.
139    ///
140    /// # Arguments
141    ///
142    /// * `vs` - Variable builder for weight initialization
143    /// * `query_dim` - Query input dimension
144    /// * `context_dim` - Context dimension (None for self-attention)
145    /// * `heads` - Number of attention heads
146    /// * `dim_head` - Dimension per head
147    /// * `use_flash_attention` - Whether to use flash attention
148    /// * `flash_block_size` - Block size for flash attention tiling
149    #[allow(unused_variables)]
150    pub fn new_with_flash(
151        vs: nn::VarBuilder,
152        query_dim: usize,
153        context_dim: Option<usize>,
154        heads: usize,
155        dim_head: usize,
156        use_flash_attention: bool,
157        flash_block_size: usize,
158    ) -> Result<Self> {
159        let inner_dim = dim_head * heads;
160        let context_dim = context_dim.unwrap_or(query_dim);
161        let scale = 1.0 / (dim_head as f64).sqrt();
162        let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
163        let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
164        let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
165        let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
166
167        // Initialize flash attention if feature is enabled and requested
168        #[cfg(feature = "flash_attention")]
169        let flash_attention = if use_flash_attention {
170            let config = FlashAttentionConfig::with_block_size(flash_block_size);
171            Some(FlashAttention::new(dim_head, config))
172        } else {
173            None
174        };
175
176        Ok(Self {
177            to_q,
178            to_k,
179            to_v,
180            to_out,
181            heads,
182            dim_head,
183            scale,
184            #[cfg(feature = "flash_attention")]
185            flash_attention,
186            use_flash_attention,
187        })
188    }
189
190    /// Scaled-dot-product attention (standard or flash based on configuration).
191    ///
192    /// Automatically dispatches to flash attention when enabled and the feature
193    /// is available, otherwise uses standard O(N^2) attention.
194    pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
195        let context = context.unwrap_or(xs);
196        let (b, seq_len, _) = xs.dims3()?;
197        let q = self.to_q.forward(xs)?;
198        let k = self.to_k.forward(context)?;
199        let v = self.to_v.forward(context)?;
200
201        // Reshape to (B, heads, seq, dim_head) and make contiguous for matmul
202        let q = q
203            .reshape((b, seq_len, self.heads, self.dim_head))?
204            .transpose(1, 2)?
205            .contiguous()?;
206        let ctx_len = k.dim(1)?;
207        let k = k
208            .reshape((b, ctx_len, self.heads, self.dim_head))?
209            .transpose(1, 2)?
210            .contiguous()?;
211        let v = v
212            .reshape((b, ctx_len, self.heads, self.dim_head))?
213            .transpose(1, 2)?
214            .contiguous()?;
215
216        // Dispatch to flash attention or standard attention
217        #[cfg(feature = "flash_attention")]
218        let out = if let Some(flash) = &self.flash_attention {
219            flash.forward(&q, &k, &v)?
220        } else {
221            self.standard_attention(&q, &k, &v)?
222        };
223
224        #[cfg(not(feature = "flash_attention"))]
225        let out = self.standard_attention(&q, &k, &v)?;
226
227        // Reshape back to (B, seq, inner_dim)
228        let out = out
229            .transpose(1, 2)?
230            .contiguous()?
231            .reshape((b, seq_len, ()))?;
232        self.to_out.forward(&out)
233    }
234
235    /// Standard O(N^2) scaled-dot-product attention.
236    ///
237    /// Computes the full attention matrix. Used as fallback when flash
238    /// attention is disabled or unavailable.
239    fn standard_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
240        // Compute attention in f32 for numerical stability
241        let in_dtype = q.dtype();
242        let q = q.to_dtype(DType::F32)?;
243        let k = k.to_dtype(DType::F32)?;
244        let v = v.to_dtype(DType::F32)?;
245
246        let k_t = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
247        let attn = (q.matmul(&k_t)? * self.scale)?;
248        let attn = nn::ops::softmax_last_dim(&attn)?;
249        attn.matmul(&v)?.to_dtype(in_dtype)
250    }
251
252    /// Check if flash attention is enabled for this module.
253    ///
254    /// Returns `true` only if flash attention was requested during construction
255    /// AND the `flash_attention` feature is enabled.
256    pub fn is_flash_attention_enabled(&self) -> bool {
257        #[cfg(feature = "flash_attention")]
258        {
259            self.use_flash_attention && self.flash_attention.is_some()
260        }
261        #[cfg(not(feature = "flash_attention"))]
262        {
263            // Even if requested, flash attention is not available without the feature
264            let _ = self.use_flash_attention; // Suppress unused warning
265            false
266        }
267    }
268}
269
270// ---------------------------------------------------------------------------
271// Multi-view transformer block
272// ---------------------------------------------------------------------------
273
274/// A transformer block with multi-view cross-attention support.
275///
276/// Each block contains:
277/// 1. Self-attention (within each view)
278/// 2. Cross-view attention (across all N views)
279/// 3. Text/prompt cross-attention
280/// 4. IP cross-attention (reference image CLIP embedding)
281/// 5. Feed-forward network
282#[derive(Debug)]
283pub struct MultiViewTransformerBlock {
284    /// LayerNorm before self-attention
285    norm1: nn::LayerNorm,
286    /// Self-attention
287    attn1: CrossAttention,
288    /// LayerNorm before cross-view attention
289    norm_cv: nn::LayerNorm,
290    /// Cross-view attention
291    attn_cv: CrossAttention,
292    /// LayerNorm before text cross-attention
293    norm2: nn::LayerNorm,
294    /// Text cross-attention
295    attn2: CrossAttention,
296    /// LayerNorm before IP cross-attention
297    norm_ip: nn::LayerNorm,
298    /// IP-adapter cross-attention
299    attn_ip: CrossAttention,
300    /// LayerNorm before FFN
301    norm3: nn::LayerNorm,
302    /// Feed-forward network
303    ff: FeedForward,
304    /// Number of views
305    num_views: usize,
306}
307
308impl MultiViewTransformerBlock {
309    /// Create a new multi-view transformer block with standard attention.
310    pub fn new(
311        vs: nn::VarBuilder,
312        dim: usize,
313        n_heads: usize,
314        d_head: usize,
315        context_dim: usize,
316        ip_dim: usize,
317        num_views: usize,
318    ) -> Result<Self> {
319        Self::new_with_flash(
320            vs,
321            dim,
322            n_heads,
323            d_head,
324            context_dim,
325            ip_dim,
326            num_views,
327            false,
328            64,
329        )
330    }
331
332    /// Create a new multi-view transformer block with optional flash attention.
333    ///
334    /// # Arguments
335    ///
336    /// * `vs` - Variable builder for weight initialization
337    /// * `dim` - Hidden dimension
338    /// * `n_heads` - Number of attention heads
339    /// * `d_head` - Dimension per head
340    /// * `context_dim` - Text cross-attention context dimension
341    /// * `ip_dim` - IP-adapter context dimension
342    /// * `num_views` - Number of views for cross-view attention
343    /// * `use_flash_attention` - Whether to use flash attention
344    /// * `flash_block_size` - Block size for flash attention tiling
345    #[allow(clippy::too_many_arguments)]
346    pub fn new_with_flash(
347        vs: nn::VarBuilder,
348        dim: usize,
349        n_heads: usize,
350        d_head: usize,
351        context_dim: usize,
352        ip_dim: usize,
353        num_views: usize,
354        use_flash_attention: bool,
355        flash_block_size: usize,
356    ) -> Result<Self> {
357        let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
358        let attn1 = CrossAttention::new_with_flash(
359            vs.pp("attn1"),
360            dim,
361            None,
362            n_heads,
363            d_head,
364            use_flash_attention,
365            flash_block_size,
366        )?;
367
368        let norm_cv = nn::layer_norm(dim, 1e-5, vs.pp("norm_cv"))?;
369        // Cross-view attention typically has small sequence length (num_views),
370        // so flash attention may not be beneficial here
371        let attn_cv = CrossAttention::new(vs.pp("attn_cv"), dim, None, n_heads, d_head)?;
372
373        let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
374        let attn2 = CrossAttention::new_with_flash(
375            vs.pp("attn2"),
376            dim,
377            Some(context_dim),
378            n_heads,
379            d_head,
380            use_flash_attention,
381            flash_block_size,
382        )?;
383
384        let norm_ip = nn::layer_norm(dim, 1e-5, vs.pp("norm_ip"))?;
385        let attn_ip = CrossAttention::new_with_flash(
386            vs.pp("attn_ip"),
387            dim,
388            Some(ip_dim),
389            n_heads,
390            d_head,
391            use_flash_attention,
392            flash_block_size,
393        )?;
394
395        let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
396        let ff = FeedForward::new(vs.pp("ff"), dim, 4)?;
397
398        Ok(Self {
399            norm1,
400            attn1,
401            norm_cv,
402            attn_cv,
403            norm2,
404            attn2,
405            norm_ip,
406            attn_ip,
407            norm3,
408            ff,
409            num_views,
410        })
411    }
412
413    /// Forward pass.
414    ///
415    /// - `xs`: `(B*num_views, seq_len, dim)` — spatial tokens for all views (batched)
416    /// - `context`: `(B*num_views, ctx_len, context_dim)` — text encoder hidden states
417    /// - `ip_tokens`: `(B*num_views, ip_len, ip_dim)` — CLIP image embedding tokens
418    pub fn forward(
419        &self,
420        xs: &Tensor,
421        context: Option<&Tensor>,
422        ip_tokens: Option<&Tensor>,
423    ) -> Result<Tensor> {
424        let (bv, seq_len, dim) = xs.dims3()?;
425        let b = bv / self.num_views;
426
427        // 1. Self-attention (per-view)
428        let residual = xs;
429        let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + residual)?;
430
431        // 2. Cross-view attention
432        // Reshape so each position can attend across all views
433        let residual = &xs;
434        let normed = self.norm_cv.forward(&xs)?;
435        // (B*V, S, D) -> (B, V, S, D) -> (B, S, V, D) -> (B*S, V, D)
436        let cv_input = normed
437            .reshape((b, self.num_views, seq_len, dim))?
438            .transpose(1, 2)?
439            .reshape((b * seq_len, self.num_views, dim))?;
440        let cv_out = self.attn_cv.forward(&cv_input, None)?;
441        // (B*S, V, D) -> (B, S, V, D) -> (B, V, S, D) -> (B*V, S, D)
442        let cv_out = cv_out
443            .reshape((b, seq_len, self.num_views, dim))?
444            .transpose(1, 2)?
445            .reshape((bv, seq_len, dim))?;
446        let xs = (cv_out + residual)?;
447
448        // 3. Text cross-attention
449        let residual = &xs;
450        let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + residual)?;
451
452        // 4. IP cross-attention (reference image conditioning)
453        let xs = if let Some(ip) = ip_tokens {
454            let residual = &xs;
455            (self
456                .attn_ip
457                .forward(&self.norm_ip.forward(&xs)?, Some(ip))?
458                + residual)?
459        } else {
460            xs
461        };
462
463        // 5. Feed-forward
464        let residual = &xs;
465        self.ff.forward(&self.norm3.forward(&xs)?)? + residual
466    }
467}
468
469// ---------------------------------------------------------------------------
470// Multi-view spatial transformer (wraps projection + transformer blocks)
471// ---------------------------------------------------------------------------
472
473/// A spatial transformer that includes multi-view attention in every block.
474/// Replaces the standard `SpatialTransformer` from SD 2.1.
475#[derive(Debug)]
476pub struct MultiViewSpatialTransformer {
477    norm: nn::GroupNorm,
478    proj_in: nn::Linear,
479    transformer_blocks: Vec<MultiViewTransformerBlock>,
480    proj_out: nn::Linear,
481    use_linear_projection: bool,
482}
483
484impl MultiViewSpatialTransformer {
485    /// Create a new multi-view spatial transformer with standard attention.
486    #[allow(clippy::too_many_arguments)]
487    pub fn new(
488        vs: nn::VarBuilder,
489        in_channels: usize,
490        n_heads: usize,
491        d_head: usize,
492        depth: usize,
493        context_dim: usize,
494        ip_dim: usize,
495        num_views: usize,
496        num_groups: usize,
497        use_linear_projection: bool,
498    ) -> Result<Self> {
499        Self::new_with_flash(
500            vs,
501            in_channels,
502            n_heads,
503            d_head,
504            depth,
505            context_dim,
506            ip_dim,
507            num_views,
508            num_groups,
509            use_linear_projection,
510            false,
511            64,
512        )
513    }
514
515    /// Create a new multi-view spatial transformer with optional flash attention.
516    ///
517    /// # Arguments
518    ///
519    /// * `vs` - Variable builder for weight initialization
520    /// * `in_channels` - Number of input channels
521    /// * `n_heads` - Number of attention heads
522    /// * `d_head` - Dimension per head
523    /// * `depth` - Number of transformer blocks
524    /// * `context_dim` - Text cross-attention context dimension
525    /// * `ip_dim` - IP-adapter context dimension
526    /// * `num_views` - Number of views for cross-view attention
527    /// * `num_groups` - Number of groups for group normalization
528    /// * `use_linear_projection` - Whether to use linear projection
529    /// * `use_flash_attention` - Whether to use flash attention
530    /// * `flash_block_size` - Block size for flash attention tiling
531    #[allow(clippy::too_many_arguments)]
532    pub fn new_with_flash(
533        vs: nn::VarBuilder,
534        in_channels: usize,
535        n_heads: usize,
536        d_head: usize,
537        depth: usize,
538        context_dim: usize,
539        ip_dim: usize,
540        num_views: usize,
541        num_groups: usize,
542        use_linear_projection: bool,
543        use_flash_attention: bool,
544        flash_block_size: usize,
545    ) -> Result<Self> {
546        let inner_dim = n_heads * d_head;
547        let norm = nn::group_norm(num_groups, in_channels, 1e-6, vs.pp("norm"))?;
548        let proj_in = nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?;
549        let proj_out = nn::linear(inner_dim, in_channels, vs.pp("proj_out"))?;
550
551        let vs_tb = vs.pp("transformer_blocks");
552        let mut transformer_blocks = Vec::with_capacity(depth);
553        for i in 0..depth {
554            transformer_blocks.push(MultiViewTransformerBlock::new_with_flash(
555                vs_tb.pp(i.to_string()),
556                inner_dim,
557                n_heads,
558                d_head,
559                context_dim,
560                ip_dim,
561                num_views,
562                use_flash_attention,
563                flash_block_size,
564            )?);
565        }
566
567        Ok(Self {
568            norm,
569            proj_in,
570            transformer_blocks,
571            proj_out,
572            use_linear_projection,
573        })
574    }
575
576    /// Forward pass.
577    ///
578    /// - `xs`: `(B*V, C, H, W)` feature map
579    /// - `context`: optional text cross-attention context
580    /// - `ip_tokens`: optional IP-adapter tokens
581    pub fn forward(
582        &self,
583        xs: &Tensor,
584        context: Option<&Tensor>,
585        ip_tokens: Option<&Tensor>,
586    ) -> Result<Tensor> {
587        let (batch, _channel, height, width) = xs.dims4()?;
588        let residual = xs;
589
590        let xs = self.norm.forward(xs)?;
591        // Flatten spatial dims and optionally project
592        let inner_dim = if self.use_linear_projection {
593            let inner_dim = xs.dim(1)?;
594            let xs_flat =
595                xs.transpose(1, 2)?
596                    .transpose(2, 3)?
597                    .reshape((batch, height * width, inner_dim))?;
598            let xs_proj = self.proj_in.forward(&xs_flat)?;
599            // Process through transformer blocks
600            let mut h = xs_proj;
601            for block in &self.transformer_blocks {
602                h = block.forward(&h, context, ip_tokens)?;
603            }
604            let h = self.proj_out.forward(&h)?;
605            let result = h
606                .reshape((batch, height, width, inner_dim))?
607                .transpose(2, 3)?
608                .transpose(1, 2)?;
609            return result + residual;
610        } else {
611            xs.dim(1)?
612        };
613
614        // Conv-style projection path (for completeness, though SD 2.1 uses linear)
615        let xs_flat =
616            xs.transpose(1, 2)?
617                .transpose(2, 3)?
618                .reshape((batch, height * width, inner_dim))?;
619        let xs_proj = self.proj_in.forward(&xs_flat)?;
620        let mut h = xs_proj;
621        for block in &self.transformer_blocks {
622            h = block.forward(&h, context, ip_tokens)?;
623        }
624        let h = self.proj_out.forward(&h)?;
625        let result = h
626            .reshape((batch, height, width, inner_dim))?
627            .transpose(2, 3)?
628            .transpose(1, 2)?;
629        result + residual
630    }
631}