entrenar/generative/code_gan/
config.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GeneratorConfig {
8 pub latent_dim: usize,
10 pub hidden_dims: Vec<usize>,
12 pub vocab_size: usize,
14 pub max_seq_len: usize,
16 pub dropout: f32,
18 pub batch_norm: bool,
20}
21
22impl Default for GeneratorConfig {
23 fn default() -> Self {
24 Self {
25 latent_dim: 128,
26 hidden_dims: vec![256, 512, 256],
27 vocab_size: 1000,
28 max_seq_len: 256,
29 dropout: 0.1,
30 batch_norm: true,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct DiscriminatorConfig {
38 pub vocab_size: usize,
40 pub max_seq_len: usize,
42 pub embed_dim: usize,
44 pub hidden_dims: Vec<usize>,
46 pub dropout: f32,
48 pub spectral_norm: bool,
50}
51
52impl Default for DiscriminatorConfig {
53 fn default() -> Self {
54 Self {
55 vocab_size: 1000,
56 max_seq_len: 256,
57 embed_dim: 64,
58 hidden_dims: vec![256, 128, 64],
59 dropout: 0.2,
60 spectral_norm: true,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CodeGanConfig {
68 pub generator: GeneratorConfig,
70 pub discriminator: DiscriminatorConfig,
72 pub gen_lr: f32,
74 pub disc_lr: f32,
76 pub n_critic: usize,
78 pub gradient_penalty: f32,
80 pub label_smoothing: f32,
82 pub batch_size: usize,
84}
85
86impl Default for CodeGanConfig {
87 fn default() -> Self {
88 Self {
89 generator: GeneratorConfig::default(),
90 discriminator: DiscriminatorConfig::default(),
91 gen_lr: 0.0002,
92 disc_lr: 0.0002,
93 n_critic: 5,
94 gradient_penalty: 10.0,
95 label_smoothing: 0.1,
96 batch_size: 32,
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_generator_config_default() {
107 let config = GeneratorConfig::default();
108 assert_eq!(config.latent_dim, 128);
109 assert_eq!(config.vocab_size, 1000);
110 assert_eq!(config.max_seq_len, 256);
111 }
112
113 #[test]
114 fn test_discriminator_config_default() {
115 let config = DiscriminatorConfig::default();
116 assert_eq!(config.vocab_size, 1000);
117 assert_eq!(config.max_seq_len, 256);
118 }
119
120 #[test]
121 fn test_code_gan_config_default() {
122 let config = CodeGanConfig::default();
123 assert_eq!(config.n_critic, 5);
124 assert!(config.gen_lr > 0.0);
125 assert!(config.disc_lr > 0.0);
126 }
127}