1use candle_core::{Result, Tensor};
33use candle_nn as nn;
34use candle_nn::Module;
35
36use crate::attention::MultiViewSpatialTransformer;
37use crate::camera::{timestep_embedding, CameraEmbedding, TimestepEmbedding};
38use crate::config::DiffusionConfig;
39use crate::DiffusionError;
40
41#[derive(Debug)]
47struct ResBlock {
48 norm1: nn::GroupNorm,
49 conv1: nn::Conv2d,
50 time_emb_proj: nn::Linear,
51 norm2: nn::GroupNorm,
52 conv2: nn::Conv2d,
53 residual_conv: Option<nn::Conv2d>,
54}
55
56impl ResBlock {
57 fn new(vs: nn::VarBuilder, in_ch: usize, out_ch: usize, time_dim: usize) -> Result<Self> {
58 let norm1 = nn::group_norm(32, in_ch, 1e-5, vs.pp("norm1"))?;
59 let conv1 = nn::conv2d(
60 in_ch,
61 out_ch,
62 3,
63 nn::Conv2dConfig {
64 padding: 1,
65 ..Default::default()
66 },
67 vs.pp("conv1"),
68 )?;
69 let time_emb_proj = nn::linear(time_dim, out_ch, vs.pp("time_emb_proj"))?;
70 let norm2 = nn::group_norm(32, out_ch, 1e-5, vs.pp("norm2"))?;
71 let conv2 = nn::conv2d(
72 out_ch,
73 out_ch,
74 3,
75 nn::Conv2dConfig {
76 padding: 1,
77 ..Default::default()
78 },
79 vs.pp("conv2"),
80 )?;
81 let residual_conv = if in_ch != out_ch {
82 Some(nn::conv2d(
83 in_ch,
84 out_ch,
85 1,
86 Default::default(),
87 vs.pp("conv_shortcut"),
88 )?)
89 } else {
90 None
91 };
92 Ok(Self {
93 norm1,
94 conv1,
95 time_emb_proj,
96 norm2,
97 conv2,
98 residual_conv,
99 })
100 }
101
102 fn forward(&self, xs: &Tensor, time_emb: &Tensor) -> Result<Tensor> {
103 let residual = if let Some(ref conv) = self.residual_conv {
104 conv.forward(xs)?
105 } else {
106 xs.clone()
107 };
108 let h = self.norm1.forward(xs)?.silu()?;
109 let h = self.conv1.forward(&h)?;
110
111 let t = self.time_emb_proj.forward(&time_emb.silu()?)?;
113 let t = t.unsqueeze(2)?.unsqueeze(3)?;
114 let h = (h.clone() + t.broadcast_as(h.shape())?)?;
115
116 let h = self.norm2.forward(&h)?.silu()?;
117 let h = self.conv2.forward(&h)?;
118 h + residual
119 }
120}
121
122#[derive(Debug)]
124struct Downsample2d {
125 conv: nn::Conv2d,
126}
127
128impl Downsample2d {
129 fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
130 let conv = nn::conv2d(
131 channels,
132 channels,
133 3,
134 nn::Conv2dConfig {
135 stride: 2,
136 padding: 1,
137 ..Default::default()
138 },
139 vs.pp("conv"),
140 )?;
141 Ok(Self { conv })
142 }
143}
144
145impl Module for Downsample2d {
146 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
147 self.conv.forward(xs)
148 }
149}
150
151#[derive(Debug)]
153struct Upsample2d {
154 conv: nn::Conv2d,
155}
156
157impl Upsample2d {
158 fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
159 let conv = nn::conv2d(
160 channels,
161 channels,
162 3,
163 nn::Conv2dConfig {
164 padding: 1,
165 ..Default::default()
166 },
167 vs.pp("conv"),
168 )?;
169 Ok(Self { conv })
170 }
171}
172
173impl Module for Upsample2d {
174 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
175 let (_, _, h, w) = xs.dims4()?;
176 let xs = xs.upsample_nearest2d(h * 2, w * 2)?;
177 self.conv.forward(&xs)
178 }
179}
180
181#[derive(Debug)]
187struct DownBlock {
188 resnets: Vec<ResBlock>,
189 attentions: Vec<MultiViewSpatialTransformer>,
190 downsample: Option<Downsample2d>,
191}
192
193impl DownBlock {
194 #[allow(clippy::too_many_arguments)]
195 fn new(
196 vs: nn::VarBuilder,
197 in_ch: usize,
198 out_ch: usize,
199 time_dim: usize,
200 num_layers: usize,
201 has_attn: bool,
202 n_heads: usize,
203 d_head: usize,
204 depth: usize,
205 context_dim: usize,
206 ip_dim: usize,
207 num_views: usize,
208 num_groups: usize,
209 use_linear: bool,
210 has_downsample: bool,
211 ) -> Result<Self> {
212 let vs_res = vs.pp("resnets");
213 let mut resnets = Vec::with_capacity(num_layers);
214 for i in 0..num_layers {
215 let ich = if i == 0 { in_ch } else { out_ch };
216 resnets.push(ResBlock::new(
217 vs_res.pp(i.to_string()),
218 ich,
219 out_ch,
220 time_dim,
221 )?);
222 }
223
224 let mut attentions = Vec::new();
225 if has_attn {
226 let vs_attn = vs.pp("attentions");
227 for i in 0..num_layers {
228 attentions.push(MultiViewSpatialTransformer::new(
229 vs_attn.pp(i.to_string()),
230 out_ch,
231 n_heads,
232 d_head,
233 depth,
234 context_dim,
235 ip_dim,
236 num_views,
237 num_groups,
238 use_linear,
239 )?);
240 }
241 }
242
243 let downsample = if has_downsample {
244 Some(Downsample2d::new(vs.pp("downsamplers.0"), out_ch)?)
245 } else {
246 None
247 };
248
249 Ok(Self {
250 resnets,
251 attentions,
252 downsample,
253 })
254 }
255
256 fn forward(
257 &self,
258 xs: &Tensor,
259 time_emb: &Tensor,
260 context: Option<&Tensor>,
261 ip_tokens: Option<&Tensor>,
262 ) -> Result<(Tensor, Vec<Tensor>)> {
263 let mut h = xs.clone();
264 let mut skip_connections = Vec::new();
265
266 for (i, resnet) in self.resnets.iter().enumerate() {
267 h = resnet.forward(&h, time_emb)?;
268 if !self.attentions.is_empty() {
269 h = self.attentions[i].forward(&h, context, ip_tokens)?;
270 }
271 skip_connections.push(h.clone());
272 }
273
274 if let Some(ref ds) = self.downsample {
275 h = ds.forward(&h)?;
276 skip_connections.push(h.clone());
277 }
278
279 Ok((h, skip_connections))
280 }
281}
282
283#[derive(Debug)]
285struct MidBlock {
286 resnet1: ResBlock,
287 attention: MultiViewSpatialTransformer,
288 resnet2: ResBlock,
289}
290
291impl MidBlock {
292 #[allow(clippy::too_many_arguments)]
293 fn new(
294 vs: nn::VarBuilder,
295 channels: usize,
296 time_dim: usize,
297 n_heads: usize,
298 d_head: usize,
299 depth: usize,
300 context_dim: usize,
301 ip_dim: usize,
302 num_views: usize,
303 num_groups: usize,
304 use_linear: bool,
305 ) -> Result<Self> {
306 let resnet1 = ResBlock::new(vs.pp("resnets.0"), channels, channels, time_dim)?;
307 let attention = MultiViewSpatialTransformer::new(
308 vs.pp("attentions.0"),
309 channels,
310 n_heads,
311 d_head,
312 depth,
313 context_dim,
314 ip_dim,
315 num_views,
316 num_groups,
317 use_linear,
318 )?;
319 let resnet2 = ResBlock::new(vs.pp("resnets.1"), channels, channels, time_dim)?;
320 Ok(Self {
321 resnet1,
322 attention,
323 resnet2,
324 })
325 }
326
327 fn forward(
328 &self,
329 xs: &Tensor,
330 time_emb: &Tensor,
331 context: Option<&Tensor>,
332 ip_tokens: Option<&Tensor>,
333 ) -> Result<Tensor> {
334 let h = self.resnet1.forward(xs, time_emb)?;
335 let h = self.attention.forward(&h, context, ip_tokens)?;
336 self.resnet2.forward(&h, time_emb)
337 }
338}
339
340#[derive(Debug)]
342struct UpBlock {
343 resnets: Vec<ResBlock>,
344 attentions: Vec<MultiViewSpatialTransformer>,
345 upsample: Option<Upsample2d>,
346}
347
348impl UpBlock {
349 #[allow(clippy::too_many_arguments)]
350 fn new(
351 vs: nn::VarBuilder,
352 in_ch: usize,
353 out_ch: usize,
354 skip_ch: usize,
355 time_dim: usize,
356 num_layers: usize,
357 has_attn: bool,
358 n_heads: usize,
359 d_head: usize,
360 depth: usize,
361 context_dim: usize,
362 ip_dim: usize,
363 num_views: usize,
364 num_groups: usize,
365 use_linear: bool,
366 has_upsample: bool,
367 ) -> Result<Self> {
368 let vs_res = vs.pp("resnets");
369 let mut resnets = Vec::with_capacity(num_layers);
370 for i in 0..num_layers {
371 let ich = if i == 0 {
372 in_ch + skip_ch
373 } else {
374 out_ch + skip_ch
375 };
376 resnets.push(ResBlock::new(
377 vs_res.pp(i.to_string()),
378 ich,
379 out_ch,
380 time_dim,
381 )?);
382 }
383
384 let mut attentions = Vec::new();
385 if has_attn {
386 let vs_attn = vs.pp("attentions");
387 for i in 0..num_layers {
388 attentions.push(MultiViewSpatialTransformer::new(
389 vs_attn.pp(i.to_string()),
390 out_ch,
391 n_heads,
392 d_head,
393 depth,
394 context_dim,
395 ip_dim,
396 num_views,
397 num_groups,
398 use_linear,
399 )?);
400 }
401 }
402
403 let upsample = if has_upsample {
404 Some(Upsample2d::new(vs.pp("upsamplers.0"), out_ch)?)
405 } else {
406 None
407 };
408
409 Ok(Self {
410 resnets,
411 attentions,
412 upsample,
413 })
414 }
415
416 fn forward(
417 &self,
418 xs: &Tensor,
419 time_emb: &Tensor,
420 skip_connections: &mut Vec<Tensor>,
421 context: Option<&Tensor>,
422 ip_tokens: Option<&Tensor>,
423 ) -> std::result::Result<Tensor, DiffusionError> {
424 let mut h = xs.clone();
425
426 for (i, resnet) in self.resnets.iter().enumerate() {
427 let skip =
428 skip_connections
429 .pop()
430 .ok_or_else(|| DiffusionError::SkipConnectionUnderflow {
431 expected: self.resnets.len(),
432 available: i,
433 })?;
434 h = Tensor::cat(&[h, skip], 1)?;
435 h = resnet.forward(&h, time_emb)?;
436 if !self.attentions.is_empty() {
437 h = self.attentions[i].forward(&h, context, ip_tokens)?;
438 }
439 }
440
441 if let Some(ref us) = self.upsample {
442 h = us.forward(&h)?;
443 }
444
445 Ok(h)
446 }
447}
448
449#[derive(Debug)]
459pub struct MultiViewUNet {
460 conv_in: nn::Conv2d,
462 time_embedding: TimestepEmbedding,
464 camera_embedding: CameraEmbedding,
466 down_blocks: Vec<DownBlock>,
468 mid_block: MidBlock,
470 up_blocks: Vec<UpBlock>,
472 conv_norm_out: nn::GroupNorm,
474 conv_out: nn::Conv2d,
475 config: DiffusionConfig,
477}
478
479impl MultiViewUNet {
480 pub fn new(vs: nn::VarBuilder, config: &DiffusionConfig) -> Result<Self> {
482 let base = config.base_channels;
483 let time_embed_dim = config.time_embed_dim;
484
485 let conv_in = nn::conv2d(
487 config.unet_in_channels,
488 base,
489 3,
490 nn::Conv2dConfig {
491 padding: 1,
492 ..Default::default()
493 },
494 vs.pp("conv_in"),
495 )?;
496
497 let time_embedding = TimestepEmbedding::new(vs.pp("time_embedding"), base, time_embed_dim)?;
499
500 let camera_embedding = CameraEmbedding::new(
502 vs.pp("camera_embedding"),
503 config.camera_pose_dim,
504 time_embed_dim,
505 )?;
506
507 let mut down_blocks = Vec::new();
509 let num_stages = config.num_stages();
510 let vs_down = vs.pp("down_blocks");
511 let mut input_ch = base;
512 for i in 0..num_stages {
513 let output_ch = config.stage_channels(i);
514 let n_heads = output_ch / config.attention_head_dim[i];
515 let d_head = config.attention_head_dim[i];
516 let depth = config.transformer_layers_per_block[i];
517 let has_ds = i < num_stages - 1;
518
519 down_blocks.push(DownBlock::new(
520 vs_down.pp(i.to_string()),
521 input_ch,
522 output_ch,
523 time_embed_dim,
524 config.layers_per_block,
525 true, n_heads,
527 d_head,
528 depth,
529 config.cross_attention_dim,
530 config.clip_embed_dim,
531 config.num_views,
532 config.norm_num_groups,
533 config.use_linear_projection,
534 has_ds,
535 )?);
536 input_ch = output_ch;
537 }
538
539 let last_ch = config.stage_channels(num_stages - 1);
541 let mid_n_heads = last_ch / config.attention_head_dim[num_stages - 1];
542 let mid_d_head = config.attention_head_dim[num_stages - 1];
543 let mid_depth = config.transformer_layers_per_block[num_stages - 1];
544 let mid_block = MidBlock::new(
545 vs.pp("mid_block"),
546 last_ch,
547 time_embed_dim,
548 mid_n_heads,
549 mid_d_head,
550 mid_depth,
551 config.cross_attention_dim,
552 config.clip_embed_dim,
553 config.num_views,
554 config.norm_num_groups,
555 config.use_linear_projection,
556 )?;
557
558 let mut up_blocks = Vec::new();
560 let vs_up = vs.pp("up_blocks");
561 let reversed_channels: Vec<usize> = (0..num_stages)
562 .rev()
563 .map(|i| config.stage_channels(i))
564 .collect();
565 let mut prev_ch = last_ch;
566 for i in 0..num_stages {
567 let output_ch = reversed_channels[i];
568 let skip_ch = if i == 0 {
569 last_ch
570 } else {
571 reversed_channels[i - 1]
572 };
573 let stage_idx = num_stages - 1 - i;
574 let n_heads = output_ch / config.attention_head_dim[stage_idx];
575 let d_head = config.attention_head_dim[stage_idx];
576 let depth = config.transformer_layers_per_block[stage_idx];
577 let has_us = i < num_stages - 1;
578
579 up_blocks.push(UpBlock::new(
580 vs_up.pp(i.to_string()),
581 prev_ch,
582 output_ch,
583 skip_ch,
584 time_embed_dim,
585 config.layers_per_block + 1, true,
587 n_heads,
588 d_head,
589 depth,
590 config.cross_attention_dim,
591 config.clip_embed_dim,
592 config.num_views,
593 config.norm_num_groups,
594 config.use_linear_projection,
595 has_us,
596 )?);
597 prev_ch = output_ch;
598 }
599
600 let conv_norm_out = nn::group_norm(
602 config.norm_num_groups,
603 base,
604 config.norm_eps,
605 vs.pp("conv_norm_out"),
606 )?;
607 let conv_out = nn::conv2d(
608 base,
609 config.unet_out_channels,
610 3,
611 nn::Conv2dConfig {
612 padding: 1,
613 ..Default::default()
614 },
615 vs.pp("conv_out"),
616 )?;
617
618 Ok(Self {
619 conv_in,
620 time_embedding,
621 camera_embedding,
622 down_blocks,
623 mid_block,
624 up_blocks,
625 conv_norm_out,
626 conv_out,
627 config: config.clone(),
628 })
629 }
630
631 pub fn forward(
644 &self,
645 sample: &Tensor,
646 timestep: usize,
647 context: Option<&Tensor>,
648 camera_poses: Option<&Tensor>,
649 ip_tokens: Option<&Tensor>,
650 ) -> std::result::Result<Tensor, DiffusionError> {
651 let batch_size = sample.dim(0)?;
652 let device = sample.device();
653
654 let t_emb = timestep_embedding(
656 &Tensor::full(timestep as f32, (batch_size,), device)?,
657 self.config.base_channels,
658 )?;
659 let mut emb = self.time_embedding.forward(&t_emb)?;
660
661 if let Some(cam) = camera_poses {
663 let cam_emb = self.camera_embedding.forward(cam)?;
664 emb = (emb + cam_emb)?;
665 }
666
667 let mut h = self.conv_in.forward(sample)?;
669
670 let mut all_skips: Vec<Tensor> = Vec::new();
672 for down in &self.down_blocks {
673 let (out, skips) = down.forward(&h, &emb, context, ip_tokens)?;
674 h = out;
675 all_skips.extend(skips);
676 }
677
678 h = self.mid_block.forward(&h, &emb, context, ip_tokens)?;
680
681 for up in &self.up_blocks {
683 h = up.forward(&h, &emb, &mut all_skips, context, ip_tokens)?;
684 }
685
686 h = self.conv_norm_out.forward(&h)?.silu()?;
688 Ok(self.conv_out.forward(&h)?)
689 }
690}