use crate::error::{NeuralError, Result};
use crate::layers::{Dense, Dropout, Layer};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::SeedableRng;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum GANMode {
Standard,
WGANGP,
}
#[derive(Debug, Clone)]
pub struct GeneratorConfig {
pub noise_dim: usize,
pub hidden_dims: Vec<usize>,
pub output_dim: usize,
pub dropout_rate: f64,
pub use_batch_norm: bool,
}
impl GeneratorConfig {
pub fn standard(noise_dim: usize, output_dim: usize) -> Self {
Self {
noise_dim,
hidden_dims: vec![256, 512, 1024],
output_dim,
dropout_rate: 0.0,
use_batch_norm: true,
}
}
pub fn tiny(noise_dim: usize, output_dim: usize) -> Self {
Self {
noise_dim,
hidden_dims: vec![32, 64],
output_dim,
dropout_rate: 0.0,
use_batch_norm: false,
}
}
}
#[derive(Debug, Clone)]
pub struct DiscriminatorConfig {
pub input_dim: usize,
pub hidden_dims: Vec<usize>,
pub dropout_rate: f64,
}
impl DiscriminatorConfig {
pub fn standard(input_dim: usize) -> Self {
Self {
input_dim,
hidden_dims: vec![1024, 512, 256],
dropout_rate: 0.2,
}
}
pub fn tiny(input_dim: usize) -> Self {
Self {
input_dim,
hidden_dims: vec![64, 32],
dropout_rate: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct GANConfig {
pub generator: GeneratorConfig,
pub discriminator: DiscriminatorConfig,
pub mode: GANMode,
pub gradient_penalty_weight: f64,
pub n_critic: usize,
}
impl GANConfig {
pub fn standard(noise_dim: usize, data_dim: usize) -> Self {
Self {
generator: GeneratorConfig::standard(noise_dim, data_dim),
discriminator: DiscriminatorConfig::standard(data_dim),
mode: GANMode::Standard,
gradient_penalty_weight: 10.0,
n_critic: 1,
}
}
pub fn wgan_gp(noise_dim: usize, data_dim: usize) -> Self {
Self {
generator: GeneratorConfig::standard(noise_dim, data_dim),
discriminator: DiscriminatorConfig::standard(data_dim),
mode: GANMode::WGANGP,
gradient_penalty_weight: 10.0,
n_critic: 5,
}
}
pub fn tiny(noise_dim: usize, data_dim: usize) -> Self {
Self {
generator: GeneratorConfig::tiny(noise_dim, data_dim),
discriminator: DiscriminatorConfig::tiny(data_dim),
mode: GANMode::Standard,
gradient_penalty_weight: 10.0,
n_critic: 1,
}
}
pub fn with_mode(mut self, mode: GANMode) -> Self {
self.mode = mode;
self
}
pub fn with_gradient_penalty(mut self, weight: f64) -> Self {
self.gradient_penalty_weight = weight;
self
}
pub fn with_n_critic(mut self, n: usize) -> Self {
self.n_critic = n.max(1);
self
}
}
pub struct Generator<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
config: GeneratorConfig,
layers: Vec<Dense<F>>,
dropout: Option<Dropout<F>>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Generator<F> {
pub fn new(config: GeneratorConfig) -> Result<Self> {
if config.noise_dim == 0 || config.output_dim == 0 {
return Err(NeuralError::InvalidArchitecture(
"noise_dim and output_dim must be > 0".to_string(),
));
}
let mut layers = Vec::new();
let mut in_dim = config.noise_dim;
let mut seed: u8 = 110;
for &hdim in &config.hidden_dims {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([seed; 32]);
seed = seed.wrapping_add(1);
layers.push(Dense::new(in_dim, hdim, None, &mut rng)?);
in_dim = hdim;
}
{
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([seed; 32]);
layers.push(Dense::new(in_dim, config.output_dim, None, &mut rng)?);
}
let dropout = if config.dropout_rate > 0.0 {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([120; 32]);
Some(Dropout::new(config.dropout_rate, &mut rng)?)
} else {
None
};
Ok(Self {
config,
layers,
dropout,
})
}
pub fn config(&self) -> &GeneratorConfig {
&self.config
}
pub fn total_parameter_count(&self) -> usize {
self.layers.iter().map(|l| l.parameter_count()).sum()
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
for Generator<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = input.clone();
for (i, layer) in self.layers.iter().enumerate() {
x = layer.forward(&x)?;
if i < self.layers.len() - 1 {
x = x.mapv(|v| v.max(F::zero()));
if let Some(ref drop) = self.dropout {
x = drop.forward(&x)?;
}
} else {
x = x.mapv(|v| v.tanh());
}
}
Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
for layer in &mut self.layers {
layer.update(learning_rate)?;
}
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
self.layers.iter().flat_map(|l| l.params()).collect()
}
fn parameter_count(&self) -> usize {
self.total_parameter_count()
}
fn layer_type(&self) -> &str {
"Generator"
}
fn layer_description(&self) -> String {
format!(
"Generator(noise={}, hidden={:?}, output={}, params={})",
self.config.noise_dim,
self.config.hidden_dims,
self.config.output_dim,
self.total_parameter_count()
)
}
}
pub struct Discriminator<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
config: DiscriminatorConfig,
layers: Vec<Dense<F>>,
dropout: Option<Dropout<F>>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Discriminator<F> {
pub fn new(config: DiscriminatorConfig) -> Result<Self> {
if config.input_dim == 0 {
return Err(NeuralError::InvalidArchitecture(
"input_dim must be > 0".to_string(),
));
}
let mut layers = Vec::new();
let mut in_dim = config.input_dim;
let mut seed: u8 = 130;
for &hdim in &config.hidden_dims {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([seed; 32]);
seed = seed.wrapping_add(1);
layers.push(Dense::new(in_dim, hdim, None, &mut rng)?);
in_dim = hdim;
}
{
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([seed; 32]);
layers.push(Dense::new(in_dim, 1, None, &mut rng)?);
}
let dropout = if config.dropout_rate > 0.0 {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([140; 32]);
Some(Dropout::new(config.dropout_rate, &mut rng)?)
} else {
None
};
Ok(Self {
config,
layers,
dropout,
})
}
pub fn config(&self) -> &DiscriminatorConfig {
&self.config
}
pub fn total_parameter_count(&self) -> usize {
self.layers.iter().map(|l| l.parameter_count()).sum()
}
}
fn leaky_relu<F: Float>(x: F) -> F {
let slope = F::from(0.2).expect("leaky relu slope");
if x > F::zero() {
x
} else {
slope * x
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
for Discriminator<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = input.clone();
for (i, layer) in self.layers.iter().enumerate() {
x = layer.forward(&x)?;
if i < self.layers.len() - 1 {
x = x.mapv(leaky_relu);
if let Some(ref drop) = self.dropout {
x = drop.forward(&x)?;
}
}
}
Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
for layer in &mut self.layers {
layer.update(learning_rate)?;
}
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
self.layers.iter().flat_map(|l| l.params()).collect()
}
fn parameter_count(&self) -> usize {
self.total_parameter_count()
}
fn layer_type(&self) -> &str {
"Discriminator"
}
fn layer_description(&self) -> String {
format!(
"Discriminator(input={}, hidden={:?}, params={})",
self.config.input_dim,
self.config.hidden_dims,
self.total_parameter_count()
)
}
}
#[derive(Debug, Clone)]
pub struct GANStepResult<F: Float> {
pub d_loss: F,
pub g_loss: F,
pub gradient_penalty: F,
pub d_real_mean: F,
pub d_fake_mean: F,
}
pub struct GAN<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
config: GANConfig,
pub generator: Generator<F>,
pub discriminator: Discriminator<F>,
rng_state: std::cell::Cell<u64>,
step_count: usize,
}
unsafe impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Sync for GAN<F> {}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> GAN<F> {
pub fn new(config: GANConfig) -> Result<Self> {
let generator = Generator::new(config.generator.clone())?;
let discriminator = Discriminator::new(config.discriminator.clone())?;
Ok(Self {
config,
generator,
discriminator,
rng_state: std::cell::Cell::new(0xBEEF_CAFE_DEAD_BABEu64),
step_count: 0,
})
}
pub fn config(&self) -> &GANConfig {
&self.config
}
pub fn sample_noise(&self, batch_size: usize) -> Array<F, IxDyn> {
let noise_dim = self.config.generator.noise_dim;
let mut state = self.rng_state.get();
let mut noise = Array::zeros(IxDyn(&[batch_size, noise_dim]));
for b in 0..batch_size {
for d in 0..noise_dim {
let u1 = xorshift_f64(&mut state).max(1e-10);
let u2 = xorshift_f64(&mut state);
let normal = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
noise[[b, d]] = F::from(normal).expect("noise conversion");
}
}
self.rng_state.set(state);
noise
}
pub fn generate(&self, batch_size: usize) -> Result<Array<F, IxDyn>> {
let noise = self.sample_noise(batch_size);
self.generator.forward(&noise)
}
fn bce_loss(logits: &Array<F, IxDyn>, target: F) -> F {
let eps = F::from(1e-7).expect("eps");
let one = F::one();
let n = F::from(logits.len()).expect("n");
let mut loss = F::zero();
for &logit in logits.iter() {
let sigmoid = one / (one + (-logit).exp());
let sig_clamped = sigmoid.max(eps).min(one - eps);
loss += target * sig_clamped.ln() + (one - target) * (one - sig_clamped).ln();
}
-loss / n
}
fn wasserstein_d_loss(d_real: &Array<F, IxDyn>, d_fake: &Array<F, IxDyn>) -> F {
let n_real = F::from(d_real.len()).expect("n");
let n_fake = F::from(d_fake.len()).expect("n");
let real_mean: F = d_real.iter().copied().fold(F::zero(), |a, b| a + b) / n_real;
let fake_mean: F = d_fake.iter().copied().fold(F::zero(), |a, b| a + b) / n_fake;
fake_mean - real_mean
}
fn wasserstein_g_loss(d_fake: &Array<F, IxDyn>) -> F {
let n = F::from(d_fake.len()).expect("n");
let fake_mean: F = d_fake.iter().copied().fold(F::zero(), |a, b| a + b) / n;
-fake_mean
}
fn gradient_penalty(&self, real: &Array<F, IxDyn>, fake: &Array<F, IxDyn>) -> Result<F> {
let shape = real.shape();
let batch_size = shape[0];
let data_dim = shape[1];
let mut state = self.rng_state.get();
let mut penalty = F::zero();
let epsilon_fd = F::from(1e-4).expect("fd epsilon");
for b in 0..batch_size {
let alpha = F::from(xorshift_f64(&mut state)).expect("alpha");
let mut interp = Array::zeros(IxDyn(&[1, data_dim]));
for d in 0..data_dim {
interp[[0, d]] = alpha * real[[b, d]] + (F::one() - alpha) * fake[[b, d]];
}
let d_interp = self.discriminator.forward(&interp)?;
let base_val = d_interp[[0, 0]];
let mut grad_norm_sq = F::zero();
for d in 0..data_dim {
let mut perturbed = interp.clone();
perturbed[[0, d]] += epsilon_fd;
let d_perturbed = self.discriminator.forward(&perturbed)?;
let grad_d = (d_perturbed[[0, 0]] - base_val) / epsilon_fd;
grad_norm_sq += grad_d * grad_d;
}
let grad_norm = grad_norm_sq.sqrt();
let diff = grad_norm - F::one();
penalty += diff * diff;
}
self.rng_state.set(state);
Ok(penalty / F::from(batch_size).expect("batch"))
}
pub fn train_step(
&mut self,
real_data: &Array<F, IxDyn>,
learning_rate: F,
) -> Result<GANStepResult<F>> {
let batch_size = real_data.shape()[0];
let fake_data = self.generate(batch_size)?;
let d_real = self.discriminator.forward(real_data)?;
let d_fake = self.discriminator.forward(&fake_data)?;
let (d_loss, gp) = match self.config.mode {
GANMode::Standard => {
let loss_real = Self::bce_loss(&d_real, F::one());
let loss_fake = Self::bce_loss(&d_fake, F::zero());
(loss_real + loss_fake, F::zero())
}
GANMode::WGANGP => {
let w_loss = Self::wasserstein_d_loss(&d_real, &d_fake);
let gp = self.gradient_penalty(real_data, &fake_data)?;
let lambda = F::from(self.config.gradient_penalty_weight).expect("lambda");
(w_loss + lambda * gp, gp)
}
};
self.discriminator.update(learning_rate)?;
self.step_count += 1;
let g_loss = if self.step_count % self.config.n_critic == 0 {
let fake_data = self.generate(batch_size)?;
let d_fake_for_g = self.discriminator.forward(&fake_data)?;
let g_loss = match self.config.mode {
GANMode::Standard => Self::bce_loss(&d_fake_for_g, F::one()),
GANMode::WGANGP => Self::wasserstein_g_loss(&d_fake_for_g),
};
self.generator.update(learning_rate)?;
g_loss
} else {
F::zero()
};
let n_real = F::from(d_real.len()).expect("n");
let n_fake = F::from(d_fake.len()).expect("n");
let d_real_mean: F = d_real.iter().copied().fold(F::zero(), |a, b| a + b) / n_real;
let d_fake_mean: F = d_fake.iter().copied().fold(F::zero(), |a, b| a + b) / n_fake;
Ok(GANStepResult {
d_loss,
g_loss,
gradient_penalty: gp,
d_real_mean,
d_fake_mean,
})
}
pub fn total_parameter_count(&self) -> usize {
self.generator.total_parameter_count() + self.discriminator.total_parameter_count()
}
}
fn xorshift_f64(state: &mut u64) -> f64 {
let mut s = *state;
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
*state = s;
(s >> 11) as f64 / ((1u64 << 53) as f64)
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for GAN<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
self.generator.forward(input)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.generator.update(learning_rate)?;
self.discriminator.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut p = self.generator.params();
p.extend(self.discriminator.params());
p
}
fn parameter_count(&self) -> usize {
self.total_parameter_count()
}
fn layer_type(&self) -> &str {
"GAN"
}
fn layer_description(&self) -> String {
format!(
"GAN(mode={:?}, noise={}, data={}, g_params={}, d_params={})",
self.config.mode,
self.config.generator.noise_dim,
self.config.generator.output_dim,
self.generator.total_parameter_count(),
self.discriminator.total_parameter_count()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generator_config_standard() {
let cfg = GeneratorConfig::standard(100, 784);
assert_eq!(cfg.noise_dim, 100);
assert_eq!(cfg.output_dim, 784);
assert_eq!(cfg.hidden_dims, vec![256, 512, 1024]);
}
#[test]
fn test_generator_config_tiny() {
let cfg = GeneratorConfig::tiny(10, 20);
assert_eq!(cfg.hidden_dims, vec![32, 64]);
}
#[test]
fn test_discriminator_config_standard() {
let cfg = DiscriminatorConfig::standard(784);
assert_eq!(cfg.input_dim, 784);
assert_eq!(cfg.hidden_dims, vec![1024, 512, 256]);
assert!((cfg.dropout_rate - 0.2).abs() < 1e-10);
}
#[test]
fn test_gan_config_standard() {
let cfg = GANConfig::standard(100, 784);
assert_eq!(cfg.mode, GANMode::Standard);
assert_eq!(cfg.n_critic, 1);
}
#[test]
fn test_gan_config_wgan_gp() {
let cfg = GANConfig::wgan_gp(100, 784);
assert_eq!(cfg.mode, GANMode::WGANGP);
assert_eq!(cfg.n_critic, 5);
assert!((cfg.gradient_penalty_weight - 10.0).abs() < 1e-10);
}
#[test]
fn test_gan_config_builder() {
let cfg = GANConfig::standard(10, 20)
.with_mode(GANMode::WGANGP)
.with_gradient_penalty(5.0)
.with_n_critic(3);
assert_eq!(cfg.mode, GANMode::WGANGP);
assert!((cfg.gradient_penalty_weight - 5.0).abs() < 1e-10);
assert_eq!(cfg.n_critic, 3);
}
#[test]
fn test_generator_creation() {
let cfg = GeneratorConfig::tiny(10, 20);
let gen: Generator<f64> = Generator::new(cfg).expect("Failed to create generator");
assert!(gen.total_parameter_count() > 0);
}
#[test]
fn test_generator_forward() {
let cfg = GeneratorConfig::tiny(10, 20);
let gen: Generator<f64> = Generator::new(cfg).expect("Failed to create generator");
let noise = Array::zeros(IxDyn(&[2, 10]));
let output = gen.forward(&noise).expect("Generator forward failed");
assert_eq!(output.shape(), &[2, 20]);
for &v in output.iter() {
assert!(v >= -1.0 && v <= 1.0, "tanh should bound output, got {}", v);
}
}
#[test]
fn test_generator_layer_trait() {
let cfg = GeneratorConfig::tiny(10, 20);
let gen: Generator<f64> = Generator::new(cfg).expect("Failed to create generator");
assert_eq!(gen.layer_type(), "Generator");
assert!(gen.parameter_count() > 0);
}
#[test]
fn test_generator_invalid() {
let cfg = GeneratorConfig {
noise_dim: 0,
hidden_dims: vec![],
output_dim: 10,
dropout_rate: 0.0,
use_batch_norm: false,
};
assert!(Generator::<f64>::new(cfg).is_err());
}
#[test]
fn test_discriminator_creation() {
let cfg = DiscriminatorConfig::tiny(20);
let disc: Discriminator<f64> =
Discriminator::new(cfg).expect("Failed to create discriminator");
assert!(disc.total_parameter_count() > 0);
}
#[test]
fn test_discriminator_forward() {
let cfg = DiscriminatorConfig::tiny(20);
let disc: Discriminator<f64> =
Discriminator::new(cfg).expect("Failed to create discriminator");
let input = Array::zeros(IxDyn(&[3, 20]));
let output = disc.forward(&input).expect("Discriminator forward failed");
assert_eq!(output.shape(), &[3, 1]);
}
#[test]
fn test_discriminator_layer_trait() {
let cfg = DiscriminatorConfig::tiny(20);
let disc: Discriminator<f64> =
Discriminator::new(cfg).expect("Failed to create discriminator");
assert_eq!(disc.layer_type(), "Discriminator");
}
#[test]
fn test_discriminator_invalid() {
let cfg = DiscriminatorConfig {
input_dim: 0,
hidden_dims: vec![],
dropout_rate: 0.0,
};
assert!(Discriminator::<f64>::new(cfg).is_err());
}
#[test]
fn test_gan_creation() {
let cfg = GANConfig::tiny(10, 20);
let gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
assert!(gan.total_parameter_count() > 0);
}
#[test]
fn test_gan_generate() {
let cfg = GANConfig::tiny(10, 20);
let gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
let samples = gan.generate(5).expect("Generation failed");
assert_eq!(samples.shape(), &[5, 20]);
}
#[test]
fn test_gan_sample_noise() {
let cfg = GANConfig::tiny(10, 20);
let gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
let noise = gan.sample_noise(3);
assert_eq!(noise.shape(), &[3, 10]);
}
#[test]
fn test_gan_forward() {
let cfg = GANConfig::tiny(10, 20);
let gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
let noise = Array::zeros(IxDyn(&[2, 10]));
let output = gan.forward(&noise).expect("GAN forward failed");
assert_eq!(output.shape(), &[2, 20]);
}
#[test]
fn test_gan_train_step_standard() {
let cfg = GANConfig::tiny(10, 20);
let mut gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
let real_data = Array::zeros(IxDyn(&[4, 20]));
let result = gan
.train_step(&real_data, 0.001)
.expect("Train step failed");
assert!(result.d_loss.is_finite(), "d_loss should be finite");
assert!(result.g_loss.is_finite(), "g_loss should be finite");
}
#[test]
fn test_gan_train_step_wgan() {
let cfg = GANConfig::tiny(10, 20)
.with_mode(GANMode::WGANGP)
.with_n_critic(1);
let mut gan: GAN<f64> = GAN::new(cfg).expect("Failed to create WGAN");
let real_data = Array::zeros(IxDyn(&[4, 20]));
let result = gan
.train_step(&real_data, 0.001)
.expect("WGAN train step failed");
assert!(result.d_loss.is_finite(), "d_loss should be finite");
assert!(result.gradient_penalty.is_finite(), "GP should be finite");
}
#[test]
fn test_gan_multiple_train_steps() {
let cfg = GANConfig::tiny(10, 20);
let mut gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
let real_data = Array::zeros(IxDyn(&[4, 20]));
for _ in 0..5 {
let result = gan
.train_step(&real_data, 0.001)
.expect("Train step failed");
assert!(result.d_loss.is_finite());
}
}
#[test]
fn test_gan_layer_trait() {
let cfg = GANConfig::tiny(10, 20);
let gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
assert_eq!(gan.layer_type(), "GAN");
let desc = gan.layer_description();
assert!(desc.contains("GAN"));
assert!(desc.contains("Standard"));
}
#[test]
fn test_gan_update() {
let cfg = GANConfig::tiny(10, 20);
let mut gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
gan.update(0.001).expect("Update failed");
}
#[test]
fn test_gan_params() {
let cfg = GANConfig::tiny(10, 20);
let gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
let p = gan.params();
assert!(!p.is_empty());
}
#[test]
fn test_gan_f32() {
let cfg = GANConfig::tiny(10, 20);
let gan: GAN<f32> = GAN::new(cfg).expect("Failed to create f32 GAN");
let noise = Array::zeros(IxDyn(&[1, 10]));
let output = gan.forward(&noise).expect("f32 forward failed");
assert_eq!(output.shape(), &[1, 20]);
}
#[test]
fn test_bce_loss() {
let logits = Array::from_elem(IxDyn(&[4, 1]), 10.0_f64);
let loss = GAN::<f64>::bce_loss(&logits, 1.0);
assert!(
loss < 0.01,
"BCE loss should be small for correct prediction"
);
let loss_wrong = GAN::<f64>::bce_loss(&logits, 0.0);
assert!(loss_wrong > loss, "Wrong target should give higher loss");
}
#[test]
fn test_leaky_relu() {
assert!((leaky_relu(1.0_f64) - 1.0).abs() < 1e-10);
assert!((leaky_relu(0.0_f64) - 0.0).abs() < 1e-10);
assert!((leaky_relu(-1.0_f64) - (-0.2)).abs() < 1e-10);
assert!((leaky_relu(-5.0_f64) - (-1.0)).abs() < 1e-10);
}
#[test]
fn test_xorshift_f64_range() {
let mut state = 42u64;
for _ in 0..100 {
let v = xorshift_f64(&mut state);
assert!(v >= 0.0 && v < 1.0);
}
}
#[test]
fn test_n_critic_skips_generator() {
let cfg = GANConfig::tiny(10, 20).with_n_critic(3);
let mut gan: GAN<f64> = GAN::new(cfg).expect("Failed to create GAN");
let real_data = Array::zeros(IxDyn(&[2, 20]));
let r1 = gan.train_step(&real_data, 0.001).expect("Step 1");
assert!(
(r1.g_loss - 0.0).abs() < 1e-10,
"G should not train on step 1"
);
let r2 = gan.train_step(&real_data, 0.001).expect("Step 2");
assert!(
(r2.g_loss - 0.0).abs() < 1e-10,
"G should not train on step 2"
);
let r3 = gan.train_step(&real_data, 0.001).expect("Step 3");
assert!(r3.g_loss.is_finite(), "G should train on step 3");
}
}