1use crate::circuit_integration::QuantumMLExecutor;
8use crate::error::{MLError, Result};
9use crate::scirs2_integration::{SciRS2Array, SciRS2Optimizer};
10use crate::simulator_backends::{Observable, SimulatorBackend};
11use quantrs2_circuit::prelude::*;
12use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis, Dimension, IxDyn};
13use std::cell::RefCell;
14use std::collections::HashMap;
15use std::rc::Rc;
16
17pub trait QuantumModule: Send + Sync {
19 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array>;
21
22 fn parameters(&self) -> Vec<Parameter>;
24
25 fn train(&mut self, mode: bool);
27
28 fn training(&self) -> bool;
30
31 fn zero_grad(&mut self);
33
34 fn name(&self) -> &str;
36}
37
38#[derive(Debug, Clone)]
40pub struct Parameter {
41 pub data: SciRS2Array,
43 pub name: String,
45 pub requires_grad: bool,
47}
48
49impl Parameter {
50 pub fn new(data: SciRS2Array, name: impl Into<String>) -> Self {
52 Self {
53 data,
54 name: name.into(),
55 requires_grad: true,
56 }
57 }
58
59 pub fn no_grad(data: SciRS2Array, name: impl Into<String>) -> Self {
61 Self {
62 data,
63 name: name.into(),
64 requires_grad: false,
65 }
66 }
67
68 pub fn shape(&self) -> &[usize] {
70 self.data.data.shape()
71 }
72
73 pub fn numel(&self) -> usize {
75 self.data.data.len()
76 }
77}
78
79pub struct QuantumLinear {
81 weights: Parameter,
83 bias: Option<Parameter>,
85 in_features: usize,
87 out_features: usize,
89 training: bool,
91 executor: QuantumMLExecutor<8>, }
94
95impl QuantumLinear {
96 pub fn new(in_features: usize, out_features: usize) -> Result<Self> {
98 let weight_data = ArrayD::zeros(IxDyn(&[out_features, in_features]));
99 let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
100
101 Ok(Self {
102 weights,
103 bias: None,
104 in_features,
105 out_features,
106 training: true,
107 executor: QuantumMLExecutor::new(),
108 })
109 }
110
111 pub fn with_bias(mut self) -> Result<Self> {
113 let bias_data = ArrayD::zeros(IxDyn(&[self.out_features]));
114 self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
115 Ok(self)
116 }
117
118 pub fn init_xavier_uniform(&mut self) -> Result<()> {
120 let fan_in = self.in_features as f64;
121 let fan_out = self.out_features as f64;
122 let bound = (6.0 / (fan_in + fan_out)).sqrt();
123
124 for elem in self.weights.data.data.iter_mut() {
125 *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
126 }
127
128 Ok(())
129 }
130}
131
132impl QuantumModule for QuantumLinear {
133 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
134 let output = input.matmul(&self.weights.data)?;
136
137 if let Some(ref bias) = self.bias {
138 output.add(&bias.data)
139 } else {
140 Ok(output)
141 }
142 }
143
144 fn parameters(&self) -> Vec<Parameter> {
145 let mut params = vec![self.weights.clone()];
146 if let Some(ref bias) = self.bias {
147 params.push(bias.clone());
148 }
149 params
150 }
151
152 fn train(&mut self, mode: bool) {
153 self.training = mode;
154 }
155
156 fn training(&self) -> bool {
157 self.training
158 }
159
160 fn zero_grad(&mut self) {
161 self.weights.data.zero_grad();
162 if let Some(ref mut bias) = self.bias {
163 bias.data.zero_grad();
164 }
165 }
166
167 fn name(&self) -> &str {
168 "QuantumLinear"
169 }
170}
171
172pub struct QuantumConv2d {
174 weights: Parameter,
176 bias: Option<Parameter>,
178 in_channels: usize,
180 out_channels: usize,
182 kernel_size: (usize, usize),
184 stride: (usize, usize),
186 padding: (usize, usize),
188 training: bool,
190}
191
192impl QuantumConv2d {
193 pub fn new(
195 in_channels: usize,
196 out_channels: usize,
197 kernel_size: (usize, usize),
198 ) -> Result<Self> {
199 let weight_shape = [out_channels, in_channels, kernel_size.0, kernel_size.1];
200 let weight_data = ArrayD::zeros(IxDyn(&weight_shape));
201 let weights = Parameter::new(SciRS2Array::with_grad(weight_data), "weight");
202
203 Ok(Self {
204 weights,
205 bias: None,
206 in_channels,
207 out_channels,
208 kernel_size,
209 stride: (1, 1),
210 padding: (0, 0),
211 training: true,
212 })
213 }
214
215 pub fn stride(mut self, stride: (usize, usize)) -> Self {
217 self.stride = stride;
218 self
219 }
220
221 pub fn padding(mut self, padding: (usize, usize)) -> Self {
223 self.padding = padding;
224 self
225 }
226
227 pub fn with_bias(mut self) -> Result<Self> {
229 let bias_data = ArrayD::zeros(IxDyn(&[self.out_channels]));
230 self.bias = Some(Parameter::new(SciRS2Array::with_grad(bias_data), "bias"));
231 Ok(self)
232 }
233}
234
235impl QuantumModule for QuantumConv2d {
236 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
237 let output_data = input.data.clone(); let mut output = SciRS2Array::new(output_data, input.requires_grad);
241
242 if let Some(ref bias) = self.bias {
243 output = output.add(&bias.data)?;
244 }
245
246 Ok(output)
247 }
248
249 fn parameters(&self) -> Vec<Parameter> {
250 let mut params = vec![self.weights.clone()];
251 if let Some(ref bias) = self.bias {
252 params.push(bias.clone());
253 }
254 params
255 }
256
257 fn train(&mut self, mode: bool) {
258 self.training = mode;
259 }
260
261 fn training(&self) -> bool {
262 self.training
263 }
264
265 fn zero_grad(&mut self) {
266 self.weights.data.zero_grad();
267 if let Some(ref mut bias) = self.bias {
268 bias.data.zero_grad();
269 }
270 }
271
272 fn name(&self) -> &str {
273 "QuantumConv2d"
274 }
275}
276
277pub struct QuantumActivation {
279 activation_type: ActivationType,
281 training: bool,
283}
284
285#[derive(Debug, Clone)]
287pub enum ActivationType {
288 QReLU,
290 QSigmoid,
292 QTanh,
294 QSoftmax,
296 Identity,
298}
299
300impl QuantumActivation {
301 pub fn new(activation_type: ActivationType) -> Self {
303 Self {
304 activation_type,
305 training: true,
306 }
307 }
308
309 pub fn relu() -> Self {
311 Self::new(ActivationType::QReLU)
312 }
313
314 pub fn sigmoid() -> Self {
316 Self::new(ActivationType::QSigmoid)
317 }
318
319 pub fn tanh() -> Self {
321 Self::new(ActivationType::QTanh)
322 }
323
324 pub fn softmax() -> Self {
326 Self::new(ActivationType::QSoftmax)
327 }
328}
329
330impl QuantumModule for QuantumActivation {
331 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
332 match self.activation_type {
333 ActivationType::QReLU => {
334 let output_data = input.data.mapv(|x| x.max(0.0));
336 Ok(SciRS2Array::new(output_data, input.requires_grad))
337 }
338 ActivationType::QSigmoid => {
339 let output_data = input.data.mapv(|x| 1.0 / (1.0 + (-x).exp()));
341 Ok(SciRS2Array::new(output_data, input.requires_grad))
342 }
343 ActivationType::QTanh => {
344 let output_data = input.data.mapv(|x| x.tanh());
346 Ok(SciRS2Array::new(output_data, input.requires_grad))
347 }
348 ActivationType::QSoftmax => {
349 let max_val = input.data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
351 let exp_data = input.data.mapv(|x| (x - max_val).exp());
352 let sum_exp = exp_data.sum();
353 let output_data = exp_data.mapv(|x| x / sum_exp);
354 Ok(SciRS2Array::new(output_data, input.requires_grad))
355 }
356 ActivationType::Identity => {
357 Ok(SciRS2Array::new(input.data.clone(), input.requires_grad))
358 }
359 }
360 }
361
362 fn parameters(&self) -> Vec<Parameter> {
363 Vec::new() }
365
366 fn train(&mut self, mode: bool) {
367 self.training = mode;
368 }
369
370 fn training(&self) -> bool {
371 self.training
372 }
373
374 fn zero_grad(&mut self) {
375 }
377
378 fn name(&self) -> &str {
379 "QuantumActivation"
380 }
381}
382
383pub struct QuantumSequential {
385 modules: Vec<Box<dyn QuantumModule>>,
387 training: bool,
389}
390
391impl QuantumSequential {
392 pub fn new() -> Self {
394 Self {
395 modules: Vec::new(),
396 training: true,
397 }
398 }
399
400 pub fn add(mut self, module: Box<dyn QuantumModule>) -> Self {
402 self.modules.push(module);
403 self
404 }
405
406 pub fn len(&self) -> usize {
408 self.modules.len()
409 }
410
411 pub fn is_empty(&self) -> bool {
413 self.modules.is_empty()
414 }
415}
416
417impl QuantumModule for QuantumSequential {
418 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
419 let mut output = input.clone();
420
421 for module in &mut self.modules {
422 output = module.forward(&output)?;
423 }
424
425 Ok(output)
426 }
427
428 fn parameters(&self) -> Vec<Parameter> {
429 let mut all_params = Vec::new();
430
431 for module in &self.modules {
432 all_params.extend(module.parameters());
433 }
434
435 all_params
436 }
437
438 fn train(&mut self, mode: bool) {
439 self.training = mode;
440 for module in &mut self.modules {
441 module.train(mode);
442 }
443 }
444
445 fn training(&self) -> bool {
446 self.training
447 }
448
449 fn zero_grad(&mut self) {
450 for module in &mut self.modules {
451 module.zero_grad();
452 }
453 }
454
455 fn name(&self) -> &str {
456 "QuantumSequential"
457 }
458}
459
460pub trait QuantumLoss {
462 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array>;
464
465 fn name(&self) -> &str;
467}
468
469pub struct QuantumMSELoss;
471
472impl QuantumLoss for QuantumMSELoss {
473 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
474 let diff = predictions.data.clone() - &targets.data;
475 let squared_diff = &diff * &diff;
476 let mse = squared_diff.mean().ok_or_else(|| {
477 MLError::InvalidConfiguration("Cannot compute mean of empty array".to_string())
478 })?;
479
480 let loss_data = ArrayD::from_elem(IxDyn(&[]), mse);
481 Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
482 }
483
484 fn name(&self) -> &str {
485 "MSELoss"
486 }
487}
488
489pub struct QuantumCrossEntropyLoss;
491
492impl QuantumLoss for QuantumCrossEntropyLoss {
493 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
494 let max_val = predictions
496 .data
497 .iter()
498 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
499 let exp_preds = predictions.data.mapv(|x| (x - max_val).exp());
500 let sum_exp = exp_preds.sum();
501 let softmax = exp_preds.mapv(|x| x / sum_exp);
502
503 let log_softmax = softmax.mapv(|x| x.ln());
505 let cross_entropy = -(&targets.data * &log_softmax).sum();
506
507 let loss_data = ArrayD::from_elem(IxDyn(&[]), cross_entropy);
508 Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
509 }
510
511 fn name(&self) -> &str {
512 "CrossEntropyLoss"
513 }
514}
515
516pub struct QuantumTrainer {
518 model: Box<dyn QuantumModule>,
520 optimizer: SciRS2Optimizer,
522 loss_fn: Box<dyn QuantumLoss>,
524 history: TrainingHistory,
526}
527
528#[derive(Debug, Clone)]
530pub struct TrainingHistory {
531 pub losses: Vec<f64>,
533 pub accuracies: Vec<f64>,
535 pub val_losses: Vec<f64>,
537 pub val_accuracies: Vec<f64>,
539}
540
541impl TrainingHistory {
542 pub fn new() -> Self {
544 Self {
545 losses: Vec::new(),
546 accuracies: Vec::new(),
547 val_losses: Vec::new(),
548 val_accuracies: Vec::new(),
549 }
550 }
551
552 pub fn add_training(&mut self, loss: f64, accuracy: Option<f64>) {
554 self.losses.push(loss);
555 if let Some(acc) = accuracy {
556 self.accuracies.push(acc);
557 }
558 }
559
560 pub fn add_validation(&mut self, loss: f64, accuracy: Option<f64>) {
562 self.val_losses.push(loss);
563 if let Some(acc) = accuracy {
564 self.val_accuracies.push(acc);
565 }
566 }
567}
568
569impl QuantumTrainer {
570 pub fn new(
572 model: Box<dyn QuantumModule>,
573 optimizer: SciRS2Optimizer,
574 loss_fn: Box<dyn QuantumLoss>,
575 ) -> Self {
576 Self {
577 model,
578 optimizer,
579 loss_fn,
580 history: TrainingHistory::new(),
581 }
582 }
583
584 pub fn train_epoch(&mut self, dataloader: &mut dyn DataLoader) -> Result<f64> {
586 self.model.train(true);
587 let mut total_loss = 0.0;
588 let mut num_batches = 0;
589
590 while let Some((inputs, targets)) = dataloader.next_batch()? {
591 self.model.zero_grad();
593
594 let predictions = self.model.forward(&inputs)?;
596
597 let loss = self.loss_fn.forward(&predictions, &targets)?;
599 total_loss += loss.data[[0]];
600
601 let mut params = HashMap::new();
606 for (i, param) in self.model.parameters().iter().enumerate() {
607 params.insert(format!("param_{}", i), param.data.clone());
608 }
609 self.optimizer.step(&mut params)?;
610
611 num_batches += 1;
612 }
613
614 let avg_loss = total_loss / num_batches as f64;
615 self.history.add_training(avg_loss, None);
616 Ok(avg_loss)
617 }
618
619 pub fn evaluate(&mut self, dataloader: &mut dyn DataLoader) -> Result<f64> {
621 self.model.train(false);
622 let mut total_loss = 0.0;
623 let mut num_batches = 0;
624
625 while let Some((inputs, targets)) = dataloader.next_batch()? {
626 let predictions = self.model.forward(&inputs)?;
628
629 let loss = self.loss_fn.forward(&predictions, &targets)?;
631 total_loss += loss.data[[0]];
632
633 num_batches += 1;
634 }
635
636 let avg_loss = total_loss / num_batches as f64;
637 self.history.add_validation(avg_loss, None);
638 Ok(avg_loss)
639 }
640
641 pub fn history(&self) -> &TrainingHistory {
643 &self.history
644 }
645}
646
647pub trait DataLoader {
649 fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>>;
651
652 fn reset(&mut self);
654
655 fn batch_size(&self) -> usize;
657}
658
659pub struct MemoryDataLoader {
661 inputs: SciRS2Array,
663 targets: SciRS2Array,
665 batch_size: usize,
667 current_pos: usize,
669 shuffle: bool,
671 indices: Vec<usize>,
673}
674
675impl MemoryDataLoader {
676 pub fn new(
678 inputs: SciRS2Array,
679 targets: SciRS2Array,
680 batch_size: usize,
681 shuffle: bool,
682 ) -> Result<Self> {
683 let num_samples = inputs.data.shape()[0];
684 if targets.data.shape()[0] != num_samples {
685 return Err(MLError::InvalidConfiguration(
686 "Input and target batch sizes don't match".to_string(),
687 ));
688 }
689
690 let indices: Vec<usize> = (0..num_samples).collect();
691
692 Ok(Self {
693 inputs,
694 targets,
695 batch_size,
696 current_pos: 0,
697 shuffle,
698 indices,
699 })
700 }
701
702 fn shuffle_indices(&mut self) {
704 if self.shuffle {
705 for i in (1..self.indices.len()).rev() {
707 let j = fastrand::usize(0..=i);
708 self.indices.swap(i, j);
709 }
710 }
711 }
712}
713
714impl DataLoader for MemoryDataLoader {
715 fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>> {
716 if self.current_pos >= self.indices.len() {
717 return Ok(None);
718 }
719
720 let end_pos = (self.current_pos + self.batch_size).min(self.indices.len());
721 let batch_indices = &self.indices[self.current_pos..end_pos];
722
723 let batch_inputs = self.inputs.clone(); let batch_targets = self.targets.clone(); self.current_pos = end_pos;
728
729 Ok(Some((batch_inputs, batch_targets)))
730 }
731
732 fn reset(&mut self) {
733 self.current_pos = 0;
734 self.shuffle_indices();
735 }
736
737 fn batch_size(&self) -> usize {
738 self.batch_size
739 }
740}
741
742pub mod quantum_nn {
744 use super::*;
745
746 pub fn create_feedforward(
748 input_size: usize,
749 hidden_sizes: &[usize],
750 output_size: usize,
751 activation: ActivationType,
752 ) -> Result<QuantumSequential> {
753 let mut model = QuantumSequential::new();
754
755 let mut prev_size = input_size;
756
757 for &hidden_size in hidden_sizes {
759 model = model.add(Box::new(
760 QuantumLinear::new(prev_size, hidden_size)?.with_bias()?,
761 ));
762 model = model.add(Box::new(QuantumActivation::new(activation.clone())));
763 prev_size = hidden_size;
764 }
765
766 model = model.add(Box::new(
768 QuantumLinear::new(prev_size, output_size)?.with_bias()?,
769 ));
770
771 Ok(model)
772 }
773
774 pub fn create_cnn(input_channels: usize, num_classes: usize) -> Result<QuantumSequential> {
776 let model = QuantumSequential::new()
777 .add(Box::new(
778 QuantumConv2d::new(input_channels, 32, (3, 3))?.with_bias()?,
779 ))
780 .add(Box::new(QuantumActivation::relu()))
781 .add(Box::new(QuantumConv2d::new(32, 64, (3, 3))?.with_bias()?))
782 .add(Box::new(QuantumActivation::relu()))
783 .add(Box::new(QuantumLinear::new(64, num_classes)?.with_bias()?));
784
785 Ok(model)
786 }
787
788 pub fn init_parameters(model: &mut dyn QuantumModule, init_type: InitType) -> Result<()> {
790 for mut param in model.parameters() {
791 match init_type {
792 InitType::Xavier => {
793 let fan_in = param.shape().iter().rev().skip(1).product::<usize>() as f64;
795 let fan_out = param.shape()[0] as f64;
796 let bound = (6.0 / (fan_in + fan_out)).sqrt();
797
798 for elem in param.data.data.iter_mut() {
799 *elem = (fastrand::f64() * 2.0 - 1.0) * bound;
800 }
801 }
802 InitType::He => {
803 let fan_in = param.shape().iter().rev().skip(1).product::<usize>() as f64;
805 let std = (2.0 / fan_in).sqrt();
806
807 for elem in param.data.data.iter_mut() {
808 *elem = fastrand::f64() * std;
809 }
810 }
811 InitType::Normal(mean, std) => {
812 for elem in param.data.data.iter_mut() {
814 *elem = mean + std * fastrand::f64();
815 }
816 }
817 InitType::Uniform(low, high) => {
818 for elem in param.data.data.iter_mut() {
820 *elem = low + (high - low) * fastrand::f64();
821 }
822 }
823 }
824 }
825 Ok(())
826 }
827}
828
829#[derive(Debug, Clone, Copy)]
831pub enum InitType {
832 Xavier,
834 He,
836 Normal(f64, f64), Uniform(f64, f64), }
841
842#[cfg(test)]
843mod tests {
844 use super::*;
845
846 #[test]
847 fn test_quantum_linear() {
848 let mut linear = QuantumLinear::new(4, 2).expect("QuantumLinear creation should succeed");
849 assert_eq!(linear.in_features, 4);
850 assert_eq!(linear.out_features, 2);
851 assert_eq!(linear.parameters().len(), 1); let linear_with_bias = linear.with_bias().expect("Adding bias should succeed");
854 }
856
857 #[test]
858 fn test_quantum_sequential() {
859 let model = QuantumSequential::new()
860 .add(Box::new(
861 QuantumLinear::new(4, 8).expect("QuantumLinear creation should succeed"),
862 ))
863 .add(Box::new(QuantumActivation::relu()))
864 .add(Box::new(
865 QuantumLinear::new(8, 2).expect("QuantumLinear creation should succeed"),
866 ));
867
868 assert_eq!(model.len(), 3);
869 assert!(!model.is_empty());
870 }
871
872 #[test]
873 fn test_quantum_activation() {
874 let mut relu = QuantumActivation::relu();
875 let input_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![-1.0, 1.0])
876 .expect("Valid shape for input data");
877 let input = SciRS2Array::new(input_data, false);
878
879 let output = relu.forward(&input).expect("Forward pass should succeed");
880 assert_eq!(output.data[[0]], 0.0); assert_eq!(output.data[[1]], 1.0); }
883
884 #[test]
885 #[ignore]
886 fn test_quantum_loss() {
887 let mse_loss = QuantumMSELoss;
888
889 let pred_data = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0])
890 .expect("Valid shape for predictions");
891 let target_data =
892 ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.5, 1.8]).expect("Valid shape for targets");
893
894 let predictions = SciRS2Array::new(pred_data, false);
895 let targets = SciRS2Array::new(target_data, false);
896
897 let loss = mse_loss
898 .forward(&predictions, &targets)
899 .expect("Loss computation should succeed");
900 assert!(loss.data[[0]] > 0.0); }
902
903 #[test]
904 fn test_parameter() {
905 let data = ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0; 6])
906 .expect("Valid shape for parameter data");
907 let param = Parameter::new(SciRS2Array::new(data, true), "test_param");
908
909 assert_eq!(param.name, "test_param");
910 assert!(param.requires_grad);
911 assert_eq!(param.shape(), &[2, 3]);
912 assert_eq!(param.numel(), 6);
913 }
914
915 #[test]
916 fn test_training_history() {
917 let mut history = TrainingHistory::new();
918 history.add_training(0.5, Some(0.8));
919 history.add_validation(0.6, Some(0.7));
920
921 assert_eq!(history.losses.len(), 1);
922 assert_eq!(history.accuracies.len(), 1);
923 assert_eq!(history.val_losses.len(), 1);
924 assert_eq!(history.val_accuracies.len(), 1);
925 }
926}