Skip to main content

candle_transformers/models/stable_diffusion/
vae.rs

1#![allow(dead_code)]
2//! # Variational Auto-Encoder (VAE) Models.
3//!
4//! Auto-encoder models compress their input to a usually smaller latent space
5//! before expanding it back to its original shape. This results in the latent values
6//! compressing the original information.
7use super::unet_2d_blocks::{
8    DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
9    UpDecoderBlock2D, UpDecoderBlock2DConfig,
10};
11use candle::{Result, Tensor};
12use candle_nn as nn;
13use candle_nn::Module;
14
15#[derive(Debug, Clone)]
16struct EncoderConfig {
17    // down_block_types: DownEncoderBlock2D
18    block_out_channels: Vec<usize>,
19    layers_per_block: usize,
20    norm_num_groups: usize,
21    double_z: bool,
22}
23
24impl Default for EncoderConfig {
25    fn default() -> Self {
26        Self {
27            block_out_channels: vec![64],
28            layers_per_block: 2,
29            norm_num_groups: 32,
30            double_z: true,
31        }
32    }
33}
34
35#[derive(Debug)]
36struct Encoder {
37    conv_in: nn::Conv2d,
38    down_blocks: Vec<DownEncoderBlock2D>,
39    mid_block: UNetMidBlock2D,
40    conv_norm_out: nn::GroupNorm,
41    conv_out: nn::Conv2d,
42    #[allow(dead_code)]
43    config: EncoderConfig,
44}
45
46impl Encoder {
47    fn new(
48        vs: nn::VarBuilder,
49        in_channels: usize,
50        out_channels: usize,
51        config: EncoderConfig,
52    ) -> Result<Self> {
53        let conv_cfg = nn::Conv2dConfig {
54            padding: 1,
55            ..Default::default()
56        };
57        let conv_in = nn::conv2d(
58            in_channels,
59            config.block_out_channels[0],
60            3,
61            conv_cfg,
62            vs.pp("conv_in"),
63        )?;
64        let mut down_blocks = vec![];
65        let vs_down_blocks = vs.pp("down_blocks");
66        for index in 0..config.block_out_channels.len() {
67            let out_channels = config.block_out_channels[index];
68            let in_channels = if index > 0 {
69                config.block_out_channels[index - 1]
70            } else {
71                config.block_out_channels[0]
72            };
73            let is_final = index + 1 == config.block_out_channels.len();
74            let cfg = DownEncoderBlock2DConfig {
75                num_layers: config.layers_per_block,
76                resnet_eps: 1e-6,
77                resnet_groups: config.norm_num_groups,
78                add_downsample: !is_final,
79                downsample_padding: 0,
80                ..Default::default()
81            };
82            let down_block = DownEncoderBlock2D::new(
83                vs_down_blocks.pp(index.to_string()),
84                in_channels,
85                out_channels,
86                cfg,
87            )?;
88            down_blocks.push(down_block)
89        }
90        let last_block_out_channels = *config.block_out_channels.last().unwrap();
91        let mid_cfg = UNetMidBlock2DConfig {
92            resnet_eps: 1e-6,
93            output_scale_factor: 1.,
94            attn_num_head_channels: None,
95            resnet_groups: Some(config.norm_num_groups),
96            ..Default::default()
97        };
98        let mid_block =
99            UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
100        let conv_norm_out = nn::group_norm(
101            config.norm_num_groups,
102            last_block_out_channels,
103            1e-6,
104            vs.pp("conv_norm_out"),
105        )?;
106        let conv_out_channels = if config.double_z {
107            2 * out_channels
108        } else {
109            out_channels
110        };
111        let conv_cfg = nn::Conv2dConfig {
112            padding: 1,
113            ..Default::default()
114        };
115        let conv_out = nn::conv2d(
116            last_block_out_channels,
117            conv_out_channels,
118            3,
119            conv_cfg,
120            vs.pp("conv_out"),
121        )?;
122        Ok(Self {
123            conv_in,
124            down_blocks,
125            mid_block,
126            conv_norm_out,
127            conv_out,
128            config,
129        })
130    }
131}
132
133impl Encoder {
134    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
135        let mut xs = xs.apply(&self.conv_in)?;
136        for down_block in self.down_blocks.iter() {
137            xs = xs.apply(down_block)?
138        }
139        let xs = self
140            .mid_block
141            .forward(&xs, None)?
142            .apply(&self.conv_norm_out)?;
143        nn::ops::silu(&xs)?.apply(&self.conv_out)
144    }
145}
146
147#[derive(Debug, Clone)]
148struct DecoderConfig {
149    // up_block_types: UpDecoderBlock2D
150    block_out_channels: Vec<usize>,
151    layers_per_block: usize,
152    norm_num_groups: usize,
153}
154
155impl Default for DecoderConfig {
156    fn default() -> Self {
157        Self {
158            block_out_channels: vec![64],
159            layers_per_block: 2,
160            norm_num_groups: 32,
161        }
162    }
163}
164
165#[derive(Debug)]
166struct Decoder {
167    conv_in: nn::Conv2d,
168    up_blocks: Vec<UpDecoderBlock2D>,
169    mid_block: UNetMidBlock2D,
170    conv_norm_out: nn::GroupNorm,
171    conv_out: nn::Conv2d,
172    #[allow(dead_code)]
173    config: DecoderConfig,
174}
175
176impl Decoder {
177    fn new(
178        vs: nn::VarBuilder,
179        in_channels: usize,
180        out_channels: usize,
181        config: DecoderConfig,
182    ) -> Result<Self> {
183        let n_block_out_channels = config.block_out_channels.len();
184        let last_block_out_channels = *config.block_out_channels.last().unwrap();
185        let conv_cfg = nn::Conv2dConfig {
186            padding: 1,
187            ..Default::default()
188        };
189        let conv_in = nn::conv2d(
190            in_channels,
191            last_block_out_channels,
192            3,
193            conv_cfg,
194            vs.pp("conv_in"),
195        )?;
196        let mid_cfg = UNetMidBlock2DConfig {
197            resnet_eps: 1e-6,
198            output_scale_factor: 1.,
199            attn_num_head_channels: None,
200            resnet_groups: Some(config.norm_num_groups),
201            ..Default::default()
202        };
203        let mid_block =
204            UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
205        let mut up_blocks = vec![];
206        let vs_up_blocks = vs.pp("up_blocks");
207        let reversed_block_out_channels: Vec<_> =
208            config.block_out_channels.iter().copied().rev().collect();
209        for index in 0..n_block_out_channels {
210            let out_channels = reversed_block_out_channels[index];
211            let in_channels = if index > 0 {
212                reversed_block_out_channels[index - 1]
213            } else {
214                reversed_block_out_channels[0]
215            };
216            let is_final = index + 1 == n_block_out_channels;
217            let cfg = UpDecoderBlock2DConfig {
218                num_layers: config.layers_per_block + 1,
219                resnet_eps: 1e-6,
220                resnet_groups: config.norm_num_groups,
221                add_upsample: !is_final,
222                ..Default::default()
223            };
224            let up_block = UpDecoderBlock2D::new(
225                vs_up_blocks.pp(index.to_string()),
226                in_channels,
227                out_channels,
228                cfg,
229            )?;
230            up_blocks.push(up_block)
231        }
232        let conv_norm_out = nn::group_norm(
233            config.norm_num_groups,
234            config.block_out_channels[0],
235            1e-6,
236            vs.pp("conv_norm_out"),
237        )?;
238        let conv_cfg = nn::Conv2dConfig {
239            padding: 1,
240            ..Default::default()
241        };
242        let conv_out = nn::conv2d(
243            config.block_out_channels[0],
244            out_channels,
245            3,
246            conv_cfg,
247            vs.pp("conv_out"),
248        )?;
249        Ok(Self {
250            conv_in,
251            up_blocks,
252            mid_block,
253            conv_norm_out,
254            conv_out,
255            config,
256        })
257    }
258}
259
260impl Decoder {
261    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
262        let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
263        for up_block in self.up_blocks.iter() {
264            xs = up_block.forward(&xs)?
265        }
266        let xs = self.conv_norm_out.forward(&xs)?;
267        let xs = nn::ops::silu(&xs)?;
268        self.conv_out.forward(&xs)
269    }
270}
271
272#[derive(Debug, Clone)]
273pub struct AutoEncoderKLConfig {
274    pub block_out_channels: Vec<usize>,
275    pub layers_per_block: usize,
276    pub latent_channels: usize,
277    pub norm_num_groups: usize,
278    pub use_quant_conv: bool,
279    pub use_post_quant_conv: bool,
280}
281
282impl Default for AutoEncoderKLConfig {
283    fn default() -> Self {
284        Self {
285            block_out_channels: vec![64],
286            layers_per_block: 1,
287            latent_channels: 4,
288            norm_num_groups: 32,
289            use_quant_conv: true,
290            use_post_quant_conv: true,
291        }
292    }
293}
294
295pub struct DiagonalGaussianDistribution {
296    mean: Tensor,
297    std: Tensor,
298}
299
300impl DiagonalGaussianDistribution {
301    pub fn new(parameters: &Tensor) -> Result<Self> {
302        let mut parameters = parameters.chunk(2, 1)?.into_iter();
303        let mean = parameters.next().unwrap();
304        let logvar = parameters.next().unwrap();
305        let std = (logvar * 0.5)?.exp()?;
306        Ok(DiagonalGaussianDistribution { mean, std })
307    }
308
309    pub fn sample(&self) -> Result<Tensor> {
310        let sample = self.mean.randn_like(0., 1.);
311        &self.mean + &self.std * sample
312    }
313}
314
315// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
316// This implementation is specific to the config used in stable-diffusion-v1-5
317// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
318#[derive(Debug)]
319pub struct AutoEncoderKL {
320    encoder: Encoder,
321    decoder: Decoder,
322    quant_conv: Option<nn::Conv2d>,
323    post_quant_conv: Option<nn::Conv2d>,
324    pub config: AutoEncoderKLConfig,
325}
326
327impl AutoEncoderKL {
328    pub fn new(
329        vs: nn::VarBuilder,
330        in_channels: usize,
331        out_channels: usize,
332        config: AutoEncoderKLConfig,
333    ) -> Result<Self> {
334        let latent_channels = config.latent_channels;
335        let encoder_cfg = EncoderConfig {
336            block_out_channels: config.block_out_channels.clone(),
337            layers_per_block: config.layers_per_block,
338            norm_num_groups: config.norm_num_groups,
339            double_z: true,
340        };
341        let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
342        let decoder_cfg = DecoderConfig {
343            block_out_channels: config.block_out_channels.clone(),
344            layers_per_block: config.layers_per_block,
345            norm_num_groups: config.norm_num_groups,
346        };
347        let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
348        let conv_cfg = Default::default();
349
350        let quant_conv = {
351            if config.use_quant_conv {
352                Some(nn::conv2d(
353                    2 * latent_channels,
354                    2 * latent_channels,
355                    1,
356                    conv_cfg,
357                    vs.pp("quant_conv"),
358                )?)
359            } else {
360                None
361            }
362        };
363        let post_quant_conv = {
364            if config.use_post_quant_conv {
365                Some(nn::conv2d(
366                    latent_channels,
367                    latent_channels,
368                    1,
369                    conv_cfg,
370                    vs.pp("post_quant_conv"),
371                )?)
372            } else {
373                None
374            }
375        };
376        Ok(Self {
377            encoder,
378            decoder,
379            quant_conv,
380            post_quant_conv,
381            config,
382        })
383    }
384
385    /// Returns the distribution in the latent space.
386    pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
387        let xs = self.encoder.forward(xs)?;
388        let parameters = match &self.quant_conv {
389            None => xs,
390            Some(quant_conv) => quant_conv.forward(&xs)?,
391        };
392        DiagonalGaussianDistribution::new(&parameters)
393    }
394
395    /// Takes as input some sampled values.
396    pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
397        let xs = match &self.post_quant_conv {
398            None => xs,
399            Some(post_quant_conv) => &post_quant_conv.forward(xs)?,
400        };
401        self.decoder.forward(xs)
402    }
403}