Skip to main content

oxigaf_diffusion/
vae.rs

1//! Variational Autoencoder (SD 2.1 compatible).
2//!
3//! Encodes images to latent space and decodes latents back to pixel space.
4//! This is a simplified but functional VAE that mirrors the architecture
5//! used in Stable Diffusion 2.1.
6
7use candle_core::{Result, Tensor};
8use candle_nn as nn;
9use candle_nn::Module;
10
11// ---------------------------------------------------------------------------
12// Building blocks
13// ---------------------------------------------------------------------------
14
15/// A ResNet block used in both encoder and decoder.
16#[derive(Debug)]
17struct ResnetBlock {
18    norm1: nn::GroupNorm,
19    conv1: nn::Conv2d,
20    norm2: nn::GroupNorm,
21    conv2: nn::Conv2d,
22    residual_conv: Option<nn::Conv2d>,
23}
24
25impl ResnetBlock {
26    fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
27        let norm1 = nn::group_norm(32, in_channels, 1e-6, vs.pp("norm1"))?;
28        let conv1 = nn::conv2d(
29            in_channels,
30            out_channels,
31            3,
32            nn::Conv2dConfig {
33                padding: 1,
34                ..Default::default()
35            },
36            vs.pp("conv1"),
37        )?;
38        let norm2 = nn::group_norm(32, out_channels, 1e-6, vs.pp("norm2"))?;
39        let conv2 = nn::conv2d(
40            out_channels,
41            out_channels,
42            3,
43            nn::Conv2dConfig {
44                padding: 1,
45                ..Default::default()
46            },
47            vs.pp("conv2"),
48        )?;
49        let residual_conv = if in_channels != out_channels {
50            Some(nn::conv2d(
51                in_channels,
52                out_channels,
53                1,
54                Default::default(),
55                vs.pp("nin_shortcut"),
56            )?)
57        } else {
58            None
59        };
60        Ok(Self {
61            norm1,
62            conv1,
63            norm2,
64            conv2,
65            residual_conv,
66        })
67    }
68}
69
70impl Module for ResnetBlock {
71    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
72        let residual = if let Some(ref conv) = self.residual_conv {
73            conv.forward(xs)?
74        } else {
75            xs.clone()
76        };
77        let h = self.norm1.forward(xs)?.silu()?;
78        let h = self.conv1.forward(&h)?;
79        let h = self.norm2.forward(&h)?.silu()?;
80        let h = self.conv2.forward(&h)?;
81        h + residual
82    }
83}
84
85/// Self-attention block for the VAE mid-block.
86#[derive(Debug)]
87struct AttentionBlock {
88    group_norm: nn::GroupNorm,
89    to_qkv: nn::Conv2d,
90    to_out: nn::Conv2d,
91    channels: usize,
92}
93
94impl AttentionBlock {
95    fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
96        let group_norm = nn::group_norm(32, channels, 1e-6, vs.pp("group_norm"))?;
97        let to_qkv = nn::conv2d(
98            channels,
99            channels * 3,
100            1,
101            Default::default(),
102            vs.pp("to_qkv"),
103        )?;
104        let to_out = nn::conv2d(channels, channels, 1, Default::default(), vs.pp("to_out"))?;
105        Ok(Self {
106            group_norm,
107            to_qkv,
108            to_out,
109            channels,
110        })
111    }
112}
113
114impl Module for AttentionBlock {
115    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
116        let residual = xs;
117        let (b, _c, h, w) = xs.dims4()?;
118        let xs = self.group_norm.forward(xs)?;
119        let qkv = self.to_qkv.forward(&xs)?;
120        let qkv = qkv.reshape((b, 3, self.channels, h * w))?;
121        let q = qkv.narrow(1, 0, 1)?.squeeze(1)?;
122        let k = qkv.narrow(1, 1, 1)?.squeeze(1)?;
123        let v = qkv.narrow(1, 2, 1)?.squeeze(1)?;
124
125        let scale = (self.channels as f64).powf(-0.5);
126        let attn = (q.transpose(1, 2)?.matmul(&k)? * scale)?;
127        let attn = nn::ops::softmax_last_dim(&attn)?;
128        let out = v.matmul(&attn.transpose(1, 2)?)?;
129        let out = out.reshape((b, self.channels, h, w))?;
130        let out = self.to_out.forward(&out)?;
131        out + residual
132    }
133}
134
135/// Downsample block (strided convolution).
136#[derive(Debug)]
137struct Downsample {
138    conv: nn::Conv2d,
139}
140
141impl Downsample {
142    fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
143        let conv = nn::conv2d(
144            channels,
145            channels,
146            3,
147            nn::Conv2dConfig {
148                stride: 2,
149                padding: 1,
150                ..Default::default()
151            },
152            vs.pp("conv"),
153        )?;
154        Ok(Self { conv })
155    }
156}
157
158impl Module for Downsample {
159    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
160        self.conv.forward(xs)
161    }
162}
163
164/// Upsample block (nearest-neighbour interpolation + conv).
165#[derive(Debug)]
166struct Upsample {
167    conv: nn::Conv2d,
168}
169
170impl Upsample {
171    fn new(vs: nn::VarBuilder, channels: usize) -> Result<Self> {
172        let conv = nn::conv2d(
173            channels,
174            channels,
175            3,
176            nn::Conv2dConfig {
177                padding: 1,
178                ..Default::default()
179            },
180            vs.pp("conv"),
181        )?;
182        Ok(Self { conv })
183    }
184}
185
186impl Module for Upsample {
187    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
188        let (_, _, h, w) = xs.dims4()?;
189        let xs = xs.upsample_nearest2d(h * 2, w * 2)?;
190        self.conv.forward(&xs)
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Encoder
196// ---------------------------------------------------------------------------
197
198/// VAE encoder: image → latent distribution parameters.
199#[derive(Debug)]
200struct Encoder {
201    conv_in: nn::Conv2d,
202    down_blocks: Vec<Vec<ResnetBlock>>,
203    downsamplers: Vec<Option<Downsample>>,
204    mid_block_1: ResnetBlock,
205    mid_attn: AttentionBlock,
206    mid_block_2: ResnetBlock,
207    conv_norm_out: nn::GroupNorm,
208    conv_out: nn::Conv2d,
209}
210
211impl Encoder {
212    fn new(vs: nn::VarBuilder, in_channels: usize, latent_channels: usize) -> Result<Self> {
213        let block_channels = [128, 256, 512, 512];
214        let base_ch = block_channels[0];
215
216        let conv_in = nn::conv2d(
217            in_channels,
218            base_ch,
219            3,
220            nn::Conv2dConfig {
221                padding: 1,
222                ..Default::default()
223            },
224            vs.pp("conv_in"),
225        )?;
226
227        let mut down_blocks = Vec::new();
228        let mut downsamplers = Vec::new();
229        let mut ch = base_ch;
230        let vs_down = vs.pp("down_blocks");
231        for (i, &out_ch) in block_channels.iter().enumerate() {
232            let vs_block = vs_down.pp(i.to_string());
233            let mut resnets = Vec::new();
234            for j in 0..2 {
235                let in_ch = if j == 0 { ch } else { out_ch };
236                resnets.push(ResnetBlock::new(
237                    vs_block.pp("resnets").pp(j.to_string()),
238                    in_ch,
239                    out_ch,
240                )?);
241            }
242            ch = out_ch;
243            down_blocks.push(resnets);
244            if i < block_channels.len() - 1 {
245                downsamplers.push(Some(Downsample::new(vs_block.pp("downsamplers.0"), ch)?));
246            } else {
247                downsamplers.push(None);
248            }
249        }
250
251        let vs_mid = vs.pp("mid_block");
252        let mid_block_1 = ResnetBlock::new(vs_mid.pp("resnets.0"), ch, ch)?;
253        let mid_attn = AttentionBlock::new(vs_mid.pp("attentions.0"), ch)?;
254        let mid_block_2 = ResnetBlock::new(vs_mid.pp("resnets.1"), ch, ch)?;
255
256        let conv_norm_out = nn::group_norm(32, ch, 1e-6, vs.pp("conv_norm_out"))?;
257        // Output 2× latent channels for mean + log_var
258        let conv_out = nn::conv2d(
259            ch,
260            latent_channels * 2,
261            3,
262            nn::Conv2dConfig {
263                padding: 1,
264                ..Default::default()
265            },
266            vs.pp("conv_out"),
267        )?;
268
269        Ok(Self {
270            conv_in,
271            down_blocks,
272            downsamplers,
273            mid_block_1,
274            mid_attn,
275            mid_block_2,
276            conv_norm_out,
277            conv_out,
278        })
279    }
280
281    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
282        let mut h = self.conv_in.forward(xs)?;
283
284        for (resnets, ds) in self.down_blocks.iter().zip(self.downsamplers.iter()) {
285            for resnet in resnets {
286                h = resnet.forward(&h)?;
287            }
288            if let Some(ref downsample) = ds {
289                h = downsample.forward(&h)?;
290            }
291        }
292
293        h = self.mid_block_1.forward(&h)?;
294        h = self.mid_attn.forward(&h)?;
295        h = self.mid_block_2.forward(&h)?;
296
297        h = self.conv_norm_out.forward(&h)?.silu()?;
298        self.conv_out.forward(&h)
299    }
300}
301
302// ---------------------------------------------------------------------------
303// Decoder
304// ---------------------------------------------------------------------------
305
306/// VAE decoder: latent → image.
307#[derive(Debug)]
308struct Decoder {
309    conv_in: nn::Conv2d,
310    mid_block_1: ResnetBlock,
311    mid_attn: AttentionBlock,
312    mid_block_2: ResnetBlock,
313    up_blocks: Vec<Vec<ResnetBlock>>,
314    upsamplers: Vec<Option<Upsample>>,
315    conv_norm_out: nn::GroupNorm,
316    conv_out: nn::Conv2d,
317}
318
319impl Decoder {
320    fn new(vs: nn::VarBuilder, latent_channels: usize, out_channels: usize) -> Result<Self> {
321        let block_channels = [512, 512, 256, 128];
322        let first_ch = block_channels[0];
323
324        let conv_in = nn::conv2d(
325            latent_channels,
326            first_ch,
327            3,
328            nn::Conv2dConfig {
329                padding: 1,
330                ..Default::default()
331            },
332            vs.pp("conv_in"),
333        )?;
334
335        let vs_mid = vs.pp("mid_block");
336        let mid_block_1 = ResnetBlock::new(vs_mid.pp("resnets.0"), first_ch, first_ch)?;
337        let mid_attn = AttentionBlock::new(vs_mid.pp("attentions.0"), first_ch)?;
338        let mid_block_2 = ResnetBlock::new(vs_mid.pp("resnets.1"), first_ch, first_ch)?;
339
340        let mut up_blocks = Vec::new();
341        let mut upsamplers = Vec::new();
342        let mut ch = first_ch;
343        let vs_up = vs.pp("up_blocks");
344        for (i, &out_ch) in block_channels.iter().enumerate() {
345            let vs_block = vs_up.pp(i.to_string());
346            let mut resnets = Vec::new();
347            for j in 0..3 {
348                let in_ch = if j == 0 { ch } else { out_ch };
349                resnets.push(ResnetBlock::new(
350                    vs_block.pp("resnets").pp(j.to_string()),
351                    in_ch,
352                    out_ch,
353                )?);
354            }
355            ch = out_ch;
356            up_blocks.push(resnets);
357            if i < block_channels.len() - 1 {
358                upsamplers.push(Some(Upsample::new(vs_block.pp("upsamplers.0"), ch)?));
359            } else {
360                upsamplers.push(None);
361            }
362        }
363
364        let conv_norm_out = nn::group_norm(32, ch, 1e-6, vs.pp("conv_norm_out"))?;
365        let conv_out = nn::conv2d(
366            ch,
367            out_channels,
368            3,
369            nn::Conv2dConfig {
370                padding: 1,
371                ..Default::default()
372            },
373            vs.pp("conv_out"),
374        )?;
375
376        Ok(Self {
377            conv_in,
378            mid_block_1,
379            mid_attn,
380            mid_block_2,
381            up_blocks,
382            upsamplers,
383            conv_norm_out,
384            conv_out,
385        })
386    }
387
388    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
389        let mut h = self.conv_in.forward(xs)?;
390
391        h = self.mid_block_1.forward(&h)?;
392        h = self.mid_attn.forward(&h)?;
393        h = self.mid_block_2.forward(&h)?;
394
395        for (resnets, us) in self.up_blocks.iter().zip(self.upsamplers.iter()) {
396            for resnet in resnets {
397                h = resnet.forward(&h)?;
398            }
399            if let Some(ref upsample) = us {
400                h = upsample.forward(&h)?;
401            }
402        }
403
404        h = self.conv_norm_out.forward(&h)?.silu()?;
405        self.conv_out.forward(&h)
406    }
407}
408
409// ---------------------------------------------------------------------------
410// VAE public API
411// ---------------------------------------------------------------------------
412
413/// Variational Autoencoder for encoding/decoding between pixel and latent space.
414#[derive(Debug)]
415pub struct Vae {
416    encoder: Encoder,
417    decoder: Decoder,
418    /// Learned post-quant convolution (1×1).
419    quant_conv: nn::Conv2d,
420    /// Learned pre-decode convolution (1×1).
421    post_quant_conv: nn::Conv2d,
422    /// Scaling factor for the latent space.
423    scaling_factor: f64,
424}
425
426impl Vae {
427    /// Load a VAE from a VarBuilder.
428    pub fn new(vs: nn::VarBuilder, latent_channels: usize, scaling_factor: f64) -> Result<Self> {
429        let encoder = Encoder::new(vs.pp("encoder"), 3, latent_channels)?;
430        let decoder = Decoder::new(vs.pp("decoder"), latent_channels, 3)?;
431        let quant_conv = nn::conv2d(
432            latent_channels * 2,
433            latent_channels * 2,
434            1,
435            Default::default(),
436            vs.pp("quant_conv"),
437        )?;
438        let post_quant_conv = nn::conv2d(
439            latent_channels,
440            latent_channels,
441            1,
442            Default::default(),
443            vs.pp("post_quant_conv"),
444        )?;
445        Ok(Self {
446            encoder,
447            decoder,
448            quant_conv,
449            post_quant_conv,
450            scaling_factor,
451        })
452    }
453
454    /// Encode an image to latent space (returns the mean of the posterior).
455    ///
456    /// - `image`: `(B, 3, H, W)` tensor in `[-1, 1]` range.
457    ///
458    /// Returns `(B, latent_channels, H/8, W/8)`.
459    pub fn encode(&self, image: &Tensor) -> Result<Tensor> {
460        let h = self.encoder.forward(image)?;
461        let moments = self.quant_conv.forward(&h)?;
462        let channels = moments.dim(1)? / 2;
463        // Take the mean (first half of channels)
464        let mean = moments.narrow(1, 0, channels)?;
465        // Scale
466        mean * self.scaling_factor
467    }
468
469    /// Decode a latent tensor back to pixel space.
470    ///
471    /// - `latents`: `(B, latent_channels, h, w)` scaled latent tensor.
472    ///
473    /// Returns `(B, 3, H, W)` in `[-1, 1]` range.
474    pub fn decode(&self, latents: &Tensor) -> Result<Tensor> {
475        let z = (latents * (1.0 / self.scaling_factor))?;
476        let z = self.post_quant_conv.forward(&z)?;
477        self.decoder.forward(&z)
478    }
479}