1#![allow(dead_code)]
2use super::unet_2d_blocks::{
8 DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
9 UpDecoderBlock2D, UpDecoderBlock2DConfig,
10};
11use candle::{Result, Tensor};
12use candle_nn as nn;
13use candle_nn::Module;
14
15#[derive(Debug, Clone)]
16struct EncoderConfig {
17 block_out_channels: Vec<usize>,
19 layers_per_block: usize,
20 norm_num_groups: usize,
21 double_z: bool,
22}
23
24impl Default for EncoderConfig {
25 fn default() -> Self {
26 Self {
27 block_out_channels: vec![64],
28 layers_per_block: 2,
29 norm_num_groups: 32,
30 double_z: true,
31 }
32 }
33}
34
35#[derive(Debug)]
36struct Encoder {
37 conv_in: nn::Conv2d,
38 down_blocks: Vec<DownEncoderBlock2D>,
39 mid_block: UNetMidBlock2D,
40 conv_norm_out: nn::GroupNorm,
41 conv_out: nn::Conv2d,
42 #[allow(dead_code)]
43 config: EncoderConfig,
44}
45
46impl Encoder {
47 fn new(
48 vs: nn::VarBuilder,
49 in_channels: usize,
50 out_channels: usize,
51 config: EncoderConfig,
52 ) -> Result<Self> {
53 let conv_cfg = nn::Conv2dConfig {
54 padding: 1,
55 ..Default::default()
56 };
57 let conv_in = nn::conv2d(
58 in_channels,
59 config.block_out_channels[0],
60 3,
61 conv_cfg,
62 vs.pp("conv_in"),
63 )?;
64 let mut down_blocks = vec![];
65 let vs_down_blocks = vs.pp("down_blocks");
66 for index in 0..config.block_out_channels.len() {
67 let out_channels = config.block_out_channels[index];
68 let in_channels = if index > 0 {
69 config.block_out_channels[index - 1]
70 } else {
71 config.block_out_channels[0]
72 };
73 let is_final = index + 1 == config.block_out_channels.len();
74 let cfg = DownEncoderBlock2DConfig {
75 num_layers: config.layers_per_block,
76 resnet_eps: 1e-6,
77 resnet_groups: config.norm_num_groups,
78 add_downsample: !is_final,
79 downsample_padding: 0,
80 ..Default::default()
81 };
82 let down_block = DownEncoderBlock2D::new(
83 vs_down_blocks.pp(index.to_string()),
84 in_channels,
85 out_channels,
86 cfg,
87 )?;
88 down_blocks.push(down_block)
89 }
90 let last_block_out_channels = *config.block_out_channels.last().unwrap();
91 let mid_cfg = UNetMidBlock2DConfig {
92 resnet_eps: 1e-6,
93 output_scale_factor: 1.,
94 attn_num_head_channels: None,
95 resnet_groups: Some(config.norm_num_groups),
96 ..Default::default()
97 };
98 let mid_block =
99 UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
100 let conv_norm_out = nn::group_norm(
101 config.norm_num_groups,
102 last_block_out_channels,
103 1e-6,
104 vs.pp("conv_norm_out"),
105 )?;
106 let conv_out_channels = if config.double_z {
107 2 * out_channels
108 } else {
109 out_channels
110 };
111 let conv_cfg = nn::Conv2dConfig {
112 padding: 1,
113 ..Default::default()
114 };
115 let conv_out = nn::conv2d(
116 last_block_out_channels,
117 conv_out_channels,
118 3,
119 conv_cfg,
120 vs.pp("conv_out"),
121 )?;
122 Ok(Self {
123 conv_in,
124 down_blocks,
125 mid_block,
126 conv_norm_out,
127 conv_out,
128 config,
129 })
130 }
131}
132
133impl Encoder {
134 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
135 let mut xs = xs.apply(&self.conv_in)?;
136 for down_block in self.down_blocks.iter() {
137 xs = xs.apply(down_block)?
138 }
139 let xs = self
140 .mid_block
141 .forward(&xs, None)?
142 .apply(&self.conv_norm_out)?;
143 nn::ops::silu(&xs)?.apply(&self.conv_out)
144 }
145}
146
147#[derive(Debug, Clone)]
148struct DecoderConfig {
149 block_out_channels: Vec<usize>,
151 layers_per_block: usize,
152 norm_num_groups: usize,
153}
154
155impl Default for DecoderConfig {
156 fn default() -> Self {
157 Self {
158 block_out_channels: vec![64],
159 layers_per_block: 2,
160 norm_num_groups: 32,
161 }
162 }
163}
164
165#[derive(Debug)]
166struct Decoder {
167 conv_in: nn::Conv2d,
168 up_blocks: Vec<UpDecoderBlock2D>,
169 mid_block: UNetMidBlock2D,
170 conv_norm_out: nn::GroupNorm,
171 conv_out: nn::Conv2d,
172 #[allow(dead_code)]
173 config: DecoderConfig,
174}
175
176impl Decoder {
177 fn new(
178 vs: nn::VarBuilder,
179 in_channels: usize,
180 out_channels: usize,
181 config: DecoderConfig,
182 ) -> Result<Self> {
183 let n_block_out_channels = config.block_out_channels.len();
184 let last_block_out_channels = *config.block_out_channels.last().unwrap();
185 let conv_cfg = nn::Conv2dConfig {
186 padding: 1,
187 ..Default::default()
188 };
189 let conv_in = nn::conv2d(
190 in_channels,
191 last_block_out_channels,
192 3,
193 conv_cfg,
194 vs.pp("conv_in"),
195 )?;
196 let mid_cfg = UNetMidBlock2DConfig {
197 resnet_eps: 1e-6,
198 output_scale_factor: 1.,
199 attn_num_head_channels: None,
200 resnet_groups: Some(config.norm_num_groups),
201 ..Default::default()
202 };
203 let mid_block =
204 UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
205 let mut up_blocks = vec![];
206 let vs_up_blocks = vs.pp("up_blocks");
207 let reversed_block_out_channels: Vec<_> =
208 config.block_out_channels.iter().copied().rev().collect();
209 for index in 0..n_block_out_channels {
210 let out_channels = reversed_block_out_channels[index];
211 let in_channels = if index > 0 {
212 reversed_block_out_channels[index - 1]
213 } else {
214 reversed_block_out_channels[0]
215 };
216 let is_final = index + 1 == n_block_out_channels;
217 let cfg = UpDecoderBlock2DConfig {
218 num_layers: config.layers_per_block + 1,
219 resnet_eps: 1e-6,
220 resnet_groups: config.norm_num_groups,
221 add_upsample: !is_final,
222 ..Default::default()
223 };
224 let up_block = UpDecoderBlock2D::new(
225 vs_up_blocks.pp(index.to_string()),
226 in_channels,
227 out_channels,
228 cfg,
229 )?;
230 up_blocks.push(up_block)
231 }
232 let conv_norm_out = nn::group_norm(
233 config.norm_num_groups,
234 config.block_out_channels[0],
235 1e-6,
236 vs.pp("conv_norm_out"),
237 )?;
238 let conv_cfg = nn::Conv2dConfig {
239 padding: 1,
240 ..Default::default()
241 };
242 let conv_out = nn::conv2d(
243 config.block_out_channels[0],
244 out_channels,
245 3,
246 conv_cfg,
247 vs.pp("conv_out"),
248 )?;
249 Ok(Self {
250 conv_in,
251 up_blocks,
252 mid_block,
253 conv_norm_out,
254 conv_out,
255 config,
256 })
257 }
258}
259
260impl Decoder {
261 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
262 let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
263 for up_block in self.up_blocks.iter() {
264 xs = up_block.forward(&xs)?
265 }
266 let xs = self.conv_norm_out.forward(&xs)?;
267 let xs = nn::ops::silu(&xs)?;
268 self.conv_out.forward(&xs)
269 }
270}
271
272#[derive(Debug, Clone)]
273pub struct AutoEncoderKLConfig {
274 pub block_out_channels: Vec<usize>,
275 pub layers_per_block: usize,
276 pub latent_channels: usize,
277 pub norm_num_groups: usize,
278 pub use_quant_conv: bool,
279 pub use_post_quant_conv: bool,
280}
281
282impl Default for AutoEncoderKLConfig {
283 fn default() -> Self {
284 Self {
285 block_out_channels: vec![64],
286 layers_per_block: 1,
287 latent_channels: 4,
288 norm_num_groups: 32,
289 use_quant_conv: true,
290 use_post_quant_conv: true,
291 }
292 }
293}
294
295pub struct DiagonalGaussianDistribution {
296 mean: Tensor,
297 std: Tensor,
298}
299
300impl DiagonalGaussianDistribution {
301 pub fn new(parameters: &Tensor) -> Result<Self> {
302 let mut parameters = parameters.chunk(2, 1)?.into_iter();
303 let mean = parameters.next().unwrap();
304 let logvar = parameters.next().unwrap();
305 let std = (logvar * 0.5)?.exp()?;
306 Ok(DiagonalGaussianDistribution { mean, std })
307 }
308
309 pub fn sample(&self) -> Result<Tensor> {
310 let sample = self.mean.randn_like(0., 1.);
311 &self.mean + &self.std * sample
312 }
313}
314
315#[derive(Debug)]
319pub struct AutoEncoderKL {
320 encoder: Encoder,
321 decoder: Decoder,
322 quant_conv: Option<nn::Conv2d>,
323 post_quant_conv: Option<nn::Conv2d>,
324 pub config: AutoEncoderKLConfig,
325}
326
327impl AutoEncoderKL {
328 pub fn new(
329 vs: nn::VarBuilder,
330 in_channels: usize,
331 out_channels: usize,
332 config: AutoEncoderKLConfig,
333 ) -> Result<Self> {
334 let latent_channels = config.latent_channels;
335 let encoder_cfg = EncoderConfig {
336 block_out_channels: config.block_out_channels.clone(),
337 layers_per_block: config.layers_per_block,
338 norm_num_groups: config.norm_num_groups,
339 double_z: true,
340 };
341 let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
342 let decoder_cfg = DecoderConfig {
343 block_out_channels: config.block_out_channels.clone(),
344 layers_per_block: config.layers_per_block,
345 norm_num_groups: config.norm_num_groups,
346 };
347 let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
348 let conv_cfg = Default::default();
349
350 let quant_conv = {
351 if config.use_quant_conv {
352 Some(nn::conv2d(
353 2 * latent_channels,
354 2 * latent_channels,
355 1,
356 conv_cfg,
357 vs.pp("quant_conv"),
358 )?)
359 } else {
360 None
361 }
362 };
363 let post_quant_conv = {
364 if config.use_post_quant_conv {
365 Some(nn::conv2d(
366 latent_channels,
367 latent_channels,
368 1,
369 conv_cfg,
370 vs.pp("post_quant_conv"),
371 )?)
372 } else {
373 None
374 }
375 };
376 Ok(Self {
377 encoder,
378 decoder,
379 quant_conv,
380 post_quant_conv,
381 config,
382 })
383 }
384
385 pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
387 let xs = self.encoder.forward(xs)?;
388 let parameters = match &self.quant_conv {
389 None => xs,
390 Some(quant_conv) => quant_conv.forward(&xs)?,
391 };
392 DiagonalGaussianDistribution::new(¶meters)
393 }
394
395 pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
397 let xs = match &self.post_quant_conv {
398 None => xs,
399 Some(post_quant_conv) => &post_quant_conv.forward(xs)?,
400 };
401 self.decoder.forward(xs)
402 }
403}