klaster 0.2.0

Machine learning library providing modern clusterning algorithms for the Rust programming language
Documentation
// Copyright (C) 2025 Piotr Jabłoński
// Extended copyright information can be found in the LICENSE file.

//! Convolutional autoencoder configuration and architecture.
//! Provides [`Autoencoder`] and [`AutoencoderConfig`] for building a convolutional encoder-decoder model used by SDC.
//!
//! # See also
//! [`crate::sdc::SDCConfig`]

use crate::sdc::dataset::Batch;
use burn::nn::loss::{MseLoss, Reduction};
use burn::tensor::backend::AutodiffBackend;
use burn::train::{RegressionOutput, TrainOutput, TrainStep, ValidStep};
use burn::{
    nn::{
        GroupNorm, GroupNormConfig, LeakyRelu, Linear, LinearConfig, PaddingConfig2d, Sigmoid,
        conv::{Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig},
    },
    prelude::*,
};

/// Convolutional autoencoder model used to learn latent embeddings.
/// Encodes images into a latent vector and reconstructs the input with a decoder.
#[derive(Module, Debug)]
pub struct Autoencoder<B: Backend> {
    encoder: Encoder<B>,
    decoder: Decoder<B>,
}

#[derive(Module, Debug)]
struct Encoder<B: Backend> {
    conv1: Conv2d<B>,
    norm1: GroupNorm<B>,
    conv2: Conv2d<B>,
    norm2: GroupNorm<B>,
    linear: Linear<B>,
    leaky_relu: LeakyRelu,
}

#[derive(Module, Debug)]
struct Decoder<B: Backend> {
    linear: Linear<B>,
    conv_trans1: ConvTranspose2d<B>,
    norm1: GroupNorm<B>,
    conv_trans2: ConvTranspose2d<B>,
    leaky_relu: LeakyRelu,
    sigmoid: Sigmoid,
}

/// Configuration for the Autoencoder model.
/// Defines convolutional and normalization parameters for the encoder/decoder stack.
///
/// # Params
/// - `latent_dim`: Dimensionality of the latent space,
/// - `input_dims`: Dimensions of the input image (height, width),
/// - `channels`: Number of channels for the convolutional layers (input, hidden, output),
/// - `groups`: Number of groups for group normalization,
/// - `leaky_relu_slope`: Negative slope of the Leaky ReLU activation function,
/// - `kernel_size`: Size of the convolutional kernel,
/// - `stride`: Stride of the convolution,
/// - `padding`: Padding for the convolution,
/// - `output_padding`: Output padding for the transposed convolution.
///
/// # Note
/// Default values should work fine for the majority of simple image datasets (e.g. MNIST).
///
/// # See also
/// [`AutoencoderConfig::init`]
#[derive(Config, Debug)]
pub struct AutoencoderConfig {
    pub latent_dim: usize,
    pub input_dims: [usize; 2],
    pub channels: [usize; 3],
    pub groups: usize,
    #[config(default = "0.01")]
    pub leaky_relu_slope: f64,
    #[config(default = "[3, 3]")]
    pub kernel_size: [usize; 2],
    #[config(default = "[2, 2]")]
    pub stride: [usize; 2],
    #[config(default = "[1, 1]")]
    pub padding: [usize; 2],
    #[config(default = "[1, 1]")]
    pub output_padding: [usize; 2],
}

impl AutoencoderConfig {
    const LAYER_LEN: usize = 2;

    /// Initialize an [`Autoencoder`] from this configuration.
    ///
    /// # Note
    /// The output feature size is inferred from `input_dims`, `kernel_size`, `stride`, and `padding`.
    pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
        let [input_ch, hidden_ch, output_ch] = self.channels;

