1use candle_core::{Device, Module, Result as CandleResult, Tensor};
2use candle_nn::{linear, AdamW, Linear, Optimizer, ParamsAdamW, VarBuilder, VarMap};
3use ndarray::Array2;
4use std::collections::HashMap;
5
6pub struct VariationalAutoencoder {
8 input_dim: usize,
9 latent_dim: usize,
10 device: Device,
11 encoder: Option<VariationalEncoder>,
12 decoder: Option<VariationalDecoder>,
13 var_map: VarMap,
14 optimizer: Option<AdamW>,
15 beta: f32, }
17
18struct VariationalEncoder {
20 shared_layers: Vec<Linear>,
21 mean_layer: Linear,
22 logvar_layer: Linear,
23}
24
25struct VariationalDecoder {
27 layers: Vec<Linear>,
28}
29
30impl VariationalEncoder {
31 fn new(
32 vs: VarBuilder,
33 input_dim: usize,
34 hidden_dims: &[usize],
35 latent_dim: usize,
36 ) -> CandleResult<Self> {
37 let mut shared_layers = Vec::new();
38 let mut prev_dim = input_dim;
39
40 for (i, &hidden_dim) in hidden_dims.iter().enumerate() {
42 let layer = linear(prev_dim, hidden_dim, vs.pp(format!("encoder.layer{i}")))?;
43 shared_layers.push(layer);
44 prev_dim = hidden_dim;
45 }
46
47 let mean_layer = linear(prev_dim, latent_dim, vs.pp("encoder.mean"))?;
49 let logvar_layer = linear(prev_dim, latent_dim, vs.pp("encoder.logvar"))?;
50
51 Ok(Self {
52 shared_layers,
53 mean_layer,
54 logvar_layer,
55 })
56 }
57
58 fn encode(&self, x: &Tensor) -> CandleResult<(Tensor, Tensor)> {
60 let mut h = x.clone();
61
62 for layer in &self.shared_layers {
64 h = layer.forward(&h)?.relu()?;
65 }
66
67 let mean = self.mean_layer.forward(&h)?;
69 let logvar = self.logvar_layer.forward(&h)?;
70
71 Ok((mean, logvar))
72 }
73
74 fn reparameterize(&self, mean: &Tensor, logvar: &Tensor) -> CandleResult<Tensor> {
76 let std = (logvar * 0.5)?.exp()?;
77 let eps = Tensor::randn_like(&std, 0.0, 1.0)?;
78 let scaled_eps = (&std * &eps)?;
79 mean + &scaled_eps
80 }
81}
82
83impl VariationalDecoder {
84 fn new(
85 vs: VarBuilder,
86 latent_dim: usize,
87 hidden_dims: &[usize],
88 output_dim: usize,
89 ) -> CandleResult<Self> {
90 let mut layers = Vec::new();
91 let mut prev_dim = latent_dim;
92
93 for (i, &hidden_dim) in hidden_dims.iter().rev().enumerate() {
95 let layer = linear(prev_dim, hidden_dim, vs.pp(format!("decoder.layer{i}")))?;
96 layers.push(layer);
97 prev_dim = hidden_dim;
98 }
99
100 let output_layer = linear(prev_dim, output_dim, vs.pp("decoder.output"))?;
102 layers.push(output_layer);
103
104 Ok(Self { layers })
105 }
106}
107
108impl Module for VariationalDecoder {
109 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
110 let mut h = x.clone();
111
112 for (i, layer) in self.layers.iter().enumerate() {
114 h = layer.forward(&h)?;
115 if i < self.layers.len() - 1 {
116 h = h.relu()?;
117 } else {
118 h = h.tanh()?;
120 }
121 }
122
123 Ok(h)
124 }
125}
126
127impl VariationalAutoencoder {
128 pub fn new(input_dim: usize, latent_dim: usize, beta: f32) -> Self {
129 let device = Device::Cpu; let var_map = VarMap::new();
131
132 Self {
133 input_dim,
134 latent_dim,
135 device,
136 encoder: None,
137 decoder: None,
138 var_map,
139 optimizer: None,
140 beta,
141 }
142 }
143
144 pub fn init_network(&mut self, hidden_dims: &[usize]) -> Result<(), String> {
146 let vs = VarBuilder::from_varmap(&self.var_map, candle_core::DType::F32, &self.device);
147
148 let encoder =
149 VariationalEncoder::new(vs.clone(), self.input_dim, hidden_dims, self.latent_dim)
150 .map_err(|_e| "Processing...".to_string())?;
151 let decoder = VariationalDecoder::new(vs, self.latent_dim, hidden_dims, self.input_dim)
152 .map_err(|_e| "Processing...".to_string())?;
153
154 let adamw_params = ParamsAdamW {
156 lr: 0.0005,
157 beta1: 0.9,
158 beta2: 0.999,
159 eps: 1e-8,
160 weight_decay: 1e-4,
161 };
162 let optimizer = AdamW::new(self.var_map.all_vars(), adamw_params)
163 .map_err(|_e| "Processing...".to_string())?;
164
165 self.encoder = Some(encoder);
166 self.decoder = Some(decoder);
167 self.optimizer = Some(optimizer);
168
169 Ok(())
170 }
171
172 pub fn forward(&self, x: &Tensor) -> CandleResult<(Tensor, Tensor, Tensor)> {
174 let encoder = self
175 .encoder
176 .as_ref()
177 .ok_or_else(|| candle_core::Error::Msg("Encoder not initialized".into()))?;
178 let decoder = self
179 .decoder
180 .as_ref()
181 .ok_or_else(|| candle_core::Error::Msg("Decoder not initialized".into()))?;
182
183 let (mean, logvar) = encoder.encode(x)?;
185
186 let z = encoder.reparameterize(&mean, &logvar)?;
188
189 let reconstruction = decoder.forward(&z)?;
191
192 Ok((reconstruction, mean, logvar))
193 }
194
195 pub fn compute_loss(
197 &self,
198 x: &Tensor,
199 reconstruction: &Tensor,
200 mean: &Tensor,
201 logvar: &Tensor,
202 ) -> CandleResult<Tensor> {
203 let batch_size = x.dims()[0] as f32;
204
205 let diff = (x - reconstruction)?;
207 let squared = diff.powf(2.0)?;
208 let sum_tensor = squared.sum_all()?;
209 let batch_tensor = Tensor::new(batch_size, &self.device)?;
210 let recon_loss = (&sum_tensor / &batch_tensor)?;
211
212 let kl_div = {
214 let var = logvar.exp()?;
215 let mean_sq = mean.powf(2.0)?;
216 let one_tensor = Tensor::ones_like(logvar)?;
217 let logvar_plus_one = (logvar + &one_tensor)?;
218 let minus_mean_sq = (&logvar_plus_one - &mean_sq)?;
219 let kl_per_dim = (&minus_mean_sq - &var)?;
220 let kl_sum = kl_per_dim.sum_all()?;
221 let neg_half = Tensor::new(-0.5f32, &self.device)?;
222 let kl_scaled = (&kl_sum * &neg_half)?;
223 let batch_tensor = Tensor::new(batch_size, &self.device)?;
224 (&kl_scaled / &batch_tensor)?
225 };
226
227 let beta_tensor = Tensor::new(self.beta, &self.device)?;
229 let weighted_kl = (&kl_div * &beta_tensor)?;
230 let total_loss = (&recon_loss + &weighted_kl)?;
231
232 Ok(total_loss)
233 }
234
235 pub fn train_step(&mut self, vectors: &Array2<f32>) -> Result<f32, String> {
237 let batch_tensor = self.array_to_tensor(vectors)?;
239
240 let (reconstruction, mean, logvar) = self
242 .forward(&batch_tensor)
243 .map_err(|_e| "Processing...".to_string())?;
244
245 let loss = self
247 .compute_loss(&batch_tensor, &reconstruction, &mean, &logvar)
248 .map_err(|_e| "Processing...".to_string())?;
249
250 let loss_value = loss
252 .to_scalar::<f32>()
253 .map_err(|_e| "Processing...".to_string())?;
254
255 let grads = loss.backward().map_err(|_e| "Processing...".to_string())?;
257
258 let optimizer = self.optimizer.as_mut().ok_or("Optimizer not initialized")?;
260 optimizer
261 .step(&grads)
262 .map_err(|_e| "Processing...".to_string())?;
263
264 Ok(loss_value)
266 }
267
268 pub fn encode(&self, vectors: &Array2<f32>) -> Result<(Array2<f32>, Array2<f32>), String> {
270 let encoder = self.encoder.as_ref().ok_or("Encoder not initialized")?;
271
272 let input_tensor = self.array_to_tensor(vectors)?;
273 let (mean, logvar) = encoder
274 .encode(&input_tensor)
275 .map_err(|_e| "Processing...".to_string())?;
276
277 let mean_array = self.tensor_to_array(&mean)?;
278 let logvar_array = self.tensor_to_array(&logvar)?;
279
280 Ok((mean_array, logvar_array))
281 }
282
283 pub fn sample_latent(
285 &self,
286 mean: &Array2<f32>,
287 logvar: &Array2<f32>,
288 ) -> Result<Array2<f32>, String> {
289 let encoder = self.encoder.as_ref().ok_or("Encoder not initialized")?;
290
291 let mean_tensor = self.array_to_tensor(mean)?;
292 let logvar_tensor = self.array_to_tensor(logvar)?;
293
294 let z = encoder
295 .reparameterize(&mean_tensor, &logvar_tensor)
296 .map_err(|_e| "Processing...".to_string())?;
297
298 self.tensor_to_array(&z)
299 }
300
301 pub fn decode(&self, latent_vectors: &Array2<f32>) -> Result<Array2<f32>, String> {
303 let decoder = self.decoder.as_ref().ok_or("Decoder not initialized")?;
304
305 let latent_tensor = self.array_to_tensor(latent_vectors)?;
306 let output = decoder
307 .forward(&latent_tensor)
308 .map_err(|_e| "Processing...".to_string())?;
309
310 self.tensor_to_array(&output)
311 }
312
313 pub fn encode_with_sampling(&self, vectors: &Array2<f32>) -> Result<Array2<f32>, String> {
315 let (mean, logvar) = self.encode(vectors)?;
316 self.sample_latent(&mean, &logvar)
317 }
318
319 pub fn generate(&self, num_samples: usize) -> Result<Array2<f32>, String> {
321 let _decoder = self.decoder.as_ref().ok_or("Decoder not initialized")?;
322
323 let latent_samples = Array2::from_shape_fn((num_samples, self.latent_dim), |_| {
325 use rand::Rng;
326 let mut rng = rand::thread_rng();
327 rng.gen::<f32>() * 2.0 - 1.0 });
329
330 self.decode(&latent_samples)
331 }
332
333 pub fn evaluate_reconstruction(
335 &self,
336 vectors: &Array2<f32>,
337 ) -> Result<HashMap<String, f32>, String> {
338 let input_tensor = self.array_to_tensor(vectors)?;
339 let (reconstruction, _mean, _logvar) = self
340 .forward(&input_tensor)
341 .map_err(|_e| "Processing...".to_string())?;
342
343 let reconstruction_array = self.tensor_to_array(&reconstruction)?;
344
345 let mut metrics = HashMap::new();
347
348 let mse = vectors
350 .iter()
351 .zip(reconstruction_array.iter())
352 .map(|(a, b)| (a - b).powi(2))
353 .sum::<f32>()
354 / (vectors.len() as f32);
355 metrics.insert("mse".to_string(), mse);
356
357 metrics.insert("rmse".to_string(), mse.sqrt());
359
360 let mae = vectors
362 .iter()
363 .zip(reconstruction_array.iter())
364 .map(|(a, b)| (a - b).abs())
365 .sum::<f32>()
366 / (vectors.len() as f32);
367 metrics.insert("mae".to_string(), mae);
368
369 let compression_ratio = self.input_dim as f32 / self.latent_dim as f32;
371 metrics.insert("compression_ratio".to_string(), compression_ratio);
372
373 Ok(metrics)
374 }
375
376 pub fn latent_dim(&self) -> usize {
378 self.latent_dim
379 }
380
381 pub fn is_initialized(&self) -> bool {
383 self.encoder.is_some() && self.decoder.is_some() && self.optimizer.is_some()
384 }
385
386 fn array_to_tensor(&self, array: &Array2<f32>) -> Result<Tensor, String> {
388 let shape = array.shape();
389 let data: Vec<f32> = array.iter().cloned().collect();
390 Tensor::from_vec(data, (shape[0], shape[1]), &self.device)
391 .map_err(|_e| "Processing...".to_string())
392 }
393
394 fn tensor_to_array(&self, tensor: &Tensor) -> Result<Array2<f32>, String> {
395 let shape = tensor.shape();
396 if shape.dims().len() != 2 {
397 return Err("Expected 2D tensor".to_string());
398 }
399
400 let data = tensor
401 .to_vec2::<f32>()
402 .map_err(|_e| "Processing...".to_string())?;
403
404 Array2::from_shape_vec((shape.dims()[0], shape.dims()[1]), data.concat())
405 .map_err(|_e| "Processing...".to_string())
406 }
407}
408
409#[derive(Debug, Clone)]
411pub struct VAEConfig {
412 pub hidden_dims: Vec<usize>,
413 pub beta: f32,
414 pub learning_rate: f32,
415 pub batch_size: usize,
416 pub epochs: usize,
417}
418
419impl Default for VAEConfig {
420 fn default() -> Self {
421 Self {
422 hidden_dims: vec![512, 256, 128], beta: 1.0, learning_rate: 0.0005,
425 batch_size: 32,
426 epochs: 100,
427 }
428 }
429}
430
431impl VAEConfig {
432 pub fn beta_vae(beta: f32) -> Self {
434 Self {
435 beta,
436 ..Default::default()
437 }
438 }
439
440 pub fn high_compression() -> Self {
442 Self {
443 hidden_dims: vec![512, 256, 128, 64], beta: 0.5, ..Default::default()
446 }
447 }
448
449 pub fn chess_optimized() -> Self {
451 Self {
452 hidden_dims: vec![512, 256, 128], beta: 0.8, learning_rate: 0.001,
455 batch_size: 64,
456 epochs: 150,
457 }
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use ndarray::Array2;
465
466 #[test]
467 fn test_vae_initialization() {
468 let mut vae = VariationalAutoencoder::new(1024, 128, 1.0);
469 let config = VAEConfig::default();
470
471 assert!(vae.init_network(&config.hidden_dims).is_ok());
472 assert!(vae.is_initialized());
473 assert_eq!(vae.latent_dim(), 128);
474 }
475
476 #[test]
477 fn test_vae_forward_pass() {
478 let mut vae = VariationalAutoencoder::new(64, 16, 1.0);
479 let config = VAEConfig::default();
480 vae.init_network(&config.hidden_dims).unwrap();
481
482 let test_data = Array2::from_shape_fn((4, 64), |_| 0.5);
483 let result = vae.encode_with_sampling(&test_data);
484
485 assert!(result.is_ok());
486 let encoded = result.unwrap();
487 assert_eq!(encoded.shape(), &[4, 16]);
488 }
489
490 #[test]
491 fn test_vae_reconstruction() {
492 let mut vae = VariationalAutoencoder::new(32, 8, 1.0);
493 let config = VAEConfig::default();
494 vae.init_network(&config.hidden_dims).unwrap();
495
496 let test_data = Array2::from_shape_fn((2, 32), |_| 0.3);
497 let encoded = vae.encode_with_sampling(&test_data).unwrap();
498 let decoded = vae.decode(&encoded).unwrap();
499
500 assert_eq!(decoded.shape(), test_data.shape());
501 }
502
503 #[test]
504 fn test_vae_generation() {
505 let mut vae = VariationalAutoencoder::new(16, 4, 1.0);
506 let config = VAEConfig::default();
507 vae.init_network(&config.hidden_dims).unwrap();
508
509 let generated = vae.generate(3);
510 assert!(generated.is_ok());
511
512 let samples = generated.unwrap();
513 assert_eq!(samples.shape(), &[3, 16]);
514 }
515}