entrenar/generative/code_gan/gan/
code_gan.rs1use crate::generative::code_gan::config::CodeGanConfig;
4use crate::generative::code_gan::discriminator::Discriminator;
5use crate::generative::code_gan::generator::Generator;
6use crate::generative::code_gan::latent::LatentCode;
7
8use super::stats::CodeGanStats;
9use super::training_result::TrainingResult;
10
11pub struct CodeGan {
13 pub config: CodeGanConfig,
15 pub generator: Generator,
17 pub discriminator: Discriminator,
19 pub stats: CodeGanStats,
21 rng: rand::rngs::StdRng,
23}
24
25impl CodeGan {
26 pub fn new(config: CodeGanConfig) -> Self {
28 use rand::SeedableRng;
29 let generator = Generator::new(config.generator.clone());
30 let discriminator = Discriminator::new(config.discriminator.clone());
31 Self {
32 config,
33 generator,
34 discriminator,
35 stats: CodeGanStats::default(),
36 rng: rand::rngs::StdRng::from_os_rng(),
37 }
38 }
39
40 pub fn with_seed(config: CodeGanConfig, seed: u64) -> Self {
42 use rand::SeedableRng;
43 let generator = Generator::with_seed(config.generator.clone(), seed);
44 let discriminator = Discriminator::with_seed(config.discriminator.clone(), seed + 1);
45 Self {
46 config,
47 generator,
48 discriminator,
49 stats: CodeGanStats::default(),
50 rng: rand::rngs::StdRng::seed_from_u64(seed),
51 }
52 }
53
54 pub fn sample_latent(&mut self, batch_size: usize) -> Vec<LatentCode> {
56 (0..batch_size)
57 .map(|_| LatentCode::sample(&mut self.rng, self.config.generator.latent_dim))
58 .collect()
59 }
60
61 pub fn generate(&self, latent_codes: &[LatentCode]) -> Vec<Vec<u32>> {
63 latent_codes.iter().map(|z| self.generator.generate(z)).collect()
64 }
65
66 pub fn generate_one(&mut self) -> Vec<u32> {
68 let z = LatentCode::sample(&mut self.rng, self.config.generator.latent_dim);
69 self.generator.generate(&z)
70 }
71
72 pub fn discriminate(&self, samples: &[Vec<u32>]) -> Vec<f32> {
74 samples.iter().map(|tokens| self.discriminator.discriminate(tokens)).collect()
75 }
76
77 pub fn discriminator_loss(&self, real_samples: &[Vec<u32>], fake_samples: &[Vec<u32>]) -> f32 {
79 let real_probs = self.discriminate(real_samples);
80 let fake_probs = self.discriminate(fake_samples);
81
82 let smoothed_real = 1.0 - self.config.label_smoothing;
85
86 let real_loss: f32 =
87 real_probs.iter().map(|&p| -smoothed_real * p.max(1e-7).ln()).sum::<f32>()
88 / real_probs.len().max(1) as f32;
89
90 let fake_loss: f32 = fake_probs.iter().map(|&p| -(1.0 - p).max(1e-7).ln()).sum::<f32>()
91 / fake_probs.len().max(1) as f32;
92
93 real_loss + fake_loss
94 }
95
96 pub fn generator_loss(&self, fake_samples: &[Vec<u32>]) -> f32 {
98 let fake_probs = self.discriminate(fake_samples);
99
100 let loss: f32 = fake_probs.iter().map(|&p| -p.max(1e-7).ln()).sum::<f32>()
102 / fake_probs.len().max(1) as f32;
103
104 loss
105 }
106
107 pub fn detect_mode_collapse(&mut self, num_samples: usize) -> f32 {
109 use std::collections::HashSet;
110
111 let latent_codes = self.sample_latent(num_samples);
112 let samples = self.generate(&latent_codes);
113
114 let unique_seqs: HashSet<Vec<u32>> = samples.into_iter().collect();
116 let diversity = unique_seqs.len() as f32 / num_samples as f32;
117
118 let all_tokens: HashSet<u32> =
120 unique_seqs.iter().flat_map(|seq| seq.iter().copied()).collect();
121
122 self.stats.unique_tokens = all_tokens.len();
123
124 let mode_collapse_score = 1.0 - diversity;
126 self.stats.mode_collapse_score = mode_collapse_score;
127
128 mode_collapse_score
129 }
130
131 pub fn interpolate(&self, z1: &LatentCode, z2: &LatentCode, steps: usize) -> Vec<Vec<u32>> {
133 (0..=steps)
134 .map(|i| {
135 let t = i as f32 / steps as f32;
136 let z = z1.slerp(z2, t);
137 self.generator.generate(&z)
138 })
139 .collect()
140 }
141
142 #[must_use]
144 pub fn num_parameters(&self) -> usize {
145 self.generator.num_parameters() + self.discriminator.num_parameters()
146 }
147
148 pub fn record_step(&mut self, result: &TrainingResult) {
150 self.stats.steps += 1;
151
152 if self.stats.gen_losses.len() >= 100 {
153 self.stats.gen_losses.pop_front();
154 }
155 self.stats.gen_losses.push_back(result.gen_loss);
156
157 if self.stats.disc_losses.len() >= 100 {
158 self.stats.disc_losses.pop_front();
159 }
160 self.stats.disc_losses.push_back(result.disc_loss);
161 }
162
163 #[must_use]
165 pub fn avg_gen_loss(&self) -> f32 {
166 if self.stats.gen_losses.is_empty() {
167 return 0.0;
168 }
169 self.stats.gen_losses.iter().sum::<f32>() / self.stats.gen_losses.len() as f32
170 }
171
172 #[must_use]
174 pub fn avg_disc_loss(&self) -> f32 {
175 if self.stats.disc_losses.is_empty() {
176 return 0.0;
177 }
178 self.stats.disc_losses.iter().sum::<f32>() / self.stats.disc_losses.len() as f32
179 }
180}