1use candle_core::{DType, Device, Result, Tensor};
11use candle_nn as nn;
12use candle_nn::Module;
13
14use crate::scheduler::{DdimScheduler, PredictionType};
15use crate::DiffusionError;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum UpsamplerMode {
24 SdX2,
26 BilinearVae,
28}
29
30#[derive(Debug)]
36struct UpsamplerResBlock {
37 norm1: nn::GroupNorm,
38 conv1: nn::Conv2d,
39 time_emb_proj: nn::Linear,
40 norm2: nn::GroupNorm,
41 conv2: nn::Conv2d,
42 residual_conv: Option<nn::Conv2d>,
43}
44
45impl UpsamplerResBlock {
46 fn new(vs: nn::VarBuilder, in_ch: usize, out_ch: usize, time_dim: usize) -> Result<Self> {
47 let norm1 = nn::group_norm(32, in_ch, 1e-5, vs.pp("norm1"))?;
48 let conv1 = nn::conv2d(
49 in_ch,
50 out_ch,
51 3,
52 nn::Conv2dConfig {
53 padding: 1,
54 ..Default::default()
55 },
56 vs.pp("conv1"),
57 )?;
58 let time_emb_proj = nn::linear(time_dim, out_ch, vs.pp("time_emb_proj"))?;
59 let norm2 = nn::group_norm(32, out_ch, 1e-5, vs.pp("norm2"))?;
60 let conv2 = nn::conv2d(
61 out_ch,
62 out_ch,
63 3,
64 nn::Conv2dConfig {
65 padding: 1,
66 ..Default::default()
67 },
68 vs.pp("conv2"),
69 )?;
70 let residual_conv = if in_ch != out_ch {
71 Some(nn::conv2d(
72 in_ch,
73 out_ch,
74 1,
75 Default::default(),
76 vs.pp("conv_shortcut"),
77 )?)
78 } else {
79 None
80 };
81 Ok(Self {
82 norm1,
83 conv1,
84 time_emb_proj,
85 norm2,
86 conv2,
87 residual_conv,
88 })
89 }
90
91 fn forward(&self, xs: &Tensor, time_emb: &Tensor) -> Result<Tensor> {
92 let residual = if let Some(ref conv) = self.residual_conv {
93 conv.forward(xs)?
94 } else {
95 xs.clone()
96 };
97 let h = self.norm1.forward(xs)?.silu()?;
98 let h = self.conv1.forward(&h)?;
99
100 let t = self.time_emb_proj.forward(&time_emb.silu()?)?;
102 let t = t.unsqueeze(2)?.unsqueeze(3)?;
103 let h = (h.clone() + t.broadcast_as(h.shape())?)?;
104
105 let h = self.norm2.forward(&h)?.silu()?;
106 let h = self.conv2.forward(&h)?;
107 h + residual
108 }
109}
110
111#[derive(Debug)]
113struct UpsamplerDownsample {
114 conv: nn::Conv2d,
115}
116
117impl UpsamplerDownsample {
118 fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
119 let conv = nn::conv2d(
120 channels,
121 channels,
122 3,
123 nn::Conv2dConfig {
124 stride: 2,
125 padding: 1,
126 ..Default::default()
127 },
128 vs.pp("conv"),
129 )?;
130 Ok(Self { conv })
131 }
132}
133
134impl Module for UpsamplerDownsample {
135 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
136 self.conv.forward(xs)
137 }
138}
139
140#[derive(Debug)]
142struct UpsamplerUpsample {
143 conv: nn::Conv2d,
144}
145
146impl UpsamplerUpsample {
147 fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
148 let conv = nn::conv2d(
149 channels,
150 channels,
151 3,
152 nn::Conv2dConfig {
153 padding: 1,
154 ..Default::default()
155 },
156 vs.pp("conv"),
157 )?;
158 Ok(Self { conv })
159 }
160}
161
162impl Module for UpsamplerUpsample {
163 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
164 let (_, _, h, w) = xs.dims4()?;
165 let xs = xs.upsample_nearest2d(h * 2, w * 2)?;
166 self.conv.forward(&xs)
167 }
168}
169
170#[derive(Debug)]
172struct UpsamplerAttention {
173 group_norm: nn::GroupNorm,
174 to_qkv: nn::Conv2d,
175 to_out: nn::Conv2d,
176 channels: usize,
177}
178
179impl UpsamplerAttention {
180 fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
181 let group_norm = nn::group_norm(32, channels, 1e-5, vs.pp("group_norm"))?;
182 let to_qkv = nn::conv2d(
183 channels,
184 channels * 3,
185 1,
186 Default::default(),
187 vs.pp("to_qkv"),
188 )?;
189 let to_out = nn::conv2d(channels, channels, 1, Default::default(), vs.pp("to_out"))?;
190 Ok(Self {
191 group_norm,
192 to_qkv,
193 to_out,
194 channels,
195 })
196 }
197}
198
199impl Module for UpsamplerAttention {
200 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
201 let residual = xs;
202 let (b, _c, h, w) = xs.dims4()?;
203 let xs = self.group_norm.forward(xs)?;
204 let qkv = self.to_qkv.forward(&xs)?;
205 let qkv = qkv.reshape((b, 3, self.channels, h * w))?;
206 let q = qkv.narrow(1, 0, 1)?.squeeze(1)?;
207 let k = qkv.narrow(1, 1, 1)?.squeeze(1)?;
208 let v = qkv.narrow(1, 2, 1)?.squeeze(1)?;
209
210 let scale = (self.channels as f64).powf(-0.5);
211 let attn = (q.transpose(1, 2)?.matmul(&k)? * scale)?;
212 let attn = nn::ops::softmax_last_dim(&attn)?;
213 let out = v.matmul(&attn.transpose(1, 2)?)?;
214 let out = out.reshape((b, self.channels, h, w))?;
215 let out = self.to_out.forward(&out)?;
216 out + residual
217 }
218}
219
220fn timestep_embedding(timesteps: &Tensor, dim: usize) -> Result<Tensor> {
226 let half = dim / 2;
227 let freqs = Tensor::arange(0f32, half as f32, timesteps.device())?;
228 let freqs = (freqs * (-10000f64.ln() / half as f64))?;
229 let freqs = freqs.exp()?;
230 let args = timesteps
231 .unsqueeze(1)?
232 .broadcast_mul(&freqs.unsqueeze(0)?)?;
233 let sin = args.sin()?;
234 let cos = args.cos()?;
235 Tensor::cat(&[&cos, &sin], 1)
236}
237
238#[derive(Debug)]
240struct TimeEmbedding {
241 linear1: nn::Linear,
242 linear2: nn::Linear,
243}
244
245impl TimeEmbedding {
246 fn new(vs: nn::VarBuilder, in_dim: usize, out_dim: usize) -> Result<Self> {
247 let linear1 = nn::linear(in_dim, out_dim, vs.pp("linear_1"))?;
248 let linear2 = nn::linear(out_dim, out_dim, vs.pp("linear_2"))?;
249 Ok(Self { linear1, linear2 })
250 }
251
252 fn forward(&self, x: &Tensor) -> Result<Tensor> {
253 let h = self.linear1.forward(x)?.silu()?;
254 self.linear2.forward(&h)
255 }
256}
257
258#[derive(Debug)]
267struct UpsamplerUNet {
268 conv_in: nn::Conv2d,
269 time_embedding: TimeEmbedding,
270 down_res_1: UpsamplerResBlock,
272 down_attn_1: Option<UpsamplerAttention>,
273 down_res_2: UpsamplerResBlock,
274 downsample: UpsamplerDownsample,
275 mid_res_1: UpsamplerResBlock,
277 mid_attn: UpsamplerAttention,
278 mid_res_2: UpsamplerResBlock,
279 upsample: UpsamplerUpsample,
281 up_res_1: UpsamplerResBlock,
282 up_attn_1: Option<UpsamplerAttention>,
283 up_res_2: UpsamplerResBlock,
284 conv_norm_out: nn::GroupNorm,
286 conv_out: nn::Conv2d,
287}
288
289impl UpsamplerUNet {
290 fn new(vs: nn::VarBuilder) -> Result<Self> {
298 let in_channels = 4;
299 let out_channels = 4;
300 let base_channels = 128;
301 let time_dim = 512;
302
303 let conv_in = nn::conv2d(
304 in_channels,
305 base_channels,
306 3,
307 nn::Conv2dConfig {
308 padding: 1,
309 ..Default::default()
310 },
311 vs.pp("conv_in"),
312 )?;
313
314 let time_embedding = TimeEmbedding::new(vs.pp("time_embedding"), base_channels, time_dim)?;
315
316 let down_res_1 = UpsamplerResBlock::new(
318 vs.pp("down_blocks.0.resnets.0"),
319 base_channels,
320 base_channels,
321 time_dim,
322 )?;
323 let down_attn_1 = Some(UpsamplerAttention::new(
324 vs.pp("down_blocks.0.attentions.0"),
325 base_channels,
326 )?);
327 let down_res_2 = UpsamplerResBlock::new(
328 vs.pp("down_blocks.0.resnets.1"),
329 base_channels,
330 base_channels,
331 time_dim,
332 )?;
333 let downsample =
334 UpsamplerDownsample::new(vs.pp("down_blocks.0.downsamplers.0"), base_channels)?;
335
336 let mid_res_1 = UpsamplerResBlock::new(
338 vs.pp("mid_block.resnets.0"),
339 base_channels,
340 base_channels,
341 time_dim,
342 )?;
343 let mid_attn = UpsamplerAttention::new(vs.pp("mid_block.attentions.0"), base_channels)?;
344 let mid_res_2 = UpsamplerResBlock::new(
345 vs.pp("mid_block.resnets.1"),
346 base_channels,
347 base_channels,
348 time_dim,
349 )?;
350
351 let upsample = UpsamplerUpsample::new(vs.pp("up_blocks.0.upsamplers.0"), base_channels)?;
353 let up_res_1 = UpsamplerResBlock::new(
354 vs.pp("up_blocks.0.resnets.0"),
355 base_channels * 2, base_channels,
357 time_dim,
358 )?;
359 let up_attn_1 = Some(UpsamplerAttention::new(
360 vs.pp("up_blocks.0.attentions.0"),
361 base_channels,
362 )?);
363 let up_res_2 = UpsamplerResBlock::new(
364 vs.pp("up_blocks.0.resnets.1"),
365 base_channels * 2, base_channels,
367 time_dim,
368 )?;
369
370 let conv_norm_out = nn::group_norm(32, base_channels, 1e-5, vs.pp("conv_norm_out"))?;
372 let conv_out = nn::conv2d(
373 base_channels,
374 out_channels,
375 3,
376 nn::Conv2dConfig {
377 padding: 1,
378 ..Default::default()
379 },
380 vs.pp("conv_out"),
381 )?;
382
383 Ok(Self {
384 conv_in,
385 time_embedding,
386 down_res_1,
387 down_attn_1,
388 down_res_2,
389 downsample,
390 mid_res_1,
391 mid_attn,
392 mid_res_2,
393 upsample,
394 up_res_1,
395 up_attn_1,
396 up_res_2,
397 conv_norm_out,
398 conv_out,
399 })
400 }
401
402 fn forward(&self, sample: &Tensor, timestep: usize) -> Result<Tensor> {
404 let batch_size = sample.dim(0)?;
405 let device = sample.device();
406
407 let t_emb =
409 timestep_embedding(&Tensor::full(timestep as f32, (batch_size,), device)?, 128)?;
410 let emb = self.time_embedding.forward(&t_emb)?;
411
412 let mut h = self.conv_in.forward(sample)?;
414
415 let skip1 = h.clone();
417 h = self.down_res_1.forward(&h, &emb)?;
418 if let Some(ref attn) = self.down_attn_1 {
419 h = attn.forward(&h)?;
420 }
421 let skip2 = h.clone();
422 h = self.down_res_2.forward(&h, &emb)?;
423 h = self.downsample.forward(&h)?;
424
425 h = self.mid_res_1.forward(&h, &emb)?;
427 h = self.mid_attn.forward(&h)?;
428 h = self.mid_res_2.forward(&h, &emb)?;
429
430 h = self.upsample.forward(&h)?;
432 h = Tensor::cat(&[h, skip2], 1)?;
433 h = self.up_res_1.forward(&h, &emb)?;
434 if let Some(ref attn) = self.up_attn_1 {
435 h = attn.forward(&h)?;
436 }
437 h = Tensor::cat(&[h, skip1], 1)?;
438 h = self.up_res_2.forward(&h, &emb)?;
439
440 h = self.conv_norm_out.forward(&h)?.silu()?;
442 self.conv_out.forward(&h)
443 }
444}
445
446#[derive(Debug)]
478pub struct LatentUpsampler {
479 mode: UpsamplerMode,
480 unet: Option<UpsamplerUNet>,
481 scheduler: DdimScheduler,
482 device: Device,
483}
484
485impl LatentUpsampler {
486 pub fn load(
498 mode: UpsamplerMode,
499 weights_path: &std::path::Path,
500 device: &Device,
501 ) -> std::result::Result<Self, DiffusionError> {
502 let unet = match mode {
503 UpsamplerMode::SdX2 => {
504 let dtype = DType::F32;
505 let safetensors_path = weights_path.join("diffusion_pytorch_model.safetensors");
506 let data = std::fs::read(&safetensors_path).map_err(|e| {
507 DiffusionError::ModelLoad(format!("Failed to read upsampler weights: {e}"))
508 })?;
509 let vb = nn::VarBuilder::from_buffered_safetensors(data, dtype, device)
510 .map_err(|e| DiffusionError::ModelLoad(format!("Upsampler VarBuilder: {e}")))?;
511 Some(UpsamplerUNet::new(vb).map_err(|e| {
512 DiffusionError::ModelLoad(format!("Upsampler U-Net build: {e}"))
513 })?)
514 }
515 UpsamplerMode::BilinearVae => None,
516 };
517
518 let scheduler = DdimScheduler::new(1000, PredictionType::Epsilon);
519
520 Ok(Self {
521 mode,
522 unet,
523 scheduler,
524 device: device.clone(),
525 })
526 }
527
528 pub fn upsample(
543 &mut self,
544 latents: &Tensor,
545 num_steps: usize,
546 ) -> std::result::Result<Tensor, DiffusionError> {
547 match self.mode {
548 UpsamplerMode::SdX2 => self.upsample_sdx2(latents, num_steps),
549 UpsamplerMode::BilinearVae => self.upsample_bilinear(latents),
550 }
551 }
552
553 fn upsample_sdx2(
555 &mut self,
556 latents: &Tensor,
557 num_steps: usize,
558 ) -> std::result::Result<Tensor, DiffusionError> {
559 let unet = self.unet.as_ref().ok_or_else(|| {
560 DiffusionError::Inference("SdX2 mode requires U-Net, but it's not loaded".to_string())
561 })?;
562
563 let (batch, ch, h, w) = latents
565 .dims4()
566 .map_err(|e| DiffusionError::Inference(format!("Invalid latent shape: {e}")))?;
567 if ch != 4 || h != 32 || w != 32 {
568 return Err(DiffusionError::InvalidLatentShape {
569 expected: vec![batch, 4, 32, 32],
570 got: vec![batch, ch, h, w],
571 });
572 }
573
574 let mut current = Tensor::randn(0f32, 1f32, (batch, 4, 64, 64), &self.device)
576 .map_err(|e| DiffusionError::Inference(format!("Noise init: {e}")))?;
577
578 self.scheduler.set_timesteps(num_steps);
580 let timesteps = self.scheduler.timesteps().to_vec();
581
582 for &t in ×teps {
584 let current_downsampled = current
586 .upsample_nearest2d(32, 32)
587 .map_err(|e| DiffusionError::Inference(format!("Downsample for condition: {e}")))?;
588
589 let model_input = Tensor::cat(&[¤t_downsampled, latents], 1)
591 .map_err(|e| DiffusionError::Inference(format!("Concat condition: {e}")))?;
592
593 let noise_pred =
595 unet.forward(&model_input, t)
596 .map_err(|e| DiffusionError::UnetForwardFailed {
597 timestep: t,
598 reason: format!("{e}"),
599 })?;
600
601 current = self
603 .scheduler
604 .step(&noise_pred, t, ¤t)
605 .map_err(|e| DiffusionError::Inference(format!("Scheduler step: {e}")))?;
606 }
607
608 Ok(current)
609 }
610
611 fn upsample_bilinear(&self, latents: &Tensor) -> std::result::Result<Tensor, DiffusionError> {
613 latents
614 .upsample_nearest2d(64, 64)
615 .map_err(|e| DiffusionError::Inference(format!("Bilinear upsample: {e}")))
616 }
617}
618
619#[cfg(test)]
624mod tests {
625 use super::*;
626
627 #[test]
628 fn test_bilinear_upsampling() -> Result<()> {
629 let device = Device::Cpu;
630 let mut upsampler = LatentUpsampler {
631 mode: UpsamplerMode::BilinearVae,
632 unet: None,
633 scheduler: DdimScheduler::new(1000, PredictionType::Epsilon),
634 device: device.clone(),
635 };
636
637 let latents = Tensor::randn(0f32, 1f32, (2, 4, 32, 32), &device)?;
638 let upsampled = upsampler
639 .upsample(&latents, 10)
640 .map_err(|e| candle_core::Error::Msg(format!("{e}")))?;
641
642 assert_eq!(upsampled.dims(), &[2, 4, 64, 64]);
643 Ok(())
644 }
645
646 #[test]
647 fn test_upsampler_mode_equality() {
648 assert_eq!(UpsamplerMode::SdX2, UpsamplerMode::SdX2);
649 assert_eq!(UpsamplerMode::BilinearVae, UpsamplerMode::BilinearVae);
650 assert_ne!(UpsamplerMode::SdX2, UpsamplerMode::BilinearVae);
651 }
652
653 #[test]
654 fn test_timestep_embedding_shape() -> Result<()> {
655 let device = Device::Cpu;
656 let timesteps = Tensor::full(500f32, (4,), &device)?;
657 let emb = timestep_embedding(×teps, 128)?;
658 assert_eq!(emb.dims(), &[4, 128]);
659 Ok(())
660 }
661
662 #[test]
663 fn test_bilinear_with_different_batch_sizes() -> Result<()> {
664 let device = Device::Cpu;
665 let mut upsampler = LatentUpsampler {
666 mode: UpsamplerMode::BilinearVae,
667 unet: None,
668 scheduler: DdimScheduler::new(1000, PredictionType::Epsilon),
669 device: device.clone(),
670 };
671
672 for batch_size in [1, 2, 4, 8] {
673 let latents = Tensor::randn(0f32, 1f32, (batch_size, 4, 32, 32), &device)?;
674 let upsampled = upsampler
675 .upsample(&latents, 10)
676 .map_err(|e| candle_core::Error::Msg(format!("{e}")))?;
677 assert_eq!(upsampled.dims(), &[batch_size, 4, 64, 64]);
678 }
679 Ok(())
680 }
681
682 #[test]
683 fn test_invalid_latent_shape_sdx2() -> Result<()> {
684 let device = Device::Cpu;
685 let mut upsampler = LatentUpsampler {
687 mode: UpsamplerMode::SdX2,
688 unet: None,
689 scheduler: DdimScheduler::new(1000, PredictionType::Epsilon),
690 device: device.clone(),
691 };
692
693 let latents = Tensor::randn(0f32, 1f32, (2, 3, 32, 32), &device)?; let result = upsampler.upsample(&latents, 10);
695 assert!(result.is_err());
696 Ok(())
697 }
698}