use ghostflow_core::Tensor;
use crate::deep::layers::{Conv2d, Dense, BatchNorm2d};
use crate::deep::activations::{ReLU, LeakyReLU, Tanh, Sigmoid};
pub struct DCGANGenerator {
fc: Dense,
deconv1: Conv2d,
bn1: BatchNorm2d,
deconv2: Conv2d,
bn2: BatchNorm2d,
deconv3: Conv2d,
bn3: BatchNorm2d,
deconv4: Conv2d,
latent_dim: usize,
}
impl DCGANGenerator {
pub fn new(latent_dim: usize, img_channels: usize) -> Self {
DCGANGenerator {
fc: Dense::new(latent_dim, 512 * 4 * 4),
deconv1: Conv2d::new(512, 256, (4, 4)).stride((2, 2)).padding((1, 1)),
bn1: BatchNorm2d::new(256),
deconv2: Conv2d::new(256, 128, (4, 4)).stride((2, 2)).padding((1, 1)),
bn2: BatchNorm2d::new(128),
deconv3: Conv2d::new(128, 64, (4, 4)).stride((2, 2)).padding((1, 1)),
bn3: BatchNorm2d::new(64),
deconv4: Conv2d::new(64, img_channels, (4, 4)).stride((2, 2)).padding((1, 1)),
latent_dim,
}
}
pub fn forward(&mut self, z: &Tensor, training: bool) -> Tensor {
let mut out = self.fc.forward(z, training);
let batch_size = z.dims()[0];
out = Tensor::from_slice(out.data_f32(), &[batch_size, 512, 4, 4]).unwrap();
out = self.deconv1.forward(&out, training);
out = self.bn1.forward(&out, training);
out = ReLU::new().forward(&out);
out = self.deconv2.forward(&out, training);
out = self.bn2.forward(&out, training);
out = ReLU::new().forward(&out);
out = self.deconv3.forward(&out, training);
out = self.bn3.forward(&out, training);
out = ReLU::new().forward(&out);
out = self.deconv4.forward(&out, training);
Tanh::new().forward(&out)
}
}
pub struct DCGANDiscriminator {
conv1: Conv2d,
conv2: Conv2d,
bn2: BatchNorm2d,
conv3: Conv2d,
bn3: BatchNorm2d,
conv4: Conv2d,
bn4: BatchNorm2d,
fc: Dense,
}
impl DCGANDiscriminator {
pub fn new(img_channels: usize) -> Self {
DCGANDiscriminator {
conv1: Conv2d::new(img_channels, 64, (4, 4)).stride((2, 2)).padding((1, 1)),
conv2: Conv2d::new(64, 128, (4, 4)).stride((2, 2)).padding((1, 1)),
bn2: BatchNorm2d::new(128),
conv3: Conv2d::new(128, 256, (4, 4)).stride((2, 2)).padding((1, 1)),
bn3: BatchNorm2d::new(256),
conv4: Conv2d::new(256, 512, (4, 4)).stride((2, 2)).padding((1, 1)),
bn4: BatchNorm2d::new(512),
fc: Dense::new(512 * 4 * 4, 1),
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.conv1.forward(x, training);
out = LeakyReLU::new(0.2).forward(&out);
out = self.conv2.forward(&out, training);
out = self.bn2.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
out = self.conv3.forward(&out, training);
out = self.bn3.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
out = self.conv4.forward(&out, training);
out = self.bn4.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
let batch_size = out.dims()[0];
out = Tensor::from_slice(out.data_f32(), &[batch_size, 512 * 4 * 4]).unwrap();
out = self.fc.forward(&out, training);
Sigmoid::new().forward(&out)
}
}
pub struct StyleGANMappingNetwork {
layers: Vec<Dense>,
latent_dim: usize,
style_dim: usize,
}
impl StyleGANMappingNetwork {
pub fn new(latent_dim: usize, style_dim: usize, num_layers: usize) -> Self {
let mut layers = Vec::new();
layers.push(Dense::new(latent_dim, style_dim));
for _ in 1..num_layers {
layers.push(Dense::new(style_dim, style_dim));
}
StyleGANMappingNetwork {
layers,
latent_dim,
style_dim,
}
}
pub fn forward(&mut self, z: &Tensor, training: bool) -> Tensor {
let mut out = z.clone();
for layer in &mut self.layers {
out = layer.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
}
out
}
}
pub struct AdaIN {
style_scale: Dense,
style_bias: Dense,
}
impl AdaIN {
pub fn new(style_dim: usize, num_features: usize) -> Self {
AdaIN {
style_scale: Dense::new(style_dim, num_features),
style_bias: Dense::new(style_dim, num_features),
}
}
pub fn forward(&mut self, x: &Tensor, style: &Tensor, training: bool) -> Tensor {
let (mean, std) = self.compute_instance_stats(x);
let normalized = self.normalize(x, &mean, &std);
let scale = self.style_scale.forward(style, training);
let bias = self.style_bias.forward(style, training);
self.apply_style(&normalized, &scale, &bias)
}
fn compute_instance_stats(&self, x: &Tensor) -> (Tensor, Tensor) {
let dims = x.dims();
let batch_size = dims[0];
let channels = dims[1];
let spatial_size = dims[2] * dims[3];
let data = x.data_f32();
let mut means = vec![0.0f32; batch_size * channels];
let mut stds = vec![0.0f32; batch_size * channels];
for b in 0..batch_size {
for c in 0..channels {
let mut sum = 0.0f32;
for s in 0..spatial_size {
let idx = (b * channels + c) * spatial_size + s;
sum += data[idx];
}
let mean = sum / spatial_size as f32;
means[b * channels + c] = mean;
let mut var_sum = 0.0f32;
for s in 0..spatial_size {
let idx = (b * channels + c) * spatial_size + s;
var_sum += (data[idx] - mean).powi(2);
}
stds[b * channels + c] = (var_sum / spatial_size as f32).sqrt();
}
}
(
Tensor::from_slice(&means, &[batch_size, channels]).unwrap(),
Tensor::from_slice(&stds, &[batch_size, channels]).unwrap(),
)
}
fn normalize(&self, x: &Tensor, mean: &Tensor, std: &Tensor) -> Tensor {
let dims = x.dims();
let batch_size = dims[0];
let channels = dims[1];
let spatial_size = dims[2] * dims[3];
let data = x.data_f32();
let mean_data = mean.data_f32();
let std_data = std.data_f32();
let mut result = vec![0.0f32; data.len()];
for b in 0..batch_size {
for c in 0..channels {
let m = mean_data[b * channels + c];
let s = std_data[b * channels + c].max(1e-8);
for s_idx in 0..spatial_size {
let idx = (b * channels + c) * spatial_size + s_idx;
result[idx] = (data[idx] - m) / s;
}
}
}
Tensor::from_slice(&result, dims).unwrap()
}
fn apply_style(&self, x: &Tensor, scale: &Tensor, bias: &Tensor) -> Tensor {
let dims = x.dims();
let batch_size = dims[0];
let channels = dims[1];
let spatial_size = dims[2] * dims[3];
let data = x.data_f32();
let scale_data = scale.data_f32();
let bias_data = bias.data_f32();
let mut result = vec![0.0f32; data.len()];
for b in 0..batch_size {
for c in 0..channels {
let s = scale_data[b * channels + c];
let b_val = bias_data[b * channels + c];
for s_idx in 0..spatial_size {
let idx = (b * channels + c) * spatial_size + s_idx;
result[idx] = data[idx] * s + b_val;
}
}
}
Tensor::from_slice(&result, dims).unwrap()
}
}
pub struct StyleGANSynthesisBlock {
conv1: Conv2d,
adain1: AdaIN,
conv2: Conv2d,
adain2: AdaIN,
}
impl StyleGANSynthesisBlock {
pub fn new(in_channels: usize, out_channels: usize, style_dim: usize) -> Self {
StyleGANSynthesisBlock {
conv1: Conv2d::new(in_channels, out_channels, (3, 3)).padding((1, 1)),
adain1: AdaIN::new(style_dim, out_channels),
conv2: Conv2d::new(out_channels, out_channels, (3, 3)).padding((1, 1)),
adain2: AdaIN::new(style_dim, out_channels),
}
}
pub fn forward(&mut self, x: &Tensor, style: &Tensor, training: bool) -> Tensor {
let mut out = self.conv1.forward(x, training);
out = self.adain1.forward(&out, style, training);
out = LeakyReLU::new(0.2).forward(&out);
out = self.conv2.forward(&out, training);
out = self.adain2.forward(&out, style, training);
LeakyReLU::new(0.2).forward(&out)
}
}
pub struct StyleGANGenerator {
mapping: StyleGANMappingNetwork,
constant_input: Vec<f32>,
synthesis_blocks: Vec<StyleGANSynthesisBlock>,
to_rgb: Conv2d,
}
impl StyleGANGenerator {
pub fn new(latent_dim: usize, style_dim: usize, img_channels: usize) -> Self {
use rand::prelude::*;
let mut rng = thread_rng();
let constant_input: Vec<f32> = (0..512 * 4 * 4)
.map(|_| rng.gen::<f32>() * 0.02 - 0.01)
.collect();
StyleGANGenerator {
mapping: StyleGANMappingNetwork::new(latent_dim, style_dim, 8),
constant_input,
synthesis_blocks: vec![
StyleGANSynthesisBlock::new(512, 512, style_dim),
StyleGANSynthesisBlock::new(512, 256, style_dim),
StyleGANSynthesisBlock::new(256, 128, style_dim),
StyleGANSynthesisBlock::new(128, 64, style_dim),
],
to_rgb: Conv2d::new(64, img_channels, (1, 1)),
}
}
pub fn forward(&mut self, z: &Tensor, training: bool) -> Tensor {
let style = self.mapping.forward(z, training);
let batch_size = z.dims()[0];
let mut out = Tensor::from_slice(&self.constant_input, &[1, 512, 4, 4]).unwrap();
let mut batch_data = Vec::new();
for _ in 0..batch_size {
batch_data.extend_from_slice(out.data_f32());
}
out = Tensor::from_slice(&batch_data, &[batch_size, 512, 4, 4]).unwrap();
for block in &mut self.synthesis_blocks {
out = block.forward(&out, &style, training);
out = self.upsample(&out);
}
out = self.to_rgb.forward(&out, training);
Tanh::new().forward(&out)
}
fn upsample(&self, x: &Tensor) -> Tensor {
let dims = x.dims();
let batch = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
let data = x.data_f32();
let new_height = height * 2;
let new_width = width * 2;
let mut result = vec![0.0f32; batch * channels * new_height * new_width];
for b in 0..batch {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
let val = data[((b * channels + c) * height + h) * width + w];
for dh in 0..2 {
for dw in 0..2 {
let new_h = h * 2 + dh;
let new_w = w * 2 + dw;
let idx = ((b * channels + c) * new_height + new_h) * new_width + new_w;
result[idx] = val;
}
}
}
}
}
}
Tensor::from_slice(&result, &[batch, channels, new_height, new_width]).unwrap()
}
}
pub struct WGANCritic {
conv1: Conv2d,
conv2: Conv2d,
conv3: Conv2d,
conv4: Conv2d,
fc: Dense,
}
impl WGANCritic {
pub fn new(img_channels: usize) -> Self {
WGANCritic {
conv1: Conv2d::new(img_channels, 64, (4, 4)).stride((2, 2)).padding((1, 1)),
conv2: Conv2d::new(64, 128, (4, 4)).stride((2, 2)).padding((1, 1)),
conv3: Conv2d::new(128, 256, (4, 4)).stride((2, 2)).padding((1, 1)),
conv4: Conv2d::new(256, 512, (4, 4)).stride((2, 2)).padding((1, 1)),
fc: Dense::new(512 * 4 * 4, 1),
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.conv1.forward(x, training);
out = LeakyReLU::new(0.2).forward(&out);
out = self.conv2.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
out = self.conv3.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
out = self.conv4.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
let batch_size = out.dims()[0];
out = Tensor::from_slice(out.data_f32(), &[batch_size, 512 * 4 * 4]).unwrap();
self.fc.forward(&out, training)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dcgan_generator() {
let mut gen = DCGANGenerator::new(100, 3);
let z = Tensor::from_slice(&vec![0.5f32; 100], &[1, 100]).unwrap();
let output = gen.forward(&z, false);
assert_eq!(output.dims()[1], 3);
}
#[test]
fn test_dcgan_discriminator() {
let mut disc = DCGANDiscriminator::new(3);
let img = Tensor::from_slice(&vec![0.5f32; 3 * 64 * 64], &[1, 3, 64, 64]).unwrap();
let output = disc.forward(&img, false);
assert_eq!(output.dims()[1], 1);
}
#[test]
fn test_stylegan_generator() {
let mut gen = StyleGANGenerator::new(512, 512, 3);
let z = Tensor::from_slice(&vec![0.5f32; 512], &[1, 512]).unwrap();
let output = gen.forward(&z, false);
assert_eq!(output.dims()[1], 3);
}
#[test]
fn test_wgan_critic() {
let mut critic = WGANCritic::new(3);
let img = Tensor::from_slice(&vec![0.5f32; 3 * 64 * 64], &[1, 3, 64, 64]).unwrap();
let output = critic.forward(&img, false);
assert_eq!(output.dims()[1], 1);
}
}
pub struct CycleGANGenerator {
initial: Vec<(Conv2d, BatchNorm2d)>,
downsampling: Vec<(Conv2d, BatchNorm2d)>,
residual_blocks: Vec<ResNetBlock>,
upsampling: Vec<(Conv2d, BatchNorm2d)>,
output_conv: Conv2d,
}
struct ResNetBlock {
conv1: Conv2d,
bn1: BatchNorm2d,
conv2: Conv2d,
bn2: BatchNorm2d,
}
impl ResNetBlock {
fn new(channels: usize) -> Self {
ResNetBlock {
conv1: Conv2d::new(channels, channels, (3, 3)).padding((1, 1)),
bn1: BatchNorm2d::new(channels),
conv2: Conv2d::new(channels, channels, (3, 3)).padding((1, 1)),
bn2: BatchNorm2d::new(channels),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let identity = x.clone();
let mut out = self.conv1.forward(x, training);
out = self.bn1.forward(&out, training);
out = ReLU::new().forward(&out);
out = self.conv2.forward(&out, training);
out = self.bn2.forward(&out, training);
let out_data = out.data_f32();
let id_data = identity.data_f32();
let result: Vec<f32> = out_data.iter()
.zip(id_data.iter())
.map(|(&o, &i)| o + i)
.collect();
Tensor::from_slice(&result, out.dims()).unwrap()
}
}
impl CycleGANGenerator {
pub fn new(img_channels: usize, num_residual_blocks: usize) -> Self {
let ngf = 64;
CycleGANGenerator {
initial: vec![
(Conv2d::new(img_channels, ngf, (7, 7)).padding((3, 3)), BatchNorm2d::new(ngf)),
],
downsampling: vec![
(Conv2d::new(ngf, ngf * 2, (3, 3)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(ngf * 2)),
(Conv2d::new(ngf * 2, ngf * 4, (3, 3)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(ngf * 4)),
],
residual_blocks: (0..num_residual_blocks).map(|_| ResNetBlock::new(ngf * 4)).collect(),
upsampling: vec![
(Conv2d::new(ngf * 4, ngf * 2, (3, 3)).padding((1, 1)), BatchNorm2d::new(ngf * 2)),
(Conv2d::new(ngf * 2, ngf, (3, 3)).padding((1, 1)), BatchNorm2d::new(ngf)),
],
output_conv: Conv2d::new(ngf, img_channels, (7, 7)).padding((3, 3)),
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = x.clone();
for (conv, bn) in &mut self.initial {
out = conv.forward(&out, training);
out = bn.forward(&out, training);
out = ReLU::new().forward(&out);
}
for (conv, bn) in &mut self.downsampling {
out = conv.forward(&out, training);
out = bn.forward(&out, training);
out = ReLU::new().forward(&out);
}
for block in &mut self.residual_blocks {
out = block.forward(&out, training);
}
for (conv, bn) in &mut self.upsampling {
out = self.upsample(&out);
out = conv.forward(&out, training);
out = bn.forward(&out, training);
out = ReLU::new().forward(&out);
}
out = self.output_conv.forward(&out, training);
Tanh::new().forward(&out)
}
fn upsample(&self, x: &Tensor) -> Tensor {
let dims = x.dims();
let batch = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
let data = x.data_f32();
let new_height = height * 2;
let new_width = width * 2;
let mut result = vec![0.0f32; batch * channels * new_height * new_width];
for b in 0..batch {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
let val = data[((b * channels + c) * height + h) * width + w];
for dh in 0..2 {
for dw in 0..2 {
let new_h = h * 2 + dh;
let new_w = w * 2 + dw;
let idx = ((b * channels + c) * new_height + new_h) * new_width + new_w;
result[idx] = val;
}
}
}
}
}
}
Tensor::from_slice(&result, &[batch, channels, new_height, new_width]).unwrap()
}
}
pub struct CycleGANDiscriminator {
layers: Vec<(Conv2d, Option<BatchNorm2d>)>,
}
impl CycleGANDiscriminator {
pub fn new(img_channels: usize) -> Self {
let ndf = 64;
CycleGANDiscriminator {
layers: vec![
(Conv2d::new(img_channels, ndf, (4, 4)).stride((2, 2)).padding((1, 1)), None),
(Conv2d::new(ndf, ndf * 2, (4, 4)).stride((2, 2)).padding((1, 1)), Some(BatchNorm2d::new(ndf * 2))),
(Conv2d::new(ndf * 2, ndf * 4, (4, 4)).stride((2, 2)).padding((1, 1)), Some(BatchNorm2d::new(ndf * 4))),
(Conv2d::new(ndf * 4, ndf * 8, (4, 4)).stride((1, 1)).padding((1, 1)), Some(BatchNorm2d::new(ndf * 8))),
(Conv2d::new(ndf * 8, 1, (4, 4)).stride((1, 1)).padding((1, 1)), None),
],
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = x.clone();
for (i, (conv, bn_opt)) in self.layers.iter_mut().enumerate() {
out = conv.forward(&out, training);
if let Some(bn) = bn_opt {
out = bn.forward(&out, training);
}
if i < self.layers.len() - 1 {
out = LeakyReLU::new(0.2).forward(&out);
}
}
out
}
}
pub struct Pix2PixGenerator {
encoder: Vec<(Conv2d, BatchNorm2d)>,
decoder: Vec<(Conv2d, BatchNorm2d)>,
final_conv: Conv2d,
}
impl Pix2PixGenerator {
pub fn new(in_channels: usize, out_channels: usize) -> Self {
let ngf = 64;
Pix2PixGenerator {
encoder: vec![
(Conv2d::new(in_channels, ngf, (4, 4)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(ngf)),
(Conv2d::new(ngf, ngf * 2, (4, 4)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(ngf * 2)),
(Conv2d::new(ngf * 2, ngf * 4, (4, 4)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(ngf * 4)),
(Conv2d::new(ngf * 4, ngf * 8, (4, 4)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(ngf * 8)),
(Conv2d::new(ngf * 8, ngf * 8, (4, 4)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(ngf * 8)),
],
decoder: vec![
(Conv2d::new(ngf * 8, ngf * 8, (4, 4)).padding((1, 1)), BatchNorm2d::new(ngf * 8)),
(Conv2d::new(ngf * 16, ngf * 4, (4, 4)).padding((1, 1)), BatchNorm2d::new(ngf * 4)),
(Conv2d::new(ngf * 8, ngf * 2, (4, 4)).padding((1, 1)), BatchNorm2d::new(ngf * 2)),
(Conv2d::new(ngf * 4, ngf, (4, 4)).padding((1, 1)), BatchNorm2d::new(ngf)),
],
final_conv: Conv2d::new(ngf * 2, out_channels, (4, 4)).padding((1, 1)),
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut encoder_outputs = Vec::new();
let mut out = x.clone();
for (i, (conv, bn)) in self.encoder.iter_mut().enumerate() {
out = conv.forward(&out, training);
if i > 0 {
out = bn.forward(&out, training);
}
out = LeakyReLU::new(0.2).forward(&out);
encoder_outputs.push(out.clone());
}
for (i, (conv, bn)) in self.decoder.iter_mut().enumerate() {
out = self.upsample(&out);
if i < encoder_outputs.len() - 1 {
let skip_idx = encoder_outputs.len() - 2 - i;
out = self.concatenate(&out, &encoder_outputs[skip_idx]);
}
out = conv.forward(&out, training);
out = bn.forward(&out, training);
out = ReLU::new().forward(&out);
}
out = self.upsample(&out);
out = self.concatenate(&out, &encoder_outputs[0]);
out = self.final_conv.forward(&out, training);
Tanh::new().forward(&out)
}
fn upsample(&self, x: &Tensor) -> Tensor {
let dims = x.dims();
let batch = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
let data = x.data_f32();
let new_height = height * 2;
let new_width = width * 2;
let mut result = vec![0.0f32; batch * channels * new_height * new_width];
for b in 0..batch {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
let val = data[((b * channels + c) * height + h) * width + w];
for dh in 0..2 {
for dw in 0..2 {
let new_h = h * 2 + dh;
let new_w = w * 2 + dw;
let idx = ((b * channels + c) * new_height + new_h) * new_width + new_w;
result[idx] = val;
}
}
}
}
}
}
Tensor::from_slice(&result, &[batch, channels, new_height, new_width]).unwrap()
}
fn concatenate(&self, x1: &Tensor, x2: &Tensor) -> Tensor {
let dims1 = x1.dims();
let dims2 = x2.dims();
let batch = dims1[0];
let channels1 = dims1[1];
let channels2 = dims2[1];
let height = dims1[2];
let width = dims1[3];
let total_channels = channels1 + channels2;
let mut result = Vec::new();
for b in 0..batch {
for c in 0..channels1 {
for h in 0..height {
for w in 0..width {
let idx = ((b * channels1 + c) * height + h) * width + w;
result.push(x1.data_f32()[idx]);
}
}
}
for c in 0..channels2 {
for h in 0..height {
for w in 0..width {
let idx = ((b * channels2 + c) * height + h) * width + w;
result.push(x2.data_f32()[idx]);
}
}
}
}
Tensor::from_slice(&result, &[batch, total_channels, height, width]).unwrap()
}
}
pub struct ProGANGeneratorBlock {
conv1: Conv2d,
conv2: Conv2d,
bn1: BatchNorm2d,
bn2: BatchNorm2d,
}
impl ProGANGeneratorBlock {
pub fn new(in_channels: usize, out_channels: usize) -> Self {
ProGANGeneratorBlock {
conv1: Conv2d::new(in_channels, out_channels, (3, 3)).padding((1, 1)),
bn1: BatchNorm2d::new(out_channels),
conv2: Conv2d::new(out_channels, out_channels, (3, 3)).padding((1, 1)),
bn2: BatchNorm2d::new(out_channels),
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.conv1.forward(x, training);
out = self.bn1.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
out = self.conv2.forward(&out, training);
out = self.bn2.forward(&out, training);
LeakyReLU::new(0.2).forward(&out)
}
}
pub struct ConditionalGANGenerator {
label_embedding: Dense,
fc: Dense,
deconv_blocks: Vec<(Conv2d, BatchNorm2d)>,
final_conv: Conv2d,
latent_dim: usize,
num_classes: usize,
}
impl ConditionalGANGenerator {
pub fn new(latent_dim: usize, num_classes: usize, img_channels: usize) -> Self {
let embedding_dim = 50;
ConditionalGANGenerator {
label_embedding: Dense::new(num_classes, embedding_dim),
fc: Dense::new(latent_dim + embedding_dim, 256 * 7 * 7),
deconv_blocks: vec![
(Conv2d::new(256, 128, (4, 4)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(128)),
(Conv2d::new(128, 64, (4, 4)).stride((2, 2)).padding((1, 1)), BatchNorm2d::new(64)),
],
final_conv: Conv2d::new(64, img_channels, (3, 3)).padding((1, 1)),
latent_dim,
num_classes,
}
}
pub fn forward(&mut self, z: &Tensor, labels: &Tensor, training: bool) -> Tensor {
let label_embed = self.label_embedding.forward(labels, training);
let combined = self.concatenate_vectors(z, &label_embed);
let mut out = self.fc.forward(&combined, training);
let batch_size = z.dims()[0];
out = Tensor::from_slice(out.data_f32(), &[batch_size, 256, 7, 7]).unwrap();
for (conv, bn) in &mut self.deconv_blocks {
out = conv.forward(&out, training);
out = bn.forward(&out, training);
out = ReLU::new().forward(&out);
}
out = self.final_conv.forward(&out, training);
Tanh::new().forward(&out)
}
fn concatenate_vectors(&self, x1: &Tensor, x2: &Tensor) -> Tensor {
let data1 = x1.data_f32();
let data2 = x2.data_f32();
let batch_size = x1.dims()[0];
let dim1 = x1.dims()[1];
let dim2 = x2.dims()[1];
let mut result = Vec::new();
for b in 0..batch_size {
for i in 0..dim1 {
result.push(data1[b * dim1 + i]);
}
for i in 0..dim2 {
result.push(data2[b * dim2 + i]);
}
}
Tensor::from_slice(&result, &[batch_size, dim1 + dim2]).unwrap()
}
}
pub struct BigGANGenerator {
linear: Dense,
blocks: Vec<BigGANBlock>,
output_conv: Conv2d,
latent_dim: usize,
}
struct BigGANBlock {
bn1: BatchNorm2d,
conv1: Conv2d,
bn2: BatchNorm2d,
conv2: Conv2d,
class_embed: Dense,
}
impl BigGANBlock {
fn new(in_channels: usize, out_channels: usize, num_classes: usize) -> Self {
BigGANBlock {
bn1: BatchNorm2d::new(in_channels),
conv1: Conv2d::new(in_channels, out_channels, (3, 3)).padding((1, 1)),
bn2: BatchNorm2d::new(out_channels),
conv2: Conv2d::new(out_channels, out_channels, (3, 3)).padding((1, 1)),
class_embed: Dense::new(num_classes, in_channels * 2),
}
}
fn forward(&mut self, x: &Tensor, class_label: &Tensor, training: bool) -> Tensor {
let class_emb = self.class_embed.forward(class_label, training);
let mut out = self.bn1.forward(x, training);
out = self.apply_class_conditioning(&out, &class_emb);
out = ReLU::new().forward(&out);
out = self.conv1.forward(&out, training);
out = self.bn2.forward(&out, training);
out = ReLU::new().forward(&out);
out = self.conv2.forward(&out, training);
self.upsample(&out)
}
fn apply_class_conditioning(&self, x: &Tensor, class_emb: &Tensor) -> Tensor {
x.clone() }
fn upsample(&self, x: &Tensor) -> Tensor {
let dims = x.dims();
let batch = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
let data = x.data_f32();
let new_height = height * 2;
let new_width = width * 2;
let mut result = vec![0.0f32; batch * channels * new_height * new_width];
for b in 0..batch {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
let val = data[((b * channels + c) * height + h) * width + w];
for dh in 0..2 {
for dw in 0..2 {
let new_h = h * 2 + dh;
let new_w = w * 2 + dw;
let idx = ((b * channels + c) * new_height + new_h) * new_width + new_w;
result[idx] = val;
}
}
}
}
}
}
Tensor::from_slice(&result, &[batch, channels, new_height, new_width]).unwrap()
}
}
impl BigGANGenerator {
pub fn new(latent_dim: usize, num_classes: usize, img_channels: usize) -> Self {
BigGANGenerator {
linear: Dense::new(latent_dim, 128 * 4 * 4),
blocks: vec![
BigGANBlock::new(128, 128, num_classes),
BigGANBlock::new(128, 64, num_classes),
BigGANBlock::new(64, 32, num_classes),
],
output_conv: Conv2d::new(32, img_channels, (3, 3)).padding((1, 1)),
latent_dim,
}
}
pub fn forward(&mut self, z: &Tensor, class_label: &Tensor, training: bool) -> Tensor {
let mut out = self.linear.forward(z, training);
let batch_size = z.dims()[0];
out = Tensor::from_slice(out.data_f32(), &[batch_size, 128, 4, 4]).unwrap();
for block in &mut self.blocks {
out = block.forward(&out, class_label, training);
}
out = self.output_conv.forward(&out, training);
Tanh::new().forward(&out)
}
}
pub struct StyleGAN2Generator {
mapping: StyleGAN2Mapping,
synthesis: StyleGAN2Synthesis,
}
struct StyleGAN2Mapping {
layers: Vec<Dense>,
}
impl StyleGAN2Mapping {
fn new(latent_dim: usize, style_dim: usize) -> Self {
StyleGAN2Mapping {
layers: vec![
Dense::new(latent_dim, style_dim),
Dense::new(style_dim, style_dim),
Dense::new(style_dim, style_dim),
],
}
}
fn forward(&mut self, z: &Tensor, training: bool) -> Tensor {
let mut out = z.clone();
for layer in &mut self.layers {
out = layer.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
}
out
}
}
struct StyleGAN2Synthesis {
const_input: Vec<f32>,
blocks: Vec<StyleGAN2Block>,
to_rgb: Conv2d,
}
struct StyleGAN2Block {
conv1: Conv2d,
conv2: Conv2d,
style_mod1: Dense,
style_mod2: Dense,
noise_weight1: f32,
noise_weight2: f32,
}
impl StyleGAN2Block {
fn new(in_channels: usize, out_channels: usize, style_dim: usize) -> Self {
StyleGAN2Block {
conv1: Conv2d::new(in_channels, out_channels, (3, 3)).padding((1, 1)),
conv2: Conv2d::new(out_channels, out_channels, (3, 3)).padding((1, 1)),
style_mod1: Dense::new(style_dim, in_channels * 2),
style_mod2: Dense::new(style_dim, out_channels * 2),
noise_weight1: 0.0,
noise_weight2: 0.0,
}
}
fn forward(&mut self, x: &Tensor, style: &Tensor, training: bool) -> Tensor {
let style_mod1 = self.style_mod1.forward(style, training);
let mut out = self.modulate(x, &style_mod1);
out = self.conv1.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
let style_mod2 = self.style_mod2.forward(style, training);
out = self.modulate(&out, &style_mod2);
out = self.conv2.forward(&out, training);
LeakyReLU::new(0.2).forward(&out)
}
fn modulate(&self, x: &Tensor, style: &Tensor) -> Tensor {
x.clone() }
}
impl StyleGAN2Synthesis {
fn new(style_dim: usize, img_channels: usize) -> Self {
use rand::prelude::*;
let mut rng = thread_rng();
let const_input: Vec<f32> = (0..512 * 4 * 4)
.map(|_| rng.gen::<f32>() * 0.02 - 0.01)
.collect();
StyleGAN2Synthesis {
const_input,
blocks: vec![
StyleGAN2Block::new(512, 512, style_dim),
StyleGAN2Block::new(512, 256, style_dim),
StyleGAN2Block::new(256, 128, style_dim),
],
to_rgb: Conv2d::new(128, img_channels, (1, 1)),
}
}
fn forward(&mut self, style: &Tensor, training: bool) -> Tensor {
let batch_size = style.dims()[0];
let mut out = Tensor::from_slice(&self.const_input, &[1, 512, 4, 4]).unwrap();
let mut batch_data = Vec::new();
for _ in 0..batch_size {
batch_data.extend_from_slice(out.data_f32());
}
out = Tensor::from_slice(&batch_data, &[batch_size, 512, 4, 4]).unwrap();
for block in &mut self.blocks {
out = block.forward(&out, style, training);
}
out = self.to_rgb.forward(&out, training);
Tanh::new().forward(&out)
}
}
impl StyleGAN2Generator {
pub fn new(latent_dim: usize, style_dim: usize, img_channels: usize) -> Self {
StyleGAN2Generator {
mapping: StyleGAN2Mapping::new(latent_dim, style_dim),
synthesis: StyleGAN2Synthesis::new(style_dim, img_channels),
}
}
pub fn forward(&mut self, z: &Tensor, training: bool) -> Tensor {
let style = self.mapping.forward(z, training);
self.synthesis.forward(&style, training)
}
}
pub struct InfoGAN {
generator: InfoGANGenerator,
discriminator: InfoGANDiscriminator,
}
struct InfoGANGenerator {
fc: Dense,
deconv_layers: Vec<Conv2d>,
}
impl InfoGANGenerator {
fn new(latent_dim: usize, code_dim: usize, img_channels: usize) -> Self {
InfoGANGenerator {
fc: Dense::new(latent_dim + code_dim, 256 * 7 * 7),
deconv_layers: vec![
Conv2d::new(256, 128, (4, 4)).stride((2, 2)).padding((1, 1)),
Conv2d::new(128, 64, (4, 4)).stride((2, 2)).padding((1, 1)),
Conv2d::new(64, img_channels, (3, 3)).padding((1, 1)),
],
}
}
fn forward(&mut self, z: &Tensor, code: &Tensor, training: bool) -> Tensor {
let combined = self.concatenate(z, code);
let mut out = self.fc.forward(&combined, training);
let batch_size = z.dims()[0];
out = Tensor::from_slice(out.data_f32(), &[batch_size, 256, 7, 7]).unwrap();
for layer in &mut self.deconv_layers {
out = layer.forward(&out, training);
out = ReLU::new().forward(&out);
}
Tanh::new().forward(&out)
}
fn concatenate(&self, a: &Tensor, b: &Tensor) -> Tensor {
let a_data = a.data_f32();
let b_data = b.data_f32();
let mut result = Vec::new();
result.extend_from_slice(a_data);
result.extend_from_slice(b_data);
Tensor::from_slice(&result, &[a.dims()[0], a.dims()[1] + b.dims()[1]]).unwrap()
}
}
struct InfoGANDiscriminator {
conv_layers: Vec<Conv2d>,
fc_disc: Dense,
fc_code: Dense,
}
impl InfoGANDiscriminator {
fn new(img_channels: usize, code_dim: usize) -> Self {
InfoGANDiscriminator {
conv_layers: vec![
Conv2d::new(img_channels, 64, (4, 4)).stride((2, 2)).padding((1, 1)),
Conv2d::new(64, 128, (4, 4)).stride((2, 2)).padding((1, 1)),
Conv2d::new(128, 256, (4, 4)).stride((2, 2)).padding((1, 1)),
],
fc_disc: Dense::new(256 * 4 * 4, 1),
fc_code: Dense::new(256 * 4 * 4, code_dim),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> (Tensor, Tensor) {
let mut out = x.clone();
for layer in &mut self.conv_layers {
out = layer.forward(&out, training);
out = LeakyReLU::new(0.2).forward(&out);
}
let batch_size = out.dims()[0];
let flat_size = out.data_f32().len() / batch_size;
out = Tensor::from_slice(out.data_f32(), &[batch_size, flat_size]).unwrap();
let disc_out = self.fc_disc.forward(&out, training);
let code_out = self.fc_code.forward(&out, training);
(disc_out, code_out)
}
}
impl InfoGAN {
pub fn new(latent_dim: usize, code_dim: usize, img_channels: usize) -> Self {
InfoGAN {
generator: InfoGANGenerator::new(latent_dim, code_dim, img_channels),
discriminator: InfoGANDiscriminator::new(img_channels, code_dim),
}
}
pub fn forward_generator(&mut self, z: &Tensor, code: &Tensor, training: bool) -> Tensor {
self.generator.forward(z, code, training)
}
pub fn forward_discriminator(&mut self, x: &Tensor, training: bool) -> (Tensor, Tensor) {
self.discriminator.forward(x, training)
}
}
pub struct SAGANGenerator {
initial: Dense,
blocks: Vec<SAGANBlock>,
attention: SelfAttentionLayer,
output_conv: Conv2d,
}
struct SAGANBlock {
conv: Conv2d,
bn: BatchNorm2d,
}
impl SAGANBlock {
fn new(in_channels: usize, out_channels: usize) -> Self {
SAGANBlock {
conv: Conv2d::new(in_channels, out_channels, (3, 3)).padding((1, 1)),
bn: BatchNorm2d::new(out_channels),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.conv.forward(x, training);
out = self.bn.forward(&out, training);
ReLU::new().forward(&out)
}
}
struct SelfAttentionLayer {
query: Conv2d,
key: Conv2d,
value: Conv2d,
gamma: f32,
}
impl SelfAttentionLayer {
fn new(channels: usize) -> Self {
SelfAttentionLayer {
query: Conv2d::new(channels, channels / 8, (1, 1)),
key: Conv2d::new(channels, channels / 8, (1, 1)),
value: Conv2d::new(channels, channels, (1, 1)),
gamma: 0.0,
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let q = self.query.forward(x, training);
let k = self.key.forward(x, training);
let v = self.value.forward(x, training);
v
}
}
impl SAGANGenerator {
pub fn new(latent_dim: usize, img_channels: usize) -> Self {
SAGANGenerator {
initial: Dense::new(latent_dim, 256 * 4 * 4),
blocks: vec![
SAGANBlock::new(256, 256),
SAGANBlock::new(256, 128),
SAGANBlock::new(128, 64),
],
attention: SelfAttentionLayer::new(128),
output_conv: Conv2d::new(64, img_channels, (3, 3)).padding((1, 1)),
}
}
pub fn forward(&mut self, z: &Tensor, training: bool) -> Tensor {
let mut out = self.initial.forward(z, training);
let batch_size = z.dims()[0];
out = Tensor::from_slice(out.data_f32(), &[batch_size, 256, 4, 4]).unwrap();
for (i, block) in self.blocks.iter_mut().enumerate() {
out = block.forward(&out, training);
if i == 1 {
out = self.attention.forward(&out, training);
}
}
out = self.output_conv.forward(&out, training);
Tanh::new().forward(&out)
}
}
#[cfg(test)]
mod tests_gan_variants {
use super::*;
#[test]
fn test_biggan() {
let mut gen = BigGANGenerator::new(128, 1000, 3);
let z = Tensor::from_slice(&vec![0.5f32; 1 * 128], &[1, 128]).unwrap();
let label = Tensor::from_slice(&vec![0.5f32; 1 * 1000], &[1, 1000]).unwrap();
let output = gen.forward(&z, &label, false);
assert_eq!(output.dims()[1], 3);
}
#[test]
fn test_stylegan2() {
let mut gen = StyleGAN2Generator::new(512, 512, 3);
let z = Tensor::from_slice(&vec![0.5f32; 1 * 512], &[1, 512]).unwrap();
let output = gen.forward(&z, false);
assert_eq!(output.dims()[1], 3);
}
}