1use candle_core::{Result, Tensor};
8use candle_nn as nn;
9use candle_nn::Module;
10
11#[derive(Debug)]
17struct ResnetBlock {
18 norm1: nn::GroupNorm,
19 conv1: nn::Conv2d,
20 norm2: nn::GroupNorm,
21 conv2: nn::Conv2d,
22 residual_conv: Option<nn::Conv2d>,
23}
24
25impl ResnetBlock {
26 fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
27 let norm1 = nn::group_norm(32, in_channels, 1e-6, vs.pp("norm1"))?;
28 let conv1 = nn::conv2d(
29 in_channels,
30 out_channels,
31 3,
32 nn::Conv2dConfig {
33 padding: 1,
34 ..Default::default()
35 },
36 vs.pp("conv1"),
37 )?;
38 let norm2 = nn::group_norm(32, out_channels, 1e-6, vs.pp("norm2"))?;
39 let conv2 = nn::conv2d(
40 out_channels,
41 out_channels,
42 3,
43 nn::Conv2dConfig {
44 padding: 1,
45 ..Default::default()
46 },
47 vs.pp("conv2"),
48 )?;
49 let residual_conv = if in_channels != out_channels {
50 Some(nn::conv2d(
51 in_channels,
52 out_channels,
53 1,
54 Default::default(),
55 vs.pp("nin_shortcut"),
56 )?)
57 } else {
58 None
59 };
60 Ok(Self {
61 norm1,
62 conv1,
63 norm2,
64 conv2,
65 residual_conv,
66 })
67 }
68}
69
70impl Module for ResnetBlock {
71 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
72 let residual = if let Some(ref conv) = self.residual_conv {
73 conv.forward(xs)?
74 } else {
75 xs.clone()
76 };
77 let h = self.norm1.forward(xs)?.silu()?;
78 let h = self.conv1.forward(&h)?;
79 let h = self.norm2.forward(&h)?.silu()?;
80 let h = self.conv2.forward(&h)?;
81 h + residual
82 }
83}
84
85#[derive(Debug)]
87struct AttentionBlock {
88 group_norm: nn::GroupNorm,
89 to_qkv: nn::Conv2d,
90 to_out: nn::Conv2d,
91 channels: usize,
92}
93
94impl AttentionBlock {
95 fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
96 let group_norm = nn::group_norm(32, channels, 1e-6, vs.pp("group_norm"))?;
97 let to_qkv = nn::conv2d(
98 channels,
99 channels * 3,
100 1,
101 Default::default(),
102 vs.pp("to_qkv"),
103 )?;
104 let to_out = nn::conv2d(channels, channels, 1, Default::default(), vs.pp("to_out"))?;
105 Ok(Self {
106 group_norm,
107 to_qkv,
108 to_out,
109 channels,
110 })
111 }
112}
113
114impl Module for AttentionBlock {
115 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
116 let residual = xs;
117 let (b, _c, h, w) = xs.dims4()?;
118 let xs = self.group_norm.forward(xs)?;
119 let qkv = self.to_qkv.forward(&xs)?;
120 let qkv = qkv.reshape((b, 3, self.channels, h * w))?;
121 let q = qkv.narrow(1, 0, 1)?.squeeze(1)?;
122 let k = qkv.narrow(1, 1, 1)?.squeeze(1)?;
123 let v = qkv.narrow(1, 2, 1)?.squeeze(1)?;
124
125 let scale = (self.channels as f64).powf(-0.5);
126 let attn = (q.transpose(1, 2)?.matmul(&k)? * scale)?;
127 let attn = nn::ops::softmax_last_dim(&attn)?;
128 let out = v.matmul(&attn.transpose(1, 2)?)?;
129 let out = out.reshape((b, self.channels, h, w))?;
130 let out = self.to_out.forward(&out)?;
131 out + residual
132 }
133}
134
135#[derive(Debug)]
137struct Downsample {
138 conv: nn::Conv2d,
139}
140
141impl Downsample {
142 fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
143 let conv = nn::conv2d(
144 channels,
145 channels,
146 3,
147 nn::Conv2dConfig {
148 stride: 2,
149 padding: 1,
150 ..Default::default()
151 },
152 vs.pp("conv"),
153 )?;
154 Ok(Self { conv })
155 }
156}
157
158impl Module for Downsample {
159 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
160 self.conv.forward(xs)
161 }
162}
163
164#[derive(Debug)]
166struct Upsample {
167 conv: nn::Conv2d,
168}
169
170impl Upsample {
171 fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
172 let conv = nn::conv2d(
173 channels,
174 channels,
175 3,
176 nn::Conv2dConfig {
177 padding: 1,
178 ..Default::default()
179 },
180 vs.pp("conv"),
181 )?;
182 Ok(Self { conv })
183 }
184}
185
186impl Module for Upsample {
187 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
188 let (_, _, h, w) = xs.dims4()?;
189 let xs = xs.upsample_nearest2d(h * 2, w * 2)?;
190 self.conv.forward(&xs)
191 }
192}
193
194#[derive(Debug)]
200struct Encoder {
201 conv_in: nn::Conv2d,
202 down_blocks: Vec<Vec<ResnetBlock>>,
203 downsamplers: Vec<Option<Downsample>>,
204 mid_block_1: ResnetBlock,
205 mid_attn: AttentionBlock,
206 mid_block_2: ResnetBlock,
207 conv_norm_out: nn::GroupNorm,
208 conv_out: nn::Conv2d,
209}
210
211impl Encoder {
212 fn new(vs: nn::VarBuilder, in_channels: usize, latent_channels: usize) -> Result<Self> {
213 let block_channels = [128, 256, 512, 512];
214 let base_ch = block_channels[0];
215
216 let conv_in = nn::conv2d(
217 in_channels,
218 base_ch,
219 3,
220 nn::Conv2dConfig {
221 padding: 1,
222 ..Default::default()
223 },
224 vs.pp("conv_in"),
225 )?;
226
227 let mut down_blocks = Vec::new();
228 let mut downsamplers = Vec::new();
229 let mut ch = base_ch;
230 let vs_down = vs.pp("down_blocks");
231 for (i, &out_ch) in block_channels.iter().enumerate() {
232 let vs_block = vs_down.pp(i.to_string());
233 let mut resnets = Vec::new();
234 for j in 0..2 {
235 let in_ch = if j == 0 { ch } else { out_ch };
236 resnets.push(ResnetBlock::new(
237 vs_block.pp("resnets").pp(j.to_string()),
238 in_ch,
239 out_ch,
240 )?);
241 }
242 ch = out_ch;
243 down_blocks.push(resnets);
244 if i < block_channels.len() - 1 {
245 downsamplers.push(Some(Downsample::new(vs_block.pp("downsamplers.0"), ch)?));
246 } else {
247 downsamplers.push(None);
248 }
249 }
250
251 let vs_mid = vs.pp("mid_block");
252 let mid_block_1 = ResnetBlock::new(vs_mid.pp("resnets.0"), ch, ch)?;
253 let mid_attn = AttentionBlock::new(vs_mid.pp("attentions.0"), ch)?;
254 let mid_block_2 = ResnetBlock::new(vs_mid.pp("resnets.1"), ch, ch)?;
255
256 let conv_norm_out = nn::group_norm(32, ch, 1e-6, vs.pp("conv_norm_out"))?;
257 let conv_out = nn::conv2d(
259 ch,
260 latent_channels * 2,
261 3,
262 nn::Conv2dConfig {
263 padding: 1,
264 ..Default::default()
265 },
266 vs.pp("conv_out"),
267 )?;
268
269 Ok(Self {
270 conv_in,
271 down_blocks,
272 downsamplers,
273 mid_block_1,
274 mid_attn,
275 mid_block_2,
276 conv_norm_out,
277 conv_out,
278 })
279 }
280
281 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
282 let mut h = self.conv_in.forward(xs)?;
283
284 for (resnets, ds) in self.down_blocks.iter().zip(self.downsamplers.iter()) {
285 for resnet in resnets {
286 h = resnet.forward(&h)?;
287 }
288 if let Some(ref downsample) = ds {
289 h = downsample.forward(&h)?;
290 }
291 }
292
293 h = self.mid_block_1.forward(&h)?;
294 h = self.mid_attn.forward(&h)?;
295 h = self.mid_block_2.forward(&h)?;
296
297 h = self.conv_norm_out.forward(&h)?.silu()?;
298 self.conv_out.forward(&h)
299 }
300}
301
302#[derive(Debug)]
308struct Decoder {
309 conv_in: nn::Conv2d,
310 mid_block_1: ResnetBlock,
311 mid_attn: AttentionBlock,
312 mid_block_2: ResnetBlock,
313 up_blocks: Vec<Vec<ResnetBlock>>,
314 upsamplers: Vec<Option<Upsample>>,
315 conv_norm_out: nn::GroupNorm,
316 conv_out: nn::Conv2d,
317}
318
319impl Decoder {
320 fn new(vs: nn::VarBuilder, latent_channels: usize, out_channels: usize) -> Result<Self> {
321 let block_channels = [512, 512, 256, 128];
322 let first_ch = block_channels[0];
323
324 let conv_in = nn::conv2d(
325 latent_channels,
326 first_ch,
327 3,
328 nn::Conv2dConfig {
329 padding: 1,
330 ..Default::default()
331 },
332 vs.pp("conv_in"),
333 )?;
334
335 let vs_mid = vs.pp("mid_block");
336 let mid_block_1 = ResnetBlock::new(vs_mid.pp("resnets.0"), first_ch, first_ch)?;
337 let mid_attn = AttentionBlock::new(vs_mid.pp("attentions.0"), first_ch)?;
338 let mid_block_2 = ResnetBlock::new(vs_mid.pp("resnets.1"), first_ch, first_ch)?;
339
340 let mut up_blocks = Vec::new();
341 let mut upsamplers = Vec::new();
342 let mut ch = first_ch;
343 let vs_up = vs.pp("up_blocks");
344 for (i, &out_ch) in block_channels.iter().enumerate() {
345 let vs_block = vs_up.pp(i.to_string());
346 let mut resnets = Vec::new();
347 for j in 0..3 {
348 let in_ch = if j == 0 { ch } else { out_ch };
349 resnets.push(ResnetBlock::new(
350 vs_block.pp("resnets").pp(j.to_string()),
351 in_ch,
352 out_ch,
353 )?);
354 }
355 ch = out_ch;
356 up_blocks.push(resnets);
357 if i < block_channels.len() - 1 {
358 upsamplers.push(Some(Upsample::new(vs_block.pp("upsamplers.0"), ch)?));
359 } else {
360 upsamplers.push(None);
361 }
362 }
363
364 let conv_norm_out = nn::group_norm(32, ch, 1e-6, vs.pp("conv_norm_out"))?;
365 let conv_out = nn::conv2d(
366 ch,
367 out_channels,
368 3,
369 nn::Conv2dConfig {
370 padding: 1,
371 ..Default::default()
372 },
373 vs.pp("conv_out"),
374 )?;
375
376 Ok(Self {
377 conv_in,
378 mid_block_1,
379 mid_attn,
380 mid_block_2,
381 up_blocks,
382 upsamplers,
383 conv_norm_out,
384 conv_out,
385 })
386 }
387
388 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
389 let mut h = self.conv_in.forward(xs)?;
390
391 h = self.mid_block_1.forward(&h)?;
392 h = self.mid_attn.forward(&h)?;
393 h = self.mid_block_2.forward(&h)?;
394
395 for (resnets, us) in self.up_blocks.iter().zip(self.upsamplers.iter()) {
396 for resnet in resnets {
397 h = resnet.forward(&h)?;
398 }
399 if let Some(ref upsample) = us {
400 h = upsample.forward(&h)?;
401 }
402 }
403
404 h = self.conv_norm_out.forward(&h)?.silu()?;
405 self.conv_out.forward(&h)
406 }
407}
408
409#[derive(Debug)]
415pub struct Vae {
416 encoder: Encoder,
417 decoder: Decoder,
418 quant_conv: nn::Conv2d,
420 post_quant_conv: nn::Conv2d,
422 scaling_factor: f64,
424}
425
426impl Vae {
427 pub fn new(vs: nn::VarBuilder, latent_channels: usize, scaling_factor: f64) -> Result<Self> {
429 let encoder = Encoder::new(vs.pp("encoder"), 3, latent_channels)?;
430 let decoder = Decoder::new(vs.pp("decoder"), latent_channels, 3)?;
431 let quant_conv = nn::conv2d(
432 latent_channels * 2,
433 latent_channels * 2,
434 1,
435 Default::default(),
436 vs.pp("quant_conv"),
437 )?;
438 let post_quant_conv = nn::conv2d(
439 latent_channels,
440 latent_channels,
441 1,
442 Default::default(),
443 vs.pp("post_quant_conv"),
444 )?;
445 Ok(Self {
446 encoder,
447 decoder,
448 quant_conv,
449 post_quant_conv,
450 scaling_factor,
451 })
452 }
453
454 pub fn encode(&self, image: &Tensor) -> Result<Tensor> {
460 let h = self.encoder.forward(image)?;
461 let moments = self.quant_conv.forward(&h)?;
462 let channels = moments.dim(1)? / 2;
463 let mean = moments.narrow(1, 0, channels)?;
465 mean * self.scaling_factor
467 }
468
469 pub fn decode(&self, latents: &Tensor) -> Result<Tensor> {
475 let z = (latents * (1.0 / self.scaling_factor))?;
476 let z = self.post_quant_conv.forward(&z)?;
477 self.decoder.forward(&z)
478 }
479}