Skip to main content

entrenar/generative/code_gan/gan/
code_gan.rs

1//! Code GAN main struct and training logic.
2
3use crate::generative::code_gan::config::CodeGanConfig;
4use crate::generative::code_gan::discriminator::Discriminator;
5use crate::generative::code_gan::generator::Generator;
6use crate::generative::code_gan::latent::LatentCode;
7
8use super::stats::CodeGanStats;
9use super::training_result::TrainingResult;
10
11/// Complete Code GAN for generating Rust AST
12pub struct CodeGan {
13    /// Configuration
14    pub config: CodeGanConfig,
15    /// Generator network
16    pub generator: Generator,
17    /// Discriminator network
18    pub discriminator: Discriminator,
19    /// Training statistics
20    pub stats: CodeGanStats,
21    /// Random number generator
22    rng: rand::rngs::StdRng,
23}
24
25impl CodeGan {
26    /// Create a new Code GAN
27    pub fn new(config: CodeGanConfig) -> Self {
28        use rand::SeedableRng;
29        let generator = Generator::new(config.generator.clone());
30        let discriminator = Discriminator::new(config.discriminator.clone());
31        Self {
32            config,
33            generator,
34            discriminator,
35            stats: CodeGanStats::default(),
36            rng: rand::rngs::StdRng::from_os_rng(),
37        }
38    }
39
40    /// Create a new Code GAN with a seed for reproducibility
41    pub fn with_seed(config: CodeGanConfig, seed: u64) -> Self {
42        use rand::SeedableRng;
43        let generator = Generator::with_seed(config.generator.clone(), seed);
44        let discriminator = Discriminator::with_seed(config.discriminator.clone(), seed + 1);
45        Self {
46            config,
47            generator,
48            discriminator,
49            stats: CodeGanStats::default(),
50            rng: rand::rngs::StdRng::seed_from_u64(seed),
51        }
52    }
53
54    /// Sample latent codes for generation
55    pub fn sample_latent(&mut self, batch_size: usize) -> Vec<LatentCode> {
56        (0..batch_size)
57            .map(|_| LatentCode::sample(&mut self.rng, self.config.generator.latent_dim))
58            .collect()
59    }
60
61    /// Generate code from latent codes
62    pub fn generate(&self, latent_codes: &[LatentCode]) -> Vec<Vec<u32>> {
63        latent_codes.iter().map(|z| self.generator.generate(z)).collect()
64    }
65
66    /// Generate a single code sample
67    pub fn generate_one(&mut self) -> Vec<u32> {
68        let z = LatentCode::sample(&mut self.rng, self.config.generator.latent_dim);
69        self.generator.generate(&z)
70    }
71
72    /// Discriminate a batch of code samples
73    pub fn discriminate(&self, samples: &[Vec<u32>]) -> Vec<f32> {
74        samples.iter().map(|tokens| self.discriminator.discriminate(tokens)).collect()
75    }
76
77    /// Compute discriminator loss (binary cross-entropy)
78    pub fn discriminator_loss(&self, real_samples: &[Vec<u32>], fake_samples: &[Vec<u32>]) -> f32 {
79        let real_probs = self.discriminate(real_samples);
80        let fake_probs = self.discriminate(fake_samples);
81
82        // BCE loss: -[y*log(p) + (1-y)*log(1-p)]
83        // For real: y=1, for fake: y=0
84        let smoothed_real = 1.0 - self.config.label_smoothing;
85
86        let real_loss: f32 =
87            real_probs.iter().map(|&p| -smoothed_real * p.max(1e-7).ln()).sum::<f32>()
88                / real_probs.len().max(1) as f32;
89
90        let fake_loss: f32 = fake_probs.iter().map(|&p| -(1.0 - p).max(1e-7).ln()).sum::<f32>()
91            / fake_probs.len().max(1) as f32;
92
93        real_loss + fake_loss
94    }
95
96    /// Compute generator loss (try to fool discriminator)
97    pub fn generator_loss(&self, fake_samples: &[Vec<u32>]) -> f32 {
98        let fake_probs = self.discriminate(fake_samples);
99
100        // Generator wants discriminator to output 1 (real) for fakes
101        let loss: f32 = fake_probs.iter().map(|&p| -p.max(1e-7).ln()).sum::<f32>()
102            / fake_probs.len().max(1) as f32;
103
104        loss
105    }
106
107    /// Detect mode collapse by measuring diversity of generated samples
108    pub fn detect_mode_collapse(&mut self, num_samples: usize) -> f32 {
109        use std::collections::HashSet;
110
111        let latent_codes = self.sample_latent(num_samples);
112        let samples = self.generate(&latent_codes);
113
114        // Count unique token sequences
115        let unique_seqs: HashSet<Vec<u32>> = samples.into_iter().collect();
116        let diversity = unique_seqs.len() as f32 / num_samples as f32;
117
118        // Also check token diversity
119        let all_tokens: HashSet<u32> =
120            unique_seqs.iter().flat_map(|seq| seq.iter().copied()).collect();
121
122        self.stats.unique_tokens = all_tokens.len();
123
124        // Mode collapse score: 1 - diversity
125        let mode_collapse_score = 1.0 - diversity;
126        self.stats.mode_collapse_score = mode_collapse_score;
127
128        mode_collapse_score
129    }
130
131    /// Interpolate between two latent codes and generate intermediate samples
132    pub fn interpolate(&self, z1: &LatentCode, z2: &LatentCode, steps: usize) -> Vec<Vec<u32>> {
133        (0..=steps)
134            .map(|i| {
135                let t = i as f32 / steps as f32;
136                let z = z1.slerp(z2, t);
137                self.generator.generate(&z)
138            })
139            .collect()
140    }
141
142    /// Get total number of parameters
143    #[must_use]
144    pub fn num_parameters(&self) -> usize {
145        self.generator.num_parameters() + self.discriminator.num_parameters()
146    }
147
148    /// Record training step
149    pub fn record_step(&mut self, result: &TrainingResult) {
150        self.stats.steps += 1;
151
152        if self.stats.gen_losses.len() >= 100 {
153            self.stats.gen_losses.pop_front();
154        }
155        self.stats.gen_losses.push_back(result.gen_loss);
156
157        if self.stats.disc_losses.len() >= 100 {
158            self.stats.disc_losses.pop_front();
159        }
160        self.stats.disc_losses.push_back(result.disc_loss);
161    }
162
163    /// Get average generator loss over recent history
164    #[must_use]
165    pub fn avg_gen_loss(&self) -> f32 {
166        if self.stats.gen_losses.is_empty() {
167            return 0.0;
168        }
169        self.stats.gen_losses.iter().sum::<f32>() / self.stats.gen_losses.len() as f32
170    }
171
172    /// Get average discriminator loss over recent history
173    #[must_use]
174    pub fn avg_disc_loss(&self) -> f32 {
175        if self.stats.disc_losses.is_empty() {
176            return 0.0;
177        }
178        self.stats.disc_losses.iter().sum::<f32>() / self.stats.disc_losses.len() as f32
179    }
180}