Skip to main content

entrenar/generative/code_gan/
config.rs

1//! Configuration types for Code GAN components.
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for the Generator network
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GeneratorConfig {
8    /// Dimension of the latent space
9    pub latent_dim: usize,
10    /// Hidden layer sizes
11    pub hidden_dims: Vec<usize>,
12    /// Output vocabulary size (number of AST token types)
13    pub vocab_size: usize,
14    /// Maximum sequence length to generate
15    pub max_seq_len: usize,
16    /// Dropout rate during training
17    pub dropout: f32,
18    /// Use batch normalization
19    pub batch_norm: bool,
20}
21
22impl Default for GeneratorConfig {
23    fn default() -> Self {
24        Self {
25            latent_dim: 128,
26            hidden_dims: vec![256, 512, 256],
27            vocab_size: 1000,
28            max_seq_len: 256,
29            dropout: 0.1,
30            batch_norm: true,
31        }
32    }
33}
34
35/// Configuration for the Discriminator network
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct DiscriminatorConfig {
38    /// Input vocabulary size (number of AST token types)
39    pub vocab_size: usize,
40    /// Maximum sequence length to process
41    pub max_seq_len: usize,
42    /// Embedding dimension for tokens
43    pub embed_dim: usize,
44    /// Hidden layer sizes
45    pub hidden_dims: Vec<usize>,
46    /// Dropout rate during training
47    pub dropout: f32,
48    /// Use spectral normalization
49    pub spectral_norm: bool,
50}
51
52impl Default for DiscriminatorConfig {
53    fn default() -> Self {
54        Self {
55            vocab_size: 1000,
56            max_seq_len: 256,
57            embed_dim: 64,
58            hidden_dims: vec![256, 128, 64],
59            dropout: 0.2,
60            spectral_norm: true,
61        }
62    }
63}
64
65/// Configuration for the complete Code GAN
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CodeGanConfig {
68    /// Generator configuration
69    pub generator: GeneratorConfig,
70    /// Discriminator configuration
71    pub discriminator: DiscriminatorConfig,
72    /// Learning rate for generator
73    pub gen_lr: f32,
74    /// Learning rate for discriminator
75    pub disc_lr: f32,
76    /// Number of discriminator updates per generator update
77    pub n_critic: usize,
78    /// Gradient penalty coefficient (for WGAN-GP)
79    pub gradient_penalty: f32,
80    /// Label smoothing for real samples
81    pub label_smoothing: f32,
82    /// Batch size for training
83    pub batch_size: usize,
84}
85
86impl Default for CodeGanConfig {
87    fn default() -> Self {
88        Self {
89            generator: GeneratorConfig::default(),
90            discriminator: DiscriminatorConfig::default(),
91            gen_lr: 0.0002,
92            disc_lr: 0.0002,
93            n_critic: 5,
94            gradient_penalty: 10.0,
95            label_smoothing: 0.1,
96            batch_size: 32,
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_generator_config_default() {
107        let config = GeneratorConfig::default();
108        assert_eq!(config.latent_dim, 128);
109        assert_eq!(config.vocab_size, 1000);
110        assert_eq!(config.max_seq_len, 256);
111    }
112
113    #[test]
114    fn test_discriminator_config_default() {
115        let config = DiscriminatorConfig::default();
116        assert_eq!(config.vocab_size, 1000);
117        assert_eq!(config.max_seq_len, 256);
118    }
119
120    #[test]
121    fn test_code_gan_config_default() {
122        let config = CodeGanConfig::default();
123        assert_eq!(config.n_critic, 5);
124        assert!(config.gen_lr > 0.0);
125        assert!(config.disc_lr > 0.0);
126    }
127}