Skip to main content

entrenar/generative/code_gan/
generator.rs

1//! Generator network for Code GAN.
2
3use rand::Rng;
4
5use super::config::GeneratorConfig;
6use super::latent::LatentCode;
7
8/// Generator network: maps latent vectors to AST token sequences
9#[derive(Debug)]
10pub struct Generator {
11    /// Configuration
12    pub config: GeneratorConfig,
13    /// Weights for each layer (simplified representation)
14    weights: Vec<Vec<Vec<f32>>>,
15    /// Biases for each layer
16    biases: Vec<Vec<f32>>,
17}
18
19impl Generator {
20    /// Create a new generator with random initialization
21    pub fn new(config: GeneratorConfig) -> Self {
22        use rand::SeedableRng;
23        let mut rng = rand::rngs::StdRng::from_os_rng();
24        let (weights, biases) = Self::init_weights(&config, &mut rng);
25        Self { config, weights, biases }
26    }
27
28    /// Create a new generator with a seed for reproducibility
29    pub fn with_seed(config: GeneratorConfig, seed: u64) -> Self {
30        use rand::SeedableRng;
31        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
32        let (weights, biases) = Self::init_weights(&config, &mut rng);
33        Self { config, weights, biases }
34    }
35
36    fn init_weights<R: Rng>(
37        config: &GeneratorConfig,
38        rng: &mut R,
39    ) -> (Vec<Vec<Vec<f32>>>, Vec<Vec<f32>>) {
40        let mut dims = vec![config.latent_dim];
41        dims.extend(&config.hidden_dims);
42        dims.push(config.vocab_size * config.max_seq_len);
43
44        let mut weights = Vec::new();
45        let mut biases = Vec::new();
46
47        for i in 0..dims.len() - 1 {
48            let input_dim = dims[i];
49            let output_dim = dims[i + 1];
50
51            // Xavier initialization using Box-Muller transform
52            let std = (2.0 / (input_dim + output_dim) as f64).sqrt();
53
54            let w: Vec<Vec<f32>> = (0..output_dim)
55                .map(|_| {
56                    (0..input_dim)
57                        .map(|_| {
58                            let u1: f64 = rng.random::<f64>().max(1e-10);
59                            let u2: f64 = rng.random::<f64>();
60                            let z =
61                                (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
62                            (z * std) as f32
63                        })
64                        .collect()
65                })
66                .collect();
67            let b: Vec<f32> = vec![0.0; output_dim];
68
69            weights.push(w);
70            biases.push(b);
71        }
72
73        (weights, biases)
74    }
75
76    /// Generate AST tokens from a latent code
77    pub fn generate(&self, latent: &LatentCode) -> Vec<u32> {
78        assert_eq!(latent.dim(), self.config.latent_dim);
79
80        // Forward pass through network
81        let mut x = latent.vector.clone();
82
83        for (w, b) in self.weights.iter().zip(&self.biases) {
84            x = Self::linear_forward(&x, w, b);
85            // ReLU activation (except last layer)
86            if w != self.weights.last().expect("non-empty weights") {
87                x = x.iter().map(|&v| v.max(0.0)).collect();
88            }
89        }
90
91        // Reshape to (max_seq_len, vocab_size) and take argmax for each position
92        let vocab_size = self.config.vocab_size;
93        let max_seq_len = self.config.max_seq_len;
94
95        let mut tokens = Vec::with_capacity(max_seq_len);
96        for pos in 0..max_seq_len {
97            let start = pos * vocab_size;
98            let end = start + vocab_size;
99            if end <= x.len() {
100                let logits = &x[start..end];
101                let max_idx = logits
102                    .iter()
103                    .enumerate()
104                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
105                    .map_or(0, |(i, _)| i as u32);
106                tokens.push(max_idx);
107            }
108        }
109
110        tokens
111    }
112
113    fn linear_forward(input: &[f32], weights: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
114        let output_dim = weights.len();
115        let mut output = Vec::with_capacity(output_dim);
116
117        for (i, w_row) in weights.iter().enumerate() {
118            let dot: f32 = w_row.iter().zip(input).map(|(a, b)| a * b).sum();
119            output.push(dot + bias[i]);
120        }
121
122        output
123    }
124
125    /// Get number of parameters
126    #[must_use]
127    pub fn num_parameters(&self) -> usize {
128        let weight_params: usize = self.weights.iter().map(|w| w.len() * w[0].len()).sum();
129        let bias_params: usize = self.biases.iter().map(Vec::len).sum();
130        weight_params + bias_params
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use proptest::prelude::*;
138
139    #[test]
140    fn test_generator_creation() {
141        let config = GeneratorConfig {
142            latent_dim: 32,
143            hidden_dims: vec![64, 64],
144            vocab_size: 100,
145            max_seq_len: 10,
146            dropout: 0.1,
147            batch_norm: true,
148        };
149        let gen = Generator::with_seed(config, 42);
150        assert!(gen.num_parameters() > 0);
151    }
152
153    #[test]
154    fn test_generator_generate() {
155        let config = GeneratorConfig {
156            latent_dim: 16,
157            hidden_dims: vec![32],
158            vocab_size: 50,
159            max_seq_len: 8,
160            dropout: 0.0,
161            batch_norm: false,
162        };
163        let gen = Generator::with_seed(config.clone(), 42);
164
165        use rand::SeedableRng;
166        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
167        let z = LatentCode::sample(&mut rng, config.latent_dim);
168
169        let tokens = gen.generate(&z);
170        assert_eq!(tokens.len(), config.max_seq_len);
171        assert!(tokens.iter().all(|&t| t < config.vocab_size as u32));
172    }
173
174    #[test]
175    fn test_generator_deterministic() {
176        let config = GeneratorConfig {
177            latent_dim: 16,
178            hidden_dims: vec![32],
179            vocab_size: 50,
180            max_seq_len: 8,
181            dropout: 0.0,
182            batch_norm: false,
183        };
184
185        let gen = Generator::with_seed(config.clone(), 42);
186        let z = LatentCode::new(vec![0.5; config.latent_dim]);
187
188        let tokens1 = gen.generate(&z);
189        let tokens2 = gen.generate(&z);
190
191        assert_eq!(tokens1, tokens2);
192    }
193
194    proptest! {
195        #[test]
196        fn test_generator_output_valid_tokens(seed in 0u64..10000) {
197            let config = GeneratorConfig {
198                latent_dim: 16,
199                hidden_dims: vec![32],
200                vocab_size: 50,
201                max_seq_len: 8,
202                dropout: 0.0,
203                batch_norm: false,
204            };
205            let gen = Generator::with_seed(config.clone(), seed);
206
207            use rand::SeedableRng;
208            let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
209            let z = LatentCode::sample(&mut rng, config.latent_dim);
210
211            let tokens = gen.generate(&z);
212            prop_assert!(tokens.iter().all(|&t| t < 50));
213        }
214    }
215}