oxigaf_diffusion/attention.rs
1//! Attention-based building blocks for multi-view diffusion.
2//!
3//! Implements the multi-view transformer block that replaces the standard
4//! SD 2.1 `BasicTransformerBlock` with additional layers:
5//!
6//! ## Multi-View Transformer Architecture
7//!
8//! Each `MultiViewTransformerBlock` contains five sequential operations:
9//!
10//! 1. **Self-Attention** (`attn1`): Attention within each view's spatial tokens
11//! 2. **Cross-View Attention** (`attn_cv`): Attention across all N views at each
12//! spatial position, enabling 3D consistency
13//! 3. **Text Cross-Attention** (`attn2`): Conditions on text embeddings
14//! (always zero in GAF since we don't use text prompts)
15//! 4. **IP-Adapter Cross-Attention** (`attn_ip`): Conditions on CLIP image
16//! embeddings from the reference photo, providing identity preservation
17//! 5. **Feed-Forward** (`ff`): GeGLU-activated MLP for feature processing
18//!
19//! ## IP-Adapter Mechanism
20//!
21//! The IP-Adapter layer enables pixel-level identity conditioning:
22//!
23//! - **Input**: CLIP ViT-H/14 encodes reference image → 257×1280 embeddings
24//! - **Projection**: Linear layer projects to cross_attention_dim (1024)
25//! - **Attention**: Each spatial position (h×w) attends to 257 image tokens
26//! - **Output**: Spatially-varying conditioning based on reference features
27//!
28//! When `ip_tokens=None` (CFG unconditional pass), the IP-Adapter layer is
29//! skipped entirely via early return, producing unconditional predictions.
30//!
31//! ## Flash Attention Support
32//!
33//! When the `flash_attention` feature is enabled, attention modules can use
34//! memory-efficient flash attention with O(N) memory complexity instead of
35//! O(N²). This is controlled via the `use_flash_attention` field in
36//! `DiffusionConfig`.
37//!
38//! Flash attention provides 2-4× memory reduction for large images without
39//! sacrificing accuracy (< 1e-3 numerical difference from standard attention).
40
41use candle_core::{DType, Result, Tensor, D};
42use candle_nn as nn;
43use candle_nn::Module;
44
45#[cfg(feature = "flash_attention")]
46use crate::flash_attention::{FlashAttention, FlashAttentionConfig};
47
48// ---------------------------------------------------------------------------
49// GeGLU activation
50// ---------------------------------------------------------------------------
51
52#[derive(Debug)]
53struct GeGlu {
54 proj: nn::Linear,
55}
56
57impl GeGlu {
58 fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
59 let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
60 Ok(Self { proj })
61 }
62}
63
64impl Module for GeGlu {
65 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
66 let hidden_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
67 &hidden_and_gate[0] * hidden_and_gate[1].gelu()?
68 }
69}
70
71// ---------------------------------------------------------------------------
72// Feed-forward network
73// ---------------------------------------------------------------------------
74
75#[derive(Debug)]
76struct FeedForward {
77 project_in: GeGlu,
78 linear_out: nn::Linear,
79}
80
81impl FeedForward {
82 fn new(vs: nn::VarBuilder, dim: usize, mult: usize) -> Result<Self> {
83 let inner_dim = dim * mult;
84 let vs = vs.pp("net");
85 let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
86 let linear_out = nn::linear(inner_dim, dim, vs.pp("2"))?;
87 Ok(Self {
88 project_in,
89 linear_out,
90 })
91 }
92}
93
94impl Module for FeedForward {
95 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
96 let xs = self.project_in.forward(xs)?;
97 self.linear_out.forward(&xs)
98 }
99}
100
101// ---------------------------------------------------------------------------
102// Cross-attention (used for self-attn, text cross-attn, cross-view, IP)
103// ---------------------------------------------------------------------------
104
105/// Cross-attention module with optional flash attention support.
106///
107/// When flash attention is enabled (via feature flag and configuration),
108/// uses memory-efficient O(N) block-wise attention computation instead
109/// of the standard O(N^2) attention matrix.
110#[derive(Debug)]
111pub struct CrossAttention {
112 to_q: nn::Linear,
113 to_k: nn::Linear,
114 to_v: nn::Linear,
115 to_out: nn::Linear,
116 heads: usize,
117 dim_head: usize,
118 scale: f64,
119 /// Flash attention module (when feature is enabled)
120 #[cfg(feature = "flash_attention")]
121 flash_attention: Option<FlashAttention>,
122 /// Whether to use flash attention for this module
123 use_flash_attention: bool,
124}
125
126impl CrossAttention {
127 /// Create a new cross-attention module with standard attention.
128 pub fn new(
129 vs: nn::VarBuilder,
130 query_dim: usize,
131 context_dim: Option<usize>,
132 heads: usize,
133 dim_head: usize,
134 ) -> Result<Self> {
135 Self::new_with_flash(vs, query_dim, context_dim, heads, dim_head, false, 64)
136 }
137
138 /// Create a new cross-attention module with optional flash attention.
139 ///
140 /// # Arguments
141 ///
142 /// * `vs` - Variable builder for weight initialization
143 /// * `query_dim` - Query input dimension
144 /// * `context_dim` - Context dimension (None for self-attention)
145 /// * `heads` - Number of attention heads
146 /// * `dim_head` - Dimension per head
147 /// * `use_flash_attention` - Whether to use flash attention
148 /// * `flash_block_size` - Block size for flash attention tiling
149 #[allow(unused_variables)]
150 pub fn new_with_flash(
151 vs: nn::VarBuilder,
152 query_dim: usize,
153 context_dim: Option<usize>,
154 heads: usize,
155 dim_head: usize,
156 use_flash_attention: bool,
157 flash_block_size: usize,
158 ) -> Result<Self> {
159 let inner_dim = dim_head * heads;
160 let context_dim = context_dim.unwrap_or(query_dim);
161 let scale = 1.0 / (dim_head as f64).sqrt();
162 let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
163 let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
164 let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
165 let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
166
167 // Initialize flash attention if feature is enabled and requested
168 #[cfg(feature = "flash_attention")]
169 let flash_attention = if use_flash_attention {
170 let config = FlashAttentionConfig::with_block_size(flash_block_size);
171 Some(FlashAttention::new(dim_head, config))
172 } else {
173 None
174 };
175
176 Ok(Self {
177 to_q,
178 to_k,
179 to_v,
180 to_out,
181 heads,
182 dim_head,
183 scale,
184 #[cfg(feature = "flash_attention")]
185 flash_attention,
186 use_flash_attention,
187 })
188 }
189
190 /// Scaled-dot-product attention (standard or flash based on configuration).
191 ///
192 /// Automatically dispatches to flash attention when enabled and the feature
193 /// is available, otherwise uses standard O(N^2) attention.
194 pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
195 let context = context.unwrap_or(xs);
196 let (b, seq_len, _) = xs.dims3()?;
197 let q = self.to_q.forward(xs)?;
198 let k = self.to_k.forward(context)?;
199 let v = self.to_v.forward(context)?;
200
201 // Reshape to (B, heads, seq, dim_head) and make contiguous for matmul
202 let q = q
203 .reshape((b, seq_len, self.heads, self.dim_head))?
204 .transpose(1, 2)?
205 .contiguous()?;
206 let ctx_len = k.dim(1)?;
207 let k = k
208 .reshape((b, ctx_len, self.heads, self.dim_head))?
209 .transpose(1, 2)?
210 .contiguous()?;
211 let v = v
212 .reshape((b, ctx_len, self.heads, self.dim_head))?
213 .transpose(1, 2)?
214 .contiguous()?;
215
216 // Dispatch to flash attention or standard attention
217 #[cfg(feature = "flash_attention")]
218 let out = if let Some(flash) = &self.flash_attention {
219 flash.forward(&q, &k, &v)?
220 } else {
221 self.standard_attention(&q, &k, &v)?
222 };
223
224 #[cfg(not(feature = "flash_attention"))]
225 let out = self.standard_attention(&q, &k, &v)?;
226
227 // Reshape back to (B, seq, inner_dim)
228 let out = out
229 .transpose(1, 2)?
230 .contiguous()?
231 .reshape((b, seq_len, ()))?;
232 self.to_out.forward(&out)
233 }
234
235 /// Standard O(N^2) scaled-dot-product attention.
236 ///
237 /// Computes the full attention matrix. Used as fallback when flash
238 /// attention is disabled or unavailable.
239 fn standard_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
240 // Compute attention in f32 for numerical stability
241 let in_dtype = q.dtype();
242 let q = q.to_dtype(DType::F32)?;
243 let k = k.to_dtype(DType::F32)?;
244 let v = v.to_dtype(DType::F32)?;
245
246 let k_t = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
247 let attn = (q.matmul(&k_t)? * self.scale)?;
248 let attn = nn::ops::softmax_last_dim(&attn)?;
249 attn.matmul(&v)?.to_dtype(in_dtype)
250 }
251
252 /// Check if flash attention is enabled for this module.
253 ///
254 /// Returns `true` only if flash attention was requested during construction
255 /// AND the `flash_attention` feature is enabled.
256 pub fn is_flash_attention_enabled(&self) -> bool {
257 #[cfg(feature = "flash_attention")]
258 {
259 self.use_flash_attention && self.flash_attention.is_some()
260 }
261 #[cfg(not(feature = "flash_attention"))]
262 {
263 // Even if requested, flash attention is not available without the feature
264 let _ = self.use_flash_attention; // Suppress unused warning
265 false
266 }
267 }
268}
269
270// ---------------------------------------------------------------------------
271// Multi-view transformer block
272// ---------------------------------------------------------------------------
273
274/// A transformer block with multi-view cross-attention support.
275///
276/// Each block contains:
277/// 1. Self-attention (within each view)
278/// 2. Cross-view attention (across all N views)
279/// 3. Text/prompt cross-attention
280/// 4. IP cross-attention (reference image CLIP embedding)
281/// 5. Feed-forward network
282#[derive(Debug)]
283pub struct MultiViewTransformerBlock {
284 /// LayerNorm before self-attention
285 norm1: nn::LayerNorm,
286 /// Self-attention
287 attn1: CrossAttention,
288 /// LayerNorm before cross-view attention
289 norm_cv: nn::LayerNorm,
290 /// Cross-view attention
291 attn_cv: CrossAttention,
292 /// LayerNorm before text cross-attention
293 norm2: nn::LayerNorm,
294 /// Text cross-attention
295 attn2: CrossAttention,
296 /// LayerNorm before IP cross-attention
297 norm_ip: nn::LayerNorm,
298 /// IP-adapter cross-attention
299 attn_ip: CrossAttention,
300 /// LayerNorm before FFN
301 norm3: nn::LayerNorm,
302 /// Feed-forward network
303 ff: FeedForward,
304 /// Number of views
305 num_views: usize,
306}
307
308impl MultiViewTransformerBlock {
309 /// Create a new multi-view transformer block with standard attention.
310 pub fn new(
311 vs: nn::VarBuilder,
312 dim: usize,
313 n_heads: usize,
314 d_head: usize,
315 context_dim: usize,
316 ip_dim: usize,
317 num_views: usize,
318 ) -> Result<Self> {
319 Self::new_with_flash(
320 vs,
321 dim,
322 n_heads,
323 d_head,
324 context_dim,
325 ip_dim,
326 num_views,
327 false,
328 64,
329 )
330 }
331
332 /// Create a new multi-view transformer block with optional flash attention.
333 ///
334 /// # Arguments
335 ///
336 /// * `vs` - Variable builder for weight initialization
337 /// * `dim` - Hidden dimension
338 /// * `n_heads` - Number of attention heads
339 /// * `d_head` - Dimension per head
340 /// * `context_dim` - Text cross-attention context dimension
341 /// * `ip_dim` - IP-adapter context dimension
342 /// * `num_views` - Number of views for cross-view attention
343 /// * `use_flash_attention` - Whether to use flash attention
344 /// * `flash_block_size` - Block size for flash attention tiling
345 #[allow(clippy::too_many_arguments)]
346 pub fn new_with_flash(
347 vs: nn::VarBuilder,
348 dim: usize,
349 n_heads: usize,
350 d_head: usize,
351 context_dim: usize,
352 ip_dim: usize,
353 num_views: usize,
354 use_flash_attention: bool,
355 flash_block_size: usize,
356 ) -> Result<Self> {
357 let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
358 let attn1 = CrossAttention::new_with_flash(
359 vs.pp("attn1"),
360 dim,
361 None,
362 n_heads,
363 d_head,
364 use_flash_attention,
365 flash_block_size,
366 )?;
367
368 let norm_cv = nn::layer_norm(dim, 1e-5, vs.pp("norm_cv"))?;
369 // Cross-view attention typically has small sequence length (num_views),
370 // so flash attention may not be beneficial here
371 let attn_cv = CrossAttention::new(vs.pp("attn_cv"), dim, None, n_heads, d_head)?;
372
373 let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
374 let attn2 = CrossAttention::new_with_flash(
375 vs.pp("attn2"),
376 dim,
377 Some(context_dim),
378 n_heads,
379 d_head,
380 use_flash_attention,
381 flash_block_size,
382 )?;
383
384 let norm_ip = nn::layer_norm(dim, 1e-5, vs.pp("norm_ip"))?;
385 let attn_ip = CrossAttention::new_with_flash(
386 vs.pp("attn_ip"),
387 dim,
388 Some(ip_dim),
389 n_heads,
390 d_head,
391 use_flash_attention,
392 flash_block_size,
393 )?;
394
395 let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
396 let ff = FeedForward::new(vs.pp("ff"), dim, 4)?;
397
398 Ok(Self {
399 norm1,
400 attn1,
401 norm_cv,
402 attn_cv,
403 norm2,
404 attn2,
405 norm_ip,
406 attn_ip,
407 norm3,
408 ff,
409 num_views,
410 })
411 }
412
413 /// Forward pass.
414 ///
415 /// - `xs`: `(B*num_views, seq_len, dim)` — spatial tokens for all views (batched)
416 /// - `context`: `(B*num_views, ctx_len, context_dim)` — text encoder hidden states
417 /// - `ip_tokens`: `(B*num_views, ip_len, ip_dim)` — CLIP image embedding tokens
418 pub fn forward(
419 &self,
420 xs: &Tensor,
421 context: Option<&Tensor>,
422 ip_tokens: Option<&Tensor>,
423 ) -> Result<Tensor> {
424 let (bv, seq_len, dim) = xs.dims3()?;
425 let b = bv / self.num_views;
426
427 // 1. Self-attention (per-view)
428 let residual = xs;
429 let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + residual)?;
430
431 // 2. Cross-view attention
432 // Reshape so each position can attend across all views
433 let residual = &xs;
434 let normed = self.norm_cv.forward(&xs)?;
435 // (B*V, S, D) -> (B, V, S, D) -> (B, S, V, D) -> (B*S, V, D)
436 let cv_input = normed
437 .reshape((b, self.num_views, seq_len, dim))?
438 .transpose(1, 2)?
439 .reshape((b * seq_len, self.num_views, dim))?;
440 let cv_out = self.attn_cv.forward(&cv_input, None)?;
441 // (B*S, V, D) -> (B, S, V, D) -> (B, V, S, D) -> (B*V, S, D)
442 let cv_out = cv_out
443 .reshape((b, seq_len, self.num_views, dim))?
444 .transpose(1, 2)?
445 .reshape((bv, seq_len, dim))?;
446 let xs = (cv_out + residual)?;
447
448 // 3. Text cross-attention
449 let residual = &xs;
450 let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + residual)?;
451
452 // 4. IP cross-attention (reference image conditioning)
453 let xs = if let Some(ip) = ip_tokens {
454 let residual = &xs;
455 (self
456 .attn_ip
457 .forward(&self.norm_ip.forward(&xs)?, Some(ip))?
458 + residual)?
459 } else {
460 xs
461 };
462
463 // 5. Feed-forward
464 let residual = &xs;
465 self.ff.forward(&self.norm3.forward(&xs)?)? + residual
466 }
467}
468
469// ---------------------------------------------------------------------------
470// Multi-view spatial transformer (wraps projection + transformer blocks)
471// ---------------------------------------------------------------------------
472
473/// A spatial transformer that includes multi-view attention in every block.
474/// Replaces the standard `SpatialTransformer` from SD 2.1.
475#[derive(Debug)]
476pub struct MultiViewSpatialTransformer {
477 norm: nn::GroupNorm,
478 proj_in: nn::Linear,
479 transformer_blocks: Vec<MultiViewTransformerBlock>,
480 proj_out: nn::Linear,
481 use_linear_projection: bool,
482}
483
484impl MultiViewSpatialTransformer {
485 /// Create a new multi-view spatial transformer with standard attention.
486 #[allow(clippy::too_many_arguments)]
487 pub fn new(
488 vs: nn::VarBuilder,
489 in_channels: usize,
490 n_heads: usize,
491 d_head: usize,
492 depth: usize,
493 context_dim: usize,
494 ip_dim: usize,
495 num_views: usize,
496 num_groups: usize,
497 use_linear_projection: bool,
498 ) -> Result<Self> {
499 Self::new_with_flash(
500 vs,
501 in_channels,
502 n_heads,
503 d_head,
504 depth,
505 context_dim,
506 ip_dim,
507 num_views,
508 num_groups,
509 use_linear_projection,
510 false,
511 64,
512 )
513 }
514
515 /// Create a new multi-view spatial transformer with optional flash attention.
516 ///
517 /// # Arguments
518 ///
519 /// * `vs` - Variable builder for weight initialization
520 /// * `in_channels` - Number of input channels
521 /// * `n_heads` - Number of attention heads
522 /// * `d_head` - Dimension per head
523 /// * `depth` - Number of transformer blocks
524 /// * `context_dim` - Text cross-attention context dimension
525 /// * `ip_dim` - IP-adapter context dimension
526 /// * `num_views` - Number of views for cross-view attention
527 /// * `num_groups` - Number of groups for group normalization
528 /// * `use_linear_projection` - Whether to use linear projection
529 /// * `use_flash_attention` - Whether to use flash attention
530 /// * `flash_block_size` - Block size for flash attention tiling
531 #[allow(clippy::too_many_arguments)]
532 pub fn new_with_flash(
533 vs: nn::VarBuilder,
534 in_channels: usize,
535 n_heads: usize,
536 d_head: usize,
537 depth: usize,
538 context_dim: usize,
539 ip_dim: usize,
540 num_views: usize,
541 num_groups: usize,
542 use_linear_projection: bool,
543 use_flash_attention: bool,
544 flash_block_size: usize,
545 ) -> Result<Self> {
546 let inner_dim = n_heads * d_head;
547 let norm = nn::group_norm(num_groups, in_channels, 1e-6, vs.pp("norm"))?;
548 let proj_in = nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?;
549 let proj_out = nn::linear(inner_dim, in_channels, vs.pp("proj_out"))?;
550
551 let vs_tb = vs.pp("transformer_blocks");
552 let mut transformer_blocks = Vec::with_capacity(depth);
553 for i in 0..depth {
554 transformer_blocks.push(MultiViewTransformerBlock::new_with_flash(
555 vs_tb.pp(i.to_string()),
556 inner_dim,
557 n_heads,
558 d_head,
559 context_dim,
560 ip_dim,
561 num_views,
562 use_flash_attention,
563 flash_block_size,
564 )?);
565 }
566
567 Ok(Self {
568 norm,
569 proj_in,
570 transformer_blocks,
571 proj_out,
572 use_linear_projection,
573 })
574 }
575
576 /// Forward pass.
577 ///
578 /// - `xs`: `(B*V, C, H, W)` feature map
579 /// - `context`: optional text cross-attention context
580 /// - `ip_tokens`: optional IP-adapter tokens
581 pub fn forward(
582 &self,
583 xs: &Tensor,
584 context: Option<&Tensor>,
585 ip_tokens: Option<&Tensor>,
586 ) -> Result<Tensor> {
587 let (batch, _channel, height, width) = xs.dims4()?;
588 let residual = xs;
589
590 let xs = self.norm.forward(xs)?;
591 // Flatten spatial dims and optionally project
592 let inner_dim = if self.use_linear_projection {
593 let inner_dim = xs.dim(1)?;
594 let xs_flat =
595 xs.transpose(1, 2)?
596 .transpose(2, 3)?
597 .reshape((batch, height * width, inner_dim))?;
598 let xs_proj = self.proj_in.forward(&xs_flat)?;
599 // Process through transformer blocks
600 let mut h = xs_proj;
601 for block in &self.transformer_blocks {
602 h = block.forward(&h, context, ip_tokens)?;
603 }
604 let h = self.proj_out.forward(&h)?;
605 let result = h
606 .reshape((batch, height, width, inner_dim))?
607 .transpose(2, 3)?
608 .transpose(1, 2)?;
609 return result + residual;
610 } else {
611 xs.dim(1)?
612 };
613
614 // Conv-style projection path (for completeness, though SD 2.1 uses linear)
615 let xs_flat =
616 xs.transpose(1, 2)?
617 .transpose(2, 3)?
618 .reshape((batch, height * width, inner_dim))?;
619 let xs_proj = self.proj_in.forward(&xs_flat)?;
620 let mut h = xs_proj;
621 for block in &self.transformer_blocks {
622 h = block.forward(&h, context, ip_tokens)?;
623 }
624 let h = self.proj_out.forward(&h)?;
625 let result = h
626 .reshape((batch, height, width, inner_dim))?
627 .transpose(2, 3)?
628 .transpose(1, 2)?;
629 result + residual
630 }
631}