entrenar/generative/code_gan/
discriminator.rs1use rand::Rng;
4
5use super::config::DiscriminatorConfig;
6
7type DiscriminatorWeights = (Vec<Vec<f32>>, Vec<Vec<Vec<f32>>>, Vec<Vec<f32>>);
9
10#[derive(Debug)]
12pub struct Discriminator {
13 pub config: DiscriminatorConfig,
15 embeddings: Vec<Vec<f32>>,
17 weights: Vec<Vec<Vec<f32>>>,
19 biases: Vec<Vec<f32>>,
21}
22
23impl Discriminator {
24 pub fn new(config: DiscriminatorConfig) -> Self {
26 use rand::SeedableRng;
27 let mut rng = rand::rngs::StdRng::from_os_rng();
28 let (embeddings, weights, biases) = Self::init_weights(&config, &mut rng);
29 Self { config, embeddings, weights, biases }
30 }
31
32 pub fn with_seed(config: DiscriminatorConfig, seed: u64) -> Self {
34 use rand::SeedableRng;
35 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
36 let (embeddings, weights, biases) = Self::init_weights(&config, &mut rng);
37 Self { config, embeddings, weights, biases }
38 }
39
40 fn init_weights<R: Rng>(config: &DiscriminatorConfig, rng: &mut R) -> DiscriminatorWeights {
41 let sample_normal = |rng: &mut R, std: f64| -> f32 {
43 let u1: f64 = rng.random::<f64>().max(1e-10);
44 let u2: f64 = rng.random::<f64>();
45 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
46 (z * std) as f32
47 };
48
49 let embed_std = (1.0 / config.embed_dim as f64).sqrt();
51 let embeddings: Vec<Vec<f32>> = (0..config.vocab_size)
52 .map(|_| (0..config.embed_dim).map(|_| sample_normal(rng, embed_std)).collect())
53 .collect();
54
55 let input_dim = config.embed_dim * config.max_seq_len;
57 let mut dims = vec![input_dim];
58 dims.extend(&config.hidden_dims);
59 dims.push(1); let mut weights = Vec::new();
62 let mut biases = Vec::new();
63
64 for i in 0..dims.len() - 1 {
65 let in_dim = dims[i];
66 let out_dim = dims[i + 1];
67
68 let std = (2.0 / (in_dim + out_dim) as f64).sqrt();
69
70 let w: Vec<Vec<f32>> = (0..out_dim)
71 .map(|_| (0..in_dim).map(|_| sample_normal(rng, std)).collect())
72 .collect();
73 let b: Vec<f32> = vec![0.0; out_dim];
74
75 weights.push(w);
76 biases.push(b);
77 }
78
79 (embeddings, weights, biases)
80 }
81
82 pub fn discriminate(&self, tokens: &[u32]) -> f32 {
84 let mut padded = tokens.to_vec();
86 padded.resize(self.config.max_seq_len, 0);
87
88 let mut x = Vec::with_capacity(self.config.max_seq_len * self.config.embed_dim);
90 for &token in &padded {
91 let token_idx = (token as usize).min(self.config.vocab_size - 1);
92 x.extend(&self.embeddings[token_idx]);
93 }
94
95 for (i, (w, b)) in self.weights.iter().zip(&self.biases).enumerate() {
97 x = Self::linear_forward(&x, w, b);
98 if i < self.weights.len() - 1 {
100 x = x.iter().map(|&v| if v > 0.0 { v } else { 0.01 * v }).collect();
101 }
102 }
103
104 sigmoid(x[0])
106 }
107
108 fn linear_forward(input: &[f32], weights: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
109 let output_dim = weights.len();
110 let mut output = Vec::with_capacity(output_dim);
111
112 for (i, w_row) in weights.iter().enumerate() {
113 let dot: f32 = w_row.iter().zip(input).map(|(a, b)| a * b).sum();
114 output.push(dot + bias[i]);
115 }
116
117 output
118 }
119
120 #[must_use]
122 pub fn num_parameters(&self) -> usize {
123 let embed_params = self.embeddings.len() * self.config.embed_dim;
124 let weight_params: usize = self.weights.iter().map(|w| w.len() * w[0].len()).sum();
125 let bias_params: usize = self.biases.iter().map(Vec::len).sum();
126 embed_params + weight_params + bias_params
127 }
128}
129
130pub fn sigmoid(x: f32) -> f32 {
132 contract_pre_sigmoid!();
133 let result = 1.0 / (1.0 + (-x).exp());
134 contract_post_silu!(&[result]);
135 result
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use proptest::prelude::*;
142
143 #[test]
144 fn test_discriminator_creation() {
145 let config = DiscriminatorConfig {
146 vocab_size: 100,
147 max_seq_len: 10,
148 embed_dim: 16,
149 hidden_dims: vec![32, 16],
150 dropout: 0.1,
151 spectral_norm: true,
152 };
153 let disc = Discriminator::with_seed(config, 42);
154 assert!(disc.num_parameters() > 0);
155 }
156
157 #[test]
158 fn test_discriminator_output_range() {
159 let config = DiscriminatorConfig {
160 vocab_size: 50,
161 max_seq_len: 8,
162 embed_dim: 8,
163 hidden_dims: vec![16],
164 dropout: 0.0,
165 spectral_norm: false,
166 };
167 let disc = Discriminator::with_seed(config, 42);
168
169 let tokens = vec![1, 2, 3, 4, 5];
170 let prob = disc.discriminate(&tokens);
171
172 assert!((0.0..=1.0).contains(&prob));
174 }
175
176 #[test]
177 fn test_discriminator_deterministic() {
178 let config = DiscriminatorConfig {
179 vocab_size: 50,
180 max_seq_len: 8,
181 embed_dim: 8,
182 hidden_dims: vec![16],
183 dropout: 0.0,
184 spectral_norm: false,
185 };
186 let disc = Discriminator::with_seed(config, 42);
187
188 let tokens = vec![1, 2, 3, 4, 5];
189 let prob1 = disc.discriminate(&tokens);
190 let prob2 = disc.discriminate(&tokens);
191
192 assert!((prob1 - prob2).abs() < 1e-6);
193 }
194
195 #[test]
196 fn test_sigmoid_function() {
197 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
198 assert!(sigmoid(10.0) > 0.99);
199 assert!(sigmoid(-10.0) < 0.01);
200 }
201
202 proptest! {
203 #[test]
204 fn test_discriminator_output_bounds(tokens in prop::collection::vec(0u32..50, 1..10)) {
205 let config = DiscriminatorConfig {
206 vocab_size: 50,
207 max_seq_len: 10,
208 embed_dim: 8,
209 hidden_dims: vec![16],
210 dropout: 0.0,
211 spectral_norm: false,
212 };
213 let disc = Discriminator::with_seed(config, 42);
214
215 let prob = disc.discriminate(&tokens);
216 prop_assert!((0.0..=1.0).contains(&prob));
217 }
218 }
219}