        // CNN final layer output sizes
        let [mut h, mut w] = self.input_dims;
        for _ in 0..Self::LAYER_LEN {
            h = dbg!((h - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1);
            w = dbg!((w - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1] + 1);
        }
        let flat_features = output_ch * h * w;

        Autoencoder {
            encoder: Encoder {
                conv1: Conv2dConfig::new([input_ch, hidden_ch], self.kernel_size)
                    .with_stride(self.stride)
                    .with_padding(PaddingConfig2d::Explicit(self.padding[0], self.padding[1]))
                    .init(device),
                norm1: GroupNormConfig::new(self.groups, hidden_ch).init(device),
                conv2: Conv2dConfig::new([hidden_ch, output_ch], self.kernel_size)
                    .with_stride(self.stride)
                    .with_padding(PaddingConfig2d::Explicit(self.padding[0], self.padding[1]))
                    .init(device),
                norm2: GroupNormConfig::new(self.groups, output_ch).init(device),
                linear: LinearConfig::new(flat_features, self.latent_dim).init(device),
                leaky_relu: LeakyRelu {
                    negative_slope: self.leaky_relu_slope,
                },
            },
            decoder: Decoder {
                linear: LinearConfig::new(self.latent_dim, flat_features).init(device),
                conv_trans1: ConvTranspose2dConfig::new([output_ch, hidden_ch], self.kernel_size)
                    .with_stride(self.stride)
                    .with_padding(self.padding)
                    .with_padding_out(self.output_padding)
                    .init(device),
                norm1: GroupNormConfig::new(self.groups, hidden_ch).init(device),
                conv_trans2: ConvTranspose2dConfig::new([hidden_ch, input_ch], self.kernel_size)
                    .with_stride(self.stride)
                    .with_padding(self.padding)
                    .with_padding_out(self.output_padding)
                    .init(device),
                leaky_relu: LeakyRelu {
                    negative_slope: self.leaky_relu_slope,
                },
                sigmoid: Sigmoid::new(),
            },
        }
    }
}

impl<B: Backend> Autoencoder<B> {
    /// Forward pass returning reconstructed input and latent embeddings.
    ///
    /// # Data layout
    /// - Input: [batch, channels, height, width]
    /// - Output embeddings: [batch, latent_dim]
    pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 2>) {
        // Encoder
        let x = self.encoder.conv1.forward(x);
        let x = self.encoder.norm1.forward(x);
        let x = self.encoder.leaky_relu.forward(x);

        let x = self.encoder.conv2.forward(x);
        let x = self.encoder.norm2.forward(x);
        let x = self.encoder.leaky_relu.forward(x);

        let [batch_size, channels, height, width] = x.dims();
        let x = x.reshape([batch_size, channels * height * width]);
        let embeddings = self.encoder.linear.forward(x);

        // Decoder
        let x = self.decoder.linear.forward(embeddings.clone());
        let x = x.reshape([batch_size, channels, height, width]);

        let x = self.decoder.conv_trans1.forward(x);
        let x = self.decoder.norm1.forward(x);
        let x = self.decoder.leaky_relu.forward(x);

        let x = self.decoder.conv_trans2.forward(x);
        let recon = self.decoder.sigmoid.forward(x);

        (recon, embeddings)
    }

    /// Forward pass returning regression output for autoencoder pretraining.
    ///
    /// # See also
    /// [`RegressionOutput`]
    pub fn forward_regression(&self, x: Tensor<B, 4>) -> RegressionOutput<B> {
        let (recon, _) = self.forward(x.clone());
        let loss = MseLoss::new().forward(recon.clone(), x.clone(), Reduction::Mean);

        RegressionOutput {
            loss,
            output: recon.flatten(1, 3),
            targets: x.flatten(1, 3),
        }
    }
}

impl<B: AutodiffBackend> TrainStep<Batch<B>, RegressionOutput<B>> for Autoencoder<B> {
    fn step(&self, batch: Batch<B>) -> TrainOutput<RegressionOutput<B>> {
        let item = self.forward_regression(batch.images);

        TrainOutput::new(self, item.loss.backward(), item)
    }
}

impl<B: Backend> ValidStep<Batch<B>, RegressionOutput<B>> for Autoencoder<B> {
    fn step(&self, batch: Batch<B>) -> RegressionOutput<B> {
        self.forward_regression(batch.images)
    }
}