Skip to main content

entrenar/generative/code_gan/
discriminator.rs

1//! Discriminator network for Code GAN.
2
3use rand::Rng;
4
5use super::config::DiscriminatorConfig;
6
7/// Type alias for discriminator weights structure
8type DiscriminatorWeights = (Vec<Vec<f32>>, Vec<Vec<Vec<f32>>>, Vec<Vec<f32>>);
9
10/// Discriminator network: classifies code as real or fake
11#[derive(Debug)]
12pub struct Discriminator {
13    /// Configuration
14    pub config: DiscriminatorConfig,
15    /// Token embeddings
16    embeddings: Vec<Vec<f32>>,
17    /// Weights for each layer
18    weights: Vec<Vec<Vec<f32>>>,
19    /// Biases for each layer
20    biases: Vec<Vec<f32>>,
21}
22
23impl Discriminator {
24    /// Create a new discriminator with random initialization
25    pub fn new(config: DiscriminatorConfig) -> Self {
26        use rand::SeedableRng;
27        let mut rng = rand::rngs::StdRng::from_os_rng();
28        let (embeddings, weights, biases) = Self::init_weights(&config, &mut rng);
29        Self { config, embeddings, weights, biases }
30    }
31
32    /// Create a new discriminator with a seed for reproducibility
33    pub fn with_seed(config: DiscriminatorConfig, seed: u64) -> Self {
34        use rand::SeedableRng;
35        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
36        let (embeddings, weights, biases) = Self::init_weights(&config, &mut rng);
37        Self { config, embeddings, weights, biases }
38    }
39
40    fn init_weights<R: Rng>(config: &DiscriminatorConfig, rng: &mut R) -> DiscriminatorWeights {
41        // Helper function for Box-Muller normal sampling
42        let sample_normal = |rng: &mut R, std: f64| -> f32 {
43            let u1: f64 = rng.random::<f64>().max(1e-10);
44            let u2: f64 = rng.random::<f64>();
45            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
46            (z * std) as f32
47        };
48
49        // Initialize embeddings
50        let embed_std = (1.0 / config.embed_dim as f64).sqrt();
51        let embeddings: Vec<Vec<f32>> = (0..config.vocab_size)
52            .map(|_| (0..config.embed_dim).map(|_| sample_normal(rng, embed_std)).collect())
53            .collect();
54
55        // Initialize dense layers
56        let input_dim = config.embed_dim * config.max_seq_len;
57        let mut dims = vec![input_dim];
58        dims.extend(&config.hidden_dims);
59        dims.push(1); // Output: single logit for real/fake
60
61        let mut weights = Vec::new();
62        let mut biases = Vec::new();
63
64        for i in 0..dims.len() - 1 {
65            let in_dim = dims[i];
66            let out_dim = dims[i + 1];
67
68            let std = (2.0 / (in_dim + out_dim) as f64).sqrt();
69
70            let w: Vec<Vec<f32>> = (0..out_dim)
71                .map(|_| (0..in_dim).map(|_| sample_normal(rng, std)).collect())
72                .collect();
73            let b: Vec<f32> = vec![0.0; out_dim];
74
75            weights.push(w);
76            biases.push(b);
77        }
78
79        (embeddings, weights, biases)
80    }
81
82    /// Discriminate: returns probability that input is real (valid code)
83    pub fn discriminate(&self, tokens: &[u32]) -> f32 {
84        // Pad or truncate to max_seq_len
85        let mut padded = tokens.to_vec();
86        padded.resize(self.config.max_seq_len, 0);
87
88        // Embed tokens
89        let mut x = Vec::with_capacity(self.config.max_seq_len * self.config.embed_dim);
90        for &token in &padded {
91            let token_idx = (token as usize).min(self.config.vocab_size - 1);
92            x.extend(&self.embeddings[token_idx]);
93        }
94
95        // Forward pass through dense layers
96        for (i, (w, b)) in self.weights.iter().zip(&self.biases).enumerate() {
97            x = Self::linear_forward(&x, w, b);
98            // Leaky ReLU for all but last layer
99            if i < self.weights.len() - 1 {
100                x = x.iter().map(|&v| if v > 0.0 { v } else { 0.01 * v }).collect();
101            }
102        }
103
104        // Sigmoid on output
105        sigmoid(x[0])
106    }
107
108    fn linear_forward(input: &[f32], weights: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
109        let output_dim = weights.len();
110        let mut output = Vec::with_capacity(output_dim);
111
112        for (i, w_row) in weights.iter().enumerate() {
113            let dot: f32 = w_row.iter().zip(input).map(|(a, b)| a * b).sum();
114            output.push(dot + bias[i]);
115        }
116
117        output
118    }
119
120    /// Get number of parameters
121    #[must_use]
122    pub fn num_parameters(&self) -> usize {
123        let embed_params = self.embeddings.len() * self.config.embed_dim;
124        let weight_params: usize = self.weights.iter().map(|w| w.len() * w[0].len()).sum();
125        let bias_params: usize = self.biases.iter().map(Vec::len).sum();
126        embed_params + weight_params + bias_params
127    }
128}
129
130/// Sigmoid activation function
131pub fn sigmoid(x: f32) -> f32 {
132    contract_pre_sigmoid!();
133    let result = 1.0 / (1.0 + (-x).exp());
134    contract_post_silu!(&[result]);
135    result
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use proptest::prelude::*;
142
143    #[test]
144    fn test_discriminator_creation() {
145        let config = DiscriminatorConfig {
146            vocab_size: 100,
147            max_seq_len: 10,
148            embed_dim: 16,
149            hidden_dims: vec![32, 16],
150            dropout: 0.1,
151            spectral_norm: true,
152        };
153        let disc = Discriminator::with_seed(config, 42);
154        assert!(disc.num_parameters() > 0);
155    }
156
157    #[test]
158    fn test_discriminator_output_range() {
159        let config = DiscriminatorConfig {
160            vocab_size: 50,
161            max_seq_len: 8,
162            embed_dim: 8,
163            hidden_dims: vec![16],
164            dropout: 0.0,
165            spectral_norm: false,
166        };
167        let disc = Discriminator::with_seed(config, 42);
168
169        let tokens = vec![1, 2, 3, 4, 5];
170        let prob = disc.discriminate(&tokens);
171
172        // Output should be in [0, 1] due to sigmoid
173        assert!((0.0..=1.0).contains(&prob));
174    }
175
176    #[test]
177    fn test_discriminator_deterministic() {
178        let config = DiscriminatorConfig {
179            vocab_size: 50,
180            max_seq_len: 8,
181            embed_dim: 8,
182            hidden_dims: vec![16],
183            dropout: 0.0,
184            spectral_norm: false,
185        };
186        let disc = Discriminator::with_seed(config, 42);
187
188        let tokens = vec![1, 2, 3, 4, 5];
189        let prob1 = disc.discriminate(&tokens);
190        let prob2 = disc.discriminate(&tokens);
191
192        assert!((prob1 - prob2).abs() < 1e-6);
193    }
194
195    #[test]
196    fn test_sigmoid_function() {
197        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
198        assert!(sigmoid(10.0) > 0.99);
199        assert!(sigmoid(-10.0) < 0.01);
200    }
201
202    proptest! {
203        #[test]
204        fn test_discriminator_output_bounds(tokens in prop::collection::vec(0u32..50, 1..10)) {
205            let config = DiscriminatorConfig {
206                vocab_size: 50,
207                max_seq_len: 10,
208                embed_dim: 8,
209                hidden_dims: vec![16],
210                dropout: 0.0,
211                spectral_norm: false,
212            };
213            let disc = Discriminator::with_seed(config, 42);
214
215            let prob = disc.discriminate(&tokens);
216            prop_assert!((0.0..=1.0).contains(&prob));
217        }
218    }
219}