entrenar/generative/code_gan/
generator.rs1use rand::Rng;
4
5use super::config::GeneratorConfig;
6use super::latent::LatentCode;
7
8#[derive(Debug)]
10pub struct Generator {
11 pub config: GeneratorConfig,
13 weights: Vec<Vec<Vec<f32>>>,
15 biases: Vec<Vec<f32>>,
17}
18
19impl Generator {
20 pub fn new(config: GeneratorConfig) -> Self {
22 use rand::SeedableRng;
23 let mut rng = rand::rngs::StdRng::from_os_rng();
24 let (weights, biases) = Self::init_weights(&config, &mut rng);
25 Self { config, weights, biases }
26 }
27
28 pub fn with_seed(config: GeneratorConfig, seed: u64) -> Self {
30 use rand::SeedableRng;
31 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
32 let (weights, biases) = Self::init_weights(&config, &mut rng);
33 Self { config, weights, biases }
34 }
35
36 fn init_weights<R: Rng>(
37 config: &GeneratorConfig,
38 rng: &mut R,
39 ) -> (Vec<Vec<Vec<f32>>>, Vec<Vec<f32>>) {
40 let mut dims = vec![config.latent_dim];
41 dims.extend(&config.hidden_dims);
42 dims.push(config.vocab_size * config.max_seq_len);
43
44 let mut weights = Vec::new();
45 let mut biases = Vec::new();
46
47 for i in 0..dims.len() - 1 {
48 let input_dim = dims[i];
49 let output_dim = dims[i + 1];
50
51 let std = (2.0 / (input_dim + output_dim) as f64).sqrt();
53
54 let w: Vec<Vec<f32>> = (0..output_dim)
55 .map(|_| {
56 (0..input_dim)
57 .map(|_| {
58 let u1: f64 = rng.random::<f64>().max(1e-10);
59 let u2: f64 = rng.random::<f64>();
60 let z =
61 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
62 (z * std) as f32
63 })
64 .collect()
65 })
66 .collect();
67 let b: Vec<f32> = vec![0.0; output_dim];
68
69 weights.push(w);
70 biases.push(b);
71 }
72
73 (weights, biases)
74 }
75
76 pub fn generate(&self, latent: &LatentCode) -> Vec<u32> {
78 assert_eq!(latent.dim(), self.config.latent_dim);
79
80 let mut x = latent.vector.clone();
82
83 for (w, b) in self.weights.iter().zip(&self.biases) {
84 x = Self::linear_forward(&x, w, b);
85 if w != self.weights.last().expect("non-empty weights") {
87 x = x.iter().map(|&v| v.max(0.0)).collect();
88 }
89 }
90
91 let vocab_size = self.config.vocab_size;
93 let max_seq_len = self.config.max_seq_len;
94
95 let mut tokens = Vec::with_capacity(max_seq_len);
96 for pos in 0..max_seq_len {
97 let start = pos * vocab_size;
98 let end = start + vocab_size;
99 if end <= x.len() {
100 let logits = &x[start..end];
101 let max_idx = logits
102 .iter()
103 .enumerate()
104 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
105 .map_or(0, |(i, _)| i as u32);
106 tokens.push(max_idx);
107 }
108 }
109
110 tokens
111 }
112
113 fn linear_forward(input: &[f32], weights: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
114 let output_dim = weights.len();
115 let mut output = Vec::with_capacity(output_dim);
116
117 for (i, w_row) in weights.iter().enumerate() {
118 let dot: f32 = w_row.iter().zip(input).map(|(a, b)| a * b).sum();
119 output.push(dot + bias[i]);
120 }
121
122 output
123 }
124
125 #[must_use]
127 pub fn num_parameters(&self) -> usize {
128 let weight_params: usize = self.weights.iter().map(|w| w.len() * w[0].len()).sum();
129 let bias_params: usize = self.biases.iter().map(Vec::len).sum();
130 weight_params + bias_params
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use proptest::prelude::*;
138
139 #[test]
140 fn test_generator_creation() {
141 let config = GeneratorConfig {
142 latent_dim: 32,
143 hidden_dims: vec![64, 64],
144 vocab_size: 100,
145 max_seq_len: 10,
146 dropout: 0.1,
147 batch_norm: true,
148 };
149 let gen = Generator::with_seed(config, 42);
150 assert!(gen.num_parameters() > 0);
151 }
152
153 #[test]
154 fn test_generator_generate() {
155 let config = GeneratorConfig {
156 latent_dim: 16,
157 hidden_dims: vec![32],
158 vocab_size: 50,
159 max_seq_len: 8,
160 dropout: 0.0,
161 batch_norm: false,
162 };
163 let gen = Generator::with_seed(config.clone(), 42);
164
165 use rand::SeedableRng;
166 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
167 let z = LatentCode::sample(&mut rng, config.latent_dim);
168
169 let tokens = gen.generate(&z);
170 assert_eq!(tokens.len(), config.max_seq_len);
171 assert!(tokens.iter().all(|&t| t < config.vocab_size as u32));
172 }
173
174 #[test]
175 fn test_generator_deterministic() {
176 let config = GeneratorConfig {
177 latent_dim: 16,
178 hidden_dims: vec![32],
179 vocab_size: 50,
180 max_seq_len: 8,
181 dropout: 0.0,
182 batch_norm: false,
183 };
184
185 let gen = Generator::with_seed(config.clone(), 42);
186 let z = LatentCode::new(vec![0.5; config.latent_dim]);
187
188 let tokens1 = gen.generate(&z);
189 let tokens2 = gen.generate(&z);
190
191 assert_eq!(tokens1, tokens2);
192 }
193
194 proptest! {
195 #[test]
196 fn test_generator_output_valid_tokens(seed in 0u64..10000) {
197 let config = GeneratorConfig {
198 latent_dim: 16,
199 hidden_dims: vec![32],
200 vocab_size: 50,
201 max_seq_len: 8,
202 dropout: 0.0,
203 batch_norm: false,
204 };
205 let gen = Generator::with_seed(config.clone(), seed);
206
207 use rand::SeedableRng;
208 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
209 let z = LatentCode::sample(&mut rng, config.latent_dim);
210
211 let tokens = gen.generate(&z);
212 prop_assert!(tokens.iter().all(|&t| t < 50));
213 }
214 }
215}