use rand::Rng;
use super::config::GeneratorConfig;
use super::latent::LatentCode;
#[derive(Debug)]
pub struct Generator {
pub config: GeneratorConfig,
weights: Vec<Vec<Vec<f32>>>,
biases: Vec<Vec<f32>>,
}
impl Generator {
pub fn new(config: GeneratorConfig) -> Self {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::from_os_rng();
let (weights, biases) = Self::init_weights(&config, &mut rng);
Self { config, weights, biases }
}
pub fn with_seed(config: GeneratorConfig, seed: u64) -> Self {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let (weights, biases) = Self::init_weights(&config, &mut rng);
Self { config, weights, biases }
}
fn init_weights<R: Rng>(
config: &GeneratorConfig,
rng: &mut R,
) -> (Vec<Vec<Vec<f32>>>, Vec<Vec<f32>>) {
let mut dims = vec![config.latent_dim];
dims.extend(&config.hidden_dims);
dims.push(config.vocab_size * config.max_seq_len);
let mut weights = Vec::new();
let mut biases = Vec::new();
for i in 0..dims.len() - 1 {
let input_dim = dims[i];
let output_dim = dims[i + 1];
let std = (2.0 / (input_dim + output_dim) as f64).sqrt();
let w: Vec<Vec<f32>> = (0..output_dim)
.map(|_| {
(0..input_dim)
.map(|_| {
let u1: f64 = rng.random::<f64>().max(1e-10);
let u2: f64 = rng.random::<f64>();
let z =
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
(z * std) as f32
})
.collect()
})
.collect();
let b: Vec<f32> = vec![0.0; output_dim];
weights.push(w);
biases.push(b);
}
(weights, biases)
}
pub fn generate(&self, latent: &LatentCode) -> Vec<u32> {
assert_eq!(latent.dim(), self.config.latent_dim);
let mut x = latent.vector.clone();
for (w, b) in self.weights.iter().zip(&self.biases) {
x = Self::linear_forward(&x, w, b);
if w != self.weights.last().expect("non-empty weights") {
x = x.iter().map(|&v| v.max(0.0)).collect();
}
}
let vocab_size = self.config.vocab_size;
let max_seq_len = self.config.max_seq_len;
let mut tokens = Vec::with_capacity(max_seq_len);
for pos in 0..max_seq_len {
let start = pos * vocab_size;
let end = start + vocab_size;
if end <= x.len() {
let logits = &x[start..end];
let max_idx = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i as u32);
tokens.push(max_idx);
}
}
tokens
}
fn linear_forward(input: &[f32], weights: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
let output_dim = weights.len();
let mut output = Vec::with_capacity(output_dim);
for (i, w_row) in weights.iter().enumerate() {
let dot: f32 = w_row.iter().zip(input).map(|(a, b)| a * b).sum();
output.push(dot + bias[i]);
}
output
}
#[must_use]
pub fn num_parameters(&self) -> usize {
let weight_params: usize = self.weights.iter().map(|w| w.len() * w[0].len()).sum();
let bias_params: usize = self.biases.iter().map(Vec::len).sum();
weight_params + bias_params
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_generator_creation() {
let config = GeneratorConfig {
latent_dim: 32,
hidden_dims: vec![64, 64],
vocab_size: 100,
max_seq_len: 10,
dropout: 0.1,
batch_norm: true,
};
let gen = Generator::with_seed(config, 42);
assert!(gen.num_parameters() > 0);
}
#[test]
fn test_generator_generate() {
let config = GeneratorConfig {
latent_dim: 16,
hidden_dims: vec![32],
vocab_size: 50,
max_seq_len: 8,
dropout: 0.0,
batch_norm: false,
};
let gen = Generator::with_seed(config.clone(), 42);
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let z = LatentCode::sample(&mut rng, config.latent_dim);
let tokens = gen.generate(&z);
assert_eq!(tokens.len(), config.max_seq_len);
assert!(tokens.iter().all(|&t| t < config.vocab_size as u32));
}
#[test]
fn test_generator_deterministic() {
let config = GeneratorConfig {
latent_dim: 16,
hidden_dims: vec![32],
vocab_size: 50,
max_seq_len: 8,
dropout: 0.0,
batch_norm: false,
};
let gen = Generator::with_seed(config.clone(), 42);
let z = LatentCode::new(vec![0.5; config.latent_dim]);
let tokens1 = gen.generate(&z);
let tokens2 = gen.generate(&z);
assert_eq!(tokens1, tokens2);
}
proptest! {
#[test]
fn test_generator_output_valid_tokens(seed in 0u64..10000) {
let config = GeneratorConfig {
latent_dim: 16,
hidden_dims: vec![32],
vocab_size: 50,
max_seq_len: 8,
dropout: 0.0,
batch_norm: false,
};
let gen = Generator::with_seed(config.clone(), seed);
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let z = LatentCode::sample(&mut rng, config.latent_dim);
let tokens = gen.generate(&z);
prop_assert!(tokens.iter().all(|&t| t < 50));
}
}
}