use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratorConfig {
pub latent_dim: usize,
pub hidden_dims: Vec<usize>,
pub vocab_size: usize,
pub max_seq_len: usize,
pub dropout: f32,
pub batch_norm: bool,
}
impl Default for GeneratorConfig {
fn default() -> Self {
Self {
latent_dim: 128,
hidden_dims: vec![256, 512, 256],
vocab_size: 1000,
max_seq_len: 256,
dropout: 0.1,
batch_norm: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscriminatorConfig {
pub vocab_size: usize,
pub max_seq_len: usize,
pub embed_dim: usize,
pub hidden_dims: Vec<usize>,
pub dropout: f32,
pub spectral_norm: bool,
}
impl Default for DiscriminatorConfig {
fn default() -> Self {
Self {
vocab_size: 1000,
max_seq_len: 256,
embed_dim: 64,
hidden_dims: vec![256, 128, 64],
dropout: 0.2,
spectral_norm: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeGanConfig {
pub generator: GeneratorConfig,
pub discriminator: DiscriminatorConfig,
pub gen_lr: f32,
pub disc_lr: f32,
pub n_critic: usize,
pub gradient_penalty: f32,
pub label_smoothing: f32,
pub batch_size: usize,
}
impl Default for CodeGanConfig {
fn default() -> Self {
Self {
generator: GeneratorConfig::default(),
discriminator: DiscriminatorConfig::default(),
gen_lr: 0.0002,
disc_lr: 0.0002,
n_critic: 5,
gradient_penalty: 10.0,
label_smoothing: 0.1,
batch_size: 32,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generator_config_default() {
let config = GeneratorConfig::default();
assert_eq!(config.latent_dim, 128);
assert_eq!(config.vocab_size, 1000);
assert_eq!(config.max_seq_len, 256);
}
#[test]
fn test_discriminator_config_default() {
let config = DiscriminatorConfig::default();
assert_eq!(config.vocab_size, 1000);
assert_eq!(config.max_seq_len, 256);
}
#[test]
fn test_code_gan_config_default() {
let config = CodeGanConfig::default();
assert_eq!(config.n_critic, 5);
assert!(config.gen_lr > 0.0);
assert!(config.disc_lr > 0.0);
}
}