1use crate::circuit_integration::QuantumMLExecutor;
8use crate::error::{MLError, Result};
9use crate::scirs2_integration::{SciRS2Array, SciRS2Optimizer};
10use crate::simulator_backends::{Observable, SimulatorBackend};
11use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis, Dimension, IxDyn};
12use quantrs2_circuit::prelude::*;
13use std::cell::RefCell;
14use std::collections::HashMap;
15use std::rc::Rc;
16
17pub trait QuantumModule: Send + Sync {
19 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array>;
21
22 fn parameters(&self) -> Vec<Parameter>;
24
25 fn train(&mut self, mode: bool);
27
28 fn training(&self) -> bool;
30
31 fn zero_grad(&mut self);
33
34 fn name(&self) -> &str;
36}
37
38#[derive(Debug, Clone)]
40pub struct Parameter {
41 pub data: SciRS2Array,
43 pub name: String,
45 pub requires_grad: bool,
47}
48
49impl Parameter {
50 pub fn new(data: SciRS2Array, name: impl Into<String>) -> Self {
52 Self {
53 data,
54 name: name.into(),
55 requires_grad: true,
56 }
57 }
58
59 pub fn no_grad(data: SciRS2Array, name: impl Into<String>) -> Self {
61 Self {
62 data,
63 name: name.into(),
64 requires_grad: false,
65 }
66 }
67
68 pub fn shape(&self) -> &[usize] {
70 self.data.data.shape()
71 }
72
73 pub fn numel(&self) -> usize {
75 self.data.data.len()
76 }
77}
78
79pub struct QuantumLinear {
81 weights: Parameter,
83 bias: Option<Parameter>,
85 in_features: usize,
87 out_features: usize,
89 training: bool,
91 executor: QuantumMLExecutor<8>, }
94
95impl QuantumLinear {
96 pub fn new(in_features: usize, out_features: usize) -> Result<Self> {
98 let weight_data = ArrayD::zeros(IxDyn(&[out_features, in_features]));
99 let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
100
101 Ok(Self {
102 weights,
103 bias: None,
104 in_features,
105 out_features,
106 training: true,
107 executor: QuantumMLExecutor::new(),
108 })
109 }
110
111 pub fn with_bias(mut self) -> Result<Self> {
113 let bias_data = ArrayD::zeros(IxDyn(&[self.out_features]));
114 self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
115 Ok(self)
116 }
117
118 pub fn init_xavier_uniform(&mut self) -> Result<()> {
120 let fan_in = self.in_features as f64;
121 let fan_out = self.out_features as f64;
122 let bound = (6.0 / (fan_in + fan_out)).sqrt();
123
124 for elem in self.weights.data.data.iter_mut() {
125 *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
126 }
127
128 Ok(())
129 }
130}
131
132impl QuantumModule for QuantumLinear {
133 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
134 let output = input.matmul(&self.weights.data)?;
136
137 if let Some(ref bias) = self.bias {
138 output.add(&bias.data)
139 } else {
140 Ok(output)
141 }
142 }
143
144 fn parameters(&self) -> Vec<Parameter> {
145 let mut params = vec![self.weights.clone()];
146 if let Some(ref bias) = self.bias {
147 params.push(bias.clone());
148 }
149 params
150 }
151
152 fn train(&mut self, mode: bool) {
153 self.training = mode;
154 }
155
156 fn training(&self) -> bool {
157 self.training
158 }
159
160 fn zero_grad(&mut self) {
161 self.weights.data.zero_grad();
162 if let Some(ref mut bias) = self.bias {
163 bias.data.zero_grad();
164 }
165 }
166
167 fn name(&self) -> &str {
168 "QuantumLinear"
169 }
170}
171
172pub struct QuantumConv2d {
174 weights: Parameter,
176 bias: Option<Parameter>,
178 in_channels: usize,
180 out_channels: usize,
182 kernel_size: (usize, usize),
184 stride: (usize, usize),
186 padding: (usize, usize),
188 training: bool,
190}
191
192impl QuantumConv2d {
193 pub fn new(
195 in_channels: usize,
196 out_channels: usize,
197 kernel_size: (usize, usize),
198 ) -> Result<Self> {
199 let weight_shape = [out_channels, in_channels, kernel_size.0, kernel_size.1];
200 let weight_data = ArrayD::zeros(IxDyn(&weight_shape));
201 let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
202
203 Ok(Self {
204 weights,
205 bias: None,
206 in_channels,
207 out_channels,
208 kernel_size,
209 stride: (1, 1),
210 padding: (0, 0),
211 training: true,
212 })
213 }
214
215 pub fn stride(mut self, stride: (usize, usize)) -> Self {
217 self.stride = stride;
218 self
219 }
220
221 pub fn padding(mut self, padding: (usize, usize)) -> Self {
223 self.padding = padding;
224 self
225 }
226
227 pub fn with_bias(mut self) -> Result<Self> {
229 let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
230 self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
231 Ok(self)
232 }
233}
234
235impl QuantumModule for QuantumConv2d {
236 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
237 let output_data = input.data.clone(); let mut output = SciRS2Array::new(output_data, input.requires_grad);
241
242 if let Some(ref bias) = self.bias {
243 output = output.add(&bias.data)?;
244 }
245
246 Ok(output)
247 }
248
249 fn parameters(&self) -> Vec<Parameter> {
250 let mut params = vec![self.weights.clone()];
251 if let Some(ref bias) = self.bias {
252 params.push(bias.clone());
253 }
254 params
255 }
256
257 fn train(&mut self, mode: bool) {
258 self.training = mode;
259 }
260
261 fn training(&self) -> bool {
262 self.training
263 }
264
265 fn zero_grad(&mut self) {
266 self.weights.data.zero_grad();
267 if let Some(ref mut bias) = self.bias {
268 bias.data.zero_grad();
269 }
270 }
271
272 fn name(&self) -> &str {
273 "QuantumConv2d"
274 }
275}
276
277pub struct QuantumActivation {
279 activation_type: ActivationType,
281 training: bool,
283}
284
285#[derive(Debug, Clone)]
287pub enum ActivationType {
288 QReLU,
290 QSigmoid,
292 QTanh,
294 QSoftmax,
296 Identity,
298}
299
300impl QuantumActivation {
301 pub fn new(activation_type: ActivationType) -> Self {
303 Self {
304 activation_type,
305 training: true,
306 }
307 }
308
309 pub fn relu() -> Self {
311 Self::new(ActivationType::QReLU)
312 }
313
314 pub fn sigmoid() -> Self {
316 Self::new(ActivationType::QSigmoid)
317 }
318
319 pub fn tanh() -> Self {
321 Self::new(ActivationType::QTanh)
322 }
323
324 pub fn softmax() -> Self {
326 Self::new(ActivationType::QSoftmax)
327 }
328}
329
330impl QuantumModule for QuantumActivation {
331 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
332 match self.activation_type {
333 ActivationType::QReLU => {
334 let output_data = input.data.mapv(|x| x.max(0.0));
336 Ok(SciRS2Array::new(output_data, input.requires_grad))
337 }
338 ActivationType::QSigmoid => {
339 let output_data = input.data.mapv(|x| 1.0 / (1.0 + (-x).exp()));
341 Ok(SciRS2Array::new(output_data, input.requires_grad))
342 }
343 ActivationType::QTanh => {
344 let output_data = input.data.mapv(|x| x.tanh());
346 Ok(SciRS2Array::new(output_data, input.requires_grad))
347 }
348 ActivationType::QSoftmax => {
349 let max_val = input.data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
351 let exp_data = input.data.mapv(|x| (x - max_val).exp());
352 let sum_exp = exp_data.sum();
353 let output_data = exp_data.mapv(|x| x / sum_exp);
354 Ok(SciRS2Array::new(output_data, input.requires_grad))
355 }
356 ActivationType::Identity => {
357 Ok(SciRS2Array::new(input.data.clone(), input.requires_grad))
358 }
359 }
360 }
361
362 fn parameters(&self) -> Vec<Parameter> {
363 Vec::new() }
365
366 fn train(&mut self, mode: bool) {
367 self.training = mode;
368 }
369
370 fn training(&self) -> bool {
371 self.training
372 }
373
374 fn zero_grad(&mut self) {
375 }
377
378 fn name(&self) -> &str {
379 "QuantumActivation"
380 }
381}
382
383pub struct QuantumSequential {
385 modules: Vec<Box<dyn QuantumModule>>,
387 training: bool,
389}
390
391impl QuantumSequential {
392 pub fn new() -> Self {
394 Self {
395 modules: Vec::new(),
396 training: true,
397 }
398 }
399
400 pub fn add(mut self, module: Box<dyn QuantumModule>) -> Self {
402 self.modules.push(module);
403 self
404 }
405
406 pub fn len(&self) -> usize {
408 self.modules.len()
409 }
410
411 pub fn is_empty(&self) -> bool {
413 self.modules.is_empty()
414 }
415}
416
417impl QuantumModule for QuantumSequential {
418 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
419 let mut output = input.clone();
420
421 for module in &mut self.modules {
422 output = module.forward(&output)?;
423 }
424
425 Ok(output)
426 }
427
428 fn parameters(&self) -> Vec<Parameter> {
429 let mut all_params = Vec::new();
430
431 for module in &self.modules {
432 all_params.extend(module.parameters());
433 }
434
435 all_params
436 }
437
438 fn train(&mut self, mode: bool) {
439 self.training = mode;
440 for module in &mut self.modules {
441 module.train(mode);
442 }
443 }
444
445 fn training(&self) -> bool {
446 self.training
447 }
448
449 fn zero_grad(&mut self) {
450 for module in &mut self.modules {
451 module.zero_grad();
452 }
453 }
454
455 fn name(&self) -> &str {
456 "QuantumSequential"
457 }
458}
459
460pub trait QuantumLoss {
462 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array>;
464
465 fn name(&self) -> &str;
467}
468
469pub struct QuantumMSELoss;
471
472impl QuantumLoss for QuantumMSELoss {
473 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
474 let diff = predictions.data.clone() - &targets.data;
475 let squared_diff = &diff * &diff;
476 let mse = squared_diff.mean().unwrap();
477
478 let loss_data = ArrayD::from_elem(IxDyn(&[]), mse);
479 Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
480 }
481
482 fn name(&self) -> &str {
483 "MSELoss"
484 }
485}
486
487pub struct QuantumCrossEntropyLoss;
489
490impl QuantumLoss for QuantumCrossEntropyLoss {
491 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
492 let max_val = predictions
494 .data
495 .iter()
496 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
497 let exp_preds = predictions.data.mapv(|x| (x - max_val).exp());
498 let sum_exp = exp_preds.sum();
499 let softmax = exp_preds.mapv(|x| x / sum_exp);
500
501 let log_softmax = softmax.mapv(|x| x.ln());
503 let cross_entropy = -(&targets.data * &log_softmax).sum();
504
505 let loss_data = ArrayD::from_elem(IxDyn(&[]), cross_entropy);
506 Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
507 }
508
509 fn name(&self) -> &str {
510 "CrossEntropyLoss"
511 }
512}
513
514pub struct QuantumTrainer {
516 model: Box<dyn QuantumModule>,
518 optimizer: SciRS2Optimizer,
520 loss_fn: Box<dyn QuantumLoss>,
522 history: TrainingHistory,
524}
525
526#[derive(Debug, Clone)]
528pub struct TrainingHistory {
529 pub losses: Vec<f64>,
531 pub accuracies: Vec<f64>,
533 pub val_losses: Vec<f64>,
535 pub val_accuracies: Vec<f64>,
537}
538
539impl TrainingHistory {
540 pub fn new() -> Self {
542 Self {
543 losses: Vec::new(),
544 accuracies: Vec::new(),
545 val_losses: Vec::new(),
546 val_accuracies: Vec::new(),
547 }
548 }
549
550 pub fn add_training(&mut self, loss: f64, accuracy: Option<f64>) {
552 self.losses.push(loss);
553 if let Some(acc) = accuracy {
554 self.accuracies.push(acc);
555 }
556 }
557
558 pub fn add_validation(&mut self, loss: f64, accuracy: Option<f64>) {
560 self.val_losses.push(loss);
561 if let Some(acc) = accuracy {
562 self.val_accuracies.push(acc);
563 }
564 }
565}
566
567impl QuantumTrainer {
568 pub fn new(
570 model: Box<dyn QuantumModule>,
571 optimizer: SciRS2Optimizer,
572 loss_fn: Box<dyn QuantumLoss>,
573 ) -> Self {
574 Self {
575 model,
576 optimizer,
577 loss_fn,
578 history: TrainingHistory::new(),
579 }
580 }
581
582 pub fn train_epoch(&mut self, dataloader: &mut dyn DataLoader) -> Result<f64> {
584 self.model.train(true);
585 let mut total_loss = 0.0;
586 let mut num_batches = 0;
587
588 while let Some((inputs, targets)) = dataloader.next_batch()? {
589 self.model.zero_grad();
591
592 let predictions = self.model.forward(&inputs)?;
594
595 let loss = self.loss_fn.forward(&predictions, &targets)?;
597 total_loss += loss.data[[0]];
598
599 let mut params = HashMap::new();
604 for (i, param) in self.model.parameters().iter().enumerate() {
605 params.insert(format!("param_{}", i), param.data.clone());
606 }
607 self.optimizer.step(&mut params)?;
608
609 num_batches += 1;
610 }
611
612 let avg_loss = total_loss / num_batches as f64;
613 self.history.add_training(avg_loss, None);
614 Ok(avg_loss)
615 }
616
617 pub fn evaluate(&mut self, dataloader: &mut dyn DataLoader) -> Result<f64> {
619 self.model.train(false);
620 let mut total_loss = 0.0;
621 let mut num_batches = 0;
622
623 while let Some((inputs, targets)) = dataloader.next_batch()? {
624 let predictions = self.model.forward(&inputs)?;
626
627 let loss = self.loss_fn.forward(&predictions, &targets)?;
629 total_loss += loss.data[[0]];
630
631 num_batches += 1;
632 }
633
634 let avg_loss = total_loss / num_batches as f64;
635 self.history.add_validation(avg_loss, None);
636 Ok(avg_loss)
637 }
638
639 pub fn history(&self) -> &TrainingHistory {
641 &self.history
642 }
643}
644
645pub trait DataLoader {
647 fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>>;
649
650 fn reset(&mut self);
652
653 fn batch_size(&self) -> usize;
655}
656
657pub struct MemoryDataLoader {
659 inputs: SciRS2Array,
661 targets: SciRS2Array,
663 batch_size: usize,
665 current_pos: usize,
667 shuffle: bool,
669 indices: Vec<usize>,
671}
672
673impl MemoryDataLoader {
674 pub fn new(
676 inputs: SciRS2Array,
677 targets: SciRS2Array,
678 batch_size: usize,
679 shuffle: bool,
680 ) -> Result<Self> {
681 let num_samples = inputs.data.shape()[0];
682 if targets.data.shape()[0] != num_samples {
683 return Err(MLError::InvalidConfiguration(
684 "Input and target batch sizes don't match".to_string(),
685 ));
686 }
687
688 let indices: Vec<usize> = (0..num_samples).collect();
689
690 Ok(Self {
691 inputs,
692 targets,
693 batch_size,
694 current_pos: 0,
695 shuffle,
696 indices,
697 })
698 }
699
700 fn shuffle_indices(&mut self) {
702 if self.shuffle {
703 for i in (1..self.indices.len()).rev() {
705 let j = fastrand::usize(0..=i);
706 self.indices.swap(i, j);
707 }
708 }
709 }
710}
711
712impl DataLoader for MemoryDataLoader {
713 fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>> {
714 if self.current_pos >= self.indices.len() {
715 return Ok(None);
716 }
717
718 let end_pos = (self.current_pos + self.batch_size).min(self.indices.len());
719 let batch_indices = &self.indices[self.current_pos..end_pos];
720
721 let batch_inputs = self.inputs.clone(); let batch_targets = self.targets.clone(); self.current_pos = end_pos;
726
727 Ok(Some((batch_inputs, batch_targets)))
728 }
729
730 fn reset(&mut self) {
731 self.current_pos = 0;
732 self.shuffle_indices();
733 }
734
735 fn batch_size(&self) -> usize {
736 self.batch_size
737 }
738}
739
740pub mod quantum_nn {
742 use super::*;
743
744 pub fn create_feedforward(
746 input_size: usize,
747 hidden_sizes: &[usize],
748 output_size: usize,
749 activation: ActivationType,
750 ) -> Result<QuantumSequential> {
751 let mut model = QuantumSequential::new();
752
753 let mut prev_size = input_size;
754
755 for &hidden_size in hidden_sizes {
757 model = model.add(Box::new(
758 QuantumLinear::new(prev_size, hidden_size)?.with_bias()?,
759 ));
760 model = model.add(Box::new(QuantumActivation::new(activation.clone())));
761 prev_size = hidden_size;
762 }
763
764 model = model.add(Box::new(
766 QuantumLinear::new(prev_size, output_size)?.with_bias()?,
767 ));
768
769 Ok(model)
770 }
771
772 pub fn create_cnn(input_channels: usize, num_classes: usize) -> Result<QuantumSequential> {
774 let model = QuantumSequential::new()
775 .add(Box::new(
776 QuantumConv2d::new(input_channels, 32, (3, 3))?.with_bias()?,
777 ))
778 .add(Box::new(QuantumActivation::relu()))
779 .add(Box::new(QuantumConv2d::new(32, 64, (3, 3))?.with_bias()?))
780 .add(Box::new(QuantumActivation::relu()))
781 .add(Box::new(QuantumLinear::new(64, num_classes)?.with_bias()?));
782
783 Ok(model)
784 }
785
786 pub fn init_parameters(model: &mut dyn QuantumModule, init_type: InitType) -> Result<()> {
788 for mut param in model.parameters() {
789 match init_type {
790 InitType::Xavier => {
791 let fan_in = param.shape().iter().rev().skip(1).product::<usize>() as f64;
793 let fan_out = param.shape()[0] as f64;
794 let bound = (6.0 / (fan_in + fan_out)).sqrt();
795
796 for elem in param.data.data.iter_mut() {
797 *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
798 }
799 }
800 InitType::He => {
801 let fan_in = param.shape().iter().rev().skip(1).product::<usize>() as f64;
803 let std = (2.0 / fan_in).sqrt();
804
805 for elem in param.data.data.iter_mut() {
806 *elem = fastrand::f64() * std;
807 }
808 }
809 InitType::Normal(mean, std) => {
810 for elem in param.data.data.iter_mut() {
812 *elem = mean + std * fastrand::f64();
813 }
814 }
815 InitType::Uniform(low, high) => {
816 for elem in param.data.data.iter_mut() {
818 *elem = low + (high - low) * fastrand::f64();
819 }
820 }
821 }
822 }
823 Ok(())
824 }
825}
826
827#[derive(Debug, Clone, Copy)]
829pub enum InitType {
830 Xavier,
832 He,
834 Normal(f64, f64), Uniform(f64, f64), }
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843
844 #[test]
845 fn test_quantum_linear() {
846 let mut linear = QuantumLinear::new(4, 2).unwrap();
847 assert_eq!(linear.in_features, 4);
848 assert_eq!(linear.out_features, 2);
849 assert_eq!(linear.parameters().len(), 1); let linear_with_bias = linear.with_bias().unwrap();
852 }
854
855 #[test]
856 fn test_quantum_sequential() {
857 let model = QuantumSequential::new()
858 .add(Box::new(QuantumLinear::new(4, 8).unwrap()))
859 .add(Box::new(QuantumActivation::relu()))
860 .add(Box::new(QuantumLinear::new(8, 2).unwrap()));
861
862 assert_eq!(model.len(), 3);
863 assert!(!model.is_empty());
864 }
865
866 #[test]
867 fn test_quantum_activation() {
868 let mut relu = QuantumActivation::relu();
869 let input_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![-1.0, 1.0]).unwrap();
870 let input = SciRS2Array::new(input_data, false);
871
872 let output = relu.forward(&input).unwrap();
873 assert_eq!(output.data[[0]], 0.0); assert_eq!(output.data[[1]], 1.0); }
876
877 #[test]
878 #[ignore]
879 fn test_quantum_loss() {
880 let mse_loss = QuantumMSELoss;
881
882 let pred_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap();
883 let target_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.5, 1.8]).unwrap();
884
885 let predictions = SciRS2Array::new(pred_data, false);
886 let targets = SciRS2Array::new(target_data, false);
887
888 let loss = mse_loss.forward(&predictions, &targets).unwrap();
889 assert!(loss.data[[0]] > 0.0); }
891
892 #[test]
893 fn test_parameter() {
894 let data = ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0; 6]).unwrap();
895 let param = Parameter::new(SciRS2Array::new(data, true), "test_param");
896
897 assert_eq!(param.name, "test_param");
898 assert!(param.requires_grad);
899 assert_eq!(param.shape(), &[2, 3]);
900 assert_eq!(param.numel(), 6);
901 }
902
903 #[test]
904 fn test_training_history() {
905 let mut history = TrainingHistory::new();
906 history.add_training(0.5, Some(0.8));
907 history.add_validation(0.6, Some(0.7));
908
909 assert_eq!(history.losses.len(), 1);
910 assert_eq!(history.accuracies.len(), 1);
911 assert_eq!(history.val_losses.len(), 1);
912 assert_eq!(history.val_accuracies.len(), 1);
913 }
914}