use crate::autograd::Tensor;
use crate::nn::{Linear, Module, ReLU};
#[derive(Debug)]
pub struct VAEOutput {
pub reconstruction: Tensor,
pub mu: Tensor,
pub log_var: Tensor,
pub z: Tensor,
}
pub struct VAE {
encoder_layers: Vec<Linear>,
encoder_activation: ReLU,
fc_mu: Linear,
fc_log_var: Linear,
decoder_layers: Vec<Linear>,
decoder_activation: ReLU,
output_layer: Linear,
input_dim: usize,
latent_dim: usize,
hidden_dims: Vec<usize>,
training: bool,
beta: f32, }
impl VAE {
#[must_use]
pub fn new(input_dim: usize, hidden_dims: Vec<usize>, latent_dim: usize) -> Self {
let mut encoder_layers = Vec::new();
let mut prev_dim = input_dim;
for &hidden_dim in &hidden_dims {
encoder_layers.push(Linear::new(prev_dim, hidden_dim));
prev_dim = hidden_dim;
}
let last_hidden = *hidden_dims.last().unwrap_or(&input_dim);
let fc_mu = Linear::new(last_hidden, latent_dim);
let fc_log_var = Linear::new(last_hidden, latent_dim);
let mut decoder_layers = Vec::new();
prev_dim = latent_dim;
for &hidden_dim in hidden_dims.iter().rev() {
decoder_layers.push(Linear::new(prev_dim, hidden_dim));
prev_dim = hidden_dim;
}
let output_layer = Linear::new(prev_dim, input_dim);
Self {
encoder_layers,
encoder_activation: ReLU::new(),
fc_mu,
fc_log_var,
decoder_layers,
decoder_activation: ReLU::new(),
output_layer,
input_dim,
latent_dim,
hidden_dims,
training: true,
beta: 1.0,
}
}
#[must_use]
pub fn with_beta(mut self, beta: f32) -> Self {
self.beta = beta;
self
}
#[must_use]
pub fn encode(&self, x: &Tensor) -> (Tensor, Tensor) {
let mut h = x.clone();
for layer in &self.encoder_layers {
h = layer.forward(&h);
h = self.encoder_activation.forward(&h);
}
let mu = self.fc_mu.forward(&h);
let log_var = self.fc_log_var.forward(&h);
(mu, log_var)
}
#[must_use]
pub fn reparameterize(&self, mu: &Tensor, log_var: &Tensor) -> Tensor {
if !self.training {
return mu.clone();
}
let epsilon = sample_standard_normal(mu.shape());
let std = exp_half(log_var);
add_mul(mu, &std, &epsilon)
}
#[must_use]
pub fn decode(&self, z: &Tensor) -> Tensor {
let mut h = z.clone();
for layer in &self.decoder_layers {
h = layer.forward(&h);
h = self.decoder_activation.forward(&h);
}
self.output_layer.forward(&h)
}
#[must_use]
pub fn forward_vae(&self, x: &Tensor) -> VAEOutput {
let (mu, log_var) = self.encode(x);
let z = self.reparameterize(&mu, &log_var);
let reconstruction = self.decode(&z);
VAEOutput {
reconstruction,
mu,
log_var,
z,
}
}
#[must_use]
pub fn loss(&self, output: &VAEOutput, target: &Tensor) -> (f32, f32, f32) {
let recon_loss = mse_loss(&output.reconstruction, target);
let kl_loss = kl_divergence_loss(&output.mu, &output.log_var);
let total_loss = recon_loss + self.beta * kl_loss;
(total_loss, recon_loss, kl_loss)
}
#[must_use]
pub fn sample(&self, num_samples: usize) -> Tensor {
let z = sample_standard_normal(&[num_samples, self.latent_dim]);
self.decode(&z)
}
#[must_use]
pub fn interpolate(&self, x1: &Tensor, x2: &Tensor, steps: usize) -> Vec<Tensor> {
let (mu1, _) = self.encode(x1);
let (mu2, _) = self.encode(x2);
let mut results = Vec::with_capacity(steps);
for i in 0..steps {
let alpha = i as f32 / (steps - 1) as f32;
let z = lerp(&mu1, &mu2, alpha);
results.push(self.decode(&z));
}
results
}
#[must_use]
pub fn latent_dim(&self) -> usize {
self.latent_dim
}
#[must_use]
pub fn input_dim(&self) -> usize {
self.input_dim
}
#[must_use]
pub fn beta(&self) -> f32 {
self.beta
}
}
impl Module for VAE {
fn forward(&self, input: &Tensor) -> Tensor {
let output = self.forward_vae(input);
output.reconstruction
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = Vec::new();
for layer in &self.encoder_layers {
params.extend(layer.parameters());
}
params.extend(self.fc_mu.parameters());
params.extend(self.fc_log_var.parameters());
for layer in &self.decoder_layers {
params.extend(layer.parameters());
}
params.extend(self.output_layer.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = Vec::new();
for layer in &mut self.encoder_layers {
params.extend(layer.parameters_mut());
}
params.extend(self.fc_mu.parameters_mut());
params.extend(self.fc_log_var.parameters_mut());
for layer in &mut self.decoder_layers {
params.extend(layer.parameters_mut());
}
params.extend(self.output_layer.parameters_mut());
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for VAE {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VAE")
.field("input_dim", &self.input_dim)
.field("hidden_dims", &self.hidden_dims)
.field("latent_dim", &self.latent_dim)
.field("beta", &self.beta)
.finish_non_exhaustive()
}
}
pub struct ConditionalVAE {
encoder_layers: Vec<Linear>,
encoder_activation: ReLU,
fc_mu: Linear,
fc_log_var: Linear,
decoder_layers: Vec<Linear>,
decoder_activation: ReLU,
output_layer: Linear,
input_dim: usize,
latent_dim: usize,
num_classes: usize,
hidden_dims: Vec<usize>,
training: bool,
}
#[path = "vae_conditional.rs"]
mod vae_conditional;
#[allow(clippy::wildcard_imports)]
use vae_conditional::*;
#[path = "vae_tests.rs"]
mod vae_tests;