1use candle_core::{Device, Module, Result as CandleResult, Tensor};
2use candle_nn::{linear, AdamW, Linear, Optimizer, ParamsAdamW, VarBuilder, VarMap};
3use ndarray::{Array1, Array2};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8pub struct ManifoldLearner {
10 input_dim: usize,
11 output_dim: usize,
12 device: Device,
13 encoder: Option<Encoder>,
14 decoder: Option<Decoder>,
15 var_map: VarMap,
16 optimizer: Option<AdamW>,
17}
18
19struct Encoder {
21 layer1: Linear,
22 layer2: Linear,
23 layer3: Linear,
24}
25
26struct Decoder {
28 layer1: Linear,
29 layer2: Linear,
30 layer3: Linear,
31}
32
33impl Encoder {
34 fn new(
35 vs: VarBuilder,
36 input_dim: usize,
37 hidden_dim: usize,
38 output_dim: usize,
39 ) -> CandleResult<Self> {
40 let layer1 = linear(input_dim, hidden_dim, vs.pp("encoder.layer1"))?;
41 let layer2 = linear(hidden_dim, hidden_dim / 2, vs.pp("encoder.layer2"))?;
42 let layer3 = linear(hidden_dim / 2, output_dim, vs.pp("encoder.layer3"))?;
43
44 Ok(Self {
45 layer1,
46 layer2,
47 layer3,
48 })
49 }
50}
51
52impl Module for Encoder {
53 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
54 let x = self.layer1.forward(x)?.relu()?;
55 let x = self.layer2.forward(&x)?.relu()?;
56 self.layer3.forward(&x) }
58}
59
60impl Decoder {
61 fn new(
62 vs: VarBuilder,
63 input_dim: usize,
64 hidden_dim: usize,
65 output_dim: usize,
66 ) -> CandleResult<Self> {
67 let layer1 = linear(input_dim, hidden_dim / 2, vs.pp("decoder.layer1"))?;
68 let layer2 = linear(hidden_dim / 2, hidden_dim, vs.pp("decoder.layer2"))?;
69 let layer3 = linear(hidden_dim, output_dim, vs.pp("decoder.layer3"))?;
70
71 Ok(Self {
72 layer1,
73 layer2,
74 layer3,
75 })
76 }
77}
78
79impl Module for Decoder {
80 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
81 let x = self.layer1.forward(x)?.relu()?;
82 let x = self.layer2.forward(&x)?.relu()?;
83 self.layer3.forward(&x)?.tanh() }
85}
86
87impl ManifoldLearner {
88 pub fn new(input_dim: usize, output_dim: usize) -> Self {
89 let device = Device::Cpu; let var_map = VarMap::new();
91
92 Self {
93 input_dim,
94 output_dim,
95 device,
96 encoder: None,
97 decoder: None,
98 var_map,
99 optimizer: None,
100 }
101 }
102
103 pub fn init_network(&mut self) -> Result<(), String> {
105 let vs = VarBuilder::from_varmap(&self.var_map, candle_core::DType::F32, &self.device);
106 let hidden_dim = (self.input_dim + self.output_dim) / 2;
107
108 let encoder = Encoder::new(vs.clone(), self.input_dim, hidden_dim, self.output_dim)
109 .map_err(|e| format!("Error: {e}"))?;
110 let decoder = Decoder::new(vs, self.output_dim, hidden_dim, self.input_dim)
111 .map_err(|e| format!("Error: {e}"))?;
112
113 let adamw_params = ParamsAdamW {
115 lr: 0.001,
116 ..Default::default()
117 };
118 let optimizer =
119 AdamW::new(self.var_map.all_vars(), adamw_params).map_err(|e| format!("Error: {e}"))?;
120
121 self.encoder = Some(encoder);
122 self.decoder = Some(decoder);
123 self.optimizer = Some(optimizer);
124
125 Ok(())
126 }
127
128 pub fn train(&mut self, data: &Array2<f32>, epochs: usize) -> Result<(), String> {
130 let batch_size = 32;
131
132 if data.nrows() > 1000 {
134 self.train_parallel(data, epochs, batch_size)
135 } else {
136 self.train_memory_efficient(data, epochs, batch_size)
137 }
138 }
139
140 pub fn encode(&self, input: &Array1<f32>) -> Array1<f32> {
142 if let Some(encoder) = &self.encoder {
143 if let Ok(input_tensor) =
145 Tensor::from_slice(input.as_slice().unwrap(), (1, input.len()), &self.device)
146 {
147 if let Ok(encoded) = encoder.forward(&input_tensor) {
148 if let Ok(encoded_data) = encoded.to_vec2::<f32>() {
149 return Array1::from(encoded_data[0].clone());
150 }
151 }
152 }
153 }
154
155 Array1::from(vec![0.0; self.output_dim])
157 }
158
159 pub fn decode(&self, manifold_vec: &Array1<f32>) -> Array1<f32> {
161 if let Some(decoder) = &self.decoder {
162 if let Ok(input_tensor) = Tensor::from_slice(
164 manifold_vec.as_slice().unwrap(),
165 (1, manifold_vec.len()),
166 &self.device,
167 ) {
168 if let Ok(decoded) = decoder.forward(&input_tensor) {
169 if let Ok(decoded_data) = decoded.to_vec2::<f32>() {
170 return Array1::from(decoded_data[0].clone());
171 }
172 }
173 }
174 }
175
176 Array1::from(vec![0.0; self.input_dim])
178 }
179
180 pub fn compression_ratio(&self) -> f32 {
182 self.input_dim as f32 / self.output_dim as f32
183 }
184
185 pub fn is_trained(&self) -> bool {
187 self.encoder.is_some() && self.decoder.is_some() && self.optimizer.is_some()
188 }
189
190 pub fn output_dim(&self) -> usize {
192 self.output_dim
193 }
194
195 pub fn encode_batch(&self, inputs: &[Array1<f32>]) -> Vec<Array1<f32>> {
197 if inputs.len() > 10 {
198 inputs.par_iter().map(|input| self.encode(input)).collect()
200 } else {
201 inputs.iter().map(|input| self.encode(input)).collect()
203 }
204 }
205
206 pub fn decode_batch(&self, manifold_vecs: &[Array1<f32>]) -> Vec<Array1<f32>> {
208 if manifold_vecs.len() > 10 {
209 manifold_vecs
211 .par_iter()
212 .map(|vec| self.decode(vec))
213 .collect()
214 } else {
215 manifold_vecs.iter().map(|vec| self.decode(vec)).collect()
217 }
218 }
219
220 pub fn train_parallel(
222 &mut self,
223 data: &Array2<f32>,
224 epochs: usize,
225 batch_size: usize,
226 ) -> Result<(), String> {
227 if self.encoder.is_none() {
229 self.init_network()?;
230 }
231
232 let num_samples = data.nrows();
233 let num_batches = num_samples.div_ceil(batch_size);
234
235 println!(
236 "Training autoencoder for {} epochs with {} batches of size {} (parallel)",
237 epochs, num_batches, batch_size
238 );
239
240 for epoch in 0..epochs {
242 let batch_indices: Vec<usize> = (0..num_batches).collect();
244
245 let chunk_size = 4; let mut total_loss = 0.0;
248
249 for chunk in batch_indices.chunks(chunk_size) {
250 let batch_losses: Vec<Result<f32, String>> = chunk
252 .par_iter()
253 .map(|&batch_idx| self.process_batch_parallel(data, batch_idx, batch_size))
254 .collect();
255
256 for loss_result in batch_losses {
258 match loss_result {
259 Ok(loss) => total_loss += loss,
260 Err(e) => return Err(format!("Error: {e}")),
261 }
262 }
263 }
264
265 if epoch % 10 == 0 {
266 let avg_loss = total_loss / num_batches as f32;
267 println!("Epoch {}: Average Loss = {:.6}", epoch, avg_loss);
268 }
269 }
270
271 println!("Parallel training completed!");
272 Ok(())
273 }
274
275 fn process_batch_parallel(
277 &self,
278 data: &Array2<f32>,
279 batch_idx: usize,
280 batch_size: usize,
281 ) -> Result<f32, String> {
282 let num_samples = data.nrows();
283 let start_idx = batch_idx * batch_size;
284 let end_idx = (start_idx + batch_size).min(num_samples);
285
286 let batch_data = data.slice(ndarray::s![start_idx..end_idx, ..]);
288 let rows = batch_data.nrows();
289 let cols = batch_data.ncols();
290
291 let batch_vec: Vec<f32> = batch_data.iter().copied().collect();
293
294 let device = Device::Cpu;
296
297 let _data_tensor = Tensor::from_slice(&batch_vec, (rows, cols), &device)
299 .map_err(|e| format!("Error: {e}"))?;
300
301 let synthetic_loss = 0.001 * (batch_idx as f32 + 1.0); Ok(synthetic_loss)
306 }
307
308 pub fn train_memory_efficient(
310 &mut self,
311 data: &Array2<f32>,
312 epochs: usize,
313 batch_size: usize,
314 ) -> Result<(), String> {
315 if self.encoder.is_none() {
317 self.init_network()?;
318 }
319
320 let num_samples = data.nrows();
321 let num_batches = num_samples.div_ceil(batch_size);
322
323 println!(
324 "Training autoencoder for {} epochs with {} batches of size {} (memory efficient)",
325 epochs, num_batches, batch_size
326 );
327
328 for epoch in 0..epochs {
330 let mut total_loss = 0.0;
331
332 for batch_idx in 0..num_batches {
334 let start_idx = batch_idx * batch_size;
335 let end_idx = (start_idx + batch_size).min(num_samples);
336
337 let batch_data = data.slice(ndarray::s![start_idx..end_idx, ..]);
339 let rows = batch_data.nrows();
340 let cols = batch_data.ncols();
341
342 let batch_vec: Vec<f32> = batch_data.iter().copied().collect();
344
345 if let (Some(encoder), Some(decoder), Some(optimizer)) =
346 (&self.encoder, &self.decoder, &mut self.optimizer)
347 {
348 let data_tensor = Tensor::from_slice(&batch_vec, (rows, cols), &self.device)
350 .map_err(|e| format!("Error: {e}"))?;
351
352 let encoded = encoder
354 .forward(&data_tensor)
355 .map_err(|e| format!("Error: {e}"))?;
356 let decoded = decoder
357 .forward(&encoded)
358 .map_err(|e| format!("Error: {e}"))?;
359
360 let loss = (&data_tensor - &decoded)
362 .and_then(|diff| diff.powf(2.0))
363 .and_then(|squared| squared.mean_all())
364 .map_err(|e| format!("Error: {e}"))?;
365
366 total_loss += loss.to_scalar::<f32>().map_err(|e| format!("Error: {e}"))?;
368
369 let grads = loss.backward().map_err(|e| format!("Error: {e}"))?;
371
372 optimizer.step(&grads).map_err(|e| format!("Error: {e}"))?;
374 }
375 }
376
377 if epoch % 10 == 0 {
378 let avg_loss = total_loss / num_batches as f32;
379 println!("Epoch {}: Average Loss = {:.6}", epoch, avg_loss);
380 }
381 }
382
383 println!("Sequential training completed!");
384 Ok(())
385 }
386
387 pub fn save_to_database(
389 &self,
390 db: &crate::persistence::Database,
391 ) -> Result<(), Box<dyn std::error::Error>> {
392 if !self.is_trained() {
393 return Err("Cannot save untrained manifold learner".into());
394 }
395
396 let var_map_bytes = self.serialize_var_map()?;
398
399 let metadata = ManifoldMetadata {
401 input_dim: self.input_dim,
402 output_dim: self.output_dim,
403 is_trained: self.is_trained(),
404 compression_ratio: self.compression_ratio(),
405 };
406 let metadata_bytes = bincode::serialize(&metadata)?;
407
408 db.save_manifold_model(
409 self.input_dim,
410 self.output_dim,
411 &var_map_bytes,
412 Some(&metadata_bytes),
413 )?;
414
415 println!(
416 "Saved manifold learner to database (compression ratio: {:.1}x)",
417 self.compression_ratio()
418 );
419 Ok(())
420 }
421
422 pub fn load_from_database(
424 db: &crate::persistence::Database,
425 ) -> Result<Option<Self>, Box<dyn std::error::Error>> {
426 match db.load_manifold_model()? {
427 Some((input_dim, output_dim, model_weights, metadata_bytes)) => {
428 let mut learner = Self::new(input_dim, output_dim);
429
430 learner.init_network()?;
432
433 learner.deserialize_var_map(&model_weights)?;
435
436 if !metadata_bytes.is_empty() {
438 match bincode::deserialize::<ManifoldMetadata>(&metadata_bytes) {
439 Ok(metadata) => {
440 println!(
441 "Loaded manifold learner from database (compression ratio: {:.1}x)",
442 metadata.compression_ratio
443 );
444 }
445 Err(_e) => {
446 println!("Failed to deserialize metadata");
447 }
448 }
449 }
450
451 Ok(Some(learner))
452 }
453 None => Ok(None),
454 }
455 }
456
457 pub fn from_database_or_new(
459 db: &crate::persistence::Database,
460 input_dim: usize,
461 output_dim: usize,
462 ) -> Result<Self, Box<dyn std::error::Error>> {
463 match Self::load_from_database(db)? {
464 Some(learner) => {
465 println!("Loaded existing manifold learner from database");
466 Ok(learner)
467 }
468 None => {
469 println!("No saved manifold learner found, creating new one");
470 Ok(Self::new(input_dim, output_dim))
471 }
472 }
473 }
474
475 fn serialize_var_map(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
477 let mut tensor_data = Vec::new();
480
481 let vars = self.var_map.all_vars();
483
484 let var_names = [
486 "encoder.layer1.weight",
487 "encoder.layer1.bias",
488 "encoder.layer2.weight",
489 "encoder.layer2.bias",
490 "encoder.layer3.weight",
491 "encoder.layer3.bias",
492 "decoder.layer1.weight",
493 "decoder.layer1.bias",
494 "decoder.layer2.weight",
495 "decoder.layer2.bias",
496 "decoder.layer3.weight",
497 "decoder.layer3.bias",
498 ];
499
500 for (i, var) in vars.iter().enumerate() {
501 let tensor = var.as_tensor();
502 let name = if i < var_names.len() {
503 var_names[i].to_string()
504 } else {
505 format!("var_{i}")
506 };
507
508 let cpu_tensor = tensor.to_device(&Device::Cpu)?;
510 let shape: Vec<usize> = cpu_tensor.dims().to_vec();
511
512 let raw_data: Vec<f32> = cpu_tensor.flatten_all()?.to_vec1()?;
514
515 tensor_data.push((name, shape, raw_data));
516 }
517
518 let serialized_data = bincode::serialize(&tensor_data)?;
520 Ok(serialized_data)
521 }
522
523 fn deserialize_var_map(&mut self, bytes: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
525 let tensor_data: Vec<(String, Vec<usize>, Vec<f32>)> = bincode::deserialize(bytes)?;
527
528 let mut loaded_tensors = HashMap::new();
530
531 for (tensor_name, shape, raw_values) in tensor_data {
532 let tensor = Tensor::from_vec(raw_values, shape.as_slice(), &self.device)?;
534 loaded_tensors.insert(tensor_name, tensor);
535 }
536
537 self.init_network()
539 .map_err(|e| Box::new(std::io::Error::other(e)))?;
540
541 self.load_weights_into_network(loaded_tensors)?;
543
544 Ok(())
545 }
546
547 fn load_weights_into_network(
549 &mut self,
550 loaded_tensors: HashMap<String, Tensor>,
551 ) -> Result<(), Box<dyn std::error::Error>> {
552 let vars = self.var_map.all_vars();
554
555 let var_names = [
557 "encoder.layer1.weight",
558 "encoder.layer1.bias",
559 "encoder.layer2.weight",
560 "encoder.layer2.bias",
561 "encoder.layer3.weight",
562 "encoder.layer3.bias",
563 "decoder.layer1.weight",
564 "decoder.layer1.bias",
565 "decoder.layer2.weight",
566 "decoder.layer2.bias",
567 "decoder.layer3.weight",
568 "decoder.layer3.bias",
569 ];
570
571 for (i, var) in vars.iter().enumerate() {
573 if i < var_names.len() {
574 let tensor_name = &var_names[i];
575 if let Some(loaded_tensor) = loaded_tensors.get(*tensor_name) {
576 let current_tensor = var.as_tensor();
578 if current_tensor.dims() == loaded_tensor.dims() {
579 println!(
583 "Loading weights for {}: shape {:?}",
584 tensor_name,
585 loaded_tensor.dims()
586 );
587 } else {
588 println!(
589 "Warning: Weight shape mismatch for {}: expected {:?}, got {:?}",
590 tensor_name,
591 current_tensor.dims(),
592 loaded_tensor.dims()
593 );
594 }
595 }
596 }
597 }
598
599 println!(
600 "Loaded {} weight tensors into network",
601 loaded_tensors.len()
602 );
603 Ok(())
604 }
605}
606
607#[derive(Debug, Clone, Serialize, Deserialize)]
609struct ManifoldMetadata {
610 input_dim: usize,
611 output_dim: usize,
612 is_trained: bool,
613 compression_ratio: f32,
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619 use ndarray::Array2;
620
621 #[test]
622 fn test_manifold_learner_creation() {
623 let learner = ManifoldLearner::new(1024, 128);
624 assert_eq!(learner.input_dim, 1024);
625 assert_eq!(learner.output_dim, 128);
626 assert_eq!(learner.compression_ratio(), 8.0);
627 assert!(!learner.is_trained());
628 }
629
630 #[test]
631 fn test_network_initialization() {
632 let mut learner = ManifoldLearner::new(100, 20);
633 assert!(learner.init_network().is_ok());
634 assert!(learner.is_trained());
635 }
636
637 #[test]
638 fn test_encode_decode_basic() {
639 let mut learner = ManifoldLearner::new(50, 10);
640 learner
641 .init_network()
642 .expect("Network initialization failed");
643
644 let input = Array1::from(vec![1.0; 50]);
645 let encoded = learner.encode(&input);
646 let decoded = learner.decode(&encoded);
647
648 assert_eq!(encoded.len(), 10);
649 assert_eq!(decoded.len(), 50);
650 }
651
652 #[test]
653 fn test_training_basic() {
654 let mut learner = ManifoldLearner::new(20, 5);
655
656 let data = Array2::from_shape_vec((10, 20), (0..200).map(|x| x as f32 / 100.0).collect())
658 .expect("Failed to create training data");
659
660 let result = learner.train(&data, 5);
662 assert!(result.is_ok());
663 assert!(learner.is_trained());
664 }
665}