Skip to main content

oxigaf_diffusion/
unet.rs

1//! Multi-view U-Net with camera-conditioned cross-view attention.
2//!
3//! The U-Net follows the SD 2.1 architecture but replaces every spatial
4//! transformer block with a `MultiViewSpatialTransformer` that adds:
5//!
6//! 1. **Cross-view attention**: Allows spatial positions to attend across all views
7//! 2. **IP-Adapter conditioning**: Dedicated cross-attention layer (`attn_ip`) that
8//!    conditions on CLIP image embeddings from the reference photo
9//! 3. **Camera-pose conditioning**: Camera extrinsics added to timestep embedding
10//!
11//! ## IP-Adapter Integration
12//!
13//! Each transformer block contains four attention layers:
14//! - `attn1`: Self-attention (within view)
15//! - `attn_cv`: Cross-view attention (across views)
16//! - `attn2`: Text cross-attention (unused in GAF, always zero)
17//! - `attn_ip`: IP-Adapter cross-attention (reference image conditioning)
18//!
19//! When `ip_tokens` is `None` (unconditional pass), the `attn_ip` layer is
20//! skipped entirely, producing the unconditional prediction for CFG.
21//!
22//! ## Architecture Details
23//!
24//! The U-Net structure:
25//! - **Encoder**: 4 downsampling stages (320 → 640 → 1280 → 1280 channels)
26//! - **Bottleneck**: ResBlock + Attention + ResBlock at 1280 channels
27//! - **Decoder**: 4 upsampling stages with skip connections
28//! - **Output**: GroupNorm + Conv → 4-channel latent prediction
29//!
30//! Each stage contains 2 ResBlocks + 1 MultiViewSpatialTransformer (if attention enabled).
31
32use candle_core::{Result, Tensor};
33use candle_nn as nn;
34use candle_nn::Module;
35
36use crate::attention::MultiViewSpatialTransformer;
37use crate::camera::{timestep_embedding, CameraEmbedding, TimestepEmbedding};
38use crate::config::DiffusionConfig;
39use crate::DiffusionError;
40
41// ---------------------------------------------------------------------------
42// Building blocks
43// ---------------------------------------------------------------------------
44
45/// ResNet block with time-step conditioning.
46#[derive(Debug)]
47struct ResBlock {
48    norm1: nn::GroupNorm,
49    conv1: nn::Conv2d,
50    time_emb_proj: nn::Linear,
51    norm2: nn::GroupNorm,
52    conv2: nn::Conv2d,
53    residual_conv: Option<nn::Conv2d>,
54}
55
56impl ResBlock {
57    fn new(vs: nn::VarBuilder, in_ch: usize, out_ch: usize, time_dim: usize) -> Result<Self> {
58        let norm1 = nn::group_norm(32, in_ch, 1e-5, vs.pp("norm1"))?;
59        let conv1 = nn::conv2d(
60            in_ch,
61            out_ch,
62            3,
63            nn::Conv2dConfig {
64                padding: 1,
65                ..Default::default()
66            },
67            vs.pp("conv1"),
68        )?;
69        let time_emb_proj = nn::linear(time_dim, out_ch, vs.pp("time_emb_proj"))?;
70        let norm2 = nn::group_norm(32, out_ch, 1e-5, vs.pp("norm2"))?;
71        let conv2 = nn::conv2d(
72            out_ch,
73            out_ch,
74            3,
75            nn::Conv2dConfig {
76                padding: 1,
77                ..Default::default()
78            },
79            vs.pp("conv2"),
80        )?;
81        let residual_conv = if in_ch != out_ch {
82            Some(nn::conv2d(
83                in_ch,
84                out_ch,
85                1,
86                Default::default(),
87                vs.pp("conv_shortcut"),
88            )?)
89        } else {
90            None
91        };
92        Ok(Self {
93            norm1,
94            conv1,
95            time_emb_proj,
96            norm2,
97            conv2,
98            residual_conv,
99        })
100    }
101
102    fn forward(&self, xs: &Tensor, time_emb: &Tensor) -> Result<Tensor> {
103        let residual = if let Some(ref conv) = self.residual_conv {
104            conv.forward(xs)?
105        } else {
106            xs.clone()
107        };
108        let h = self.norm1.forward(xs)?.silu()?;
109        let h = self.conv1.forward(&h)?;
110
111        // Add time embedding: project then unsqueeze spatial dims
112        let t = self.time_emb_proj.forward(&time_emb.silu()?)?;
113        let t = t.unsqueeze(2)?.unsqueeze(3)?;
114        let h = (h.clone() + t.broadcast_as(h.shape())?)?;
115
116        let h = self.norm2.forward(&h)?.silu()?;
117        let h = self.conv2.forward(&h)?;
118        h + residual
119    }
120}
121
122/// Downsample with strided convolution.
123#[derive(Debug)]
124struct Downsample2d {
125    conv: nn::Conv2d,
126}
127
128impl Downsample2d {
129    fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
130        let conv = nn::conv2d(
131            channels,
132            channels,
133            3,
134            nn::Conv2dConfig {
135                stride: 2,
136                padding: 1,
137                ..Default::default()
138            },
139            vs.pp("conv"),
140        )?;
141        Ok(Self { conv })
142    }
143}
144
145impl Module for Downsample2d {
146    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
147        self.conv.forward(xs)
148    }
149}
150
151/// Upsample with nearest-neighbor interpolation + conv.
152#[derive(Debug)]
153struct Upsample2d {
154    conv: nn::Conv2d,
155}
156
157impl Upsample2d {
158    fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
159        let conv = nn::conv2d(
160            channels,
161            channels,
162            3,
163            nn::Conv2dConfig {
164                padding: 1,
165                ..Default::default()
166            },
167            vs.pp("conv"),
168        )?;
169        Ok(Self { conv })
170    }
171}
172
173impl Module for Upsample2d {
174    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
175        let (_, _, h, w) = xs.dims4()?;
176        let xs = xs.upsample_nearest2d(h * 2, w * 2)?;
177        self.conv.forward(&xs)
178    }
179}
180
181// ---------------------------------------------------------------------------
182// Down / Mid / Up blocks
183// ---------------------------------------------------------------------------
184
185/// A single downsampling stage with ResBlocks + optional spatial transformer.
186#[derive(Debug)]
187struct DownBlock {
188    resnets: Vec<ResBlock>,
189    attentions: Vec<MultiViewSpatialTransformer>,
190    downsample: Option<Downsample2d>,
191}
192
193impl DownBlock {
194    #[allow(clippy::too_many_arguments)]
195    fn new(
196        vs: nn::VarBuilder,
197        in_ch: usize,
198        out_ch: usize,
199        time_dim: usize,
200        num_layers: usize,
201        has_attn: bool,
202        n_heads: usize,
203        d_head: usize,
204        depth: usize,
205        context_dim: usize,
206        ip_dim: usize,
207        num_views: usize,
208        num_groups: usize,
209        use_linear: bool,
210        has_downsample: bool,
211    ) -> Result<Self> {
212        let vs_res = vs.pp("resnets");
213        let mut resnets = Vec::with_capacity(num_layers);
214        for i in 0..num_layers {
215            let ich = if i == 0 { in_ch } else { out_ch };
216            resnets.push(ResBlock::new(
217                vs_res.pp(i.to_string()),
218                ich,
219                out_ch,
220                time_dim,
221            )?);
222        }
223
224        let mut attentions = Vec::new();
225        if has_attn {
226            let vs_attn = vs.pp("attentions");
227            for i in 0..num_layers {
228                attentions.push(MultiViewSpatialTransformer::new(
229                    vs_attn.pp(i.to_string()),
230                    out_ch,
231                    n_heads,
232                    d_head,
233                    depth,
234                    context_dim,
235                    ip_dim,
236                    num_views,
237                    num_groups,
238                    use_linear,
239                )?);
240            }
241        }
242
243        let downsample = if has_downsample {
244            Some(Downsample2d::new(vs.pp("downsamplers.0"), out_ch)?)
245        } else {
246            None
247        };
248
249        Ok(Self {
250            resnets,
251            attentions,
252            downsample,
253        })
254    }
255
256    fn forward(
257        &self,
258        xs: &Tensor,
259        time_emb: &Tensor,
260        context: Option<&Tensor>,
261        ip_tokens: Option<&Tensor>,
262    ) -> Result<(Tensor, Vec<Tensor>)> {
263        let mut h = xs.clone();
264        let mut skip_connections = Vec::new();
265
266        for (i, resnet) in self.resnets.iter().enumerate() {
267            h = resnet.forward(&h, time_emb)?;
268            if !self.attentions.is_empty() {
269                h = self.attentions[i].forward(&h, context, ip_tokens)?;
270            }
271            skip_connections.push(h.clone());
272        }
273
274        if let Some(ref ds) = self.downsample {
275            h = ds.forward(&h)?;
276            skip_connections.push(h.clone());
277        }
278
279        Ok((h, skip_connections))
280    }
281}
282
283/// Mid-block: ResBlock + attention + ResBlock.
284#[derive(Debug)]
285struct MidBlock {
286    resnet1: ResBlock,
287    attention: MultiViewSpatialTransformer,
288    resnet2: ResBlock,
289}
290
291impl MidBlock {
292    #[allow(clippy::too_many_arguments)]
293    fn new(
294        vs: nn::VarBuilder,
295        channels: usize,
296        time_dim: usize,
297        n_heads: usize,
298        d_head: usize,
299        depth: usize,
300        context_dim: usize,
301        ip_dim: usize,
302        num_views: usize,
303        num_groups: usize,
304        use_linear: bool,
305    ) -> Result<Self> {
306        let resnet1 = ResBlock::new(vs.pp("resnets.0"), channels, channels, time_dim)?;
307        let attention = MultiViewSpatialTransformer::new(
308            vs.pp("attentions.0"),
309            channels,
310            n_heads,
311            d_head,
312            depth,
313            context_dim,
314            ip_dim,
315            num_views,
316            num_groups,
317            use_linear,
318        )?;
319        let resnet2 = ResBlock::new(vs.pp("resnets.1"), channels, channels, time_dim)?;
320        Ok(Self {
321            resnet1,
322            attention,
323            resnet2,
324        })
325    }
326
327    fn forward(
328        &self,
329        xs: &Tensor,
330        time_emb: &Tensor,
331        context: Option<&Tensor>,
332        ip_tokens: Option<&Tensor>,
333    ) -> Result<Tensor> {
334        let h = self.resnet1.forward(xs, time_emb)?;
335        let h = self.attention.forward(&h, context, ip_tokens)?;
336        self.resnet2.forward(&h, time_emb)
337    }
338}
339
340/// A single upsampling stage with ResBlocks + optional spatial transformer.
341#[derive(Debug)]
342struct UpBlock {
343    resnets: Vec<ResBlock>,
344    attentions: Vec<MultiViewSpatialTransformer>,
345    upsample: Option<Upsample2d>,
346}
347
348impl UpBlock {
349    #[allow(clippy::too_many_arguments)]
350    fn new(
351        vs: nn::VarBuilder,
352        in_ch: usize,
353        out_ch: usize,
354        skip_ch: usize,
355        time_dim: usize,
356        num_layers: usize,
357        has_attn: bool,
358        n_heads: usize,
359        d_head: usize,
360        depth: usize,
361        context_dim: usize,
362        ip_dim: usize,
363        num_views: usize,
364        num_groups: usize,
365        use_linear: bool,
366        has_upsample: bool,
367    ) -> Result<Self> {
368        let vs_res = vs.pp("resnets");
369        let mut resnets = Vec::with_capacity(num_layers);
370        for i in 0..num_layers {
371            let ich = if i == 0 {
372                in_ch + skip_ch
373            } else {
374                out_ch + skip_ch
375            };
376            resnets.push(ResBlock::new(
377                vs_res.pp(i.to_string()),
378                ich,
379                out_ch,
380                time_dim,
381            )?);
382        }
383
384        let mut attentions = Vec::new();
385        if has_attn {
386            let vs_attn = vs.pp("attentions");
387            for i in 0..num_layers {
388                attentions.push(MultiViewSpatialTransformer::new(
389                    vs_attn.pp(i.to_string()),
390                    out_ch,
391                    n_heads,
392                    d_head,
393                    depth,
394                    context_dim,
395                    ip_dim,
396                    num_views,
397                    num_groups,
398                    use_linear,
399                )?);
400            }
401        }
402
403        let upsample = if has_upsample {
404            Some(Upsample2d::new(vs.pp("upsamplers.0"), out_ch)?)
405        } else {
406            None
407        };
408
409        Ok(Self {
410            resnets,
411            attentions,
412            upsample,
413        })
414    }
415
416    fn forward(
417        &self,
418        xs: &Tensor,
419        time_emb: &Tensor,
420        skip_connections: &mut Vec<Tensor>,
421        context: Option<&Tensor>,
422        ip_tokens: Option<&Tensor>,
423    ) -> std::result::Result<Tensor, DiffusionError> {
424        let mut h = xs.clone();
425
426        for (i, resnet) in self.resnets.iter().enumerate() {
427            let skip =
428                skip_connections
429                    .pop()
430                    .ok_or_else(|| DiffusionError::SkipConnectionUnderflow {
431                        expected: self.resnets.len(),
432                        available: i,
433                    })?;
434            h = Tensor::cat(&[h, skip], 1)?;
435            h = resnet.forward(&h, time_emb)?;
436            if !self.attentions.is_empty() {
437                h = self.attentions[i].forward(&h, context, ip_tokens)?;
438            }
439        }
440
441        if let Some(ref us) = self.upsample {
442            h = us.forward(&h)?;
443        }
444
445        Ok(h)
446    }
447}
448
449// ---------------------------------------------------------------------------
450// Multi-view U-Net
451// ---------------------------------------------------------------------------
452
453/// The multi-view U-Net for diffusion-based avatar generation.
454///
455/// Architecture matches SD 2.1 but with multi-view cross-attention in every
456/// spatial transformer block, camera-pose conditioning added to the timestep
457/// embedding, and IP-adapter cross-attention for reference-image conditioning.
458#[derive(Debug)]
459pub struct MultiViewUNet {
460    /// Input convolution: in_channels → base_channels.
461    conv_in: nn::Conv2d,
462    /// Sinusoidal → MLP time embedding.
463    time_embedding: TimestepEmbedding,
464    /// Camera-pose → time-embedding-dim MLP.
465    camera_embedding: CameraEmbedding,
466    /// Downsampling stages.
467    down_blocks: Vec<DownBlock>,
468    /// Bottleneck.
469    mid_block: MidBlock,
470    /// Upsampling stages.
471    up_blocks: Vec<UpBlock>,
472    /// Output: GroupNorm + conv → out_channels.
473    conv_norm_out: nn::GroupNorm,
474    conv_out: nn::Conv2d,
475    /// Model config.
476    config: DiffusionConfig,
477}
478
479impl MultiViewUNet {
480    /// Build the U-Net from a DiffusionConfig and VarBuilder.
481    pub fn new(vs: nn::VarBuilder, config: &DiffusionConfig) -> Result<Self> {
482        let base = config.base_channels;
483        let time_embed_dim = config.time_embed_dim;
484
485        // Input conv
486        let conv_in = nn::conv2d(
487            config.unet_in_channels,
488            base,
489            3,
490            nn::Conv2dConfig {
491                padding: 1,
492                ..Default::default()
493            },
494            vs.pp("conv_in"),
495        )?;
496
497        // Time embedding
498        let time_embedding = TimestepEmbedding::new(vs.pp("time_embedding"), base, time_embed_dim)?;
499
500        // Camera embedding
501        let camera_embedding = CameraEmbedding::new(
502            vs.pp("camera_embedding"),
503            config.camera_pose_dim,
504            time_embed_dim,
505        )?;
506
507        // Down blocks
508        let mut down_blocks = Vec::new();
509        let num_stages = config.num_stages();
510        let vs_down = vs.pp("down_blocks");
511        let mut input_ch = base;
512        for i in 0..num_stages {
513            let output_ch = config.stage_channels(i);
514            let n_heads = output_ch / config.attention_head_dim[i];
515            let d_head = config.attention_head_dim[i];
516            let depth = config.transformer_layers_per_block[i];
517            let has_ds = i < num_stages - 1;
518
519            down_blocks.push(DownBlock::new(
520                vs_down.pp(i.to_string()),
521                input_ch,
522                output_ch,
523                time_embed_dim,
524                config.layers_per_block,
525                true, // all stages have attention in SD 2.1 768-v
526                n_heads,
527                d_head,
528                depth,
529                config.cross_attention_dim,
530                config.clip_embed_dim,
531                config.num_views,
532                config.norm_num_groups,
533                config.use_linear_projection,
534                has_ds,
535            )?);
536            input_ch = output_ch;
537        }
538
539        // Mid block
540        let last_ch = config.stage_channels(num_stages - 1);
541        let mid_n_heads = last_ch / config.attention_head_dim[num_stages - 1];
542        let mid_d_head = config.attention_head_dim[num_stages - 1];
543        let mid_depth = config.transformer_layers_per_block[num_stages - 1];
544        let mid_block = MidBlock::new(
545            vs.pp("mid_block"),
546            last_ch,
547            time_embed_dim,
548            mid_n_heads,
549            mid_d_head,
550            mid_depth,
551            config.cross_attention_dim,
552            config.clip_embed_dim,
553            config.num_views,
554            config.norm_num_groups,
555            config.use_linear_projection,
556        )?;
557
558        // Up blocks (reverse order)
559        let mut up_blocks = Vec::new();
560        let vs_up = vs.pp("up_blocks");
561        let reversed_channels: Vec<usize> = (0..num_stages)
562            .rev()
563            .map(|i| config.stage_channels(i))
564            .collect();
565        let mut prev_ch = last_ch;
566        for i in 0..num_stages {
567            let output_ch = reversed_channels[i];
568            let skip_ch = if i == 0 {
569                last_ch
570            } else {
571                reversed_channels[i - 1]
572            };
573            let stage_idx = num_stages - 1 - i;
574            let n_heads = output_ch / config.attention_head_dim[stage_idx];
575            let d_head = config.attention_head_dim[stage_idx];
576            let depth = config.transformer_layers_per_block[stage_idx];
577            let has_us = i < num_stages - 1;
578
579            up_blocks.push(UpBlock::new(
580                vs_up.pp(i.to_string()),
581                prev_ch,
582                output_ch,
583                skip_ch,
584                time_embed_dim,
585                config.layers_per_block + 1, // +1 for the skip connection layer
586                true,
587                n_heads,
588                d_head,
589                depth,
590                config.cross_attention_dim,
591                config.clip_embed_dim,
592                config.num_views,
593                config.norm_num_groups,
594                config.use_linear_projection,
595                has_us,
596            )?);
597            prev_ch = output_ch;
598        }
599
600        // Output
601        let conv_norm_out = nn::group_norm(
602            config.norm_num_groups,
603            base,
604            config.norm_eps,
605            vs.pp("conv_norm_out"),
606        )?;
607        let conv_out = nn::conv2d(
608            base,
609            config.unet_out_channels,
610            3,
611            nn::Conv2dConfig {
612                padding: 1,
613                ..Default::default()
614            },
615            vs.pp("conv_out"),
616        )?;
617
618        Ok(Self {
619            conv_in,
620            time_embedding,
621            camera_embedding,
622            down_blocks,
623            mid_block,
624            up_blocks,
625            conv_norm_out,
626            conv_out,
627            config: config.clone(),
628        })
629    }
630
631    /// Forward pass.
632    ///
633    /// - `sample`: `(B*V, in_channels, H, W)` noisy latent input.
634    /// - `timestep`: scalar timestep (will be broadcast to batch).
635    /// - `context`: `(B*V, seq_len, cross_attn_dim)` text/null embedding.
636    /// - `camera_poses`: `(B*V, pose_dim)` flattened extrinsics.
637    /// - `ip_tokens`: `(B*V, ip_len, ip_dim)` CLIP image tokens.
638    ///
639    /// # Errors
640    ///
641    /// Returns `DiffusionError::SkipConnectionUnderflow` if skip connections are
642    /// exhausted before all up blocks have consumed them.
643    pub fn forward(
644        &self,
645        sample: &Tensor,
646        timestep: usize,
647        context: Option<&Tensor>,
648        camera_poses: Option<&Tensor>,
649        ip_tokens: Option<&Tensor>,
650    ) -> std::result::Result<Tensor, DiffusionError> {
651        let batch_size = sample.dim(0)?;
652        let device = sample.device();
653
654        // 1. Time embedding
655        let t_emb = timestep_embedding(
656            &Tensor::full(timestep as f32, (batch_size,), device)?,
657            self.config.base_channels,
658        )?;
659        let mut emb = self.time_embedding.forward(&t_emb)?;
660
661        // 2. Add camera embedding if provided
662        if let Some(cam) = camera_poses {
663            let cam_emb = self.camera_embedding.forward(cam)?;
664            emb = (emb + cam_emb)?;
665        }
666
667        // 3. Input conv
668        let mut h = self.conv_in.forward(sample)?;
669
670        // 4. Down blocks — collect skip connections
671        let mut all_skips: Vec<Tensor> = Vec::new();
672        for down in &self.down_blocks {
673            let (out, skips) = down.forward(&h, &emb, context, ip_tokens)?;
674            h = out;
675            all_skips.extend(skips);
676        }
677
678        // 5. Mid block
679        h = self.mid_block.forward(&h, &emb, context, ip_tokens)?;
680
681        // 6. Up blocks — consume skip connections
682        for up in &self.up_blocks {
683            h = up.forward(&h, &emb, &mut all_skips, context, ip_tokens)?;
684        }
685
686        // 7. Output
687        h = self.conv_norm_out.forward(&h)?.silu()?;
688        Ok(self.conv_out.forward(&h)?)
689    }
690}