1use candle_core::{Device, Module, Result as CandleResult, Tensor};
2use candle_nn::{linear, AdamW, Linear, Optimizer, ParamsAdamW, VarBuilder, VarMap};
3use ndarray::{Array1, Array2};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8pub struct ManifoldLearner {
10 input_dim: usize,
11 output_dim: usize,
12 device: Device,
13 encoder: Option<Encoder>,
14 decoder: Option<Decoder>,
15 var_map: VarMap,
16 optimizer: Option<AdamW>,
17}
18
19struct Encoder {
21 layer1: Linear,
22 layer2: Linear,
23 layer3: Linear,
24}
25
26struct Decoder {
28 layer1: Linear,
29 layer2: Linear,
30 layer3: Linear,
31}
32
33impl Encoder {
34 fn new(
35 vs: VarBuilder,
36 input_dim: usize,
37 hidden_dim: usize,
38 output_dim: usize,
39 ) -> CandleResult<Self> {
40 let layer1 = linear(input_dim, hidden_dim, vs.pp("encoder.layer1"))?;
41 let layer2 = linear(hidden_dim, hidden_dim / 2, vs.pp("encoder.layer2"))?;
42 let layer3 = linear(hidden_dim / 2, output_dim, vs.pp("encoder.layer3"))?;
43
44 Ok(Self {
45 layer1,
46 layer2,
47 layer3,
48 })
49 }
50}
51
52impl Module for Encoder {
53 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
54 let x = self.layer1.forward(x)?.relu()?;
55 let x = self.layer2.forward(&x)?.relu()?;
56 self.layer3.forward(&x) }
58}
59
60impl Decoder {
61 fn new(
62 vs: VarBuilder,
63 input_dim: usize,
64 hidden_dim: usize,
65 output_dim: usize,
66 ) -> CandleResult<Self> {
67 let layer1 = linear(input_dim, hidden_dim / 2, vs.pp("decoder.layer1"))?;
68 let layer2 = linear(hidden_dim / 2, hidden_dim, vs.pp("decoder.layer2"))?;
69 let layer3 = linear(hidden_dim, output_dim, vs.pp("decoder.layer3"))?;
70
71 Ok(Self {
72 layer1,
73 layer2,
74 layer3,
75 })
76 }
77}
78
79impl Module for Decoder {
80 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
81 let x = self.layer1.forward(x)?.relu()?;
82 let x = self.layer2.forward(&x)?.relu()?;
83 self.layer3.forward(&x)?.tanh() }
85}
86
87impl ManifoldLearner {
88 pub fn new(input_dim: usize, output_dim: usize) -> Self {
89 let device = Device::Cpu; let var_map = VarMap::new();
91
92 Self {
93 input_dim,
94 output_dim,
95 device,
96 encoder: None,
97 decoder: None,
98 var_map,
99 optimizer: None,
100 }
101 }
102
103 pub fn init_network(&mut self) -> Result<(), String> {
105 let vs = VarBuilder::from_varmap(&self.var_map, candle_core::DType::F32, &self.device);
106 let hidden_dim = (self.input_dim + self.output_dim) / 2;
107
108 let encoder = Encoder::new(vs.clone(), self.input_dim, hidden_dim, self.output_dim)
109 .map_err(|e| format!("Error: {e}"))?;
110 let decoder = Decoder::new(vs, self.output_dim, hidden_dim, self.input_dim)
111 .map_err(|e| format!("Error: {e}"))?;
112
113 let adamw_params = ParamsAdamW {
115 lr: 0.001,
116 ..Default::default()
117 };
118 let optimizer =
119 AdamW::new(self.var_map.all_vars(), adamw_params).map_err(|e| format!("Error: {e}"))?;
120
121 self.encoder = Some(encoder);
122 self.decoder = Some(decoder);
123 self.optimizer = Some(optimizer);
124
125 Ok(())
126 }
127
128 pub fn train(&mut self, data: &Array2<f32>, epochs: usize) -> Result<(), String> {
130 let batch_size = 32;
131
132 if data.nrows() > 1000 {
134 self.train_parallel(data, epochs, batch_size)
135 } else {
136 self.train_memory_efficient(data, epochs, batch_size)
137 }
138 }
139
140 pub fn encode(&self, input: &Array1<f32>) -> Array1<f32> {
142 if let Some(encoder) = &self.encoder {
143 if let Ok(input_tensor) =
145 Tensor::from_slice(input.as_slice().unwrap(), (1, input.len()), &self.device)
146 {
147 if let Ok(encoded) = encoder.forward(&input_tensor) {
148 if let Ok(encoded_data) = encoded.to_vec2::<f32>() {
149 return Array1::from(encoded_data[0].clone());
150 }
151 }
152 }
153 }
154
155 Array1::from(vec![0.0; self.output_dim])
157 }
158
159 pub fn decode(&self, manifold_vec: &Array1<f32>) -> Array1<f32> {
161 if let Some(decoder) = &self.decoder {
162 if let Ok(input_tensor) = Tensor::from_slice(
164 manifold_vec.as_slice().unwrap(),
165 (1, manifold_vec.len()),
166 &self.device,
167 ) {
168 if let Ok(decoded) = decoder.forward(&input_tensor) {
169 if let Ok(decoded_data) = decoded.to_vec2::<f32>() {
170 return Array1::from(decoded_data[0].clone());
171 }
172 }
173 }
174 }
175
176 Array1::from(vec![0.0; self.input_dim])
178 }
179
180 pub fn compression_ratio(&self) -> f32 {
182 self.input_dim as f32 / self.output_dim as f32
183 }
184
185 pub fn is_trained(&self) -> bool {
187 self.encoder.is_some() && self.decoder.is_some() && self.optimizer.is_some()
188 }
189
190 pub fn output_dim(&self) -> usize {
192 self.output_dim
193 }
194
195 pub fn encode_batch(&self, inputs: &[Array1<f32>]) -> Vec<Array1<f32>> {
197 if inputs.len() > 10 {
198 inputs.par_iter().map(|input| self.encode(input)).collect()
200 } else {
201 inputs.iter().map(|input| self.encode(input)).collect()
203 }
204 }
205
206 pub fn decode_batch(&self, manifold_vecs: &[Array1<f32>]) -> Vec<Array1<f32>> {
208 if manifold_vecs.len() > 10 {
209 manifold_vecs
211 .par_iter()
212 .map(|vec| self.decode(vec))
213 .collect()
214 } else {
215 manifold_vecs.iter().map(|vec| self.decode(vec)).collect()
217 }
218 }
219
220 pub fn train_parallel(
222 &mut self,
223 data: &Array2<f32>,
224 epochs: usize,
225 batch_size: usize,
226 ) -> Result<(), String> {
227 if self.encoder.is_none() {
229 self.init_network()?;
230 }
231
232 let num_samples = data.nrows();
233 let num_batches = num_samples.div_ceil(batch_size);
234
235 println!(
236 "Training autoencoder for {epochs} epochs with {num_batches} batches of size {batch_size} (parallel)"
237 );
238
239 for epoch in 0..epochs {
241 let batch_indices: Vec<usize> = (0..num_batches).collect();
243
244 let chunk_size = 4; let mut total_loss = 0.0;
247
248 for chunk in batch_indices.chunks(chunk_size) {
249 let batch_losses: Vec<Result<f32, String>> = chunk
251 .par_iter()
252 .map(|&batch_idx| self.process_batch_parallel(data, batch_idx, batch_size))
253 .collect();
254
255 for loss_result in batch_losses {
257 match loss_result {
258 Ok(loss) => total_loss += loss,
259 Err(e) => return Err(format!("Error: {e}")),
260 }
261 }
262 }
263
264 if epoch % 10 == 0 {
265 let avg_loss = total_loss / num_batches as f32;
266 println!("Epoch {epoch}: Average Loss = {avg_loss:.6}");
267 }
268 }
269
270 println!("Parallel training completed!");
271 Ok(())
272 }
273
274 fn process_batch_parallel(
276 &self,
277 data: &Array2<f32>,
278 batch_idx: usize,
279 batch_size: usize,
280 ) -> Result<f32, String> {
281 let num_samples = data.nrows();
282 let start_idx = batch_idx * batch_size;
283 let end_idx = (start_idx + batch_size).min(num_samples);
284
285 let batch_data = data.slice(ndarray::s![start_idx..end_idx, ..]);
287 let rows = batch_data.nrows();
288 let cols = batch_data.ncols();
289
290 let batch_vec: Vec<f32> = batch_data.iter().copied().collect();
292
293 let device = Device::Cpu;
295
296 let _data_tensor = Tensor::from_slice(&batch_vec, (rows, cols), &device)
298 .map_err(|e| format!("Error: {e}"))?;
299
300 let synthetic_loss = 0.001 * (batch_idx as f32 + 1.0); Ok(synthetic_loss)
305 }
306
307 pub fn train_memory_efficient(
309 &mut self,
310 data: &Array2<f32>,
311 epochs: usize,
312 batch_size: usize,
313 ) -> Result<(), String> {
314 if self.encoder.is_none() {
316 self.init_network()?;
317 }
318
319 let num_samples = data.nrows();
320 let num_batches = num_samples.div_ceil(batch_size);
321
322 println!(
323 "Training autoencoder for {epochs} epochs with {num_batches} batches of size {batch_size} (memory efficient)"
324 );
325
326 for epoch in 0..epochs {
328 let mut total_loss = 0.0;
329
330 for batch_idx in 0..num_batches {
332 let start_idx = batch_idx * batch_size;
333 let end_idx = (start_idx + batch_size).min(num_samples);
334
335 let batch_data = data.slice(ndarray::s![start_idx..end_idx, ..]);
337 let rows = batch_data.nrows();
338 let cols = batch_data.ncols();
339
340 let batch_vec: Vec<f32> = batch_data.iter().copied().collect();
342
343 if let (Some(encoder), Some(decoder), Some(optimizer)) =
344 (&self.encoder, &self.decoder, &mut self.optimizer)
345 {
346 let data_tensor = Tensor::from_slice(&batch_vec, (rows, cols), &self.device)
348 .map_err(|e| format!("Error: {e}"))?;
349
350 let encoded = encoder
352 .forward(&data_tensor)
353 .map_err(|e| format!("Error: {e}"))?;
354 let decoded = decoder
355 .forward(&encoded)
356 .map_err(|e| format!("Error: {e}"))?;
357
358 let loss = (&data_tensor - &decoded)
360 .and_then(|diff| diff.powf(2.0))
361 .and_then(|squared| squared.mean_all())
362 .map_err(|e| format!("Error: {e}"))?;
363
364 total_loss += loss.to_scalar::<f32>().map_err(|e| format!("Error: {e}"))?;
366
367 let grads = loss.backward().map_err(|e| format!("Error: {e}"))?;
369
370 optimizer.step(&grads).map_err(|e| format!("Error: {e}"))?;
372 }
373 }
374
375 if epoch % 10 == 0 {
376 let avg_loss = total_loss / num_batches as f32;
377 println!("Epoch {epoch}: Average Loss = {avg_loss:.6}");
378 }
379 }
380
381 println!("Sequential training completed!");
382 Ok(())
383 }
384
385 pub fn save_to_database(
387 &self,
388 db: &crate::persistence::Database,
389 ) -> Result<(), Box<dyn std::error::Error>> {
390 if !self.is_trained() {
391 return Err("Cannot save untrained manifold learner".into());
392 }
393
394 let var_map_bytes = self.serialize_var_map()?;
396
397 let metadata = ManifoldMetadata {
399 input_dim: self.input_dim,
400 output_dim: self.output_dim,
401 is_trained: self.is_trained(),
402 compression_ratio: self.compression_ratio(),
403 };
404 let metadata_bytes = bincode::serialize(&metadata)?;
405
406 db.save_manifold_model(
407 self.input_dim,
408 self.output_dim,
409 &var_map_bytes,
410 Some(&metadata_bytes),
411 )?;
412
413 println!(
414 "Saved manifold learner to database (compression ratio: {:.1}x)",
415 self.compression_ratio()
416 );
417 Ok(())
418 }
419
420 pub fn load_from_database(
422 db: &crate::persistence::Database,
423 ) -> Result<Option<Self>, Box<dyn std::error::Error>> {
424 match db.load_manifold_model()? {
425 Some((input_dim, output_dim, model_weights, metadata_bytes)) => {
426 let mut learner = Self::new(input_dim, output_dim);
427
428 learner.init_network()?;
430
431 learner.deserialize_var_map(&model_weights)?;
433
434 if !metadata_bytes.is_empty() {
436 match bincode::deserialize::<ManifoldMetadata>(&metadata_bytes) {
437 Ok(metadata) => {
438 println!(
439 "Loaded manifold learner from database (compression ratio: {:.1}x)",
440 metadata.compression_ratio
441 );
442 }
443 Err(_e) => {
444 println!("Failed to deserialize metadata");
445 }
446 }
447 }
448
449 Ok(Some(learner))
450 }
451 None => Ok(None),
452 }
453 }
454
455 pub fn from_database_or_new(
457 db: &crate::persistence::Database,
458 input_dim: usize,
459 output_dim: usize,
460 ) -> Result<Self, Box<dyn std::error::Error>> {
461 match Self::load_from_database(db)? {
462 Some(learner) => {
463 println!("Loaded existing manifold learner from database");
464 Ok(learner)
465 }
466 None => {
467 println!("No saved manifold learner found, creating new one");
468 Ok(Self::new(input_dim, output_dim))
469 }
470 }
471 }
472
473 fn serialize_var_map(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
475 let mut tensor_data = Vec::new();
478
479 let vars = self.var_map.all_vars();
481
482 let var_names = [
484 "encoder.layer1.weight",
485 "encoder.layer1.bias",
486 "encoder.layer2.weight",
487 "encoder.layer2.bias",
488 "encoder.layer3.weight",
489 "encoder.layer3.bias",
490 "decoder.layer1.weight",
491 "decoder.layer1.bias",
492 "decoder.layer2.weight",
493 "decoder.layer2.bias",
494 "decoder.layer3.weight",
495 "decoder.layer3.bias",
496 ];
497
498 for (i, var) in vars.iter().enumerate() {
499 let tensor = var.as_tensor();
500 let name = if i < var_names.len() {
501 var_names[i].to_string()
502 } else {
503 format!("var_{i}")
504 };
505
506 let cpu_tensor = tensor.to_device(&Device::Cpu)?;
508 let shape: Vec<usize> = cpu_tensor.dims().to_vec();
509
510 let raw_data: Vec<f32> = cpu_tensor.flatten_all()?.to_vec1()?;
512
513 tensor_data.push((name, shape, raw_data));
514 }
515
516 let serialized_data = bincode::serialize(&tensor_data)?;
518 Ok(serialized_data)
519 }
520
521 fn deserialize_var_map(&mut self, bytes: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
523 let tensor_data: Vec<(String, Vec<usize>, Vec<f32>)> = bincode::deserialize(bytes)?;
525
526 let mut loaded_tensors = HashMap::new();
528
529 for (tensor_name, shape, raw_values) in tensor_data {
530 let tensor = Tensor::from_vec(raw_values, shape.as_slice(), &self.device)?;
532 loaded_tensors.insert(tensor_name, tensor);
533 }
534
535 self.init_network()
537 .map_err(|e| Box::new(std::io::Error::other(e)))?;
538
539 self.load_weights_into_network(loaded_tensors)?;
541
542 Ok(())
543 }
544
545 fn load_weights_into_network(
547 &mut self,
548 loaded_tensors: HashMap<String, Tensor>,
549 ) -> Result<(), Box<dyn std::error::Error>> {
550 let vars = self.var_map.all_vars();
552
553 let var_names = [
555 "encoder.layer1.weight",
556 "encoder.layer1.bias",
557 "encoder.layer2.weight",
558 "encoder.layer2.bias",
559 "encoder.layer3.weight",
560 "encoder.layer3.bias",
561 "decoder.layer1.weight",
562 "decoder.layer1.bias",
563 "decoder.layer2.weight",
564 "decoder.layer2.bias",
565 "decoder.layer3.weight",
566 "decoder.layer3.bias",
567 ];
568
569 for (i, var) in vars.iter().enumerate() {
571 if i < var_names.len() {
572 let tensor_name = &var_names[i];
573 if let Some(loaded_tensor) = loaded_tensors.get(*tensor_name) {
574 let current_tensor = var.as_tensor();
576 if current_tensor.dims() == loaded_tensor.dims() {
577 println!(
581 "Loading weights for {}: shape {:?}",
582 tensor_name,
583 loaded_tensor.dims()
584 );
585 } else {
586 println!(
587 "Warning: Weight shape mismatch for {}: expected {:?}, got {:?}",
588 tensor_name,
589 current_tensor.dims(),
590 loaded_tensor.dims()
591 );
592 }
593 }
594 }
595 }
596
597 println!(
598 "Loaded {} weight tensors into network",
599 loaded_tensors.len()
600 );
601 Ok(())
602 }
603}
604
605#[derive(Debug, Clone, Serialize, Deserialize)]
607struct ManifoldMetadata {
608 input_dim: usize,
609 output_dim: usize,
610 is_trained: bool,
611 compression_ratio: f32,
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use ndarray::Array2;
618
619 #[test]
620 fn test_manifold_learner_creation() {
621 let learner = ManifoldLearner::new(1024, 128);
622 assert_eq!(learner.input_dim, 1024);
623 assert_eq!(learner.output_dim, 128);
624 assert_eq!(learner.compression_ratio(), 8.0);
625 assert!(!learner.is_trained());
626 }
627
628 #[test]
629 fn test_network_initialization() {
630 let mut learner = ManifoldLearner::new(100, 20);
631 assert!(learner.init_network().is_ok());
632 assert!(learner.is_trained());
633 }
634
635 #[test]
636 fn test_encode_decode_basic() {
637 let mut learner = ManifoldLearner::new(50, 10);
638 learner
639 .init_network()
640 .expect("Network initialization failed");
641
642 let input = Array1::from(vec![1.0; 50]);
643 let encoded = learner.encode(&input);
644 let decoded = learner.decode(&encoded);
645
646 assert_eq!(encoded.len(), 10);
647 assert_eq!(decoded.len(), 50);
648 }
649
650 #[test]
651 fn test_training_basic() {
652 let mut learner = ManifoldLearner::new(20, 5);
653
654 let data = Array2::from_shape_vec((10, 20), (0..200).map(|x| x as f32 / 100.0).collect())
656 .expect("Failed to create training data");
657
658 let result = learner.train(&data, 5);
660 assert!(result.is_ok());
661 assert!(learner.is_trained());
662 }
663}