chess_vector_engine/
variational_autoencoder.rs

1use candle_core::{Device, Module, Result as CandleResult, Tensor};
2use candle_nn::{linear, AdamW, Linear, Optimizer, ParamsAdamW, VarBuilder, VarMap};
3use ndarray::Array2;
4use std::collections::HashMap;
5
6/// Variational Autoencoder for chess position manifold learning with uncertainty quantification
7pub struct VariationalAutoencoder {
8    input_dim: usize,
9    latent_dim: usize,
10    device: Device,
11    encoder: Option<VariationalEncoder>,
12    decoder: Option<VariationalDecoder>,
13    var_map: VarMap,
14    optimizer: Option<AdamW>,
15    beta: f32, // KL divergence weight for β-VAE
16}
17
18/// Variational encoder with mean and log-variance outputs
19struct VariationalEncoder {
20    shared_layers: Vec<Linear>,
21    mean_layer: Linear,
22    logvar_layer: Linear,
23}
24
25/// Variational decoder
26struct VariationalDecoder {
27    layers: Vec<Linear>,
28}
29
30impl VariationalEncoder {
31    fn new(
32        vs: VarBuilder,
33        input_dim: usize,
34        hidden_dims: &[usize],
35        latent_dim: usize,
36    ) -> CandleResult<Self> {
37        let mut shared_layers = Vec::new();
38        let mut prev_dim = input_dim;
39
40        // Create shared hidden layers
41        for (i, &hidden_dim) in hidden_dims.iter().enumerate() {
42            let layer = linear(prev_dim, hidden_dim, vs.pp(format!("encoder.layer{i}")))?;
43            shared_layers.push(layer);
44            prev_dim = hidden_dim;
45        }
46
47        // Mean and log-variance branches
48        let mean_layer = linear(prev_dim, latent_dim, vs.pp("encoder.mean"))?;
49        let logvar_layer = linear(prev_dim, latent_dim, vs.pp("encoder.logvar"))?;
50
51        Ok(Self {
52            shared_layers,
53            mean_layer,
54            logvar_layer,
55        })
56    }
57
58    /// Forward pass returning mean and log-variance
59    fn encode(&self, x: &Tensor) -> CandleResult<(Tensor, Tensor)> {
60        let mut h = x.clone();
61
62        // Pass through shared layers with ReLU activation
63        for layer in &self.shared_layers {
64            h = layer.forward(&h)?.relu()?;
65        }
66
67        // Compute mean and log-variance
68        let mean = self.mean_layer.forward(&h)?;
69        let logvar = self.logvar_layer.forward(&h)?;
70
71        Ok((mean, logvar))
72    }
73
74    /// Reparameterization trick: z = μ + σ * ε
75    fn reparameterize(&self, mean: &Tensor, logvar: &Tensor) -> CandleResult<Tensor> {
76        let std = (logvar * 0.5)?.exp()?;
77        let eps = Tensor::randn_like(&std, 0.0, 1.0)?;
78        let scaled_eps = (&std * &eps)?;
79        mean + &scaled_eps
80    }
81}
82
83impl VariationalDecoder {
84    fn new(
85        vs: VarBuilder,
86        latent_dim: usize,
87        hidden_dims: &[usize],
88        output_dim: usize,
89    ) -> CandleResult<Self> {
90        let mut layers = Vec::new();
91        let mut prev_dim = latent_dim;
92
93        // Create hidden layers (reverse of encoder)
94        for (i, &hidden_dim) in hidden_dims.iter().rev().enumerate() {
95            let layer = linear(prev_dim, hidden_dim, vs.pp(format!("decoder.layer{i}")))?;
96            layers.push(layer);
97            prev_dim = hidden_dim;
98        }
99
100        // Output layer
101        let output_layer = linear(prev_dim, output_dim, vs.pp("decoder.output"))?;
102        layers.push(output_layer);
103
104        Ok(Self { layers })
105    }
106}
107
108impl Module for VariationalDecoder {
109    fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
110        let mut h = x.clone();
111
112        // All layers except last use ReLU
113        for (i, layer) in self.layers.iter().enumerate() {
114            h = layer.forward(&h)?;
115            if i < self.layers.len() - 1 {
116                h = h.relu()?;
117            } else {
118                // Output layer uses tanh to bound values
119                h = h.tanh()?;
120            }
121        }
122
123        Ok(h)
124    }
125}
126
127impl VariationalAutoencoder {
128    pub fn new(input_dim: usize, latent_dim: usize, beta: f32) -> Self {
129        let device = Device::Cpu; // Use CPU by default, can be upgraded to GPU later
130        let var_map = VarMap::new();
131
132        Self {
133            input_dim,
134            latent_dim,
135            device,
136            encoder: None,
137            decoder: None,
138            var_map,
139            optimizer: None,
140            beta,
141        }
142    }
143
144    /// Initialize the VAE network with configurable architecture
145    pub fn init_network(&mut self, hidden_dims: &[usize]) -> Result<(), String> {
146        let vs = VarBuilder::from_varmap(&self.var_map, candle_core::DType::F32, &self.device);
147
148        let encoder =
149            VariationalEncoder::new(vs.clone(), self.input_dim, hidden_dims, self.latent_dim)
150                .map_err(|_e| "Processing...".to_string())?;
151        let decoder = VariationalDecoder::new(vs, self.latent_dim, hidden_dims, self.input_dim)
152            .map_err(|_e| "Processing...".to_string())?;
153
154        // Initialize AdamW optimizer with lower learning rate for VAE stability
155        let adamw_params = ParamsAdamW {
156            lr: 0.0005,
157            beta1: 0.9,
158            beta2: 0.999,
159            eps: 1e-8,
160            weight_decay: 1e-4,
161        };
162        let optimizer = AdamW::new(self.var_map.all_vars(), adamw_params)
163            .map_err(|_e| "Processing...".to_string())?;
164
165        self.encoder = Some(encoder);
166        self.decoder = Some(decoder);
167        self.optimizer = Some(optimizer);
168
169        Ok(())
170    }
171
172    /// Forward pass through the VAE
173    pub fn forward(&self, x: &Tensor) -> CandleResult<(Tensor, Tensor, Tensor)> {
174        let encoder = self
175            .encoder
176            .as_ref()
177            .ok_or_else(|| candle_core::Error::Msg("Encoder not initialized".into()))?;
178        let decoder = self
179            .decoder
180            .as_ref()
181            .ok_or_else(|| candle_core::Error::Msg("Decoder not initialized".into()))?;
182
183        // Encode to get mean and log-variance
184        let (mean, logvar) = encoder.encode(x)?;
185
186        // Sample from latent distribution
187        let z = encoder.reparameterize(&mean, &logvar)?;
188
189        // Decode back to original space
190        let reconstruction = decoder.forward(&z)?;
191
192        Ok((reconstruction, mean, logvar))
193    }
194
195    /// Compute VAE loss (reconstruction + KL divergence)
196    pub fn compute_loss(
197        &self,
198        x: &Tensor,
199        reconstruction: &Tensor,
200        mean: &Tensor,
201        logvar: &Tensor,
202    ) -> CandleResult<Tensor> {
203        let batch_size = x.dims()[0] as f32;
204
205        // Reconstruction loss (MSE)
206        let diff = (x - reconstruction)?;
207        let squared = diff.powf(2.0)?;
208        let sum_tensor = squared.sum_all()?;
209        let batch_tensor = Tensor::new(batch_size, &self.device)?;
210        let recon_loss = (&sum_tensor / &batch_tensor)?;
211
212        // KL divergence: -0.5 * sum(1 + log(σ²) - μ² - σ²)
213        let kl_div = {
214            let var = logvar.exp()?;
215            let mean_sq = mean.powf(2.0)?;
216            let one_tensor = Tensor::ones_like(logvar)?;
217            let logvar_plus_one = (logvar + &one_tensor)?;
218            let minus_mean_sq = (&logvar_plus_one - &mean_sq)?;
219            let kl_per_dim = (&minus_mean_sq - &var)?;
220            let kl_sum = kl_per_dim.sum_all()?;
221            let neg_half = Tensor::new(-0.5f32, &self.device)?;
222            let kl_scaled = (&kl_sum * &neg_half)?;
223            let batch_tensor = Tensor::new(batch_size, &self.device)?;
224            (&kl_scaled / &batch_tensor)?
225        };
226
227        // Total loss with β weighting
228        let beta_tensor = Tensor::new(self.beta, &self.device)?;
229        let weighted_kl = (&kl_div * &beta_tensor)?;
230        let total_loss = (&recon_loss + &weighted_kl)?;
231
232        Ok(total_loss)
233    }
234
235    /// Train the VAE on a batch of position vectors
236    pub fn train_step(&mut self, vectors: &Array2<f32>) -> Result<f32, String> {
237        // Convert to tensor
238        let batch_tensor = self.array_to_tensor(vectors)?;
239
240        // Forward pass
241        let (reconstruction, mean, logvar) = self
242            .forward(&batch_tensor)
243            .map_err(|_e| "Processing...".to_string())?;
244
245        // Compute loss
246        let loss = self
247            .compute_loss(&batch_tensor, &reconstruction, &mean, &logvar)
248            .map_err(|_e| "Processing...".to_string())?;
249
250        // Get loss value for return
251        let loss_value = loss
252            .to_scalar::<f32>()
253            .map_err(|_e| "Processing...".to_string())?;
254
255        // Backward pass
256        let grads = loss.backward().map_err(|_e| "Processing...".to_string())?;
257
258        // Now get optimizer and step
259        let optimizer = self.optimizer.as_mut().ok_or("Optimizer not initialized")?;
260        optimizer
261            .step(&grads)
262            .map_err(|_e| "Processing...".to_string())?;
263
264        // Return loss value
265        Ok(loss_value)
266    }
267
268    /// Encode positions to latent space with uncertainty
269    pub fn encode(&self, vectors: &Array2<f32>) -> Result<(Array2<f32>, Array2<f32>), String> {
270        let encoder = self.encoder.as_ref().ok_or("Encoder not initialized")?;
271
272        let input_tensor = self.array_to_tensor(vectors)?;
273        let (mean, logvar) = encoder
274            .encode(&input_tensor)
275            .map_err(|_e| "Processing...".to_string())?;
276
277        let mean_array = self.tensor_to_array(&mean)?;
278        let logvar_array = self.tensor_to_array(&logvar)?;
279
280        Ok((mean_array, logvar_array))
281    }
282
283    /// Sample from the latent space
284    pub fn sample_latent(
285        &self,
286        mean: &Array2<f32>,
287        logvar: &Array2<f32>,
288    ) -> Result<Array2<f32>, String> {
289        let encoder = self.encoder.as_ref().ok_or("Encoder not initialized")?;
290
291        let mean_tensor = self.array_to_tensor(mean)?;
292        let logvar_tensor = self.array_to_tensor(logvar)?;
293
294        let z = encoder
295            .reparameterize(&mean_tensor, &logvar_tensor)
296            .map_err(|_e| "Processing...".to_string())?;
297
298        self.tensor_to_array(&z)
299    }
300
301    /// Decode from latent space
302    pub fn decode(&self, latent_vectors: &Array2<f32>) -> Result<Array2<f32>, String> {
303        let decoder = self.decoder.as_ref().ok_or("Decoder not initialized")?;
304
305        let latent_tensor = self.array_to_tensor(latent_vectors)?;
306        let output = decoder
307            .forward(&latent_tensor)
308            .map_err(|_e| "Processing...".to_string())?;
309
310        self.tensor_to_array(&output)
311    }
312
313    /// Full encoding pipeline (encode then sample)
314    pub fn encode_with_sampling(&self, vectors: &Array2<f32>) -> Result<Array2<f32>, String> {
315        let (mean, logvar) = self.encode(vectors)?;
316        self.sample_latent(&mean, &logvar)
317    }
318
319    /// Generate new samples from the learned manifold
320    pub fn generate(&self, num_samples: usize) -> Result<Array2<f32>, String> {
321        let _decoder = self.decoder.as_ref().ok_or("Decoder not initialized")?;
322
323        // Sample from standard normal distribution
324        let latent_samples = Array2::from_shape_fn((num_samples, self.latent_dim), |_| {
325            use rand::Rng;
326            let mut rng = rand::thread_rng();
327            rng.gen::<f32>() * 2.0 - 1.0 // Sample from [-1, 1]
328        });
329
330        self.decode(&latent_samples)
331    }
332
333    /// Get reconstruction quality metrics
334    pub fn evaluate_reconstruction(
335        &self,
336        vectors: &Array2<f32>,
337    ) -> Result<HashMap<String, f32>, String> {
338        let input_tensor = self.array_to_tensor(vectors)?;
339        let (reconstruction, _mean, _logvar) = self
340            .forward(&input_tensor)
341            .map_err(|_e| "Processing...".to_string())?;
342
343        let reconstruction_array = self.tensor_to_array(&reconstruction)?;
344
345        // Compute metrics
346        let mut metrics = HashMap::new();
347
348        // MSE
349        let mse = vectors
350            .iter()
351            .zip(reconstruction_array.iter())
352            .map(|(a, b)| (a - b).powi(2))
353            .sum::<f32>()
354            / (vectors.len() as f32);
355        metrics.insert("mse".to_string(), mse);
356
357        // RMSE
358        metrics.insert("rmse".to_string(), mse.sqrt());
359
360        // Mean absolute error
361        let mae = vectors
362            .iter()
363            .zip(reconstruction_array.iter())
364            .map(|(a, b)| (a - b).abs())
365            .sum::<f32>()
366            / (vectors.len() as f32);
367        metrics.insert("mae".to_string(), mae);
368
369        // Compression ratio
370        let compression_ratio = self.input_dim as f32 / self.latent_dim as f32;
371        metrics.insert("compression_ratio".to_string(), compression_ratio);
372
373        Ok(metrics)
374    }
375
376    /// Get the latent dimensionality
377    pub fn latent_dim(&self) -> usize {
378        self.latent_dim
379    }
380
381    /// Check if the VAE is initialized
382    pub fn is_initialized(&self) -> bool {
383        self.encoder.is_some() && self.decoder.is_some() && self.optimizer.is_some()
384    }
385
386    // Helper methods for tensor conversions
387    fn array_to_tensor(&self, array: &Array2<f32>) -> Result<Tensor, String> {
388        let shape = array.shape();
389        let data: Vec<f32> = array.iter().cloned().collect();
390        Tensor::from_vec(data, (shape[0], shape[1]), &self.device)
391            .map_err(|_e| "Processing...".to_string())
392    }
393
394    fn tensor_to_array(&self, tensor: &Tensor) -> Result<Array2<f32>, String> {
395        let shape = tensor.shape();
396        if shape.dims().len() != 2 {
397            return Err("Expected 2D tensor".to_string());
398        }
399
400        let data = tensor
401            .to_vec2::<f32>()
402            .map_err(|_e| "Processing...".to_string())?;
403
404        Array2::from_shape_vec((shape.dims()[0], shape.dims()[1]), data.concat())
405            .map_err(|_e| "Processing...".to_string())
406    }
407}
408
409/// Configuration for VAE training
410#[derive(Debug, Clone)]
411pub struct VAEConfig {
412    pub hidden_dims: Vec<usize>,
413    pub beta: f32,
414    pub learning_rate: f32,
415    pub batch_size: usize,
416    pub epochs: usize,
417}
418
419impl Default for VAEConfig {
420    fn default() -> Self {
421        Self {
422            hidden_dims: vec![512, 256, 128], // Deeper architecture
423            beta: 1.0,                        // Standard VAE
424            learning_rate: 0.0005,
425            batch_size: 32,
426            epochs: 100,
427        }
428    }
429}
430
431impl VAEConfig {
432    /// Configuration for β-VAE with higher disentanglement
433    pub fn beta_vae(beta: f32) -> Self {
434        Self {
435            beta,
436            ..Default::default()
437        }
438    }
439
440    /// Configuration for high compression ratio
441    pub fn high_compression() -> Self {
442        Self {
443            hidden_dims: vec![512, 256, 128, 64], // More layers for better compression
444            beta: 0.5,                            // Lower KL weight for better reconstruction
445            ..Default::default()
446        }
447    }
448
449    /// Configuration optimized for chess positions
450    pub fn chess_optimized() -> Self {
451        Self {
452            hidden_dims: vec![512, 256, 128], // Balanced architecture
453            beta: 0.8,                        // Slightly favor reconstruction
454            learning_rate: 0.001,
455            batch_size: 64,
456            epochs: 150,
457        }
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use ndarray::Array2;
465
466    #[test]
467    fn test_vae_initialization() {
468        let mut vae = VariationalAutoencoder::new(1024, 128, 1.0);
469        let config = VAEConfig::default();
470
471        assert!(vae.init_network(&config.hidden_dims).is_ok());
472        assert!(vae.is_initialized());
473        assert_eq!(vae.latent_dim(), 128);
474    }
475
476    #[test]
477    fn test_vae_forward_pass() {
478        let mut vae = VariationalAutoencoder::new(64, 16, 1.0);
479        let config = VAEConfig::default();
480        vae.init_network(&config.hidden_dims).unwrap();
481
482        let test_data = Array2::from_shape_fn((4, 64), |_| 0.5);
483        let result = vae.encode_with_sampling(&test_data);
484
485        assert!(result.is_ok());
486        let encoded = result.unwrap();
487        assert_eq!(encoded.shape(), &[4, 16]);
488    }
489
490    #[test]
491    fn test_vae_reconstruction() {
492        let mut vae = VariationalAutoencoder::new(32, 8, 1.0);
493        let config = VAEConfig::default();
494        vae.init_network(&config.hidden_dims).unwrap();
495
496        let test_data = Array2::from_shape_fn((2, 32), |_| 0.3);
497        let encoded = vae.encode_with_sampling(&test_data).unwrap();
498        let decoded = vae.decode(&encoded).unwrap();
499
500        assert_eq!(decoded.shape(), test_data.shape());
501    }
502
503    #[test]
504    fn test_vae_generation() {
505        let mut vae = VariationalAutoencoder::new(16, 4, 1.0);
506        let config = VAEConfig::default();
507        vae.init_network(&config.hidden_dims).unwrap();
508
509        let generated = vae.generate(3);
510        assert!(generated.is_ok());
511
512        let samples = generated.unwrap();
513        assert_eq!(samples.shape(), &[3, 16]);
514    }
515}